Skip to content

Commit e384ef5

Browse files
authored
Add runner type (open-mmlab#118)
* Add runner_type option * pre-commit * Fix max_iters * Add by_epoch to EvalHook * Add test_eval_hook for epoch runner * Remove runner-type arg from tools/train * Add missing every_n_iters check for epoch mode * Bump mmcv min version * Use build_runner * Use interval in tests * Update test_eval_hook.py * Use every_n_epochs instead of every_n_iters. Update DistEvalHook * Add test_dist_eval_hook_epoch * Fix tests * Add DeprecationWarning * Update docs * Replace DeprecationWarning with UserWarning
1 parent 3bdc276 commit e384ef5

File tree

9 files changed

+134
-20
lines changed

9 files changed

+134
-20
lines changed

configs/_base_/schedules/schedule_160k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
# learning policy
55
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
66
# runtime settings
7-
total_iters = 160000
7+
runner = dict(type='IterBasedRunner', max_iters=160000)
88
checkpoint_config = dict(by_epoch=False, interval=16000)
99
evaluation = dict(interval=16000, metric='mIoU')

configs/_base_/schedules/schedule_20k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
# learning policy
55
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
66
# runtime settings
7-
total_iters = 20000
7+
runner = dict(type='IterBasedRunner', max_iters=20000)
88
checkpoint_config = dict(by_epoch=False, interval=2000)
99
evaluation = dict(interval=2000, metric='mIoU')

configs/_base_/schedules/schedule_40k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
# learning policy
55
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
66
# runtime settings
7-
total_iters = 40000
7+
runner = dict(type='IterBasedRunner', max_iters=40000)
88
checkpoint_config = dict(by_epoch=False, interval=4000)
99
evaluation = dict(interval=4000, metric='mIoU')

configs/_base_/schedules/schedule_80k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
# learning policy
55
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
66
# runtime settings
7-
total_iters = 80000
7+
runner = dict(type='IterBasedRunner', max_iters=80000)
88
checkpoint_config = dict(by_epoch=False, interval=8000)
99
evaluation = dict(interval=8000, metric='mIoU')

docs/config.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ dist_params = dict(backend='nccl') # Parameters to setup distributed training,
226226
log_level = 'INFO' # The level of logging.
227227
load_from = None # load models as a pre-trained model from a given path. This will not resume training.
228228
resume_from = None # Resume checkpoints from a given path, the training will be resumed from the iteration when the checkpoint's is saved.
229-
workflow = [('train', 1)] # Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. The workflow trains the model by 40000 iterations according to the total_iters.
229+
workflow = [('train', 1)] # Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. The workflow trains the model by 40000 iterations according to the `runner.max_iters`.
230230
cudnn_benchmark = True # Whether use cudnn_benchmark to speed up, which is fast for fixed input size.
231231
optimizer = dict( # Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
232232
type='SGD', # Type of optimizers, refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/optimizer/default_constructor.py#L13 for more details
@@ -239,7 +239,9 @@ lr_config = dict(
239239
power=0.9, # The power of polynomial decay.
240240
min_lr=0.0001, # The minimum learning rate to stable the training.
241241
by_epoch=False) # Whethe count by epoch or not.
242-
total_iters = 40000 # Total number of iterations.
242+
runner = dict(
243+
type='IterBasedRunner', # Type of runner to use (i.e. IterBasedRunner or EpochBasedRunner)
244+
max_iters=40000) # Total number of iterations. For EpochBasedRunner use `max_epochs`
243245
checkpoint_config = dict( # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
244246
by_epoch=False, # Whethe count by epoch or not.
245247
interval=4000) # The save interval.

mmseg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .version import __version__, version_info
44

5-
MMCV_MIN = '1.1.2'
5+
MMCV_MIN = '1.1.4'
66
MMCV_MAX = '1.2.0'
77

88

mmseg/apis/train.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import random
2+
import warnings
23

34
import numpy as np
45
import torch
56
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
6-
from mmcv.runner import IterBasedRunner, build_optimizer
7+
from mmcv.runner import build_optimizer, build_runner
78

89
from mmseg.core import DistEvalHook, EvalHook
910
from mmseg.datasets import build_dataloader, build_dataset
@@ -70,13 +71,21 @@ def train_segmentor(model,
7071
# build runner
7172
optimizer = build_optimizer(model, cfg.optimizer)
7273

73-
runner = IterBasedRunner(
74-
model=model,
75-
batch_processor=None,
76-
optimizer=optimizer,
77-
work_dir=cfg.work_dir,
78-
logger=logger,
79-
meta=meta)
74+
if cfg.get('runner') is None:
75+
cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
76+
warnings.warn(
77+
'config is now expected to have a `runner` section, '
78+
'please set `runner` in your config.', UserWarning)
79+
80+
runner = build_runner(
81+
cfg.runner,
82+
default_args=dict(
83+
model=model,
84+
batch_processor=None,
85+
optimizer=optimizer,
86+
work_dir=cfg.work_dir,
87+
logger=logger,
88+
meta=meta))
8089

8190
# register hooks
8291
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
@@ -96,11 +105,12 @@ def train_segmentor(model,
96105
dist=distributed,
97106
shuffle=False)
98107
eval_cfg = cfg.get('evaluation', {})
108+
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
99109
eval_hook = DistEvalHook if distributed else EvalHook
100110
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
101111

102112
if cfg.resume_from:
103113
runner.resume(cfg.resume_from)
104114
elif cfg.load_from:
105115
runner.load_checkpoint(cfg.load_from)
106-
runner.run(data_loaders, cfg.workflow, cfg.total_iters)
116+
runner.run(data_loaders, cfg.workflow)

mmseg/core/evaluation/eval_hooks.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,27 @@ class EvalHook(Hook):
1212
interval (int): Evaluation interval (by epochs). Default: 1.
1313
"""
1414

15-
def __init__(self, dataloader, interval=1, **eval_kwargs):
15+
def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs):
1616
if not isinstance(dataloader, DataLoader):
1717
raise TypeError('dataloader must be a pytorch DataLoader, but got '
1818
f'{type(dataloader)}')
1919
self.dataloader = dataloader
2020
self.interval = interval
21+
self.by_epoch = by_epoch
2122
self.eval_kwargs = eval_kwargs
2223

2324
def after_train_iter(self, runner):
2425
"""After train epoch hook."""
25-
if not self.every_n_iters(runner, self.interval):
26+
if self.by_epoch or not self.every_n_iters(runner, self.interval):
27+
return
28+
from mmseg.apis import single_gpu_test
29+
runner.log_buffer.clear()
30+
results = single_gpu_test(runner.model, self.dataloader, show=False)
31+
self.evaluate(runner, results)
32+
33+
def after_train_epoch(self, runner):
34+
"""After train epoch hook."""
35+
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
2636
return
2737
from mmseg.apis import single_gpu_test
2838
runner.log_buffer.clear()
@@ -54,6 +64,7 @@ def __init__(self,
5464
dataloader,
5565
interval=1,
5666
gpu_collect=False,
67+
by_epoch=False,
5768
**eval_kwargs):
5869
if not isinstance(dataloader, DataLoader):
5970
raise TypeError(
@@ -62,11 +73,27 @@ def __init__(self,
6273
self.dataloader = dataloader
6374
self.interval = interval
6475
self.gpu_collect = gpu_collect
76+
self.by_epoch = by_epoch
6577
self.eval_kwargs = eval_kwargs
6678

6779
def after_train_iter(self, runner):
6880
"""After train epoch hook."""
69-
if not self.every_n_iters(runner, self.interval):
81+
if self.by_epoch or not self.every_n_iters(runner, self.interval):
82+
return
83+
from mmseg.apis import multi_gpu_test
84+
runner.log_buffer.clear()
85+
results = multi_gpu_test(
86+
runner.model,
87+
self.dataloader,
88+
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
89+
gpu_collect=self.gpu_collect)
90+
if runner.rank == 0:
91+
print('\n')
92+
self.evaluate(runner, results)
93+
94+
def after_train_epoch(self, runner):
95+
"""After train epoch hook."""
96+
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
7097
return
7198
from mmseg.apis import multi_gpu_test
7299
runner.log_buffer.clear()

tests/test_eval_hook.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def train_step(self, data_batch, optimizer):
3838
return dict(loss=loss)
3939

4040

41-
def test_eval_hook():
41+
def test_iter_eval_hook():
4242
with pytest.raises(TypeError):
4343
test_dataset = ExampleModel()
4444
data_loader = [
@@ -75,6 +75,43 @@ def test_eval_hook():
7575
logger=runner.logger)
7676

7777

78+
def test_epoch_eval_hook():
79+
with pytest.raises(TypeError):
80+
test_dataset = ExampleModel()
81+
data_loader = [
82+
DataLoader(
83+
test_dataset,
84+
batch_size=1,
85+
sampler=None,
86+
num_worker=0,
87+
shuffle=False)
88+
]
89+
EvalHook(data_loader, by_epoch=True)
90+
91+
test_dataset = ExampleDataset()
92+
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
93+
loader = DataLoader(test_dataset, batch_size=1)
94+
model = ExampleModel()
95+
data_loader = DataLoader(
96+
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
97+
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
98+
optimizer = obj_from_dict(optim_cfg, torch.optim,
99+
dict(params=model.parameters()))
100+
101+
# test EvalHook with interval
102+
with tempfile.TemporaryDirectory() as tmpdir:
103+
eval_hook = EvalHook(data_loader, by_epoch=True, interval=2)
104+
runner = mmcv.runner.EpochBasedRunner(
105+
model=model,
106+
optimizer=optimizer,
107+
work_dir=tmpdir,
108+
logger=logging.getLogger())
109+
runner.register_hook(eval_hook)
110+
runner.run([loader], [('train', 1)], 2)
111+
test_dataset.evaluate.assert_called_once_with([torch.tensor([1])],
112+
logger=runner.logger)
113+
114+
78115
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
79116
results = single_gpu_test(model, data_loader)
80117
return results
@@ -116,3 +153,41 @@ def test_dist_eval_hook():
116153
runner.run([loader], [('train', 1)], 1)
117154
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
118155
logger=runner.logger)
156+
157+
158+
@patch('mmseg.apis.multi_gpu_test', multi_gpu_test)
159+
def test_dist_eval_hook_epoch():
160+
with pytest.raises(TypeError):
161+
test_dataset = ExampleModel()
162+
data_loader = [
163+
DataLoader(
164+
test_dataset,
165+
batch_size=1,
166+
sampler=None,
167+
num_worker=0,
168+
shuffle=False)
169+
]
170+
DistEvalHook(data_loader)
171+
172+
test_dataset = ExampleDataset()
173+
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
174+
loader = DataLoader(test_dataset, batch_size=1)
175+
model = ExampleModel()
176+
data_loader = DataLoader(
177+
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
178+
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
179+
optimizer = obj_from_dict(optim_cfg, torch.optim,
180+
dict(params=model.parameters()))
181+
182+
# test DistEvalHook
183+
with tempfile.TemporaryDirectory() as tmpdir:
184+
eval_hook = DistEvalHook(data_loader, by_epoch=True, interval=2)
185+
runner = mmcv.runner.EpochBasedRunner(
186+
model=model,
187+
optimizer=optimizer,
188+
work_dir=tmpdir,
189+
logger=logging.getLogger())
190+
runner.register_hook(eval_hook)
191+
runner.run([loader], [('train', 1)], 2)
192+
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
193+
logger=runner.logger)

0 commit comments

Comments
 (0)