Skip to content

Commit 22963ed

Browse files
Fix gradient checkpointing test (huggingface#797)
* Fix gradient checkpointing test * more tsets
1 parent fab1752 commit 22963ed

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

tests/test_models_unet.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -273,37 +273,39 @@ def test_gradient_checkpointing(self):
273273
model = self.model_class(**init_dict)
274274
model.to(torch_device)
275275

276+
assert not model.is_gradient_checkpointing and model.training
277+
276278
out = model(**inputs_dict).sample
277279
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
278280
# we won't calculate the loss and rather backprop on out.sum()
279281
model.zero_grad()
280-
out.sum().backward()
281282

282-
# now we save the output and parameter gradients that we will use for comparison purposes with
283-
# the non-checkpointed run.
284-
output_not_checkpointed = out.data.clone()
285-
grad_not_checkpointed = {}
286-
for name, param in model.named_parameters():
287-
grad_not_checkpointed[name] = param.grad.data.clone()
283+
labels = torch.randn_like(out)
284+
loss = (out - labels).mean()
285+
loss.backward()
288286

289-
model.enable_gradient_checkpointing()
290-
out = model(**inputs_dict).sample
287+
# re-instantiate the model now enabling gradient checkpointing
288+
model_2 = self.model_class(**init_dict)
289+
# clone model
290+
model_2.load_state_dict(model.state_dict())
291+
model_2.to(torch_device)
292+
model_2.enable_gradient_checkpointing()
293+
294+
assert model_2.is_gradient_checkpointing and model_2.training
295+
296+
out_2 = model_2(**inputs_dict).sample
291297
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
292298
# we won't calculate the loss and rather backprop on out.sum()
293-
model.zero_grad()
294-
out.sum().backward()
295-
296-
# now we save the output and parameter gradients that we will use for comparison purposes with
297-
# the non-checkpointed run.
298-
output_checkpointed = out.data.clone()
299-
grad_checkpointed = {}
300-
for name, param in model.named_parameters():
301-
grad_checkpointed[name] = param.grad.data.clone()
299+
model_2.zero_grad()
300+
loss_2 = (out_2 - labels).mean()
301+
loss_2.backward()
302302

303303
# compare the output and parameters gradients
304-
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
305-
for name in grad_checkpointed:
306-
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
304+
self.assertTrue((loss - loss_2).abs() < 1e-5)
305+
named_params = dict(model.named_parameters())
306+
named_params_2 = dict(model_2.named_parameters())
307+
for name, param in named_params.items():
308+
self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
307309

308310

309311
# TODO(Patrick) - Re-add this test after having cleaned up LDM

0 commit comments

Comments
 (0)