16
16
17
17
import numpy as np
18
18
import paddle
19
+ import paddle .distributed .fleet as fleet
19
20
from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer .dygraph_sharding_optimizer import (
20
21
DygraphShardingOptimizer ,
21
22
)
@@ -106,17 +107,16 @@ def convert_opt_name_to_tname(tensor_names, opt_names):
106
107
107
108
108
109
class NodeModelState :
109
- def __init__ (self , mp_rank = None , sharding_rank = None , pp_rank = None ):
110
+ def __init__ (self , group ):
110
111
self ._model_weights = OrderedDict ()
111
112
self ._opt_state = OrderedDict ()
112
113
self ._master_weights = OrderedDict ()
113
114
self ._lr_scheduler = None
114
- self .set_node_rank ( mp_rank , sharding_rank , pp_rank )
115
+ self ._group = group
115
116
116
- def set_node_rank (self , mp_rank , sharding_rank , pp_rank ):
117
- self ._mp_rank = mp_rank
118
- self ._sharding_rank = sharding_rank
119
- self ._pp_rank = pp_rank
117
+ @property
118
+ def group (self ):
119
+ return self ._group
120
120
121
121
def _add_kv (self , d , k , v ):
122
122
assert k not in d
@@ -407,12 +407,13 @@ def split_state(self, split_func):
407
407
408
408
return node_model_states
409
409
410
- def even_distribute (self , group ):
410
+ def even_distribute (self ):
411
411
"""
412
412
distribute the node state evenly among all workers in group, and make sure
413
413
in the dicts of (key, rank)=>tensor, items keys of the same key but different rank are distributed to the
414
414
same worker
415
415
"""
416
+ group = self .group
416
417
# sharding degree == 1
417
418
if group is None or group .nranks < 2 :
418
419
return self
@@ -446,7 +447,7 @@ def distribute(state_dict):
446
447
def filter_func (key ):
447
448
assert key [0 ] in key_to_rank , key
448
449
dst_rank = key_to_rank [key [0 ]]
449
- return dst_rank == group .rank
450
+ return dst_rank == max ( group .rank , 0 )
450
451
451
452
return _all_gather_state_dict (state_dict , filter_func , group )
452
453
@@ -455,10 +456,11 @@ def filter_func(key):
455
456
self ._master_weights = distribute (self ._master_weights )
456
457
return self
457
458
458
- def reshard (self , group , filter_func ):
459
+ def reshard (self , filter_func ):
459
460
"""
460
461
reshard according to the passed in filter_func
461
462
"""
463
+ group = self .group
462
464
self ._model_weights = _all_gather_state_dict (self ._model_weights , filter_func , group )
463
465
self ._opt_state = _all_gather_state_dict (self ._opt_state , filter_func , group )
464
466
self ._master_weights = _all_gather_state_dict (self ._master_weights , filter_func , group )
@@ -511,6 +513,7 @@ def merge(state, l):
511
513
return self
512
514
513
515
def merge_from (self , other , rank = None ):
516
+ assert other .group is self .group
514
517
self .add_weights (other .model_weights , rank )
515
518
self .add_opts (other .opt_state , rank )
516
519
self .add_master_weights (other .master_weights , rank )
@@ -528,6 +531,68 @@ def get_opt_state_dict(self):
528
531
return opt_state_dict
529
532
530
533
534
+ def split_model_state (model_state , group_getter ):
535
+ res = OrderedDict ()
536
+ for k , v in model_state .items ():
537
+ group = group_getter .get_group (k )
538
+ if group .id not in res :
539
+ res [group .id ] = OrderedDict ()
540
+ res [group .id ][k ] = v
541
+ return res
542
+
543
+
544
+ def merge_model_state (model_state_map ):
545
+ res = OrderedDict ()
546
+ for gid , model_state in model_state_map .items ():
547
+ res .update (model_state )
548
+ return res
549
+
550
+
551
+ def split_opt_state (opt_state , group_getter ):
552
+ res = OrderedDict ()
553
+ lr_scheduler = opt_state .get ("LR_Scheduler" , None )
554
+ for k , v in opt_state .items ():
555
+ if k == "LR_Scheduler" :
556
+ continue
557
+ elif k == "master_weights" :
558
+ for kk , vv in v .items ():
559
+ group = group_getter .get_group (kk )
560
+ if group .id not in res :
561
+ res [group .id ] = {"master_weights" : OrderedDict (), "LR_Scheduler" : lr_scheduler }
562
+ res [group .id ]["master_weights" ][kk ] = vv
563
+ else :
564
+ assert isinstance (v , paddle .Tensor ), type (v )
565
+ group = group_getter .get_group (k )
566
+ if group .id not in res :
567
+ res [group .id ] = {"master_weights" : OrderedDict (), "LR_Scheduler" : lr_scheduler }
568
+ res [group .id ][k ] = v
569
+ return res
570
+
571
+
572
+ def merge_opt_state (opt_state_map ):
573
+ res = {"LR_Scheduler" : None , "master_weights" : OrderedDict ()}
574
+ for gid , opt_state in opt_state_map .items ():
575
+ for k , v in opt_state .items ():
576
+ if k == "LR_Scheduler" :
577
+ if v is not None :
578
+ res ["LR_Scheduler" ] = v
579
+ elif k == "master_weights" :
580
+ res ["master_weights" ].update (v )
581
+ else :
582
+ res [k ] = v
583
+ return res
584
+
585
+
586
+ def split_structure_name_mapping (structure_name_mapping , group_getter ):
587
+ res = OrderedDict ()
588
+ for k , v in structure_name_mapping .items ():
589
+ group = group_getter .get_group (k )
590
+ if group .id not in res :
591
+ res [group .id ] = OrderedDict ()
592
+ res [group .id ][k ] = v
593
+ return res
594
+
595
+
531
596
def all_gather_simple_object (obj , group ):
532
597
res = []
533
598
if group .nranks < 2 :
@@ -570,7 +635,7 @@ def map_func(weight):
570
635
del state_dict [k ]
571
636
else :
572
637
tensor = paddle .to_tensor (np .empty (shape , dtype ))
573
- logger .info (f"broadcast { k } from { rank } " )
638
+ logger .info (f"broadcast { k } from { rank } , group { group } " )
574
639
# broadcast the tensor
575
640
if group .nranks > 1 :
576
641
paddle .distributed .broadcast (
@@ -595,3 +660,29 @@ def _all_gather_state_dict(state_dict, filter_func, group):
595
660
for (k , v ) in tmp_state_dict .items ():
596
661
state_dict [k ] = v
597
662
return state_dict
663
+
664
+
665
+ def get_moe_sharding_group (hcg = None ):
666
+ if hcg is None :
667
+ hcg = fleet .get_hybrid_communicate_group ()
668
+ if hasattr (hcg , "get_moe_sharding_parallel_group" ):
669
+ return hcg .get_moe_sharding_parallel_group ()
670
+ else :
671
+ return None
672
+
673
+
674
+ def get_param_sharding_group (param , hcg = None ):
675
+ if hcg is None :
676
+ hcg = fleet .get_hybrid_communicate_group ()
677
+ default_group = hcg .get_sharding_parallel_group ()
678
+ ep_sharding_group = get_moe_sharding_group (hcg )
679
+
680
+ if not hasattr (param , "color" ):
681
+ return default_group
682
+ color = getattr (param , "color" )
683
+ if isinstance (color , dict ):
684
+ group = color .get ("group" , default_group )
685
+ assert group is default_group or group is ep_sharding_group , f"unsupported group: { group } "
686
+ return group
687
+ else :
688
+ return default_group
0 commit comments