Skip to content

Commit f8c0a61

Browse files
author
shenhao
committed
add_twomicrobatch
1 parent 2af5a71 commit f8c0a61

File tree

4 files changed

+595
-3
lines changed

4 files changed

+595
-3
lines changed

lmdeploy/pytorch/backends/cuda/moe.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

3-
from typing import List
3+
from typing import List, Any
44

55
import torch
66
import torch.distributed as dist
@@ -361,6 +361,15 @@ def __init__(self,
361361
hidden_size=hidden_dim,
362362
params_dtype=out_dtype,
363363
)
364+
self.token_dispatcher_for2mb = DeepEPDispatcher(
365+
group=ep_group,
366+
router_topk=self.top_k,
367+
permute_fusion=True,
368+
num_experts=self.num_experts,
369+
num_local_experts=self.num_experts // ep_size,
370+
hidden_size=hidden_dim,
371+
params_dtype=out_dtype,
372+
)
364373
self.experts = DeepEPMoE(num_experts, ep_size, [block_size,block_size])
365374

366375
def forward(self,
@@ -377,7 +386,7 @@ def forward(self,
377386
recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = (
378387
self.token_dispatcher.dispatch(
379388
hidden_states,
380-
topk_ids.to(torch.int32),
389+
topk_ids.to(torch.int64),
381390
topk_weights.to(torch.float32),
382391
self.num_experts,
383392
)
@@ -386,6 +395,37 @@ def forward(self,
386395
down_weights, down_scale)
387396
out_states = self.token_dispatcher.combine(out_states)
388397
return out_states
398+
399+
def forward_yield(self,
400+
hidden_states: torch.Tensor,
401+
topk_weights: torch.Tensor,
402+
topk_ids: torch.LongTensor,
403+
gate_up_weights: torch.Tensor,
404+
gate_up_scale: torch.Tensor,
405+
down_weights: torch.Tensor,
406+
down_scale: torch.Tensor,
407+
expert_list: List[int] = None,
408+
tag: Any = None):
409+
"""forward_yield."""
410+
topk_weights = _renormalize(topk_weights, self.renormalize)
411+
412+
_token_dispatcher = self.token_dispatcher
413+
if tag is not None and tag[0] == "0":
414+
_token_dispatcher = self.token_dispatcher
415+
if tag is not None and tag[0] == "1":
416+
_token_dispatcher = self.token_dispatcher_for2mb
417+
recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = (
418+
yield from _token_dispatcher.dispatch_yield(
419+
hidden_states,
420+
topk_ids.to(torch.int64),
421+
topk_weights.to(torch.float32),
422+
self.num_experts,
423+
)
424+
)
425+
out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, gate_up_weights, gate_up_scale,
426+
down_weights, down_scale)
427+
out_states = yield from _token_dispatcher.combine_yield(out_states)
428+
return out_states
389429

390430
class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
391431
"""triton fused moe blocked f8 builder."""

lmdeploy/pytorch/backends/cuda/token_dispatcher.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ def dispatch(
181181
num_max_dispatch_tokens_per_rank: int = 128,
182182
) -> Tuple[torch.Tensor, torch.Tensor]:
183183
self.hidden_shape = hidden_states.shape
184-
topk_idx = topk_idx.to(torch.int64)
185184
(
186185
hidden_states,
187186
topk_idx,
@@ -205,6 +204,39 @@ def dispatch(
205204
hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states)
206205
return hidden_states, topk_idx, topk_weights, tokens_per_expert
207206

207+
def dispatch_yield(
208+
self,
209+
hidden_states: torch.Tensor,
210+
topk_idx: torch.Tensor,
211+
topk_weights: torch.Tensor,
212+
num_experts: int,
213+
previous_event=None,
214+
num_max_dispatch_tokens_per_rank: int = 128,
215+
):
216+
self.hidden_shape = hidden_states.shape
217+
(
218+
hidden_states,
219+
topk_idx,
220+
topk_weights,
221+
num_recv_tokens_per_expert_list,
222+
handle,
223+
event,
224+
) = yield from self.dispatch_normal_yield(
225+
hidden_states, topk_idx, topk_weights, num_experts, previous_event
226+
)
227+
self.tokens_per_expert = torch.tensor(
228+
num_recv_tokens_per_expert_list,
229+
device=hidden_states.device,
230+
dtype=torch.int64,
231+
)
232+
tokens_per_expert = self.get_number_of_tokens_per_expert()
233+
self.handle = handle
234+
self.topk_idx = topk_idx
235+
self.topk_weights = topk_weights
236+
if hidden_states.shape[0] > 0:
237+
hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states)
238+
return hidden_states, topk_idx, topk_weights, tokens_per_expert
239+
208240
def dispatch_normal(
209241
self,
210242
x: torch.Tensor,
@@ -256,6 +288,62 @@ def dispatch_normal(
256288
event,
257289
)
258290

291+
def dispatch_normal_yield(
292+
self,
293+
x: torch.Tensor,
294+
topk_idx: torch.Tensor,
295+
topk_weights: torch.Tensor,
296+
num_experts: int,
297+
previous_event=None,
298+
async_finish=True
299+
):
300+
yield
301+
previous_event = self.buffer_normal.capture() if async_finish else None
302+
(
303+
num_tokens_per_rank,
304+
num_tokens_per_rdma_rank,
305+
num_tokens_per_expert,
306+
is_token_in_rank,
307+
previous_event,
308+
) = self.buffer_normal.get_dispatch_layout(
309+
topk_idx,
310+
num_experts,
311+
previous_event=previous_event,
312+
async_finish=async_finish,
313+
allocate_on_comm_stream=previous_event is not None and async_finish,
314+
)
315+
316+
(
317+
recv_x,
318+
recv_topk_idx,
319+
recv_topk_weights,
320+
num_recv_tokens_per_expert_list,
321+
handle,
322+
event,
323+
) = self.buffer_normal.dispatch(
324+
x,
325+
topk_idx=topk_idx,
326+
topk_weights=topk_weights,
327+
num_tokens_per_rank=num_tokens_per_rank,
328+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
329+
is_token_in_rank=is_token_in_rank,
330+
num_tokens_per_expert=num_tokens_per_expert,
331+
previous_event=previous_event,
332+
async_finish=async_finish,
333+
allocate_on_comm_stream=previous_event is not None and async_finish,
334+
)
335+
336+
yield
337+
if async_finish:
338+
event.current_stream_wait()
339+
return (
340+
recv_x,
341+
recv_topk_idx,
342+
recv_topk_weights,
343+
num_recv_tokens_per_expert_list,
344+
handle,
345+
event,
346+
)
259347

260348
def combine(
261349
self, hidden_states: torch.Tensor
@@ -268,6 +356,17 @@ def combine(
268356
self.handle = None
269357
return hidden_states.view(self.hidden_shape)
270358

359+
def combine_yield(
360+
self, hidden_states: torch.Tensor
361+
):
362+
if hidden_states.shape[0] > 0:
363+
hidden_states = self.get_restored_hidden_states_by_experts(
364+
hidden_states
365+
)
366+
hidden_states, event = yield from self.combine_normal_yield(hidden_states, self.handle)
367+
self.handle = None
368+
return hidden_states.view(self.hidden_shape)
369+
271370
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
272371
combined_x, _, event = self.buffer_normal.combine(
273372
x,
@@ -278,6 +377,22 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
278377
)
279378
return combined_x, event
280379

380+
def combine_normal_yield(self, x: torch.Tensor, handle: Tuple, previous_event=None, async_finish=True):
381+
yield
382+
previous_event = self.buffer_normal.capture() if async_finish else None
383+
combined_x, _, event = self.buffer_normal.combine(
384+
x,
385+
handle,
386+
async_finish=async_finish,
387+
previous_event=previous_event,
388+
allocate_on_comm_stream=previous_event is not None and async_finish,
389+
)
390+
391+
yield
392+
if async_finish:
393+
event.current_stream_wait()
394+
return combined_x, event
395+
281396
def _indices_to_multihot(self, indices, probs):
282397
batch_size = indices.shape[0]
283398
multihot_routing_map = torch.zeros(

lmdeploy/pytorch/models/deepseek_v2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,3 +894,5 @@ def __skip_nextn(name, nextn_keys):
894894
else:
895895
param = params_dict[name]
896896
load_weight(param, loaded_weight)
897+
898+
import lmdeploy.pytorch.models.utils.microbatch

0 commit comments

Comments
 (0)