Skip to content

Commit 3768d4d

Browse files
[Core] refactor encode_prompt (huggingface#4617)
* refactoring of encode_prompt() * better handling of device. * fix: device determination * fix: device determination 2 * handle num_images_per_prompt * revert changes in loaders.py and give birth to encode_prompt(). * minor refactoring for encode_prompt()/ * make backward compatible. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * fix: concatenation of the neg and pos embeddings. * incorporate encode_prompt() in test_stable_diffusion.py * turn it into big PR. * make it bigger * gligen fixes. * more fixes to fligen * _encode_prompt -> encode_prompt in tests * first batch * second batch * fix blasphemous mistake * fix * fix: hopefully for the final time. --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 8ccb619 commit 3768d4d

30 files changed

+1180
-290
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,45 @@ def _encode_prompt(
258258
prompt_embeds: Optional[torch.FloatTensor] = None,
259259
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
260260
lora_scale: Optional[float] = None,
261+
):
262+
deprecation_message = (
263+
"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
264+
" instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
265+
)
266+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
267+
268+
prompt_embeds_tuple = self.encode_prompt(
269+
prompt=prompt,
270+
device=device,
271+
num_images_per_prompt=num_images_per_prompt,
272+
do_classifier_free_guidance=do_classifier_free_guidance,
273+
negative_prompt=negative_prompt,
274+
prompt_embeds=prompt_embeds,
275+
negative_prompt_embeds=negative_prompt_embeds,
276+
lora_scale=lora_scale,
277+
)
278+
279+
# concatenate for backwards comp
280+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
281+
282+
return prompt_embeds
283+
284+
def encode_prompt(
285+
self,
286+
prompt,
287+
device,
288+
num_images_per_prompt,
289+
do_classifier_free_guidance,
290+
negative_prompt=None,
291+
prompt_embeds: Optional[torch.FloatTensor] = None,
292+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
293+
lora_scale: Optional[float] = None,
261294
):
262295
r"""
263296
Encodes the prompt into text encoder hidden states.
264297
265298
Args:
266-
prompt (`str` or `List[str]`, *optional*):
299+
prompt (`str` or `List[str]`, *optional*):
267300
prompt to be encoded
268301
device: (`torch.device`):
269302
torch device
@@ -402,12 +435,7 @@ def _encode_prompt(
402435
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
403436
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
404437

405-
# For classifier free guidance, we need to do two forward passes.
406-
# Here we concatenate the unconditional and text embeddings into a single batch
407-
# to avoid doing two forward passes
408-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
409-
410-
return prompt_embeds
438+
return prompt_embeds, negative_prompt_embeds
411439

412440
def run_safety_checker(self, image, device, dtype):
413441
if self.safety_checker is None:
@@ -634,7 +662,7 @@ def __call__(
634662
text_encoder_lora_scale = (
635663
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
636664
)
637-
prompt_embeds = self._encode_prompt(
665+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
638666
prompt,
639667
device,
640668
num_images_per_prompt,
@@ -644,6 +672,11 @@ def __call__(
644672
negative_prompt_embeds=negative_prompt_embeds,
645673
lora_scale=text_encoder_lora_scale,
646674
)
675+
# For classifier free guidance, we need to do two forward passes.
676+
# Here we concatenate the unconditional and text embeddings into a single batch
677+
# to avoid doing two forward passes
678+
if do_classifier_free_guidance:
679+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
647680

648681
# 4. Prepare timesteps
649682
self.scheduler.set_timesteps(num_inference_steps, device=device)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,45 @@ def _encode_prompt(
259259
prompt_embeds: Optional[torch.FloatTensor] = None,
260260
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
261261
lora_scale: Optional[float] = None,
262+
):
263+
deprecation_message = (
264+
"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
265+
" instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
266+
)
267+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
268+
269+
prompt_embeds_tuple = self.encode_prompt(
270+
prompt=prompt,
271+
device=device,
272+
num_images_per_prompt=num_images_per_prompt,
273+
do_classifier_free_guidance=do_classifier_free_guidance,
274+
negative_prompt=negative_prompt,
275+
prompt_embeds=prompt_embeds,
276+
negative_prompt_embeds=negative_prompt_embeds,
277+
lora_scale=lora_scale,
278+
)
279+
280+
# concatenate for backwards comp
281+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
282+
283+
return prompt_embeds
284+
285+
def encode_prompt(
286+
self,
287+
prompt,
288+
device,
289+
num_images_per_prompt,
290+
do_classifier_free_guidance,
291+
negative_prompt=None,
292+
prompt_embeds: Optional[torch.FloatTensor] = None,
293+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
294+
lora_scale: Optional[float] = None,
262295
):
263296
r"""
264297
Encodes the prompt into text encoder hidden states.
265298
266299
Args:
267-
prompt (`str` or `List[str]`, *optional*):
300+
prompt (`str` or `List[str]`, *optional*):
268301
prompt to be encoded
269302
device: (`torch.device`):
270303
torch device
@@ -403,12 +436,7 @@ def _encode_prompt(
403436
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
404437
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
405438

406-
# For classifier free guidance, we need to do two forward passes.
407-
# Here we concatenate the unconditional and text embeddings into a single batch
408-
# to avoid doing two forward passes
409-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
410-
411-
return prompt_embeds
439+
return prompt_embeds, negative_prompt_embeds
412440

413441
def run_safety_checker(self, image, device, dtype):
414442
if self.safety_checker is None:
@@ -668,7 +696,7 @@ def __call__(
668696
text_encoder_lora_scale = (
669697
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
670698
)
671-
prompt_embeds = self._encode_prompt(
699+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
672700
prompt,
673701
device,
674702
num_images_per_prompt,
@@ -678,6 +706,11 @@ def __call__(
678706
negative_prompt_embeds=negative_prompt_embeds,
679707
lora_scale=text_encoder_lora_scale,
680708
)
709+
# For classifier free guidance, we need to do two forward passes.
710+
# Here we concatenate the unconditional and text embeddings into a single batch
711+
# to avoid doing two forward passes
712+
if do_classifier_free_guidance:
713+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
681714

682715
# 4. Preprocess image
683716
image = self.image_processor.preprocess(image)

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
2929
from ...schedulers import KarrasDiffusionSchedulers
3030
from ...utils import (
31+
deprecate,
3132
is_accelerate_available,
3233
is_accelerate_version,
3334
is_compiled_module,
@@ -250,12 +251,43 @@ def _encode_prompt(
250251
prompt_embeds: Optional[torch.FloatTensor] = None,
251252
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
252253
lora_scale: Optional[float] = None,
254+
):
255+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
256+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
257+
258+
prompt_embeds_tuple = self.encode_prompt(
259+
prompt=prompt,
260+
device=device,
261+
num_images_per_prompt=num_images_per_prompt,
262+
do_classifier_free_guidance=do_classifier_free_guidance,
263+
negative_prompt=negative_prompt,
264+
prompt_embeds=prompt_embeds,
265+
negative_prompt_embeds=negative_prompt_embeds,
266+
lora_scale=lora_scale,
267+
)
268+
269+
# concatenate for backwards comp
270+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
271+
272+
return prompt_embeds
273+
274+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
275+
def encode_prompt(
276+
self,
277+
prompt,
278+
device,
279+
num_images_per_prompt,
280+
do_classifier_free_guidance,
281+
negative_prompt=None,
282+
prompt_embeds: Optional[torch.FloatTensor] = None,
283+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
284+
lora_scale: Optional[float] = None,
253285
):
254286
r"""
255287
Encodes the prompt into text encoder hidden states.
256288
257289
Args:
258-
prompt (`str` or `List[str]`, *optional*):
290+
prompt (`str` or `List[str]`, *optional*):
259291
prompt to be encoded
260292
device: (`torch.device`):
261293
torch device
@@ -394,12 +426,7 @@ def _encode_prompt(
394426
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
395427
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
396428

397-
# For classifier free guidance, we need to do two forward passes.
398-
# Here we concatenate the unconditional and text embeddings into a single batch
399-
# to avoid doing two forward passes
400-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
401-
402-
return prompt_embeds
429+
return prompt_embeds, negative_prompt_embeds
403430

404431
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
405432
def run_safety_checker(self, image, device, dtype):
@@ -842,7 +869,7 @@ def __call__(
842869
text_encoder_lora_scale = (
843870
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
844871
)
845-
prompt_embeds = self._encode_prompt(
872+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
846873
prompt,
847874
device,
848875
num_images_per_prompt,
@@ -852,6 +879,11 @@ def __call__(
852879
negative_prompt_embeds=negative_prompt_embeds,
853880
lora_scale=text_encoder_lora_scale,
854881
)
882+
# For classifier free guidance, we need to do two forward passes.
883+
# Here we concatenate the unconditional and text embeddings into a single batch
884+
# to avoid doing two forward passes
885+
if do_classifier_free_guidance:
886+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
855887

856888
# 4. Prepare image
857889
if isinstance(controlnet, ControlNetModel):

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,43 @@ def _encode_prompt(
276276
prompt_embeds: Optional[torch.FloatTensor] = None,
277277
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
278278
lora_scale: Optional[float] = None,
279+
):
280+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
281+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
282+
283+
prompt_embeds_tuple = self.encode_prompt(
284+
prompt=prompt,
285+
device=device,
286+
num_images_per_prompt=num_images_per_prompt,
287+
do_classifier_free_guidance=do_classifier_free_guidance,
288+
negative_prompt=negative_prompt,
289+
prompt_embeds=prompt_embeds,
290+
negative_prompt_embeds=negative_prompt_embeds,
291+
lora_scale=lora_scale,
292+
)
293+
294+
# concatenate for backwards comp
295+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
296+
297+
return prompt_embeds
298+
299+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
300+
def encode_prompt(
301+
self,
302+
prompt,
303+
device,
304+
num_images_per_prompt,
305+
do_classifier_free_guidance,
306+
negative_prompt=None,
307+
prompt_embeds: Optional[torch.FloatTensor] = None,
308+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
309+
lora_scale: Optional[float] = None,
279310
):
280311
r"""
281312
Encodes the prompt into text encoder hidden states.
282313
283314
Args:
284-
prompt (`str` or `List[str]`, *optional*):
315+
prompt (`str` or `List[str]`, *optional*):
285316
prompt to be encoded
286317
device: (`torch.device`):
287318
torch device
@@ -420,12 +451,7 @@ def _encode_prompt(
420451
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
421452
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
422453

423-
# For classifier free guidance, we need to do two forward passes.
424-
# Here we concatenate the unconditional and text embeddings into a single batch
425-
# to avoid doing two forward passes
426-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
427-
428-
return prompt_embeds
454+
return prompt_embeds, negative_prompt_embeds
429455

430456
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
431457
def run_safety_checker(self, image, device, dtype):
@@ -921,7 +947,7 @@ def __call__(
921947
text_encoder_lora_scale = (
922948
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
923949
)
924-
prompt_embeds = self._encode_prompt(
950+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
925951
prompt,
926952
device,
927953
num_images_per_prompt,
@@ -931,6 +957,12 @@ def __call__(
931957
negative_prompt_embeds=negative_prompt_embeds,
932958
lora_scale=text_encoder_lora_scale,
933959
)
960+
# For classifier free guidance, we need to do two forward passes.
961+
# Here we concatenate the unconditional and text embeddings into a single batch
962+
# to avoid doing two forward passes
963+
if do_classifier_free_guidance:
964+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
965+
934966
# 4. Prepare image
935967
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
936968

0 commit comments

Comments
 (0)