관리 메뉴

Joonas' Note

[강화학습 일지] DQN Tutorial 살펴보기 본문

AI

[강화학습 일지] DQN Tutorial 살펴보기

joonas 2023. 1. 13. 00:20

    PyTorch 공식 문서에서 강화학습(Reinforcement Learning)의 한 예시로 DQN 튜토리얼이 있어서 살펴보기로 했다.

    시간이 많이 지나서 깨달은 사실은, 한글 문서와 영어 문서의 내용과 도메인이 다르다는 것이었다.

     

    한글 문서의 경우에는 Cart-Pole-v0 을 기준으로 작성된 예전 내용이라서, Cart-Pole-v1로 그대로 옮기면 학습도 잘 안되고 동작 방식에도 큰 차이가 있었다.

    참고로 한글 문서는 1200 episode 돌리는 데 약 1시간 정도 걸렸는데 결과가 좋지 않았는데, 영어 문서의 경우에는 수월하게 학습이 진행되었다. GPU 점유는 둘 다 10%를 채 넘기지 않았다.

    (좌) v0 기준인 한글 문서 / (우) v1 기준인 영어 문서

    튜토리얼을 따라했는데, 학습이 잘 되지 않는다면 이 글이 도움이 되기를 바라면서 작성해본다.

     

    동작 비교

    먼저, 한글 문서의 경우에는 env에 액션을 수행하고 렌더링해서 나온 RGB 이미지를 Agent가 CNN 네트워크로 학습하는 방식이었다.

    렌더링 되는 이미지는 40x90 이고, CNN 네트워크은 2개의 클래스로 분류된다. 이 결과는 다음 step의 action이 된다.
    (CartPole의 경우에는 action이 왼쪽=0 또는 오른쪽=1로 2개)

    한글 문서 기준 (render + CNN 기반)

    v1로 바뀌면서 그런 것인지는 모르겠지만, 영어 문서에서는 더 이상 CNN 기반으로 학습하지 않고,
    env의 한 스텝에서 반환되는 관측 결과(observation)를 그대로 네트워크에 넣는 방식으로 바뀌었다.

    observation (object): this will be an element of the environment's `observation_space`.
    This may, for instance, a numpy array containing the positions and velocities of certain

    영어 문서 기준 (observaion + FC Layer 기반)

    그렇다보니 네트워크는 엄청 가벼워졌는데, 관측 결과로 넘어오는 위 4가지 값만으로도 학습이 충분한 모양이었다.

    하나 더 다른 부분은 Target Net을 업데이트할 때 soft update 방식을 사용한다. [관련 논문 링크] 

     

    실행 결과

    하이퍼파라미터들을 조금씩 고치다보면, 최대 3325 episodes 까지 버티는 결과도 있었다.

    Max duration: 1472

    https://github.com/joonas-yoon/practice-on-rl/tree/f0707db982dd5e2302c52fa3245265e0f6b94569/cartpole

     

    GitHub - joonas-yoon/practice-on-rl: Practice on Reinforcement Learning

    Practice on Reinforcement Learning. Contribute to joonas-yoon/practice-on-rl development by creating an account on GitHub.

    github.com

    실행 환경

    Python 3.8.3

    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 461.09       Driver Version: 461.09       CUDA Version: 11.2     |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |                               |                      |               MIG M. |
    |===============================+======================+======================|
    |   0  GeForce GTX 1660   WDDM  | 00000000:01:00.0  On |                  N/A |
    | 46%   40C    P0    19W / 130W |   2203MiB /  6144MiB |     20%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    0 Comments
    댓글쓰기 폼