Skip to content

Commit ffed242

Browse files
authored
fix distributed init twice (huggingface#2252)
fix colossalai dreambooth
1 parent 8178c84 commit ffed242

File tree

1 file changed

+14
-28
lines changed

1 file changed

+14
-28
lines changed

examples/research_projects/colossalai/train_dreambooth_colossalai.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,6 @@ def parse_args(input_args=None):
161161
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
162162
)
163163
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
164-
parser.add_argument(
165-
"--gradient_accumulation_steps",
166-
type=int,
167-
default=1,
168-
help="Number of updates steps to accumulate before performing a backward/update pass.",
169-
)
170164
parser.add_argument(
171165
"--gradient_checkpointing",
172166
action="store_true",
@@ -376,10 +370,8 @@ def main(args):
376370
else:
377371
colossalai.launch_from_torch(config={}, seed=args.seed)
378372

379-
colossalai.launch_from_torch(config={})
380-
381-
if args.seed is not None:
382-
gpc.set_seed(args.seed)
373+
local_rank = gpc.get_local_rank(ParallelMode.DATA)
374+
world_size = gpc.get_world_size(ParallelMode.DATA)
383375

384376
if args.with_prior_preservation:
385377
class_images_dir = Path(args.class_data_dir)
@@ -408,7 +400,7 @@ def main(args):
408400
for example in tqdm(
409401
sample_dataloader,
410402
desc="Generating class images",
411-
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
403+
disable=not local_rank == 0,
412404
):
413405
images = pipeline(example["prompt"]).images
414406

@@ -420,7 +412,7 @@ def main(args):
420412
del pipeline
421413

422414
# Handle the repository creation
423-
if gpc.get_local_rank(ParallelMode.DATA) == 0:
415+
if local_rank == 0:
424416
if args.push_to_hub:
425417
if args.hub_model_id is None:
426418
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
@@ -486,12 +478,7 @@ def main(args):
486478
unet.enable_gradient_checkpointing()
487479

488480
if args.scale_lr:
489-
args.learning_rate = (
490-
args.learning_rate
491-
* args.gradient_accumulation_steps
492-
* args.train_batch_size
493-
* gpc.get_world_size(ParallelMode.DATA)
494-
)
481+
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
495482

496483
unet = gemini_zero_dpp(unet, args.placement)
497484

@@ -547,16 +534,16 @@ def collate_fn(examples):
547534

548535
# Scheduler and math around the number of training steps.
549536
overrode_max_train_steps = False
550-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
537+
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
551538
if args.max_train_steps is None:
552539
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
553540
overrode_max_train_steps = True
554541

555542
lr_scheduler = get_scheduler(
556543
args.lr_scheduler,
557544
optimizer=optimizer,
558-
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
559-
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
545+
num_warmup_steps=args.lr_warmup_steps,
546+
num_training_steps=args.max_train_steps,
560547
)
561548
weight_dtype = torch.float32
562549
if args.mixed_precision == "fp16":
@@ -571,26 +558,25 @@ def collate_fn(examples):
571558
text_encoder.to(get_current_device(), dtype=weight_dtype)
572559

573560
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
574-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
561+
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
575562
if overrode_max_train_steps:
576563
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
577564
# Afterwards we recalculate our number of training epochs
578565
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
579566

580567
# Train!
581-
total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps
568+
total_batch_size = args.train_batch_size * world_size
582569

583570
logger.info("***** Running training *****", ranks=[0])
584571
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
585572
logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0])
586573
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
587574
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
588575
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
589-
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
590576
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
591577

592578
# Only show the progress bar once on each machine.
593-
progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0)
579+
progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
594580
progress_bar.set_description("Steps")
595581
global_step = 0
596582

@@ -607,7 +593,7 @@ def collate_fn(examples):
607593
optimizer.zero_grad()
608594

609595
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
610-
latents = latents * vae.config.scaling_factor
596+
latents = latents * 0.18215
611597

612598
# Sample noise that we'll add to the latents
613599
noise = torch.randn_like(latents)
@@ -667,7 +653,7 @@ def collate_fn(examples):
667653
if global_step % args.save_steps == 0:
668654
torch.cuda.synchronize()
669655
torch_unet = get_static_torch_model(unet)
670-
if gpc.get_local_rank(ParallelMode.DATA) == 0:
656+
if local_rank == 0:
671657
pipeline = DiffusionPipeline.from_pretrained(
672658
args.pretrained_model_name_or_path,
673659
unet=torch_unet,
@@ -682,7 +668,7 @@ def collate_fn(examples):
682668
torch.cuda.synchronize()
683669
unet = get_static_torch_model(unet)
684670

685-
if gpc.get_local_rank(ParallelMode.DATA) == 0:
671+
if local_rank == 0:
686672
pipeline = DiffusionPipeline.from_pretrained(
687673
args.pretrained_model_name_or_path,
688674
unet=unet,

0 commit comments

Comments
 (0)