@@ -261,60 +261,89 @@ def _execution_device(self):
261261 return self .device
262262
263263 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
264- def _encode_prompt (self , prompt , device , num_images_per_prompt , do_classifier_free_guidance , negative_prompt ):
264+ def _encode_prompt (
265+ self ,
266+ prompt ,
267+ device ,
268+ num_images_per_prompt ,
269+ do_classifier_free_guidance ,
270+ negative_prompt = None ,
271+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
272+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
273+ ):
265274 r"""
266275 Encodes the prompt into text encoder hidden states.
267276
268277 Args:
269- prompt (`str` or `list(int)` ):
278+ prompt (`str` or `List[str]`, *optional* ):
270279 prompt to be encoded
271280 device: (`torch.device`):
272281 torch device
273282 num_images_per_prompt (`int`):
274283 number of images that should be generated per prompt
275284 do_classifier_free_guidance (`bool`):
276285 whether to use classifier free guidance or not
277- negative_prompt (`str` or `List[str]`):
278- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
279- if `guidance_scale` is less than `1`).
286+ negative_ prompt (`str` or `List[str]`, *optional*):
287+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
288+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
289+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
290+ prompt_embeds (`torch.FloatTensor`, *optional*):
291+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
292+ provided, text embeddings will be generated from `prompt` input argument.
293+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
294+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
295+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
296+ argument.
280297 """
281- batch_size = len (prompt ) if isinstance (prompt , list ) else 1
282-
283- text_inputs = self .tokenizer (
284- prompt ,
285- padding = "max_length" ,
286- max_length = self .tokenizer .model_max_length ,
287- truncation = True ,
288- return_tensors = "pt" ,
289- )
290- text_input_ids = text_inputs .input_ids
291- untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
298+ if prompt is not None and isinstance (prompt , str ):
299+ batch_size = 1
300+ elif prompt is not None and isinstance (prompt , list ):
301+ batch_size = len (prompt )
302+ else :
303+ batch_size = prompt_embeds .shape [0 ]
292304
293- if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
294- removed_text = self .tokenizer .batch_decode (untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ])
295- logger .warning (
296- "The following part of your input was truncated because CLIP can only handle sequences up to"
297- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
305+ if prompt_embeds is None :
306+ text_inputs = self .tokenizer (
307+ prompt ,
308+ padding = "max_length" ,
309+ max_length = self .tokenizer .model_max_length ,
310+ truncation = True ,
311+ return_tensors = "pt" ,
298312 )
313+ text_input_ids = text_inputs .input_ids
314+ untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
299315
300- if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
301- attention_mask = text_inputs .attention_mask .to (device )
302- else :
303- attention_mask = None
316+ if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
317+ text_input_ids , untruncated_ids
318+ ):
319+ removed_text = self .tokenizer .batch_decode (
320+ untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
321+ )
322+ logger .warning (
323+ "The following part of your input was truncated because CLIP can only handle sequences up to"
324+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
325+ )
304326
305- text_embeddings = self .text_encoder (
306- text_input_ids .to (device ),
307- attention_mask = attention_mask ,
308- )
309- text_embeddings = text_embeddings [0 ]
327+ if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
328+ attention_mask = text_inputs .attention_mask .to (device )
329+ else :
330+ attention_mask = None
331+
332+ prompt_embeds = self .text_encoder (
333+ text_input_ids .to (device ),
334+ attention_mask = attention_mask ,
335+ )
336+ prompt_embeds = prompt_embeds [0 ]
310337
338+ prompt_embeds = prompt_embeds .to (dtype = self .text_encoder .dtype , device = device )
339+
340+ bs_embed , seq_len , _ = prompt_embeds .shape
311341 # duplicate text embeddings for each generation per prompt, using mps friendly method
312- bs_embed , seq_len , _ = text_embeddings .shape
313- text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt , 1 )
314- text_embeddings = text_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
342+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
343+ prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
315344
316345 # get unconditional embeddings for classifier free guidance
317- if do_classifier_free_guidance :
346+ if do_classifier_free_guidance and negative_prompt_embeds is None :
318347 uncond_tokens : List [str ]
319348 if negative_prompt is None :
320349 uncond_tokens = ["" ] * batch_size
@@ -334,7 +363,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
334363 else :
335364 uncond_tokens = negative_prompt
336365
337- max_length = text_input_ids .shape [- 1 ]
366+ max_length = prompt_embeds .shape [1 ]
338367 uncond_input = self .tokenizer (
339368 uncond_tokens ,
340369 padding = "max_length" ,
@@ -348,26 +377,32 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
348377 else :
349378 attention_mask = None
350379
351- uncond_embeddings = self .text_encoder (
380+ negative_prompt_embeds = self .text_encoder (
352381 uncond_input .input_ids .to (device ),
353382 attention_mask = attention_mask ,
354383 )
355- uncond_embeddings = uncond_embeddings [0 ]
384+ negative_prompt_embeds = negative_prompt_embeds [0 ]
356385
386+ if do_classifier_free_guidance :
357387 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
358- seq_len = uncond_embeddings .shape [1 ]
359- uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
360- uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
388+ seq_len = negative_prompt_embeds .shape [1 ]
389+
390+ negative_prompt_embeds = negative_prompt_embeds .to (dtype = self .text_encoder .dtype , device = device )
391+
392+ negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
393+ negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
361394
362395 # For classifier free guidance, we need to do two forward passes.
363396 # Here we concatenate the unconditional and text embeddings into a single batch
364397 # to avoid doing two forward passes
365- text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
398+ prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
366399
367- return text_embeddings
400+ return prompt_embeds
368401
369402 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
370- def check_inputs (self , prompt , strength , callback_steps ):
403+ def check_inputs (
404+ self , prompt , strength , callback_steps , negative_prompt = None , prompt_embeds = None , negative_prompt_embeds = None
405+ ):
371406 if not isinstance (prompt , str ) and not isinstance (prompt , list ):
372407 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
373408
@@ -382,6 +417,32 @@ def check_inputs(self, prompt, strength, callback_steps):
382417 f" { type (callback_steps )} ."
383418 )
384419
420+ if prompt is not None and prompt_embeds is not None :
421+ raise ValueError (
422+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
423+ " only forward one of the two."
424+ )
425+ elif prompt is None and prompt_embeds is None :
426+ raise ValueError (
427+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
428+ )
429+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
430+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
431+
432+ if negative_prompt is not None and negative_prompt_embeds is not None :
433+ raise ValueError (
434+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
435+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
436+ )
437+
438+ if prompt_embeds is not None and negative_prompt_embeds is not None :
439+ if prompt_embeds .shape != negative_prompt_embeds .shape :
440+ raise ValueError (
441+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
442+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
443+ f" { negative_prompt_embeds .shape } ."
444+ )
445+
385446 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
386447 def prepare_extra_step_kwargs (self , generator , eta ):
387448 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -492,6 +553,7 @@ def __call__(
492553 num_images_per_prompt : Optional [int ] = 1 ,
493554 eta : Optional [float ] = 0.1 ,
494555 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
556+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
495557 output_type : Optional [str ] = "pil" ,
496558 return_dict : bool = True ,
497559 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -533,6 +595,13 @@ def __call__(
533595 generator (`torch.Generator`, *optional*):
534596 One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
535597 to make generation deterministic.
598+ prompt_embeds (`torch.FloatTensor`, *optional*):
599+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
600+ provided, text embeddings will be generated from `prompt` input argument.
601+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
602+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
603+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
604+ argument.
536605 output_type (`str`, *optional*, defaults to `"pil"`):
537606 The output format of the generate image. Choose between
538607 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -569,8 +638,14 @@ def __call__(
569638 do_classifier_free_guidance = guidance_scale > 1.0
570639
571640 # 3. Encode input prompt
572- text_embeddings = self ._encode_prompt (prompt , device , num_images_per_prompt , do_classifier_free_guidance , None )
573- source_text_embeddings = self ._encode_prompt (
641+ prompt_embeds = self ._encode_prompt (
642+ prompt ,
643+ device ,
644+ num_images_per_prompt ,
645+ do_classifier_free_guidance ,
646+ prompt_embeds = prompt_embeds ,
647+ )
648+ source_prompt_embeds = self ._encode_prompt (
574649 source_prompt , device , num_images_per_prompt , do_classifier_free_guidance , None
575650 )
576651
@@ -584,7 +659,7 @@ def __call__(
584659
585660 # 6. Prepare latent variables
586661 latents , clean_latents = self .prepare_latents (
587- image , latent_timestep , batch_size , num_images_per_prompt , text_embeddings .dtype , device , generator
662+ image , latent_timestep , batch_size , num_images_per_prompt , prompt_embeds .dtype , device , generator
588663 )
589664 source_latents = latents
590665
@@ -612,17 +687,17 @@ def __call__(
612687 ],
613688 dim = 0 ,
614689 )
615- concat_text_embeddings = torch .stack (
690+ concat_prompt_embeds = torch .stack (
616691 [
617- source_text_embeddings [0 ],
618- text_embeddings [0 ],
619- source_text_embeddings [1 ],
620- text_embeddings [1 ],
692+ source_prompt_embeds [0 ],
693+ prompt_embeds [0 ],
694+ source_prompt_embeds [1 ],
695+ prompt_embeds [1 ],
621696 ],
622697 dim = 0 ,
623698 )
624699 concat_noise_pred = self .unet (
625- concat_latent_model_input , t , encoder_hidden_states = concat_text_embeddings
700+ concat_latent_model_input , t , encoder_hidden_states = concat_prompt_embeds
626701 ).sample
627702
628703 # perform guidance
@@ -662,7 +737,7 @@ def __call__(
662737 image = self .decode_latents (latents )
663738
664739 # 10. Run safety checker
665- image , has_nsfw_concept = self .run_safety_checker (image , device , text_embeddings .dtype )
740+ image , has_nsfw_concept = self .run_safety_checker (image , device , prompt_embeds .dtype )
666741
667742 # 11. Convert to PIL
668743 if output_type == "pil" :
0 commit comments