Skip to content

Commit 847daf2

Browse files
prathikrPrathik RaoPrathik Rao
authored
update train_unconditional_ort.py (huggingface#1775)
* reflect changes * run make style Co-authored-by: Prathik Rao <[email protected]> Co-authored-by: Prathik Rao <[email protected]@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
1 parent 9f8c915 commit 847daf2

File tree

1 file changed

+65
-6
lines changed

1 file changed

+65
-6
lines changed

examples/unconditional_image_generation/train_unconditional_ort.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,16 @@ def parse_args():
174174
parser.add_argument(
175175
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
176176
)
177+
parser.add_argument(
178+
"--logger",
179+
type=str,
180+
default="tensorboard",
181+
choices=["tensorboard", "wandb"],
182+
help=(
183+
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
184+
" for experiment tracking and logging of model metrics and model checkpoints"
185+
),
186+
)
177187
parser.add_argument(
178188
"--logging_dir",
179189
type=str,
@@ -195,7 +205,6 @@ def parse_args():
195205
"and an Nvidia Ampere GPU."
196206
),
197207
)
198-
199208
parser.add_argument(
200209
"--prediction_type",
201210
type=str,
@@ -206,6 +215,24 @@ def parse_args():
206215

207216
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
208217
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
218+
parser.add_argument(
219+
"--checkpointing_steps",
220+
type=int,
221+
default=500,
222+
help=(
223+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
224+
" training using `--resume_from_checkpoint`."
225+
),
226+
)
227+
parser.add_argument(
228+
"--resume_from_checkpoint",
229+
type=str,
230+
default=None,
231+
help=(
232+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
233+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
234+
),
235+
)
209236

210237
args = parser.parse_args()
211238
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -233,7 +260,7 @@ def main(args):
233260
accelerator = Accelerator(
234261
gradient_accumulation_steps=args.gradient_accumulation_steps,
235262
mixed_precision=args.mixed_precision,
236-
log_with="tensorboard",
263+
log_with=args.logger,
237264
logging_dir=logging_dir,
238265
)
239266

@@ -321,6 +348,7 @@ def transforms(examples):
321348
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
322349
model, optimizer, train_dataloader, lr_scheduler
323350
)
351+
accelerator.register_for_checkpointing(lr_scheduler)
324352

325353
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
326354

@@ -353,11 +381,34 @@ def transforms(examples):
353381
accelerator.init_trackers(run)
354382

355383
global_step = 0
356-
for epoch in range(args.num_epochs):
384+
first_epoch = 0
385+
if args.resume_from_checkpoint:
386+
if args.resume_from_checkpoint != "latest":
387+
path = os.path.basename(args.resume_from_checkpoint)
388+
else:
389+
# Get the most recent checkpoint
390+
dirs = os.listdir(args.output_dir)
391+
dirs = [d for d in dirs if d.startswith("checkpoint")]
392+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
393+
path = dirs[-1]
394+
accelerator.print(f"Resuming from checkpoint {path}")
395+
accelerator.load_state(os.path.join(args.output_dir, path))
396+
global_step = int(path.split("-")[1])
397+
resume_global_step = global_step * args.gradient_accumulation_steps
398+
first_epoch = resume_global_step // num_update_steps_per_epoch
399+
resume_step = resume_global_step % num_update_steps_per_epoch
400+
401+
for epoch in range(first_epoch, args.num_epochs):
357402
model.train()
358403
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
359404
progress_bar.set_description(f"Epoch {epoch}")
360405
for step, batch in enumerate(train_dataloader):
406+
# Skip steps until we reach the resumed step
407+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
408+
if step % args.gradient_accumulation_steps == 0:
409+
progress_bar.update(1)
410+
continue
411+
361412
clean_images = batch["input"]
362413
# Sample noise that we'll add to the images
363414
noise = torch.randn(clean_images.shape).to(clean_images.device)
@@ -404,6 +455,12 @@ def transforms(examples):
404455
progress_bar.update(1)
405456
global_step += 1
406457

458+
if global_step % args.checkpointing_steps == 0:
459+
if accelerator.is_main_process:
460+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
461+
accelerator.save_state(save_path)
462+
logger.info(f"Saved state to {save_path}")
463+
407464
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
408465
if args.use_ema:
409466
logs["ema_decay"] = ema_model.decay
@@ -431,9 +488,11 @@ def transforms(examples):
431488

432489
# denormalize the images and save to tensorboard
433490
images_processed = (images * 255).round().astype("uint8")
434-
accelerator.trackers[0].writer.add_images(
435-
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
436-
)
491+
492+
if args.logger == "tensorboard":
493+
accelerator.get_tracker("tensorboard").add_images(
494+
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
495+
)
437496

438497
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
439498
# save the model

0 commit comments

Comments
 (0)