Skip to content

Commit 44a0cd3

Browse files
CyCle1024jinminxi104tangzhiyi11
authored
[ascend] add ascend graph mode (InternLM#2647)
* [pytorch] ascend enable atbgraph * add paged prefill attention * refine ascend-update-step-ctx (#26) refine ascend-update-step-ctx --------- Co-authored-by: CyCle1024 <[email protected]> * fix: rewrite enable graph for ascend * fix backend error due to folder refactor * remove unnecessary comment * fix rotary_embedding (#27) --------- Co-authored-by: jinminxi104 <[email protected]> Co-authored-by: tangzhiyi11 <[email protected]>
1 parent c25520a commit 44a0cd3

File tree

9 files changed

+315
-29
lines changed

9 files changed

+315
-29
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from lmdeploy.pytorch.kernels.dlinfer.activation import silu_and_mul
3+
4+
from ..activation import SiluAndMulBuilder, SiluAndMulImpl
5+
6+
7+
class DlinferSiluAndMulImpl(SiluAndMulImpl):
8+
"""silu + multiple fused implementation."""
9+
10+
def forward(self, x):
11+
"""forward."""
12+
return silu_and_mul(x)
13+
14+
15+
class DlinferSiluAndMulBuilder(SiluAndMulBuilder):
16+
"""silu and mul implementation builder."""
17+
18+
@staticmethod
19+
def build(inplace: bool = False):
20+
"""build."""
21+
return DlinferSiluAndMulImpl()
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import warnings
3+
from importlib import import_module
4+
5+
import torch
6+
import torch.distributed
7+
8+
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
9+
from lmdeploy.utils import get_logger
10+
11+
from ...graph_runner import GraphRunner
12+
13+
logger = get_logger('lmdeploy')
14+
15+
16+
class AscendGraphRunner(GraphRunner):
17+
"""ascend graph runner."""
18+
19+
def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
20+
cache_config: CacheConfig, backend_config: BackendConfig,
21+
device: torch.device):
22+
super().__init__(model, model_config, cache_config, backend_config,
23+
device)
24+
25+
self.enable_graph = self.check_enable_graph()
26+
if self.enable_graph:
27+
import dlinfer.graph
28+
dlinfer.graph.config.enable_graph_mode = True
29+
self.patch_kernels_custom_op()
30+
self.patch_kvcache_static_shape()
31+
self.model = torch.compile(self.model,
32+
fullgraph=True,
33+
dynamic=True,
34+
backend='atbgraph')
35+
36+
def check_enable_graph(self):
37+
"""check enable graph."""
38+
# eager_mode
39+
if self.backend_config.eager_mode:
40+
return False
41+
# tp
42+
if torch.distributed.is_initialized():
43+
warnings.warn(
44+
"Graph mode of device_type 'ascend' only supports tp=1 "
45+
'for now, fallback to eager mode', RuntimeWarning)
46+
return False
47+
# model support
48+
self.supported_model = {
49+
'Llama2': 'LlamaConfig',
50+
'InternLM2': 'InternLM2Config',
51+
'Qwen2': 'Qwen2Config',
52+
}
53+
is_model_support = True
54+
model_config_name = str(type(self.model_config.hf_config).__name__)
55+
if model_config_name not in self.supported_model.values():
56+
is_model_support = False
57+
if not is_model_support:
58+
warnings.warn(
59+
"Graph mode of device_type 'ascend' only supports models: "
60+
f"{', '.join(self.supported_model.keys())} when tp=1 for now",
61+
RuntimeWarning)
62+
return True
63+
64+
def patch_kernels_custom_op(self):
65+
from dlinfer.graph.custom_op import register_custom_op
66+
dlinfer_kernels_module = import_module(
67+
'lmdeploy.pytorch.kernels.dlinfer')
68+
dlinfer_backends_module = import_module(
69+
'lmdeploy.pytorch.backends.dlinfer')
70+
71+
# prefill_attention
72+
module_str = 'pagedattention'
73+
paged_attn_module = getattr(dlinfer_kernels_module, module_str)
74+
func_str = 'prefill_attention'
75+
prefill_attn_origin = getattr(paged_attn_module, func_str)
76+
prefill_attn_registered = register_custom_op(
77+
f'lmdeploy::{func_str}', ['attn_output'])(prefill_attn_origin)
78+
setattr(paged_attn_module, func_str, prefill_attn_registered)
79+
80+
# apply_rotary_pos_emb
81+
def apply_rotary_emb_abstract_impl(q, k, cos, sin, q_out, k_out):
82+
result = [q, k]
83+
if q_out is not None:
84+
result[0] = q_out
85+
if k_out is not None:
86+
result[1] = k_out
87+
return tuple(result)
88+
89+
module_str = 'apply_rotary_emb'
90+
apply_rotary_emb_module = getattr(dlinfer_backends_module, module_str)
91+
func_str = 'apply_rotary_pos_emb'
92+
apply_rotary_pos_emb_origin = getattr(apply_rotary_emb_module,
93+
func_str)
94+
apply_rotary_pos_emb_registered = register_custom_op(
95+
f'lmdeploy::{func_str}',
96+
impl_abstract_func=apply_rotary_emb_abstract_impl)(
97+
apply_rotary_pos_emb_origin)
98+
setattr(apply_rotary_emb_module, func_str,
99+
apply_rotary_pos_emb_registered)
100+
101+
def patch_kvcache_static_shape(self):
102+
import torch._dynamo as dynamo
103+
from torch.utils._pytree import tree_map
104+
cache_engine_module = import_module(
105+
'lmdeploy.pytorch.engine.cache_engine')
106+
class_str = 'CacheEngine'
107+
cache_engine_class = getattr(cache_engine_module, class_str)
108+
func_str = 'allocate_gpu_cache'
109+
allocate_gpu_cache_origin = getattr(cache_engine_class, func_str)
110+
111+
def allocate_gpu_cache_mark_static(self):
112+
gpu_cache = allocate_gpu_cache_origin(self)
113+
tree_map(lambda x: dynamo.mark_static(x), gpu_cache)
114+
return gpu_cache
115+
116+
setattr(cache_engine_class, func_str, allocate_gpu_cache_mark_static)

lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
67
from lmdeploy.utils import get_logger
78

89
from ..op_backend import DlinferOpsBackend
@@ -12,6 +13,9 @@
1213

1314
class AscendOpsBackend(DlinferOpsBackend):
1415
"""ascend layer backend."""
16+
enable_graph = False
17+
half_negative_inf = torch.finfo(torch.float16).min
18+
total_slots = None
1519

1620
@staticmethod
1721
def get_name() -> str:
@@ -45,21 +49,23 @@ def get_v_block_shape(
4549
@classmethod
4650
def update_step_context(cls, step_context):
4751
"""update step context."""
52+
53+
def get_total_slots():
54+
if cls.total_slots is None:
55+
cls.total_slots = torch.arange(
56+
block_num * block_size,
57+
dtype=torch.long,
58+
device=step_context.block_offsets.device)
59+
cls.total_slots = cls.total_slots.view(block_num, block_size)
60+
return cls.total_slots
61+
4862
kv_start_indices, attention_mask = [], []
4963
block_num, block_size, _ = step_context.kv_caches[0][0].shape
50-
device = step_context.block_offsets.device
51-
5264
is_unpaged_prefill = False
5365
if not step_context.is_decoding:
5466
is_unpaged_prefill = \
5567
all((step_context.q_seqlens ==
5668
step_context.kv_seqlens).tolist())
57-
58-
total_slots = torch.arange(block_num * block_size,
59-
dtype=torch.long,
60-
device=device)
61-
total_slots = total_slots.view(block_num, block_size)
62-
6369
q_seqlens_list = step_context.q_seqlens.tolist()
6470
kv_seqlens_list = step_context.kv_seqlens.tolist()
6571
max_q_seq_len = max(q_seqlens_list)
@@ -71,9 +77,9 @@ def update_step_context(cls, step_context):
7177

7278
# collect kv start indices.
7379
history_length = kv_seq_len - q_seq_len
74-
slot_tables = total_slots[step_context.block_offsets[i]].flatten()
75-
slot_indices = [p for p in range(history_length, kv_seq_len)]
76-
slots = slot_tables[slot_indices].reshape((-1, 1))
80+
total_slots = get_total_slots()
81+
slot_tables = total_slots[step_context.block_offsets[i]].view(-1)
82+
slots = slot_tables[history_length:kv_seq_len]
7783
kv_start_indices.append(slots)
7884

7985
# collect attention mask of paged_prefill attention stage.
@@ -83,19 +89,19 @@ def update_step_context(cls, step_context):
8389
torch.ones(q_seq_len,
8490
step_context.block_offsets.shape[1] *
8591
block_size,
86-
dtype=torch.bool).cuda(),
92+
dtype=torch.bool,
93+
device=step_context.block_offsets.device),
8794
diagonal=kv_seq_len - q_seq_len,
8895
))
8996
attention_mask.append(single_attention_mask)
9097

