Skip to content

Commit a88391a

Browse files
committed
feat: udpate infer_ext ops interface
1 parent 7fd57c9 commit a88391a

File tree

11 files changed

+40
-317
lines changed

11 files changed

+40
-317
lines changed

lmdeploy/pytorch/engine/devices/ascend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import torch
33

4-
from .dipu import DIPUDeviceUtils
4+
from .base_device_utils import BaseDeviceUtils
55

66

7-
class ASCENDDeviceUtils(DIPUDeviceUtils):
7+
class ASCENDDeviceUtils(BaseDeviceUtils):
88

99
device = 'ascend'
1010

@@ -38,4 +38,7 @@ def update_step_context(cls, step_context):
3838
kv_start_indices, device=step_context.block_offsets.device)
3939
setattr(step_context, 'kv_start_indices', kv_start_indices)
4040
setattr(step_context, 'attention_mask', attention_mask)
41+
is_unpaged_prefill = not step_context.is_decoding or all(
42+
(step_context.q_seq_length == step_context.kv_seq_length).tolist())
43+
setattr(step_context, 'is_unpaged_prefill', is_unpaged_prefill)
4144
return step_context

lmdeploy/pytorch/engine/devices/dipu.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def apply_rotary_pos_emb(
2727
cached_cos = context.cos if context else cos
2828
cached_sin = context.sin if context else sin
2929
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
30-
cached_cos, cached_sin, None, None, None)
30+
cached_cos, cached_sin, None, None)
3131
if q_embed is None:
3232
q_embed = query_states
3333
else:

lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def fused_rotary_emb(
3333
cached_cos = context.cos if context else cos
3434
cached_sin = context.sin if context else sin
3535
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
36-
cached_cos, cached_sin, None, None, None)
36+
cached_cos, cached_sin, None, None)
3737
if out_q is None:
3838
out_q = query_states
3939
else:

lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,41 @@ def flash_context_attention(
2121
):
2222
num_q_heads, dim = query_states.shape[1:3]
2323
num_kv_heads = value_states.shape[1]
24-
batch = q_start_loc.shape[0]
2524

26-
for i in range(batch):
27-
if torch.equal(q_seq_len[i], kv_seq_len[i]):
28-
ext_ops.context_attention(
29-
query_states,
30-
key_states,
31-
value_states,
32-
q_start_loc[i:i + 1],
33-
q_seq_len[i:i + 1],
34-
num_q_heads,
35-
num_kv_heads,
36-
attn_mask=context.attention_mask[i:i + 1],
37-
attn_output=attn_output,
38-
)
39-
else:
40-
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
41-
value_cache = value_cache.reshape(1, kv_cache_len,
42-
num_kv_heads * dim)
43-
ext_ops.paged_prefill_attention(
44-
query_states,
45-
key_cache,
46-
value_cache,
47-
block_offsets,
48-
block_size,
49-
q_start_loc[i:i + 1],
50-
q_seq_len[i:i + 1],
51-
kv_seq_len[i:i + 1],
52-
num_q_heads,
53-
num_kv_heads,
54-
attn_mask=context.attention_mask[i:i + 1],
55-
attn_output=attn_output,
56-
)
25+
if context.is_unpaged_prefill:
26+
ext_ops.prefill_attention(
27+
query_states,
28+
key_states,
29+
value_states,
30+
q_start_loc,
31+
q_seq_len,
32+
context.max_q_seq_length,
33+
num_q_heads,
34+
num_kv_heads,
35+
attn_mask=context.attention_mask,
36+
attn_output=attn_output,
37+
)
38+
else:
39+
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
40+
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
41+
ext_ops.paged_prefill_attention(
42+
query_states,
43+
key_cache,
44+
value_cache,
45+
block_offsets,
46+
block_size,
47+
q_start_loc,
48+
q_seq_len,
49+
kv_seq_len,
50+
num_q_heads,
51+
num_kv_heads,
52+
attn_mask=context.attention_mask,
53+
attn_output=attn_output,
54+
)
5755

5856

5957
def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
60-
block_offsets, block_size):
58+
max_kv_seq_len, block_offsets, block_size):
6159
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
6260
ext_ops.paged_decode_attention(
6361
q,
@@ -66,6 +64,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6664
block_offsets,
6765
block_size,
6866
kv_seq_len,
67+
max_kv_seq_len,
6968
num_q_heads,
7069
num_kv_heads,
7170
attn_output=attn_output.view(q.shape),
@@ -115,6 +114,7 @@ def paged_attention_fwd(
115114
v,
116115
attn_output,
117116
kv_seqlens,
117+
context.max_kv_seq_length,
118118
block_offsets,
119119
block_size,
120120
)

lmdeploy/pytorch/kernels/dipu/__init__.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)