Skip to content

Commit 71ac636

Browse files
author
Jianfeng Wang
committed
fix the mask shape mismatch
1 parent c754911 commit 71ac636

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/diffusers/models/attention.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,17 @@ def _attention(self, query, key, value, mask=None):
613613
)
614614
if mask is not None:
615615
# we assumed the mask is either 0 or -inf
616-
attention_scores += mask
616+
#ipdb> pp attention_scores.shape
617+
#torch.Size([16, 4096, 17])
618+
#ipdb> mask.shape
619+
#torch.Size([2, 17])
620+
origin_shape = attention_scores.shape
621+
attention_scores = attention_scores.reshape(mask.shape[0],
622+
attention_scores.shape[0] // mask.shape[0],
623+
attention_scores.shape[1],
624+
attention_scores.shape[2])
625+
attention_scores += mask.unsqueeze(1).unsqueeze(1)
626+
attention_scores = attention_scores.reshape(*origin_shape)
617627
attention_probs = attention_scores.softmax(dim=-1)
618628

619629
# cast back to the original dtype

0 commit comments

Comments
 (0)