We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9a58888 commit 09d0b5cCopy full SHA for 09d0b5c
train.py
@@ -385,22 +385,6 @@ def train(self):
385
for epoch in range(self.n_epochs):
386
self.epoch = epoch
387
self.train_one_epoch()
388
- # Save after each epoch
389
- print('Epoch completed. Saving..')
390
- state = {
391
- 'net': {key: self.model[key].state_dict() for key in self.model},
392
- 'optimizer': self.optimizer.state_dict(),
393
- 'scheduler': self.optimizer.scheduler_state_dict(),
394
- 'iters': self.iters,
395
- 'epoch': self.epoch,
396
- }
397
- save_path = os.path.join(
398
- self.log_dir,
399
- f'DiT_epoch_{self.epoch:05d}_step_{self.iters:05d}.pth'
400
- )
401
- torch.save(state, save_path)
402
- print(f"Checkpoint saved at {save_path}")
403
-
404
if self.iters >= self.max_steps:
405
break
406
0 commit comments