Skip to content

Commit c2c0537

Browse files
committed
support tpdp-ep sharding reshard
1 parent 51009f9 commit c2c0537

File tree

6 files changed

+339
-85
lines changed

6 files changed

+339
-85
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,9 @@ def __post_init__(self):
12301230

12311231
if expert_parallel_degree > 1:
12321232
moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree)
1233+
assert (
1234+
self.expert_tensor_parallel_degree <= 1
1235+
), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
12331236
else:
12341237
moe_sharding_parallel_degree = 1
12351238
moe_sharding_parallel_degree = max(moe_sharding_parallel_degree, 1)
@@ -2164,6 +2167,17 @@ def pipeline_parallel_rank(self):
21642167
else:
21652168
return 0
21662169

2170+
@property
2171+
def expert_parallel_rank(self):
2172+
if self.use_hybrid_parallel:
2173+
hcg = fleet.get_hybrid_communicate_group()
2174+
if hasattr(hcg, "get_expert_parallel_rank"):
2175+
return max(hcg.get_expert_parallel_rank(), 0)
2176+
else:
2177+
return 0
2178+
else:
2179+
return 0
2180+
21672181
def _format_name(self, prefix, rank, degree):
21682182
size = 2
21692183
return f"{prefix}{rank:0>{size}d}"
@@ -2178,7 +2192,7 @@ def optimizer_name_suffix(self):
21782192
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
21792193
if self.sharding_parallel_degree > 1:
21802194
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
2181-
if self.use_expert_parallel:
2195+
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
21822196
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
21832197
return "_".join(name)
21842198
else:
@@ -2194,7 +2208,7 @@ def weight_name_suffix(self):
21942208
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
21952209
if self.pipeline_parallel_degree > 1:
21962210
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
2197-
if self.use_expert_parallel:
2211+
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
21982212
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
21992213
return "_".join(name)
22002214

@@ -2220,7 +2234,7 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None, sharding_p
22202234
shard_id = self.sharding_parallel_rank
22212235
assert isinstance(shard_id, int)
22222236
name.append(self._format_name("shard", shard_id, sharding_parallel_degree))
2223-
if self.use_expert_parallel:
2237+
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
22242238
if moe_id is None:
22252239
moe_id = self.data_parallel_rank
22262240
assert isinstance(moe_id, int)

paddlenlp/trainer/utils/reshard/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
NodeModelState,
2020
all_gather_state_dict,
2121
convert_opt_name_to_tname,
22+
get_moe_sharding_group,
23+
get_param_sharding_group,
2224
get_sharding_strategy,
2325
is_sharding_opt,
26+
merge_model_state,
27+
merge_opt_state,
28+
split_model_state,
29+
split_opt_state,
30+
split_structure_name_mapping,
2431
)

paddlenlp/trainer/utils/reshard/common.py

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import paddle
19+
import paddle.distributed.fleet as fleet
1920
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
2021
DygraphShardingOptimizer,
2122
)
@@ -106,17 +107,16 @@ def convert_opt_name_to_tname(tensor_names, opt_names):
106107

107108

108109
class NodeModelState:
109-
def __init__(self, mp_rank=None, sharding_rank=None, pp_rank=None):
110+
def __init__(self, group):
110111
self._model_weights = OrderedDict()
111112
self._opt_state = OrderedDict()
112113
self._master_weights = OrderedDict()
113114
self._lr_scheduler = None
114-
self.set_node_rank(mp_rank, sharding_rank, pp_rank)
115+
self._group = group
115116

116-
def set_node_rank(self, mp_rank, sharding_rank, pp_rank):
117-
self._mp_rank = mp_rank
118-
self._sharding_rank = sharding_rank
119-
self._pp_rank = pp_rank
117+
@property
118+
def group(self):
119+
return self._group
120120

121121
def _add_kv(self, d, k, v):
122122
assert k not in d
@@ -407,12 +407,13 @@ def split_state(self, split_func):
407407

408408
return node_model_states
409409

410-
def even_distribute(self, group):
410+
def even_distribute(self):
411411
"""
412412
distribute the node state evenly among all workers in group, and make sure
413413
in the dicts of (key, rank)=>tensor, items keys of the same key but different rank are distributed to the
414414
same worker
415415
"""
416+
group = self.group
416417
# sharding degree == 1
417418
if group is None or group.nranks < 2:
418419
return self
@@ -446,7 +447,7 @@ def distribute(state_dict):
446447
def filter_func(key):
447448
assert key[0] in key_to_rank, key
448449
dst_rank = key_to_rank[key[0]]
449-
return dst_rank == group.rank
450+
return dst_rank == max(group.rank, 0)
450451

