GAN 설명

GAN은 일반적인 신경망 학습과 달리 생성 모델 G(generator)와 식별 모델 D(discriminator)를 사용하여 상호 간 학습을 진행한다. G는 k차원의 잠재적 특이 벡터를 입력값으로 받아서 대상(ex. 64 x 64)과 동일 형식의 데이터를 생성하는 신경망이다. D는 지금까지 본 것과 마찬 가지로 대상 데이터를 입력값으로 받아 진위를 식별하는 신경망이다. GAN의 학습 순서를 수식을 사용하지 않고 설명해 보겠다.

 

    1. 잠재적 특이 벡터 z를 난수로 생성하고, G(z)를 사용해 가짜 데이터(fake_data)를 생성

    2. fake_data를 D로 식별

    3. 진짜 데이터의 샘플(real_data)을 D로 식별한다. 

    4. fake_out이 진짜 데이터(1)라고 간주하고, 크로스 엔트로피 함수를 계산해서 G의 파라미터를 갱신한다.

    5. real_out이 진짜 데이터이고, fake_out이 가짜 데이터(0)라고 간주하고, 크로스 엔트로피 함수를 계산해서 D의 파라미터를 갱신한다.

 

이 처럼 서로를 훈련 시키는 것이 GAN의 핵심이다. G나 D에 깊은 CNN을 사용한 것이 DCGAN이다.

 

필요한 데이터를 다운로드 합니다.

!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!tar xf 102flowers.tgz
!mkdir oxford-102
!mkdir oxford-102/jpg
!mv jpg/*.jpg oxford-102/jpg

필요한 라이브러리를 다운로드 합니다.

import torch
from torch import nn, optim
from torch.utils.data import (Dataset, DataLoader, TensorDataset)
import tqdm

from torchvision.datasets import ImageFolder
from torchvision import transforms


from torchvision.utils import save_image

 

디렉터리와 파일 준비가 완료 되었고, DataLoader를 만듭니다. (Pytorch 전용 DataLoader) 여기서는 64 X 64픽셀의 이미지를 생성하므로 데이터의 짧은 변을 80픽셀로 조절한 후, 가운데를 64 X 64픽셀로 잘랐습니다.

img_data = ImageFolder("oxford-102/",
    transform=transforms.Compose([
        transforms.Resize(80),
        transforms.CenterCrop(64),
        transforms.ToTensor()
]))

batch_size = 64
img_loader = DataLoader(img_data, batch_size=batch_size,
                        shuffle=True)

 

이미지 생성 모델

잠재적 특이 벡터 z를 100 차원으로 구성하고 이 z로 부터 3 X 64 X 64(3은 채널)의 이미지를 만드는 생성 모델을 구축한다.

Transposed Convolution을 총 5회 반복하고 있다. 이를 통해 처음에는 100 X 1 X 1 의 z가 256 X 4 X 로 변환되며 최종으로 3 X 64 X 64가 됩니다. 

nz = 100
ngf = 32
class GNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 
                               4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 2, ngf,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf, 3,
                               4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.main(x)
        return out
# 이미지 크기 
out_size = (in_size - 1) * stride -2 * padding \ + kernel_size + output_padding


첫 번째 Transposed Convolution
in_size = 1
stride = 1
padding = 0
kernel_size = 4
output_padding = 0

 

식별 모델 작성

3 X 64 X 64 이미지를 최종적으로는 1차원의 스칼라로 변환하는 신경망을 만든다.

5회 합성곱 연산으로 3 X 64 X 64 이미지가 최종적으로 1 X 1 X 1이 된다. 그리고 forward 마지막에 있는 squeeze는 A X 1 X B X 1 처럼 불필요한 1이 들어 있는 shape를 A X B로 조정하는 처리이다.

ndf = 32

class DNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        )
    
    def forward(self, x):
        out = self.main(x)
        return out.squeeze()

Model Train

fixed_z는 훈련 모니터링용이고 훈련이 진행됨에 따라 어떤 이미지가 만들어지는지 확인한다.

from statistics import mean

def train_dcgan(g, d, opt_g, opt_d, loader):
    # 생성 모델, 식별 모델의 목적 함수 추적용 배열
    log_loss_g = []
    log_loss_d = []
    for real_img, _ in tqdm.tqdm(loader):
        batch_len = len(real_img)
        
        # 실제 이미지를 GPU로 복사
        real_img = real_img.to("cuda:0")
        
        # 가짜 이미지를 난수와 생성 모델을 사용해 만든다
        z = torch.randn(batch_len, nz, 1, 1).to("cuda:0")
        fake_img = g(z)
        
        # 나중에 사용하기 위해서 가짜 아미지의 값만 별도로 저장해둠
        fake_img_tensor = fake_img.detach()
        
        # 가짜 이미지에 대한 생성 모델의 평가 함수 계산
        out = d(fake_img)
        loss_g = loss_f(out, ones[: batch_len])
        log_loss_g.append(loss_g.item())

        # 계산 그래프가 생성 모델과 식별 모델 양쪽에
        # 의존하므로 양쪽 모두 경사하강을 끝낸 후에
        # 미분 계산과 파라미터 갱신을 실시
        d.zero_grad(), g.zero_grad()
        loss_g.backward()
        opt_g.step()
        
        #실제 이미지에 대한 식별 모델의 평가 함수 계산
        real_out = d(real_img)
        loss_d_real = loss_f(real_out, ones[: batch_len])
        
        # PyTorch에선 동일 Tensor를 포함한 계산 그래프에
        # 2회 backward를 할 수 없으므로 저장된 Tensor를
        # 사용해서 불필요한 계산은 생략
        fake_img = fake_img_tensor
        
        # 가짜 아미지에 대한 식별 모델의 평가 함수 계산
        fake_out = d(fake_img_tensor)
        loss_d_fake = loss_f(fake_out, zeros[: batch_len])
        
        # 진위 평가 함수의 합계
        loss_d = loss_d_real + loss_d_fake
        log_loss_d.append(loss_d.item())
        
        # 식별 모델의 미분 계산고 파라미터 갱신
        d.zero_grad(), g.zero_grad()
        loss_d.backward()
        opt_d.step()
                                             
    return mean(log_loss_g), mean(log_loss_d)
for epoch in range(300):
    train_dcgan(g, d, opt_g, opt_d, img_loader)
    # 10회 반복마다 학습 결과를 저장
    if epoch % 10 == 0:
        # 파라미터 저장
        torch.save(
            g.state_dict(),
            "g_{:03d}.prm".format(epoch),
            pickle_protocol=4)
        torch.save(
            d.state_dict(),
            "d_{:03d}.prm".format(epoch),
            pickle_protocol=4)
        # 모니터링용 z로부터 생성한 이미지 저장
        generated_img = g(fixed_z)
        save_image(generated_img,
                   "{:03d}.jpg".format(epoch))

10 에폭마다 결과물이 저장된다.

DCGAN을 통해 생성된 이미지

처음 생성된 이미지를 예시로 가져와 흐릿하지만 모델 후반으로 갈 수록 생성된 이미지가 선명한 것을 확인할 수 있을 것이다.

 

+ Recent posts