@@ -187,7 +187,7 @@ def test_properties(self):
187187 with self .test_session () as sess :
188188 tf .constant (0.0 )
189189 coord = tf .train .Coordinator ()
190- coord_sess = monitored_session ._CoordinatedSession (sess , coord , [] )
190+ coord_sess = monitored_session ._CoordinatedSession (sess , coord )
191191 self .assertEquals (sess .graph , coord_sess .graph )
192192 self .assertEquals (sess .sess_str , coord_sess .sess_str )
193193
@@ -196,21 +196,21 @@ def test_run(self):
196196 c = tf .constant (0 )
197197 v = tf .identity (c )
198198 coord = tf .train .Coordinator ()
199- coord_sess = monitored_session ._CoordinatedSession (sess , coord , [] )
199+ coord_sess = monitored_session ._CoordinatedSession (sess , coord )
200200 self .assertEqual (42 , coord_sess .run (v , feed_dict = {c : 42 }))
201201
202202 def test_should_stop_on_close (self ):
203203 with self .test_session () as sess :
204204 coord = tf .train .Coordinator ()
205- coord_sess = monitored_session ._CoordinatedSession (sess , coord , [] )
205+ coord_sess = monitored_session ._CoordinatedSession (sess , coord )
206206 self .assertFalse (coord_sess .should_stop ())
207207 coord_sess .close ()
208208 self .assertTrue (coord_sess .should_stop ())
209209
210210 def test_should_stop_on_coord_stop (self ):
211211 with self .test_session () as sess :
212212 coord = tf .train .Coordinator ()
213- coord_sess = monitored_session ._CoordinatedSession (sess , coord , [] )
213+ coord_sess = monitored_session ._CoordinatedSession (sess , coord )
214214 self .assertFalse (coord_sess .should_stop ())
215215 coord .request_stop ()
216216 self .assertTrue (coord_sess .should_stop ())
@@ -220,7 +220,7 @@ def test_request_stop_on_exception(self):
220220 c = tf .constant (0 )
221221 v = tf .identity (c )
222222 coord = tf .train .Coordinator ()
223- coord_sess = monitored_session ._CoordinatedSession (sess , coord , [] )
223+ coord_sess = monitored_session ._CoordinatedSession (sess , coord )
224224 self .assertFalse (coord_sess .should_stop ())
225225 self .assertEqual (0 , coord_sess .run (c ))
226226 self .assertEqual (1 , coord_sess .run (v , feed_dict = {c : 1 }))
@@ -237,8 +237,9 @@ def test_stop_threads_on_exception(self):
237237 threads = [threading .Thread (
238238 target = busy_wait_for_coord_stop , args = (coord ,)) for _ in range (3 )]
239239 for t in threads :
240+ coord .register_thread (t )
240241 t .start ()
241- coord_sess = monitored_session ._CoordinatedSession (sess , coord , threads )
242+ coord_sess = monitored_session ._CoordinatedSession (sess , coord )
242243 self .assertFalse (coord_sess .should_stop ())
243244 for t in threads :
244245 self .assertTrue (t .is_alive ())
@@ -261,8 +262,9 @@ def test_stop_threads_on_close(self):
261262 threads = [threading .Thread (
262263 target = busy_wait_for_coord_stop , args = (coord ,)) for _ in range (3 )]
263264 for t in threads :
265+ coord .register_thread (t )
264266 t .start ()
265- coord_sess = monitored_session ._CoordinatedSession (sess , coord , threads )
267+ coord_sess = monitored_session ._CoordinatedSession (sess , coord )
266268 coord_sess .close ()
267269 for t in threads :
268270 self .assertFalse (t .is_alive ())
@@ -766,6 +768,46 @@ def test_regular_exception_pass_through_run(self):
766768 self .assertTrue (hook .raised )
767769 self .assertTrue (session .should_stop ())
768770
771+ def test_regular_exception_reported_to_coord_pass_through_run (self ):
772+ # Tests that regular exceptions reported to the coordinator from a thread
773+ # passes through a "run()" call within a "with MonitoredSession" block and
774+ # set the session in stop mode.
775+ with tf .Graph ().as_default ():
776+ gstep = tf .contrib .framework .get_or_create_global_step ()
777+ scaffold = monitored_session .Scaffold ()
778+ session = monitored_session .MonitoredSession ('' , scaffold = scaffold )
779+ with self .assertRaisesRegexp (RuntimeError , 'a thread wants to stop' ):
780+ with session :
781+ self .assertEqual (0 , session .run (gstep ))
782+ # Report an exception through the coordinator.
783+ try :
784+ raise RuntimeError ('a thread wants to stop' )
785+ except RuntimeError as e :
786+ session .coord .request_stop (e )
787+ # Call run() which should raise the reported exception.
788+ self .assertEqual (0 , session .run (gstep ))
789+ # We should not hit this
790+ self .assertFalse (True )
791+
792+ def test_regular_exception_reported_to_coord_pass_through_return (self ):
793+ # Tests that regular exceptions reported to the coordinator from a thread
794+ # passes through returning from a "with MonitoredSession" block and
795+ # set the session in stop mode.
796+ with tf .Graph ().as_default ():
797+ gstep = tf .contrib .framework .get_or_create_global_step ()
798+ scaffold = monitored_session .Scaffold ()
799+ session = monitored_session .MonitoredSession ('' , scaffold = scaffold )
800+ with self .assertRaisesRegexp (RuntimeError , 'a thread wants to stop' ):
801+ with session :
802+ self .assertEqual (0 , session .run (gstep ))
803+ # Report an exception through the coordinator.
804+ try :
805+ raise RuntimeError ('a thread wants to stop' )
806+ except RuntimeError as e :
807+ session .coord .request_stop (e )
808+ # Do not call run, just terminate the with.session block cleanly.
809+ self .assertTrue (session .should_stop ())
810+
769811 # This set of tests, verifies the session behavior when exceptions are raised
770812 # from code inside a "with MonitoredSession:" context.
771813
0 commit comments