Skip to content

Commit 5c9dd0a

Browse files
authored
Add to support Guess Mode for StableDiffusionControlnetPipleline (huggingface#2998)
* add guess mode (WIP) * fix uncond/cond order * support guidance_scale=1.0 and batch != 1 * remove magic coeff * add docstring * add intergration test * add document to controlnet.mdx * made the comments a bit more explanatory * fix table
1 parent d0f2582 commit 5c9dd0a

File tree

4 files changed

+115
-6
lines changed

4 files changed

+115
-6
lines changed

docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,42 @@ image.save("./multi_controlnet_output.png")
242242

243243
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/multi_controlnet_output.png" width=600/>
244244

245+
### Guess Mode
246+
247+
Guess Mode is [a ControlNet feature that was implemented](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) after the publication of [the paper](https://arxiv.org/abs/2302.05543). The description states:
248+
249+
>In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.
250+
251+
#### The core implementation:
252+
253+
It adjusts the scale of the output residuals from ControlNet by a fixed ratio depending on the block depth. The shallowest DownBlock corresponds to `0.1`. As the blocks get deeper, the scale increases exponentially, and the scale for the output of the MidBlock becomes `1.0`.
254+
255+
Since the core implementation is just this, **it does not have any impact on prompt conditioning**. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired.
256+
257+
#### Usage:
258+
259+
Just specify `guess_mode=True` in the pipe() function. A `guidance_scale` between 3.0 and 5.0 is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode).
260+
```py
261+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
262+
import torch
263+
264+
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
265+
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to(
266+
"cuda"
267+
)
268+
image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0]
269+
image.save("guess_mode_generated.png")
270+
```
271+
272+
#### Output image comparison:
273+
Canny Control Example
274+
275+
|no guess_mode with prompt|guess_mode without prompt|
276+
|---|---|
277+
|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"/></a>|
278+
279+
280+
245281
## Available checkpoints
246282

247283
ControlNet requires a *control image* in addition to the text-to-image *prompt*.

src/diffusers/models/controlnet.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def forward(
456456
timestep_cond: Optional[torch.Tensor] = None,
457457
attention_mask: Optional[torch.Tensor] = None,
458458
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459+
guess_mode: bool = False,
459460
return_dict: bool = True,
460461
) -> Union[ControlNetOutput, Tuple]:
461462
# check channel order
@@ -556,8 +557,14 @@ def forward(
556557
mid_block_res_sample = self.controlnet_mid_block(sample)
557558

558559
# 6. scaling
559-
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
560-
mid_block_res_sample *= conditioning_scale
560+
if guess_mode:
561+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
562+
scales *= conditioning_scale
563+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
564+
mid_block_res_sample *= scales[-1] # last one
565+
else:
566+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
567+
mid_block_res_sample *= conditioning_scale
561568

562569
if not return_dict:
563570
return (down_block_res_samples, mid_block_res_sample)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def forward(
118118
timestep_cond: Optional[torch.Tensor] = None,
119119
attention_mask: Optional[torch.Tensor] = None,
120120
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
121+
guess_mode: bool = False,
121122
return_dict: bool = True,
122123
) -> Union[ControlNetOutput, Tuple]:
123124
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
@@ -131,6 +132,7 @@ def forward(
131132
timestep_cond,
132133
attention_mask,
133134
cross_attention_kwargs,
135+
guess_mode,
134136
return_dict,
135137
)
136138

@@ -627,7 +629,16 @@ def check_image(self, image, prompt, prompt_embeds):
627629
)
628630

629631
def prepare_image(
630-
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
632+
self,
633+
image,
634+
width,
635+
height,
636+
batch_size,
637+
num_images_per_prompt,
638+
device,
639+
dtype,
640+
do_classifier_free_guidance,
641+
guess_mode,
631642
):
632643
if not isinstance(image, torch.Tensor):
633644
if isinstance(image, PIL.Image.Image):
@@ -664,7 +675,7 @@ def prepare_image(
664675

665676
image = image.to(device=device, dtype=dtype)
666677

667-
if do_classifier_free_guidance:
678+
if do_classifier_free_guidance and not guess_mode:
668679
image = torch.cat([image] * 2)
669680

670681
return image
@@ -747,6 +758,7 @@ def __call__(
747758
callback_steps: int = 1,
748759
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
749760
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
761+
guess_mode: bool = False,
750762
):
751763
r"""
752764
Function invoked when calling the pipeline for generation.
@@ -819,6 +831,10 @@ def __call__(
819831
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
820832
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
821833
corresponding scale as a list.
834+
guess_mode (`bool`, *optional*, defaults to `False`):
835+
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
836+
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
837+
822838
Examples:
823839
824840
Returns:
@@ -883,6 +899,7 @@ def __call__(
883899
device=device,
884900
dtype=self.controlnet.dtype,
885901
do_classifier_free_guidance=do_classifier_free_guidance,
902+
guess_mode=guess_mode,
886903
)
887904
elif isinstance(self.controlnet, MultiControlNetModel):
888905
images = []
@@ -897,6 +914,7 @@ def __call__(
897914
device=device,
898915
dtype=self.controlnet.dtype,
899916
do_classifier_free_guidance=do_classifier_free_guidance,
917+
guess_mode=guess_mode,
900918
)
901919

902920
images.append(image_)
@@ -934,15 +952,31 @@ def __call__(
934952
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
935953

936954
# controlnet(s) inference
955+
if guess_mode and do_classifier_free_guidance:
956+
# Infer ControlNet only for the conditional batch.
957+
controlnet_latent_model_input = latents
958+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
959+
else:
960+
controlnet_latent_model_input = latent_model_input
961+
controlnet_prompt_embeds = prompt_embeds
962+
937963
down_block_res_samples, mid_block_res_sample = self.controlnet(
938-
latent_model_input,
964+
controlnet_latent_model_input,
939965
t,
940-
encoder_hidden_states=prompt_embeds,
966+
encoder_hidden_states=controlnet_prompt_embeds,
941967
controlnet_cond=image,
942968
conditioning_scale=controlnet_conditioning_scale,
969+
guess_mode=guess_mode,
943970
return_dict=False,
944971
)
945972

973+
if guess_mode and do_classifier_free_guidance:
974+
# Infered ControlNet only for the conditional batch.
975+
# To apply the output of ControlNet to both the unconditional and conditional batches,
976+
# add 0 to the unconditional batch to keep it unchanged.
977+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
978+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
979+
946980
# predict the noise residual
947981
noise_pred = self.unet(
948982
latent_model_input,

tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,38 @@ def test_sequential_cpu_offloading(self):
553553
# make sure that less than 7 GB is allocated
554554
assert mem_bytes < 4 * 10**9
555555

556+
def test_canny_guess_mode(self):
557+
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
558+
559+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
560+
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
561+
)
562+
pipe.enable_model_cpu_offload()
563+
pipe.set_progress_bar_config(disable=None)
564+
565+
generator = torch.Generator(device="cpu").manual_seed(0)
566+
prompt = ""
567+
image = load_image(
568+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
569+
)
570+
571+
output = pipe(
572+
prompt,
573+
image,
574+
generator=generator,
575+
output_type="np",
576+
num_inference_steps=3,
577+
guidance_scale=3.0,
578+
guess_mode=True,
579+
)
580+
581+
image = output.images[0]
582+
assert image.shape == (768, 512, 3)
583+
584+
image_slice = image[-3:, -3:, -1]
585+
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
586+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
587+
556588

557589
@slow
558590
@require_torch_gpu

0 commit comments

Comments
 (0)