Skip to content

Commit c2fd74d

Browse files
authored
Fix the param_count to sum across pipeline_parallel_comm (#151)
* Fix the param_count to sum across pipeline_parallel_comm * Address review comments
1 parent 1e7dccf commit c2fd74d

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

training/dist_clm_train.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,13 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
9797
).to(device)
9898

9999
do_sync_before_save = (args.dp_mode in ['local'] and use_dp)
100-
100+
101+
# Get the number of model parameters for the model
102+
param_count = torch.zeros(1, dtype=torch.int64).to(device)
103+
local_param_count = sum(p.numel() for p in pipe.model.parameters())
104+
param_count.data[:] = local_param_count
105+
pp_comm.reduce(param_count, 0)
106+
101107
if get_pipeline_parallel_rank() == 0 and dp_rank == 0:
102108

103109
epoch_steps = int(train_data_loader.dataset.get_dataset_example_count() / args.batch_size)
@@ -109,14 +115,11 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
109115
# Get the number of tokens in the dataset
110116
token_count = train_data_loader.dataset.get_dataset_token_count()
111117

112-
# Get the number of model parameters
113-
param_count = sum(p.numel() for p in pipe.model.parameters())
114-
115118
# Report training start
116119
event_reporter.report(object=EventReporter.OBJECT_FINE_TUNE,
117120
message=f"Training started for model {args.model_name}",
118121
event_type=EventReporter.EVENT_TYPE_TRAINING_START,
119-
param_count=param_count,
122+
param_count=param_count.item(),
120123
token_count=token_count,
121124
requires_is_enabled=False)
122125

@@ -150,7 +153,7 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
150153

151154
if event_reporter is not None and pipe.global_step % epoch_steps == 0:
152155
event_reporter.report(object=EventReporter.OBJECT_FINE_TUNE,
153-
message=f"Epoch competed for step {pipe.global_step}",
156+
message=f"Epoch completed, at step {pipe.global_step}",
154157
event_type=EventReporter.EVENT_TYPE_EPOCH_COMPLETE,
155158
requires_is_enabled=False)
156159

@@ -163,11 +166,6 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
163166
if dp_rank == 0:
164167
save_checkpoint(pipe, args)
165168

166-
if event_reporter is not None:
167-
event_reporter.report(object=EventReporter.OBJECT_FINE_TUNE,
168-
message=f"checkpoint saved for step {pipe.global_step}",
169-
event_type=EventReporter.EVENT_TYPE_CHECKPOINT_SAVE,
170-
requires_is_enabled=False)
171169
if do_sync_before_save:
172170
pipe.dp_optim.rollback_parameters()
173171

0 commit comments

Comments
 (0)