6
6
use_deepep = False
7
7
8
8
import os
9
- from typing import Optional , Tuple
9
+ from typing import Optional , Tuple , Any
10
10
11
11
import torch
12
12
import torch .distributed as dist
@@ -156,6 +156,8 @@ def __init__(
156
156
self .token_probs = None
157
157
# Handle used for combine operation
158
158
self .handle = None
159
+ # shared experts
160
+ self .shared_experts = None
159
161
160
162
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
161
163
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
@@ -212,30 +214,44 @@ def dispatch_yield(
212
214
num_experts : int ,
213
215
previous_event = None ,
214
216
num_max_dispatch_tokens_per_rank : int = 128 ,
217
+ is_prefill : bool = False ,
218
+ is_decoding : bool = False
215
219
):
216
220
self .hidden_shape = hidden_states .shape
221
+ # yield for attn1, dis (+share)
222
+ yield
223
+ previous_event = self .buffer_normal .capture ()
217
224
(
218
- hidden_states ,
219
- topk_idx ,
220
- topk_weights ,
225
+ recv_hidden_states ,
226
+ recv_topk_idx ,
227
+ recv_topk_weights ,
221
228
num_recv_tokens_per_expert_list ,
222
229
handle ,
223
230
event ,
224
- ) = yield from self .dispatch_normal_yield (
225
- hidden_states , topk_idx , topk_weights , num_experts , previous_event
231
+ ) = self .dispatch_normal_async (
232
+ hidden_states , topk_idx , topk_weights , num_experts , previous_event , True
226
233
)
234
+ if is_decoding and self .shared_experts is not None :
235
+ shared_states = self .shared_experts (hidden_states )
236
+ else :
237
+ shared_states = None
238
+ # yield for dis (+share), dis_wait
239
+ yield
240
+ event .current_stream_wait ()
241
+ # yield for dis_wait, moe
242
+ yield
227
243
self .tokens_per_expert = torch .tensor (
228
244
num_recv_tokens_per_expert_list ,
229
245
device = hidden_states .device ,
230
246
dtype = torch .int64 ,
231
247
)
232
248
tokens_per_expert = self .get_number_of_tokens_per_expert ()
233
249
self .handle = handle
234
- self .topk_idx = topk_idx
235
- self .topk_weights = topk_weights
236
- if hidden_states .shape [0 ] > 0 :
237
- hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
238
- return hidden_states , topk_idx , topk_weights , tokens_per_expert
250
+ self .topk_idx = recv_topk_idx
251
+ self .topk_weights = recv_topk_weights
252
+ if recv_hidden_states .shape [0 ] > 0 :
253
+ recv_hidden_states = self .get_permuted_hidden_states_by_experts (recv_hidden_states )
254
+ return recv_hidden_states , recv_topk_idx , recv_topk_weights , tokens_per_expert , shared_states
239
255
240
256
def dispatch_normal (
241
257
self ,
@@ -288,7 +304,7 @@ def dispatch_normal(
288
304
event ,
289
305
)
290
306
291
- def dispatch_normal_yield (
307
+ def dispatch_normal_async (
292
308
self ,
293
309
x : torch .Tensor ,
294
310
topk_idx : torch .Tensor ,
@@ -297,8 +313,6 @@ def dispatch_normal_yield(
297
313
previous_event = None ,
298
314
async_finish = True
299
315
):
300
- yield
301
- previous_event = self .buffer_normal .capture () if async_finish else None
302
316
(
303
317
num_tokens_per_rank ,
304
318
num_tokens_per_rdma_rank ,
@@ -333,9 +347,6 @@ def dispatch_normal_yield(
333
347
allocate_on_comm_stream = previous_event is not None and async_finish ,
334
348
)
335
349
336
- yield
337
- if async_finish :
338
- event .current_stream_wait ()
339
350
return (
340
351
recv_x ,
341
352
recv_topk_idx ,
@@ -357,15 +368,34 @@ def combine(
357
368
return hidden_states .view (self .hidden_shape )
358
369
359
370
def combine_yield (
360
- self , hidden_states : torch .Tensor
371
+ self ,
372
+ out_states : torch .Tensor ,
373
+ hidden_states : torch .Tensor ,
374
+ is_prefill : bool = False ,
375
+ is_decoding : bool = False
361
376
):
362
- if hidden_states .shape [0 ] > 0 :
363
- hidden_states = self .get_restored_hidden_states_by_experts (
364
- hidden_states
377
+ if out_states .shape [0 ] > 0 :
378
+ out_states = self .get_restored_hidden_states_by_experts (
379
+ out_states
365
380
)
366
- hidden_states , event = yield from self .combine_normal_yield (hidden_states , self .handle )
381
+ # yield for moe, comb
382
+ yield
383
+ previous_event = self .buffer_normal .capture ()
384
+ out_states , event = self .combine_normal_async (out_states ,
385
+ self .handle ,
386
+ previous_event = previous_event ,
387
+ async_finish = True )
388
+ # yield for comb, (+share) comb_wait,
389
+ yield
390
+ if is_prefill and self .shared_experts is not None :
391
+ shared_states = self .shared_experts (hidden_states )
392
+ else :
393
+ shared_states = None
394
+ event .current_stream_wait ()
395
+ # yield for (+share) comb_wait, (+share) attn0
396
+ yield
367
397
self .handle = None
368
- return hidden_states .view (self .hidden_shape )
398
+ return out_states .view (self .hidden_shape ), shared_states
369
399
370
400
def combine_normal (self , x : torch .Tensor , handle : Tuple , previous_event = None ):
371
401
combined_x , _ , event = self .buffer_normal .combine (
@@ -377,20 +407,14 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
377
407
)
378
408
return combined_x , event
379
409
380
- def combine_normal_yield (self , x : torch .Tensor , handle : Tuple , previous_event = None , async_finish = True ):
381
- yield
382
- previous_event = self .buffer_normal .capture () if async_finish else None
410
+ def combine_normal_async (self , x : torch .Tensor , handle : Tuple , previous_event = None , async_finish = True ):
383
411
combined_x , _ , event = self .buffer_normal .combine (
384
412
x ,
385
413
handle ,
386
414
async_finish = async_finish ,
387
415
previous_event = previous_event ,
388
416
allocate_on_comm_stream = previous_event is not None and async_finish ,
389
417
)
390
-
391
- yield
392
- if async_finish :
393
- event .current_stream_wait ()
394
418
return combined_x , event
395
419
396
420
def _indices_to_multihot (self , indices , probs ):
@@ -456,3 +480,11 @@ def get_restored_hidden_states_by_experts(
456
480
fused = self .permute_fusion ,
457
481
)
458
482
return hidden_states .to (input_dtype )
483
+
484
+ def set_shared_experts (self , shared_experts : Any = None ):
485
+ if self .shared_experts is not None :
486
+ self .shared_experts = shared_experts
487
+ return self .shared_experts
488
+
489
+ def get_shared_experts (self ):
490
+ return self .shared_experts
0 commit comments