Skip to content

Commit b562b66

Browse files
Allow directly passing text embeddings to Stable Diffusion Pipeline for prompt weighting (huggingface#2071)
* add text embeds to sd * add text embeds to sd * finish tests * finish * finish * make style * fix tests * make style * make style * up * better docs * fix * fix * new try * up * up * finish
1 parent c118491 commit b562b66

30 files changed

+1738
-790
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 148 additions & 58 deletions
Large diffs are not rendered by default.

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 133 additions & 55 deletions
Large diffs are not rendered by default.

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def __call__(
129129
uncond_input = self.tokenizer(
130130
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
131131
)
132-
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
132+
negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self.device))[0]
133133

134134
# get prompt text embeddings
135135
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
136-
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
136+
prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0]
137137

138138
# get the initial random noise unless the user supplied it
139139
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
@@ -144,7 +144,7 @@ def __call__(
144144
)
145145

146146
if latents is None:
147-
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
147+
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=prompt_embeds.dtype)
148148
else:
149149
if latents.shape != latents_shape:
150150
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
@@ -163,13 +163,13 @@ def __call__(
163163
if guidance_scale == 1.0:
164164
# guidance_scale of 1 means no guidance
165165
latents_input = latents
166-
context = text_embeddings
166+
context = prompt_embeds
167167
else:
168168
# For classifier free guidance, we need to do two forward passes.
169169
# Here we concatenate the unconditional and text embeddings into a single batch
170170
# to avoid doing two forward passes
171171
latents_input = torch.cat([latents] * 2)
172-
context = torch.cat([uncond_embeddings, text_embeddings])
172+
context = torch.cat([negative_prompt_embeds, prompt_embeds])
173173

174174
# predict the noise residual
175175
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,21 +364,21 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free
364364
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
365365

366366
image = image.to(device=device, dtype=dtype)
367-
image_embeddings, uncond_embeddings = self.image_encoder(image, return_uncond_vector=True)
367+
image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True)
368368

369369
# duplicate image embeddings for each generation per prompt, using mps friendly method
370370
bs_embed, seq_len, _ = image_embeddings.shape
371371
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
372372
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
373373

374374
if do_classifier_free_guidance:
375-
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
376-
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)
375+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1)
376+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1)
377377

378378
# For classifier free guidance, we need to do two forward passes.
379379
# Here we concatenate the unconditional and text embeddings into a single batch
380380
# to avoid doing two forward passes
381-
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
381+
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
382382

383383
return image_embeddings
384384

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 128 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -261,60 +261,89 @@ def _execution_device(self):
261261
return self.device
262262

263263
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
264-
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
264+
def _encode_prompt(
265+
self,
266+
prompt,
267+
device,
268+
num_images_per_prompt,
269+
do_classifier_free_guidance,
270+
negative_prompt=None,
271+
prompt_embeds: Optional[torch.FloatTensor] = None,
272+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
273+
):
265274
r"""
266275
Encodes the prompt into text encoder hidden states.
267276
268277
Args:
269-
prompt (`str` or `list(int)`):
278+
prompt (`str` or `List[str]`, *optional*):
270279
prompt to be encoded
271280
device: (`torch.device`):
272281
torch device
273282
num_images_per_prompt (`int`):
274283
number of images that should be generated per prompt
275284
do_classifier_free_guidance (`bool`):
276285
whether to use classifier free guidance or not
277-
negative_prompt (`str` or `List[str]`):
278-
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
279-
if `guidance_scale` is less than `1`).
286+
negative_ prompt (`str` or `List[str]`, *optional*):
287+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
288+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
289+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
290+
prompt_embeds (`torch.FloatTensor`, *optional*):
291+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
292+
provided, text embeddings will be generated from `prompt` input argument.
293+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
294+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
295+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
296+
argument.
280297
"""
281-
batch_size = len(prompt) if isinstance(prompt, list) else 1
282-
283-
text_inputs = self.tokenizer(
284-
prompt,
285-
padding="max_length",
286-
max_length=self.tokenizer.model_max_length,
287-
truncation=True,
288-
return_tensors="pt",
289-
)
290-
text_input_ids = text_inputs.input_ids
291-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
298+
if prompt is not None and isinstance(prompt, str):
299+
batch_size = 1
300+
elif prompt is not None and isinstance(prompt, list):
301+
batch_size = len(prompt)
302+
else:
303+
batch_size = prompt_embeds.shape[0]
292304

