@@ -73,7 +73,7 @@ def _lm_pred_func(x, y):
73
73
74
74
75
75
76
- def train_loop (args , pipe , device , train_data_loader , test_data_loader ):
76
+ def train_loop (args , pipe , device , train_data_loader , test_data_loader , steps_per_epoch ):
77
77
78
78
print ('training starts......' )
79
79
@@ -108,6 +108,8 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
108
108
109
109
if get_pipeline_parallel_rank () == 0 and dp_rank == 0 :
110
110
111
+ print (f"Training steps: total_steps={ args .total_steps } , steps_per_epoch={ steps_per_epoch } , steps_per_checkpoint={ args .checkpoint_steps } " )
112
+
111
113
upload_checkpoints_enabled = args .checkpoint_upload_prefix is not None
112
114
upload_manager = UploadManager (aws_endpoint_url = args .aws_endpoint_url ,
113
115
aws_access_key_id = args .aws_access_key_id ,
@@ -117,10 +119,6 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
117
119
event_reporter = event_reporter ,
118
120
n_stages = args .pipeline_group_size )
119
121
120
- epoch_steps = int (train_data_loader .dataset .get_dataset_example_count () / args .batch_size )
121
- if epoch_steps < 1 :
122
- epoch_steps = 1
123
-
124
122
if event_reporter is not None :
125
123
126
124
# Get the number of tokens in the dataset
@@ -162,7 +160,7 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
162
160
labels = input_ids .clone ()
163
161
current_iter_time = pipe .sgd_iter (input_ids , labels , loss_func = gpt_loss_func )
164
162
165
- if event_reporter is not None and pipe .global_step % epoch_steps == 0 :
163
+ if event_reporter is not None and ( pipe .global_step >= args . total_steps or pipe . global_step % steps_per_epoch == 0 ) :
166
164
event_reporter .report (object = EventReporter .OBJECT_FINE_TUNE ,
167
165
message = f"Epoch completed, at step { pipe .global_step } " ,
168
166
event_type = EventReporter .EVENT_TYPE_EPOCH_COMPLETE ,
@@ -261,14 +259,12 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
261
259
if do_sync_before_save :
262
260
pipe .dp_optim .rollback_parameters ()
263
261
262
+ # Compute the total number of training steps, steps per epoch, and steps per
263
+ # checkpoint
264
264
def calculate_training_steps (args , train_data_loader ) -> int :
265
- if args .total_steps is None and args .nepochs is None :
266
- return len (train_data_loader )
267
-
268
- if args .total_steps is not None :
269
- if args .nepochs is not None :
270
- print ("WARNING: total_steps ({args.toal_steps}) supercedes nepochs ({args.nepochs})." )
271
- return args .total_steps
265
+ total_steps = 0
266
+ steps_per_epoch = 0
267
+ steps_per_checkpoint = 0
272
268
273
269
token_count = train_data_loader .dataset .get_dataset_token_count ()
274
270
@@ -277,12 +273,54 @@ def calculate_training_steps(args, train_data_loader) -> int:
277
273
print ("Missing required arguments for calculating total steps based on epochs." )
278
274
sys .exit (1 )
279
275
280
- global_batch_size = int (args .batch_size * args .world_size / args .pipeline_group_size )
281
- total_steps = int ((args .nepochs * token_count ) / (global_batch_size * args .seq_length ))
276
+ global_batch_size = (args .batch_size * args .world_size + args .pipeline_group_size - 1 ) // args .pipeline_group_size
277
+ tokens_per_batch = global_batch_size * args .seq_length
278
+ steps_per_epoch = (token_count + tokens_per_batch - 1 ) // tokens_per_batch
279
+
280
+ if args .total_steps is not None :
281
+ if args .nepochs is not None :
282
+ print ("WARNING: total_steps ({args.toal_steps}) supercedes nepochs ({args.nepochs})." )
283
+ total_steps = args .total_steps
284
+ elif args .nepochs is not None :
285
+ total_steps = steps_per_epoch * args .nepochs
286
+ else :
287
+ total_steps = len (train_data_loader )
288
+
289
+ # Set the minimum number of total steps
282
290
if total_steps < 10 :
283
291
total_steps = 10
284
292
285
- return total_steps
293
+ # Ensure that the steps per epoch are consistent with total steps
294
+ # Note: This does not strictly follow the definition of an epoch. It just
295
+ # approximately distributes the reporting of epochs over the total number of
296
+ # steps.
297
+ if args .nepochs is not None :
298
+ steps_per_epoch = (total_steps + args .nepochs - 1 ) // args .nepochs
299
+
300
+ # clamp steps_per_epoch to [1, total_steps]
301
+ if steps_per_epoch > total_steps :
302
+ steps_per_epoch = total_steps
303
+ if steps_per_epoch < 1 :
304
+ steps_per_epoch = 1
305
+
306
+ # Set the number of steps per epoch based on user input.
307
+ if args .checkpoint_steps is not None and args .checkpoint_steps > 0 :
308
+ steps_per_checkpoint = args .checkpoint_steps
309
+ elif args .num_checkpoints is not None and args .num_checkpoints > 0 :
310
+ steps_per_checkpoint = (total_steps + args .num_checkpoints - 1 ) // args .num_checkpoints
311
+ else :
312
+ steps_per_checkpoint = total_steps
313
+
314
+ # Clamp steps_per_checkpoint to [1, total_steps]
315
+ if steps_per_checkpoint > total_steps :
316
+ steps_per_checkpoint = total_steps
317
+ if steps_per_checkpoint < 1 :
318
+ steps_per_checkpoint = 1
319
+
320
+ # Set the args base on what we computed above
321
+ args .total_steps = total_steps
322
+ args .checkpoint_steps = steps_per_checkpoint
323
+ return steps_per_epoch
286
324
287
325
def main ():
288
326
parser = argparse .ArgumentParser (description = 'Gpipe-GPT' )
@@ -393,13 +431,7 @@ def main():
393
431
test_data_loader = None
394
432
395
433
# calculate total steps
396
- args .total_steps = calculate_training_steps (args , train_data_loader )
397
- if args .checkpoint_steps == 0 and args .num_checkpoints > 0 :
398
- args .checkpoint_steps = int (args .total_steps / args .num_checkpoints )
399
- if args .checkpoint_steps < 1 :
400
- args .checkpoint_steps = 1
401
- print ("Total steps:" , args .total_steps )
402
- print ("Checkpoint steps:" , args .checkpoint_steps )
434
+ steps_per_epoch = calculate_training_steps (args , train_data_loader )
403
435
404
436
use_dp = (args .world_size != args .pipeline_group_size )
405
437
if use_dp :
@@ -416,7 +448,7 @@ def main():
416
448
pipe .optimizer .reload_model_params ()
417
449
418
450
if args .profiling == 'no-profiling' :
419
- train_loop (args , pipe , device , train_data_loader , test_data_loader )
451
+ train_loop (args , pipe , device , train_data_loader , test_data_loader , steps_per_epoch )
420
452
else :
421
453
prefix = './trace_json/gpt3_' + args .pp_mode
422
454
if use_dp :
@@ -426,14 +458,14 @@ def main():
426
458
args .profiling + '_' + args .trace_postfix + '.json'
427
459
if args .profiling == 'tidy_profiling' :
428
460
try :
429
- train_loop (args , pipe , device , train_data_loader , test_data_loader )
461
+ train_loop (args , pipe , device , train_data_loader , test_data_loader , steps_per_epoch )
430
462
except Exception as e :
431
463
raise e
432
464
print (get_pipeline_parallel_rank (), e )
433
465
pipe .export_profiling_result (filename = trace_file )
434
466
elif args .profiling == 'pytorch_profiling' :
435
467
with profiler .profile (profile_memory = True , use_cuda = args .use_cuda ) as prof :
436
- train_loop (args , pipe , device , train_data_loader , test_data_loader )
468
+ train_loop (args , pipe , device , train_data_loader , test_data_loader , steps_per_epoch )
437
469
print (prof .key_averages ().table ())
438
470
prof .export_chrome_trace (trace_file )
439
471
else :
0 commit comments