Skip to content

Commit 3be0ff9

Browse files
authored
[Core] Support negative conditions in SDXL (huggingface#4774)
* add: support negative conditions. * fix: key * add: tests * address PR feedback. * add documentation * add img2img support. * add inpainting support. * ad controlnet support * Apply suggestions from code review * modify wording in the doc.
1 parent 2764db3 commit 3be0ff9

File tree

9 files changed

+269
-12
lines changed

9 files changed

+269
-12
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ The abstract of the paper is the following:
2323
- Stable Diffusion XL works especially well with images between 768 and 1024.
2424
- Stable Diffusion XL can pass a different prompt for each of the text encoders it was trained on as shown below. We can even pass different parts of the same prompt to the text encoders.
2525
- Stable Diffusion XL output image can be improved by making use of a refiner as shown below.
26+
- One can make use of `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to influence the generation process.
2627

2728
### Available checkpoints:
2829

@@ -74,6 +75,37 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
7475
image = pipe(prompt=prompt).images[0]
7576
```
7677

78+
You can additionally pass negative conditions about an image's size and position to avoid undesirable cropping behavior in the generated image, and improve image resolution. Let's take an example:
79+
80+
```python
81+
from diffusers import StableDiffusionXLPipeline
82+
import torch
83+
84+
pipe = StableDiffusionXLPipeline.from_pretrained(
85+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
86+
)
87+
pipe.to("cuda")
88+
89+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
90+
image = pipe(
91+
prompt=prompt,
92+
negative_original_size=(512, 512),
93+
negative_crops_coords_top_left=(0, 0),
94+
negative_target_size=(1024, 1024),
95+
).images[0]
96+
```
97+
98+
Here is a comparative example that shows the influence of using three `negative_original_size`s of
99+
(128, 128), (256, 256), and (512, 512) respectively:
100+
101+
![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/negative_conditions.png)
102+
103+
<Tip>
104+
105+
One can use these negative conditions in the other SDXL pipelines ([Image-To-Image](#image-to-image), [Inpainting](#inpainting), [ControlNet](../controlnet_sdxl.md)) too!
106+
107+
</Tip>
108+
77109
### Image-to-image
78110

79111
You can use SDXL as follows for *image-to-image*:

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,9 @@ def __call__(
789789
original_size: Tuple[int, int] = None,
790790
crops_coords_top_left: Tuple[int, int] = (0, 0),
791791
target_size: Tuple[int, int] = None,
792+
negative_original_size: Optional[Tuple[int, int]] = None,
793+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
794+
negative_target_size: Optional[Tuple[int, int]] = None,
792795
):
793796
r"""
794797
Function invoked when calling the pipeline for generation.
@@ -895,6 +898,22 @@ def __call__(
895898
For most cases, `target_size` should be set to the desired height and width of the generated image. If
896899
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
897900
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
901+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
902+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
903+
micro-conditioning as explained in section 2.2 of
904+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
905+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
906+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
907+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
908+
micro-conditioning as explained in section 2.2 of
909+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
910+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
911+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
912+
To negatively condition the generation process based on a target image resolution. It should be as same
913+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
914+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
915+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
916+
898917
Examples:
899918
900919
Returns:
@@ -1058,10 +1077,20 @@ def __call__(
10581077
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
10591078
)
10601079

1080+
if negative_original_size is not None and negative_target_size is not None:
1081+
negative_add_time_ids = self._get_add_time_ids(
1082+
negative_original_size,
1083+
negative_crops_coords_top_left,
1084+
negative_target_size,
1085+
dtype=prompt_embeds.dtype,
1086+
)
1087+
else:
1088+
negative_add_time_ids = add_time_ids
1089+
10611090
if do_classifier_free_guidance:
10621091
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
10631092
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1064-
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
1093+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
10651094

10661095
prompt_embeds = prompt_embeds.to(device)
10671096
add_text_embeds = add_text_embeds.to(device)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,9 @@ def __call__(
589589
original_size: Optional[Tuple[int, int]] = None,
590590
crops_coords_top_left: Tuple[int, int] = (0, 0),
591591
target_size: Optional[Tuple[int, int]] = None,
592+
negative_original_size: Optional[Tuple[int, int]] = None,
593+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
594+
negative_target_size: Optional[Tuple[int, int]] = None,
592595
):
593596
r"""
594597
Function invoked when calling the pipeline for generation.
@@ -688,6 +691,21 @@ def __call__(
688691
For most cases, `target_size` should be set to the desired height and width of the generated image. If
689692
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
690693
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
694+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
695+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
696+
micro-conditioning as explained in section 2.2 of
697+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
698+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
699+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
700+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
701+
micro-conditioning as explained in section 2.2 of
702+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
703+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
704+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
705+
To negatively condition the generation process based on a target image resolution. It should be as same
706+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
707+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
708+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
691709
692710
Examples:
693711
@@ -783,11 +801,20 @@ def __call__(
783801
add_time_ids = self._get_add_time_ids(
784802
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
785803
)
804+
if negative_original_size is not None and negative_target_size is not None:
805+
negative_add_time_ids = self._get_add_time_ids(
806+
negative_original_size,
807+
negative_crops_coords_top_left,
808+
negative_target_size,
809+
dtype=prompt_embeds.dtype,
810+
)
811+
else:
812+
negative_add_time_ids = add_time_ids
786813

787814
if do_classifier_free_guidance:
788815
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
789816
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
790-
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
817+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
791818

792819
prompt_embeds = prompt_embeds.to(device)
793820
add_text_embeds = add_text_embeds.to(device)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -601,14 +601,25 @@ def prepare_latents(
601601
return latents
602602

603603
def _get_add_time_ids(
604-
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
604+
self,
605+
original_size,
606+
crops_coords_top_left,
607+
target_size,
608+
aesthetic_score,
609+
negative_aesthetic_score,
610+
negative_original_size,
611+
negative_crops_coords_top_left,
612+
negative_target_size,
613+
dtype,
605614
):
606615
if self.config.requires_aesthetics_score:
607616
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
608-
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
617+
add_neg_time_ids = list(
618+
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
619+
)
609620
else:
610621
add_time_ids = list(original_size + crops_coords_top_left + target_size)
611-
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
622+
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
612623

613624
passed_add_embed_dim = (
614625
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
@@ -690,6 +701,9 @@ def __call__(
690701
original_size: Tuple[int, int] = None,
691702
crops_coords_top_left: Tuple[int, int] = (0, 0),
692703
target_size: Tuple[int, int] = None,
704+
negative_original_size: Optional[Tuple[int, int]] = None,
705+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
706+
negative_target_size: Optional[Tuple[int, int]] = None,
693707
aesthetic_score: float = 6.0,
694708
negative_aesthetic_score: float = 2.5,
695709
):
@@ -804,6 +818,21 @@ def __call__(
804818
For most cases, `target_size` should be set to the desired height and width of the generated image. If
805819
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
806820
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
821+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
822+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
823+
micro-conditioning as explained in section 2.2 of
824+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
825+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
826+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
827+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
828+
micro-conditioning as explained in section 2.2 of
829+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
830+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
831+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
832+
To negatively condition the generation process based on a target image resolution. It should be as same
833+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
834+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
835+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
807836
aesthetic_score (`float`, *optional*, defaults to 6.0):
808837
Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
809838
Part of SDXL's micro-conditioning as explained in section 2.2 of
@@ -908,13 +937,21 @@ def denoising_value_valid(dnv):
908937
target_size = target_size or (height, width)
909938

910939
# 8. Prepare added time ids & embeddings
940+
if negative_original_size is None:
941+
negative_original_size = original_size
942+
if negative_target_size is None:
943+
negative_target_size = target_size
944+
911945
add_text_embeds = pooled_prompt_embeds
912946
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
913947
original_size,
914948
crops_coords_top_left,
915949
target_size,
916950
aesthetic_score,
917951
negative_aesthetic_score,
952+
negative_original_size,
953+
negative_crops_coords_top_left,
954+
negative_target_size,
918955
dtype=prompt_embeds.dtype,
919956
)
920957
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,14 +813,25 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
813813

814814
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
815815
def _get_add_time_ids(
816-
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
816+
self,
817+
original_size,
818+
crops_coords_top_left,
819+
target_size,
820+
aesthetic_score,
821+
negative_aesthetic_score,
822+
negative_original_size,
823+
negative_crops_coords_top_left,
824+
negative_target_size,
825+
dtype,
817826
):
818827
if self.config.requires_aesthetics_score:
819828
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
820-
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
829+
add_neg_time_ids = list(
830+
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
831+
)
821832
else:
822833
add_time_ids = list(original_size + crops_coords_top_left + target_size)
823-
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
834+
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
824835

825836
passed_add_embed_dim = (
826837
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
@@ -905,6 +916,9 @@ def __call__(
905916
original_size: Tuple[int, int] = None,
906917
crops_coords_top_left: Tuple[int, int] = (0, 0),
907918
target_size: Tuple[int, int] = None,
919+
negative_original_size: Optional[Tuple[int, int]] = None,
920+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
921+
negative_target_size: Optional[Tuple[int, int]] = None,
908922
aesthetic_score: float = 6.0,
909923
negative_aesthetic_score: float = 2.5,
910924
):
@@ -1025,6 +1039,21 @@ def __call__(
10251039
For most cases, `target_size` should be set to the desired height and width of the generated image. If
10261040
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
10271041
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1042+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1043+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1044+
micro-conditioning as explained in section 2.2 of
1045+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1046+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1047+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1048+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1049+
micro-conditioning as explained in section 2.2 of
1050+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1051+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1052+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1053+
To negatively condition the generation process based on a target image resolution. It should be as same
1054+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1055+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1056+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
10281057
aesthetic_score (`float`, *optional*, defaults to 6.0):
10291058
Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
10301059
Part of SDXL's micro-conditioning as explained in section 2.2 of
@@ -1199,13 +1228,21 @@ def denoising_value_valid(dnv):
11991228
target_size = target_size or (height, width)
12001229

12011230
# 10. Prepare added time ids & embeddings
1231+
if negative_original_size is None:
1232+
negative_original_size = original_size
1233+
if negative_target_size is None:
1234+
negative_target_size = target_size
1235+
12021236
add_text_embeds = pooled_prompt_embeds
12031237
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
12041238
original_size,
12051239
crops_coords_top_left,
12061240
target_size,
12071241
aesthetic_score,
12081242
negative_aesthetic_score,
1243+
negative_original_size,
1244+
negative_crops_coords_top_left,
1245+
negative_target_size,
12091246
dtype=prompt_embeds.dtype,
12101247
)
12111248
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)

tests/pipelines/controlnet/test_controlnet_sdxl.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def get_dummy_inputs(self, device, seed=0):
160160
"generator": generator,
161161
"num_inference_steps": 2,
162162
"guidance_scale": 6.0,
163-
"output_type": "numpy",
163+
"output_type": "np",
164164
"image": image,
165165
}
166166

@@ -680,6 +680,25 @@ def test_xformers_attention_forwardGenerator_pass(self):
680680
def test_inference_batch_single_identical(self):
681681
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
682682

683+
def test_negative_conditions(self):
684+
components = self.get_dummy_components()
685+
pipe = self.pipeline_class(**components)
686+
pipe.to(torch_device)
687+
688+
inputs = self.get_dummy_inputs(torch_device)
689+
image = pipe(**inputs).images
690+
image_slice_without_neg_cond = image[0, -3:, -3:, -1]
691+
692+
image = pipe(
693+
**inputs,
694+
negative_original_size=(512, 512),
695+
negative_crops_coords_top_left=(0, 0),
696+
negative_target_size=(1024, 1024),
697+
).images
698+
image_slice_with_neg_cond = image[0, -3:, -3:, -1]
699+
700+
self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2)
701+
683702

684703
@slow
685704
@require_torch_gpu

0 commit comments

Comments
 (0)