Skip to content

Commit b2cfc7a

Browse files
authored
Fix slow tests (huggingface#689)
* revert using baddbmm in attention - to fix `test_stable_diffusion_memory_chunking` test * styling
1 parent 552b967 commit b2cfc7a

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,8 @@ def forward(self, hidden_states, context=None, mask=None):
274274
return self.to_out(hidden_states)
275275

276276
def _attention(self, query, key, value):
277-
attention_scores = torch.baddbmm(
278-
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
279-
query,
280-
key.transpose(-1, -2),
281-
beta=0,
282-
alpha=self.scale,
283-
)
277+
# TODO: use baddbmm for better performance
278+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
284279
attention_probs = attention_scores.softmax(dim=-1)
285280
# compute attention output
286281
hidden_states = torch.matmul(attention_probs, value)

0 commit comments

Comments
 (0)