3
3
4
4
import torch
5
5
6
+ from lmdeploy .pytorch .config import BackendConfig , CacheConfig , ModelConfig
6
7
from lmdeploy .utils import get_logger
7
8
8
9
from ..op_backend import DlinferOpsBackend
12
13
13
14
class AscendOpsBackend (DlinferOpsBackend ):
14
15
"""ascend layer backend."""
16
+ enable_graph = False
17
+ half_negative_inf = torch .finfo (torch .float16 ).min
18
+ total_slots = None
15
19
16
20
@staticmethod
17
21
def get_name () -> str :
@@ -45,21 +49,23 @@ def get_v_block_shape(
45
49
@classmethod
46
50
def update_step_context (cls , step_context ):
47
51
"""update step context."""
52
+
53
+ def get_total_slots ():
54
+ if cls .total_slots is None :
55
+ cls .total_slots = torch .arange (
56
+ block_num * block_size ,
57
+ dtype = torch .long ,
58
+ device = step_context .block_offsets .device )
59
+ cls .total_slots = cls .total_slots .view (block_num , block_size )
60
+ return cls .total_slots
61
+
48
62
kv_start_indices , attention_mask = [], []
49
63
block_num , block_size , _ = step_context .kv_caches [0 ][0 ].shape
50
- device = step_context .block_offsets .device
51
-
52
64
is_unpaged_prefill = False
53
65
if not step_context .is_decoding :
54
66
is_unpaged_prefill = \
55
67
all ((step_context .q_seqlens ==
56
68
step_context .kv_seqlens ).tolist ())
57
-
58
- total_slots = torch .arange (block_num * block_size ,
59
- dtype = torch .long ,
60
- device = device )
61
- total_slots = total_slots .view (block_num , block_size )
62
-
63
69
q_seqlens_list = step_context .q_seqlens .tolist ()
64
70
kv_seqlens_list = step_context .kv_seqlens .tolist ()
65
71
max_q_seq_len = max (q_seqlens_list )
@@ -71,9 +77,9 @@ def update_step_context(cls, step_context):
71
77
72
78
# collect kv start indices.
73
79
history_length = kv_seq_len - q_seq_len
74
- slot_tables = total_slots [ step_context . block_offsets [ i ]]. flatten ()
75
- slot_indices = [ p for p in range ( history_length , kv_seq_len )]
76
- slots = slot_tables [slot_indices ]. reshape (( - 1 , 1 ))
80
+ total_slots = get_total_slots ()
81
+ slot_tables = total_slots [ step_context . block_offsets [ i ]]. view ( - 1 )
82
+ slots = slot_tables [history_length : kv_seq_len ]
77
83
kv_start_indices .append (slots )
78
84
79
85
# collect attention mask of paged_prefill attention stage.
@@ -83,19 +89,19 @@ def update_step_context(cls, step_context):
83
89
torch .ones (q_seq_len ,
84
90
step_context .block_offsets .shape [1 ] *
85
91
block_size ,
86
- dtype = torch .bool ).cuda (),
92
+ dtype = torch .bool ,
93
+ device = step_context .block_offsets .device ),
87
94
diagonal = kv_seq_len - q_seq_len ,
88
95
))
89
96
attention_mask .append (single_attention_mask )
90
97
91
98
kv_start_indices = torch .cat (kv_start_indices )
92
99
93
100
if step_context .is_decoding :
94
- # prepare somae params of paged_decode attention stage.
101
+ # prepare some params of paged_decode attention stage.
95
102
q_start_loc_cpu , q_seqlens_cpu = None , None
96
- kv_seqlens_cpu = step_context .kv_seqlens .cpu ()
97
103
elif is_unpaged_prefill :
98
- # prepare somae params of unpaged_prefill attention stage.
104
+ # prepare some params of unpaged_prefill attention stage.
99
105
q_start_loc_cpu , kv_seqlens_cpu = None , None
100
106
q_seqlens_cpu = step_context .q_seqlens .cpu ()
101
107
single_attention_mask = torch .logical_not (
@@ -106,24 +112,54 @@ def update_step_context(cls, step_context):
106
112
))
107
113
attention_mask .append (single_attention_mask )
108
114
else :
109
- # prepare somae params of paged_prefill attention stage.
115
+ # prepare some params of paged_prefill attention stage.
110
116
q_start_loc_cpu , q_seqlens_cpu = None , None
111
- kv_seqlens_cpu = step_context .kv_seqlens .repeat_interleave (
112
- step_context .q_seqlens , 0 ).cpu ()
113
- block_offsets_int32 = step_context .block_offsets .to (torch .int32 )
114
- step_context .block_offsets = block_offsets_int32 .repeat_interleave (
115
- step_context .q_seqlens , 0 )
116
- attention_mask = [
117
- torch .cat ([mask for mask in attention_mask ]).unsqueeze (1 )
118
- ]
117
+ attention_mask = [torch .cat ([mask for mask in attention_mask ])]
118
+
119
+ if cls .enable_graph :
120
+ kv_start_indices = kv_start_indices .view (- 1 ).to (torch .int32 )
121
+ import torch ._dynamo as dynamo
122
+ if not is_unpaged_prefill :
123
+ step_context .block_offsets = step_context .block_offsets .to (
124
+ torch .int32 )
125
+ if not step_context .is_decoding :
126
+ step_context .block_offsets = step_context .block_offsets \
127
+ .repeat_interleave (step_context .q_seqlens , 0 )
128
+ dynamo .mark_dynamic (step_context .block_offsets , [0 , 1 ])
129
+ kv_seqlens = step_context .kv_seqlens .to (torch .int32 )
130
+ if not step_context .is_decoding :
131
+ if is_unpaged_prefill :
132
+ attention_mask = [mask .half () for mask in attention_mask ]
133
+ else :
134
+ attention_mask = [
135
+ torch .cat ([
136
+ mask .half () * cls .half_negative_inf
137
+ for mask in attention_mask
138
+ ]).unsqueeze (1 )
139
+ ]
140
+ kv_seqlens = kv_seqlens .repeat_interleave (
141
+ step_context .q_seqlens , 0 )
142
+ else :
143
+ if step_context .is_decoding :
144
+ kv_seqlens_cpu = step_context .kv_seqlens .cpu ()
145
+ elif is_unpaged_prefill :
146
+ pass
147
+ else :
148
+ kv_seqlens_cpu = step_context .kv_seqlens .repeat_interleave (
149
+ step_context .q_seqlens , 0 ).cpu ()
150
+ block_offsets_int32 = step_context .block_offsets .to (
151
+ torch .int32 )
152
+ step_context .block_offsets = block_offsets_int32 \
153
+ .repeat_interleave (step_context .q_seqlens , 0 )
154
+ kv_seqlens = kv_seqlens_cpu
119
155
120
156
attn_meta_cls = cls .get_attention_metadata_cls ()
121
157
attn_metadata = attn_meta_cls (
122
158
step_context .is_decoding ,
123
159
step_context .block_offsets ,
124
160
q_start_loc = q_start_loc_cpu ,
125
161
q_seqlens = q_seqlens_cpu ,
126
- kv_seqlens = kv_seqlens_cpu ,
162
+ kv_seqlens = kv_seqlens ,
127
163
kv_start_indices = kv_start_indices ,
128
164
block_size = block_size ,
129
165
attention_mask = attention_mask ,
@@ -134,3 +170,16 @@ def update_step_context(cls, step_context):
134
170
135
171
step_context .attn_metadata = attn_metadata
136
172
return step_context
173
+
174
+ @staticmethod
175
+ def build_graph_runner (model : torch .nn .Module , model_config : ModelConfig ,
176
+ cache_config : CacheConfig ,
177
+ backend_config : BackendConfig ,
178
+ device : torch .device ):
179
+ """build graph runner."""
180
+ from .graph_runner import AscendGraphRunner
181
+ ascend_graph_runner = AscendGraphRunner (model , model_config ,
182
+ cache_config , backend_config ,
183
+ device )
184
+ AscendOpsBackend .enable_graph = ascend_graph_runner .enable_graph
185
+ return ascend_graph_runner
0 commit comments