Skip to content

Commit 11d4d9a

Browse files
committed
return select index
1 parent 7eff1f4 commit 11d4d9a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def encode_prompt(
235235

236236
# See Section 3.1. of the paper. (???)
237237
max_length = max_sequence_length
238-
#select_index = [0] + list(range(-max_length + 1, 0))
238+
select_index = [0] + list(range(-max_length + 1, 0))
239239

240240
if prompt_embeds is None:
241241
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
@@ -263,8 +263,8 @@ def encode_prompt(
263263
prompt_attention_mask = prompt_attention_mask.to(device)
264264

265265
prompt_embeds = self.text_encoder(input_ids=text_input_ids.to(device), attention_mask=prompt_attention_mask)
266-
prompt_embeds = prompt_embeds[0]#[:, select_index]
267-
prompt_attention_mask = prompt_attention_mask#[:, select_index]
266+
prompt_embeds = prompt_embeds[0][:, select_index]
267+
prompt_attention_mask = prompt_attention_mask[:, select_index]
268268

269269
if self.transformer is not None:
270270
dtype = self.transformer.dtype

0 commit comments

Comments
 (0)