11import argparse
2+ import math
23import os
34
45import torch
2930def main (args ):
3031 logging_dir = os .path .join (args .output_dir , args .logging_dir )
3132 accelerator = Accelerator (
33+ gradient_accumulation_steps = args .gradient_accumulation_steps ,
3234 mixed_precision = args .mixed_precision ,
3335 log_with = "tensorboard" ,
3436 logging_dir = logging_dir ,
@@ -105,6 +107,8 @@ def transforms(examples):
105107 model , optimizer , train_dataloader , lr_scheduler
106108 )
107109
110+ num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
111+
108112 ema_model = EMAModel (model , inv_gamma = args .ema_inv_gamma , power = args .ema_power , max_value = args .ema_max_decay )
109113
110114 if args .push_to_hub :
@@ -117,7 +121,7 @@ def transforms(examples):
117121 global_step = 0
118122 for epoch in range (args .num_epochs ):
119123 model .train ()
120- progress_bar = tqdm (total = len ( train_dataloader ) , disable = not accelerator .is_local_main_process )
124+ progress_bar = tqdm (total = num_update_steps_per_epoch , disable = not accelerator .is_local_main_process )
121125 progress_bar .set_description (f"Epoch { epoch } " )
122126 for step , batch in enumerate (train_dataloader ):
123127 clean_images = batch ["input" ]
@@ -146,13 +150,16 @@ def transforms(examples):
146150 ema_model .step (model )
147151 optimizer .zero_grad ()
148152
149- progress_bar .update (1 )
153+ # Checks if the accelerator has performed an optimization step behind the scenes
154+ if accelerator .sync_gradients :
155+ progress_bar .update (1 )
156+ global_step += 1
157+
150158 logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ], "step" : global_step }
151159 if args .use_ema :
152160 logs ["ema_decay" ] = ema_model .decay
153161 progress_bar .set_postfix (** logs )
154162 accelerator .log (logs , step = global_step )
155- global_step += 1
156163 progress_bar .close ()
157164
158165 accelerator .wait_for_everyone ()
0 commit comments