Skip to content

Commit d56825e

Browse files
fix: how print training resume logs. (huggingface#5117)
* fix: how print training resume logs. * propagate changes to text-to-image scripts. * propagate changes to instructpix2pix. * propagate changes to dreambooth * propagate changes to custom diffusion and instructpix2pix * propagate changes to kandinsky * propagate changes to textual inv. * debug * fix: checkpointing. * debug * debug * debug * back to the square * debug * debug * change condition order. * debug * debug * debug * debug * revert to original * clean --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent cd1b8d7 commit d56825e

14 files changed

+180
-164
lines changed

examples/custom_diffusion/train_custom_diffusion.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,30 +1075,30 @@ def main(args):
10751075
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
10761076
)
10771077
args.resume_from_checkpoint = None
1078+
initial_global_step = 0
10781079
else:
10791080
accelerator.print(f"Resuming from checkpoint {path}")
10801081
accelerator.load_state(os.path.join(args.output_dir, path))
10811082
global_step = int(path.split("-")[1])
10821083

1083-
resume_global_step = global_step * args.gradient_accumulation_steps
1084+
initial_global_step = global_step
10841085
first_epoch = global_step // num_update_steps_per_epoch
1085-
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
1086-
1087-
# Only show the progress bar once on each machine.
1088-
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1089-
progress_bar.set_description("Steps")
1086+
else:
1087+
initial_global_step = 0
1088+
1089+
progress_bar = tqdm(
1090+
range(0, args.max_train_steps),
1091+
initial=initial_global_step,
1092+
desc="Steps",
1093+
# Only show the progress bar once on each machine.
1094+
disable=not accelerator.is_local_main_process,
1095+
)
10901096

10911097
for epoch in range(first_epoch, args.num_train_epochs):
10921098
unet.train()
10931099
if args.modifier_token is not None:
10941100
text_encoder.train()
10951101
for step, batch in enumerate(train_dataloader):
1096-
# Skip steps until we reach the resumed step
1097-
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1098-
if step % args.gradient_accumulation_steps == 0:
1099-
progress_bar.update(1)
1100-
continue
1101-
11021102
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
11031103
# Convert images to latent space
11041104
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()

examples/dreambooth/train_dreambooth.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,30 +1178,30 @@ def compute_text_embeddings(prompt):
11781178
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
11791179
)
11801180
args.resume_from_checkpoint = None
1181+
initial_global_step = 0
11811182
else:
11821183
accelerator.print(f"Resuming from checkpoint {path}")
11831184
accelerator.load_state(os.path.join(args.output_dir, path))
11841185
global_step = int(path.split("-")[1])
11851186

1186-
resume_global_step = global_step * args.gradient_accumulation_steps
1187+
initial_global_step = global_step
11871188
first_epoch = global_step // num_update_steps_per_epoch
1188-
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
1189-
1190-
# Only show the progress bar once on each machine.
1191-
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1192-
progress_bar.set_description("Steps")
1189+
else:
1190+
initial_global_step = 0
1191+
1192+
progress_bar = tqdm(
1193+
range(0, args.max_train_steps),
1194+
initial=initial_global_step,
1195+
desc="Steps",
1196+
# Only show the progress bar once on each machine.
1197+
disable=not accelerator.is_local_main_process,
1198+
)
11931199

11941200
for epoch in range(first_epoch, args.num_train_epochs):
11951201
unet.train()
11961202
if args.train_text_encoder:
11971203
text_encoder.train()
11981204
for step, batch in enumerate(train_dataloader):
1199-
# Skip steps until we reach the resumed step
1200-
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1201-
if step % args.gradient_accumulation_steps == 0:
1202-
progress_bar.update(1)
1203-
continue
1204-
12051205
with accelerator.accumulate(unet):
12061206
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
12071207

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,30 +1108,30 @@ def compute_text_embeddings(prompt):
11081108
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
11091109
)
11101110
args.resume_from_checkpoint = None
1111+
initial_global_step = 0
11111112
else:
11121113
accelerator.print(f"Resuming from checkpoint {path}")
11131114
accelerator.load_state(os.path.join(args.output_dir, path))
11141115
global_step = int(path.split("-")[1])
11151116

1116-
resume_global_step = global_step * args.gradient_accumulation_steps
1117+
initial_global_step = global_step
11171118
first_epoch = global_step // num_update_steps_per_epoch
1118-
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
1119-
1120-
# Only show the progress bar once on each machine.
1121-
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1122-
progress_bar.set_description("Steps")
1119+
else:
1120+
initial_global_step = 0
1121+
1122+
progress_bar = tqdm(
1123+
range(0, args.max_train_steps),
1124+
initial=initial_global_step,
1125+
desc="Steps",
1126+
# Only show the progress bar once on each machine.
1127+
disable=not accelerator.is_local_main_process,
1128+
)
11231129