451452
return _all_gather_state_dict(state_dict, filter_func, group)
452453

@@ -455,10 +456,11 @@ def filter_func(key):
455456
self._master_weights = distribute(self._master_weights)
456457
return self
457458

458-
def reshard(self, group, filter_func):
459+
def reshard(self, filter_func):
459460
"""
460461
reshard according to the passed in filter_func
461462
"""
463+
group = self.group
462464
self._model_weights = _all_gather_state_dict(self._model_weights, filter_func, group)
463465
self._opt_state = _all_gather_state_dict(self._opt_state, filter_func, group)
464466
self._master_weights = _all_gather_state_dict(self._master_weights, filter_func, group)
@@ -511,6 +513,7 @@ def merge(state, l):
511513
return self
512514

513515
def merge_from(self, other, rank=None):
516+
assert other.group is self.group
514517
self.add_weights(other.model_weights, rank)
515518
self.add_opts(other.opt_state, rank)
516519
self.add_master_weights(other.master_weights, rank)
@@ -528,6 +531,68 @@ def get_opt_state_dict(self):
528531
return opt_state_dict
529532

530533

534+
def split_model_state(model_state, group_getter):
535+
res = OrderedDict()
536+
for k, v in model_state.items():
537+
group = group_getter.get_group(k)
538+
if group.id not in res:
539+
res[group.id] = OrderedDict()
540+
res[group.id][k] = v
541+
return res
542+
543+
544+
def merge_model_state(model_state_map):
545+
res = OrderedDict()
546+
for gid, model_state in model_state_map.items():
547+
res.update(model_state)
548+
return res
549+
550+
551+
def split_opt_state(opt_state, group_getter):
552+
res = OrderedDict()
553+
lr_scheduler = opt_state.get("LR_Scheduler", None)
554+
for k, v in opt_state.items():
555+
if k == "LR_Scheduler":
556+
continue
557+
elif k == "master_weights":
558+
for kk, vv in v.items():
559+
group = group_getter.get_group(kk)
560+
if group.id not in res:
561+
res[group.id] = {"master_weights": OrderedDict(), "LR_Scheduler": lr_scheduler}
562+
res[group.id]["master_weights"][kk] = vv
563+
else:
564+
assert isinstance(v, paddle.Tensor), type(v)
565+
group = group_getter.get_group(k)
566+
if group.id not in res:
567+
res[group.id] = {"master_weights": OrderedDict(), "LR_Scheduler": lr_scheduler}
568+
res[group.id][k] = v
569+
return res
570+
571+
572+
def merge_opt_state(opt_state_map):
573+
res = {"LR_Scheduler": None, "master_weights": OrderedDict()}
574+
for gid, opt_state in opt_state_map.items():
575+
for k, v in opt_state.items():
576+
if k == "LR_Scheduler":
577+
if v is not None:
578+
res["LR_Scheduler"] = v
579+
elif k == "master_weights":
580+
res["master_weights"].update(v)
581+
else:
582+
res[k] = v
583+
return res
584+
585+
586+
def split_structure_name_mapping(structure_name_mapping, group_getter):
587+
res = OrderedDict()
588+
for k, v in structure_name_mapping.items():
589+
group = group_getter.get_group(k)
590+
if group.id not in res:
591+
res[group.id] = OrderedDict()
592+
res[group.id][k] = v
593+
return res
594+
595+
531596
def all_gather_simple_object(obj, group):
532597
res = []
533598
if group.nranks < 2:
@@ -570,7 +635,7 @@ def map_func(weight):
570635
del state_dict[k]
571636
else:
572637
tensor = paddle.to_tensor(np.empty(shape, dtype))
573-
logger.info(f"broadcast {k} from {rank}")
638+
logger.info(f"broadcast {k} from {rank}, group {group}")
574639
# broadcast the tensor
575640
if group.nranks > 1:
576641
paddle.distributed.broadcast(
@@ -595,3 +660,29 @@ def _all_gather_state_dict(state_dict, filter_func, group):
595660
for (k, v) in tmp_state_dict.items():
596661
state_dict[k] = v
597662
return state_dict
663+
664+
665+
def get_moe_sharding_group(hcg=None):
666+
if hcg is None:
667+
hcg = fleet.get_hybrid_communicate_group()
668+
if hasattr(hcg, "get_moe_sharding_parallel_group"):
669+
return hcg.get_moe_sharding_parallel_group()
670+
else:
671+
return None
672+
673+
674+
def get_param_sharding_group(param, hcg=None):
675+
if hcg is None:
676+
hcg = fleet.get_hybrid_communicate_group()
677+
default_group = hcg.get_sharding_parallel_group()
678+
ep_sharding_group = get_moe_sharding_group(hcg)
679+
680+
if not hasattr(param, "color"):
681+
return default_group
682+
color = getattr(param, "color")
683+
if isinstance(color, dict):
684+
group = color.get("group", default_group)
685+
assert group is default_group or group is ep_sharding_group, f"unsupported group: {group}"
686+
return group
687+
else:
688+
return default_group

paddlenlp/trainer/utils/reshard/sharding_v1.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
from .common import is_sharding_opt
2121

2222

23-
def shard(node_model_state, model, optimizer, hcg):
24-
group = hcg.get_sharding_parallel_group()
25-
cur_rank = group.rank
23+
def shard(node_model_state, model, optimizer):
24+
cur_rank = max(node_model_state.group.rank, 0)
2625
unwrapped_optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizer)
2726
if unwrapped_optimizer is not None:
2827
optimizer = unwrapped_optimizer
@@ -40,10 +39,10 @@ def filter_func(key):
4039
assert not is_sharding_opt(optimizer)
4140
filter_func = lambda key: True
4241