293-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
294-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
295-
logger.warning(
296-
"The following part of your input was truncated because CLIP can only handle sequences up to"
297-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
305+
if prompt_embeds is None:
306+
text_inputs = self.tokenizer(
307+
prompt,
308+
padding="max_length",
309+
max_length=self.tokenizer.model_max_length,
310+
truncation=True,
311+
return_tensors="pt",
298312
)
313+
text_input_ids = text_inputs.input_ids
314+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
299315

300-
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
301-
attention_mask = text_inputs.attention_mask.to(device)
302-
else:
303-
attention_mask = None
316+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
317+
text_input_ids, untruncated_ids
318+
):
319+
removed_text = self.tokenizer.batch_decode(
320+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
321+
)
322+
logger.warning(
323+
"The following part of your input was truncated because CLIP can only handle sequences up to"
324+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
325+
)
304326

305-
text_embeddings = self.text_encoder(
306-
text_input_ids.to(device),
307-
attention_mask=attention_mask,
308-
)
309-
text_embeddings = text_embeddings[0]
327+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
328+
attention_mask = text_inputs.attention_mask.to(device)
329+
else:
330+
attention_mask = None
331+
332+
prompt_embeds = self.text_encoder(
333+
text_input_ids.to(device),
334+
attention_mask=attention_mask,
335+
)
336+
prompt_embeds = prompt_embeds[0]
310337

338+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
339+
340+
bs_embed, seq_len, _ = prompt_embeds.shape
311341
# duplicate text embeddings for each generation per prompt, using mps friendly method
312-
bs_embed, seq_len, _ = text_embeddings.shape
313-
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
314-
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
342+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
343+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
315344

316345
# get unconditional embeddings for classifier free guidance
317-
if do_classifier_free_guidance:
346+
if do_classifier_free_guidance and negative_prompt_embeds is None:
318347
uncond_tokens: List[str]
319348
if negative_prompt is None:
320349
uncond_tokens = [""] * batch_size
@@ -334,7 +363,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
334363
else:
335364
uncond_tokens = negative_prompt
336365

