File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,7 @@ def __init__(
7171 self .proj_attn = nn .Linear (channels , channels , bias = True )
7272
7373 self ._use_memory_efficient_attention_xformers = False
74+ self ._use_2_0_attn = True
7475 self ._attention_op = None
7576
7677 def reshape_heads_to_batch_dim (self , tensor , merge_head_and_batch = True ):
@@ -142,9 +143,8 @@ def forward(self, hidden_states):
142143
143144 scale = 1 / math .sqrt (self .channels / self .num_heads )
144145
145- use_torch_2_0_attn = (
146- hasattr (F , "scaled_dot_product_attention" ) and not self ._use_memory_efficient_attention_xformers
147- )
146+ _use_2_0_attn = self ._use_2_0_attn and not self ._use_memory_efficient_attention_xformers
147+ use_torch_2_0_attn = hasattr (F , "scaled_dot_product_attention" ) and _use_2_0_attn
148148
149149 query_proj = self .reshape_heads_to_batch_dim (query_proj , merge_head_and_batch = not use_torch_2_0_attn )
150150 key_proj = self .reshape_heads_to_batch_dim (key_proj , merge_head_and_batch = not use_torch_2_0_attn )
You can’t perform that action at this time.
0 commit comments