@@ -325,27 +325,18 @@ def forward_linux_triton(
325
325
key_states [:, :, :, self .qk_nope_head_dim :] = k_pe .view (bsz , kv_seq_len , 1 , - 1 )
326
326
327
327
value_states = value_states .view (bsz , kv_seq_len , self .num_heads , self .v_head_dim )
328
+ value_states_padded = torch .nn .functional .pad (value_states , [0 , query_states .shape [- 1 ] - value_states .shape [- 1 ]], value = 0 )
328
329
329
- # for bsz = 1
330
- attn_output = torch .zeros (bsz * q_len , self .num_heads , self .v_head_dim , device = hidden_states .device )
331
- b_start_loc = torch .zeros (bsz , dtype = torch .int64 , device = hidden_states .device )
332
- b_seq_len = torch .full ((bsz ,), q_len , dtype = torch .int64 , device = hidden_states .device )
333
-
334
- max_input_len = q_len
335
-
336
- context_attention_fwd (
337
- q = query_states .squeeze (0 ).view (- 1 , self .num_heads , self .q_head_dim ),
338
- k = key_states .squeeze (0 ).view (- 1 , self .num_heads , self .q_head_dim ),
339
- v = value_states .squeeze (0 ).view (- 1 , self .num_heads , self .v_head_dim ),
340
- o = attn_output ,
341
- b_start_loc = b_start_loc ,
342
- b_seq_len = b_seq_len ,
343
- max_input_len = max_input_len ,
344
- is_causal = True
330
+ attn_output = flash_attn_func (
331
+ query_states ,
332
+ key_states ,
333
+ value_states_padded ,
334
+ softmax_scale = self .softmax_scale ,
335
+ causal = True ,
345
336
)
346
337
347
338
if self .q_head_dim != self .v_head_dim :
348
- attn_output = attn_output [:, :, : self .v_head_dim ]
339
+ attn_output = attn_output [:, :, :, : self .v_head_dim ]
349
340
350
341
attn_output = attn_output .reshape (
351
342
bsz , q_len , self .num_heads * self .v_head_dim
0 commit comments