Skip to content

Commit 6336300

Browse files
committed
Fix bug where adding a hook could replace an existing hook.
We were keying hooks by RemovableHandle id. However, we don't hold onto handles and ids of dead objects can be reused. This replaces id(handle) with a global counter.
1 parent 5073132 commit 6336300

File tree

4 files changed

+10
-7
lines changed

4 files changed

+10
-7
lines changed

torch/autograd/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _register_hook(backward_hooks, hook):
111111
if backward_hooks is None:
112112
backward_hooks = OrderedDict()
113113
handle = hooks.RemovableHandle(backward_hooks)
114-
backward_hooks[id(handle)] = hook
114+
backward_hooks[handle.id] = hook
115115
return backward_hooks, handle
116116

117117
def forward(self, *input):

torch/autograd/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def register_hook(self, hook):
180180
if self.creator is not None:
181181
self.creator._register_hook_dict(self)
182182
handle = hooks.RemovableHandle(self._backward_hooks)
183-
self._backward_hooks[id(handle)] = hook
183+
self._backward_hooks[handle.id] = hook
184184
return handle
185185

186186
def reinforce(self, reward):

torch/nn/modules/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def register_backward_hook(self, hook):
179179
that removes the hook from the module.
180180
"""
181181
handle = hooks.RemovableHandle(self._backward_hooks)
182-
self._backward_hooks[id(handle)] = hook
182+
self._backward_hooks[handle.id] = hook
183183
return handle
184184

185185
def register_forward_hook(self, hook):
@@ -195,7 +195,7 @@ def register_forward_hook(self, hook):
195195
that removes the hook from the module.
196196
"""
197197
handle = hooks.RemovableHandle(self._forward_hooks)
198-
self._forward_hooks[id(handle)] = hook
198+
self._forward_hooks[handle.id] = hook
199199
return handle
200200

201201
def __call__(self, *input, **kwargs):

torch/utils/hooks.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
class RemovableHandle(object):
55
"""A handle which provides the capability to remove a hook."""
66

7+
next_id = 0
8+
79
def __init__(self, hooks_dict):
810
self.hooks_dict_ref = weakref.ref(hooks_dict)
11+
self.id = RemovableHandle.next_id
12+
RemovableHandle.next_id += 1
913

1014
def remove(self):
1115
hooks_dict = self.hooks_dict_ref()
12-
key = id(self)
13-
if hooks_dict is not None and key in hooks_dict:
14-
del hooks_dict[key]
16+
if hooks_dict is not None and self.id in hooks_dict:
17+
del hooks_dict[self.id]
1518

1619
def __enter__(self):
1720
return self

0 commit comments

Comments
 (0)