@@ -161,12 +161,6 @@ def parse_args(input_args=None):
161161 help = "Total number of training steps to perform. If provided, overrides num_train_epochs." ,
162162 )
163163 parser .add_argument ("--save_steps" , type = int , default = 500 , help = "Save checkpoint every X updates steps." )
164- parser .add_argument (
165- "--gradient_accumulation_steps" ,
166- type = int ,
167- default = 1 ,
168- help = "Number of updates steps to accumulate before performing a backward/update pass." ,
169- )
170164 parser .add_argument (
171165 "--gradient_checkpointing" ,
172166 action = "store_true" ,
@@ -376,10 +370,8 @@ def main(args):
376370 else :
377371 colossalai .launch_from_torch (config = {}, seed = args .seed )
378372
379- colossalai .launch_from_torch (config = {})
380-
381- if args .seed is not None :
382- gpc .set_seed (args .seed )
373+ local_rank = gpc .get_local_rank (ParallelMode .DATA )
374+ world_size = gpc .get_world_size (ParallelMode .DATA )
383375
384376 if args .with_prior_preservation :
385377 class_images_dir = Path (args .class_data_dir )
@@ -408,7 +400,7 @@ def main(args):
408400 for example in tqdm (
409401 sample_dataloader ,
410402 desc = "Generating class images" ,
411- disable = not gpc . get_local_rank ( ParallelMode . DATA ) == 0 ,
403+ disable = not local_rank == 0 ,
412404 ):
413405 images = pipeline (example ["prompt" ]).images
414406
@@ -420,7 +412,7 @@ def main(args):
420412 del pipeline
421413
422414 # Handle the repository creation
423- if gpc . get_local_rank ( ParallelMode . DATA ) == 0 :
415+ if local_rank == 0 :
424416 if args .push_to_hub :
425417 if args .hub_model_id is None :
426418 repo_name = get_full_repo_name (Path (args .output_dir ).name , token = args .hub_token )
@@ -486,12 +478,7 @@ def main(args):
486478 unet .enable_gradient_checkpointing ()
487479
488480 if args .scale_lr :
489- args .learning_rate = (
490- args .learning_rate
491- * args .gradient_accumulation_steps
492- * args .train_batch_size
493- * gpc .get_world_size (ParallelMode .DATA )
494- )
481+ args .learning_rate = args .learning_rate * args .train_batch_size * world_size
495482
496483 unet = gemini_zero_dpp (unet , args .placement )
497484
@@ -547,16 +534,16 @@ def collate_fn(examples):
547534
548535 # Scheduler and math around the number of training steps.
549536 overrode_max_train_steps = False
550- num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args . gradient_accumulation_steps )
537+ num_update_steps_per_epoch = math .ceil (len (train_dataloader ))
551538 if args .max_train_steps is None :
552539 args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
553540 overrode_max_train_steps = True
554541
555542 lr_scheduler = get_scheduler (
556543 args .lr_scheduler ,
557544 optimizer = optimizer ,
558- num_warmup_steps = args .lr_warmup_steps * args . gradient_accumulation_steps ,
559- num_training_steps = args .max_train_steps * args . gradient_accumulation_steps ,
545+ num_warmup_steps = args .lr_warmup_steps ,
546+ num_training_steps = args .max_train_steps ,
560547 )
561548 weight_dtype = torch .float32
562549 if args .mixed_precision == "fp16" :
@@ -571,26 +558,25 @@ def collate_fn(examples):
571558 text_encoder .to (get_current_device (), dtype = weight_dtype )
572559
573560 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
574- num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args . gradient_accumulation_steps )
561+ num_update_steps_per_epoch = math .ceil (len (train_dataloader ))
575562 if overrode_max_train_steps :
576563 args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
577564 # Afterwards we recalculate our number of training epochs
578565 args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
579566
580567 # Train!
581- total_batch_size = args .train_batch_size * gpc . get_world_size ( ParallelMode . DATA ) * args . gradient_accumulation_steps
568+ total_batch_size = args .train_batch_size * world_size
582569
583570 logger .info ("***** Running training *****" , ranks = [0 ])
584571 logger .info (f" Num examples = { len (train_dataset )} " , ranks = [0 ])
585572 logger .info (f" Num batches each epoch = { len (train_dataloader )} " , ranks = [0 ])
586573 logger .info (f" Num Epochs = { args .num_train_epochs } " , ranks = [0 ])
587574 logger .info (f" Instantaneous batch size per device = { args .train_batch_size } " , ranks = [0 ])
588575 logger .info (f" Total train batch size (w. parallel, distributed & accumulation) = { total_batch_size } " , ranks = [0 ])
589- logger .info (f" Gradient Accumulation steps = { args .gradient_accumulation_steps } " , ranks = [0 ])
590576 logger .info (f" Total optimization steps = { args .max_train_steps } " , ranks = [0 ])
591577
592578 # Only show the progress bar once on each machine.
593- progress_bar = tqdm (range (args .max_train_steps ), disable = not gpc . get_local_rank ( ParallelMode . DATA ) == 0 )
579+ progress_bar = tqdm (range (args .max_train_steps ), disable = not local_rank == 0 )
594580 progress_bar .set_description ("Steps" )
595581 global_step = 0
596582
@@ -607,7 +593,7 @@ def collate_fn(examples):
607593 optimizer .zero_grad ()
608594
609595 latents = vae .encode (batch ["pixel_values" ].to (dtype = weight_dtype )).latent_dist .sample ()
610- latents = latents * vae . config . scaling_factor
596+ latents = latents * 0.18215
611597
612598 # Sample noise that we'll add to the latents
613599 noise = torch .randn_like (latents )
@@ -667,7 +653,7 @@ def collate_fn(examples):
667653 if global_step % args .save_steps == 0 :
668654 torch .cuda .synchronize ()
669655 torch_unet = get_static_torch_model (unet )
670- if gpc . get_local_rank ( ParallelMode . DATA ) == 0 :
656+ if local_rank == 0 :
671657 pipeline = DiffusionPipeline .from_pretrained (
672658 args .pretrained_model_name_or_path ,
673659 unet = torch_unet ,
@@ -682,7 +668,7 @@ def collate_fn(examples):
682668 torch .cuda .synchronize ()
683669 unet = get_static_torch_model (unet )
684670
685- if gpc . get_local_rank ( ParallelMode . DATA ) == 0 :
671+ if local_rank == 0 :
686672 pipeline = DiffusionPipeline .from_pretrained (
687673 args .pretrained_model_name_or_path ,
688674 unet = unet ,
0 commit comments