@@ -181,7 +181,6 @@ def dispatch(
181
181
num_max_dispatch_tokens_per_rank : int = 128 ,
182
182
) -> Tuple [torch .Tensor , torch .Tensor ]:
183
183
self .hidden_shape = hidden_states .shape
184
- topk_idx = topk_idx .to (torch .int64 )
185
184
(
186
185
hidden_states ,
187
186
topk_idx ,
@@ -205,6 +204,39 @@ def dispatch(
205
204
hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
206
205
return hidden_states , topk_idx , topk_weights , tokens_per_expert
207
206
207
+ def dispatch_yield (
208
+ self ,
209
+ hidden_states : torch .Tensor ,
210
+ topk_idx : torch .Tensor ,
211
+ topk_weights : torch .Tensor ,
212
+ num_experts : int ,
213
+ previous_event = None ,
214
+ num_max_dispatch_tokens_per_rank : int = 128 ,
215
+ ):
216
+ self .hidden_shape = hidden_states .shape
217
+ (
218
+ hidden_states ,
219
+ topk_idx ,
220
+ topk_weights ,
221
+ num_recv_tokens_per_expert_list ,
222
+ handle ,
223
+ event ,
224
+ ) = yield from self .dispatch_normal_yield (
225
+ hidden_states , topk_idx , topk_weights , num_experts , previous_event
226
+ )
227
+ self .tokens_per_expert = torch .tensor (
228
+ num_recv_tokens_per_expert_list ,
229
+ device = hidden_states .device ,
230
+ dtype = torch .int64 ,
231
+ )
232
+ tokens_per_expert = self .get_number_of_tokens_per_expert ()
233
+ self .handle = handle
234
+ self .topk_idx = topk_idx
235
+ self .topk_weights = topk_weights
236
+ if hidden_states .shape [0 ] > 0 :
237
+ hidden_states = self .get_permuted_hidden_states_by_experts (hidden_states )
238
+ return hidden_states , topk_idx , topk_weights , tokens_per_expert
239
+
208
240
def dispatch_normal (
209
241
self ,
210
242
x : torch .Tensor ,
@@ -256,6 +288,61 @@ def dispatch_normal(
256
288
event ,
257
289
)
258
290
291
+ def dispatch_normal_yield (
292
+ self ,
293
+ x : torch .Tensor ,
294
+ topk_idx : torch .Tensor ,
295
+ topk_weights : torch .Tensor ,
296
+ num_experts : int ,
297
+ previous_event = None ,
298
+ async_finish = True
299
+ ):
300
+ previous_event = self .buffer_normal .capture () if async_finish else None
301
+ (
302
+ num_tokens_per_rank ,
303
+ num_tokens_per_rdma_rank ,
304
+ num_tokens_per_expert ,
305
+ is_token_in_rank ,
306
+ previous_event ,
307
+ ) = self .buffer_normal .get_dispatch_layout (
308
+ topk_idx ,
309
+ num_experts ,
310
+ previous_event = previous_event ,
311
+ async_finish = async_finish ,
312
+ allocate_on_comm_stream = previous_event is not None and async_finish ,
313
+ )
314
+
315
+ (
316
+ recv_x ,
317
+ recv_topk_idx ,
318
+ recv_topk_weights ,
319
+ num_recv_tokens_per_expert_list ,
320
+ handle ,
321
+ event ,
322
+ ) = self .buffer_normal .dispatch (
323
+ x ,
324
+ topk_idx = topk_idx ,
325
+ topk_weights = topk_weights ,
326
+ num_tokens_per_rank = num_tokens_per_rank ,
327
+ num_tokens_per_rdma_rank = num_tokens_per_rdma_rank ,
328
+ is_token_in_rank = is_token_in_rank ,
329
+ num_tokens_per_expert = num_tokens_per_expert ,
330
+ previous_event = previous_event ,
331
+ async_finish = async_finish ,
332
+ allocate_on_comm_stream = previous_event is not None and async_finish ,
333
+ )
334
+
335
+ yield
336
+ if async_finish :
337
+ event .current_stream_wait ()
338
+ return (
339
+ recv_x ,
340
+ recv_topk_idx ,
341
+ recv_topk_weights ,
342
+ num_recv_tokens_per_expert_list ,
343
+ handle ,
344
+ event ,
345
+ )
259
346
260
347
def combine (
261
348
self , hidden_states : torch .Tensor
@@ -268,6 +355,17 @@ def combine(
268
355
self .handle = None
269
356
return hidden_states .view (self .hidden_shape )
270
357
358
+ def combine_yield (
359
+ self , hidden_states : torch .Tensor
360
+ ):
361
+ if hidden_states .shape [0 ] > 0 :
362
+ hidden_states = self .get_restored_hidden_states_by_experts (
363
+ hidden_states
364
+ )
365
+ hidden_states , event = yield from self .combine_normal_yield (hidden_states , self .handle )
366
+ self .handle = None
367
+ return hidden_states .view (self .hidden_shape )
368
+
271
369
def combine_normal (self , x : torch .Tensor , handle : Tuple , previous_event = None ):
272
370
combined_x , _ , event = self .buffer_normal .combine (
273
371
x ,
@@ -278,6 +376,22 @@ def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
278
376
)
279
377
return combined_x , event
280
378
379
+ def combine_normal_yield (self , x : torch .Tensor , handle : Tuple , previous_event = None , async_finish = True ):
380
+ yield
381
+ previous_event = self .buffer_normal .capture () if async_finish else None
382
+ combined_x , _ , event = self .buffer_normal .combine (
383
+ x ,
384
+ handle ,
385
+ async_finish = async_finish ,
386
+ previous_event = previous_event ,
387
+ allocate_on_comm_stream = previous_event is not None and async_finish ,
388
+ )
389
+
390
+ yield
391
+ if async_finish :
392
+ event .current_stream_wait ()
393
+ return combined_x , event
394
+
281
395
def _indices_to_multihot (self , indices , probs ):
282
396
batch_size = indices .shape [0 ]
283
397
multihot_routing_map = torch .zeros (
0 commit comments