@@ -60,7 +60,6 @@ def __init__(
6060 self .channels = channels
6161
6262 self .num_heads = channels // num_head_channels if num_head_channels is not None else 1
63- self .num_head_size = num_head_channels
6463 self .group_norm = nn .GroupNorm (num_channels = channels , num_groups = norm_num_groups , eps = eps , affine = True )
6564
6665 # define q,k,v as linear layers
@@ -74,18 +73,25 @@ def __init__(
7473 self ._use_memory_efficient_attention_xformers = False
7574 self ._attention_op = None
7675
77- def reshape_heads_to_batch_dim (self , tensor ):
76+ def reshape_heads_to_batch_dim (self , tensor , merge_head_and_batch = True ):
7877 batch_size , seq_len , dim = tensor .shape
7978 head_size = self .num_heads
8079 tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
81- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
80+ tensor = tensor .permute (0 , 2 , 1 , 3 )
81+ if merge_head_and_batch :
82+ tensor = tensor .reshape (batch_size * head_size , seq_len , dim // head_size )
8283 return tensor
8384
84- def reshape_batch_dim_to_heads (self , tensor ):
85- batch_size , seq_len , dim = tensor .shape
85+ def reshape_batch_dim_to_heads (self , tensor , unmerge_head_and_batch = True ):
8686 head_size = self .num_heads
87- tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
88- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
87+
88+ if unmerge_head_and_batch :
89+ batch_size , seq_len , dim = tensor .shape
90+ tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
91+ else :
92+ batch_size , _ , seq_len , dim = tensor .shape
93+
94+ tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size , seq_len , dim * head_size )
8995 return tensor
9096
9197 def set_use_memory_efficient_attention_xformers (
@@ -134,14 +140,25 @@ def forward(self, hidden_states):
134140
135141 scale = 1 / math .sqrt (self .channels / self .num_heads )
136142
137- query_proj = self .reshape_heads_to_batch_dim (query_proj )
138- key_proj = self .reshape_heads_to_batch_dim (key_proj )
139- value_proj = self .reshape_heads_to_batch_dim (value_proj )
143+ use_torch_2_0_attn = (
144+ hasattr (F , "scaled_dot_product_attention" ) and not self ._use_memory_efficient_attention_xformers
145+ )
146+
147+ query_proj = self .reshape_heads_to_batch_dim (query_proj , merge_head_and_batch = not use_torch_2_0_attn )
148+ key_proj = self .reshape_heads_to_batch_dim (key_proj , merge_head_and_batch = not use_torch_2_0_attn )
149+ value_proj = self .reshape_heads_to_batch_dim (value_proj , merge_head_and_batch = not use_torch_2_0_attn )
140150
141151 if self ._use_memory_efficient_attention_xformers :
142152 # Memory efficient attention
143153 hidden_states = xformers .ops .memory_efficient_attention (
144- query_proj , key_proj , value_proj , attn_bias = None , op = self ._attention_op
154+ query_proj , key_proj , value_proj , attn_bias = None , op = self ._attention_op , scale = scale
155+ )
156+ hidden_states = hidden_states .to (query_proj .dtype )
157+ elif use_torch_2_0_attn :
158+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
159+ # TODO: add support for attn.scale when we move to Torch 2.1
160+ hidden_states = F .scaled_dot_product_attention (
161+ query_proj , key_proj , value_proj , dropout_p = 0.0 , is_causal = False
145162 )
146163 hidden_states = hidden_states .to (query_proj .dtype )
147164 else :
@@ -162,7 +179,7 @@ def forward(self, hidden_states):
162179 hidden_states = torch .bmm (attention_probs , value_proj )
163180
164181 # reshape hidden_states
165- hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
182+ hidden_states = self .reshape_batch_dim_to_heads (hidden_states , unmerge_head_and_batch = not use_torch_2_0_attn )
166183
167184 # compute next hidden_states
168185 hidden_states = self .proj_attn (hidden_states )
0 commit comments