Skip to content

Commit 7e4ddcf

Browse files
colesburysoumith
authored andcommitted
Remove names from register_hook calls (pytorch#446)
The register hook calls now return an object that can be used to remove the hook. For example, >>> h = module.register_forward_hook(callback) >>> h.remove() # removes hook Or as a context manager: >>> with module.register_forward_hook(callback): ... pass This makes it easier for libraries to use hooks without worrying about name collisions.
1 parent 3152be5 commit 7e4ddcf

File tree

7 files changed

+161
-141
lines changed

7 files changed

+161
-141
lines changed

test/test_autograd.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,28 @@ def bw_hook(inc, grad):
7272
counter[0] += inc
7373

7474
z = x ** 2 + x * 2 + x * y + y
75-
z.register_hook('test', lambda *args: bw_hook(1, *args))
75+
test = z.register_hook(lambda *args: bw_hook(1, *args))
7676
z.backward(torch.ones(5, 5), retain_variables=True)
7777
self.assertEqual(counter[0], 1)
7878

79-
z.register_hook('test2', lambda *args: bw_hook(2, *args))
79+
test2 = z.register_hook(lambda *args: bw_hook(2, *args))
8080
z.backward(torch.ones(5, 5), retain_variables=True)
8181
self.assertEqual(counter[0], 4)
8282

83-
z.remove_hook('test2')
83+
test2.remove()
8484
z.backward(torch.ones(5, 5), retain_variables=True)
8585
self.assertEqual(counter[0], 5)
8686

8787
def bw_hook_modify(grad):
8888
return grad.mul(2)
8989

90-
z.remove_hook('test')
91-
z.register_hook('test', bw_hook_modify)
90+
test.remove()
91+
z.register_hook(bw_hook_modify)
9292
y.grad.zero_()
9393
z.backward(torch.ones(5, 5), retain_variables=True)
9494
self.assertEqual(y.grad, (x.data + 1) * 2)
9595

96-
y.register_hook('test', bw_hook_modify)
96+
y.register_hook(bw_hook_modify)
9797
y.grad.zero_()
9898
z.backward(torch.ones(5, 5))
9999
self.assertEqual(y.grad, (x.data + 1) * 4)
@@ -420,8 +420,7 @@ def backward(self, grad_a, grad_b):
420420

421421
q, p = Identity()(x, y)
422422
# Make sure hooks only receive grad from usage of q, not x.
423-
q.register_hook(
424-
'test', lambda grad: self.assertEqual(grad, torch.ones(5, 5)))
423+
q.register_hook(lambda grad: self.assertEqual(grad, torch.ones(5, 5)))
425424
(q + p + x).sum().backward()
426425
self.assertEqual(x.grad, torch.ones(5, 5) * 3)
427426
self.assertEqual(y.grad, torch.ones(5, 5))

test/test_nn.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,14 @@ def bw_hook(inc, h_module, grad_input, grad_output):
227227
self.assertEqual(grad_output[0], torch.ones(5, 5) * 2)
228228
counter['backwards'] += inc
229229

230-
module.register_forward_hook('test', lambda *args: fw_hook(1, *args))
230+
test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args))
231231

232232
module(input)
233233
module(input)
234234
self.assertEqual(counter['forwards'], 2)
235235
self.assertEqual(counter['backwards'], 0)
236236

237-
module.register_backward_hook('test', lambda *args: bw_hook(1, *args))
237+
test_bwd = module.register_backward_hook(lambda *args: bw_hook(1, *args))
238238

239239
output = module(input)
240240
self.assertEqual(counter['forwards'], 3)
@@ -248,32 +248,32 @@ def bw_hook(inc, h_module, grad_input, grad_output):
248248
self.assertEqual(counter['forwards'], 3)
249249
self.assertEqual(counter['backwards'], 2)
250250

251-
module.register_forward_hook('test2', lambda *args: fw_hook(2, *args))
251+
test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args))
252252

253253
output = module(input)
254254
self.assertEqual(counter['forwards'], 6)
255255
self.assertEqual(counter['backwards'], 2)
256256

257-
module.register_backward_hook('test2', lambda *args: bw_hook(2, *args))
257+
test2_bwd = module.register_backward_hook(lambda *args: bw_hook(2, *args))
258258

259259
module(input).backward(torch.ones(5, 5) * 2)
260260
self.assertEqual(counter['forwards'], 9)
261261
self.assertEqual(counter['backwards'], 5)
262262

263-
module.remove_backward_hook('test2')
263+
test2_bwd.remove()
264264

