Skip to content

Commit bb09883

Browse files
authored
add lr and total iter in trainer (tusen-ai#263)
1 parent 070fa02 commit bb09883

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

core/detection_module.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import time
2424
import logging
2525
import warnings
26+
from collections import namedtuple
2627

2728
import mxnet
2829

@@ -35,7 +36,6 @@
3536
from mxnet.module.executor_group import DataParallelExecutorGroup
3637
from mxnet.model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
3738
from mxnet.model import load_checkpoint
38-
from mxnet.model import BatchEndParam
3939
from mxnet.initializer import Uniform, InitDesc
4040
from mxnet.io import DataDesc
4141
from mxnet.ndarray import zeros
@@ -44,6 +44,10 @@
4444
from mxnet.module.module import Module
4545

4646

47+
BatchEndParam = namedtuple('BatchEndParams',
48+
['epoch', 'nbatch', 'eval_metric', 'lr', 'iter', 'locals'])
49+
50+
4751
class DetModule(BaseModule):
4852
"""Module is a basic module that wrap a `Symbol`. It is functionally the same
4953
as the `FeedForward` model, except under the module API.
@@ -1025,6 +1029,8 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
10251029
if batch_end_callback is not None:
10261030
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
10271031
eval_metric=eval_metric,
1032+
lr=self._optimizer.lr_scheduler(total_iter),
1033+
iter=total_iter,
10281034
locals=locals())
10291035
for callback in _as_list(batch_end_callback):
10301036
callback(batch_end_params)

detection_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def train_net(config):
235235
mode='cosine',
236236
base_lr=base_lr,
237237
target_lr=0,
238+
offset=pOpt.warmup.iter,
238239
niters=(iter_per_epoch * (end_epoch - begin_epoch)) - pOpt.warmup.iter
239240
)
240241
lr_scheduler = LRSequential([warmup_lr_scheduler, cosine_lr_scheduler])
@@ -248,6 +249,7 @@ def train_net(config):
248249
mode='cosine',
249250
base_lr=base_lr,
250251
target_lr=0,
252+
offset=pOpt.warmup.iter,
251253
niters=iter_per_epoch * (end_epoch - begin_epoch)
252254
)
253255
else:

utils/callback.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ def __call__(self, param):
2323
speed = self.frequent * self.batch_size / (time.time() - self.tic)
2424
if param.eval_metric is not None:
2525
name, value = param.eval_metric.get()
26-
s = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-" % (param.epoch, count, speed)
26+
s = "Epoch[%d] Batch [%d]\tIter: %d\tLr: %.5f\tSpeed: %.2f samples/sec\tTrain-" % \
27+
(param.epoch, count, param.iter, param.lr, speed)
2728
for n, v in zip(name, value):
2829
s += "%s=%f,\t" % (n, v)
2930
logging.info(s)
3031
else:
31-
logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
32+
logging.info("Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec",
3233
param.epoch, count, speed)
3334
self.tic = time.time()
3435
else:

0 commit comments

Comments
 (0)