Skip to content

Commit 2a83798

Browse files
committed
add reentrancy checking for gradcheck.
1 parent eb1ac73 commit 2a83798

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

torch/autograd/gradcheck.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from collections import Iterable
44

55

6-
def iter_gradients(x):
6+
def iter_variables(x):
77
if isinstance(x, Variable):
88
if x.requires_grad:
9-
yield x.grad.data if x.grad is not None else None
9+
yield (x.grad.data, x.data) if x.grad is not None else (None, None)
1010
elif isinstance(x, Iterable):
1111
for elem in x:
12-
for result in iter_gradients(elem):
12+
for result in iter_variables(elem):
1313
yield result
1414

1515

@@ -94,21 +94,28 @@ def get_numerical_jacobian(fn, input, target, eps=1e-3):
9494

9595
def get_analytical_jacobian(input, output):
9696
jacobian = make_jacobian(input, output.numel())
97+
jacobian_reentrant = make_jacobian(input, output.numel())
9798
grad_output = output.data.clone().zero_()
9899
flat_grad_output = grad_output.view(-1)
100+
reentrant = True
99101

100102
for i in range(flat_grad_output.numel()):
101103
flat_grad_output.zero_()
102104
flat_grad_output[i] = 1
103-
zero_gradients(input)
104-
output.backward(grad_output, retain_graph=True)
105-
for jacobian_x, d_x in zip(jacobian, iter_gradients(input)):
106-
if d_x is None:
107-
jacobian_x[:, i].zero_()
108-
else:
109-
jacobian_x[:, i] = d_x.to_dense() if d_x.is_sparse else d_x
105+
for jacobian_c in (jacobian, jacobian_reentrant):
106+
zero_gradients(input)
107+
output.backward(grad_output, create_graph=True)
108+
for jacobian_x, (d_x, _) in zip(jacobian_c, iter_variables(input)):
109+
if d_x is None:
110+
jacobian_x[:, i].zero_()
111+
else:
112+
jacobian_x[:, i] = d_x.to_dense() if d_x.is_sparse else d_x
110113

111-
return jacobian
114+
for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant):
115+
if (jacobian_x - jacobian_reentrant_x).abs().max() != 0:
116+
reentrant = False
117+
118+
return jacobian, reentrant
112119

113120

114121
def _as_tuple(x):
@@ -151,13 +158,16 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3):
151158
def fn(input):
152159
return _as_tuple(func(*input))[i].data
153160

154-
analytical = get_analytical_jacobian(_as_tuple(inputs), o)
161+
analytical, reentrant = get_analytical_jacobian(_as_tuple(inputs), o)
155162
numerical = get_numerical_jacobian(fn, inputs, inputs, eps)
156163

157164
for a, n in zip(analytical, numerical):
158165
if not ((a - n).abs() <= (atol + rtol * n.abs())).all():
159166
return False
160167

168+
if not reentrant:
169+
return False
170+
161171
# check if the backward multiplies by grad_output
162172
zero_gradients(inputs)
163173
output = _as_tuple(func(*inputs))
@@ -202,7 +212,7 @@ def new_func(*input_args):
202212
input_args = input_args[:-len(grad_outputs)]
203213
outputs = func(*input_args)
204214
outputs = _as_tuple(outputs)
205-
input_args = tuple(x for x in input_args if isinstance(x, Variable) if x.requires_grad)
215+
input_args = tuple(x for x in input_args if isinstance(x, Variable) and x.requires_grad)
206216
grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs)
207217
return grad_inputs
208218

0 commit comments

Comments
 (0)