Skip to content

Commit 46314e2

Browse files
author
shenhao
committed
update_timeline
1 parent f8c0a61 commit 46314e2

File tree

3 files changed

+242
-77
lines changed

3 files changed

+242
-77
lines changed

lmdeploy/pytorch/backends/cuda/moe.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def forward(self,
395395
down_weights, down_scale)
396396
out_states = self.token_dispatcher.combine(out_states)
397397
return out_states
398-
398+
399399
def forward_yield(self,
400400
hidden_states: torch.Tensor,
401401
topk_weights: torch.Tensor,
@@ -405,27 +405,51 @@ def forward_yield(self,
405405
down_weights: torch.Tensor,
406406
down_scale: torch.Tensor,
407407
expert_list: List[int] = None,
408-
tag: Any = None):
408+
tag: Any = None,
409+
shared_experts: Any = None):
409410
"""forward_yield."""
410411
topk_weights = _renormalize(topk_weights, self.renormalize)
411412

413+
if shared_experts is not None:
414+
if self.token_dispatcher.get_shared_experts() is None:
415+
self.token_dispatcher.set_shared_experts(shared_experts)
416+
if self.token_dispatcher.get_shared_experts() is None:
417+
self.token_dispatcher_for2mb.set_shared_experts(shared_experts)
418+
419+
assert tag is not None and len(tag) >= 1
412420
_token_dispatcher = self.token_dispatcher
413421
if tag is not None and tag[0] == "0":
414422
_token_dispatcher = self.token_dispatcher
415423
if tag is not None and tag[0] == "1":
416424
_token_dispatcher = self.token_dispatcher_for2mb
417-
recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = (
425+
is_decoding = False
426+
is_prefill = False
427+
if tag is not None and len(tag) > 1 and tag[1].upper() == "P":
428+
is_prefill = True
429+
if tag is not None and len(tag) > 1 and tag[1].upper() == "D":
430+
is_decoding = True
431+
432+
_token_dispatcher.set_shared_experts(shared_experts)
433+
# yield for attn1, dis (+share), dis_wait, moe
434+
recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert, shared_states_indispatch = (
418435
yield from _token_dispatcher.dispatch_yield(
419436
hidden_states,
420437
topk_ids.to(torch.int64),
421438
topk_weights.to(torch.float32),
422439
self.num_experts,
440+
is_prefill,
441+
is_decoding
423442
)
424443
)
425444
out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, gate_up_weights, gate_up_scale,
426445
down_weights, down_scale)
427-
out_states = yield from _token_dispatcher.combine_yield(out_states)
428-
return out_states
446+
# yield for moe, comb, (+share) comb_wait, (+share) attn0
447+
out_states, shared_states_incomb = yield from _token_dispatcher.combine_yield(out_states,
448+
hidden_states,
449+
is_prefill,
450+
is_decoding)
451+
shared_states = shared_states_indispatch if shared_states_indispatch is not None else shared_states_incomb
452+
return out_states, shared_states
429453

430454
class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
431455
"""triton fused moe blocked f8 builder."""

lmdeploy/pytorch/backends/cuda/token_dispatcher.py

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
use_deepep = False
77

88
import os
9-
from typing import Optional, Tuple
9+
from typing import Optional, Tuple, Any
1010

1111
import torch
1212
import torch.distributed as dist
@@ -156,6 +156,8 @@ def __init__(
156156
self.token_probs = None
157157
# Handle used for combine operation
158158
self.handle = None
159+
# shared experts
160+
self.shared_experts = None
159161

160162
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
161163
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
@@ -212,30 +214,44 @@ def dispatch_yield(
212214
num_experts: int,
213215
previous_event=None,
214216
num_max_dispatch_tokens_per_rank: int = 128,
217+
is_prefill: bool = False,
218+
is_decoding: bool = False
215219
):
216220
self.hidden_shape = hidden_states.shape
221+
# yield for attn1, dis (+share)
222+
yield
223+
previous_event = self.buffer_normal.capture()
217224
(
218-
hidden_states,
219-
topk_idx,
220-
topk_weights,
225+
recv_hidden_states,
226+
recv_topk_idx,
227+
recv_topk_weights,
221228
num_recv_tokens_per_expert_list,
222229
handle,
223230
event,
224-
) = yield from self.dispatch_normal_yield(
225-
hidden_states, topk_idx, topk_weights, num_experts, previous_event
231+
) = self.dispatch_normal_async(
232+
hidden_states, topk_idx, topk_weights, num_experts, previous_event, True
226233
)
234+
if is_decoding and self.shared_experts is not None:
235+
shared_states = self.shared_experts(hidden_states)
236+
else:
237+
shared_states = None
238+
# yield for dis (+share), dis_wait
239+
yield
240+
event.current_stream_wait()
241+
# yield for dis_wait, moe
242+
yield
227243
self.tokens_per_expert = torch.tensor(
228244
num_recv_tokens_per_expert_list,
229245
device=hidden_states.device,
230246
dtype=torch.int64,
231247
)
232248
tokens_per_expert = self.get_number_of_tokens_per_expert()
233249
self.handle = handle
234-
self.topk_idx = topk_idx
235-
self.topk_weights = topk_weights
236-
if hidden_states.shape[0] > 0:
237-
hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states)
238-
return hidden_states, topk_idx, topk_weights, tokens_per_expert
250+
self.topk_idx = recv_topk_idx
251+
self.topk_weights = recv_topk_weights
252+
if recv_hidden_states.shape[0] > 0:
253+
recv_hidden_states = self.get_permuted_hidden_states_by_experts(recv_hidden_states)
254+
return recv_hidden_states, recv_topk_idx, recv_topk_weights, tokens_per_expert, shared_states
239255

