5454 SequenceClassifierOutput )
5555from transformers .models .bert .modeling_bert import BertPreTrainedModel
5656
57+ IMPL_USE_FLASH2 = False
5758try :
58- import flash_attn_triton as flash_attn_triton
59- flash_attn_qkvpacked_func = flash_attn_triton .flash_attn_qkvpacked_func
59+ import importlib
60+
61+ from flash_attn import flash_attn_qkvpacked_func
62+ installed_version = importlib .metadata .version ('flash_attn' )
63+ if installed_version < '2.4.2' :
64+ raise ImportError ('newer version of flash_attn required (>= 2.4.2)' )
65+ IMPL_USE_FLASH2 = True
6066except ImportError as e :
61- flash_attn_qkvpacked_func = None
67+ warnings .warn (
68+ f'Failed to import flash_attn. Will try to import triton implementation: { e } ' ,
69+ stacklevel = 2 )
70+ try :
71+ import flash_attn_triton as flash_attn_triton
72+ flash_attn_qkvpacked_func = flash_attn_triton .flash_attn_qkvpacked_func
73+ except ImportError as e :
74+ flash_attn_qkvpacked_func = None
75+ warnings .warn (f'Failed to import flash_attn_triton as a fallback: { e } ' ,
76+ stacklevel = 2 )
6277
6378logger = logging .getLogger (__name__ )
6479
@@ -183,7 +198,8 @@ def __init__(self, config):
183198
184199 def forward (self , hidden_states : torch .Tensor , cu_seqlens : torch .Tensor ,
185200 max_seqlen_in_batch : int , indices : torch .Tensor ,
186- attn_mask : torch .Tensor , bias : torch .Tensor ) -> torch .Tensor :
201+ attn_mask : torch .Tensor , bias : torch .Tensor ,
202+ slopes : torch .Tensor ) -> torch .Tensor :
187203 """Perform self-attention.
188204
189205 If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
@@ -201,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
201217 indices: (total_nnz,)
202218 attn_mask: (batch, max_seqlen_in_batch)
203219 bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
220+ slopes: (heads) or (batch, heads)
204221
205222 Returns:
206223 attention: (total_nnz, dim)
@@ -213,7 +230,8 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
213230 'b s (t h d) -> b s t h d' ,
214231 t = 3 ,
215232 h = self .num_attention_heads )
216- if self .p_dropout or flash_attn_qkvpacked_func is None :
233+ if (not IMPL_USE_FLASH2 and
234+ self .p_dropout ) or flash_attn_qkvpacked_func is None :
217235 # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
218236 q = qkv [:, :, 0 , :, :].permute (0 , 2 , 1 , 3 ) # b h s d
219237 k = qkv [:, :, 1 , :, :].permute (0 , 2 , 3 , 1 ) # b h d s
@@ -226,19 +244,41 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
226244 attention = torch .matmul (attention_probs , v ).permute (0 , 2 , 1 ,
227245 3 ) # b s h d
228246 else :
229- # Triton implementation only supports 0 attention dropout
230- convert_dtype = qkv .dtype not in [torch .float16 , torch .bfloat16 ]
231- if convert_dtype :
232- # Triton implementation only supports fp16 and bf16
233- orig_dtype = qkv .dtype
234- qkv = qkv .to (torch .float16 )
235- bias_dtype = bias .dtype
236- bias = bias .to (torch .float16 )
237- attention = flash_attn_qkvpacked_func (qkv , bias )
238- attention = attention .to (orig_dtype )
239- bias = bias .to (bias_dtype )
247+ if IMPL_USE_FLASH2 :
248+ assert 1 <= len (slopes .shape ) <= 2 , f'{ slopes = } '
249+ assert slopes .shape [
250+ - 1 ] == self .num_attention_heads , f'{ slopes = } '
251+
252+ # Triton implementation only supports 0 attention dropout
253+ convert_dtype = qkv .dtype not in [torch .float16 , torch .bfloat16 ]
254+ if convert_dtype :
255+ # Triton implementation only supports fp16 and bf16
256+ orig_dtype = qkv .dtype
257+ qkv = qkv .to (torch .float16 )
258+ bias_dtype = bias .dtype
259+ bias = bias .to (torch .float16 )
260+
261+ attention = flash_attn_qkvpacked_func (
262+ qkv , dropout_p = self .p_dropout , alibi_slopes = slopes )
263+ attention = attention .to (orig_dtype )
264+ bias = bias .to (bias_dtype )
265+ else :
266+ attention = flash_attn_qkvpacked_func (
267+ qkv , dropout_p = self .p_dropout , alibi_slopes = slopes )
240268 else :
241- attention = flash_attn_qkvpacked_func (qkv , bias )
269+ # Triton implementation only supports 0 attention dropout
270+ convert_dtype = qkv .dtype not in [torch .float16 , torch .bfloat16 ]
271+ if convert_dtype :
272+ # Triton implementation only supports fp16 and bf16
273+ orig_dtype = qkv .dtype
274+ qkv = qkv .to (torch .float16 )
275+ bias_dtype = bias .dtype
276+ bias = bias .to (torch .float16 )
277+ attention = flash_attn_qkvpacked_func (qkv , bias )
278+ attention = attention .to (orig_dtype )
279+ bias = bias .to (bias_dtype )
280+ else :
281+ attention = flash_attn_qkvpacked_func (qkv , bias )
242282
243283 # attn_mask is 1 for attend and 0 for don't
244284 attention = bert_padding_module .unpad_input_only (
@@ -291,6 +331,7 @@ def forward(
291331 indices : Optional [torch .Tensor ] = None ,
292332 attn_mask : Optional [torch .Tensor ] = None ,
293333 bias : Optional [torch .Tensor ] = None ,
334+ slopes : Optional [torch .Tensor ] = None ,
294335 ) -> torch .Tensor :
295336 """Forward pass for scaled self-attention without padding.
296337
@@ -303,9 +344,11 @@ def forward(
303344 indices: None or (total_nnz,)
304345 attn_mask: None or (batch, max_seqlen_in_batch)
305346 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
347+ slopes: None or (batch, heads) or (heads,)
306348 """
349+ assert (bias is None ) == (slopes is None ), f'{ bias = } , { slopes = } '
307350 self_output = self .self (input_tensor , cu_seqlens , max_s , indices ,
308- attn_mask , bias )
351+ attn_mask , bias , slopes )
309352 if subset_idx is not None :
310353 return self .output (
311354 bert_padding_module .index_first_axis (self_output , subset_idx ),
@@ -379,6 +422,7 @@ def forward(
379422 indices : Optional [torch .Tensor ] = None ,
380423 attn_mask : Optional [torch .Tensor ] = None ,
381424 bias : Optional [torch .Tensor ] = None ,
425+ slopes : Optional [torch .Tensor ] = None ,
382426 ) -> torch .Tensor :
383427 """Forward pass for a BERT layer, including both attention and MLP.
384428
@@ -391,9 +435,12 @@ def forward(
391435 indices: None or (total_nnz,)
392436 attn_mask: None or (batch, max_seqlen_in_batch)
393437 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
438+ slopes: None or (batch, heads) or (heads,)
394439 """
440+ assert (bias is None ) == (slopes is None ), f'{ bias = } , { slopes = } '
395441 attention_output = self .attention (hidden_states , cu_seqlens , seqlen ,
396- subset_idx , indices , attn_mask , bias )
442+ subset_idx , indices , attn_mask , bias ,
443+ slopes )
397444 layer_output = self .mlp (attention_output )
398445 return layer_output
399446
@@ -463,6 +510,7 @@ def get_slopes_power_of_2(n_heads: int) -> List[float]:
463510 relative_position = relative_position .unsqueeze (0 ).expand (
464511 n_heads , - 1 , - 1 )
465512 slopes = torch .Tensor (_get_alibi_head_slopes (n_heads )).to (device )
513+ self .slopes = slopes
466514 alibi = slopes .unsqueeze (1 ).unsqueeze (1 ) * - relative_position
467515 # [1, n_heads, max_token_length, max_token_length]
468516 alibi = alibi .unsqueeze (0 )
@@ -504,6 +552,7 @@ def forward(
504552 elif self .alibi .device != hidden_states .device :
505553 # Device catch-up
506554 self .alibi = self .alibi .to (hidden_states .device )
555+ self .slopes = self .slopes .to (hidden_states .device )
507556 alibi_bias = self .alibi [:, :, :seqlen , :seqlen ]
508557 attn_bias = extended_attention_mask [:, :, :seqlen , :seqlen ]
509558 alibi_attn_mask = attn_bias + alibi_bias
@@ -517,7 +566,8 @@ def forward(
517566 None ,
518567 indices ,
519568 attn_mask = attention_mask ,
520- bias = alibi_attn_mask )
569+ bias = alibi_attn_mask ,
570+ slopes = self .slopes )
521571 if output_all_encoded_layers :
522572 all_encoder_layers .append (hidden_states )
523573 # Pad inputs and mask. It will insert back zero-padded tokens.
@@ -536,7 +586,8 @@ def forward(
536586 None ,
537587 indices ,
538588 attn_mask = attention_mask ,
539- bias = alibi_attn_mask )
589+ bias = alibi_attn_mask ,
590+ slopes = self .slopes )
540591 if output_all_encoded_layers :
541592 all_encoder_layers .append (hidden_states )
542593 subset_idx = torch .nonzero (subset_mask [attention_mask_bool ],
@@ -547,7 +598,8 @@ def forward(
547598 subset_idx = subset_idx ,
548599 indices = indices ,
549600 attn_mask = attention_mask ,
550- bias = alibi_attn_mask )
601+ bias = alibi_attn_mask ,
602+ slopes = self .slopes )
551603
552604 if not output_all_encoded_layers :
553605 all_encoder_layers .append (hidden_states )
0 commit comments