@@ -174,6 +174,16 @@ def parse_args():
174174 parser .add_argument (
175175 "--hub_private_repo" , action = "store_true" , help = "Whether or not to create a private repository."
176176 )
177+ parser .add_argument (
178+ "--logger" ,
179+ type = str ,
180+ default = "tensorboard" ,
181+ choices = ["tensorboard" , "wandb" ],
182+ help = (
183+ "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
184+ " for experiment tracking and logging of model metrics and model checkpoints"
185+ ),
186+ )
177187 parser .add_argument (
178188 "--logging_dir" ,
179189 type = str ,
@@ -195,7 +205,6 @@ def parse_args():
195205 "and an Nvidia Ampere GPU."
196206 ),
197207 )
198-
199208 parser .add_argument (
200209 "--prediction_type" ,
201210 type = str ,
@@ -206,6 +215,24 @@ def parse_args():
206215
207216 parser .add_argument ("--ddpm_num_steps" , type = int , default = 1000 )
208217 parser .add_argument ("--ddpm_beta_schedule" , type = str , default = "linear" )
218+ parser .add_argument (
219+ "--checkpointing_steps" ,
220+ type = int ,
221+ default = 500 ,
222+ help = (
223+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
224+ " training using `--resume_from_checkpoint`."
225+ ),
226+ )
227+ parser .add_argument (
228+ "--resume_from_checkpoint" ,
229+ type = str ,
230+ default = None ,
231+ help = (
232+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
233+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
234+ ),
235+ )
209236
210237 args = parser .parse_args ()
211238 env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -233,7 +260,7 @@ def main(args):
233260 accelerator = Accelerator (
234261 gradient_accumulation_steps = args .gradient_accumulation_steps ,
235262 mixed_precision = args .mixed_precision ,
236- log_with = "tensorboard" ,
263+ log_with = args . logger ,
237264 logging_dir = logging_dir ,
238265 )
239266
@@ -321,6 +348,7 @@ def transforms(examples):
321348 model , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
322349 model , optimizer , train_dataloader , lr_scheduler
323350 )
351+ accelerator .register_for_checkpointing (lr_scheduler )
324352
325353 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
326354
@@ -353,11 +381,34 @@ def transforms(examples):
353381 accelerator .init_trackers (run )
354382
355383 global_step = 0
356- for epoch in range (args .num_epochs ):
384+ first_epoch = 0
385+ if args .resume_from_checkpoint :
386+ if args .resume_from_checkpoint != "latest" :
387+ path = os .path .basename (args .resume_from_checkpoint )
388+ else :
389+ # Get the most recent checkpoint
390+ dirs = os .listdir (args .output_dir )
391+ dirs = [d for d in dirs if d .startswith ("checkpoint" )]
392+ dirs = sorted (dirs , key = lambda x : int (x .split ("-" )[1 ]))
393+ path = dirs [- 1 ]
394+ accelerator .print (f"Resuming from checkpoint { path } " )
395+ accelerator .load_state (os .path .join (args .output_dir , path ))
396+ global_step = int (path .split ("-" )[1 ])
397+ resume_global_step = global_step * args .gradient_accumulation_steps
398+ first_epoch = resume_global_step // num_update_steps_per_epoch
399+ resume_step = resume_global_step % num_update_steps_per_epoch
400+
401+ for epoch in range (first_epoch , args .num_epochs ):
357402 model .train ()
358403 progress_bar = tqdm (total = num_update_steps_per_epoch , disable = not accelerator .is_local_main_process )
359404 progress_bar .set_description (f"Epoch { epoch } " )
360405 for step , batch in enumerate (train_dataloader ):
406+ # Skip steps until we reach the resumed step
407+ if args .resume_from_checkpoint and epoch == first_epoch and step < resume_step :
408+ if step % args .gradient_accumulation_steps == 0 :
409+ progress_bar .update (1 )
410+ continue
411+
361412 clean_images = batch ["input" ]
362413 # Sample noise that we'll add to the images
363414 noise = torch .randn (clean_images .shape ).to (clean_images .device )
@@ -404,6 +455,12 @@ def transforms(examples):
404455 progress_bar .update (1 )
405456 global_step += 1
406457
458+ if global_step % args .checkpointing_steps == 0 :
459+ if accelerator .is_main_process :
460+ save_path = os .path .join (args .output_dir , f"checkpoint-{ global_step } " )
461+ accelerator .save_state (save_path )
462+ logger .info (f"Saved state to { save_path } " )
463+
407464 logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ], "step" : global_step }
408465 if args .use_ema :
409466 logs ["ema_decay" ] = ema_model .decay
@@ -431,9 +488,11 @@ def transforms(examples):
431488
432489 # denormalize the images and save to tensorboard
433490 images_processed = (images * 255 ).round ().astype ("uint8" )
434- accelerator .trackers [0 ].writer .add_images (
435- "test_samples" , images_processed .transpose (0 , 3 , 1 , 2 ), epoch
436- )
491+
492+ if args .logger == "tensorboard" :
493+ accelerator .get_tracker ("tensorboard" ).add_images (
494+ "test_samples" , images_processed .transpose (0 , 3 , 1 , 2 ), epoch
495+ )
437496
438497 if epoch % args .save_model_epochs == 0 or epoch == args .num_epochs - 1 :
439498 # save the model
0 commit comments