Skip to content

Commit 2b8f2d4

Browse files
committed
update LR before optimize
1 parent 666a2cd commit 2b8f2d4

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ def train(segmentation_module, iterator, optimizers, history, epoch, cfg):
2828
# main loop
2929
tic = time.time()
3030
for i in range(cfg.TRAIN.epoch_iters):
31+
# load a batch of data
3132
batch_data = next(iterator)
3233
data_time.update(time.time() - tic)
3334
segmentation_module.zero_grad()
3435

36+
# adjust learning rate
37+
cur_iter = i + (epoch - 1) * cfg.TRAIN.epoch_iters
38+
adjust_learning_rate(optimizers, cur_iter, cfg)
39+
3540
# forward pass
3641
loss, acc = segmentation_module(batch_data)
3742
loss = loss.mean()
@@ -65,10 +70,6 @@ def train(segmentation_module, iterator, optimizers, history, epoch, cfg):
6570
history['train']['loss'].append(loss.data.item())
6671
history['train']['acc'].append(acc.data.item())
6772

68-
# adjust learning rate
69-
cur_iter = i + (epoch - 1) * cfg.TRAIN.epoch_iters
70-
adjust_learning_rate(optimizers, cur_iter, cfg)
71-
7273

7374
def checkpoint(nets, history, cfg, epoch):
7475
print('Saving checkpoints...')

0 commit comments

Comments
 (0)