@@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
780780 text_input_ids = text_input_ids_list [i ]
781781
782782 prompt_embeds = text_encoder (
783- text_input_ids .to (text_encoder .device ),
784- output_hidden_states = True ,
783+ text_input_ids .to (text_encoder .device ), output_hidden_states = True , return_dict = False
785784 )
786785
787786 # We are only ALWAYS interested in the pooled output of the final text encoder
788787 pooled_prompt_embeds = prompt_embeds [0 ]
789- prompt_embeds = prompt_embeds . hidden_states [- 2 ]
788+ prompt_embeds = prompt_embeds [ - 1 ] [- 2 ]
790789 bs_embed , seq_len , _ = prompt_embeds .shape
791790 prompt_embeds = prompt_embeds .view (bs_embed , seq_len , - 1 )
792791 prompt_embeds_list .append (prompt_embeds )
@@ -1429,7 +1428,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14291428 timesteps ,
14301429 prompt_embeds_input ,
14311430 added_cond_kwargs = unet_added_conditions ,
1432- ).sample
1431+ return_dict = False ,
1432+ )[0 ]
14331433 else :
14341434 unet_added_conditions = {"time_ids" : add_time_ids .repeat (elems_to_repeat_time_ids , 1 )}
14351435 prompt_embeds , pooled_prompt_embeds = encode_prompt (
@@ -1443,8 +1443,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14431443 )
14441444 prompt_embeds_input = prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 , 1 )
14451445 model_pred = unet (
1446- noisy_model_input , timesteps , prompt_embeds_input , added_cond_kwargs = unet_added_conditions
1447- ).sample
1446+ noisy_model_input ,
1447+ timesteps ,
1448+ prompt_embeds_input ,
1449+ added_cond_kwargs = unet_added_conditions ,
1450+ return_dict = False ,
1451+ )[0 ]
14481452
14491453 # Get the target for loss depending on the prediction type
14501454 if noise_scheduler .config .prediction_type == "epsilon" :
0 commit comments