관리 메뉴

Joonas' Note

[딥러닝 일지] WGAN-GP (Gradient Penalty) 본문

AI/딥러닝

[딥러닝 일지] WGAN-GP (Gradient Penalty)

joonas 2022. 6. 12. 22:43

    이전 글 - [딥러닝 일지] WGAN (Wasserstein GAN)


    WGAN-GP

    논문: https://arxiv.org/abs/1704.00028

     

    앞선 WGAN에서 애매하게 넘어간 것이 있다. 바로 weight clipping 이다.

    논문 Figure 1.b

    얼마만큼의 weight로 clipping을 제한할 것인가는 매직 넘버였다. 논문에서는 [-0.01, 0.01]을 사용했지만, 대부분의 가중치들이 양쪽 끝에 걸린 것을 볼 수 있다.

    Gradient penalty는 weight clipping처럼 한 쪽으로 몰려있지 않고 가중치가 고르게 퍼져있다.

    변경 사항으로는, loss function을 gradient penalty를 계산해서 새로 정의한 것과, 판별자 모델에서 배치 정규화(Batch normalization) 층이 빠지는 것이다.

     

    Gradient penalty

    이미 인터넷에 충분히 많은 자료와 자세한 설명이 많으므로, 사용한 코드만 잘라서 붙인다.

    def gradient_penalty(D, real_samples, fake_samples):
        alpha = torch.randn(real_samples.size(0), 1, 1, 1, device=device)
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = D(interpolates)
        fake = torch.ones(real_samples.shape[0], 1, device=device)
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    확실히 수식보다는 코드가 비교적 이해하기는 쉬운 것 같다.

    netD.zero_grad()
    
    real_imgs = data[0].float().to(device)
    fake_imgs = netG(noise)
    
    real_validity = netD(real_imgs)
    fake_validity = netD(fake_imgs)
    
    # Gradient penalty
    gp = gradient_penalty(netD, real_imgs.data, fake_imgs.data)
    
    # Adversarial loss
    errD = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
    errD.backward()

    그리고 WGAN-GP 버전은 모멘텀 기반의 optimizer인 Adam을 사용해도 문제 없이 동작한다. (아래에 Figure 3 참고)

     

    결과

    WGAN-GP로 CelebA 학습한 loss 그래프

    LSUN 데이터셋을 학습할 때의 두 모델의 손실은 아래와 같았다.

    많이 흔들리지만 이정도면 안정적이다
    LSUN과 CelebA 학습 결과

    단순히 Linear 레이어를 쌓은 FC 인데 이정도까지 표현이 되는 게 신기하다!

     

    삽질 기록

    중간에 놓친 것이 있어서 계속 잘못된 학습을 했다. 손실 함수를 잘못 작성했는 지 수식을 붙잡아도 보고, 여러 구현체를 참고해보기도 했는데 여전히 학습이 안정화되지 못하고 loss는 계속 발산했다.

    loss diverge

    인터넷에서도 손실이 발산하는 사람들은 여럿 있었는데, 그런 질문들마다 답변은 없었다. [링크]
    그나마 있는 답변들도, 기존에 구현된 모델을 불러와서 해보라고 했는데 지금 생각해보니 이게 정답이었다(!)

    이대로 넘어갈까 고민했는데, 아무리 해봐도 논문처럼 재현되지 않아서 해결하고 싶었다. 근데 loss가 이렇게 발산해도 학습은 나름 된다.

    발산하지만 학습은 된다.

    그러다 아래 그림을 보고 이상한 것을 알아차렸다.

    논문 Figure 3.

    DCGAN과 Weight clipping, 그리고 Gradient penalty가 서로 구분되어 표기되어 있다.

    그렇다. 이유는, DCGAN에서 WGAN으로 손실 함수만 변경하면서, Convolution Net 구조로 학습한 것이 문제였다.
    논문도 그렇고 각종 WGAN 구현체에서는 생성자(Generator)와 판별자(Discriminator, 비평자;Critic)는 모두 단순한 MLP(Multi-Layer Perceptron)로 되어있었다.

    이것이 가장 큰 원인이라고 생각하고, 모델 구조를 아래와 같이 변경했더니 바로 해결되었다.

    1. 배치 정규화 층을 모두 제거

    2. Convolutional 2D 레이어를 모두 제거하고 Fully connected로 구성

     

    노트북

    https://github.com/joonas-yoon/practice-on-dl/blob/main/04_WGAN/WGAN-GP.ipynb

     

    GitHub - joonas-yoon/practice-on-dl: practice on deep learning including ML

    practice on deep learning including ML. Contribute to joonas-yoon/practice-on-dl development by creating an account on GitHub.

    github.com

    https://www.kaggle.com/code/joonasyoon/wgan-gp-with-celeba

     

    참고

    반응형
    0 Comments
    댓글쓰기 폼