Skip to content

Commit 0ec7a02

Browse files
[StableDiffusionXLAdapterPipeline] allow negative micro conds (huggingface#4941)
* allow negative micro conds in t2i pipeline * Empty-Commit --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 626284f commit 0ec7a02

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,9 @@ def __call__(
655655
original_size: Optional[Tuple[int, int]] = None,
656656
crops_coords_top_left: Tuple[int, int] = (0, 0),
657657
target_size: Optional[Tuple[int, int]] = None,
658+
negative_original_size: Optional[Tuple[int, int]] = None,
659+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
660+
negative_target_size: Optional[Tuple[int, int]] = None,
658661
adapter_conditioning_scale: Union[float, List[float]] = 1.0,
659662
adapter_conditioning_factor: float = 1.0,
660663
):
@@ -764,6 +767,22 @@ def __call__(
764767
For most cases, `target_size` should be set to the desired height and width of the generated image. If
765768
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
766769
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
770+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
771+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
772+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
773+
micro-conditioning as explained in section 2.2 of
774+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
775+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
776+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
777+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
778+
micro-conditioning as explained in section 2.2 of
779+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
780+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
781+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
782+
To negatively condition the generation process based on a target image resolution. It should be as same
783+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
784+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
785+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
767786
adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
768787
The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
769788
residual in the original unet. If multiple adapters are specified in init, you can set the
@@ -876,11 +895,20 @@ def __call__(
876895
add_time_ids = self._get_add_time_ids(
877896
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
878897
)
898+
if negative_original_size is not None and negative_target_size is not None:
899+
negative_add_time_ids = self._get_add_time_ids(
900+
negative_original_size,
901+
negative_crops_coords_top_left,
902+
negative_target_size,
903+
dtype=prompt_embeds.dtype,
904+
)
905+
else:
906+
negative_add_time_ids = add_time_ids
879907

880908
if do_classifier_free_guidance:
881909
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
882910
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
883-
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
911+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
884912

885913
prompt_embeds = prompt_embeds.to(device)
886914
add_text_embeds = add_text_embeds.to(device)

0 commit comments

Comments
 (0)