관리 메뉴

Joonas' Note

[딥러닝 일지] WGAN (Wasserstein GAN) 본문

AI/딥러닝

[딥러닝 일지] WGAN (Wasserstein GAN)

joonas 2022. 6. 11. 02:03

    이전 글 - [딥러닝 일지] PyTorch로 DCGAN 훈련해보기


    WGAN

    논문: https://arxiv.org/pdf/1701.07875.pdf

    DCGAN의 한계와 차이점

    WGAN은 기존의 DCGAN 네트워크 구조는 거의 그대로 두고, 손실 함수만 바꿔서 학습을 안정화시켰다.

    모드 붕괴

    이진 분류(Binary Cross Entropy)로 진짜/가짜 여부만 판별하기 때문에 발생하는 문제가 있다.

    WGAN 논문 Figure 7

    결국은 판별자를 속이는 이미지를 만들도록 학습했기 때문에, 잘 속이는 일부 샘플(mode)을 발견하면 그것만 계속 만들어낸다.
    잠재 공간의 모든 포인트가 일부 샘플로 모이면, 손실 함수의 미분값이 0에 가까운 값으로 무너지게 된다. (mode collapse)

    이전 글에서도 직접 학습해 본 결과에서 확인할 수 있다.

    비슷하게 생성된 이미지들

    오로지 판별자에 의한 역전파

    DCGAN에서는 생성자를 판별자의 결과만으로 업데이트되었기 때문에, 판별자에 비하면 학습이 잘 안될 수밖에 없다.

     

    그래서 WGAN 논문은, GAN을 훈련하면서 아래 두 가지를 다루는 내용을 제시한다.

    • 생성자에게 의미 있는 손실 측정 방법
    • 안정적인 학습

     


     

    새로운 손실 함수

    GAN과 WGAN의 손실 함수를 비교해보자. 먼저 GAN의 손실 함수 정의이다.

    \( \min_G \max_D V(D,G)=\mathbb{E}_{x\sim p_{data(x)}} [\log D(x)]+\mathbb{E}_{z\sim p_{z}(z)}[\log (1-D(G(z))] \)

    다음으로 WGAN의 손실 함수이다.

    \( \min_G \max_D V(D,G)=\mathbb{E}_{x\sim p_{data(x)}} [D(x)]+\mathbb{E}_{z\sim p_{z}(z)}[D(G(z))] \)

    로그 함수가 빠지면서 손실을 더 큰 값으로 본다.

    판별자의 마지막 레이어였던 sigmoid 활성화 함수로 제거해서, 예측 값의 범위가 \([0, 1]\) 에서 \([-\infty, \infty]\) 가 된다.
    이런 이유로 WGAN의 판별자(Discriminator)는 0/1 판별이 아니라 점수를 주다 보니 비평자(Critic)라고 불린다고 한다.

    그래서 WGAN에서는 learning rate가 GAN 보다 작은 값을 선택한다고 한다.

    그럼 이대로 학습을 시켜보면 어떻게 될까?

    걷잡을 수 없이 발산하는 loss 값

    Weight Clipping

    이제 비평자의 결과가 \([-\infty, \infty]\) 사이의 값이다보니, 사실상 손실이 무한으로 나올 수 있다. 그럼 local minimum을 찾지 못하고 튕겨져 나갈 수 있다는 것이다.

    여기서부터 갑자기 립시츠 제약, 립시츠 연속 함수(Lipschitz-continuous function) 같은 개념이 등장하는데, 완전히 이해하지는 못하겠다.
    변화량에 제한을 두어서 기울기를 한정하겠다는 내용을 수학적으로 증명한 것으로 이해하고 넘겼다.

    WGAN 논문에서는 이것을 가중치 클리핑(weight clipping)이라고 했다. 모델이 업데이트 될 때 가중치를 \([-0.01, 0.01]\) 사이로 제한해서 (clipping) 안정적인 학습을 이끌어냈다.

    그래서 이 버전의 WGAN 모델을 WGAN-CP 라고도 부른다.
    왜냐하면 이후에 WGAN-GP가 나왔기 때문. (Improved Training of Wasserstein GANs, NIPS 2017)

    그리고 그 WGAN-GP 논문에서 clipping과 gradient norm의 관계에 대한 내용이 언급된다.

    WGAN-GP 논문 Figure 1b

    아무튼 클리핑을 적용하면 아주 안정적으로 loss가 수렴하는 것을 확인할 수 있다.

    가중치 클리핑 후 loss 그래프

     

    비평가를 더 훈련

    한 가지 더 바뀐 점이 있다. 논문에서는 생성자보다 비평가를 더 많이 훈련했을 때 미분값이 더 괜찮았다고 한다.

    the more we train the critic, the more reliable gradient of the Wasserstein we get

    논문에서는 생성자가 1번 훈련할 동안 비평자(판별자)를 5번 훈련해서 둘 사이의 균형을 맞추고 있다.

     

    코드로 보기

    매 스텝마다 아래와 같이 비평자(판별자)를 훈련한다. 판별자에서 가짜 이미지에 대한 손실 함수가 \(\log(1-D(G(z)))\) 에서 \(D(G(z))\)로 바뀌면서 코드가 짧아졌다.

    netD.zero_grad()
    
    real_imgs = data[0].to(device)
    b_size = real_imgs.size(0)
    noise = torch.randn(b_size, nz, 1, 1, device=device)
    fake_imgs = netG(noise)
    
    errD = -torch.mean(netD(real_img)) + torch.mean(netD(fake_img)) # Wasserstein
    errD.backward()
    optimizerD.step()
    
    # Clip weights of discriminator
    for p in netD.parameters():
        p.data.clamp_(-0.01, 0.01)

    생성자는 5번째 배치마다 훈련한다. 그 외에 BCE(Binary Cross Entropy)가 바뀐 것 말고는 달라진 것은 없다.

    if step % 5 == 0:
        netG.zero_grad()
        gen_img = netG(noise)
        errG = -torch.mean(netD(gen_img))
        errG.backward()
        optimizerG.step()

     

    결과

    배치 크기는 128로 잡고 10 epochs 정도 학습한 결과이다.

    128 batch size, 10 epochs

    결과물이 생각만큼 나아지지 않아서 논문대로 배치 크기를 64로 잡고, DCGAN에서 10 epochs 학습했던 것을 이번에는 30 epochs 만큼 해봤는데도 결과는 비슷했다.

    64 batch size, 30 epochs
    64 batch size, 30 epochs

    논문에서 LSUN-bedroom 데이터셋으로 비교했을 때는 엄청 깔끔하게 정돈된 이미지가 나왔는데, 내가 실행한 것은 CelebA 데이터 셋이라 그런건지 놓친 하이퍼파라미터가 있는 지 모르겠다.

    WGAN-CP(Clipping)은 DCGAN에 비해서 결과물은 나아진 지 잘 모르겠으나 손실 측정만 바꿔서 안정적으로 학습되는 것은 무척 의미있어 보인다.

     

    여기까지는 모두 Adam 옵티마이저를 사용했는데, WGAN과 모멘텀 기반의 최적화 함수들(Adam 등)은 잘 맞지 않다고 한다.

    그래서인지 LSUN 데이터셋은 학습이 굉장히 잘 안되었는데, 최적화를 RMSProp(Adam에서 momentum이 빠짐) 함수로 바꿨더니 잘 되었다.

    (위) Adam / (아래) RMSProp

    자세한 결과는 아래 노트북에 있다.

     

    노트북

    GitHub

     

    GitHub - joonas-yoon/practice-on-dl: practice on deep learning including ML

    practice on deep learning including ML. Contribute to joonas-yoon/practice-on-dl development by creating an account on GitHub.

    github.com

    Kaggle

     

    🔥 WGAN-CP with CelebA and LSUN dataset

    Explore and run machine learning code with Kaggle Notebooks | Using data from multiple data sources

    www.kaggle.com

     

    참고

     

    반응형
    0 Comments
    댓글쓰기 폼