11# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
22
33import inspect
4- from typing import Any , Callable , Dict , List , Optional , Union
4+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
55
66import numpy as np
77import PIL .Image
1010
1111from diffusers import AutoencoderKL , ControlNetModel , DiffusionPipeline , UNet2DConditionModel , logging
1212from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput , StableDiffusionSafetyChecker
13+ from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion_controlnet import MultiControlNetModel
1314from diffusers .schedulers import KarrasDiffusionSchedulers
1415from diffusers .utils import (
1516 PIL_INTERPOLATION ,
@@ -86,7 +87,14 @@ def prepare_image(image):
8687
8788
8889def prepare_controlnet_conditioning_image (
89- controlnet_conditioning_image , width , height , batch_size , num_images_per_prompt , device , dtype
90+ controlnet_conditioning_image ,
91+ width ,
92+ height ,
93+ batch_size ,
94+ num_images_per_prompt ,
95+ device ,
96+ dtype ,
97+ do_classifier_free_guidance ,
9098):
9199 if not isinstance (controlnet_conditioning_image , torch .Tensor ):
92100 if isinstance (controlnet_conditioning_image , PIL .Image .Image ):
@@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image(
116124
117125 controlnet_conditioning_image = controlnet_conditioning_image .to (device = device , dtype = dtype )
118126
127+ if do_classifier_free_guidance :
128+ controlnet_conditioning_image = torch .cat ([controlnet_conditioning_image ] * 2 )
129+
119130 return controlnet_conditioning_image
120131
121132
@@ -132,7 +143,7 @@ def __init__(
132143 text_encoder : CLIPTextModel ,
133144 tokenizer : CLIPTokenizer ,
134145 unet : UNet2DConditionModel ,
135- controlnet : ControlNetModel ,
146+ controlnet : Union [ ControlNetModel , List [ ControlNetModel ], Tuple [ ControlNetModel ], MultiControlNetModel ] ,
136147 scheduler : KarrasDiffusionSchedulers ,
137148 safety_checker : StableDiffusionSafetyChecker ,
138149 feature_extractor : CLIPImageProcessor ,
@@ -156,6 +167,9 @@ def __init__(
156167 " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
157168 )
158169
170+ if isinstance (controlnet , (list , tuple )):
171+ controlnet = MultiControlNetModel (controlnet )
172+
159173 self .register_modules (
160174 vae = vae ,
161175 text_encoder = text_encoder ,
@@ -424,6 +438,42 @@ def prepare_extra_step_kwargs(self, generator, eta):
424438 extra_step_kwargs ["generator" ] = generator
425439 return extra_step_kwargs
426440
441+ def check_controlnet_conditioning_image (self , image , prompt , prompt_embeds ):
442+ image_is_pil = isinstance (image , PIL .Image .Image )
443+ image_is_tensor = isinstance (image , torch .Tensor )
444+ image_is_pil_list = isinstance (image , list ) and isinstance (image [0 ], PIL .Image .Image )
445+ image_is_tensor_list = isinstance (image , list ) and isinstance (image [0 ], torch .Tensor )
446+
447+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list :
448+ raise TypeError (
449+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
450+ )
451+
452+ if image_is_pil :
453+ image_batch_size = 1
454+ elif image_is_tensor :
455+ image_batch_size = image .shape [0 ]
456+ elif image_is_pil_list :
457+ image_batch_size = len (image )
458+ elif image_is_tensor_list :
459+ image_batch_size = len (image )
460+ else :
461+ raise ValueError ("controlnet condition image is not valid" )
462+
463+ if prompt is not None and isinstance (prompt , str ):
464+ prompt_batch_size = 1
465+ elif prompt is not None and isinstance (prompt , list ):
466+ prompt_batch_size = len (prompt )
467+ elif prompt_embeds is not None :
468+ prompt_batch_size = prompt_embeds .shape [0 ]
469+ else :
470+ raise ValueError ("prompt or prompt_embeds are not valid" )
471+
472+ if image_batch_size != 1 and image_batch_size != prompt_batch_size :
473+ raise ValueError (
474+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
475+ )
476+
427477 def check_inputs (
428478 self ,
429479 prompt ,
@@ -438,6 +488,7 @@ def check_inputs(
438488 strength = None ,
439489 controlnet_guidance_start = None ,
440490 controlnet_guidance_end = None ,
491+ controlnet_conditioning_scale = None ,
441492 ):
442493 if height % 8 != 0 or width % 8 != 0 :
443494 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -476,58 +527,51 @@ def check_inputs(
476527 f" { negative_prompt_embeds .shape } ."
477528 )
478529
479- controlnet_cond_image_is_pil = isinstance (controlnet_conditioning_image , PIL .Image .Image )
480- controlnet_cond_image_is_tensor = isinstance (controlnet_conditioning_image , torch .Tensor )
481- controlnet_cond_image_is_pil_list = isinstance (controlnet_conditioning_image , list ) and isinstance (
482- controlnet_conditioning_image [0 ], PIL .Image .Image
483- )
484- controlnet_cond_image_is_tensor_list = isinstance (controlnet_conditioning_image , list ) and isinstance (
485- controlnet_conditioning_image [0 ], torch .Tensor
486- )
530+ # check controlnet condition image
487531
488- if (
489- not controlnet_cond_image_is_pil
490- and not controlnet_cond_image_is_tensor
491- and not controlnet_cond_image_is_pil_list
492- and not controlnet_cond_image_is_tensor_list
493- ):
494- raise TypeError (
495- "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
496- )
532+ if isinstance (self .controlnet , ControlNetModel ):
533+ self .check_controlnet_conditioning_image (controlnet_conditioning_image , prompt , prompt_embeds )
534+ elif isinstance (self .controlnet , MultiControlNetModel ):
535+ if not isinstance (controlnet_conditioning_image , list ):
536+ raise TypeError ("For multiple controlnets: `image` must be type `list`" )
497537
498- if controlnet_cond_image_is_pil :
499- controlnet_cond_image_batch_size = 1
500- elif controlnet_cond_image_is_tensor :
501- controlnet_cond_image_batch_size = controlnet_conditioning_image .shape [0 ]
502- elif controlnet_cond_image_is_pil_list :
503- controlnet_cond_image_batch_size = len (controlnet_conditioning_image )
504- elif controlnet_cond_image_is_tensor_list :
505- controlnet_cond_image_batch_size = len (controlnet_conditioning_image )
538+ if len (controlnet_conditioning_image ) != len (self .controlnet .nets ):
539+ raise ValueError (
540+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
541+ )
506542
507- if prompt is not None and isinstance (prompt , str ):
508- prompt_batch_size = 1
509- elif prompt is not None and isinstance (prompt , list ):
510- prompt_batch_size = len (prompt )
511- elif prompt_embeds is not None :
512- prompt_batch_size = prompt_embeds .shape [0 ]
543+ for image_ in controlnet_conditioning_image :
544+ self .check_controlnet_conditioning_image (image_ , prompt , prompt_embeds )
545+ else :
546+ assert False
513547
514- if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size :
515- raise ValueError (
516- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { controlnet_cond_image_batch_size } , prompt batch size: { prompt_batch_size } "
517- )
548+ # Check `controlnet_conditioning_scale`
549+
550+ if isinstance (self .controlnet , ControlNetModel ):
551+ if not isinstance (controlnet_conditioning_scale , float ):
552+ raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
553+ elif isinstance (self .controlnet , MultiControlNetModel ):
554+ if isinstance (controlnet_conditioning_scale , list ) and len (controlnet_conditioning_scale ) != len (
555+ self .controlnet .nets
556+ ):
557+ raise ValueError (
558+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
559+ " the same length as the number of controlnets"
560+ )
561+ else :
562+ assert False
518563
519564 if isinstance (image , torch .Tensor ):
520565 if image .ndim != 3 and image .ndim != 4 :
521566 raise ValueError ("`image` must have 3 or 4 dimensions" )
522567
523- # if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
524- # raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
525-
526568 if image .ndim == 3 :
527569 image_batch_size = 1
528570 image_channels , image_height , image_width = image .shape
529571 elif image .ndim == 4 :
530572 image_batch_size , image_channels , image_height , image_width = image .shape
573+ else :
574+ assert False
531575
532576 if image_channels != 3 :
533577 raise ValueError ("`image` must have 3 channels" )
@@ -659,7 +703,7 @@ def __call__(
659703 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
660704 callback_steps : int = 1 ,
661705 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
662- controlnet_conditioning_scale : float = 1.0 ,
706+ controlnet_conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
663707 controlnet_guidance_start : float = 0.0 ,
664708 controlnet_guidance_end : float = 1.0 ,
665709 ):
@@ -759,7 +803,6 @@ def __call__(
759803 self .check_inputs (
760804 prompt ,
761805 image ,
762- # mask_image,
763806 controlnet_conditioning_image ,
764807 height ,
765808 width ,
@@ -770,6 +813,7 @@ def __call__(
770813 strength ,
771814 controlnet_guidance_start ,
772815 controlnet_guidance_end ,
816+ controlnet_conditioning_scale ,
773817 )
774818
775819 # 2. Define call parameters
@@ -786,6 +830,9 @@ def __call__(
786830 # corresponds to doing no classifier free guidance.
787831 do_classifier_free_guidance = guidance_scale > 1.0
788832
833+ if isinstance (self .controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
834+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (self .controlnet .nets )
835+
789836 # 3. Encode input prompt
790837 prompt_embeds = self ._encode_prompt (
791838 prompt ,
@@ -797,22 +844,41 @@ def __call__(
797844 negative_prompt_embeds = negative_prompt_embeds ,
798845 )
799846
800- # 4. Prepare mask, image, and controlnet_conditioning_image
847+ # 4. Prepare image, and controlnet_conditioning_image
801848 image = prepare_image (image )
802849
803- # mask_image = prepare_mask_image(mask_image)
850+ # condition image(s)
851+ if isinstance (self .controlnet , ControlNetModel ):
852+ controlnet_conditioning_image = prepare_controlnet_conditioning_image (
853+ controlnet_conditioning_image = controlnet_conditioning_image ,
854+ width = width ,
855+ height = height ,
856+ batch_size = batch_size * num_images_per_prompt ,
857+ num_images_per_prompt = num_images_per_prompt ,
858+ device = device ,
859+ dtype = self .controlnet .dtype ,
860+ do_classifier_free_guidance = do_classifier_free_guidance ,
861+ )
862+ elif isinstance (self .controlnet , MultiControlNetModel ):
863+ controlnet_conditioning_images = []
864+
865+ for image_ in controlnet_conditioning_image :
866+ image_ = prepare_controlnet_conditioning_image (
867+ controlnet_conditioning_image = image_ ,
868+ width = width ,
869+ height = height ,
870+ batch_size = batch_size * num_images_per_prompt ,
871+ num_images_per_prompt = num_images_per_prompt ,
872+ device = device ,
873+ dtype = self .controlnet .dtype ,
874+ do_classifier_free_guidance = do_classifier_free_guidance ,
875+ )
804876
805- controlnet_conditioning_image = prepare_controlnet_conditioning_image (
806- controlnet_conditioning_image ,
807- width ,
808- height ,
809- batch_size * num_images_per_prompt ,
810- num_images_per_prompt ,
811- device ,
812- self .controlnet .dtype ,
813- )
877+ controlnet_conditioning_images .append (image_ )
814878
815- # masked_image = image * (mask_image < 0.5)
879+ controlnet_conditioning_image = controlnet_conditioning_images
880+ else :
881+ assert False
816882
817883 # 5. Prepare timesteps
818884 self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -830,9 +896,6 @@ def __call__(
830896 generator ,
831897 )
832898
833- if do_classifier_free_guidance :
834- controlnet_conditioning_image = torch .cat ([controlnet_conditioning_image ] * 2 )
835-
836899 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
837900 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
838901
@@ -862,15 +925,10 @@ def __call__(
862925 t ,
863926 encoder_hidden_states = prompt_embeds ,
864927 controlnet_cond = controlnet_conditioning_image ,
928+ conditioning_scale = controlnet_conditioning_scale ,
865929 return_dict = False ,
866930 )
867931
868- down_block_res_samples = [
869- down_block_res_sample * controlnet_conditioning_scale
870- for down_block_res_sample in down_block_res_samples
871- ]
872- mid_block_res_sample *= controlnet_conditioning_scale
873-
874932 # predict the noise residual
875933 noise_pred = self .unet (
876934 latent_model_input ,
0 commit comments