반응형
이번 포스팅은 PyTorch에서 모델을 저장하고 로드하는 방법에 대해 알아본다.
간단한 신경망 모델을 만들고 모델의 파라미터를 저장한 후 다시 로드하는 과정을 알아본다.
1. 간단한 모델 생성
먼저 간단한 모델을 정의하자.
torch.nn.Module을 상속받아 신경망 구조를 정의한다.
import torch
import torch.nn as nn
import torch.optim as optim
# 간단한 모델 정의
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
# 입력이 10차원이고 출력이 1차원인 완전연결층
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 모델 인스턴스 생성
model = SimpleModel()
# 손실 함수 및 옵티마이저 설정
criterion = nn.MSELoss() # MSE 손실 함수 사용
optimizer = optim.SGD(model.parameters(), lr=0.01) # SGD 옵티마이저 사용
2. 모델 파라미터 저장하기
모델의 파라미터는 state_dict() 메서드를 통해 저장할 수 있다.
이를 통해 모델의 모든 가중치 및 버퍼를 딕셔너리 형태로 반환하며, torch.save() 함수를 사용해 이를 파일로 저장한다.
# 모델 파라미터 저장
torch.save(model.state_dict(), 'model_weights.pth')
# 모델의 파라미터가 저장된 'model_weights.pth' 파일 생성
- model.state_dict()는 모델의 파라미터와 버퍼를 포함한 모든 가중치를 딕셔너리 형태로 반환한다.
- torch.save()는 파이토치 객체를 파일로 저장하는 함수이다.
3. 모델 파라미터 로드하기
새로운 모델 인스턴스를 생성한 후, 해당 모델에 저장된 파라미터를 불러온다.
# 모델 파라미터 로드
model = SimpleModel() # 새로운 모델 인스턴스 생성
model.load_state_dict(torch.load('model_weights.pth'))
# 로드된 파라미터가 모델에 적용
- torch.load('model_weights.pth')는 파일에서 저장된 모델 딕셔너리를 불러온다.
- model.load_state_dict()는 불러온 딕셔너리를 모델에 적용하여 학습된 파라미터를 복원한다.
그 외의 방법 / 모델의 전체 저장과 로드
모델을 통으로 저장하고 불러올 수 있다.
# 모델 전체 저장 (구조 포함)
torch.save(model, 'model_complete.pth')
# 모델 전체 로드 (구조 포함)
model = torch.load('model_complete.pth')
- 모델 전체를 저장하면 가중치뿐만 아니라 모델의 구조 자체도 저장한다.
- torch.load()를 사용해서 저장된 모델을 로드한다.
그러나 복잡한 모델의 경우 state_dict()을 사용하는 것이 더 유연하다.
전체 코드
import torch
import torch.nn as nn
import torch.optim as optim
# 간단한 모델 정의
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 모델 인스턴스 생성
model = SimpleModel()
# 손실 함수 및 옵티마이저 설정
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 모델 파라미터 저장
torch.save(model.state_dict(), 'model_weights.pth')
# 모델 파라미터 로드
model = SimpleModel() # 새로운 모델 인스턴스 생성
model.load_state_dict(torch.load('model_weights.pth'))
# 모델 전체 저장
torch.save(model, 'model_complete.pth')
# 모델 전체 로드
model = torch.load('model_complete.pth')
반응형
'코딩 환경 > PyTorch' 카테고리의 다른 글
[PyTorch] eval() 함수 개념 및 사용법 (0) | 2024.10.29 |
---|---|
[PyTorch] 파이토치에서 특정 GPU 선택, 지정하는 방법 (0) | 2023.05.27 |
[PyTorch] 파이토치에서 .to('cuda'), .cuda() 차이점 (2) | 2023.05.27 |
[PyTorch] 파이토치 코드 실행 속도 높이기 / .cuda()와 device=torch.device('cuda')의 차이 + 애플 m1, m2 칩 (0) | 2023.05.27 |
[PyTorch] num_workers 설명과 빠른 트레이닝을 위한 값 최적화 (0) | 2023.05.24 |