|
11 | 11 | "%matplotlib inline\n", |
12 | 12 | "\n", |
13 | 13 | "import gym\n", |
| 14 | + "from gym.wrappers import Monitor\n", |
14 | 15 | "import itertools\n", |
15 | 16 | "import numpy as np\n", |
16 | 17 | "import os\n", |
17 | 18 | "import random\n", |
18 | 19 | "import sys\n", |
19 | 20 | "import tensorflow as tf\n", |
20 | 21 | "\n", |
21 | | - "from distutils.version import StrictVersion\n", |
22 | | - "tf_version_info = tf.__version__.split(\"-\")[0] # 0.12.0-rc0 to 0.12.0,future release alpha 1.*.*\n", |
23 | | - "\n", |
24 | 22 | "if \"../\" not in sys.path:\n", |
25 | 23 | " sys.path.append(\"../\")\n", |
26 | 24 | "\n", |
|
69 | 67 | " self.input_state = tf.placeholder(shape=[210, 160, 3], dtype=tf.uint8)\n", |
70 | 68 | " self.output = tf.image.rgb_to_grayscale(self.input_state)\n", |
71 | 69 | " self.output = tf.image.crop_to_bounding_box(self.output, 34, 0, 160, 160)\n", |
72 | | - " \n", |
73 | | - " # tf.image.resize_images() function changed after tf version 0.11.0\n", |
74 | | - " if (StrictVersion(tf_version_info) >= StrictVersion('0.11.0')):\n", |
75 | | - " self.output = tf.image.resize_images(\n", |
76 | | - " self.output, [84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", |
77 | | - " else:\n", |
78 | | - " self.output = tf.image.resize_images(\n", |
79 | | - " self.output, 84, 84, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", |
| 70 | + " self.output = tf.image.resize_images(\n", |
| 71 | + " self.output, [84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", |
80 | 72 | " self.output = tf.squeeze(self.output)\n", |
81 | | - " \n", |
| 73 | + "\n", |
82 | 74 | " def process(self, sess, state):\n", |
83 | 75 | " \"\"\"\n", |
84 | 76 | " Args:\n", |
|
116 | 108 | " summary_dir = os.path.join(summaries_dir, \"summaries_{}\".format(scope))\n", |
117 | 109 | " if not os.path.exists(summary_dir):\n", |
118 | 110 | " os.makedirs(summary_dir)\n", |
119 | | - " self.summary_writer = tf.train.SummaryWriter(summary_dir)\n", |
| 111 | + " self.summary_writer = tf.summary.FileWriter(summary_dir)\n", |
120 | 112 | "\n", |
121 | 113 | " def _build_model(self):\n", |
122 | 114 | " \"\"\"\n", |
|
160 | 152 | " self.train_op = self.optimizer.minimize(self.loss, global_step=tf.contrib.framework.get_global_step())\n", |
161 | 153 | "\n", |
162 | 154 | " # Summaries for Tensorboard\n", |
163 | | - " self.summaries = tf.merge_summary([\n", |
164 | | - " tf.scalar_summary(\"loss\", self.loss),\n", |
165 | | - " tf.histogram_summary(\"loss_hist\", self.losses),\n", |
166 | | - " tf.histogram_summary(\"q_values_hist\", self.predictions),\n", |
167 | | - " tf.scalar_summary(\"max_q_value\", tf.reduce_max(self.predictions))\n", |
| 155 | + " self.summaries = tf.summary.merge([\n", |
| 156 | + " tf.summary.scalar(\"loss\", self.loss),\n", |
| 157 | + " tf.summary.histogram(\"loss_hist\", self.losses),\n", |
| 158 | + " tf.summary.histogram(\"q_values_hist\", self.predictions),\n", |
| 159 | + " tf.summary.scalar(\"max_q_value\", tf.reduce_max(self.predictions))\n", |
168 | 160 | " ])\n", |
169 | 161 | "\n", |
170 | | - "\n", |
171 | 162 | " def predict(self, sess, s):\n", |
172 | 163 | " \"\"\"\n", |
173 | 164 | " Predicts action values.\n", |
|
221 | 212 | "sp = StateProcessor()\n", |
222 | 213 | "\n", |
223 | 214 | "with tf.Session() as sess:\n", |
224 | | - " sess.run(tf.initialize_all_variables())\n", |
| 215 | + " sess.run(tf.global_variables_initializer())\n", |
225 | 216 | " \n", |
226 | 217 | " # Example observation batch\n", |
227 | 218 | " observation = env.reset()\n", |
|
366 | 357 | " checkpoint_dir = os.path.join(experiment_dir, \"checkpoints\")\n", |
367 | 358 | " checkpoint_path = os.path.join(checkpoint_dir, \"model\")\n", |
368 | 359 | " monitor_path = os.path.join(experiment_dir, \"monitor\")\n", |
369 | | - "\n", |
| 360 | + " # Add env Monitor wrapper\n", |
| 361 | + " env = Monitor(env, directory=monitor_path, video_callable=lambda x: True, resume=True)\n", |
| 362 | + " \n", |
370 | 363 | " if not os.path.exists(checkpoint_dir):\n", |
371 | 364 | " os.makedirs(checkpoint_dir)\n", |
372 | 365 | " if not os.path.exists(monitor_path):\n", |
|
409 | 402 | " else:\n", |
410 | 403 | " state = next_state\n", |
411 | 404 | "\n", |
412 | | - " # Record videos\n", |
413 | | - " env.monitor.start(monitor_path,\n", |
414 | | - " resume=True,\n", |
415 | | - " video_callable=lambda count: count % record_video_every == 0)\n", |
416 | 405 | "\n", |
417 | 406 | " for i_episode in range(num_episodes):\n", |
418 | 407 | "\n", |
|
493 | 482 | " episode_lengths=stats.episode_lengths[:i_episode+1],\n", |
494 | 483 | " episode_rewards=stats.episode_rewards[:i_episode+1])\n", |
495 | 484 | "\n", |
496 | | - " env.monitor.close()\n", |
497 | 485 | " return stats" |
498 | 486 | ] |
499 | 487 | }, |
|
522 | 510 | "\n", |
523 | 511 | "# Run it!\n", |
524 | 512 | "with tf.Session() as sess:\n", |
525 | | - " sess.run(tf.initialize_all_variables())\n", |
| 513 | + " sess.run(tf.global_variables_initializer())\n", |
526 | 514 | " for t, stats in deep_q_learning(sess,\n", |
527 | 515 | " env,\n", |
528 | 516 | " q_estimator=q_estimator,\n", |
|
0 commit comments