|
3 | 3 | from collections import Iterable |
4 | 4 |
|
5 | 5 |
|
6 | | -def iter_gradients(x): |
| 6 | +def iter_variables(x): |
7 | 7 | if isinstance(x, Variable): |
8 | 8 | if x.requires_grad: |
9 | | - yield x.grad.data if x.grad is not None else None |
| 9 | + yield (x.grad.data, x.data) if x.grad is not None else (None, None) |
10 | 10 | elif isinstance(x, Iterable): |
11 | 11 | for elem in x: |
12 | | - for result in iter_gradients(elem): |
| 12 | + for result in iter_variables(elem): |
13 | 13 | yield result |
14 | 14 |
|
15 | 15 |
|
@@ -94,21 +94,28 @@ def get_numerical_jacobian(fn, input, target, eps=1e-3): |
94 | 94 |
|
95 | 95 | def get_analytical_jacobian(input, output): |
96 | 96 | jacobian = make_jacobian(input, output.numel()) |
| 97 | + jacobian_reentrant = make_jacobian(input, output.numel()) |
97 | 98 | grad_output = output.data.clone().zero_() |
98 | 99 | flat_grad_output = grad_output.view(-1) |
| 100 | + reentrant = True |
99 | 101 |
|
100 | 102 | for i in range(flat_grad_output.numel()): |
101 | 103 | flat_grad_output.zero_() |
102 | 104 | flat_grad_output[i] = 1 |
103 | | - zero_gradients(input) |
104 | | - output.backward(grad_output, retain_graph=True) |
105 | | - for jacobian_x, d_x in zip(jacobian, iter_gradients(input)): |
106 | | - if d_x is None: |
107 | | - jacobian_x[:, i].zero_() |
108 | | - else: |
109 | | - jacobian_x[:, i] = d_x.to_dense() if d_x.is_sparse else d_x |
| 105 | + for jacobian_c in (jacobian, jacobian_reentrant): |
| 106 | + zero_gradients(input) |
| 107 | + output.backward(grad_output, create_graph=True) |
| 108 | + for jacobian_x, (d_x, _) in zip(jacobian_c, iter_variables(input)): |
| 109 | + if d_x is None: |
| 110 | + jacobian_x[:, i].zero_() |
| 111 | + else: |
| 112 | + jacobian_x[:, i] = d_x.to_dense() if d_x.is_sparse else d_x |
110 | 113 |
|
111 | | - return jacobian |
| 114 | + for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant): |
| 115 | + if (jacobian_x - jacobian_reentrant_x).abs().max() != 0: |
| 116 | + reentrant = False |
| 117 | + |
| 118 | + return jacobian, reentrant |
112 | 119 |
|
113 | 120 |
|
114 | 121 | def _as_tuple(x): |
@@ -151,13 +158,16 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3): |
151 | 158 | def fn(input): |
152 | 159 | return _as_tuple(func(*input))[i].data |
153 | 160 |
|
154 | | - analytical = get_analytical_jacobian(_as_tuple(inputs), o) |
| 161 | + analytical, reentrant = get_analytical_jacobian(_as_tuple(inputs), o) |
155 | 162 | numerical = get_numerical_jacobian(fn, inputs, inputs, eps) |
156 | 163 |
|
157 | 164 | for a, n in zip(analytical, numerical): |
158 | 165 | if not ((a - n).abs() <= (atol + rtol * n.abs())).all(): |
159 | 166 | return False |
160 | 167 |
|
| 168 | + if not reentrant: |
| 169 | + return False |
| 170 | + |
161 | 171 | # check if the backward multiplies by grad_output |
162 | 172 | zero_gradients(inputs) |
163 | 173 | output = _as_tuple(func(*inputs)) |
@@ -202,7 +212,7 @@ def new_func(*input_args): |
202 | 212 | input_args = input_args[:-len(grad_outputs)] |
203 | 213 | outputs = func(*input_args) |
204 | 214 | outputs = _as_tuple(outputs) |
205 | | - input_args = tuple(x for x in input_args if isinstance(x, Variable) if x.requires_grad) |
| 215 | + input_args = tuple(x for x in input_args if isinstance(x, Variable) and x.requires_grad) |
206 | 216 | grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs) |
207 | 217 | return grad_inputs |
208 | 218 |
|
|
0 commit comments