@@ -97,7 +97,13 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
97
97
).to (device )
98
98
99
99
do_sync_before_save = (args .dp_mode in ['local' ] and use_dp )
100
-
100
+
101
+ # Get the number of model parameters for the model
102
+ param_count = torch .zeros (1 , dtype = torch .int64 ).to (device )
103
+ local_param_count = sum (p .numel () for p in pipe .model .parameters ())
104
+ param_count .data [:] = local_param_count
105
+ pp_comm .reduce (param_count , 0 )
106
+
101
107
if get_pipeline_parallel_rank () == 0 and dp_rank == 0 :
102
108
103
109
epoch_steps = int (train_data_loader .dataset .get_dataset_example_count () / args .batch_size )
@@ -109,14 +115,11 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
109
115
# Get the number of tokens in the dataset
110
116
token_count = train_data_loader .dataset .get_dataset_token_count ()
111
117
112
- # Get the number of model parameters
113
- param_count = sum (p .numel () for p in pipe .model .parameters ())
114
-
115
118
# Report training start
116
119
event_reporter .report (object = EventReporter .OBJECT_FINE_TUNE ,
117
120
message = f"Training started for model { args .model_name } " ,
118
121
event_type = EventReporter .EVENT_TYPE_TRAINING_START ,
119
- param_count = param_count ,
122
+ param_count = param_count . item () ,
120
123
token_count = token_count ,
121
124
requires_is_enabled = False )
122
125
@@ -150,7 +153,7 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
150
153
151
154
if event_reporter is not None and pipe .global_step % epoch_steps == 0 :
152
155
event_reporter .report (object = EventReporter .OBJECT_FINE_TUNE ,
153
- message = f"Epoch competed for step { pipe .global_step } " ,
156
+ message = f"Epoch completed, at step { pipe .global_step } " ,
154
157
event_type = EventReporter .EVENT_TYPE_EPOCH_COMPLETE ,
155
158
requires_is_enabled = False )
156
159
@@ -163,11 +166,6 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
163
166
if dp_rank == 0 :
164
167
save_checkpoint (pipe , args )
165
168
166
- if event_reporter is not None :
167
- event_reporter .report (object = EventReporter .OBJECT_FINE_TUNE ,
168
- message = f"checkpoint saved for step { pipe .global_step } " ,
169
- event_type = EventReporter .EVENT_TYPE_CHECKPOINT_SAVE ,
170
- requires_is_enabled = False )
171
169
if do_sync_before_save :
172
170
pipe .dp_optim .rollback_parameters ()
173
171
0 commit comments