Skip to content

Commit fd66ca1

Browse files
committed
fix tensorflow and gym version incompetibility
1 parent 7390a50 commit fd66ca1

File tree

1 file changed

+15
-27
lines changed

1 file changed

+15
-27
lines changed

DQN/Deep Q Learning Solution.ipynb

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,14 @@
1111
"%matplotlib inline\n",
1212
"\n",
1313
"import gym\n",
14+
"from gym.wrappers import Monitor\n",
1415
"import itertools\n",
1516
"import numpy as np\n",
1617
"import os\n",
1718
"import random\n",
1819
"import sys\n",
1920
"import tensorflow as tf\n",
2021
"\n",
21-
"from distutils.version import StrictVersion\n",
22-
"tf_version_info = tf.__version__.split(\"-\")[0] # 0.12.0-rc0 to 0.12.0,future release alpha 1.*.*\n",
23-
"\n",
2422
"if \"../\" not in sys.path:\n",
2523
" sys.path.append(\"../\")\n",
2624
"\n",
@@ -69,16 +67,10 @@
6967
" self.input_state = tf.placeholder(shape=[210, 160, 3], dtype=tf.uint8)\n",
7068
" self.output = tf.image.rgb_to_grayscale(self.input_state)\n",
7169
" self.output = tf.image.crop_to_bounding_box(self.output, 34, 0, 160, 160)\n",
72-
" \n",
73-
" # tf.image.resize_images() function changed after tf version 0.11.0\n",
74-
" if (StrictVersion(tf_version_info) >= StrictVersion('0.11.0')):\n",
75-
" self.output = tf.image.resize_images(\n",
76-
" self.output, [84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
77-
" else:\n",
78-
" self.output = tf.image.resize_images(\n",
79-
" self.output, 84, 84, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
70+
" self.output = tf.image.resize_images(\n",
71+
" self.output, [84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
8072
" self.output = tf.squeeze(self.output)\n",
81-
" \n",
73+
"\n",
8274
" def process(self, sess, state):\n",
8375
" \"\"\"\n",
8476
" Args:\n",
@@ -116,7 +108,7 @@
116108
" summary_dir = os.path.join(summaries_dir, \"summaries_{}\".format(scope))\n",
117109
" if not os.path.exists(summary_dir):\n",
118110
" os.makedirs(summary_dir)\n",
119-
" self.summary_writer = tf.train.SummaryWriter(summary_dir)\n",
111+
" self.summary_writer = tf.summary.FileWriter(summary_dir)\n",
120112
"\n",
121113
" def _build_model(self):\n",
122114
" \"\"\"\n",
@@ -160,14 +152,13 @@
160152
" self.train_op = self.optimizer.minimize(self.loss, global_step=tf.contrib.framework.get_global_step())\n",
161153
"\n",
162154
" # Summaries for Tensorboard\n",
163-
" self.summaries = tf.merge_summary([\n",
164-
" tf.scalar_summary(\"loss\", self.loss),\n",
165-
" tf.histogram_summary(\"loss_hist\", self.losses),\n",
166-
" tf.histogram_summary(\"q_values_hist\", self.predictions),\n",
167-
" tf.scalar_summary(\"max_q_value\", tf.reduce_max(self.predictions))\n",
155+
" self.summaries = tf.summary.merge([\n",
156+
" tf.summary.scalar(\"loss\", self.loss),\n",
157+
" tf.summary.histogram(\"loss_hist\", self.losses),\n",
158+
" tf.summary.histogram(\"q_values_hist\", self.predictions),\n",
159+
" tf.summary.scalar(\"max_q_value\", tf.reduce_max(self.predictions))\n",
168160
" ])\n",
169161
"\n",
170-
"\n",
171162
" def predict(self, sess, s):\n",
172163
" \"\"\"\n",
173164
" Predicts action values.\n",
@@ -221,7 +212,7 @@
221212
"sp = StateProcessor()\n",
222213
"\n",
223214
"with tf.Session() as sess:\n",
224-
" sess.run(tf.initialize_all_variables())\n",
215+
" sess.run(tf.global_variables_initializer())\n",
225216
" \n",
226217
" # Example observation batch\n",
227218
" observation = env.reset()\n",
@@ -366,7 +357,9 @@
366357
" checkpoint_dir = os.path.join(experiment_dir, \"checkpoints\")\n",
367358
" checkpoint_path = os.path.join(checkpoint_dir, \"model\")\n",
368359
" monitor_path = os.path.join(experiment_dir, \"monitor\")\n",
369-
"\n",
360+
" # Add env Monitor wrapper\n",
361+
" env = Monitor(env, directory=monitor_path, video_callable=lambda x: True, resume=True)\n",
362+
" \n",
370363
" if not os.path.exists(checkpoint_dir):\n",
371364
" os.makedirs(checkpoint_dir)\n",
372365
" if not os.path.exists(monitor_path):\n",
@@ -409,10 +402,6 @@
409402
" else:\n",
410403
" state = next_state\n",
411404
"\n",
412-
" # Record videos\n",
413-
" env.monitor.start(monitor_path,\n",
414-
" resume=True,\n",
415-
" video_callable=lambda count: count % record_video_every == 0)\n",
416405
"\n",
417406
" for i_episode in range(num_episodes):\n",
418407
"\n",
@@ -493,7 +482,6 @@
493482
" episode_lengths=stats.episode_lengths[:i_episode+1],\n",
494483
" episode_rewards=stats.episode_rewards[:i_episode+1])\n",
495484
"\n",
496-
" env.monitor.close()\n",
497485
" return stats"
498486
]
499487
},
@@ -522,7 +510,7 @@
522510
"\n",
523511
"# Run it!\n",
524512
"with tf.Session() as sess:\n",
525-
" sess.run(tf.initialize_all_variables())\n",
513+
" sess.run(tf.global_variables_initializer())\n",
526514
" for t, stats in deep_q_learning(sess,\n",
527515
" env,\n",
528516
" q_estimator=q_estimator,\n",

0 commit comments

Comments
 (0)