11241130
for epoch in range(first_epoch, args.num_train_epochs):
11251131
unet.train()
11261132
if args.train_text_encoder:
11271133
text_encoder.train()
11281134
for step, batch in enumerate(train_dataloader):
1129-
# Skip steps until we reach the resumed step
1130-
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1131-
if step % args.gradient_accumulation_steps == 0:
1132-
progress_bar.update(1)
1133-
continue
1134-
11351135
with accelerator.accumulate(unet):
11361136
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
11371137

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,31 +1048,32 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
10481048
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
10491049
)
10501050
args.resume_from_checkpoint = None
1051+
initial_global_step = 0
10511052
else:
10521053
accelerator.print(f"Resuming from checkpoint {path}")
10531054
accelerator.load_state(os.path.join(args.output_dir, path))
10541055
global_step = int(path.split("-")[1])
10551056

1056-
resume_global_step = global_step * args.gradient_accumulation_steps
1057+
initial_global_step = global_step
10571058
first_epoch = global_step // num_update_steps_per_epoch
1058-
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
10591059

1060-
# Only show the progress bar once on each machine.
1061-
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
1062-
progress_bar.set_description("Steps")
1060+
else:
1061+
initial_global_step = 0
1062+
1063+
progress_bar = tqdm(
1064+
range(0, args.max_train_steps),
1065+
initial=initial_global_step,
1066+
desc="Steps",
1067+
# Only show the progress bar once on each machine.
1068+
disable=not accelerator.is_local_main_process,
1069+
)
10631070

10641071
for epoch in range(first_epoch, args.num_train_epochs):
10651072
unet.train()
10661073
if args.train_text_encoder:
10671074
text_encoder_one.train()
10681075
text_encoder_two.train()
10691076
for step, batch in enumerate(train_dataloader):
1070-
# Skip steps until we reach the resumed step
1071-
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
1072-
if step % args.gradient_accumulation_steps == 0:
1073-
progress_bar.update(1)
1074-
continue
1075-
10761077
with accelerator.accumulate(unet):
10771078
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
10781079

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,9 @@ def preprocess_images(examples):
726726
text_encoder_1.requires_grad_(False)
727727
text_encoder_2.requires_grad_(False)
728728

729+
# Set UNet to trainable.
730+
unet.train()
731+
729732
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
730733
def encode_prompt(text_encoders, tokenizers, prompt):
731734
prompt_embeds_list = []
@@ -933,29 +936,28 @@ def collate_fn(examples):
933936
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
934937
)
935938
args.resume_from_checkpoint = None
939+
initial_global_step = 0
936940
else:
937941
accelerator.print(f"Resuming from checkpoint {path}")
938942
accelerator.load_state(os.path.join(args.output_dir, path))
939943
global_step = int(path.split("-")[1])
940944

941-
resume_global_step = global_step * args.gradient_accumulation_steps
945+
initial_global_step = global_step
942946
first_epoch = global_step // num_update_steps_per_epoch
943-
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
944-
945-
# Only show the progress bar once on each machine.
946-
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
947-
progress_bar.set_description("Steps")
947+
else:
948+
initial_global_step = 0
949+
950+
progress_bar = tqdm(
951+
range(0, args.max_train_steps),
952+
initial=initial_global_step,
953+
desc="Steps",
954+
# Only show the progress bar once on each machine.
955+
disable=not accelerator.is_local_main_process,
956+
)
948957

949958
for epoch in range(first_epoch, args.num_train_epochs):
950-
unet.train()
951959
train_loss = 0.0
952960
for step, batch in enumerate(train_dataloader):
953-
# Skip steps until we reach the resumed step
954-
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
955-
if step % args.gradient_accumulation_steps == 0:
956-
progress_bar.update(1)
957-
continue
958-
959961
with accelerator.accumulate(unet):
960962
# We want to learn the denoising process w.r.t the edited images which
961963
# are conditioned on the original image (which was edited) and the edit instruction.

examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,9 @@ def deepspeed_zero_init_disabled_context_manager():
512512
vae.requires_grad_(False)
513513
image_encoder.requires_grad_(False)
514514

