@@ -398,6 +398,179 @@ def test_save_pretrained_raise_not_implemented_exception(self):
398398 pass
399399
400400
401+ class StableDiffusionMultiControlNetOneModelPipelineFastTests (
402+ PipelineTesterMixin , PipelineKarrasSchedulerTesterMixin , unittest .TestCase
403+ ):
404+ pipeline_class = StableDiffusionControlNetPipeline
405+ params = TEXT_TO_IMAGE_PARAMS
406+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
407+ image_params = frozenset ([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
408+
409+ def get_dummy_components (self ):
410+ torch .manual_seed (0 )
411+ unet = UNet2DConditionModel (
412+ block_out_channels = (32 , 64 ),
413+ layers_per_block = 2 ,
414+ sample_size = 32 ,
415+ in_channels = 4 ,
416+ out_channels = 4 ,
417+ down_block_types = ("DownBlock2D" , "CrossAttnDownBlock2D" ),
418+ up_block_types = ("CrossAttnUpBlock2D" , "UpBlock2D" ),
419+ cross_attention_dim = 32 ,
420+ )
421+ torch .manual_seed (0 )
422+
423+ def init_weights (m ):
424+ if isinstance (m , torch .nn .Conv2d ):
425+ torch .nn .init .normal (m .weight )
426+ m .bias .data .fill_ (1.0 )
427+
428+ controlnet = ControlNetModel (
429+ block_out_channels = (32 , 64 ),
430+ layers_per_block = 2 ,
431+ in_channels = 4 ,
432+ down_block_types = ("DownBlock2D" , "CrossAttnDownBlock2D" ),
433+ cross_attention_dim = 32 ,
434+ conditioning_embedding_out_channels = (16 , 32 ),
435+ )
436+ controlnet .controlnet_down_blocks .apply (init_weights )
437+
438+ torch .manual_seed (0 )
439+ scheduler = DDIMScheduler (
440+ beta_start = 0.00085 ,
441+ beta_end = 0.012 ,
442+ beta_schedule = "scaled_linear" ,
443+ clip_sample = False ,
444+ set_alpha_to_one = False ,
445+ )
446+ torch .manual_seed (0 )
447+ vae = AutoencoderKL (
448+ block_out_channels = [32 , 64 ],
449+ in_channels = 3 ,
450+ out_channels = 3 ,
451+ down_block_types = ["DownEncoderBlock2D" , "DownEncoderBlock2D" ],
452+ up_block_types = ["UpDecoderBlock2D" , "UpDecoderBlock2D" ],
453+ latent_channels = 4 ,
454+ )
455+ torch .manual_seed (0 )
456+ text_encoder_config = CLIPTextConfig (
457+ bos_token_id = 0 ,
458+ eos_token_id = 2 ,
459+ hidden_size = 32 ,
460+ intermediate_size = 37 ,
461+ layer_norm_eps = 1e-05 ,
462+ num_attention_heads = 4 ,
463+ num_hidden_layers = 5 ,
464+ pad_token_id = 1 ,
465+ vocab_size = 1000 ,
466+ )
467+ text_encoder = CLIPTextModel (text_encoder_config )
468+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
469+
470+ controlnet = MultiControlNetModel ([controlnet ])
471+
472+ components = {
473+ "unet" : unet ,
474+ "controlnet" : controlnet ,
475+ "scheduler" : scheduler ,
476+ "vae" : vae ,
477+ "text_encoder" : text_encoder ,
478+ "tokenizer" : tokenizer ,
479+ "safety_checker" : None ,
480+ "feature_extractor" : None ,
481+ }
482+ return components
483+
484+ def get_dummy_inputs (self , device , seed = 0 ):
485+ if str (device ).startswith ("mps" ):
486+ generator = torch .manual_seed (seed )
487+ else :
488+ generator = torch .Generator (device = device ).manual_seed (seed )
489+
490+ controlnet_embedder_scale_factor = 2
491+
492+ images = [
493+ randn_tensor (
494+ (1 , 3 , 32 * controlnet_embedder_scale_factor , 32 * controlnet_embedder_scale_factor ),
495+ generator = generator ,
496+ device = torch .device (device ),
497+ ),
498+ ]
499+
500+ inputs = {
501+ "prompt" : "A painting of a squirrel eating a burger" ,
502+ "generator" : generator ,
503+ "num_inference_steps" : 2 ,
504+ "guidance_scale" : 6.0 ,
505+ "output_type" : "numpy" ,
506+ "image" : images ,
507+ }
508+
509+ return inputs
510+
511+ def test_control_guidance_switch (self ):
512+ components = self .get_dummy_components ()
513+ pipe = self .pipeline_class (** components )
514+ pipe .to (torch_device )
515+
516+ scale = 10.0
517+ steps = 4
518+
519+ inputs = self .get_dummy_inputs (torch_device )
520+ inputs ["num_inference_steps" ] = steps
521+ inputs ["controlnet_conditioning_scale" ] = scale
522+ output_1 = pipe (** inputs )[0 ]
523+
524+ inputs = self .get_dummy_inputs (torch_device )
525+ inputs ["num_inference_steps" ] = steps
526+ inputs ["controlnet_conditioning_scale" ] = scale
527+ output_2 = pipe (** inputs , control_guidance_start = 0.1 , control_guidance_end = 0.2 )[0 ]
528+
529+ inputs = self .get_dummy_inputs (torch_device )
530+ inputs ["num_inference_steps" ] = steps
531+ inputs ["controlnet_conditioning_scale" ] = scale
532+ output_3 = pipe (
533+ ** inputs ,
534+ control_guidance_start = [0.1 ],
535+ control_guidance_end = [0.2 ],
536+ )[0 ]
537+
538+ inputs = self .get_dummy_inputs (torch_device )
539+ inputs ["num_inference_steps" ] = steps
540+ inputs ["controlnet_conditioning_scale" ] = scale
541+ output_4 = pipe (** inputs , control_guidance_start = 0.4 , control_guidance_end = [0.5 ])[0 ]
542+
543+ # make sure that all outputs are different
544+ assert np .sum (np .abs (output_1 - output_2 )) > 1e-3
545+ assert np .sum (np .abs (output_1 - output_3 )) > 1e-3
546+ assert np .sum (np .abs (output_1 - output_4 )) > 1e-3
547+
548+ def test_attention_slicing_forward_pass (self ):
549+ return self ._test_attention_slicing_forward_pass (expected_max_diff = 2e-3 )
550+
551+ @unittest .skipIf (
552+ torch_device != "cuda" or not is_xformers_available (),
553+ reason = "XFormers attention is only available with CUDA and `xformers` installed" ,
554+ )
555+ def test_xformers_attention_forwardGenerator_pass (self ):
556+ self ._test_xformers_attention_forwardGenerator_pass (expected_max_diff = 2e-3 )
557+
558+ def test_inference_batch_single_identical (self ):
559+ self ._test_inference_batch_single_identical (expected_max_diff = 2e-3 )
560+
561+ def test_save_pretrained_raise_not_implemented_exception (self ):
562+ components = self .get_dummy_components ()
563+ pipe = self .pipeline_class (** components )
564+ pipe .to (torch_device )
565+ pipe .set_progress_bar_config (disable = None )
566+ with tempfile .TemporaryDirectory () as tmpdir :
567+ try :
568+ # save_pretrained is not implemented for Multi-ControlNet
569+ pipe .save_pretrained (tmpdir )
570+ except NotImplementedError :
571+ pass
572+
573+
401574@slow
402575@require_torch_gpu
403576class ControlNetPipelineSlowTests (unittest .TestCase ):
0 commit comments