265265
module(input).backward(torch.ones(5, 5) * 2)
266266
self.assertEqual(counter['forwards'], 12)
267267
self.assertEqual(counter['backwards'], 6)
268268

269-
module.remove_forward_hook('test2')
269+
test2_fwd.remove()
270270

271271
module(input).backward(torch.ones(5, 5) * 2)
272272
self.assertEqual(counter['forwards'], 13)
273273
self.assertEqual(counter['backwards'], 7)
274274

275-
module.remove_forward_hook('test')
276-
module.remove_backward_hook('test')
275+
test_fwd.remove()
276+
test_bwd.remove()
277277

278278
def test_hook_fail(self):
279279
module = nn.Sigmoid()
@@ -291,33 +291,29 @@ def bw_fail1(self, grad_input, grad_output):
291291
def bw_fail2(self, grad_input, grad_output):
292292
return grad_input + (torch.randn(2, 2),)
293293

294-
module.register_forward_hook('fw_fail', fw_fail1)
295-
with self.assertRaises(RuntimeError) as err:
296-
module(input)
297-
self.assertIn("fw_fail", err.exception.args[0])
298-
self.assertIn("didn't return None", err.exception.args[0])
299-
module.remove_forward_hook('fw_fail')
294+
with module.register_forward_hook(fw_fail1):
295+
with self.assertRaises(RuntimeError) as err:
296+
module(input)
297+
self.assertIn("fw_fail", err.exception.args[0])
298+
self.assertIn("didn't return None", err.exception.args[0])
300299

301-
module.register_forward_hook('fw_fail2', fw_fail2)
302-
with self.assertRaises(RuntimeError) as err:
303-
module(input)
304-
self.assertIn("fw_fail2", err.exception.args[0])
305-
self.assertIn("didn't return None", err.exception.args[0])
306-
module.remove_forward_hook('fw_fail2')
307-
308-
module.register_backward_hook('bw_fail', bw_fail1)
309-
with self.assertRaises(RuntimeError) as err:
310-
module(input).sum().backward()
311-
self.assertIn("bw_fail", err.exception.args[0])
312-
self.assertIn("got 0, but expected 1", err.exception.args[0])
313-
module.remove_backward_hook('bw_fail')
314-
315-
module.register_backward_hook('bw_fail2', bw_fail2)
316-
with self.assertRaises(RuntimeError) as err:
317-
module(input).sum().backward()
318-
self.assertIn("bw_fail2", err.exception.args[0])
319-
self.assertIn("got 2, but expected 1", err.exception.args[0])
320-
module.remove_backward_hook('bw_fail2')
300+
with module.register_forward_hook(fw_fail2):
301+
with self.assertRaises(RuntimeError) as err:
302+
module(input)
303+
self.assertIn("fw_fail2", err.exception.args[0])
304+
self.assertIn("didn't return None", err.exception.args[0])
305+
306+
with module.register_backward_hook(bw_fail1):
307+
with self.assertRaises(RuntimeError) as err:
308+
module(input).sum().backward()
309+
self.assertIn("bw_fail", err.exception.args[0])
310+
self.assertIn("got 0, but expected 1", err.exception.args[0])
311+
312+
with module.register_backward_hook(bw_fail2):
313+
with self.assertRaises(RuntimeError) as err:
314+
module(input).sum().backward()
315+
self.assertIn("bw_fail2", err.exception.args[0])
316+
self.assertIn("got 2, but expected 1", err.exception.args[0])
321317

322318
def test_hook_writeable(self):
323319
module = nn.Linear(5, 5)
@@ -326,7 +322,7 @@ def test_hook_writeable(self):
326322
def bw_hook(self, grad_input, grad_output):
327323
return tuple(gi * 2 for gi in grad_input)
328324

329-
module.register_backward_hook('test', bw_hook)
325+
module.register_backward_hook(bw_hook)
330326
module(input).backward(torch.ones(5, 5))
331327
expected_grad = torch.ones(5, 5).mm(module.weight.data) * 2
332328
self.assertEqual(input.grad, expected_grad)

torch/autograd/function.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch._C as _C
3+
import torch.utils.hooks as hooks
34
from collections import OrderedDict
45
from itertools import chain
56

