Skip to content

Commit 6290668

Browse files
authored
Add multiple conditions to StableDiffusionControlNetInpaintPipeline (huggingface#3125)
* try multi controlnet inpaint * multi controlnet inpaint * multi controlnet inpaint
1 parent 73cc431 commit 6290668

File tree

1 file changed

+123
-61
lines changed

1 file changed

+123
-61
lines changed

examples/community/stable_diffusion_controlnet_inpaint.py

Lines changed: 123 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
22

33
import inspect
4-
from typing import Any, Callable, Dict, List, Optional, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import PIL.Image
@@ -11,6 +11,7 @@
1111

1212
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
1313
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
14+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
1415
from diffusers.schedulers import KarrasDiffusionSchedulers
1516
from diffusers.utils import (
1617
PIL_INTERPOLATION,
@@ -184,7 +185,14 @@ def prepare_mask_image(mask_image):
184185

185186

186187
def prepare_controlnet_conditioning_image(
187-
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
188+
controlnet_conditioning_image,
189+
width,
190+
height,
191+
batch_size,
192+
num_images_per_prompt,
193+
device,
194+
dtype,
195+
do_classifier_free_guidance,
188196
):
189197
if not isinstance(controlnet_conditioning_image, torch.Tensor):
190198
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
@@ -214,6 +222,9 @@ def prepare_controlnet_conditioning_image(
214222

215223
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
216224

225+
if do_classifier_free_guidance:
226+
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
227+
217228
return controlnet_conditioning_image
218229

219230

@@ -230,7 +241,7 @@ def __init__(
230241
text_encoder: CLIPTextModel,
231242
tokenizer: CLIPTokenizer,
232243
unet: UNet2DConditionModel,
233-
controlnet: ControlNetModel,
244+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
234245
scheduler: KarrasDiffusionSchedulers,
235246
safety_checker: StableDiffusionSafetyChecker,
236247
feature_extractor: CLIPImageProcessor,
@@ -254,6 +265,9 @@ def __init__(
254265
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
255266
)
256267

268+
if isinstance(controlnet, (list, tuple)):
269+
controlnet = MultiControlNetModel(controlnet)
270+
257271
self.register_modules(
258272
vae=vae,
259273
text_encoder=text_encoder,
@@ -264,6 +278,7 @@ def __init__(
264278
safety_checker=safety_checker,
265279
feature_extractor=feature_extractor,
266280
)
281+
267282
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
268283
self.register_to_config(requires_safety_checker=requires_safety_checker)
269284

@@ -522,6 +537,42 @@ def prepare_extra_step_kwargs(self, generator, eta):
522537
extra_step_kwargs["generator"] = generator
523538
return extra_step_kwargs
524539

540+
def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
541+
image_is_pil = isinstance(image, PIL.Image.Image)
542+
image_is_tensor = isinstance(image, torch.Tensor)
543+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
544+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
545+
546+
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
547+
raise TypeError(
548+
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
549+
)
550+
551+
if image_is_pil:
552+
image_batch_size = 1
553+
elif image_is_tensor:
554+
image_batch_size = image.shape[0]
555+
elif image_is_pil_list:
556+
image_batch_size = len(image)
557+
elif image_is_tensor_list:
558+
image_batch_size = len(image)
559+
else:
560+
raise ValueError("controlnet condition image is not valid")
561+
562+
if prompt is not None and isinstance(prompt, str):
563+
prompt_batch_size = 1
564+
elif prompt is not None and isinstance(prompt, list):
565+
prompt_batch_size = len(prompt)
566+
elif prompt_embeds is not None:
567+
prompt_batch_size = prompt_embeds.shape[0]
568+
else:
569+
raise ValueError("prompt or prompt_embeds are not valid")
570+
571+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
572+
raise ValueError(
573+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
574+
)
575+
525576
def check_inputs(
526577
self,
527578
prompt,
@@ -534,6 +585,7 @@ def check_inputs(
534585
negative_prompt=None,
535586
prompt_embeds=None,
536587
negative_prompt_embeds=None,
588+
controlnet_conditioning_scale=None,
537589
):
538590
if height % 8 != 0 or width % 8 != 0:
539591
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -572,45 +624,35 @@ def check_inputs(
572624
f" {negative_prompt_embeds.shape}."
573625
)
574626

575-
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
576-
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
577-
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
578-
controlnet_conditioning_image[0], PIL.Image.Image
579-
)
580-
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
581-
controlnet_conditioning_image[0], torch.Tensor
582-
)
583-
584-
if (
585-
not controlnet_cond_image_is_pil
586-
and not controlnet_cond_image_is_tensor
587-
and not controlnet_cond_image_is_pil_list
588-
and not controlnet_cond_image_is_tensor_list
589-
):
590-
raise TypeError(
591-
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
592-
)
593-
594-
if controlnet_cond_image_is_pil:
595-
controlnet_cond_image_batch_size = 1
596-
elif controlnet_cond_image_is_tensor:
597-
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
598-
elif controlnet_cond_image_is_pil_list:
599-
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
600-
elif controlnet_cond_image_is_tensor_list:
601-
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
602-
603-
if prompt is not None and isinstance(prompt, str):
604-
prompt_batch_size = 1
605-
elif prompt is not None and isinstance(prompt, list):
606-
prompt_batch_size = len(prompt)
607-
elif prompt_embeds is not None:
608-
prompt_batch_size = prompt_embeds.shape[0]
609-
610-
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
611-
raise ValueError(
612-
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
613-
)
627+
# check controlnet condition image
628+
if isinstance(self.controlnet, ControlNetModel):
629+
self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
630+
elif isinstance(self.controlnet, MultiControlNetModel):
631+
if not isinstance(controlnet_conditioning_image, list):
632+
raise TypeError("For multiple controlnets: `image` must be type `list`")
633+
if len(controlnet_conditioning_image) != len(self.controlnet.nets):
634+
raise ValueError(
635+
"For multiple controlnets: `image` must have the same length as the number of controlnets."
636+
)
637+
for image_ in controlnet_conditioning_image:
638+
self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
639+
else:
640+
assert False
641+
642+
# Check `controlnet_conditioning_scale`
643+
if isinstance(self.controlnet, ControlNetModel):
644+
if not isinstance(controlnet_conditioning_scale, float):
645+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
646+
elif isinstance(self.controlnet, MultiControlNetModel):
647+
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
648+
self.controlnet.nets
649+
):
650+
raise ValueError(
651+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
652+
" the same length as the number of controlnets"
653+
)
654+
else:
655+
assert False
614656

615657
if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
616658
raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
@@ -630,6 +672,8 @@ def check_inputs(
630672
image_channels, image_height, image_width = image.shape
631673
elif image.ndim == 4:
632674
image_batch_size, image_channels, image_height, image_width = image.shape
675+
else:
676+
assert False
633677

634678
if mask_image.ndim == 2:
635679
mask_image_batch_size = 1
@@ -797,7 +841,7 @@ def __call__(
797841
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
798842
callback_steps: int = 1,
799843
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
800-
controlnet_conditioning_scale: float = 1.0,
844+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
801845
):
802846
r"""
803847
Function invoked when calling the pipeline for generation.
@@ -897,6 +941,7 @@ def __call__(
897941
negative_prompt,
898942
prompt_embeds,
899943
negative_prompt_embeds,
944+
controlnet_conditioning_scale,
900945
)
901946

902947
# 2. Define call parameters
@@ -913,6 +958,9 @@ def __call__(
913958
# corresponds to doing no classifier free guidance.
914959
do_classifier_free_guidance = guidance_scale > 1.0
915960

961+
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
962+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
963+
916964
# 3. Encode input prompt
917965
prompt_embeds = self._encode_prompt(
918966
prompt,
@@ -929,15 +977,37 @@ def __call__(
929977

930978
mask_image = prepare_mask_image(mask_image)
931979

932-
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
933-
controlnet_conditioning_image,
934-
width,
935-
height,
936-
batch_size * num_images_per_prompt,
937-
num_images_per_prompt,
938-
device,
939-
self.controlnet.dtype,
940-
)
980+
# condition image(s)
981+
if isinstance(self.controlnet, ControlNetModel):
982+
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
983+
controlnet_conditioning_image=controlnet_conditioning_image,
984+
width=width,
985+
height=height,
986+
batch_size=batch_size * num_images_per_prompt,
987+
num_images_per_prompt=num_images_per_prompt,
988+
device=device,
989+
dtype=self.controlnet.dtype,
990+
do_classifier_free_guidance=do_classifier_free_guidance,
991+
)
992+
elif isinstance(self.controlnet, MultiControlNetModel):
993+
controlnet_conditioning_images = []
994+
995+
for image_ in controlnet_conditioning_image:
996+
image_ = prepare_controlnet_conditioning_image(
997+
controlnet_conditioning_image=image_,
998+
width=width,
999+
height=height,
1000+
batch_size=batch_size * num_images_per_prompt,
1001+
num_images_per_prompt=num_images_per_prompt,
1002+
device=device,
1003+
dtype=self.controlnet.dtype,
1004+
do_classifier_free_guidance=do_classifier_free_guidance,
1005+
)
1006+
controlnet_conditioning_images.append(image_)
1007+
1008+
controlnet_conditioning_image = controlnet_conditioning_images
1009+
else:
1010+
assert False
9411011

9421012
masked_image = image * (mask_image < 0.5)
9431013

@@ -979,9 +1049,6 @@ def __call__(
9791049
do_classifier_free_guidance,
9801050
)
9811051

982-
if do_classifier_free_guidance:
983-
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
984-
9851052
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
9861053
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
9871054

@@ -1007,15 +1074,10 @@ def __call__(
10071074
t,
10081075
encoder_hidden_states=prompt_embeds,
10091076
controlnet_cond=controlnet_conditioning_image,
1077+
conditioning_scale=controlnet_conditioning_scale,
10101078
return_dict=False,
10111079
)
10121080

1013-
down_block_res_samples = [
1014-
down_block_res_sample * controlnet_conditioning_scale
1015-
for down_block_res_sample in down_block_res_samples
1016-
]
1017-
mid_block_res_sample *= controlnet_conditioning_scale
1018-
10191081
# predict the noise residual
10201082
noise_pred = self.unet(
10211083
inpainting_latent_model_input,

0 commit comments

Comments
 (0)