337-
max_length = text_input_ids.shape[-1]
366+
max_length = prompt_embeds.shape[1]
338367
uncond_input = self.tokenizer(
339368
uncond_tokens,
340369
padding="max_length",
@@ -348,26 +377,32 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
348377
else:
349378
attention_mask = None
350379

351-
uncond_embeddings = self.text_encoder(
380+
negative_prompt_embeds = self.text_encoder(
352381
uncond_input.input_ids.to(device),
353382
attention_mask=attention_mask,
354383
)
355-
uncond_embeddings = uncond_embeddings[0]
384+
negative_prompt_embeds = negative_prompt_embeds[0]
356385

386+
if do_classifier_free_guidance:
357387
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
358-
seq_len = uncond_embeddings.shape[1]
359-
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
360-
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
388+
seq_len = negative_prompt_embeds.shape[1]
389+
390+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
391+
392+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
393+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
361394

362395
# For classifier free guidance, we need to do two forward passes.
363396
# Here we concatenate the unconditional and text embeddings into a single batch
364397
# to avoid doing two forward passes
365-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
398+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
366399

367-
return text_embeddings
400+
return prompt_embeds
368401

369402
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
370-
def check_inputs(self, prompt, strength, callback_steps):
403+
def check_inputs(
404+
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
405+
):
371406
if not isinstance(prompt, str) and not isinstance(prompt, list):
372407
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
373408

@@ -382,6 +417,32 @@ def check_inputs(self, prompt, strength, callback_steps):
382417
f" {type(callback_steps)}."
383418
)
384419

420+
if prompt is not None and prompt_embeds is not None:
421+
raise ValueError(
422+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
423+
" only forward one of the two."
424+
)
425+
elif prompt is None and prompt_embeds is None:
426+
raise ValueError(
427+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
428+
)
429+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
430+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
431+
432+
if negative_prompt is not None and negative_prompt_embeds is not None:
433+
raise ValueError(
434+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
435+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
436+
)
437+
438+
if prompt_embeds is not None and negative_prompt_embeds is not None:
439+
if prompt_embeds.shape != negative_prompt_embeds.shape:
440+
raise ValueError(
441+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
442+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
443+
f" {negative_prompt_embeds.shape}."
444+
)
445+
385446
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
386447
def prepare_extra_step_kwargs(self, generator, eta):
387448
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -492,6 +553,7 @@ def __call__(
492553
num_images_per_prompt: Optional[int] = 1,
493554
eta: Optional[float] = 0.1,
494555
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
556+
prompt_embeds: Optional[torch.FloatTensor] = None,
495557
output_type: Optional[str] = "pil",
496558
return_dict: bool = True,
497559
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -533,6 +595,13 @@ def __call__(
533595
generator (`torch.Generator`, *optional*):
534596
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
535597
to make generation deterministic.
598+
prompt_embeds (`torch.FloatTensor`, *optional*):
599+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
600+
provided, text embeddings will be generated from `prompt` input argument.
601+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
602+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
603+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
604+
argument.
536605
output_type (`str`, *optional*, defaults to `"pil"`):
537606
The output format of the generate image. Choose between
538607
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -569,8 +638,14 @@ def __call__(
569638
do_classifier_free_guidance = guidance_scale > 1.0
570639

571640
# 3. Encode input prompt
572-
text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None)
573-
source_text_embeddings = self._encode_prompt(
641+
prompt_embeds = self._encode_prompt(
642+
prompt,
643+
device,
644+
num_images_per_prompt,
645+
do_classifier_free_guidance,
646+
prompt_embeds=prompt_embeds,
647+
)
648+
source_prompt_embeds = self._encode_prompt(
574649
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
575650
)
576651

@@ -584,7 +659,7 @@ def __call__(
584659

585660
# 6. Prepare latent variables
586661
latents, clean_latents = self.prepare_latents(
587-
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
662+
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
588663
)
589664
source_latents = latents
590665

@@ -612,17 +687,17 @@ def __call__(
612687
],
613688
dim=0,
614689
)
615-
concat_text_embeddings = torch.stack(
690+
concat_prompt_embeds = torch.stack(
616691
[
617-
source_text_embeddings[0],
618-
text_embeddings[0],
619-
source_text_embeddings[1],
620-
text_embeddings[1],
692+
source_prompt_embeds[0],
693+
prompt_embeds[0],
694+
source_prompt_embeds[1],
695+
prompt_embeds[1],
621696
],
622697
dim=0,
623698
)
624699
concat_noise_pred = self.unet(
625-
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
700+
concat_latent_model_input, t, encoder_hidden_states=concat_prompt_embeds
626701
).sample
627702

628703
# perform guidance
@@ -662,7 +737,7 @@ def __call__(
662737
image = self.decode_latents(latents)
663738

664739
# 10. Run safety checker
665-
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
740+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
666741

667742
# 11. Convert to PIL
668743
if output_type == "pil":

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _generate(
196196
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
197197

198198
# get prompt text embeddings
199-
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
199+
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
200200

201201
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
202202
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
@@ -210,8 +210,8 @@ def _generate(
210210
).input_ids
211211
else:
212212
uncond_input = neg_prompt_ids
213-
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
214-
context = jnp.concatenate([uncond_embeddings, text_embeddings])
213+
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
214+
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
215215

216216
latents_shape = (
217217
batch_size,

0 commit comments

Comments
 (0)