@@ -150,41 +150,57 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
150150
151151 def encode_prompt (
152152 self ,
153- prompt ,
154153 device ,
155154 num_images_per_prompt ,
156155 do_classifier_free_guidance ,
156+ prompt = None ,
157157 negative_prompt = None ,
158+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
159+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
158160 ):
159- batch_size = len (prompt ) if isinstance (prompt , list ) else 1
160- # get prompt text embeddings
161- text_inputs = self .tokenizer (
162- prompt ,
163- padding = "max_length" ,
164- max_length = self .tokenizer .model_max_length ,
165- truncation = True ,
166- return_tensors = "pt" ,
167- )
168- text_input_ids = text_inputs .input_ids
169- attention_mask = text_inputs .attention_mask
161+ if prompt is not None and isinstance (prompt , str ):
162+ batch_size = 1
163+ elif prompt is not None and isinstance (prompt , list ):
164+ batch_size = len (prompt )
165+ else :
166+ batch_size = prompt_embeds .shape [0 ]
167+
168+ if prompt_embeds is None :
169+ # get prompt text embeddings
170+ text_inputs = self .tokenizer (
171+ prompt ,
172+ padding = "max_length" ,
173+ max_length = self .tokenizer .model_max_length ,
174+ truncation = True ,
175+ return_tensors = "pt" ,
176+ )
177+ text_input_ids = text_inputs .input_ids
178+ attention_mask = text_inputs .attention_mask
170179
171- untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
180+ untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
181+
182+ if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
183+ text_input_ids , untruncated_ids
184+ ):
185+ removed_text = self .tokenizer .batch_decode (
186+ untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
187+ )
188+ logger .warning (
189+ "The following part of your input was truncated because CLIP can only handle sequences up to"
190+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
191+ )
192+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
193+ attention_mask = attention_mask [:, : self .tokenizer .model_max_length ]
172194
173- if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
174- removed_text = self .tokenizer .batch_decode (untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ])
175- logger .warning (
176- "The following part of your input was truncated because CLIP can only handle sequences up to"
177- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
195+ text_encoder_output = self .text_encoder (
196+ text_input_ids .to (device ), attention_mask = attention_mask .to (device )
178197 )
179- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
180- attention_mask = attention_mask [:, : self .tokenizer .model_max_length ]
198+ prompt_embeds = text_encoder_output .last_hidden_state
181199
182- text_encoder_output = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask .to (device ))
183- text_encoder_hidden_states = text_encoder_output .last_hidden_state
184- text_encoder_hidden_states = text_encoder_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
200+ prompt_embeds = prompt_embeds .to (dtype = self .text_encoder .dtype , device = device )
201+ prompt_embeds = prompt_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
185202
186- uncond_text_encoder_hidden_states = None
187- if do_classifier_free_guidance :
203+ if negative_prompt_embeds is None and do_classifier_free_guidance :
188204 uncond_tokens : List [str ]
189205 if negative_prompt is None :
190206 uncond_tokens = ["" ] * batch_size
@@ -215,17 +231,17 @@ def encode_prompt(
215231 uncond_input .input_ids .to (device ), attention_mask = uncond_input .attention_mask .to (device )
216232 )
217233
218- uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output .last_hidden_state
234+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output .last_hidden_state
219235
236+ if do_classifier_free_guidance :
220237 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
221- seq_len = uncond_text_encoder_hidden_states .shape [1 ]
222- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states .repeat (1 , num_images_per_prompt , 1 )
223- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states .view (
224- batch_size * num_images_per_prompt , seq_len , - 1
225- )
238+ seq_len = negative_prompt_embeds .shape [1 ]
239+ negative_prompt_embeds = negative_prompt_embeds .to (dtype = self .text_encoder .dtype , device = device )
240+ negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
241+ negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
226242 # done duplicates
227243
228- return text_encoder_hidden_states , uncond_text_encoder_hidden_states
244+ return prompt_embeds , negative_prompt_embeds
229245
230246 def check_inputs (
231247 self ,
@@ -264,13 +280,15 @@ def check_inputs(
264280 @replace_example_docstring (EXAMPLE_DOC_STRING )
265281 def __call__ (
266282 self ,
267- prompt : Union [str , List [str ]] = None ,
283+ prompt : Optional [ Union [str , List [str ] ]] = None ,
268284 height : int = 1024 ,
269285 width : int = 1024 ,
270286 num_inference_steps : int = 60 ,
271287 timesteps : List [float ] = None ,
272288 guidance_scale : float = 8.0 ,
273289 negative_prompt : Optional [Union [str , List [str ]]] = None ,
290+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
291+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
274292 num_images_per_prompt : Optional [int ] = 1 ,
275293 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
276294 latents : Optional [torch .FloatTensor ] = None ,
@@ -304,6 +322,13 @@ def __call__(
304322 negative_prompt (`str` or `List[str]`, *optional*):
305323 The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
306324 if `decoder_guidance_scale` is less than `1`).
325+ prompt_embeds (`torch.FloatTensor`, *optional*):
326+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
327+ provided, text embeddings will be generated from `prompt` input argument.
328+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
329+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
330+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
331+ argument.
307332 num_images_per_prompt (`int`, *optional*, defaults to 1):
308333 The number of images to generate per prompt.
309334 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -345,7 +370,13 @@ def __call__(
345370
346371 # 2. Encode caption
347372 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
348- prompt , device , num_images_per_prompt , do_classifier_free_guidance , negative_prompt
373+ prompt = prompt ,
374+ device = device ,
375+ num_images_per_prompt = num_images_per_prompt ,
376+ do_classifier_free_guidance = do_classifier_free_guidance ,
377+ negative_prompt = negative_prompt ,
378+ prompt_embeds = prompt_embeds ,
379+ negative_prompt_embeds = negative_prompt_embeds ,
349380 )
350381
351382 # For classifier free guidance, we need to do two forward passes.
0 commit comments