|
23 | 23 | import time |
24 | 24 | import logging |
25 | 25 | import warnings |
| 26 | +from collections import namedtuple |
26 | 27 |
|
27 | 28 | import mxnet |
28 | 29 |
|
|
35 | 36 | from mxnet.module.executor_group import DataParallelExecutorGroup |
36 | 37 | from mxnet.model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore |
37 | 38 | from mxnet.model import load_checkpoint |
38 | | -from mxnet.model import BatchEndParam |
39 | 39 | from mxnet.initializer import Uniform, InitDesc |
40 | 40 | from mxnet.io import DataDesc |
41 | 41 | from mxnet.ndarray import zeros |
|
44 | 44 | from mxnet.module.module import Module |
45 | 45 |
|
46 | 46 |
|
| 47 | +BatchEndParam = namedtuple('BatchEndParams', |
| 48 | + ['epoch', 'nbatch', 'eval_metric', 'lr', 'iter', 'locals']) |
| 49 | + |
| 50 | + |
47 | 51 | class DetModule(BaseModule): |
48 | 52 | """Module is a basic module that wrap a `Symbol`. It is functionally the same |
49 | 53 | as the `FeedForward` model, except under the module API. |
@@ -1025,6 +1029,8 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', |
1025 | 1029 | if batch_end_callback is not None: |
1026 | 1030 | batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, |
1027 | 1031 | eval_metric=eval_metric, |
| 1032 | + lr=self._optimizer.lr_scheduler(total_iter), |
| 1033 | + iter=total_iter, |
1028 | 1034 | locals=locals()) |
1029 | 1035 | for callback in _as_list(batch_end_callback): |
1030 | 1036 | callback(batch_end_params) |
|
0 commit comments