@@ -287,6 +287,11 @@ def check_over_configs(self, time_step=0, **config):
287287 elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
288288 kwargs ["num_inference_steps" ] = num_inference_steps
289289
290+ # Make sure `scale_model_input` is invoked to prevent a warning
291+ if scheduler_class != VQDiffusionScheduler :
292+ _ = scheduler .scale_model_input (sample , 0 )
293+ _ = new_scheduler .scale_model_input (sample , 0 )
294+
290295 # Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
291296 if "generator" in set (inspect .signature (scheduler .step ).parameters .keys ()):
292297 kwargs ["generator" ] = torch .manual_seed (0 )
@@ -597,7 +602,7 @@ def test_trained_betas(self):
597602 continue
598603
599604 scheduler_config = self .get_scheduler_config ()
600- scheduler = scheduler_class (** scheduler_config , trained_betas = np .array ([0.0 , 0.1 ]))
605+ scheduler = scheduler_class (** scheduler_config , trained_betas = np .array ([0.1 , 0.3 ]))
601606
602607 with tempfile .TemporaryDirectory () as tmpdirname :
603608 scheduler .save_pretrained (tmpdirname )
@@ -2648,6 +2653,7 @@ def get_scheduler_config(self, **kwargs):
26482653 "beta_end" : 0.02 ,
26492654 "beta_schedule" : "linear" ,
26502655 "solver_order" : 2 ,
2656+ "solver_type" : "bh1" ,
26512657 }
26522658
26532659 config .update (** kwargs )
0 commit comments