43-
node_model_state.reshard(group, filter_func)
42+
node_model_state.reshard(filter_func)
4443
return node_model_state
4544

4645

47-
def restore(node_model_state, model, optimizer, hcg):
46+
def restore(node_model_state, model, optimizer):
4847
node_model_state.drop_rank()
4948
return node_model_state

paddlenlp/trainer/utils/reshard/sharding_v2.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,16 @@
3434

3535
from paddle.distributed.communication.reduce import ReduceOp
3636

37+
from .common import get_moe_sharding_group
3738

38-
def shard(node_model_state, model, optimizer, hcg):
39+
40+
def shard(node_model_state, model, optimizer):
3941
assert DygraphShardingOptimizerV2 is not None
40-
group = hcg.get_sharding_parallel_group()
41-
cur_rank = group.rank
4242
split_infos = collect_split_info(optimizer, model)
4343

44+
group = node_model_state.group
45+
cur_rank = max(group.rank, 0)
46+
4447
def split_func(k, v):
4548
param_name = k[1]
4649
opt_name = k[-1]
@@ -87,15 +90,14 @@ def filter_func(k):
8790
return rank == cur_rank
8891

8992
# reshard
90-
node_model_state.reshard(group, filter_func)
93+
node_model_state.reshard(filter_func)
9194
node_model_state.drop_rank()
9295
return node_model_state
9396

9497

95-
def restore(node_model_state, model, optimizer, hcg):
96-
group = hcg.get_sharding_parallel_group()
98+
def restore(node_model_state, model, optimizer):
9799
# evenly distribute param
98-
node_model_state.even_distribute(group)
100+
node_model_state.even_distribute()
99101
param_shapes = {k: v.shape for (k, v) in model.state_dict().items()}
100102

101103
def merge_func(k, v):
@@ -175,7 +177,7 @@ def gather_infos(comm_buffer):
175177
for comm_buffer in optimizer._comm_buffer_list:
176178
gather_infos(comm_buffer)
177179

178-
assert len(split_infos)
180+
assert len(split_infos) > 0
179181
return split_infos
180182

181183

@@ -211,11 +213,16 @@ def get_matched_length(name):
211213
if need_allgather:
212214
if hcg is None:
213215
hcg = fleet.get_hybrid_communicate_group()
214-
group = hcg.get_sharding_parallel_group()
215-
if group is not None and group.nranks > 1:
216-
x = paddle.to_tensor([is_matched], dtype=paddle.int32)
217-
paddle.distributed.stream.all_reduce(x, op=ReduceOp.MIN, group=group, sync_op=True, use_calc_stream=True)
218-
global_is_matched = int(x.numpy()[0])
216+
sharding_group = hcg.get_sharding_parallel_group()
217+
moe_sharding_group = get_moe_sharding_group(hcg)
218+
for group in [sharding_group, moe_sharding_group]:
219+
if group is not None and group.nranks > 1:
220+
x = paddle.to_tensor([is_matched], dtype=paddle.int32)
221+
paddle.distributed.stream.all_reduce(
222+
x, op=ReduceOp.MIN, group=group, sync_op=True, use_calc_stream=True
223+
)
224+
is_matched = int(x.numpy()[0])
225+
global_is_matched = is_matched
219226
else:
220227
global_is_matched = is_matched
221228

0 commit comments

Comments
 (0)