반응형
PyTorch 모델에 대한 정보를 보기 쉽게 확인하기 위한 파이썬 라이브러리 torchinfo을 살펴보자.
torchinfo는 모델 구조나 레이어의 텐서 모양 등을 빠르고 쉽게 볼 수 있어 디버깅 및 최적화에 도움이 된다.
torchinfo 설치
pip install torchinfo
위 명령어로 설치 가능하다.
가상 환경에서 파이토치를 사용 중이면 가상 환경을 활성화한 다음 설치하자.
torchinfo 사용 방법
아래와 같은 방식으로 모듈 torchinfo
에서 summary
함수를 가져오면 모델의 summary을 출력할 수 있다.
from torchinfo import summary
model = ...
summary(model, input_size=(batch_size, channels, height, width))
위의 코드에서 model
은 우리가 보고싶은 PyTorch 모델이고, input_size
는 모델에 입력할 텐서의 모양을 넣는 튜플이다.
summary
의 결과는 모델의 병목 현상이나 메모리 사용 및 속도에 대해 모델을 최적화하는데 도움이 된다.
torchinfo 사용 예시
torchinfo을 사용해서 PyTorch 모델을 보는 방법을 예시를 통해 살펴보자.
import torch
import torch.nn as nn
from torchinfo import summary
# Define a simple PyTorch model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 8 * 8)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Create an instance of the model
net = Net()
# Print a summary of the model using torchinfo
print(summary(net, (3, 32, 32)))
코드를 실행하면 아래와 같은 결과가 나온다.
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Net [1, 10] --
├─Conv2d: 1-1 [16, 32, 32] 448
├─MaxPool2d: 1-2 [16, 16, 16] --
├─Conv2d: 1-3 [32, 16, 16] 4,640
├─MaxPool2d: 1-4 [32, 8, 8] --
├─Linear: 1-5 [1, 128] 262,272
├─Linear: 1-6 [1, 10] 1,290
==========================================================================================
Total params: 268,650
Trainable params: 268,650
Non-trainable params: 0
Total mult-adds (M): 2.87
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.20
Params size (MB): 1.07
Estimated Total Size (MB): 1.28
==========================================================================================
모델에 따라 인풋이 여러 개일 수 있다.
인풋이 여러 개인 경우는 다음과 같이 입력하면 된다.
import torch
from torch import nn
from torchinfo import summary
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, input1, input2):
x1 = self.fc1(input1)
x2 = self.fc1(input2)
x = x1 + x2
x = self.fc2(x)
return x
model = MyModel()
input1 = torch.randn(1, 10)
input2 = torch.randn(1, 10)
summary(model, input_data=[input1, input2])
# or
summary(model, [torch.randn(1, 10), torch.randn(1, 10)])
참고로 결과는 아래처럼 나온다.
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
MyModel [1, 1] --
├─Linear: 1-1 [1, 5] 55
├─Linear: 1-2 [1, 5] (recursive)
├─Linear: 1-3 [1, 1] 6
==========================================================================================
Total params: 61
Trainable params: 61
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
이 외에도 인풋에 대한 몇 가지 파라미터들이 있다.
summary(
model: 'nn.Module',
input_size: 'INPUT_SIZE_TYPE | None' = None,
input_data: 'INPUT_DATA_TYPE | None' = None,
batch_dim: 'int | None' = None,
cache_forward_pass: 'bool | None' = None,
col_names: 'Iterable[str] | None' = None,
col_width: 'int' = 25,
depth: 'int' = 3,
device: 'torch.device | str | None' = None,
dtypes: 'list[torch.dtype] | None' = None,
mode: 'str | None' = None,
row_settings: 'Iterable[str] | None' = None,
verbose: 'int | None' = None,
**kwargs: 'Any',
)
주피터랩 등에서 summary?
을 실행시켜서 다른 파라미터들에 대한 자세한 설명을 직접 확인해보자.
반응형
'코딩 환경 > PyTorch' 카테고리의 다른 글
[PyTorch] GPU을 사용할 때 to(device)와 cuda() 차이 (2) | 2023.02.27 |
---|---|
[PyTorch] MNIST 데이터셋 다운로드하고 열어보기 (0) | 2023.02.26 |
[PyTorch] IProgress not found 에러 해결 방법 (0) | 2023.02.26 |
[PyTorch] onnx을 이용한 모델 시각화 및 프레임워크 전환 (0) | 2023.02.25 |
[PyTorch] m1 맥에 GPU 사용 Pytorch 설치 및 에러 해결 / ERROR: Could not find a version that satisfies the requirement torch (0) | 2023.02.15 |