Skip to content

Commit 016512c

Browse files
Fix 2 bugs where errors reported directly to the coordinator were not propagated
out of a "with tf.MonitoredSession()" block. 1: If the _CoordinatedSession detected a request for stop during a call to run() the request was not honored. (Missing a finally: clause) 2: If a thread requested a stop but the with.MonitoredSession block terminated without calling run() the coordinated threads where not joined and the optional exception was not raised. Added 2 tests, one for each bug. Modernize the code to take benefit of the fact that the threads now register themselves with the Coordinator directly. Change: 130890075
1 parent 3cc1ef6 commit 016512c

File tree

2 files changed

+71
-21
lines changed

2 files changed

+71
-21
lines changed

tensorflow/contrib/learn/python/learn/monitored_session.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def __init__(self,
267267
self._config = config
268268
self._hooks = hooks or []
269269
self._scaffold = scaffold or Scaffold()
270+
self._coord = None
270271
for h in self._hooks:
271272
h.begin()
272273
# Create the session.
@@ -299,12 +300,10 @@ def _create_session(self):
299300
# Keep the tf_sess for quick runs of global step when needed.
300301
self._tf_sess = tf_sess
301302
# We don't want coordinator to suppress any exception.
302-
coord = coordinator.Coordinator(clean_stop_exception_types=[])
303-
coordinated_threads_to_join = queue_runner.start_queue_runners(
304-
sess=tf_sess, coord=coord)
305-
return _CoordinatedSession(
306-
_HookedSession(tf_sess, self._hooks), coord,
307-
coordinated_threads_to_join)
303+
self._coord = coordinator.Coordinator(clean_stop_exception_types=[])
304+
queue_runner.start_queue_runners(sess=tf_sess, coord=self._coord)
305+
return _CoordinatedSession(_HookedSession(tf_sess, self._hooks),
306+
self._coord)
308307

309308
@property
310309
def scaffold(self):
@@ -346,10 +345,20 @@ def _close_internal(self, exception_type=None):
346345
if not exception_type:
347346
for h in self._hooks:
348347
h.end(self._tf_sess)
348+
if not self._coord.joined:
349+
# We exited cleanly without stopping. Some things now. This will also
350+
# re-raise exceptions from the coordinated threads, as needed.
351+
self._coord.request_stop()
352+
self._coord.join()
349353
finally:
350354
self._sess.close()
351355
self._sess = None
352356
self._tf_sess = None
357+
self._coord = None
358+
359+
@property
360+
def coord(self):
361+
return self._coord
353362

354363
def _is_closed(self):
355364
"""Return True if the supervised session is closed. For tests only.
@@ -491,24 +500,22 @@ class _CoordinatedSession(_WrappedSession):
491500
raises an exception, the exception is reported to the coordinator.
492501
493502
In addition, after each call to `run()` this session ask the coordinator if
494-
the session should stop. In that case it will will join all the coordinated
495-
threads passed to the constructor before returning.
503+
the session should stop. In that case it will will join all the threads
504+
registered with the coordinator before returning.
496505
497506
If the coordinator was requested to stop with an exception, that exception
498507
will be re-raised from the call to `run()`.
499508
"""
500509

501-
def __init__(self, sess, coord, coordinated_threads_to_join):
510+
def __init__(self, sess, coord):
502511
"""Create a new `_CoordinatedSession`.
503512
504513
Args:
505514
sess: A `tf.Session` object. The wrapped session.
506515
coord: A `tf.train.Coordinator` object.
507-
coordinated_threads_to_join: A list of threads.
508516
"""
509517
_WrappedSession.__init__(self, sess)
510518
self._coord = coord
511-
self._coordinated_threads_to_join = coordinated_threads_to_join
512519

513520
def _check_stop(self):
514521
# Check with the coordinator if we should stop.
@@ -518,7 +525,7 @@ def close(self):
518525
try:
519526
if not self._coord.should_stop():
520527
self._coord.request_stop()
521-
self._coord.join(self._coordinated_threads_to_join)
528+
self._coord.join()
522529
except Exception: # pylint: disable=broad-except
523530
# Don't raise exception at close
524531
pass
@@ -530,8 +537,9 @@ def run(self, *args, **kwargs):
530537
return self._sess.run(*args, **kwargs)
531538
except Exception as e: # pylint: disable=broad-except
532539
self._coord.request_stop(e)
533-
if self._coord.should_stop():
534-
self._coord.join(self._coordinated_threads_to_join)
540+
finally:
541+
if self._coord.should_stop():
542+
self._coord.join()
535543

536544

537545
class _HookedSession(_WrappedSession):

tensorflow/contrib/learn/python/learn/tests/monitored_session_test.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_properties(self):
187187
with self.test_session() as sess:
188188
tf.constant(0.0)
189189
coord = tf.train.Coordinator()
190-
coord_sess = monitored_session._CoordinatedSession(sess, coord, [])
190+
coord_sess = monitored_session._CoordinatedSession(sess, coord)
191191
self.assertEquals(sess.graph, coord_sess.graph)
192192
self.assertEquals(sess.sess_str, coord_sess.sess_str)
193193

@@ -196,21 +196,21 @@ def test_run(self):
196196
c = tf.constant(0)
197197
v = tf.identity(c)
198198
coord = tf.train.Coordinator()
199-
coord_sess = monitored_session._CoordinatedSession(sess, coord, [])
199+
coord_sess = monitored_session._CoordinatedSession(sess, coord)
200200
self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
201201

202202
def test_should_stop_on_close(self):
203203
with self.test_session() as sess:
204204
coord = tf.train.Coordinator()
205-
coord_sess = monitored_session._CoordinatedSession(sess, coord, [])
205+
coord_sess = monitored_session._CoordinatedSession(sess, coord)
206206
self.assertFalse(coord_sess.should_stop())
207207
coord_sess.close()
208208
self.assertTrue(coord_sess.should_stop())
209209

210210
def test_should_stop_on_coord_stop(self):
211211
with self.test_session() as sess:
212212
coord = tf.train.Coordinator()
213-
coord_sess = monitored_session._CoordinatedSession(sess, coord, [])
213+
coord_sess = monitored_session._CoordinatedSession(sess, coord)
214214
self.assertFalse(coord_sess.should_stop())
215215
coord.request_stop()
216216
self.assertTrue(coord_sess.should_stop())
@@ -220,7 +220,7 @@ def test_request_stop_on_exception(self):
220220
c = tf.constant(0)
221221
v = tf.identity(c)
222222
coord = tf.train.Coordinator()
223-
coord_sess = monitored_session._CoordinatedSession(sess, coord, [])
223+
coord_sess = monitored_session._CoordinatedSession(sess, coord)
224224
self.assertFalse(coord_sess.should_stop())
225225
self.assertEqual(0, coord_sess.run(c))
226226
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
@@ -237,8 +237,9 @@ def test_stop_threads_on_exception(self):
237237
threads = [threading.Thread(
238238
target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)]
239239
for t in threads:
240+
coord.register_thread(t)
240241
t.start()
241-
coord_sess = monitored_session._CoordinatedSession(sess, coord, threads)
242+
coord_sess = monitored_session._CoordinatedSession(sess, coord)
242243
self.assertFalse(coord_sess.should_stop())
243244
for t in threads:
244245
self.assertTrue(t.is_alive())
@@ -261,8 +262,9 @@ def test_stop_threads_on_close(self):
261262
threads = [threading.Thread(
262263
target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)]
263264
for t in threads:
265+
coord.register_thread(t)
264266
t.start()
265-
coord_sess = monitored_session._CoordinatedSession(sess, coord, threads)
267+
coord_sess = monitored_session._CoordinatedSession(sess, coord)
266268
coord_sess.close()
267269
for t in threads:
268270
self.assertFalse(t.is_alive())
@@ -766,6 +768,46 @@ def test_regular_exception_pass_through_run(self):
766768
self.assertTrue(hook.raised)
767769
self.assertTrue(session.should_stop())
768770

771+
def test_regular_exception_reported_to_coord_pass_through_run(self):
772+
# Tests that regular exceptions reported to the coordinator from a thread
773+
# passes through a "run()" call within a "with MonitoredSession" block and
774+
# set the session in stop mode.
775+
with tf.Graph().as_default():
776+
gstep = tf.contrib.framework.get_or_create_global_step()
777+
scaffold = monitored_session.Scaffold()
778+
session = monitored_session.MonitoredSession('', scaffold=scaffold)
779+
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
780+
with session:
781+
self.assertEqual(0, session.run(gstep))
782+
# Report an exception through the coordinator.
783+
try:
784+
raise RuntimeError('a thread wants to stop')
785+
except RuntimeError as e:
786+
session.coord.request_stop(e)
787+
# Call run() which should raise the reported exception.
788+
self.assertEqual(0, session.run(gstep))
789+
# We should not hit this
790+
self.assertFalse(True)
791+
792+
def test_regular_exception_reported_to_coord_pass_through_return(self):
793+
# Tests that regular exceptions reported to the coordinator from a thread
794+
# passes through returning from a "with MonitoredSession" block and
795+
# set the session in stop mode.
796+
with tf.Graph().as_default():
797+
gstep = tf.contrib.framework.get_or_create_global_step()
798+
scaffold = monitored_session.Scaffold()
799+
session = monitored_session.MonitoredSession('', scaffold=scaffold)
800+
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
801+
with session:
802+
self.assertEqual(0, session.run(gstep))
803+
# Report an exception through the coordinator.
804+
try:
805+
raise RuntimeError('a thread wants to stop')
806+
except RuntimeError as e:
807+
session.coord.request_stop(e)
808+
# Do not call run, just terminate the with.session block cleanly.
809+
self.assertTrue(session.should_stop())
810+
769811
# This set of tests, verifies the session behavior when exceptions are raised
770812
# from code inside a "with MonitoredSession:" context.
771813

0 commit comments

Comments
 (0)