Skip to content

Commit 8b132d3

Browse files
authored
Fix compute step calculations (#156)
1 parent 93367c0 commit 8b132d3

File tree

1 file changed

+58
-26
lines changed

1 file changed

+58
-26
lines changed

training/dist_clm_train.py

+58-26
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _lm_pred_func(x, y):
7373

7474

7575

76-
def train_loop(args, pipe, device, train_data_loader, test_data_loader):
76+
def train_loop(args, pipe, device, train_data_loader, test_data_loader, steps_per_epoch):
7777

7878
print('training starts......')
7979

@@ -108,6 +108,8 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
108108

109109
if get_pipeline_parallel_rank() == 0 and dp_rank == 0:
110110

111+
print(f"Training steps: total_steps={args.total_steps}, steps_per_epoch={steps_per_epoch}, steps_per_checkpoint={args.checkpoint_steps}")
112+
111113
upload_checkpoints_enabled = args.checkpoint_upload_prefix is not None
112114
upload_manager = UploadManager(aws_endpoint_url = args.aws_endpoint_url,
113115
aws_access_key_id = args.aws_access_key_id,
@@ -117,10 +119,6 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
117119
event_reporter = event_reporter,
118120
n_stages = args.pipeline_group_size)
119121

120-
epoch_steps = int(train_data_loader.dataset.get_dataset_example_count() / args.batch_size)
121-
if epoch_steps < 1:
122-
epoch_steps = 1
123-
124122
if event_reporter is not None:
125123

126124
# Get the number of tokens in the dataset
@@ -162,7 +160,7 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
162160
labels = input_ids.clone()
163161
current_iter_time = pipe.sgd_iter(input_ids, labels, loss_func=gpt_loss_func)
164162

165-
if event_reporter is not None and pipe.global_step % epoch_steps == 0:
163+
if event_reporter is not None and (pipe.global_step >= args.total_steps or pipe.global_step % steps_per_epoch == 0):
166164
event_reporter.report(object=EventReporter.OBJECT_FINE_TUNE,
167165
message=f"Epoch completed, at step {pipe.global_step}",
168166
event_type=EventReporter.EVENT_TYPE_EPOCH_COMPLETE,
@@ -261,14 +259,12 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
261259
if do_sync_before_save:
262260
pipe.dp_optim.rollback_parameters()
263261

262+
# Compute the total number of training steps, steps per epoch, and steps per
263+
# checkpoint
264264
def calculate_training_steps(args, train_data_loader) -> int:
265-
if args.total_steps is None and args.nepochs is None:
266-
return len(train_data_loader)
267-
268-
if args.total_steps is not None:
269-
if args.nepochs is not None:
270-
print("WARNING: total_steps ({args.toal_steps}) supercedes nepochs ({args.nepochs}).")
271-
return args.total_steps
265+
total_steps = 0
266+
steps_per_epoch = 0
267+
steps_per_checkpoint = 0
272268

273269
token_count = train_data_loader.dataset.get_dataset_token_count()
274270

@@ -277,12 +273,54 @@ def calculate_training_steps(args, train_data_loader) -> int:
277273
print("Missing required arguments for calculating total steps based on epochs.")
278274
sys.exit(1)
279275

280-
global_batch_size = int(args.batch_size * args.world_size / args.pipeline_group_size)
281-
total_steps = int((args.nepochs * token_count) / (global_batch_size * args.seq_length))
276+
global_batch_size = (args.batch_size * args.world_size + args.pipeline_group_size - 1) // args.pipeline_group_size
277+
tokens_per_batch = global_batch_size * args.seq_length
278+
steps_per_epoch = (token_count + tokens_per_batch - 1) // tokens_per_batch
279+
280+
if args.total_steps is not None:
281+
if args.nepochs is not None:
282+
print("WARNING: total_steps ({args.toal_steps}) supercedes nepochs ({args.nepochs}).")
283+
total_steps = args.total_steps
284+
elif args.nepochs is not None:
285+
total_steps = steps_per_epoch * args.nepochs
286+
else:
287+
total_steps = len(train_data_loader)
288+
289+
# Set the minimum number of total steps
282290
if total_steps < 10:
283291
total_steps = 10
284292

285-
return total_steps
293+
# Ensure that the steps per epoch are consistent with total steps
294+
# Note: This does not strictly follow the definition of an epoch. It just
295+
# approximately distributes the reporting of epochs over the total number of
296+
# steps.
297+
if args.nepochs is not None:
298+
steps_per_epoch = (total_steps + args.nepochs - 1) // args.nepochs
299+
300+
# clamp steps_per_epoch to [1, total_steps]
301+
if steps_per_epoch > total_steps:
302+
steps_per_epoch = total_steps
303+
if steps_per_epoch < 1:
304+
steps_per_epoch = 1
305+
306+
# Set the number of steps per epoch based on user input.
307+
if args.checkpoint_steps is not None and args.checkpoint_steps > 0:
308+
steps_per_checkpoint = args.checkpoint_steps
309+
elif args.num_checkpoints is not None and args.num_checkpoints > 0:
310+
steps_per_checkpoint = (total_steps + args.num_checkpoints - 1) // args.num_checkpoints
311+
else:
312+
steps_per_checkpoint = total_steps
313+
314+
# Clamp steps_per_checkpoint to [1, total_steps]
315+
if steps_per_checkpoint > total_steps:
316+
steps_per_checkpoint = total_steps
317+
if steps_per_checkpoint < 1:
318+
steps_per_checkpoint = 1
319+
320+
# Set the args base on what we computed above
321+
args.total_steps = total_steps
322+
args.checkpoint_steps = steps_per_checkpoint
323+
return steps_per_epoch
286324

287325
def main():
288326
parser = argparse.ArgumentParser(description='Gpipe-GPT')
@@ -393,13 +431,7 @@ def main():
393431
test_data_loader = None
394432

395433
# calculate total steps
396-
args.total_steps = calculate_training_steps(args, train_data_loader)
397-
if args.checkpoint_steps == 0 and args.num_checkpoints > 0:
398-
args.checkpoint_steps = int(args.total_steps / args.num_checkpoints)
399-
if args.checkpoint_steps < 1:
400-
args.checkpoint_steps = 1
401-
print("Total steps:", args.total_steps)
402-
print("Checkpoint steps:", args.checkpoint_steps)
434+
steps_per_epoch = calculate_training_steps(args, train_data_loader)
403435

404436
use_dp = (args.world_size != args.pipeline_group_size)
405437
if use_dp:
@@ -416,7 +448,7 @@ def main():
416448
pipe.optimizer.reload_model_params()
417449

418450
if args.profiling == 'no-profiling':
419-
train_loop(args, pipe, device, train_data_loader, test_data_loader)
451+
train_loop(args, pipe, device, train_data_loader, test_data_loader, steps_per_epoch)
420452
else:
421453
prefix = './trace_json/gpt3_' + args.pp_mode
422454
if use_dp:
@@ -426,14 +458,14 @@ def main():
426458
args.profiling + '_' + args.trace_postfix + '.json'
427459
if args.profiling == 'tidy_profiling':
428460
try:
429-
train_loop(args, pipe, device, train_data_loader, test_data_loader)
461+
train_loop(args, pipe, device, train_data_loader, test_data_loader, steps_per_epoch)
430462
except Exception as e:
431463
raise e
432464
print(get_pipeline_parallel_rank(), e)
433465
pipe.export_profiling_result(filename=trace_file)
434466
elif args.profiling == 'pytorch_profiling':
435467
with profiler.profile(profile_memory=True, use_cuda=args.use_cuda) as prof:
436-
train_loop(args, pipe, device, train_data_loader, test_data_loader)
468+
train_loop(args, pipe, device, train_data_loader, test_data_loader, steps_per_epoch)
437469
print(prof.key_averages().table())
438470
prof.export_chrome_trace(trace_file)
439471
else:

0 commit comments

Comments
 (0)