Skip to content

Commit ed7b831

Browse files
committed
fix ckpt resume bug
1 parent 39d0c9b commit ed7b831

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

trainer/default_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def train(self):
227227

228228
train_prev_logged_time = datetime.now()
229229
for epoch in range(self.train_params['start_epoch_idx'], num_epochs):
230-
self.train_params['start_epoch_idx'] = epoch
230+
self.train_params['current_epoch_idx'] = epoch
231231
logger.info(f"Start epoch: {epoch} training.")
232232

233233
epoch_start_time = datetime.now()

trainer/utils_trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,15 @@ def save_checkpoint(self, tag):
7777
tag = str(tag).zfill(8)
7878
logger.warning('Saving checkpoint...')
7979

80-
self.train_params['current_batch_idx'] = self.train_params['current_batch_idx'] + 1
81-
if self.train_params['current_batch_idx'] == self.train_params['updates_per_epoch']:
82-
self.train_params['resume_batch_idx'] = 0
83-
self.train_params['resume_epoch_idx'] += 1
80+
resume_epoch_idx = self.train_params['current_epoch_idx']
81+
resume_batch_idx = self.train_params['current_epoch_idx'] + 1
82+
83+
if resume_batch_idx == self.train_params['updates_per_epoch']:
84+
self.train_params['start_batch_idx'] = 0
85+
self.train_params['start_epoch_idx'] = resume_epoch_idx + 1
86+
else:
87+
self.train_params['start_batch_idx'] = resume_batch_idx
88+
self.train_params['start_epoch_idx'] = resume_epoch_idx
8489

8590
save_dir = os.path.join(self.save_folder, tag)
8691

0 commit comments

Comments
 (0)