@@ -230,10 +230,13 @@ def test_stable_diffusion_two_xl_mixture_of_denoiser(self):
230230 pipe_2 = StableDiffusionXLImg2ImgPipeline (** components ).to (torch_device )
231231 pipe_2 .unet .set_default_attn_processor ()
232232
233- def assert_run_mixture (num_steps , split , scheduler_cls ):
233+ def assert_run_mixture (num_steps , split , scheduler_cls_orig ):
234234 inputs = self .get_dummy_inputs (torch_device )
235235 inputs ["num_inference_steps" ] = num_steps
236236
237+ class scheduler_cls (scheduler_cls_orig ):
238+ pass
239+
237240 pipe_1 .scheduler = scheduler_cls .from_config (pipe_1 .scheduler .config )
238241 pipe_2 .scheduler = scheduler_cls .from_config (pipe_2 .scheduler .config )
239242
@@ -287,10 +290,13 @@ def test_stable_diffusion_three_xl_mixture_of_denoiser(self):
287290 pipe_3 = StableDiffusionXLImg2ImgPipeline (** components ).to (torch_device )
288291 pipe_3 .unet .set_default_attn_processor ()
289292
290- def assert_run_mixture (num_steps , split_1 , split_2 , scheduler_cls ):
293+ def assert_run_mixture (num_steps , split_1 , split_2 , scheduler_cls_orig ):
291294 inputs = self .get_dummy_inputs (torch_device )
292295 inputs ["num_inference_steps" ] = num_steps
293296
297+ class scheduler_cls (scheduler_cls_orig ):
298+ pass
299+
294300 pipe_1 .scheduler = scheduler_cls .from_config (pipe_1 .scheduler .config )
295301 pipe_2 .scheduler = scheduler_cls .from_config (pipe_2 .scheduler .config )
296302 pipe_3 .scheduler = scheduler_cls .from_config (pipe_3 .scheduler .config )
0 commit comments