Skip to content

Commit aed30df

Browse files
apolinariomultimodalartpatrickvonplatensayakpaul
authored
Allow passing different prompts to each text_encoder on stable_diffusion_xl pipelines (huggingface#4156)
* sdxl prompt2 * Improve checks * doc linting * whoops * remove cat * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Add other pipelines and tests * Add multi-prompting to docs * doc and copies check * Fix copied froms * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> * Bring back the original code for unrelated files * Fix tests * Fix img2img * Fix all * fix --------- Co-authored-by: multimodalart <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent e2bbaa4 commit aed30df

File tree

9 files changed

+469
-68
lines changed

9 files changed

+469
-68
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ The abstract of the paper is the following:
2121
## Tips
2222

2323
- Stable Diffusion XL works especially well with images between 768 and 1024.
24+
- Stable Diffusion XL can pass a different prompt for each of the text encoders it was trained on as shown below. We can even pass different parts of the same prompt to the text encoders.
2425
- Stable Diffusion XL output image can be improved by making use of a refiner as shown below.
2526

2627
### Available checkpoints:
@@ -362,3 +363,25 @@ pip install xformers
362363
[[autodoc]] StableDiffusionXLInpaintPipeline
363364
- all
364365
- __call__
366+
367+
### Passing different prompts to each text-encoder
368+
369+
Stable Diffusion XL was trained on two text encoders. The default behavior is to pass the same prompt to each. But it is possible to pass a different prompt for each text-encoder, as [some users](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201) noted that it can boost quality.
370+
To do so, you can pass `prompt_2` and `negative_prompt_2` in addition to `prompt` and `negative_prompt`. By doing that, you will pass the original prompts and negative prompts (as in `prompt` and `negative_prompt`) to `text_encoder` (in official SDXL 0.9/1.0 that is [OpenAI CLIP-ViT/L-14](https://huggingface.co/openai/clip-vit-large-patch14)),
371+
and `prompt_2` and `negative_prompt_2` to `text_encoder_2` (in official SDXL 0.9/1.0 that is [OpenCLIP-ViT/bigG-14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
372+
373+
```py
374+
from diffusers import StableDiffusionXLPipeline
375+
import torch
376+
377+
pipe = StableDiffusionXLPipeline.from_pretrained(
378+
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
379+
)
380+
pipe.to("cuda")
381+
382+
# prompt will be passed to OAI CLIP-ViT/L-14
383+
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
384+
# prompt_2 will be passed to OpenCLIP-ViT/bigG-14
385+
prompt_2 = "monet painting"
386+
image = pipe(prompt=prompt, prompt_2=prompt_2).images[0]
387+
```

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,13 @@ def enable_model_cpu_offload(self, gpu_id=0):
196196
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
197197
def encode_prompt(
198198
self,
199-
prompt,
199+
prompt: str,
200+
prompt_2: Optional[str] = None,
200201
device: Optional[torch.device] = None,
201202
num_images_per_prompt: int = 1,
202203
do_classifier_free_guidance: bool = True,
203-
negative_prompt=None,
204+
negative_prompt: Optional[str] = None,
205+
negative_prompt_2: Optional[str] = None,
204206
prompt_embeds: Optional[torch.FloatTensor] = None,
205207
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
206208
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -211,8 +213,11 @@ def encode_prompt(
211213
Encodes the prompt into text encoder hidden states.
212214
213215
Args:
214-
prompt (`str` or `List[str]`, *optional*):
216+
prompt (`str` or `List[str]`, *optional*):
215217
prompt to be encoded
218+
prompt_2 (`str` or `List[str]`, *optional*):
219+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
220+
used in both text-encoders
216221
device: (`torch.device`):
217222
torch device
218223
num_images_per_prompt (`int`):
@@ -223,6 +228,9 @@ def encode_prompt(
223228
The prompt or prompts not to guide the image generation. If not defined, one has to pass
224229
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
225230
less than `1`).
231+
negative_prompt_2 (`str` or `List[str]`, *optional*):
232+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
233+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
226234
prompt_embeds (`torch.FloatTensor`, *optional*):
227235
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
228236
provided, text embeddings will be generated from `prompt` input argument.
@@ -261,9 +269,11 @@ def encode_prompt(
261269
)
262270

263271
if prompt_embeds is None:
272+
prompt_2 = prompt_2 or prompt
264273
# textual inversion: procecss multi-vector tokens if necessary
265274
prompt_embeds_list = []
266-
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
275+
prompts = [prompt, prompt_2]
276+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
267277
if isinstance(self, TextualInversionLoaderMixin):
268278
prompt = self.maybe_convert_prompt(prompt, tokenizer)
269279

@@ -274,8 +284,10 @@ def encode_prompt(
274284
truncation=True,
275285
return_tensors="pt",
276286
)
287+
277288
text_input_ids = text_inputs.input_ids
278289
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
290+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
279291

280292
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
281293
text_input_ids, untruncated_ids
@@ -311,32 +323,33 @@ def encode_prompt(
311323
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
312324
elif do_classifier_free_guidance and negative_prompt_embeds is None:
313325
negative_prompt = negative_prompt or ""
326+
negative_prompt_2 = negative_prompt_2 or negative_prompt
327+
314328
uncond_tokens: List[str]
315329
if prompt is not None and type(prompt) is not type(negative_prompt):
316330
raise TypeError(
317331
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
318332
f" {type(prompt)}."
319333
)
320334
elif isinstance(negative_prompt, str):
321-
uncond_tokens = [negative_prompt]
335+
uncond_tokens = [negative_prompt, negative_prompt_2]
322336
elif batch_size != len(negative_prompt):
323337
raise ValueError(
324338
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
325339
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
326340
" the batch size of `prompt`."
327341
)
328342
else:
329-
uncond_tokens = negative_prompt
343+
uncond_tokens = [negative_prompt, negative_prompt_2]
330344

331345
negative_prompt_embeds_list = []
332-
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
333-
# textual inversion: procecss multi-vector tokens if necessary
346+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
334347
if isinstance(self, TextualInversionLoaderMixin):
335-
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
348+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
336349

337350
max_length = prompt_embeds.shape[1]
338351
uncond_input = tokenizer(
339-
uncond_tokens,
352+
negative_prompt,
340353
padding="max_length",
341354
max_length=max_length,
342355
truncation=True,
@@ -401,9 +414,11 @@ def prepare_extra_step_kwargs(self, generator, eta):
401414
def check_inputs(
402415
self,
403416
prompt,
417+
prompt_2,
404418
image,
405419
callback_steps,
406420
negative_prompt=None,
421+
negative_prompt_2=None,
407422
prompt_embeds=None,
408423
negative_prompt_embeds=None,
409424
controlnet_conditioning_scale=1.0,
@@ -423,18 +438,30 @@ def check_inputs(
423438
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
424439
" only forward one of the two."
425440
)
441+
elif prompt_2 is not None and prompt_embeds is not None:
442+
raise ValueError(
443+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
444+
" only forward one of the two."
445+
)
426446
elif prompt is None and prompt_embeds is None:
427447
raise ValueError(
428448
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
429449
)
430450
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
431451
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
452+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
453+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
432454

433455
if negative_prompt is not None and negative_prompt_embeds is not None:
434456
raise ValueError(
435457
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
436458
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
437459
)
460+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
461+
raise ValueError(
462+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
463+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
464+
)
438465

439466
if prompt_embeds is not None and negative_prompt_embeds is not None:
440467
if prompt_embeds.shape != negative_prompt_embeds.shape:
@@ -610,6 +637,7 @@ def upcast_vae(self):
610637
def __call__(
611638
self,
612639
prompt: Union[str, List[str]] = None,
640+
prompt_2: Optional[Union[str, List[str]]] = None,
613641
image: Union[
614642
torch.FloatTensor,
615643
PIL.Image.Image,
@@ -623,6 +651,7 @@ def __call__(
623651
num_inference_steps: int = 50,
624652
guidance_scale: float = 7.5,
625653
negative_prompt: Optional[Union[str, List[str]]] = None,
654+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
626655
num_images_per_prompt: Optional[int] = 1,
627656
eta: float = 0.0,
628657
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -649,6 +678,9 @@ def __call__(
649678
prompt (`str` or `List[str]`, *optional*):
650679
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
651680
instead.
681+
prompt_2 (`str` or `List[str]`, *optional*):
682+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
683+
used in both text-encoders
652684
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
653685
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
654686
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
@@ -674,6 +706,9 @@ def __call__(
674706
The prompt or prompts not to guide the image generation. If not defined, one has to pass
675707
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
676708
less than `1`).
709+
negative_prompt_2 (`str` or `List[str]`, *optional*):
710+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
711+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
677712
num_images_per_prompt (`int`, *optional*, defaults to 1):
678713
The number of images to generate per prompt.
679714
eta (`float`, *optional*, defaults to 0.0):
@@ -749,9 +784,11 @@ def __call__(
749784
# 1. Check inputs. Raise error if not correct
750785
self.check_inputs(
751786
prompt,
787+
prompt_2,
752788
image,
753789
callback_steps,
754790
negative_prompt,
791+
negative_prompt_2,
755792
prompt_embeds,
756793
negative_prompt_embeds,
757794
controlnet_conditioning_scale,
@@ -791,10 +828,12 @@ def __call__(
791828
negative_pooled_prompt_embeds,
792829
) = self.encode_prompt(
793830
prompt,
831+
prompt_2,
794832
device,
795833
num_images_per_prompt,
796834
do_classifier_free_guidance,
797835
negative_prompt,
836+
negative_prompt_2,
798837
prompt_embeds=prompt_embeds,
799838
negative_prompt_embeds=negative_prompt_embeds,
800839
lora_scale=text_encoder_lora_scale,

0 commit comments

Comments
 (0)