240256
def dispatch_normal(
241257
self,
@@ -288,7 +304,7 @@ def dispatch_normal(
288304
event,
289305
)
290306

291-
def dispatch_normal_yield(
307+
def dispatch_normal_async(
292308
self,
293309
x: torch.Tensor,
294310
topk_idx: torch.Tensor,
@@ -297,8 +313,6 @@ def dispatch_normal_yield(
297313
previous_event=None,
298314
async_finish=True
299315
):
300-
yield
301-
previous_event = self.buffer_normal.capture() if async_finish else None
302316
(
303317
num_tokens_per_rank,
304318
num_tokens_per_rdma_rank,
@@ -333,9 +347,6 @@ def dispatch_normal_yield(
333347
allocate_on_comm_stream=previous_event is not None and async_finish,
334348
)
335349

336-
yield
337-
if async_finish:
338-
event.current_stream_wait()
339350
return (
340351
recv_x,
341352
recv_topk_idx,
@@ -357,15 +368,34 @@ def combine(
357368
return hidden_states.view(self.hidden_shape)
358369

359370
def combine_yield(
360-
self, hidden_states: torch.Tensor
371+
self,
372+
out_states: torch.Tensor,
373+
hidden_states: torch.Tensor,
374+
is_prefill: bool = False,
375+
is_decoding: bool = False
361376
):
362-
if hidden_states.shape[0] > 0:
363-
hidden_states = self.get_restored_hidden_states_by_experts(
364-
hidden_states
377+
if out_states.shape[0] > 0:
378+
out_states = self.get_restored_hidden_states_by_experts(
379+
out_states
365380
)
366-
hidden_states, event = yield from self.combine_normal_yield(hidden_states, self.handle)
381+
# yield for moe, comb
382+
yield
383+
previous_event = self.buffer_normal.capture()
384+
out_states, event = self.combine_normal_async(out_states,
385+
self.handle,
386+
previous_event=previous_event,
387+
async_finish=True)
388+
# yield for comb, (+share) comb_wait,
389+
yield
390+
if is_prefill and self.shared_experts is not None:
391+
shared_states = self.shared_experts(hidden_states)
392+
else:
393+
shared_states = None
394+
event.current_stream_wait()
395+
# yield for (+share) comb_wait, (+share) attn0
396+
yield
367397
self.handle = None
368-
return hidden_states.view(self.hidden_shape)
398+
return out_states.view(self.hidden_shape), shared_states
369399

370400
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
371401
combined_x, _, event = self.buffer_normal.combine(
@@ -377,20 +407,14 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
377407
)
378408
return combined_x, event
379409

380-
def combine_normal_yield(self, x: torch.Tensor, handle: Tuple, previous_event=None, async_finish=True):
381-
yield
382-
previous_event = self.buffer_normal.capture() if async_finish else None
410+
def combine_normal_async(self, x: torch.Tensor, handle: Tuple, previous_event=None, async_finish=True):
383411
combined_x, _, event = self.buffer_normal.combine(
384412
x,
385413
handle,
386414
async_finish=async_finish,
387415
previous_event=previous_event,
388416
allocate_on_comm_stream=previous_event is not None and async_finish,
389417
)
390-
391-
yield
392-
if async_finish:
393-
event.current_stream_wait()
394418
return combined_x, event
395419

396420
def _indices_to_multihot(self, indices, probs):
@@ -456,3 +480,11 @@ def get_restored_hidden_states_by_experts(
456480
fused=self.permute_fusion,
457481
)
458482
return hidden_states.to(input_dtype)
483+
484+
def set_shared_experts(self, shared_experts: Any = None):
485+
if self.shared_experts is not None:
486+
self.shared_experts = shared_experts
487+
return self.shared_experts
488+
489+
def get_shared_experts(self):
490+
return self.shared_experts

0 commit comments

Comments
 (0)