Skip to content

[AutoParallel] init sync param #10783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions paddlenlp/transformers/gpt/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def __init__(
config.hidden_size,
)
self.word_embeddings.weight = dist.shard_tensor(
self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Replicate()]
self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)]
)
self.position_embeddings.weight = dist.shard_tensor(
self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)]
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None):
shape=[config.vocab_size, config.hidden_size],
dtype=paddle.get_default_dtype(),
)
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)])
self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])

def forward(self, hidden_states, tensor_parallel_output=None):

Expand All @@ -1187,7 +1187,14 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

y = dist.reshard(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)])
y = dist.reshard(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)])
# sync_group = paddle.distributed.new_group(ranks=[0,3])
# with paddle.no_grad():
# paddle.distributed.all_reduce(
# y._local_value(),
# op=paddle.distributed.ReduceOp.SUM,
# group=sync_group
# )
logits = paddle.matmul(hidden_states, y, transpose_y=self.transpose_y)
return logits

Expand Down
54 changes: 53 additions & 1 deletion paddlenlp/transformers/gpt/modeling_auto_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,38 @@ def manual_model_split(model, stage_idx, group, mode, pp_degree):

layer_lists = model.layers

shared_params_name = {
"key1": ["embedding_0.w_0.dist", "gptlm_head_auto_0.w_0.dist"],
# "key2": ["linear_0.w_0.dist", "linear_24.w_0.dist"]
}
shared_mp = {}
get_group_from_ranks = {}
cur_rank = paddle.distributed.get_rank()
for key, pair in shared_params_name.items():
assert len(pair) == 2
ori_name = pair[0]
sync_name = pair[1]
ori_param = _get_param_from_name(ori_name, model)
sync_param = _get_param_from_name(sync_name, model)
ori_process_ids = ori_param.process_mesh.process_ids
sync_process_ids = sync_param.process_mesh.process_ids
cur_group = _build_current_sync_commm_group(ori_process_ids, sync_process_ids, get_group_from_ranks)
cur_param = None
if cur_rank in ori_process_ids:
cur_param = ori_param
elif cur_rank in sync_process_ids:
cur_param = sync_param
if cur_param is not None and cur_group is not None:
shared_mp[key] = {
"param": cur_param,
"group": cur_group,
}

# for key, pair in shared_mp.items():
# print("xxx key: ", key)
# print("xxx param: ", pair["param"])
# print("xxx group: ", pair["group"])

def _build_stage(model, stage_idx, group):
new_model = None
if stage_idx == 0:
Expand All @@ -151,7 +183,7 @@ def _build_stage(model, stage_idx, group):
new_model = GPTChunk(
layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], is_first=False, is_last=False
)
stage = PipelineStage(new_model, stage_idx, chunk_num, group=group)
stage = PipelineStage(new_model, stage_idx, chunk_num, group=group, shared_map=shared_mp)
return stage

stages = []
Expand All @@ -161,6 +193,26 @@ def _build_stage(model, stage_idx, group):
return stages


def _get_param_from_name(param_name, model):
for param in model.parameters():
if param.name == param_name:
return param


def _build_current_sync_commm_group(ranks_1, ranks_2, get_group_from_ranks):
cur_rank = paddle.distributed.get_rank()
cur_group = None
assert len(ranks_1) == len(ranks_2)
for idx in range(len(ranks_1)):
grup_ranks = tuple(sorted([ranks_1[idx], ranks_2[idx]]))
if grup_ranks not in get_group_from_ranks:
new_group = dist.new_group(ranks=list(grup_ranks))
get_group_from_ranks[grup_ranks] = new_group
if cur_rank in grup_ranks:
cur_group = get_group_from_ranks[grup_ranks]
return cur_group


def get_gpt_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group):
assert mode in ["VPP", "1F1B", "FThenB"]
stages = manual_model_split(model, group.rank, group, mode, pp_degree)
Expand Down
Loading