File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed
src/diffusers/pipelines/pixart_alpha Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -156,6 +156,8 @@ def encode_prompt(
156156 mask_feature: (bool, defaults to `True`):
157157 If `True`, the function will mask the text embeddings.
158158 """
159+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
160+
159161 if device is None :
160162 device = self ._execution_device
161163
@@ -253,7 +255,7 @@ def encode_prompt(
253255 negative_prompt_embeds = None
254256
255257 # Perform additional masking.
256- if mask_feature and prompt_embeds is None and negative_prompt_embeds is None :
258+ if mask_feature and not embeds_initially_provided :
257259 prompt_embeds = prompt_embeds .unsqueeze (1 )
258260 masked_prompt_embeds , keep_indices = self .mask_text_embeddings (prompt_embeds , prompt_embeds_attention_mask )
259261 masked_prompt_embeds = masked_prompt_embeds .squeeze (1 )
Original file line number Diff line number Diff line change @@ -120,7 +120,6 @@ def test_save_load_optional_components(self):
120120 "generator" : generator ,
121121 "num_inference_steps" : num_inference_steps ,
122122 "output_type" : output_type ,
123- "mask_feature" : False ,
124123 }
125124
126125 # set all optional components to None
@@ -155,7 +154,6 @@ def test_save_load_optional_components(self):
155154 "generator" : generator ,
156155 "num_inference_steps" : num_inference_steps ,
157156 "output_type" : output_type ,
158- "mask_feature" : False ,
159157 }
160158
161159 output_loaded = pipe_loaded (** inputs )[0 ]
You can’t perform that action at this time.
0 commit comments