@@ -106,16 +107,12 @@ def mark_non_differentiable(self, *args):
106107
"""
107108
self.non_differentiable = args
108109

109-
def register_hook(self, name, hook):
110-
self._backward_hooks = self._backward_hooks or OrderedDict()
111-
assert name not in self._backward_hooks, \
112-
"Trying to register a second hook with name {}".format(name)
113-
self._backward_hooks[name] = hook
114-
115-
def remove_hook(self, name):
116-
assert self._backward_hooks and name in self._backward_hooks, \
117-
"Trying to remove an inexistent hook with name {}".format(name)
118-
del self._backward_hooks[name]
110+
def register_hook(self, hook):
111+
if self._backward_hooks is None:
112+
self._backward_hooks = OrderedDict()
113+
handle = hooks.RemovableHandle(self._backward_hooks)
114+
self._backward_hooks[id(handle)] = hook
115+
return handle
119116

120117
def forward(self, *input):
121118
"""Performs the operation.

torch/autograd/variable.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import torch._C as _C
33
from collections import OrderedDict
4+
import torch.utils.hooks as hooks
45

56
from ._functions import *
67

@@ -155,54 +156,43 @@ def backward(self, gradient=None, retain_variables=False):
155156
gradient = self.data.new().resize_as_(self.data).fill_(1)
156157
self._execution_engine.run_backward((self,), (gradient,), retain_variables)
157158