9198
kv_start_indices = torch.cat(kv_start_indices)
9299

93100
if step_context.is_decoding:
94-
# prepare somae params of paged_decode attention stage.
101+
# prepare some params of paged_decode attention stage.
95102
q_start_loc_cpu, q_seqlens_cpu = None, None
96-
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
97103
elif is_unpaged_prefill:
98-
# prepare somae params of unpaged_prefill attention stage.
104+
# prepare some params of unpaged_prefill attention stage.
99105
q_start_loc_cpu, kv_seqlens_cpu = None, None
100106
q_seqlens_cpu = step_context.q_seqlens.cpu()
101107
single_attention_mask = torch.logical_not(
@@ -106,24 +112,54 @@ def update_step_context(cls, step_context):
106112
))
107113
attention_mask.append(single_attention_mask)
108114
else:
109-
# prepare somae params of paged_prefill attention stage.
115+
# prepare some params of paged_prefill attention stage.
110116
q_start_loc_cpu, q_seqlens_cpu = None, None
111-
kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(
112-
step_context.q_seqlens, 0).cpu()
113-
block_offsets_int32 = step_context.block_offsets.to(torch.int32)
114-
step_context.block_offsets = block_offsets_int32.repeat_interleave(
115-
step_context.q_seqlens, 0)
116-
attention_mask = [
117-
torch.cat([mask for mask in attention_mask]).unsqueeze(1)
118-
]
117+
attention_mask = [torch.cat([mask for mask in attention_mask])]
118+
119+
if cls.enable_graph:
120+
kv_start_indices = kv_start_indices.view(-1).to(torch.int32)
121+
import torch._dynamo as dynamo
122+
if not is_unpaged_prefill:
123+
step_context.block_offsets = step_context.block_offsets.to(
124+
torch.int32)
125+
if not step_context.is_decoding:
126+
step_context.block_offsets = step_context.block_offsets\
127+
.repeat_interleave(step_context.q_seqlens, 0)
128+
dynamo.mark_dynamic(step_context.block_offsets, [0, 1])
129+
kv_seqlens = step_context.kv_seqlens.to(torch.int32)
130+
if not step_context.is_decoding:
131+
if is_unpaged_prefill:
132+
attention_mask = [mask.half() for mask in attention_mask]
133+
else:
134+
attention_mask = [
135+
torch.cat([
136+
mask.half() * cls.half_negative_inf
137+
for mask in attention_mask
138+
]).unsqueeze(1)
139+
]
140+
kv_seqlens = kv_seqlens.repeat_interleave(
141+
step_context.q_seqlens, 0)
142+
else:
143+
if step_context.is_decoding:
144+
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
145+
elif is_unpaged_prefill:
146+
pass
147+
else:
148+
kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(
149+
step_context.q_seqlens, 0).cpu()
150+
block_offsets_int32 = step_context.block_offsets.to(
151+
torch.int32)
152+
step_context.block_offsets = block_offsets_int32\
153+
.repeat_interleave(step_context.q_seqlens, 0)
154+
kv_seqlens = kv_seqlens_cpu
119155

