File tree Expand file tree Collapse file tree 2 files changed +10
-5
lines changed Expand file tree Collapse file tree 2 files changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -227,7 +227,7 @@ def train(self):
227
227
228
228
train_prev_logged_time = datetime .now ()
229
229
for epoch in range (self .train_params ['start_epoch_idx' ], num_epochs ):
230
- self .train_params ['start_epoch_idx ' ] = epoch
230
+ self .train_params ['current_epoch_idx ' ] = epoch
231
231
logger .info (f"Start epoch: { epoch } training." )
232
232
233
233
epoch_start_time = datetime .now ()
Original file line number Diff line number Diff line change @@ -77,10 +77,15 @@ def save_checkpoint(self, tag):
77
77
tag = str (tag ).zfill (8 )
78
78
logger .warning ('Saving checkpoint...' )
79
79
80
- self .train_params ['current_batch_idx' ] = self .train_params ['current_batch_idx' ] + 1
81
- if self .train_params ['current_batch_idx' ] == self .train_params ['updates_per_epoch' ]:
82
- self .train_params ['resume_batch_idx' ] = 0
83
- self .train_params ['resume_epoch_idx' ] += 1
80
+ resume_epoch_idx = self .train_params ['current_epoch_idx' ]
81
+ resume_batch_idx = self .train_params ['current_epoch_idx' ] + 1
82
+
83
+ if resume_batch_idx == self .train_params ['updates_per_epoch' ]:
84
+ self .train_params ['start_batch_idx' ] = 0
85
+ self .train_params ['start_epoch_idx' ] = resume_epoch_idx + 1
86
+ else :
87
+ self .train_params ['start_batch_idx' ] = resume_batch_idx
88
+ self .train_params ['start_epoch_idx' ] = resume_epoch_idx
84
89
85
90
save_dir = os .path .join (self .save_folder , tag )
86
91
You can’t perform that action at this time.
0 commit comments