11# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
22
33import inspect
4- from typing import Any , Callable , Dict , List , Optional , Union
4+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
55
66import numpy as np
77import PIL .Image
1111
1212from diffusers import AutoencoderKL , ControlNetModel , DiffusionPipeline , UNet2DConditionModel , logging
1313from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput , StableDiffusionSafetyChecker
14+ from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion_controlnet import MultiControlNetModel
1415from diffusers .schedulers import KarrasDiffusionSchedulers
1516from diffusers .utils import (
1617 PIL_INTERPOLATION ,
@@ -184,7 +185,14 @@ def prepare_mask_image(mask_image):
184185
185186
186187def prepare_controlnet_conditioning_image (
187- controlnet_conditioning_image , width , height , batch_size , num_images_per_prompt , device , dtype
188+ controlnet_conditioning_image ,
189+ width ,
190+ height ,
191+ batch_size ,
192+ num_images_per_prompt ,
193+ device ,
194+ dtype ,
195+ do_classifier_free_guidance ,
188196):
189197 if not isinstance (controlnet_conditioning_image , torch .Tensor ):
190198 if isinstance (controlnet_conditioning_image , PIL .Image .Image ):
@@ -214,6 +222,9 @@ def prepare_controlnet_conditioning_image(
214222
215223 controlnet_conditioning_image = controlnet_conditioning_image .to (device = device , dtype = dtype )
216224
225+ if do_classifier_free_guidance :
226+ controlnet_conditioning_image = torch .cat ([controlnet_conditioning_image ] * 2 )
227+
217228 return controlnet_conditioning_image
218229
219230
@@ -230,7 +241,7 @@ def __init__(
230241 text_encoder : CLIPTextModel ,
231242 tokenizer : CLIPTokenizer ,
232243 unet : UNet2DConditionModel ,
233- controlnet : ControlNetModel ,
244+ controlnet : Union [ ControlNetModel , List [ ControlNetModel ], Tuple [ ControlNetModel ], MultiControlNetModel ] ,
234245 scheduler : KarrasDiffusionSchedulers ,
235246 safety_checker : StableDiffusionSafetyChecker ,
236247 feature_extractor : CLIPImageProcessor ,
@@ -254,6 +265,9 @@ def __init__(
254265 " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
255266 )
256267
268+ if isinstance (controlnet , (list , tuple )):
269+ controlnet = MultiControlNetModel (controlnet )
270+
257271 self .register_modules (
258272 vae = vae ,
259273 text_encoder = text_encoder ,
@@ -264,6 +278,7 @@ def __init__(
264278 safety_checker = safety_checker ,
265279 feature_extractor = feature_extractor ,
266280 )
281+
267282 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
268283 self .register_to_config (requires_safety_checker = requires_safety_checker )
269284
@@ -522,6 +537,42 @@ def prepare_extra_step_kwargs(self, generator, eta):
522537 extra_step_kwargs ["generator" ] = generator
523538 return extra_step_kwargs
524539
540+ def check_controlnet_conditioning_image (self , image , prompt , prompt_embeds ):
541+ image_is_pil = isinstance (image , PIL .Image .Image )
542+ image_is_tensor = isinstance (image , torch .Tensor )
543+ image_is_pil_list = isinstance (image , list ) and isinstance (image [0 ], PIL .Image .Image )
544+ image_is_tensor_list = isinstance (image , list ) and isinstance (image [0 ], torch .Tensor )
545+
546+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list :
547+ raise TypeError (
548+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
549+ )
550+
551+ if image_is_pil :
552+ image_batch_size = 1
553+ elif image_is_tensor :
554+ image_batch_size = image .shape [0 ]
555+ elif image_is_pil_list :
556+ image_batch_size = len (image )
557+ elif image_is_tensor_list :
558+ image_batch_size = len (image )
559+ else :
560+ raise ValueError ("controlnet condition image is not valid" )
561+
562+ if prompt is not None and isinstance (prompt , str ):
563+ prompt_batch_size = 1
564+ elif prompt is not None and isinstance (prompt , list ):
565+ prompt_batch_size = len (prompt )
566+ elif prompt_embeds is not None :
567+ prompt_batch_size = prompt_embeds .shape [0 ]
568+ else :
569+ raise ValueError ("prompt or prompt_embeds are not valid" )
570+
571+ if image_batch_size != 1 and image_batch_size != prompt_batch_size :
572+ raise ValueError (
573+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
574+ )
575+
525576 def check_inputs (
526577 self ,
527578 prompt ,
@@ -534,6 +585,7 @@ def check_inputs(
534585 negative_prompt = None ,
535586 prompt_embeds = None ,
536587 negative_prompt_embeds = None ,
588+ controlnet_conditioning_scale = None ,
537589 ):
538590 if height % 8 != 0 or width % 8 != 0 :
539591 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -572,45 +624,35 @@ def check_inputs(
572624 f" { negative_prompt_embeds .shape } ."
573625 )
574626
575- controlnet_cond_image_is_pil = isinstance (controlnet_conditioning_image , PIL .Image .Image )
576- controlnet_cond_image_is_tensor = isinstance (controlnet_conditioning_image , torch .Tensor )
577- controlnet_cond_image_is_pil_list = isinstance (controlnet_conditioning_image , list ) and isinstance (
578- controlnet_conditioning_image [0 ], PIL .Image .Image
579- )
580- controlnet_cond_image_is_tensor_list = isinstance (controlnet_conditioning_image , list ) and isinstance (
581- controlnet_conditioning_image [0 ], torch .Tensor
582- )
583-
584- if (
585- not controlnet_cond_image_is_pil
586- and not controlnet_cond_image_is_tensor
587- and not controlnet_cond_image_is_pil_list
588- and not controlnet_cond_image_is_tensor_list
589- ):
590- raise TypeError (
591- "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
592- )
593-
594- if controlnet_cond_image_is_pil :
595- controlnet_cond_image_batch_size = 1
596- elif controlnet_cond_image_is_tensor :
597- controlnet_cond_image_batch_size = controlnet_conditioning_image .shape [0 ]
598- elif controlnet_cond_image_is_pil_list :
599- controlnet_cond_image_batch_size = len (controlnet_conditioning_image )
600- elif controlnet_cond_image_is_tensor_list :
601- controlnet_cond_image_batch_size = len (controlnet_conditioning_image )
602-
603- if prompt is not None and isinstance (prompt , str ):
604- prompt_batch_size = 1
605- elif prompt is not None and isinstance (prompt , list ):
606- prompt_batch_size = len (prompt )
607- elif prompt_embeds is not None :
608- prompt_batch_size = prompt_embeds .shape [0 ]
609-
610- if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size :
611- raise ValueError (
612- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { controlnet_cond_image_batch_size } , prompt batch size: { prompt_batch_size } "
613- )
627+ # check controlnet condition image
628+ if isinstance (self .controlnet , ControlNetModel ):
629+ self .check_controlnet_conditioning_image (controlnet_conditioning_image , prompt , prompt_embeds )
630+ elif isinstance (self .controlnet , MultiControlNetModel ):
631+ if not isinstance (controlnet_conditioning_image , list ):
632+ raise TypeError ("For multiple controlnets: `image` must be type `list`" )
633+ if len (controlnet_conditioning_image ) != len (self .controlnet .nets ):
634+ raise ValueError (
635+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
636+ )
637+ for image_ in controlnet_conditioning_image :
638+ self .check_controlnet_conditioning_image (image_ , prompt , prompt_embeds )
639+ else :
640+ assert False
641+
642+ # Check `controlnet_conditioning_scale`
643+ if isinstance (self .controlnet , ControlNetModel ):
644+ if not isinstance (controlnet_conditioning_scale , float ):
645+ raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
646+ elif isinstance (self .controlnet , MultiControlNetModel ):
647+ if isinstance (controlnet_conditioning_scale , list ) and len (controlnet_conditioning_scale ) != len (
648+ self .controlnet .nets
649+ ):
650+ raise ValueError (
651+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
652+ " the same length as the number of controlnets"
653+ )
654+ else :
655+ assert False
614656
615657 if isinstance (image , torch .Tensor ) and not isinstance (mask_image , torch .Tensor ):
616658 raise TypeError ("if `image` is a tensor, `mask_image` must also be a tensor" )
@@ -630,6 +672,8 @@ def check_inputs(
630672 image_channels , image_height , image_width = image .shape
631673 elif image .ndim == 4 :
632674 image_batch_size , image_channels , image_height , image_width = image .shape
675+ else :
676+ assert False
633677
634678 if mask_image .ndim == 2 :
635679 mask_image_batch_size = 1
@@ -797,7 +841,7 @@ def __call__(
797841 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
798842 callback_steps : int = 1 ,
799843 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
800- controlnet_conditioning_scale : float = 1.0 ,
844+ controlnet_conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
801845 ):
802846 r"""
803847 Function invoked when calling the pipeline for generation.
@@ -897,6 +941,7 @@ def __call__(
897941 negative_prompt ,
898942 prompt_embeds ,
899943 negative_prompt_embeds ,
944+ controlnet_conditioning_scale ,
900945 )
901946
902947 # 2. Define call parameters
@@ -913,6 +958,9 @@ def __call__(
913958 # corresponds to doing no classifier free guidance.
914959 do_classifier_free_guidance = guidance_scale > 1.0
915960
961+ if isinstance (self .controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
962+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (self .controlnet .nets )
963+
916964 # 3. Encode input prompt
917965 prompt_embeds = self ._encode_prompt (
918966 prompt ,
@@ -929,15 +977,37 @@ def __call__(
929977
930978 mask_image = prepare_mask_image (mask_image )
931979
932- controlnet_conditioning_image = prepare_controlnet_conditioning_image (
933- controlnet_conditioning_image ,
934- width ,
935- height ,
936- batch_size * num_images_per_prompt ,
937- num_images_per_prompt ,
938- device ,
939- self .controlnet .dtype ,
940- )
980+ # condition image(s)
981+ if isinstance (self .controlnet , ControlNetModel ):
982+ controlnet_conditioning_image = prepare_controlnet_conditioning_image (
983+ controlnet_conditioning_image = controlnet_conditioning_image ,
984+ width = width ,
985+ height = height ,
986+ batch_size = batch_size * num_images_per_prompt ,
987+ num_images_per_prompt = num_images_per_prompt ,
988+ device = device ,
989+ dtype = self .controlnet .dtype ,
990+ do_classifier_free_guidance = do_classifier_free_guidance ,
991+ )
992+ elif isinstance (self .controlnet , MultiControlNetModel ):
993+ controlnet_conditioning_images = []
994+
995+ for image_ in controlnet_conditioning_image :
996+ image_ = prepare_controlnet_conditioning_image (
997+ controlnet_conditioning_image = image_ ,
998+ width = width ,
999+ height = height ,
1000+ batch_size = batch_size * num_images_per_prompt ,
1001+ num_images_per_prompt = num_images_per_prompt ,
1002+ device = device ,
1003+ dtype = self .controlnet .dtype ,
1004+ do_classifier_free_guidance = do_classifier_free_guidance ,
1005+ )
1006+ controlnet_conditioning_images .append (image_ )
1007+
1008+ controlnet_conditioning_image = controlnet_conditioning_images
1009+ else :
1010+ assert False
9411011
9421012 masked_image = image * (mask_image < 0.5 )
9431013
@@ -979,9 +1049,6 @@ def __call__(
9791049 do_classifier_free_guidance ,
9801050 )
9811051
982- if do_classifier_free_guidance :
983- controlnet_conditioning_image = torch .cat ([controlnet_conditioning_image ] * 2 )
984-
9851052 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
9861053 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
9871054
@@ -1007,15 +1074,10 @@ def __call__(
10071074 t ,
10081075 encoder_hidden_states = prompt_embeds ,
10091076 controlnet_cond = controlnet_conditioning_image ,
1077+ conditioning_scale = controlnet_conditioning_scale ,
10101078 return_dict = False ,
10111079 )
10121080
1013- down_block_res_samples = [
1014- down_block_res_sample * controlnet_conditioning_scale
1015- for down_block_res_sample in down_block_res_samples
1016- ]
1017- mid_block_res_sample *= controlnet_conditioning_scale
1018-
10191081 # predict the noise residual
10201082 noise_pred = self .unet (
10211083 inpainting_latent_model_input ,
0 commit comments