Skip to content

Commit c1e9b0c

Browse files
committed
📚 ppo intro
1 parent 5442dfb commit c1e9b0c

File tree

6 files changed

+364
-35
lines changed

6 files changed

+364
-35
lines changed

docs/rl/ppo/index.html

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,22 @@
7575
<h1>Proximal Policy Optimization (PPO)</h1>
7676
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of
7777
<a href="https://arxiv.org/abs/1707.06347">Proximal Policy Optimization - PPO</a>.</p>
78+
<p>PPO is a policy gradient method for reinforcement learning.
79+
Simple policy gradient methods one do a single gradient update per sample (or a set of samples).
80+
Doing multiple gradient steps for a singe sample causes problems
81+
because the policy deviates too much producing a bad policy.
82+
PPO lets us do multiple gradient updates per sample by trying to keep the
83+
policy close to the policy that was used to sample data.
84+
It does so by clipping gradient flow if the updated policy
85+
is not close to the policy used to sample the data.</p>
7886
<p>You can find an experiment that uses it <a href="experiment.html">here</a>.
7987
The experiment uses <a href="gae.html">Generalized Advantage Estimation</a>.</p>
8088
</div>
8189
<div class='code'>
82-
<div class="highlight"><pre><span class="lineno">17</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
83-
<span class="lineno">18</span>
84-
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
85-
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
90+
<div class="highlight"><pre><span class="lineno">26</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
91+
<span class="lineno">27</span>
92+
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
93+
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
8694
</div>
8795
</div>
8896
<div class='section' id='section-1'>
@@ -91,6 +99,7 @@ <h1>Proximal Policy Optimization (PPO)</h1>
9199
<a href='#section-1'>#</a>
92100
</div>
93101
<h2>PPO Loss</h2>
102+
<p>Here&rsquo;s how the PPO update rule is derived.</p>
94103
<p>We want to maximize policy reward
95104
<script type="math/tex; mode=display">\max_\theta J(\pi_\theta) =
96105
\mathop{\mathbb{E}}_{\tau \sim \pi_\theta}\Biggl[\sum_{t=0}^\infty \gamma^t r_t \Biggr]</script>
@@ -186,7 +195,7 @@ <h2>PPO Loss</h2>
186195
</p>
187196
</div>
188197
<div class='code'>
189-
<div class="highlight"><pre><span class="lineno">23</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
198+
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
190199
</div>
191200
</div>
192201
<div class='section' id='section-2'>
@@ -197,8 +206,8 @@ <h2>PPO Loss</h2>
197206

198207
</div>
199208
<div class='code'>
200-
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
201-
<span class="lineno">123</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
209+
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
210+
<span class="lineno">134</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
202211
</div>
203212
</div>
204213
<div class='section' id='section-3'>
@@ -209,8 +218,8 @@ <h2>PPO Loss</h2>
209218

210219
</div>
211220
<div class='code'>
212-
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
213-
<span class="lineno">126</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
221+
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
222+
<span class="lineno">137</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
214223
</div>
215224
</div>
216225
<div class='section' id='section-4'>
@@ -222,15 +231,16 @@ <h2>PPO Loss</h2>
222231
<em>this is different from rewards</em> $r_t$.</p>
223232
</div>
224233
<div class='code'>
225-
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div>
234+
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div>
226235
</div>
227236
</div>
228237
<div class='section' id='section-5'>
229238
<div class='docs'>
230239
<div class='section-link'>
231240
<a href='#section-5'>#</a>
232241
</div>
233-
<p>
242+
<h3>Cliping the policy ratio</h3>
243+
<p>
234244
<script type="math/tex; mode=display">\begin{align}
235245
\mathcal{L}^{CLIP}(\theta) =
236246
\mathbb{E}_{a_t, s_t \sim \pi_{\theta{OLD}}} \biggl[
@@ -257,14 +267,14 @@ <h2>PPO Loss</h2>
257267
but it reduces variance a lot.</p>
258268
</div>
259269
<div class='code'>
260-
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span>
261-
<span class="lineno">157</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span>
262-
<span class="lineno">158</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span>
263-
<span class="lineno">159</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span>
264-
<span class="lineno">160</span>
265-
<span class="lineno">161</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
266-
<span class="lineno">162</span>
267-
<span class="lineno">163</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
270+
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span>
271+
<span class="lineno">170</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span>
272+
<span class="lineno">171</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span>
273+
<span class="lineno">172</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span>
274+
<span class="lineno">173</span>
275+
<span class="lineno">174</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
276+
<span class="lineno">175</span>
277+
<span class="lineno">176</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
268278
</div>
269279
</div>
270280
<div class='section' id='section-6'>
@@ -273,6 +283,7 @@ <h2>PPO Loss</h2>
273283
<a href='#section-6'>#</a>
274284
</div>
275285
<h2>Clipped Value Function Loss</h2>
286+
<p>Similarly we clip the value function update also.</p>
276287
<p>
277288
<script type="math/tex; mode=display">\begin{align}
278289
V^{\pi_\theta}_{CLIP}(s_t)
@@ -289,7 +300,7 @@ <h2>Clipped Value Function Loss</h2>
289300
significantly from $V_{\theta_{OLD}}$.</p>
290301
</div>
291302
<div class='code'>
292-
<div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
303+
<div class="highlight"><pre><span class="lineno">179</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
293304
</div>
294305
</div>
295306
<div class='section' id='section-7'>
@@ -300,10 +311,10 @@ <h2>Clipped Value Function Loss</h2>
300311

301312
</div>
302313
<div class='code'>
303-
<div class="highlight"><pre><span class="lineno">185</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
304-
<span class="lineno">186</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span>
305-
<span class="lineno">187</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
306-
<span class="lineno">188</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
314+
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
315+
<span class="lineno">201</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span>
316+
<span class="lineno">202</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
317+
<span class="lineno">203</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
307318
</div>
308319
</div>
309320
</div>

0 commit comments

Comments
 (0)