75
75
< h1 > Proximal Policy Optimization (PPO)</ h1 >
76
76
< p > This is a < a href ="https://pytorch.org "> PyTorch</ a > implementation of
77
77
< a href ="https://arxiv.org/abs/1707.06347 "> Proximal Policy Optimization - PPO</ a > .</ p >
78
+ < p > PPO is a policy gradient method for reinforcement learning.
79
+ Simple policy gradient methods one do a single gradient update per sample (or a set of samples).
80
+ Doing multiple gradient steps for a singe sample causes problems
81
+ because the policy deviates too much producing a bad policy.
82
+ PPO lets us do multiple gradient updates per sample by trying to keep the
83
+ policy close to the policy that was used to sample data.
84
+ It does so by clipping gradient flow if the updated policy
85
+ is not close to the policy used to sample the data.</ p >
78
86
< p > You can find an experiment that uses it < a href ="experiment.html "> here</ a > .
79
87
The experiment uses < a href ="gae.html "> Generalized Advantage Estimation</ a > .</ p >
80
88
</ div >
81
89
< div class ='code '>
82
- < div class ="highlight "> < pre > < span class ="lineno "> 17 </ span > < span > </ span > < span class ="kn "> import</ span > < span class ="nn "> torch</ span >
83
- < span class ="lineno "> 18 </ span >
84
- < span class ="lineno "> 19 </ span > < span class ="kn "> from</ span > < span class ="nn "> labml_helpers.module</ span > < span class ="kn "> import</ span > < span class ="n "> Module</ span >
85
- < span class ="lineno "> 20 </ span > < span class ="kn "> from</ span > < span class ="nn "> labml_nn.rl.ppo.gae</ span > < span class ="kn "> import</ span > < span class ="n "> GAE</ span > </ pre > </ div >
90
+ < div class ="highlight "> < pre > < span class ="lineno "> 26 </ span > < span > </ span > < span class ="kn "> import</ span > < span class ="nn "> torch</ span >
91
+ < span class ="lineno "> 27 </ span >
92
+ < span class ="lineno "> 28 </ span > < span class ="kn "> from</ span > < span class ="nn "> labml_helpers.module</ span > < span class ="kn "> import</ span > < span class ="n "> Module</ span >
93
+ < span class ="lineno "> 29 </ span > < span class ="kn "> from</ span > < span class ="nn "> labml_nn.rl.ppo.gae</ span > < span class ="kn "> import</ span > < span class ="n "> GAE</ span > </ pre > </ div >
86
94
</ div >
87
95
</ div >
88
96
< div class ='section ' id ='section-1 '>
@@ -91,6 +99,7 @@ <h1>Proximal Policy Optimization (PPO)</h1>
91
99
< a href ='#section-1 '> #</ a >
92
100
</ div >
93
101
< h2 > PPO Loss</ h2 >
102
+ < p > Here’s how the PPO update rule is derived.</ p >
94
103
< p > We want to maximize policy reward
95
104
< script type ="math/tex; mode=display "> \max_ \theta J ( \pi_ \theta ) =
96
105
\mathop { \mathbb { E } } _ { \tau \sim \pi_ \theta } \Biggl [ \sum_ { t= 0 } ^ \infty \gamma ^ t r_t \Biggr ] </ script >
@@ -186,7 +195,7 @@ <h2>PPO Loss</h2>
186
195
</ p >
187
196
</ div >
188
197
< div class ='code '>
189
- < div class ="highlight "> < pre > < span class ="lineno "> 23 </ span > < span class ="k "> class</ span > < span class ="nc "> ClippedPPOLoss</ span > < span class ="p "> (</ span > < span class ="n "> Module</ span > < span class ="p "> ):</ span > </ pre > </ div >
198
+ < div class ="highlight "> < pre > < span class ="lineno "> 32 </ span > < span class ="k "> class</ span > < span class ="nc "> ClippedPPOLoss</ span > < span class ="p "> (</ span > < span class ="n "> Module</ span > < span class ="p "> ):</ span > </ pre > </ div >
190
199
</ div >
191
200
</ div >
192
201
< div class ='section ' id ='section-2 '>
@@ -197,8 +206,8 @@ <h2>PPO Loss</h2>
197
206
198
207
</ div >
199
208
< div class ='code '>
200
- < div class ="highlight "> < pre > < span class ="lineno "> 122 </ span > < span class ="k "> def</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ):</ span >
201
- < span class ="lineno "> 123 </ span > < span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> ()</ span > </ pre > </ div >
209
+ < div class ="highlight "> < pre > < span class ="lineno "> 133 </ span > < span class ="k "> def</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ):</ span >
210
+ < span class ="lineno "> 134 </ span > < span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> ()</ span > </ pre > </ div >
202
211
</ div >
203
212
</ div >
204
213
< div class ='section ' id ='section-3 '>
@@ -209,8 +218,8 @@ <h2>PPO Loss</h2>
209
218
210
219
</ div >
211
220
< div class ='code '>
212
- < div class ="highlight "> < pre > < span class ="lineno "> 125 </ span > < span class ="k "> def</ span > < span class ="fm "> __call__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> log_pi</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> sampled_log_pi</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span >
213
- < span class ="lineno "> 126 </ span > < span class ="n "> advantage</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> clip</ span > < span class ="p "> :</ span > < span class ="nb "> float</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> :</ span > </ pre > </ div >
221
+ < div class ="highlight "> < pre > < span class ="lineno "> 136 </ span > < span class ="k "> def</ span > < span class ="fm "> __call__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> log_pi</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> sampled_log_pi</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span >
222
+ < span class ="lineno "> 137 </ span > < span class ="n "> advantage</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> clip</ span > < span class ="p "> :</ span > < span class ="nb "> float</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> :</ span > </ pre > </ div >
214
223
</ div >
215
224
</ div >
216
225
< div class ='section ' id ='section-4 '>
@@ -222,15 +231,16 @@ <h2>PPO Loss</h2>
222
231
< em > this is different from rewards</ em > $r_t$.</ p >
223
232
</ div >
224
233
< div class ='code '>
225
- < div class ="highlight "> < pre > < span class ="lineno "> 129 </ span > < span class ="n "> ratio</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> exp</ span > < span class ="p "> (</ span > < span class ="n "> log_pi</ span > < span class ="o "> -</ span > < span class ="n "> sampled_log_pi</ span > < span class ="p "> )</ span > </ pre > </ div >
234
+ < div class ="highlight "> < pre > < span class ="lineno "> 140 </ span > < span class ="n "> ratio</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> exp</ span > < span class ="p "> (</ span > < span class ="n "> log_pi</ span > < span class ="o "> -</ span > < span class ="n "> sampled_log_pi</ span > < span class ="p "> )</ span > </ pre > </ div >
226
235
</ div >
227
236
</ div >
228
237
< div class ='section ' id ='section-5 '>
229
238
< div class ='docs '>
230
239
< div class ='section-link '>
231
240
< a href ='#section-5 '> #</ a >
232
241
</ div >
233
- < p >
242
+ < h3 > Cliping the policy ratio</ h3 >
243
+ < p >
234
244
< script type ="math/tex; mode=display "> \begin { align }
235
245
\mathcal { L } ^ { CLIP } ( \theta ) =
236
246
\mathbb { E} _ { a_t, s_t \sim \pi_ { \theta { OLD } } } \biggl [
@@ -257,14 +267,14 @@ <h2>PPO Loss</h2>
257
267
but it reduces variance a lot.</ p >
258
268
</ div >
259
269
< div class ='code '>
260
- < div class ="highlight "> < pre > < span class ="lineno "> 156 </ span > < span class ="n "> clipped_ratio</ span > < span class ="o "> =</ span > < span class ="n "> ratio</ span > < span class ="o "> .</ span > < span class ="n "> clamp</ span > < span class ="p "> (</ span > < span class ="nb "> min</ span > < span class ="o "> =</ span > < span class ="mf "> 1.0</ span > < span class ="o "> -</ span > < span class ="n "> clip</ span > < span class ="p "> ,</ span >
261
- < span class ="lineno "> 157 </ span > < span class ="nb "> max</ span > < span class ="o "> =</ span > < span class ="mf "> 1.0</ span > < span class ="o "> +</ span > < span class ="n "> clip</ span > < span class ="p "> )</ span >
262
- < span class ="lineno "> 158 </ span > < span class ="n "> policy_reward</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> min</ span > < span class ="p "> (</ span > < span class ="n "> ratio</ span > < span class ="o "> *</ span > < span class ="n "> advantage</ span > < span class ="p "> ,</ span >
263
- < span class ="lineno "> 159 </ span > < span class ="n "> clipped_ratio</ span > < span class ="o "> *</ span > < span class ="n "> advantage</ span > < span class ="p "> )</ span >
264
- < span class ="lineno "> 160 </ span >
265
- < span class ="lineno "> 161 </ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> clip_fraction</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="nb "> abs</ span > < span class ="p "> ((</ span > < span class ="n "> ratio</ span > < span class ="o "> -</ span > < span class ="mf "> 1.0</ span > < span class ="p "> ))</ span > < span class ="o "> ></ span > < span class ="n "> clip</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> to</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> float</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ()</ span >
266
- < span class ="lineno "> 162 </ span >
267
- < span class ="lineno "> 163 </ span > < span class ="k "> return</ span > < span class ="o "> -</ span > < span class ="n "> policy_reward</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ()</ span > </ pre > </ div >
270
+ < div class ="highlight "> < pre > < span class ="lineno "> 169 </ span > < span class ="n "> clipped_ratio</ span > < span class ="o "> =</ span > < span class ="n "> ratio</ span > < span class ="o "> .</ span > < span class ="n "> clamp</ span > < span class ="p "> (</ span > < span class ="nb "> min</ span > < span class ="o "> =</ span > < span class ="mf "> 1.0</ span > < span class ="o "> -</ span > < span class ="n "> clip</ span > < span class ="p "> ,</ span >
271
+ < span class ="lineno "> 170 </ span > < span class ="nb "> max</ span > < span class ="o "> =</ span > < span class ="mf "> 1.0</ span > < span class ="o "> +</ span > < span class ="n "> clip</ span > < span class ="p "> )</ span >
272
+ < span class ="lineno "> 171 </ span > < span class ="n "> policy_reward</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> min</ span > < span class ="p "> (</ span > < span class ="n "> ratio</ span > < span class ="o "> *</ span > < span class ="n "> advantage</ span > < span class ="p "> ,</ span >
273
+ < span class ="lineno "> 172 </ span > < span class ="n "> clipped_ratio</ span > < span class ="o "> *</ span > < span class ="n "> advantage</ span > < span class ="p "> )</ span >
274
+ < span class ="lineno "> 173 </ span >
275
+ < span class ="lineno "> 174 </ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> clip_fraction</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="nb "> abs</ span > < span class ="p "> ((</ span > < span class ="n "> ratio</ span > < span class ="o "> -</ span > < span class ="mf "> 1.0</ span > < span class ="p "> ))</ span > < span class ="o "> ></ span > < span class ="n "> clip</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> to</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> float</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ()</ span >
276
+ < span class ="lineno "> 175 </ span >
277
+ < span class ="lineno "> 176 </ span > < span class ="k "> return</ span > < span class ="o "> -</ span > < span class ="n "> policy_reward</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ()</ span > </ pre > </ div >
268
278
</ div >
269
279
</ div >
270
280
< div class ='section ' id ='section-6 '>
@@ -273,6 +283,7 @@ <h2>PPO Loss</h2>
273
283
< a href ='#section-6 '> #</ a >
274
284
</ div >
275
285
< h2 > Clipped Value Function Loss</ h2 >
286
+ < p > Similarly we clip the value function update also.</ p >
276
287
< p >
277
288
< script type ="math/tex; mode=display "> \begin { align }
278
289
V ^ { \pi_ \theta} _ { CLIP } ( s_t )
@@ -289,7 +300,7 @@ <h2>Clipped Value Function Loss</h2>
289
300
significantly from $V_{\theta_{OLD}}$.</ p >
290
301
</ div >
291
302
< div class ='code '>
292
- < div class ="highlight "> < pre > < span class ="lineno "> 166 </ span > < span class ="k "> class</ span > < span class ="nc "> ClippedValueFunctionLoss</ span > < span class ="p "> (</ span > < span class ="n "> Module</ span > < span class ="p "> ):</ span > </ pre > </ div >
303
+ < div class ="highlight "> < pre > < span class ="lineno "> 179 </ span > < span class ="k "> class</ span > < span class ="nc "> ClippedValueFunctionLoss</ span > < span class ="p "> (</ span > < span class ="n "> Module</ span > < span class ="p "> ):</ span > </ pre > </ div >
293
304
</ div >
294
305
</ div >
295
306
< div class ='section ' id ='section-7 '>
@@ -300,10 +311,10 @@ <h2>Clipped Value Function Loss</h2>
300
311
301
312
</ div >
302
313
< div class ='code '>
303
- < div class ="highlight "> < pre > < span class ="lineno "> 185 </ span > < span class ="k "> def</ span > < span class ="fm "> __call__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> value</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> sampled_value</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> sampled_return</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> clip</ span > < span class ="p "> :</ span > < span class ="nb "> float</ span > < span class ="p "> ):</ span >
304
- < span class ="lineno "> 186 </ span > < span class ="n "> clipped_value</ span > < span class ="o "> =</ span > < span class ="n "> sampled_value</ span > < span class ="o "> +</ span > < span class ="p "> (</ span > < span class ="n "> value</ span > < span class ="o "> -</ span > < span class ="n "> sampled_value</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> clamp</ span > < span class ="p "> (</ span > < span class ="nb "> min</ span > < span class ="o "> =-</ span > < span class ="n "> clip</ span > < span class ="p "> ,</ span > < span class ="nb "> max</ span > < span class ="o "> =</ span > < span class ="n "> clip</ span > < span class ="p "> )</ span >
305
- < span class ="lineno "> 187 </ span > < span class ="n "> vf_loss</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> max</ span > < span class ="p "> ((</ span > < span class ="n "> value</ span > < span class ="o "> -</ span > < span class ="n "> sampled_return</ span > < span class ="p "> )</ span > < span class ="o "> **</ span > < span class ="mi "> 2</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> clipped_value</ span > < span class ="o "> -</ span > < span class ="n "> sampled_return</ span > < span class ="p "> )</ span > < span class ="o "> **</ span > < span class ="mi "> 2</ span > < span class ="p "> )</ span >
306
- < span class ="lineno "> 188 </ span > < span class ="k "> return</ span > < span class ="mf "> 0.5</ span > < span class ="o "> *</ span > < span class ="n "> vf_loss</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ()</ span > </ pre > </ div >
314
+ < div class ="highlight "> < pre > < span class ="lineno "> 200 </ span > < span class ="k "> def</ span > < span class ="fm "> __call__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> value</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> sampled_value</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> sampled_return</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> clip</ span > < span class ="p "> :</ span > < span class ="nb "> float</ span > < span class ="p "> ):</ span >
315
+ < span class ="lineno "> 201 </ span > < span class ="n "> clipped_value</ span > < span class ="o "> =</ span > < span class ="n "> sampled_value</ span > < span class ="o "> +</ span > < span class ="p "> (</ span > < span class ="n "> value</ span > < span class ="o "> -</ span > < span class ="n "> sampled_value</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> clamp</ span > < span class ="p "> (</ span > < span class ="nb "> min</ span > < span class ="o "> =-</ span > < span class ="n "> clip</ span > < span class ="p "> ,</ span > < span class ="nb "> max</ span > < span class ="o "> =</ span > < span class ="n "> clip</ span > < span class ="p "> )</ span >
316
+ < span class ="lineno "> 202 </ span > < span class ="n "> vf_loss</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> max</ span > < span class ="p "> ((</ span > < span class ="n "> value</ span > < span class ="o "> -</ span > < span class ="n "> sampled_return</ span > < span class ="p "> )</ span > < span class ="o "> **</ span > < span class ="mi "> 2</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> clipped_value</ span > < span class ="o "> -</ span > < span class ="n "> sampled_return</ span > < span class ="p "> )</ span > < span class ="o "> **</ span > < span class ="mi "> 2</ span > < span class ="p "> )</ span >
317
+ < span class ="lineno "> 203 </ span > < span class ="k "> return</ span > < span class ="mf "> 0.5</ span > < span class ="o "> *</ span > < span class ="n "> vf_loss</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ()</ span > </ pre > </ div >
307
318
</ div >
308
319
</ div >
309
320
</ div >
0 commit comments