@@ -273,37 +273,39 @@ def test_gradient_checkpointing(self):
273273 model = self .model_class (** init_dict )
274274 model .to (torch_device )
275275
276+ assert not model .is_gradient_checkpointing and model .training
277+
276278 out = model (** inputs_dict ).sample
277279 # run the backwards pass on the model. For backwards pass, for simplicity purpose,
278280 # we won't calculate the loss and rather backprop on out.sum()
279281 model .zero_grad ()
280- out .sum ().backward ()
281282
282- # now we save the output and parameter gradients that we will use for comparison purposes with
283- # the non-checkpointed run.
284- output_not_checkpointed = out .data .clone ()
285- grad_not_checkpointed = {}
286- for name , param in model .named_parameters ():
287- grad_not_checkpointed [name ] = param .grad .data .clone ()
283+ labels = torch .randn_like (out )
284+ loss = (out - labels ).mean ()
285+ loss .backward ()
288286
289- model .enable_gradient_checkpointing ()
290- out = model (** inputs_dict ).sample
287+ # re-instantiate the model now enabling gradient checkpointing
288+ model_2 = self .model_class (** init_dict )
289+ # clone model
290+ model_2 .load_state_dict (model .state_dict ())
291+ model_2 .to (torch_device )
292+ model_2 .enable_gradient_checkpointing ()
293+
294+ assert model_2 .is_gradient_checkpointing and model_2 .training
295+
296+ out_2 = model_2 (** inputs_dict ).sample
291297 # run the backwards pass on the model. For backwards pass, for simplicity purpose,
292298 # we won't calculate the loss and rather backprop on out.sum()
293- model .zero_grad ()
294- out .sum ().backward ()
295-
296- # now we save the output and parameter gradients that we will use for comparison purposes with
297- # the non-checkpointed run.
298- output_checkpointed = out .data .clone ()
299- grad_checkpointed = {}
300- for name , param in model .named_parameters ():
301- grad_checkpointed [name ] = param .grad .data .clone ()
299+ model_2 .zero_grad ()
300+ loss_2 = (out_2 - labels ).mean ()
301+ loss_2 .backward ()
302302
303303 # compare the output and parameters gradients
304- self .assertTrue ((output_checkpointed == output_not_checkpointed ).all ())
305- for name in grad_checkpointed :
306- self .assertTrue (torch .allclose (grad_checkpointed [name ], grad_not_checkpointed [name ], atol = 5e-5 ))
304+ self .assertTrue ((loss - loss_2 ).abs () < 1e-5 )
305+ named_params = dict (model .named_parameters ())
306+ named_params_2 = dict (model_2 .named_parameters ())
307+ for name , param in named_params .items ():
308+ self .assertTrue (torch .allclose (param .grad .data , named_params_2 [name ].grad .data , atol = 5e-5 ))
307309
308310
309311# TODO(Patrick) - Re-add this test after having cleaned up LDM
0 commit comments