@@ -181,7 +181,6 @@ def dispatch(
181
181
num_max_dispatch_tokens_per_rank : int = 128 ,
182
182
) -> Tuple [torch .Tensor , torch .Tensor ]:
183
183
self .hidden_shape = hidden_states .shape
184
- topk_idx = topk_idx .to (torch .int64 )
185
184
(
186
185
hidden_states ,
187
186
topk_idx ,
@@ -205,6 +204,39 @@ def dispatch(
205
204
hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
206
205
return hidden_states , topk_idx , topk_weights , tokens_per_expert
207
206
207
+ def dispatch_yield (
208
+ self ,
209
+ hidden_states : torch .Tensor ,
210
+ topk_idx : torch .Tensor ,
211
+ topk_weights : torch .Tensor ,
212
+ num_experts : int ,
213
+ previous_event = None ,
214
+ num_max_dispatch_tokens_per_rank : int = 128 ,
215
+ ):
216
+ self .hidden_shape = hidden_states .shape
217
+ (
218
+ hidden_states ,
219
+ topk_idx ,
220
+ topk_weights ,
221
+ num_recv_tokens_per_expert_list ,
222
+ handle ,
223
+ event ,
224
+ ) = yield from self .dispatch_normal_yield (
225
+ hidden_states , topk_idx , topk_weights , num_experts , previous_event
226
+ )
227
+ self .tokens_per_expert = torch .tensor (
228
+ num_recv_tokens_per_expert_list ,
229
+ device = hidden_states .device ,
230
+ dtype = torch .int64 ,
231
+ )
232
+ tokens_per_expert = self .get_number_of_tokens_per_expert ()
233
+ self .handle = handle
234
+ self .topk_idx = topk_idx
235
+ self .topk_weights = topk_weights
236
+ if hidden_states .shape [0 ] > 0 :
237
+ hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
238
+ return hidden_states , topk_idx , topk_weights , tokens_per_expert
239
+
208
240
def dispatch_normal (
209
241
self ,
210
242
x : torch .Tensor ,
@@ -256,6 +288,62 @@ def dispatch_normal(
256
288
event ,
257
289
)
258
290
291
+ def dispatch_normal_yield (
292
+ self ,
293
+ x : torch .Tensor ,
294
+ topk_idx : torch .Tensor ,
295
+ topk_weights : torch .Tensor ,
296
+ num_experts : int ,
297
+ previous_event = None ,
298
+ async_finish = True
299
+ ):
300
+ yield
301
+ previous_event = self .buffer_normal .capture () if async_finish else None
302
+ (
303
+ num_tokens_per_rank ,
304
+ num_tokens_per_rdma_rank ,
305
+ num_tokens_per_expert ,
306
+ is_token_in_rank ,
307
+ previous_event ,
308
+ ) = self .buffer_normal .get_dispatch_layout (
309
+ topk_idx ,
310
+ num_experts ,
311
+ previous_event = previous_event ,
312
+ async_finish = async_finish ,
313
+ allocate_on_comm_stream = previous_event is not None and async_finish ,
314
+ )
315
+
316
+ (
317
+ recv_x ,
318
+ recv_topk_idx ,
319
+ recv_topk_weights ,
320
+ num_recv_tokens_per_expert_list ,
321
+ handle ,
322
+ event ,
323
+ ) = self .buffer_normal .dispatch (
324
+ x ,
325
+ topk_idx = topk_idx ,
326
+ topk_weights = topk_weights ,
327
+ num_tokens_per_rank = num_tokens_per_rank ,
328
+ num_tokens_per_rdma_rank = num_tokens_per_rdma_rank ,
329
+ is_token_in_rank = is_token_in_rank ,
330
+ num_tokens_per_expert = num_tokens_per_expert ,
331
+ previous_event = previous_event ,
332
+ async_finish = async_finish ,
333
+ allocate_on_comm_stream = previous_event is not None and async_finish ,
334
+ )
335
+
336
+ yield
337
+ if async_finish :
338
+ event .current_stream_wait ()
339
+ return (
340
+ recv_x ,
341
+ recv_topk_idx ,
342
+ recv_topk_weights ,
343
+ num_recv_tokens_per_expert_list ,
344
+ handle ,
345
+ event ,
346
+ )
259
347
260
348
def combine (
261
349
self , hidden_states : torch .Tensor
@@ -268,6 +356,17 @@ def combine(
268
356
self .handle = None
269
357
return hidden_states .view (self .hidden_shape )
270
358
359
+ def combine_yield (
360
+ self , hidden_states : torch .Tensor
361
+ ):
362
+ if hidden_states .shape [0 ] > 0 :
363
+ hidden_states = self .get_restored_hidden_states_by_experts (
364
+ hidden_states
365
+ )
366
+ hidden_states , event = yield from self .combine_normal_yield (hidden_states , self .handle )
367
+ self .handle = None
368
+ return hidden_states .view (self .hidden_shape )
369
+
271
370
def combine_normal (self , x : torch .Tensor , handle : Tuple , previous_event = None ):
272
371
combined_x , _ , event = self .buffer_normal .combine (
273
372
x ,
@@ -278,6 +377,22 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
278
377
)
279
378
return combined_x , event
280
379
380
+ def combine_normal_yield (self , x : torch .Tensor , handle : Tuple , previous_event = None , async_finish = True ):
381
+ yield
382
+ previous_event = self .buffer_normal .capture () if async_finish else None
383
+ combined_x , _ , event = self .buffer_normal .combine (
384
+ x ,
385
+ handle ,
386
+ async_finish = async_finish ,
387
+ previous_event = previous_event ,
388
+ allocate_on_comm_stream = previous_event is not None and async_finish ,
389
+ )
390
+
391
+ yield
392
+ if async_finish :
393
+ event .current_stream_wait ()
394
+ return combined_x , event
395
+
281
396
def _indices_to_multihot (self , indices , probs ):
282
397
batch_size = indices .shape [0 ]
283
398
multihot_routing_map = torch .zeros (
0 commit comments