@@ -227,14 +227,14 @@ def bw_hook(inc, h_module, grad_input, grad_output):
227227 self .assertEqual (grad_output [0 ], torch .ones (5 , 5 ) * 2 )
228228 counter ['backwards' ] += inc
229229
230- module .register_forward_hook ('test' , lambda * args : fw_hook (1 , * args ))
230+ test_fwd = module .register_forward_hook (lambda * args : fw_hook (1 , * args ))
231231
232232 module (input )
233233 module (input )
234234 self .assertEqual (counter ['forwards' ], 2 )
235235 self .assertEqual (counter ['backwards' ], 0 )
236236
237- module .register_backward_hook ('test' , lambda * args : bw_hook (1 , * args ))
237+ test_bwd = module .register_backward_hook (lambda * args : bw_hook (1 , * args ))
238238
239239 output = module (input )
240240 self .assertEqual (counter ['forwards' ], 3 )
@@ -248,32 +248,32 @@ def bw_hook(inc, h_module, grad_input, grad_output):
248248 self .assertEqual (counter ['forwards' ], 3 )
249249 self .assertEqual (counter ['backwards' ], 2 )
250250
251- module .register_forward_hook ('test2' , lambda * args : fw_hook (2 , * args ))
251+ test2_fwd = module .register_forward_hook (lambda * args : fw_hook (2 , * args ))
252252
253253 output = module (input )
254254 self .assertEqual (counter ['forwards' ], 6 )
255255 self .assertEqual (counter ['backwards' ], 2 )
256256
257- module .register_backward_hook ('test2' , lambda * args : bw_hook (2 , * args ))
257+ test2_bwd = module .register_backward_hook (lambda * args : bw_hook (2 , * args ))
258258
259259 module (input ).backward (torch .ones (5 , 5 ) * 2 )
260260 self .assertEqual (counter ['forwards' ], 9 )
261261 self .assertEqual (counter ['backwards' ], 5 )
262262
263- module . remove_backward_hook ( 'test2' )
263+ test2_bwd . remove ( )
264264
265265 module (input ).backward (torch .ones (5 , 5 ) * 2 )
266266 self .assertEqual (counter ['forwards' ], 12 )
267267 self .assertEqual (counter ['backwards' ], 6 )
268268
269- module . remove_forward_hook ( 'test2' )
269+ test2_fwd . remove ( )
270270
271271 module (input ).backward (torch .ones (5 , 5 ) * 2 )
272272 self .assertEqual (counter ['forwards' ], 13 )
273273 self .assertEqual (counter ['backwards' ], 7 )
274274
275- module . remove_forward_hook ( 'test' )
276- module . remove_backward_hook ( 'test' )
275+ test_fwd . remove ( )
276+ test_bwd . remove ( )
277277
278278 def test_hook_fail (self ):
279279 module = nn .Sigmoid ()
@@ -291,33 +291,29 @@ def bw_fail1(self, grad_input, grad_output):
291291 def bw_fail2 (self , grad_input , grad_output ):
292292 return grad_input + (torch .randn (2 , 2 ),)
293293
294- module .register_forward_hook ('fw_fail' , fw_fail1 )
295- with self .assertRaises (RuntimeError ) as err :
296- module (input )
297- self .assertIn ("fw_fail" , err .exception .args [0 ])
298- self .assertIn ("didn't return None" , err .exception .args [0 ])
299- module .remove_forward_hook ('fw_fail' )
294+ with module .register_forward_hook (fw_fail1 ):
295+ with self .assertRaises (RuntimeError ) as err :
296+ module (input )
297+ self .assertIn ("fw_fail" , err .exception .args [0 ])
298+ self .assertIn ("didn't return None" , err .exception .args [0 ])
300299
301- module .register_forward_hook ('fw_fail2' , fw_fail2 )
302- with self .assertRaises (RuntimeError ) as err :
303- module (input )
304- self .assertIn ("fw_fail2" , err .exception .args [0 ])
305- self .assertIn ("didn't return None" , err .exception .args [0 ])
306- module .remove_forward_hook ('fw_fail2' )
307-
308- module .register_backward_hook ('bw_fail' , bw_fail1 )
309- with self .assertRaises (RuntimeError ) as err :
310- module (input ).sum ().backward ()
311- self .assertIn ("bw_fail" , err .exception .args [0 ])
312- self .assertIn ("got 0, but expected 1" , err .exception .args [0 ])
313- module .remove_backward_hook ('bw_fail' )
314-
315- module .register_backward_hook ('bw_fail2' , bw_fail2 )
316- with self .assertRaises (RuntimeError ) as err :
317- module (input ).sum ().backward ()
318- self .assertIn ("bw_fail2" , err .exception .args [0 ])
319- self .assertIn ("got 2, but expected 1" , err .exception .args [0 ])
320- module .remove_backward_hook ('bw_fail2' )
300+ with module .register_forward_hook (fw_fail2 ):
301+ with self .assertRaises (RuntimeError ) as err :
302+ module (input )
303+ self .assertIn ("fw_fail2" , err .exception .args [0 ])
304+ self .assertIn ("didn't return None" , err .exception .args [0 ])
305+
306+ with module .register_backward_hook (bw_fail1 ):
307+ with self .assertRaises (RuntimeError ) as err :
308+ module (input ).sum ().backward ()
309+ self .assertIn ("bw_fail" , err .exception .args [0 ])
310+ self .assertIn ("got 0, but expected 1" , err .exception .args [0 ])
311+
312+ with module .register_backward_hook (bw_fail2 ):
313+ with self .assertRaises (RuntimeError ) as err :
314+ module (input ).sum ().backward ()
315+ self .assertIn ("bw_fail2" , err .exception .args [0 ])
316+ self .assertIn ("got 2, but expected 1" , err .exception .args [0 ])
321317
322318 def test_hook_writeable (self ):
323319 module = nn .Linear (5 , 5 )
@@ -326,7 +322,7 @@ def test_hook_writeable(self):
326322 def bw_hook (self , grad_input , grad_output ):
327323 return tuple (gi * 2 for gi in grad_input )
328324
329- module .register_backward_hook ('test' , bw_hook )
325+ module .register_backward_hook (bw_hook )
330326 module (input ).backward (torch .ones (5 , 5 ))
331327 expected_grad = torch .ones (5 , 5 ).mm (module .weight .data ) * 2
332328 self .assertEqual (input .grad , expected_grad )
0 commit comments