Joonas' Note

Joonas' Note

[딥러닝 일지] CycleGAN 본문

AI/딥러닝

[딥러닝 일지] CycleGAN

2022. 6. 23. 14:22 joonas

    (이전글 작성 중)


    논문: https://arxiv.org/pdf/1703.10593.pdf

    논문 저자 PyTorch 구현체: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/

    CycleGAN

    Image-to-Image Translation에 대한 논문 중 pix2pix를 기반으로 나온 이후의 논문이다. (pix2pix는 작성중)

    그러다보니, Discriminator는 PatchGAN을 쓰고, Generator로는 LSGAN을 사용하고 있다고 한다.

    pix2pix는 데이터 셋에서 변환하려는 두 이미지 쌍을 pair하게 가지고 있어야했지만, CycleGAN에서는 unpair한 데이터 셋 A, B를 주어도 잘 변환하는 점이 특징이다.
    즉, 실제 사진을 모네풍의 그림으로 바꾸기 위해서 모네가 그린 그림의 실제 배경 사진이 없어도 된다는 것이다.

    논문 Figure 2.

    그리고 CycleGAN은 Unpaired한 데이터로 학습했음에도 불구하고, Paired한 데이터셋으로 학습하는 pix-2-pix와 비교가 가능할 정도의 성능을 낸다.

    참고로 이 글에서는 bdd100k dataset에서 차량 주행 사진들의 낮과 밤을 바꿔보기로 한다.

    구조

    Generator와 Discriminator가 서로 경쟁하는 GAN의 구조는 동일하지만, Cycle이라는 개념을 사용한다.

    구조를 이런 식으로 많이 그리던데, 이미지 A를 주면 B 스타일의 이미지를 생성하는 함수 \(G: X \rightarrow Y\) 와 그 반대인 함수 \(F: Y \rightarrow X\)를 학습한다.

    \(G\)를 통해서 만든 이미지를 다시 \(F\)를 거치면 원래 이미지가 나와야 한다는 것이 한 Cycle 이다.( \(A \approx F(G(A))\) )

    그리고 A 스타일의 이미지에 대해서 Real/Fake를 판별하는 \(D_A\)와, B 스타일에 대해서 판별하는 \(D_B\)가 있다.

    즉, 네트워크가 4개이다. 엄청 크다보니 학습 속도가 그렇게 빠르지는 않다. 데이터는 적으면 500장, 보통 1000장 정도만 있어도 translation을 할 수 있다고 한다.

    이런 사이클의 구조로 학습을 하다보니 A → B image translation 뿐 아니라 B → A image translation 도 가능한 모델이다.

    판별자로 넘어가는 Adversarial Loss 말고도 Cycle Consistency Loss를 추가로 사용한다. ( \(A \approx F(G(A))\) ) 에 대해서도 backpropagation을 해준다는 것인데, 이 손실이 없으면 \(G\)가 주어진 이미지 A와 상관없이 진짜라고 속일 수 있는 몇 장의 이미지만 계속 생성할 수 있기 때문이다.

    논문 Figure 7.

    논문에서는, 이미지를 \(x \rightarrow y\) 로 변환한다고 치면, \(F(G(x)) \approx x\)를 forward, \(G(F(y)) \approx y\)를 backward 라고 부르고 있다.

     

    \(G\)와 \(F\)를 중심으로 다시 그려봤는데, 실제 논문의 구현체에서는 Identity A/B 결과도 사용하고 있다. 그런데 이건 optional이다. 이것을 사용하지 않아도 학습에는 문제가 없다고 한다.

    Loss_identity

    사진과 그림 간의 image translation이 있을 때 색조가 변하는 경우가 있는데, 그걸 막기 위해서 원본의 색감을 유지해야할 때 Identity A/B 결과의 loss를 쓰면 좋은 결과를 보인다고 한다.

     

    궁금증 정리

    몇 가지 내용이 궁금해서 찾아보기도 했고, 아래 네이버 랩 영상에서의 질의 응답도 정리해봤다.

     

    Q1) CycleGAN에서 WGAN/WGAN-GP를 loss 함수로 사용하는 것은 별로인가? 안정적으로 학습되지 않을까?

    A1) 별로라고 한다. 왜냐하면 Discriminator가 PatchGAN에 있는 걸 사용하는데, 그게 패치단위로 뜯어서 보니까 각 샘플이 독립적이지 않아서 WGAN의 가정이 깨지므로 소용없다고 한다. 물론 해볼수는 있고 논문의 pytorch 코드에서도 제공하고 있지만, 학습도 잘 되지 않고 결과가 수렴하지도 않는다고 한다.
    https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/1103

     

    Q2) LSGAN vs DCGAN 중에 LSGAN을 사용한 이유는?

    A2) 정확한 이유는 모르겠지만 LSGAN이 더 결과가 좋아서 사용했다고 한다.

     

    Q3) Cycle을 1개가 아니라 여러 개를 하면?

    A3) mode collapse도 잘 발생하고, 결과가 좋지 않았다고 한다.

     

    Q4) U-Net vs ResNet

    A4) U-Net의 경우에는 skip connection에 의존도가 많이 높아서 인코더로 들어갔다가 디코더로 다시 나오는 값들이 디테일한 부분을 잃는 문제가 있었다고 한다. 그래서 Fully Convolutional 이면서 residual block을 쌓은 ResNet을 사용했다고 한다.

     

    손실 함수로 cross entropy loss는 미분값을 사용하는 특성 때문에 결국에는 gradient vanishing 문제가 발생했기 때문에,  Least square GANs [Mao et al. 2016] 에서 제안한 Mean Square Error 처럼 생긴 loss를 사용했다.

     

    강화학습에서 주로 사용하는 Replay buffer를 사용하는데, 이전의 Generator가 만들었던 이미지들을 Discriminator에게 다시 보여주는 방식을 사용했다.

    이것 때문에 학습 중에 메모리가 한번 터졌다. 학습이 조금 진행되면 중간에 메모리 사용량이 한번 튀어오르는데, 이 replay buffer에 의한 것임이 틀림없다.

     

    학습 결과

    학습을 직접 안 돌려볼 수가 없다. bdd100k 데이터셋으로 낮과 밤을 바꾸는 학습을 해봤고 결과는 아래와 같다.

    18 epochs, batch_size=6, lr=0.002, G=resnet6
    결과물

    위 결과물을 만드는 데 사용된 하이퍼 파라미터는 다음과 같다.

    더보기
    ----------------- Options ---------------
                   batch_size: 6                             	[default: 1]
                        beta1: 0.5                           
              checkpoints_dir: ./checkpoints                 
               continue_train: True                          	[default: False]
                    crop_size: 256                           
                     dataroot: ../bdd100k/bdd100k/bdd100k/images/100k/train/	[default: None]
                 dataset_mode: unaligned                     
                    direction: AtoB                          
                  display_env: main                          
                 display_freq: 400                           
                   display_id: 1                             
                display_ncols: 4                             
                 display_port: 8097                          
               display_server: http://localhost              
              display_winsize: 256                           
                        epoch: latest                        
                  epoch_count: 1                             
                     gan_mode: lsgan                         
                      gpu_ids: 0                             
                    init_gain: 0.02                          
                    init_type: normal                        
                     input_nc: 3                             
                      isTrain: True                          	[default: None]
                     lambda_A: 10.0                          
                     lambda_B: 10.0                          
              lambda_identity: 0.5                           
                    load_iter: 0                             	[default: 0]
                    load_size: 286                           
                           lr: 0.0002                        
               lr_decay_iters: 50                            
                    lr_policy: linear                        
             max_dataset_size: inf                           
                        model: cycle_gan                     
                     n_epochs: 100                           
               n_epochs_decay: 100                           
                   n_layers_D: 3                             
                         name: bdd100k                       	[default: experiment_name]
                          ndf: 64                            
                         netD: basic                         
                         netG: resnet_6blocks                	[default: resnet_9blocks]
                          ngf: 64                            
                   no_dropout: True                          
                      no_flip: False                         
                      no_html: False                         
                         norm: instance                      
                  num_threads: 12                            	[default: 4]
                    output_nc: 3                             
                        phase: train                         
                    pool_size: 50                            
                   preprocess: resize_and_crop               
                   print_freq: 100                           
                 save_by_iter: False                         
              save_epoch_freq: 5                             
             save_latest_freq: 5000                          
               serial_batches: False                         
                       suffix:                               
             update_html_freq: 1000                          
                    use_wandb: True                          	[default: False]
                      verbose: False                         
    ----------------- End -------------------

    낮 사진을 어둡게 만드는 낮→밤 변환은 자연스럽게 잘 되었다. 가로등처럼 밝게 빛나는 부분을 제외하면 잘 변환되는 편이다. 반대로 밤→낮 변환의 경우는 하늘이나 건물 쪽을 새롭게 생성하는 경우에는 어색한 경우가 종종 있다.

    하지만 만족스러운 결과이다.

    Anime2Celeb

    다른 데이터셋으로 학습을 확인해보고싶어서 데이터셋을 고르다가, 애니메이션의 인물 그림을 사진처럼 바꾸는 것을 해보기로 했다.

    학습 데이터셋 샘플들

    Kaggle에 있는 Anime Face DatasetCelebA Dataset을 각각 trainA와 trainB로 사용했다.

    전부 학습하기에는 데이터가 너무 많아서 상당히 오래걸렸다. 그리고 이미지들의 픽셀값 분포가 잘 맞는 지도 모르겠어서 한번 정리했다.
    그래서 데이터셋의 일부만 사용했는데, CelebA 데이터셋 20만장 중에서 attributes.csv 파일을 참고하여 눈코입의 위치가 평균에 가까운 것만 모아서 1600장 정도로 추렸다.
    Anime Face의 경우에는 크기가 너무 작은 것들(300x300 이하)은 삭제했고, 2019년 그림들로만 1100장 정도 사용했다.

    그리고 이미지는 128x128로 resize하고 crop은 따로 하지 않았다.

    200 epochs 정도 학습했고 4시간 반쯤 걸린 것 같다.

    anim2celeb 결과

    기존의 이미지에서 나름 각 도메인에 맞게 잘 바꾸고 있다! 무척 신기하다.

    물론 계속 잘 되지는 않았고 파라미터를 몇 번 튜닝해서 이 정도까지의 결과는 얻었다.

    Image A, B, 그리고 G(A)
    이것도 잘 된 것 같아서 추가

    실사 → 애니메이션의 경우에는 반대 변환에 비해 상대적으로 잘 되었다. 얼굴을 지우고 큰 눈만 넣어도 결과가 괜찮아서 그런 듯 하다. 사실상 지우고 새로 그리는 것에 가깝다.
    그리고 실사 이미지 A의 머리 색깔을 그대로 가져갈 때도 있고 아예 덮어버리고 새로 그리는 경우도 있다.

    애니메이션 → 실사는 몇 개를 제외하고는 거의 변환이 잘 안되었다. 실사 이미지가 픽셀이 더 복잡해서 그런 것이라고 생각한다.

    A, B와 F(B)

    그래도 이 정도면 만족스러운 결과를 확인했다.

     

    참고

    P. S.

    cycleGAN 이전에 같은 연구실의 CUT 이라는 논문이 있긴 한데, 이것도 읽어봐야 하나 고민

     

    Comments