지금까지 살펴본 대부분의 머신러닝은 타겟을 예측하는 방법에 대한 내용이었다.
이런 종류의 학습을 discriminative learning이라고 하는데 classifier(분류기)나 regressor(회귀기?) 모두 discriminative learning의 한 종류이다.
그리고 이런 종류의 학습은 모두 정답이나 레이블이 제공되는 지도학습(supervised learning)이라고 불린다.
하지만 머신 러닝에는 데이터의 레이블이 없이 학습을 하는 비지도 학습(Unsupervised learning)도 있다.
오늘 살펴볼 GAN(Generative Adversarial Network)은 비지도 학습에 혁명을 일으킨 신경망이다.
2014년 Ian Goodfellow가 소개한 GAN은 두 개의 신경망을 동시에 훈련하여 사실적인 이미지, 동영상, 심지어 음악이나 텍스트를 생성하는 데 사용할 수 있다.
현재 GAN은 예술, 패션에서부터 필자가 전공한 물리학까지 다양한 분야에서 인상적인 결과를 보여주고 있다.
이제 GAN이 어떻게 작동하는지 자세히 알아보자.
Generative Adversarial Network(GAN, 적대적 생성 네트워크)의 훈련 과정
먼저 GAN의 훈련을 아주 간단히 설명하자면 다음과 같다.
GAN은 이름에서 알 수 있는 것처럼 두 네트워크가 서로 적대적인 방식으로 훈련한다.
Generator(생성기)라고 불리는 네트워크는 새 샘플을 생성하고 discriminator(판별기)라고 불리는 다른 네트워크는 트레이닝 데이터의 샘플과 생성기가 만든 샘플을 입력받아 진위를 판단한다.
트레이닝 과정에서 생성기는 더 진짜 같은 가짜를 생성해서 판별기를 속이려고 하고, 판별기는 계속하여 진짜와 가짜를 구분하기 위해 노력한다.
이제 생성기가 진짜같은 가짜를 생성할 수 있게 되고 판별기의 정답률이 50%가 되면 즉, 입력받은 샘플이 진짜인지 가짜인지를 구별할 수 없는 순간이 오면 생성기를 가짜 데이터를 만들어내는 Generator로 사용할 수 있는 것이다.
이 과정을 Ian Goodfellow는 논문에서 생성기를 위조지폐범으로, 판별기를 경찰에 비유해서 설명했다.
위조 지폐범은 위조 지폐를 만들고, 경찰은 위조 지폐와 진짜 지폐를 받아 두 지폐를 구분한다.
처음엔 위조 지폐범과 경찰 모두 형편없는 실력을 보여줄 것이다.
하지만 둘을 붙여놓고 계속 위조 지폐를 만들고 구분을 시키면, 위조 지폐범은 점점 그럴듯한 위조 지폐를 만들고 경찰도 구분하는 실력이 늘 것이다.
훈련을 많이 진행해서 이제 위조 지폐범이 진짜와 매우 비슷한 위조 지폐를 만들고, 경찰의 정답률이 50%가 되는 순간이 올 것이다.
드디어 위조 지폐범이 자유롭게 위조 지폐를 생성하여 사용할 수 있는 것이다.
(진짜 이해하기 쉬운 기가 막힌 예다.)
이제 감이 잡혔다면 코드에서 이 과정이 어떻게 구현되는지 알아보자.
1. 생성기 및 판별기 네트워크를 초기화
GAN을 교육하는 첫 번째 단계는 무작위 가중치로 생성기 및 판별기 네트워크를 모두 초기화하는 것이다.
생성기는 랜덤 노이즈를 입력으로 사용하고 실제 데이터와 유사한 출력을 생성한다.
판별기는 실제 데이터와 생성기의 출력을 모두 입력으로 받아 둘을 구별하려고 한다.
2. 판별기 훈련
훈련의 첫 단계에서는 판별기 네트워크만 훈련한다.
우리는 실제 데이터와 생성기에서 생성된 가짜 데이터를 판별기에 넣은 다음 역전파를 사용하여 판별기의 가중치를 업데이트하여 판별기가 예측한 출력과 실제 레이블(예를 들어, 진짜는 1, 가짜는 0) 사이의 이진 교차 엔트로피 손실을 최소화한다.
3. 생성기 훈련
훈련의 두 번째 단계에서는 생성기 네트워크만 훈련한다.
우리는 무작위 노이즈를 생성기에 입력한 다음 생성기의 출력을 판별기에 공급하여 가짜 데이터를 생성한다.
그런 다음 역전파를 사용하여 생성기의 파라미터를 업데이트하여 판별기의 출력과 실제 레이블 사이의 이진 교차 엔트로피 손실을 최대한다.
4. 반복
이제 2, 3번의 과정을 반복한다.
두 네트워크의 손실 함수가 수렴될 때까지 GAN을 계속 훈련한다.
즉, 판별기가 더 이상 실제 데이터와 가짜 데이터를 구별할 수 없으며 생성기가 실제 데이터와 구별할 수 없는 출력을 생성할 수 있을 때까지 훈련한다.
이때 중요한 것은 각각의 과정에서 다른 네트워크는 고정시키고 한 네트워크만 훈련하는 것이다.
이제 직접 코드를 써보자.
PyTorch로 Vanilla GAN 구현하기
아무런 기술이 들어가있지 않은 Vanilla GAN을 직접 만들어보자.
먼저 필요한 라이브러리를 임포트하자.
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
훈련에 사용할 디바이스를 설정한다. (m1 맥 유저도 사용 가능한 코드이다.)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
훈련 과정에서 생성기의 출력을 저장할 디렉토리를 만들자.
# Directory to save the output of the generator
save_path = "./generated_images"
if not os.path.exists(save_path):
os.makedirs(save_path)
나중에 코드를 돌릴 때 같은 결과가 나오도록 랜덤시드를 설정하자. (안 해도 된다.)
# Set random seed for reproducibility
torch.manual_seed(0)
이제 판별기와 생성기를 만들자.
28*28의 MNIST 데이터 세트를 사용할 것이므로 판별기는 28*28=784의 입력을 받아 256개의 노드를 가지는 히든 레이어를 거쳐 진짜인지(1) 가짜인지(0)의 여부를 나타내는 하나의 아웃풋을 가지도록 설정한다.
생성기는 길이가 100인 랜덤 벡터를 받아 256개의 노드를 가지는 히든 레이어를 거쳐 이미지의 크기인 28*28의 아웃풋을 가지도록 설정하였다.
# Define discriminator network
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(28 * 28, 256)
self.fc2 = nn.Linear(256, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.fc1(x)
x = self.sigmoid(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x.squeeze()
# Define generator network
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(100, 256)
self.fc2 = nn.Linear(256, 28 * 28)
self.tanh = nn.Tanh()
def forward(self, x):
x = self.fc1(x)
x = self.tanh(x)
x = self.fc2(x)
x = self.tanh(x)
x = x.view(-1, 1, 28, 28)
return x
# Create discriminator and generator objects
D = Discriminator().to(device)
G = Generator().to(device)
이제 손실 함수와 옵티마이저를 설정하고 MNIST 데이터도 받아오자.
그리고 fixed_noise
도 설정하는데, 변하지 않는 노이즈를 훈련 과정에서 생성기에 입력해봄으로 같은 노이즈에 대해 훈련이 진행될수록 생성기의 출력이 어떻게 변하는지 확인하기 위한 용도이다.
# Define loss function and optimizer for discriminator
criterion = nn.BCELoss()
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Define loss function and optimizer for generator
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Fixed noise settings to monitor the training process
fixed_noise = torch.randn(100, 100).to(device)
Adam 옵티마이저에서 beta1을 기본값인 0.9가 아닌 0.5을 사용한 이유는 "이렇게 하면 훈련이 더 잘 되더라"라는 경험법칙에서 나온 것이다.
이를 포함해 GAN을 안정적으로 더 잘 훈련시키는 여러 팁은 다른 글에서 다뤄보겠다.
이제 트레이닝 루프를 만들자.
매 에포크의 마지막에는 fixed_noise
을 생성기에 넣어 출력을 확인해볼 수 있게 설정했다.
# Training loop
num_epochs = 200
for epoch in range(num_epochs):
for i, data in enumerate(trainloader):
# Train discriminator with real data
D.zero_grad()
real_images, _ = data
batch_size = real_images.size(0)
label = torch.full((batch_size,), 1.).to(device)
output = D(real_images.to(device))
errD_real = criterion(output, label)
errD_real.backward()
# Train discriminator with fake data
noise = torch.randn(batch_size, 100).to(device)
fake_images = G(noise).to(device)
label.fill_(0.)
output = D(fake_images.detach())
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
optimizer_D.step()
# Train generator
G.zero_grad()
label.fill_(1.)
output = D(fake_images)
errG = criterion(output, label)
errG.backward()
optimizer_G.step()
# Print statistics
if i % 100 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
% (epoch + 1, num_epochs, i, len(trainloader), errD.item(), errG.item()))
# Save generated images
with torch.no_grad():
save_image(G(fixed_noise), f"generated_images/{epoch}.png", nrow=10, normalize=True)
위 코드를 전부 복사해서 돌려보자.
결과 확인
결과를 확인해보자.
이번 GAN은 아무런 기술이 들어가지있지 않기 때문에 결과가 불만족스러울 수 있다.
여러 기술이 들어간 GAN은 다른 글에서 차차 소개해보겠다.
먼저 트레이닝 루프를 한 번 돌았을 때 Generator의 출력이다.
가운데 덩어리만 보인다.
두 번 돌았을 때의 출력이다.
뭔가 보일 듯 말 듯 하지만 여전히 덩어리이다.
5 epoch에서의 출력이다.
숫자인지 뭔지는 모르겠지만 뭔가 형태를 갖춰가려는 것이 보인다.
20 에포크만큼 훈련했을 때 출력이다.
숫자처럼 보이는 것과 그렇지 않은 것이 섞여있다.
그리고 주변에 노이즈도 있다.
50 에포크의 출력이다.
제법 읽을만한 숫자이다.
200 에포크의 출력이다.
훨씬 선명하고 깨끗하게 숫자가 나오고 있음을 확인할 수 있다.
결과를 보면서 7이 많은 거 같은 느낌이 들 수 있다.
재수가 없으면 모든 출력이 7일 수도 있다.
별다른 조치를 취하지 않으면 생성기는 보통 판별기를 가장 자신있게 속일 수 있는 방향으로 훈련이 진행되어 결국 생성기가 어떤 특정한 이미지만 계속 생성하게 될 수 있다.
이 현상을 Mode Collapse라고 부른다.
Mode Collapse는 판별기의 훈련이 완벽하지 못하거나 모델이 진동하기 때문일 수 있다.
GAN에 관한 더 많은 이야기는 이후 작성할 글에서 다뤄보도록 하겠다.