Custom Dataset의 필요성
MNIST 또는 CIFAR-10과 같이 PyTorch의 기본 제공 데이터 세트에서 직접 로드할 수 있는 형식이 아닌 데이터 세트가 있는 경우 CustomDataset이 필요하다.
즉, 사용자 지정 형식의 데이터셋이 있을 때, 데이터를 제대로 처리하려면 Dataset의 하위 클래스인 CustomDataset을 만들어야 한다.
예를 들어, 이미지 데이터 세트가 있는 경우 데이터 세트에서 이미지와 해당 레이블을 로드할 CustomDataset을 만들어야한다는 이야기다.
이 글에서는 Dataset와 DataLoader 클래스를 알아보고 어떻게 Custom Dataset을 사용할 수 있는지 알아보자.
Dataset와 CustomDataset의 다른 점 이해하기
PyTorch Dataset과 CustomDataset의 주요 차이점은 특정 데이터 세트 또는 작업에 대해 특별하게 CustomDataset이 생성되는 반면, Dataset은 모든 데이터 세트에 대해 사용할 수 있는 보다 일반적인 클래스라는 것이다.
Dataset 클래스는 데이터에 액세스하기 위한 인터페이스를 제공하며, 필요한 메서드를 구현하는 하위 클래스를 만들어야 한다.
반면 CustomDataset은 특정 데이터 세트 또는 작업을 위해 특별히 생성되는 데이터 세트의 하위 클래스이며, 우리는 데이터를 로드하고 그에 따라 개별 샘플에 액세스하는 방법을 정의해야 한다.
Dataset 사용하기
Dataset은 머신 러닝에서 모델을 트리이닝하고 테스트하는 데 사용되는 데이터의 모음이다.
PyTorch에서 torch.utils.data.Dataset
클래스는 데이터에 액세스할 수 있는 인터페이스를 제공하며 __len__()
및 __getitem__()
에서 사용자 정의 데이터셋을 정의할 수 있다.
Dataset의 구성은 다음과 같다.
__init__(self)
: 필요한 변수들을 선언한다. 데이터 파일을 로드하거나 x, y을 설정해준다. 데이터의 전처리도 가능하다.__len__(self)
: 데이터셋의 길이, 즉 총 샘플 수를 써준다.__getitem__(self, index)
: 데이터셋에서 특정 index 한 개의 샘플을 가져오는 함수로 여기서 텐서를 return해준다.
예제를 통해서 Dataset의 사용 방법을 알아보자.
import torch
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, image_paths, labels):
self.image_paths = image_paths
self.labels = labels
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
image_path = self.image_paths[index]
label = self.labels[index]
image = torch.load(image_path)
return image, label
위 예제에서는 ImageDataset
을 만들었다.
이 Dataset에는 image_paths
및 lables
이라는 두 가지 인수가 사용된다.
__len__()
메서드는 데이터 세트의 총 이미지 수를 리턴하고 __getitem__()
메서드는 이미지 텐서와 해당 레이블을 포함하는 튜플을 리턴한다.
이 dataset를 사용하려면 다음과 같이 해보자.
image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg', ...]
labels = [0, 1, ...]
dataset = ImageDataset(image_paths, labels)
image_paths
을 어떻게 주든 __len__()
와 __getitem__()
에서 잘 작동하게끔 만들면 거의 된다.
그런 다음 torch.utils.data
모듈의 DataLoader
클래스를 사용하여 데이터를 일괄 로드할 수 있다.
from torch.utils.data import DataLoader
batch_size = 32
num_workers = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
이때 shuffle
은 카드 덱에서 카드를 섞는 것처럼 데이터를 무작위로 섞은 다음 모델을 트레이닝 하겠다는 것이다.
위 예에서는 데이터를 크기 32의 배치로 로드하고 각 에포크 전에 데이터를 섞는 DataLoader
를 정의했다.
이제 DataLoader
를 반복하여 데이터를 가져올 수 있다.
for images, labels in dataloader:
# do something with the data
CustomDataset 사용하기
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data_path, transform=None):
self.data_path = data_path
self.transform = transform
self.data = []
self.targets = []
# Load data from file
with open(data_path, "r") as f:
for line in f:
line = line.strip().split(",")
self.data.append(line[0])
self.targets.append(int(line[1]))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# Load image from file
img_path = self.data[index]
img = Image.open(img_path).convert("RGB")
# Apply transformation if provided
if self.transform:
img = self.transform(img)
# Get target label
target = self.targets[index]
return img, target
위 예제에서 먼저 주목할 것은 CustomDataset 클래스가 Dataset 클래스의 하위 클래스라는 것이다.
Dataset을 서브클래싱하여 데이터셋의 길이를 가져오고 인덱싱을 통해 각각의 데이터에 엑세스 하는 기능과 같은 기본적인 기능들을 상속받는다.
이제 각 메서드를 살펴보자.
__init__()
메서드는 데이터 파일의 경로와 해당 레이블을 메모리로 로드한다.
로드하는 방식의 예시도 적기는 했는데 각자 상황에 맞게 변경하면 된다.
위 코드에서는 파일을 열고 각 행을 읽어 이미지 경로와 레이블을 분할한다.
그런 다음 이미지의 경로와 레이블을 클래스 속성으로 저장되는 두 개의 리스트에 추가한다.
__len__()
메서드는 데이터 파일의 이미지 수에 해당하는 집합의 길이를 리턴한다.
__getitem__()
메서드는 데이터 세트에서 개별 항목을 로드한다.
이 경우 각 항목은 이미지와 해당 레이블을 포함하는 튜플이다.
이 메서드는 먼저 제공된 인덱스를 사용하여 해당 목록에서 이미지 경로 및 대상 레이블을 검색한다.
그런 다음 Pillow 라이브러리를 사용해서 디스크에서 이미지를 로드하고 RGB 형식으로 변환한다.
살펴본 것처럼 Dataset은 모든 유형의 데이터셋을 다룰 수 있는 범용적인 클래스이고, 특정 데이터 형식이나 유형을 처리할 때 CustomDataset을 사용할 수 있다.
위에서 살펴본 예는 이미지 데이터를 처리하도록 설계되었는데, 이를 위해서는 이미지를 디스크에서 로드하고 변환을 적용해야 PyTorch 모델에서 사용할 수 있다.
데이터에 특정 형식이 있거나 사용자의 입맛에 맞게 전처리가 필요한 경우 Custom Dataset을 사용해보자.
'코딩 환경 > PyTorch' 카테고리의 다른 글
[PyTorch] 서버에서 TensorBoard(텐서보드) 실행하고 Port Forwarding(포트 포워딩)으로 로컬에서 모니터링하는 방법 (1) | 2023.05.15 |
---|---|
[PyTorch] 예제를 통해 알아보는 TensorBoard(텐서보드) 사용법 및 Duplicate plugins for name projector 에러 해결 (2) | 2023.05.15 |
[PyTorch] GPU을 사용할 때 to(device)와 cuda() 차이 (2) | 2023.02.27 |
[PyTorch] MNIST 데이터셋 다운로드하고 열어보기 (0) | 2023.02.26 |
[PyTorch] IProgress not found 에러 해결 방법 (0) | 2023.02.26 |