515+
# Set unet to trainable.
516+
unet.train()
517+
515518
# Create EMA for the unet.
516519
if args.use_ema:
517520
ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
@@ -727,27 +730,28 @@ def collate_fn(examples):
727730
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
728731
)
729732
args.resume_from_checkpoint = None
733+
initial_global_step = 0
730734
else:
731735
accelerator.print(f"Resuming from checkpoint {path}")
732736
accelerator.load_state(os.path.join(args.output_dir, path))
733737
global_step = int(path.split("-")[1])
734738

735-
resume_global_step = global_step * args.gradient_accumulation_steps
739+
initial_global_step = global_step
736740
first_epoch = global_step // num_update_steps_per_epoch
737-
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
741+
else:
742+
initial_global_step = 0
743+
744+
progress_bar = tqdm(
745+
range(0, args.max_train_steps),
746+
initial=initial_global_step,
747+
desc="Steps",
748+
# Only show the progress bar once on each machine.
749+
disable=not accelerator.is_local_main_process,
750+
)
738751

739-
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
740-
progress_bar.set_description("Steps")
741752
for epoch in range(first_epoch, args.num_train_epochs):
742-
unet.train()
743753
train_loss = 0.0
744754
for step, batch in enumerate(train_dataloader):
745-
# Skip steps until we reach the resumed step
746-
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
747-
if step % args.gradient_accumulation_steps == 0:
748-
progress_bar.update(1)
749-
continue
750-
751755
with accelerator.accumulate(unet):
752756
# Convert images to latent space
753757
images = batch["pixel_values"].to(weight_dtype)

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -579,29 +579,29 @@ def collate_fn(examples):
579579
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
580580
)
581581
args.resume_from_checkpoint = None
582+
initial_global_step = 0
582583
else:
583584
accelerator.print(f"Resuming from checkpoint {path}")
584585
accelerator.load_state(os.path.join(args.output_dir, path))
585586
global_step = int(path.split("-")[1])
586587

587-
resume_global_step = global_step * args.gradient_accumulation_steps
588+
initial_global_step = global_step
588589
first_epoch = global_step // num_update_steps_per_epoch
589-
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
590+
else:
591+
initial_global_step = 0
590592

591-
# Only show the progress bar once on each machine.
592-
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
593-
progress_bar.set_description("Steps")
593+
progress_bar = tqdm(
594+
range(0, args.max_train_steps),
595+
initial=initial_global_step,
596+
desc="Steps",
597+
# Only show the progress bar once on each machine.
598+
disable=not accelerator.is_local_main_process,
599+
)
594600

595601
for epoch in range(first_epoch, args.num_train_epochs):
596602
unet.train()
597603
train_loss = 0.0
598604
for step, batch in enumerate(train_dataloader):
599-
# Skip steps until we reach the resumed step
600-
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
601-
if step % args.gradient_accumulation_steps == 0:
602-
progress_bar.update(1)
603-
continue
604-
605605
with accelerator.accumulate(unet):
606606
# Convert images to latent space
607607
images = batch["pixel_values"].to(weight_dtype)

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -595,30 +595,33 @@ def collate_fn(examples):
595595
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
596596
)
597597
args.resume_from_checkpoint = None
598+
initial_global_step = 0
598599
else:
599600
accelerator.print(f"Resuming from checkpoint {path}")
600601
accelerator.load_state(os.path.join(args.output_dir, path))
601602
global_step = int(path.split("-")[1])
602603

603-
resume_global_step = global_step * args.gradient_accumulation_steps
604+
initial_global_step = global_step
604605
first_epoch = global_step // num_update_steps_per_epoch
605-
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
606606

607-
# Only show the progress bar once on each machine.
608-
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
609-
progress_bar.set_description("Steps")
607+
else:
608+
initial_global_step = 0
609+
610+
progress_bar = tqdm(
611+
range(0, args.max_train_steps),
612+
initial=initial_global_step,
613+
desc="Steps",
614+
# Only show the progress bar once on each machine.
615+
disable=not accelerator.is_local_main_process,
616+
)
617+
610618
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
611619
clip_std = clip_std.to(weight_dtype).to(accelerator.device)
620+
612621
for epoch in range(first_epoch, args.num_train_epochs):
613622
prior.train()
614623
train_loss = 0.0
615624
for step, batch in enumerate(train_dataloader):
616-
# Skip steps until we reach the resumed step
617-
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
618-
if step % args.gradient_accumulation_steps == 0:
619-
progress_bar.update(1)
620-
continue
621-
622625
with accelerator.accumulate(prior):
623626
# Convert images to latent space
624627
text_input_ids, text_mask, clip_images = (

0 commit comments

Comments
 (0)