Skip to content

Commit c51818c

Browse files
authored
Merge pull request #902 from kvcache-ai/rollback-triton-prefill
rollback-triton-prefill
2 parents bda9cf1 + 3934b9d commit c51818c

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

ktransformers/operators/attention.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -325,27 +325,18 @@ def forward_linux_triton(
325325
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
326326

327327
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
328+
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
328329

329-
# for bsz = 1
330-
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
331-
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
332-
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
333-
334-
max_input_len = q_len
335-
336-
context_attention_fwd(
337-
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
338-
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
339-
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
340-
o=attn_output,
341-
b_start_loc=b_start_loc,
342-
b_seq_len=b_seq_len,
343-
max_input_len=max_input_len,
344-
is_causal=True
330+
attn_output = flash_attn_func(
331+
query_states,
332+
key_states,
333+
value_states_padded,
334+
softmax_scale=self.softmax_scale,
335+
causal=True,
345336
)
346337

347338
if self.q_head_dim != self.v_head_dim:
348-
attn_output = attn_output[:, :, : self.v_head_dim]
339+
attn_output = attn_output[:, :, :, : self.v_head_dim]
349340

350341
attn_output = attn_output.reshape(
351342
bsz, q_len, self.num_heads * self.v_head_dim

0 commit comments

Comments
 (0)