@@ -156,15 +156,15 @@ def _encode_prompt(
156156 text_encoder_output = self .text_encoder (text_input_ids .to (device ))
157157
158158 prompt_embeds = text_encoder_output .text_embeds
159- text_encoder_hidden_states = text_encoder_output .last_hidden_state
159+ text_enc_hid_states = text_encoder_output .last_hidden_state
160160
161161 else :
162162 batch_size = text_model_output [0 ].shape [0 ]
163- prompt_embeds , text_encoder_hidden_states = text_model_output [0 ], text_model_output [1 ]
163+ prompt_embeds , text_enc_hid_states = text_model_output [0 ], text_model_output [1 ]
164164 text_mask = text_attention_mask
165165
166166 prompt_embeds = prompt_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
167- text_encoder_hidden_states = text_encoder_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
167+ text_enc_hid_states = text_enc_hid_states .repeat_interleave (num_images_per_prompt , dim = 0 )
168168 text_mask = text_mask .repeat_interleave (num_images_per_prompt , dim = 0 )
169169
170170 if do_classifier_free_guidance :
@@ -181,17 +181,17 @@ def _encode_prompt(
181181 negative_prompt_embeds_text_encoder_output = self .text_encoder (uncond_input .input_ids .to (device ))
182182
183183 negative_prompt_embeds = negative_prompt_embeds_text_encoder_output .text_embeds
184- uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output .last_hidden_state
184+ uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output .last_hidden_state
185185
186186 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
187187
188188 seq_len = negative_prompt_embeds .shape [1 ]
189189 negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt )
190190 negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len )
191191
192- seq_len = uncond_text_encoder_hidden_states .shape [1 ]
193- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states .repeat (1 , num_images_per_prompt , 1 )
194- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states .view (
192+ seq_len = uncond_text_enc_hid_states .shape [1 ]
193+ uncond_text_enc_hid_states = uncond_text_enc_hid_states .repeat (1 , num_images_per_prompt , 1 )
194+ uncond_text_enc_hid_states = uncond_text_enc_hid_states .view (
195195 batch_size * num_images_per_prompt , seq_len , - 1
196196 )
197197 uncond_text_mask = uncond_text_mask .repeat_interleave (num_images_per_prompt , dim = 0 )
@@ -202,11 +202,11 @@ def _encode_prompt(
202202 # Here we concatenate the unconditional and text embeddings into a single batch
203203 # to avoid doing two forward passes
204204 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
205- text_encoder_hidden_states = torch .cat ([uncond_text_encoder_hidden_states , text_encoder_hidden_states ])
205+ text_enc_hid_states = torch .cat ([uncond_text_enc_hid_states , text_enc_hid_states ])
206206
207207 text_mask = torch .cat ([uncond_text_mask , text_mask ])
208208
209- return prompt_embeds , text_encoder_hidden_states , text_mask
209+ return prompt_embeds , text_enc_hid_states , text_mask
210210
211211 @torch .no_grad ()
212212 def __call__ (
@@ -293,7 +293,7 @@ def __call__(
293293
294294 do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
295295
296- prompt_embeds , text_encoder_hidden_states , text_mask = self ._encode_prompt (
296+ prompt_embeds , text_enc_hid_states , text_mask = self ._encode_prompt (
297297 prompt , device , num_images_per_prompt , do_classifier_free_guidance , text_model_output , text_attention_mask
298298 )
299299
@@ -321,7 +321,7 @@ def __call__(
321321 latent_model_input ,
322322 timestep = t ,
323323 proj_embedding = prompt_embeds ,
324- encoder_hidden_states = text_encoder_hidden_states ,
324+ encoder_hidden_states = text_enc_hid_states ,
325325 attention_mask = text_mask ,
326326 ).predicted_image_embedding
327327
@@ -352,10 +352,10 @@ def __call__(
352352
353353 # decoder
354354
355- text_encoder_hidden_states , additive_clip_time_embeddings = self .text_proj (
355+ text_enc_hid_states , additive_clip_time_embeddings = self .text_proj (
356356 image_embeddings = image_embeddings ,
357357 prompt_embeds = prompt_embeds ,
358- text_encoder_hidden_states = text_encoder_hidden_states ,
358+ text_encoder_hidden_states = text_enc_hid_states ,
359359 do_classifier_free_guidance = do_classifier_free_guidance ,
360360 )
361361
@@ -377,7 +377,7 @@ def __call__(
377377
378378 decoder_latents = self .prepare_latents (
379379 (batch_size , num_channels_latents , height , width ),
380- text_encoder_hidden_states .dtype ,
380+ text_enc_hid_states .dtype ,
381381 device ,
382382 generator ,
383383 decoder_latents ,
@@ -391,7 +391,7 @@ def __call__(
391391 noise_pred = self .decoder (
392392 sample = latent_model_input ,
393393 timestep = t ,
394- encoder_hidden_states = text_encoder_hidden_states ,
394+ encoder_hidden_states = text_enc_hid_states ,
395395 class_labels = additive_clip_time_embeddings ,
396396 attention_mask = decoder_text_mask ,
397397 ).sample
0 commit comments