158-
def register_hook(self, name, hook):
159-
"""Registers a named backward hook.
159+
def register_hook(self, hook):
160+
"""Registers a backward hook.
160161
161-
Given hook will be saved and called with the gradient w.r.t. the
162-
variable at every backward pass. To remove a hook use
163-
:func:`remove_hook`. Saved hooks are called in the same order
164-
in which they were registered.
162+
The hook will be called every time a gradient with respect to the
163+
variable is computed. The hook should have the following signature::
165164
166-
You should never modify the data of gradient tensor given to your hook,
167-
but you can use it in out-of-place operations and return a new tensor::
165+
hook(grad) -> Tensor or None
168166
169-
variable.register_hook('double_grad', lambda grad: grad * 2)
167+
The hook should not modify its argument, but it can optionally return
168+
a new gradient which will be used in place of :attr:`grad`.
170169
171-
The returned value will replace the original tensor. Note that you
172-
don't need to return anything from the hook, in which case it won't
173-
change the gradient.
170+
This function returns a handle with a method ``handle.remove()``
171+
that removes the hook from the module.
174172
175-
176-
Parameters:
177-
name(str): Name of the hook.
178-
hook(callable): Hook callable. It will be given a single argument
179-
that's gradient w.r.t. the variable its registered on.
173+
Example:
174+
>>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
175+
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
176+
>>> v.backward(torch.Tensor([1, 1, 1]))
177+
>>> v.grad
178+
2
179+
2
180+
2
181+
[torch.FloatTensor of size 3]
182+
>>> h.remove() # removes the hook
180183
"""
181184
if self.volatile:
182-
raise RuntimeError("registering hook on a volatile variable")
185+
raise RuntimeError("cannot register a hook on a volatile variable")
183186
if not self.requires_grad:
184-
raise RuntimeError("registering hook on a variable that doesn't require gradient")
187+
raise RuntimeError("cannot register a hook on a variable that "
188+
"doesn't require gradient")
185189
if self._backward_hooks is None:
186190
self._backward_hooks = OrderedDict()
187191
if self.creator is not None:
188192
self.creator._register_hook_dict(self)
189-
assert name not in self._backward_hooks, \
190-
"Trying to register a second hook with name {}".format(name)
191-
self._backward_hooks[name] = hook
192-
193-
def remove_hook(self, name):
194-
"""Removes a previously registered backward hook.
195-
196-
Raises RuntimeError if there's no hook registered under a given name.
197-
198-
Parameters:
199-
name(str): Name of the hook.
200-
"""
201-
if self.volatile:
202-
raise RuntimeError("volatile variables don't support hooks")
203-
assert self._backward_hooks and name in self._backward_hooks, \
204-
"Trying to remove an inexistent hook with name {}".format(name)
205-
del self._backward_hooks[name]
193+
handle = hooks.RemovableHandle(self._backward_hooks)
194+
self._backward_hooks[id(handle)] = hook
195+
return handle
206196

207197
def _do_backward(self, grad_output, retain_variables):
208198
assert len(grad_output) == 1

torch/csrc/autograd/function.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -461,14 +461,15 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *inputs)
461461

462462
// We need a reference to a smart pointer that will outlive the duration of
463463
// a function call, so that the char* pointer is valid even after it returns
464-
static char* _try_get_name(PyObject *key, THPObjectPtr& tmp) {
464+
static char* _try_get_name(PyObject *hook, THPObjectPtr& tmp) {
465+
tmp = PyObject_GetAttrString(hook, "__name__");
465466
#if PY_MAJOR_VERSION == 2
466-
if (PyString_Check(key)) {
467-
return PyString_AS_STRING(key);
467+
if (tmp && PyString_Check(tmp.get())) {
468+
return PyString_AS_STRING(tmp.get());
468469
}
469470
#else
470-
if (PyUnicode_Check(key)) {
471-
tmp = PyUnicode_AsASCIIString(key);
471+
if (tmp && PyUnicode_Check(tmp.get())) {
472+
tmp = PyUnicode_AsASCIIString(tmp.get());
472473
return PyBytes_AS_STRING(tmp.get());
473474
}
474475
#endif
@@ -481,7 +482,7 @@ static char* _try_get_name(PyObject *key, THPObjectPtr& tmp) {
481482
hook_name ? "' " : ""
482483

483484
static void _ensure_correct_hook_result_single(PyObject *original,
484-
PyObject *returned, PyObject *key)
485+
PyObject *returned, PyObject *hook)
485486
{
486487
#if PY_MAJOR_VERSION == 2
487488
static PyObject *IS_SAME_SIZE_NAME = PyString_FromString("is_same_size");
@@ -491,7 +492,7 @@ static void _ensure_correct_hook_result_single(PyObject *original,
491492
THPObjectPtr tmp;
492493
// Check that the type matches
493494
if(Py_TYPE(original) != Py_TYPE(returned)) {
494-
char *hook_name = _try_get_name(key, tmp);
495+
char *hook_name = _try_get_name(hook, tmp);
495496
THPUtils_setError("backward hook %s%s%shas changed the type of "
496497
"grad_input (was %s, but got %s)",
497498
OPTIONAL_HOOK_NAME,
@@ -505,7 +506,7 @@ static void _ensure_correct_hook_result_single(PyObject *original,
505506
THPObjectPtr is_same_size = PyObject_CallMethodObjArgs(original,
506507
IS_SAME_SIZE_NAME, returned, NULL);
507508
if(is_same_size.get() != Py_True) {
508-
char *hook_name = _try_get_name(key, tmp);
509+
char *hook_name = _try_get_name(hook, tmp);
509510
THPUtils_setError("backward hook %s%s%shas changed the size of "
510511
"grad_input",
511512
OPTIONAL_HOOK_NAME
@@ -515,12 +516,12 @@ static void _ensure_correct_hook_result_single(PyObject *original,
515516
}
516517

517518
static void _ensure_correct_hook_result(THPObjectPtr& grad_input,
518-
THPObjectPtr& result, PyObject *key)
519+
THPObjectPtr& result, PyObject *hook)
519520
{
520521
THPObjectPtr tmp;
521522
// Check that the tuple sizes match
522523
if (PyTuple_GET_SIZE(result.get()) != PyTuple_GET_SIZE(grad_input.get())) {
523-
char *hook_name = _try_get_name(key, tmp);
524+
char *hook_name = _try_get_name(hook, tmp);
524525
THPUtils_setError("backward hook %s%s%sreturned an incorrect number "
525526
"of gradients (got %ld, but expected %ld)",
526527
OPTIONAL_HOOK_NAME,
@@ -534,7 +535,7 @@ static void _ensure_correct_hook_result(THPObjectPtr& grad_input,
534535
for (int i = 0; i < size; i++) {
535536
PyObject *original = PyTuple_GET_ITEM(grad_input.get(), i);
536537
PyObject *returned = PyTuple_GET_ITEM(result.get(), i);
537-
_ensure_correct_hook_result_single(original, returned, key);
538+
_ensure_correct_hook_result_single(original, returned, hook);
538539
}
539540
}
540541

@@ -570,7 +571,7 @@ static void _call_output_hooks(THPFunction *self, THPObjectPtr& grad_output)
570571
if (result.get() != Py_None) {
571572
// Check all possible inconsistencies of the output that we can detect
572573
// (sizes, types, etc.)
573-
_ensure_correct_hook_result_single(old_grad, result, key);
574+
_ensure_correct_hook_result_single(old_grad, result, value);
574575

575576
// Replace the old gradient
576577
PyTuple_SET_ITEM(new_grad_output.get(), i, result.release());
@@ -603,7 +604,7 @@ static void _call_function_hooks(THPFunction *self, THPObjectPtr& grad_input, TH
603604
_ensure_tuple(result);
604605
// Check all possible inconsistencies of the output that we can detect
605606
// (sizes, types, etc.)
606-
_ensure_correct_hook_result(grad_input, result, key);
607+
_ensure_correct_hook_result(grad_input, result, value);
607608
grad_input = result.release();
608609
}
609610
}

0 commit comments

Comments
 (0)