@@ -426,16 +426,18 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
426426 scheduler = scheduler_class (** scheduler_config )
427427 scheduler .set_timesteps (num_inference_steps )
428428
429- # copy over dummy past residuals
429+ # copy over dummy past residuals (must be after setting timesteps)
430430 scheduler .ets = dummy_past_residuals [:]
431431
432432 with tempfile .TemporaryDirectory () as tmpdirname :
433433 scheduler .save_config (tmpdirname )
434434 new_scheduler = scheduler_class .from_config (tmpdirname )
435435 # copy over dummy past residuals
436- new_scheduler .ets = dummy_past_residuals [:]
437436 new_scheduler .set_timesteps (num_inference_steps )
438437
438+ # copy over dummy past residual (must be after setting timesteps)
439+ new_scheduler .ets = dummy_past_residuals [:]
440+
439441 output = scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
440442 new_output = new_scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
441443
@@ -461,19 +463,19 @@ def test_pytorch_equal_numpy(self):
461463
462464 scheduler_config = self .get_scheduler_config ()
463465 scheduler = scheduler_class (tensor_format = "np" , ** scheduler_config )
464- # copy over dummy past residuals
465- scheduler .ets = dummy_past_residuals [:]
466466
467467 scheduler_pt = scheduler_class (tensor_format = "pt" , ** scheduler_config )
468- # copy over dummy past residuals
469- scheduler_pt .ets = dummy_past_residuals_pt [:]
470468
471469 if num_inference_steps is not None and hasattr (scheduler , "set_timesteps" ):
472470 scheduler .set_timesteps (num_inference_steps )
473471 scheduler_pt .set_timesteps (num_inference_steps )
474472 elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
475473 kwargs ["num_inference_steps" ] = num_inference_steps
476474
475+ # copy over dummy past residuals (must be done after set_timesteps)
476+ scheduler .ets = dummy_past_residuals [:]
477+ scheduler_pt .ets = dummy_past_residuals_pt [:]
478+
477479 output = scheduler .step_prk (residual , 1 , sample , ** kwargs )["prev_sample" ]
478480 output_pt = scheduler_pt .step_prk (residual_pt , 1 , sample_pt , ** kwargs )["prev_sample" ]
479481 assert np .sum (np .abs (output - output_pt .numpy ())) < 1e-4 , "Scheduler outputs are not identical"
@@ -494,15 +496,16 @@ def test_step_shape(self):
494496
495497 sample = self .dummy_sample
496498 residual = 0.1 * sample
497- # copy over dummy past residuals
498- dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
499- scheduler .ets = dummy_past_residuals [:]
500499
501500 if num_inference_steps is not None and hasattr (scheduler , "set_timesteps" ):
502501 scheduler .set_timesteps (num_inference_steps )
503502 elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
504503 kwargs ["num_inference_steps" ] = num_inference_steps
505504
505+ # copy over dummy past residuals (must be done after set_timesteps)
506+ dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
507+ scheduler .ets = dummy_past_residuals [:]
508+
506509 output_0 = scheduler .step_prk (residual , 0 , sample , ** kwargs )["prev_sample" ]
507510 output_1 = scheduler .step_prk (residual , 1 , sample , ** kwargs )["prev_sample" ]
508511
0 commit comments