Skip to content

Commit e528255

Browse files
committed
Merge branch 'infer_ext' into daoxin/support-cogvlm
2 parents 70fd41c + 51ec61c commit e528255

File tree

11 files changed

+48
-320
lines changed

11 files changed

+48
-320
lines changed

lmdeploy/pytorch/engine/devices/ascend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import torch
33

4-
from .dipu import DIPUDeviceUtils
4+
from .base_device_utils import BaseDeviceUtils
55

66

7-
class ASCENDDeviceUtils(DIPUDeviceUtils):
7+
class ASCENDDeviceUtils(BaseDeviceUtils):
88

99
device = 'ascend'
1010

@@ -38,4 +38,7 @@ def update_step_context(cls, step_context):
3838
kv_start_indices, device=step_context.block_offsets.device)
3939
setattr(step_context, 'kv_start_indices', kv_start_indices)
4040
setattr(step_context, 'attention_mask', attention_mask)
41+
is_unpaged_prefill = (not step_context.is_decoding) and all(
42+
(step_context.q_seq_length == step_context.kv_seq_length).tolist())
43+
setattr(step_context, 'is_unpaged_prefill', is_unpaged_prefill)
4144
return step_context

lmdeploy/pytorch/engine/devices/dipu.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,22 @@ def apply_rotary_pos_emb(
1919
query_states_reshaped = query_states.reshape(1, bs, head, dim)
2020
key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim)
2121
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
22-
cos = cos[position_ids_1d].view(1, bs, 1, -1)
23-
sin = sin[position_ids_1d].view(1, bs, 1, -1)
22+
if len(cos.shape) == 3 and len(sin.shape) == 3:
23+
cos = cos[:, position_ids_1d].view(1, bs, 1, -1)
24+
sin = sin[:, position_ids_1d].view(1, bs, 1, -1)
25+
elif len(cos.shape) == 2 and len(sin.shape) == 2:
26+
cos = cos[position_ids_1d].view(1, bs, 1, -1)
27+
sin = sin[position_ids_1d].view(1, bs, 1, -1)
28+
else:
29+
raise RuntimeError("Cannot handle cos/sin shape dims!")
30+
2431
if context:
2532
setattr(context, 'cos', cos)
2633
setattr(context, 'sin', sin)
2734
cached_cos = context.cos if context else cos
2835
cached_sin = context.sin if context else sin
2936
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
30-
cached_cos, cached_sin, None, None, None)
37+
cached_cos, cached_sin, None, None)
3138
if q_embed is None:
3239
q_embed = query_states
3340
else:

lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def fused_rotary_emb(
3333
cached_cos = context.cos if context else cos
3434
cached_sin = context.sin if context else sin
3535
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
36-
cached_cos, cached_sin, None, None, None)
36+
cached_cos, cached_sin, None, None)
3737
if out_q is None:
3838
out_q = query_states
3939
else:

lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,44 +23,40 @@ def flash_context_attention(
2323
):
2424
num_q_heads, dim = query_states.shape[1:3]
2525
num_kv_heads = value_states.shape[1]
26-
batch = q_start_loc.shape[0]
2726

28-
qkv_eq = query_states.shape[0] == key_states.shape[0]
29-
for i in range(batch):
30-
if qkv_eq:
31-
ext_ops.context_attention(
32-
query=query_states,
33-
key=key_states,
34-
value=value_states,
35-
q_start_loc=q_start_loc[i:i + 1],
36-
seq_len_list=q_seq_len_list[i:i + 1],
37-
num_q_heads=num_q_heads,
38-
num_kv_heads=num_kv_heads,
39-
attn_mask=context.attention_mask[i:i + 1],
40-
attn_output=attn_output,
41-
)
42-
else:
43-
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
44-
value_cache = value_cache.reshape(1, kv_cache_len,
45-
num_kv_heads * dim)
46-
ext_ops.paged_prefill_attention(
47-
query_states,
48-
key_cache,
49-
value_cache,
50-
block_offsets,
51-
block_size,
52-
q_start_loc[i:i + 1],
53-
q_seq_len_list[i:i + 1],
54-
kv_seq_len[i:i + 1],
55-
num_q_heads,
56-
num_kv_heads,
57-
attn_mask=context.attention_mask[i:i + 1],
58-
attn_output=attn_output,
59-
)
27+
if context.is_unpaged_prefill:
28+
ext_ops.context_attention(
29+
query=query_states,
30+
key=key_states,
31+
value=value_states,
32+
q_start_loc=q_start_loc[i:i + 1],
33+
seq_len_list=q_seq_len_list[i:i + 1],
34+
num_q_heads=num_q_heads,
35+
num_kv_heads=num_kv_heads,
36+
attn_mask=context.attention_mask[i:i + 1],
37+
attn_output=attn_output,
38+
)
39+
else:
40+
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
41+
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
42+
ext_ops.paged_prefill_attention(
43+
query_states,
44+
key_cache,
45+
value_cache,
46+
block_offsets,
47+
block_size,
48+
q_start_loc[i:i + 1],
49+
q_seq_len_list[i:i + 1],
50+
kv_seq_len[i:i + 1],
51+
num_q_heads,
52+
num_kv_heads,
53+
attn_mask=context.attention_mask[i:i + 1],
54+
attn_output=attn_output,
55+
)
6056

6157

6258
def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
63-
block_offsets, block_size):
59+
max_kv_seq_len, block_offsets, block_size):
6460
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
6561
ext_ops.paged_decode_attention(
6662
query=q,
@@ -69,6 +65,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6965
block_table=block_offsets,
7066
block_size=block_size,
7167
kv_seq_len=kv_seq_len,
68+
max_kv_seq_len=max_kv_seq_len,
7269
num_q_heads=num_q_heads,
7370
num_kv_heads=num_kv_heads,
7471
attn_output=attn_output.view(q.shape),
@@ -120,6 +117,7 @@ def paged_attention_fwd(
120117
v,
121118
attn_output,
122119
kv_seqlens,
120+
context.max_kv_seq_length,
123121
block_offsets,
124122
block_size,
125123
)

lmdeploy/pytorch/kernels/dipu/__init__.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)