|
118 | 118 | }, |
119 | 119 | { |
120 | 120 | "cell_type": "code", |
121 | | - "execution_count": 33, |
| 121 | + "execution_count": 47, |
122 | 122 | "metadata": { |
123 | 123 | "collapsed": false |
124 | 124 | }, |
|
158 | 158 | " # Loss and train op\n", |
159 | 159 | " self.loss = -self.normal_dist.log_prob(self.action) * self.target\n", |
160 | 160 | " # Add cross entropy cost to encourage exploration\n", |
161 | | - " self.loss -= 1e-4 * self.normal_dist.entropy()\n", |
| 161 | + " self.loss -= 1e-2 * self.normal_dist.entropy()\n", |
162 | 162 | " \n", |
163 | 163 | " self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n", |
164 | 164 | " self.train_op = self.optimizer.minimize(\n", |
|
179 | 179 | }, |
180 | 180 | { |
181 | 181 | "cell_type": "code", |
182 | | - "execution_count": 34, |
| 182 | + "execution_count": 48, |
183 | 183 | "metadata": { |
184 | 184 | "collapsed": false |
185 | 185 | }, |
|
224 | 224 | }, |
225 | 225 | { |
226 | 226 | "cell_type": "code", |
227 | | - "execution_count": 35, |
| 227 | + "execution_count": 49, |
228 | 228 | "metadata": { |
229 | 229 | "collapsed": true |
230 | 230 | }, |
|
303 | 303 | }, |
304 | 304 | { |
305 | 305 | "cell_type": "code", |
306 | | - "execution_count": null, |
| 306 | + "execution_count": 50, |
307 | 307 | "metadata": { |
308 | 308 | "collapsed": false |
309 | 309 | }, |
|
312 | 312 | "name": "stdout", |
313 | 313 | "output_type": "stream", |
314 | 314 | "text": [ |
315 | | - "Step 16840 @ Episode 1/100 (0.0)" |
| 315 | + "Step 7291 @ Episode 1/50 (0.0)" |
| 316 | + ] |
| 317 | + }, |
| 318 | + { |
| 319 | + "ename": "KeyboardInterrupt", |
| 320 | + "evalue": "", |
| 321 | + "output_type": "error", |
| 322 | + "traceback": [ |
| 323 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 324 | + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", |
| 325 | + "\u001b[0;32m<ipython-input-50-37680e81248d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;31m# Note, due to randomness in the policy the number of episodes you need to learn a good\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# TODO: Sometimes the algorithm gets stuck, I'm not sure what exactly is happening there.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mstats\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mactor_critic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_estimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue_estimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscount_factor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.95\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", |
| 326 | + "\u001b[0;32m<ipython-input-49-47f45a7578ed>\u001b[0m in \u001b[0;36mactor_critic\u001b[0;34m(env, estimator_policy, estimator_value, num_episodes, discount_factor)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;31m# Take a step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mestimator_policy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 38\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", |
| 327 | + "\u001b[0;32m<ipython-input-47-78ba8ef0f2bb>\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, state, sess)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0msess\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_default_session\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfeaturize_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 328 | + "\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 381\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 382\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 383\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 329 | + "\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 653\u001b[0m \u001b[0mmovers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_with_movers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeed_dict_string\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_map\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 654\u001b[0m results = self._do_run(handle, target_list, unique_fetches,\n\u001b[0;32m--> 655\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 656\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 657\u001b[0m \u001b[0;31m# User may have fetched the same tensor multiple times, but we\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 330 | + "\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 721\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 722\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m--> 723\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 724\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 725\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", |
| 331 | + "\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 728\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 729\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 730\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 731\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 732\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 332 | + "\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 710\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 711\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 712\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 713\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 714\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 333 | + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " |
316 | 334 | ] |
317 | 335 | } |
318 | 336 | ], |
|
327 | 345 | " sess.run(tf.initialize_all_variables())\n", |
328 | 346 | " # Note, due to randomness in the policy the number of episodes you need to learn a good\n", |
329 | 347 | " # TODO: Sometimes the algorithm gets stuck, I'm not sure what exactly is happening there.\n", |
330 | | - " stats = actor_critic(env, policy_estimator, value_estimator, 100, discount_factor=0.95)" |
| 348 | + " stats = actor_critic(env, policy_estimator, value_estimator, 50, discount_factor=0.95)" |
331 | 349 | ] |
332 | 350 | }, |
333 | 351 | { |
|
0 commit comments