File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -1361,6 +1361,7 @@ def __call__(
13611361 hidden_states = attn .to_out [0 ](hidden_states )
13621362 # dropout
13631363 hidden_states = attn .to_out [1 ](hidden_states )
1364+
13641365 return hidden_states
13651366
13661367
@@ -1433,8 +1434,11 @@ def __call__(
14331434 encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
14341435
14351436 if self .train_kv :
1436- key = self .to_k_custom_diffusion (encoder_hidden_states )
1437- value = self .to_v_custom_diffusion (encoder_hidden_states )
1437+ key = self .to_k_custom_diffusion (encoder_hidden_states .to (self .to_k_custom_diffusion .weight .dtype ))
1438+ value = self .to_v_custom_diffusion (encoder_hidden_states .to (self .to_v_custom_diffusion .weight .dtype ))
1439+ key = key .to (attn .to_q .weight .dtype )
1440+ value = value .to (attn .to_q .weight .dtype )
1441+
14381442 else :
14391443 key = attn .to_k (encoder_hidden_states )
14401444 value = attn .to_v (encoder_hidden_states )
You can’t perform that action at this time.
0 commit comments