120156
attn_meta_cls = cls.get_attention_metadata_cls()
121157
attn_metadata = attn_meta_cls(
122158
step_context.is_decoding,
123159
step_context.block_offsets,
124160
q_start_loc=q_start_loc_cpu,
125161
q_seqlens=q_seqlens_cpu,
126-
kv_seqlens=kv_seqlens_cpu,
162+
kv_seqlens=kv_seqlens,
127163
kv_start_indices=kv_start_indices,
128164
block_size=block_size,
129165
attention_mask=attention_mask,
@@ -134,3 +170,16 @@ def update_step_context(cls, step_context):
134170

135171
step_context.attn_metadata = attn_metadata
136172
return step_context
173+
174+
@staticmethod
175+
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig,
176+
cache_config: CacheConfig,
177+
backend_config: BackendConfig,
178+
device: torch.device):
179+
"""build graph runner."""
180+
from .graph_runner import AscendGraphRunner
181+
ascend_graph_runner = AscendGraphRunner(model, model_config,
182+
cache_config, backend_config,
183+
device)
184+
AscendOpsBackend.enable_graph = ascend_graph_runner.enable_graph
185+
return ascend_graph_runner

lmdeploy/pytorch/backends/dlinfer/op_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
2828
elif layer_type == OpType.ApplyRotaryEmb:
2929
from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
3030
return DlinferApplyRotaryEmbBuilder
31+
elif layer_type == OpType.SiluAndMul:
32+
from .activation import DlinferSiluAndMulBuilder
33+
return DlinferSiluAndMulBuilder
3134
elif layer_type == OpType.RMSNorm:
3235
from .norm import DlinferRMSNormBuilder
3336
return DlinferRMSNormBuilder
@@ -40,6 +43,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
4043
elif layer_type == OpType.LinearW4A16:
4144
from .awq_modules import AwqLinearW4A16Builder
4245
return AwqLinearW4A16Builder
46+
elif layer_type == OpType.RotaryEmbedding:
47+
from .rotary_embedding import DlinferRotaryEmbeddingBuilder
48+
return DlinferRotaryEmbeddingBuilder
4349
else:
4450
logger.debug(
4551
f'Op {layer_type} fallback to default implementation.')

0 commit comments

Comments
 (0)