Joonas' Note
[강화학습 일지] DQN Tutorial 살펴보기 본문
PyTorch 공식 문서에서 강화학습(Reinforcement Learning)의 한 예시로 DQN 튜토리얼이 있어서 살펴보기로 했다.
시간이 많이 지나서 깨달은 사실은, 한글 문서와 영어 문서의 내용과 도메인이 다르다는 것이었다.
- 한글 문서: https://tutorials.pytorch.kr/intermediate/reinforcement_q_learning.html
- 영어 문서: https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
한글 문서의 경우에는 Cart-Pole-v0 을 기준으로 작성된 예전 내용이라서, Cart-Pole-v1로 그대로 옮기면 학습도 잘 안되고 동작 방식에도 큰 차이가 있었다.
참고로 한글 문서는 1200 episode 돌리는 데 약 1시간 정도 걸렸는데 결과가 좋지 않았는데, 영어 문서의 경우에는 수월하게 학습이 진행되었다. GPU 점유는 둘 다 10%를 채 넘기지 않았다.
튜토리얼을 따라했는데, 학습이 잘 되지 않는다면 이 글이 도움이 되기를 바라면서 작성해본다.
동작 비교
먼저, 한글 문서의 경우에는 env에 액션을 수행하고 렌더링해서 나온 RGB 이미지를 Agent가 CNN 네트워크로 학습하는 방식이었다.
렌더링 되는 이미지는 40x90 이고, CNN 네트워크은 2개의 클래스로 분류된다. 이 결과는 다음 step의 action이 된다.
(CartPole의 경우에는 action이 왼쪽=0 또는 오른쪽=1로 2개)
v1로 바뀌면서 그런 것인지는 모르겠지만, 영어 문서에서는 더 이상 CNN 기반으로 학습하지 않고,
env의 한 스텝에서 반환되는 관측 결과(observation)를 그대로 네트워크에 넣는 방식으로 바뀌었다.
observation (object): this will be an element of the environment's `observation_space`.
This may, for instance, a numpy array containing the positions and velocities of certain
그렇다보니 네트워크는 엄청 가벼워졌는데, 관측 결과로 넘어오는 위 4가지 값만으로도 학습이 충분한 모양이었다.
하나 더 다른 부분은 Target Net을 업데이트할 때 soft update 방식을 사용한다. [관련 논문 링크]
실행 결과
하이퍼파라미터들을 조금씩 고치다보면, 최대 3325 episodes 까지 버티는 결과도 있었다.
https://github.com/joonas-yoon/practice-on-rl/tree/f0707db982dd5e2302c52fa3245265e0f6b94569/cartpole
실행 환경
Python 3.8.3
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 461.09 Driver Version: 461.09 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 GeForce GTX 1660 WDDM | 00000000:01:00.0 On | N/A |
| 46% 40C P0 19W / 130W | 2203MiB / 6144MiB | 20% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
'AI' 카테고리의 다른 글
[강화학습 메모] Proximal Policy Optimization (PPO, 2017) (0) | 2023.03.11 |
---|---|
[강화학습 메모] A3C (Asynchronous A2C, 2016) (0) | 2023.03.10 |
Loss 또는 모델 output이 NaN인 경우 확인해볼 것 (0) | 2022.04.23 |
[부동산 가격 예측] LightGBM에서 DNN Regression으로 (0) | 2022.04.21 |
[PyTorch] Tensor, NumPy, Pandas 타입 표 (0) | 2022.04.19 |