Joonas' Note
[강화학습 메모] Proximal Policy Optimization (PPO, 2017) 본문
Proximal Policy Optimization (PPO, 2017)
목적 함수를 완성하기 위한 gradient 식은 아래와 같다.
$$ \nabla_{\theta}J_{\theta} \cong \sum_{t=0}^{\infty} \int_{s_t,a_t,s_{t+1}}\nabla_{\theta}lnp_{\theta}(a_t|s_t)~A_t~p_{\theta}(s_t,a_t)~p(s_{t+1}|s_t,a_t)~ds_t,a_t,s_{t+1} $$
여기서 \(A_t\)는 n-step TD error 인데, \(A_t\)의 정의에 따라서 아래와 같이 달라졌었다.
$$ A_t = \begin{cases} G_t & \longrightarrow REINFORCE \\ Q(s_t) & \longrightarrow Actor-Critic \\ Q(s_t,a_t)-V(s_t) & \longrightarrow A2C \end{cases} $$
(1) Importance sampling
굳이 \(p_{\theta}(s_t,a_t)~p(s_{t+1}|s_t,a_t)\) 에서 뽑지 않고, 다른 pdf 에서 뽑고 보상을 해줌으로써 같은 효과를 누리겠다.
$$ \begin{aligned}E_{x\sim p(x)}[x] &= \int_x xp(x)dx \\ &= \int_x (x \frac{p(x)} {q(x)})q(x)dx \\ &= E_{x\sim q(x)}[x \frac{p(x)}{q(x)}] \end{aligned} $$
좌항은 \(\frac{1}{N}\sum_{i=1}^{N} x_i\) 과 근사하고, 우항은 \(\frac{1}{N}\sum_{i=1}^{N} x_i \frac{p(x_i)}{q(x_i)}\) 과 근사하다. 그럼 두 근삿값이 비슷하다고 어림할 수 있으니 아래처럼 된다.
$$ \frac{1}{N}\sum_{i=1}^{N} x_i ~\approx~ \frac{1}{N}\sum_{i=1}^{N} x_i \frac{p(x_i)}{q(x_i)} $$
위와 같은 importance sampling 을 기반으로 해서 식을 전개한다.
\(p(x_i)\) 를 \(p_{\theta}(a_t|s_t)\) 로 보고, \(q(x_i)\) 를 어떤 옛날 값인 \(p_{\theta_{prev}}(a_t|s_t)\) 로 보는 것이다.
$$ \begin{aligned} \nabla_{\theta}J_{\theta} & \cong \sum_{t=0}^{\infty} \int_{s_t,a_t,s_{t+1}}\nabla_{\theta}lnp_{\theta}(a_t|s_t)~A_t~p_{\theta}(s_t,a_t)~p(s_{t+1}|s_t,a_t)~ds_t,a_t,s_{t+1} \\ & \approx \sum_{t=0}^{\infty} \int_{s_t,a_t,s_{t+1}}\nabla_{\theta} \frac {lnp_{\theta}(a_t|s_t)}{p_{\theta_{prev}}(a_t|s_t)} ~A_t~p_{\theta_{prev}}(s_t,a_t)~p(s_{t+1}|s_t,a_t)~ds_t,a_t,s_{t+1} \end{aligned} $$
이제 더 이상 \(p_{\theta_{prev}}(s_t,a_t)\) 를 버리지 않고 다시 사용할 수 있게 되었다.
Policy update는 아래와 같다.
$$ \theta \longleftarrow \theta + \alpha \nabla_\theta \sum_{i=t-N+1}^{t} \frac {p_\theta(a_i|s_i)} {p_{\theta_{prev}}(a_i|s_i)} A_i $$
위에서 전개한 식이 근삿값인 이유는, \(p_{\theta_{prev}}(s_t) \approx p_\theta(s_t)\) 를 만족해야하기 때문이고, 그러므로 \(\theta\)도 너무 차이가 나면 안되고 \(\theta_{prev} \approx \theta\) 이여야한다.
목표는 \(max~ \sum_{i=t-N+1}^{t} \frac {p_\theta(a_i|s_i)} {p_{\theta_{prev}}(a_i|s_i)} A_i\) 이고 제약사항으로는 \(E_{s_t}~KL[P_{\theta_{prev}},P_\theta] \leqq \delta\) 이다.
💡 여기서 KL 함수는 쿨백-라이블러 발산(Kullback–Leibler divergence, KLD)은 두 확률 분포의 차이를 계산하는 데에 사용하는 함수이다. 어떤 이상적인 분포에 대해 그 분포를 근사하는 다른 분포를 이용해서 샘플링을 했을때 발생할 수 있는 엔트로피의 차이를 계산한다.
💡 라그랑주 완화법(Lagrangian Relaxation)을 사용해서 constraint set을 확장하여 근사적으로 문제를 해결할 수 있다고 한다.
하지만 PPO 논문에서는 클리핑(clipping)을 사용했다.
(2) Clipping
원래 TRPO(Trust Region Policy Optimization, 2015) 논문이 먼저 있었는데, 이걸 클리핑으로 간단하게 해결한 게 PPO 라고 한다.
현재 정책을 \(\pi_{\theta_{OLD}}(u_t\mid x_t)\)라 하자.
또 둘의 비율을 \(r_t(\theta) = \frac{\pi_\theta(u_t\mid x_t)}{\pi_{\theta_{OLD}}(u_t\mid x_t)}\)이라 하자.
그려면 \(r_t(\theta)\)는 1에 매우 가까운 숫자여야 한다. 두 정책이 비슷해야하기 때문이다.
\(1-\epsilon < r_t(\theta)<1+\epsilon\)이로 한정할 것이다.
\(r_t(\theta)\)가 위 범위 안에 있으면 그대로 \(r_t(\theta)\)를 채택할 것이다. 또 \(1+\epsilon\)보다 크면, \(1+\epsilon\)을
\(1-\epsilon\)보다 작으면 \(1-\epsilon\)을 채택할 것이다.
(3) GAE (Generalized Advantage Estimate)
$$ \begin{aligned}
R_t + \gamma V(s_{t+1})-V(s_t) & : 1~step~TD~error \\
R_t + \gamma V(s_{t+1}) + \gamma^2 V(s_{t+2}) - V(s_t) & : 2~stepTD~error \\
& \vdots\\ R_t + \dots + \gamma^n V(s_{t+n}) - V(s_t)
& : n~step~~TD~error \\
\end{aligned} $$
모든 n step 에 대해서 EMA(Exponential moving average)를 구해서 사용한다.
$$ TD(\lambda) \rightarrow \sum_{n=1}^\infty (1-\lambda) \lambda^{n-1}~A_{t}^{(n)} $$
\(A_t^{(n)}\) 을 전개해보면,
$$ \begin{aligned}
A_t^{(1)} & = R_t + \gamma V(s_{t+1})
V(s_t) \triangleq \delta_t \\
A_t^{(2)} & = R_t + \gamma V(s_{t+1})
\gamma^2 V(s_{t+2})
V(s_t) \\ & = (R_t + \gamma V(s_{t+1}) - V(s_t)) + \gamma (R_{t+1} + \gamma V(s_{t+2}) - V(s_{t+1})) \\ &= \delta_t + \gamma \delta_{t_1} \\
A_t^{(n)} & = \sum_{k=t}^{t+n-1} \gamma^{k-t} \delta_k
\end{aligned} $$
위와 같이 표현할 수 있고, 이걸 \(TD(\lambda)\) 로 적은 식에 대입해보면,
$$ \begin{aligned}
TD(\lambda) & \rightarrow \sum_{n=1}^\infty (1-\lambda) \lambda^{n-1}~A_{t}^{(n)} \\
& = (1-\lambda) (\delta_t
\lambda(\delta+\gamma \delta_{t+1})
\lambda^2(\delta+\gamma \delta_{t+1}+ \gamma^2 \delta_{t+2})
\dots ) \\
& = (1-\lambda) ( \delta_t (1+\lambda+\lambda^2+\dots)
\gamma \delta_{t+1} (\lambda+\lambda^2+\dots)
\dots ) \\
& = \sum_{k=t}^{\infty} (\gamma \lambda)^{k-t} \delta_k ~ \triangleq A_t^{GAE}
\end{aligned} $$
중간에는 등비급수 사용하면 \((1+\lambda+\lambda^2+\dots)\) 는 \(\frac{\lambda}{1-\lambda}\) 가 된다. (단, \(|\lambda| < 1\) )
정리
- \(\theta, w~\) 초기화
- 다음을 반복
- N개의 샘플 수집 (\(a^{(i)}$ sample : $\{s_i, a_i, s_{i+1}\}\))
- 다음을 반복 — (epoch)
- Actor update; \(\theta \leftarrow \theta + \alpha \nabla_{\theta} \sum_{i=t-N+1}^{t} J_i^{clip}\)
- Critic update; \(w \leftarrow w + \beta \nabla_{w} \sum_{i=t-N+1}^{t} (\hat A_i^{GAE})^2\)
- 배치 클리어
수식에 action의 확률을 계산하는 부분이 남아있다보니, 완전한 off policy 는 아니고 on policy 라고 한다.
참고
강화학습 알아보기(4) - Actor-Critic, A2C, A3C · greentec's blog
'AI' 카테고리의 다른 글
logit vs. sigmoid vs. softmax (0) | 2023.07.05 |
---|---|
[강화학습 메모] A3C (Asynchronous A2C, 2016) (0) | 2023.03.10 |
[강화학습 일지] DQN Tutorial 살펴보기 (0) | 2023.01.13 |
Loss 또는 모델 output이 NaN인 경우 확인해볼 것 (0) | 2022.04.23 |
[부동산 가격 예측] LightGBM에서 DNN Regression으로 (0) | 2022.04.21 |