Skip to content

Commit 5049599

Browse files
authored
[Core] feat: MultiControlNet support for SDXL ControlNet pipeline (huggingface#4597)
* core: add multicontrolnet support to sdxl controlnet * modify checks. * fix: original_size determination * add: tests for multi controlnet sdxl. * remove unnecessary prints.
1 parent 7b93c2a commit 5049599

File tree

3 files changed

+472
-6
lines changed

3 files changed

+472
-6
lines changed

src/diffusers/pipelines/controlnet/multicontrolnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def forward(
3939
class_labels: Optional[torch.Tensor] = None,
4040
timestep_cond: Optional[torch.Tensor] = None,
4141
attention_mask: Optional[torch.Tensor] = None,
42+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
4243
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
4344
guess_mode: bool = False,
4445
return_dict: bool = True,
@@ -53,6 +54,7 @@ def forward(
5354
class_labels=class_labels,
5455
timestep_cond=timestep_cond,
5556
attention_mask=attention_mask,
57+
added_cond_kwargs=added_cond_kwargs,
5658
cross_attention_kwargs=cross_attention_kwargs,
5759
guess_mode=guess_mode,
5860
return_dict=return_dict,

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,15 @@ def __init__(
149149
tokenizer: CLIPTokenizer,
150150
tokenizer_2: CLIPTokenizer,
151151
unet: UNet2DConditionModel,
152-
controlnet: ControlNetModel,
152+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
153153
scheduler: KarrasDiffusionSchedulers,
154154
force_zeros_for_empty_prompt: bool = True,
155155
add_watermarker: Optional[bool] = None,
156156
):
157157
super().__init__()
158158

159159
if isinstance(controlnet, (list, tuple)):
160-
raise ValueError("MultiControlNet is not yet supported.")
160+
controlnet = MultiControlNetModel(controlnet)
161161

162162
self.register_modules(
163163
vae=vae,
@@ -530,6 +530,15 @@ def check_inputs(
530530
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
531531
)
532532

533+
# `prompt` needs more sophisticated handling when there are multiple
534+
# conditionings.
535+
if isinstance(self.controlnet, MultiControlNetModel):
536+
if isinstance(prompt, list):
537+
logger.warning(
538+
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
539+
" prompts. The conditionings will be fixed across the prompts."
540+
)
541+
533542
# Check `image`
534543
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
535544
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
@@ -540,6 +549,25 @@ def check_inputs(
540549
and isinstance(self.controlnet._orig_mod, ControlNetModel)
541550
):
542551
self.check_image(image, prompt, prompt_embeds)
552+
elif (
553+
isinstance(self.controlnet, MultiControlNetModel)
554+
or is_compiled
555+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
556+
):
557+
if not isinstance(image, list):
558+
raise TypeError("For multiple controlnets: `image` must be type `list`")
559+
560+
# When `image` is a nested list:
561+
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
562+
elif any(isinstance(i, list) for i in image):
563+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
564+
elif len(image) != len(self.controlnet.nets):
565+
raise ValueError(
566+
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
567+
)
568+
569+
for image_ in image:
570+
self.check_image(image_, prompt, prompt_embeds)
543571
else:
544572
assert False
545573

@@ -551,14 +579,41 @@ def check_inputs(
551579
):
552580
if not isinstance(controlnet_conditioning_scale, float):
553581
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
582+
elif (
583+
isinstance(self.controlnet, MultiControlNetModel)
584+
or is_compiled
585+
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
586+
):
587+
if isinstance(controlnet_conditioning_scale, list):
588+
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
589+
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
590+
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
591+
self.controlnet.nets
592+
):
593+
raise ValueError(
594+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
595+
" the same length as the number of controlnets"
596+
)
554597
else:
555598
assert False
556599

600+
if not isinstance(control_guidance_start, (tuple, list)):
601+
control_guidance_start = [control_guidance_start]
602+
603+
if not isinstance(control_guidance_end, (tuple, list)):
604+
control_guidance_end = [control_guidance_end]
605+
557606
if len(control_guidance_start) != len(control_guidance_end):
558607
raise ValueError(
559608
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
560609
)
561610

611+
if isinstance(self.controlnet, MultiControlNetModel):
612+
if len(control_guidance_start) != len(self.controlnet.nets):
613+
raise ValueError(
614+
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
615+
)
616+
562617
for start, end in zip(control_guidance_start, control_guidance_end):
563618
if start >= end:
564619
raise ValueError(
@@ -569,6 +624,7 @@ def check_inputs(
569624
if end > 1.0:
570625
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
571626

627+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
572628
def check_image(self, image, prompt, prompt_embeds):
573629
image_is_pil = isinstance(image, PIL.Image.Image)
574630
image_is_tensor = isinstance(image, torch.Tensor)
@@ -606,6 +662,7 @@ def check_image(self, image, prompt, prompt_embeds):
606662
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
607663
)
608664

665+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
609666
def prepare_image(
610667
self,
611668
image,
@@ -888,6 +945,9 @@ def __call__(
888945
# corresponds to doing no classifier free guidance.
889946
do_classifier_free_guidance = guidance_scale > 1.0
890947

948+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
949+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
950+
891951
global_pool_conditions = (
892952
controlnet.config.global_pool_conditions
893953
if isinstance(controlnet, ControlNetModel)
@@ -933,6 +993,26 @@ def __call__(
933993
guess_mode=guess_mode,
934994
)
935995
height, width = image.shape[-2:]
996+
elif isinstance(controlnet, MultiControlNetModel):
997+
images = []
998+
999+
for image_ in image:
1000+
image_ = self.prepare_image(
1001+
image=image_,
1002+
width=width,
1003+
height=height,
1004+
batch_size=batch_size * num_images_per_prompt,
1005+
num_images_per_prompt=num_images_per_prompt,
1006+
device=device,
1007+
dtype=controlnet.dtype,
1008+
do_classifier_free_guidance=do_classifier_free_guidance,
1009+
guess_mode=guess_mode,
1010+
)
1011+
1012+
images.append(image_)
1013+
1014+
image = images
1015+
height, width = image[0].shape[-2:]
9361016
else:
9371017
assert False
9381018

@@ -963,12 +1043,15 @@ def __call__(
9631043
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
9641044
for s, e in zip(control_guidance_start, control_guidance_end)
9651045
]
966-
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
1046+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
9671047

968-
original_size = original_size or image.shape[-2:]
1048+
# 7.2 Prepare added time ids & embeddings
1049+
if isinstance(image, list):
1050+
original_size = original_size or image[0].shape[-2:]
1051+
else:
1052+
original_size = original_size or image.shape[-2:]
9691053
target_size = target_size or (height, width)
9701054

971-
# 7.2 Prepare added time ids & embeddings
9721055
add_text_embeds = pooled_prompt_embeds
9731056
add_time_ids = self._get_add_time_ids(
9741057
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype

0 commit comments

Comments
 (0)