Skip to content

Commit 6a89a6c

Browse files
authored
Update custom diffusion attn processor (huggingface#5663)
update custom diffusion attn processor
1 parent 9bafef3 commit 6a89a6c

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,7 @@ def __call__(
13611361
hidden_states = attn.to_out[0](hidden_states)
13621362
# dropout
13631363
hidden_states = attn.to_out[1](hidden_states)
1364+
13641365
return hidden_states
13651366

13661367

@@ -1433,8 +1434,11 @@ def __call__(
14331434
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
14341435

14351436
if self.train_kv:
1436-
key = self.to_k_custom_diffusion(encoder_hidden_states)
1437-
value = self.to_v_custom_diffusion(encoder_hidden_states)
1437+
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1438+
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1439+
key = key.to(attn.to_q.weight.dtype)
1440+
value = value.to(attn.to_q.weight.dtype)
1441+
14381442
else:
14391443
key = attn.to_k(encoder_hidden_states)
14401444
value = attn.to_v(encoder_hidden_states)

0 commit comments

Comments
 (0)