Skip to content

Commit f933603

Browse files
谢昕辰xvjiarui
谢昕辰
andauthored
Use MMCV's EvalHook in MMSegmentation (open-mmlab#438)
* mmcv eval hook * mmcv evalhook compatible * add warnings * inherit from base class * fix unitest * adapt to mmcv 1.3.1 * fixed unittest * set by_epoch=False * fixed efficient test * update docstring Co-authored-by: Jiarui XU <[email protected]>
1 parent e16e0e3 commit f933603

File tree

3 files changed

+57
-55
lines changed

3 files changed

+57
-55
lines changed

mmseg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .version import __version__, version_info
44

5-
MMCV_MIN = '1.1.4'
5+
MMCV_MIN = '1.3.1'
66
MMCV_MAX = '1.4.0'
77

88

mmseg/core/evaluation/eval_hooks.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,81 @@
11
import os.path as osp
22

3-
from mmcv.runner import Hook
4-
from torch.utils.data import DataLoader
3+
from mmcv.runner import DistEvalHook as _DistEvalHook
4+
from mmcv.runner import EvalHook as _EvalHook
55

66

7-
class EvalHook(Hook):
8-
"""Evaluation hook.
7+
class EvalHook(_EvalHook):
8+
"""Single GPU EvalHook, with efficient test support.
99
10-
Attributes:
11-
dataloader (DataLoader): A PyTorch dataloader.
12-
interval (int): Evaluation interval (by epochs). Default: 1.
10+
Args:
11+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
12+
If set to True, it will perform by epoch. Otherwise, by iteration.
13+
Default: False.
14+
efficient_test (bool): Whether save the results as local numpy files to
15+
save CPU memory during evaluation. Default: False.
16+
Returns:
17+
list: The prediction results.
1318
"""
1419

15-
def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs):
16-
if not isinstance(dataloader, DataLoader):
17-
raise TypeError('dataloader must be a pytorch DataLoader, but got '
18-
f'{type(dataloader)}')
19-
self.dataloader = dataloader
20-
self.interval = interval
21-
self.by_epoch = by_epoch
22-
self.eval_kwargs = eval_kwargs
20+
greater_keys = ['mIoU', 'mAcc', 'aAcc']
21+
22+
def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
23+
super().__init__(*args, by_epoch=by_epoch, **kwargs)
24+
self.efficient_test = efficient_test
2325

2426
def after_train_iter(self, runner):
25-
"""After train epoch hook."""
27+
"""After train epoch hook.
28+
29+
Override default ``single_gpu_test``.
30+
"""
2631
if self.by_epoch or not self.every_n_iters(runner, self.interval):
2732
return
2833
from mmseg.apis import single_gpu_test
2934
runner.log_buffer.clear()
30-
results = single_gpu_test(runner.model, self.dataloader, show=False)
35+
results = single_gpu_test(
36+
runner.model,
37+
self.dataloader,
38+
show=False,
39+
efficient_test=self.efficient_test)
3140
self.evaluate(runner, results)
3241

3342
def after_train_epoch(self, runner):
34-
"""After train epoch hook."""
43+
"""After train epoch hook.
44+
45+
Override default ``single_gpu_test``.
46+
"""
3547
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
3648
return
3749
from mmseg.apis import single_gpu_test
3850
runner.log_buffer.clear()
3951
results = single_gpu_test(runner.model, self.dataloader, show=False)
4052
self.evaluate(runner, results)
4153

42-
def evaluate(self, runner, results):
43-
"""Call evaluate function of dataset."""
44-
eval_res = self.dataloader.dataset.evaluate(
45-
results, logger=runner.logger, **self.eval_kwargs)
46-
for name, val in eval_res.items():
47-
runner.log_buffer.output[name] = val
48-
runner.log_buffer.ready = True
49-
5054

51-
class DistEvalHook(EvalHook):
52-
"""Distributed evaluation hook.
55+
class DistEvalHook(_DistEvalHook):
56+
"""Distributed EvalHook, with efficient test support.
5357
54-
Attributes:
55-
dataloader (DataLoader): A PyTorch dataloader.
56-
interval (int): Evaluation interval (by epochs). Default: 1.
57-
tmpdir (str | None): Temporary directory to save the results of all
58-
processes. Default: None.
59-
gpu_collect (bool): Whether to use gpu or cpu to collect results.
58+
Args:
59+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
60+
If set to True, it will perform by epoch. Otherwise, by iteration.
6061
Default: False.
62+
efficient_test (bool): Whether save the results as local numpy files to
63+
save CPU memory during evaluation. Default: False.
64+
Returns:
65+
list: The prediction results.
6166
"""
6267

63-
def __init__(self,
64-
dataloader,
65-
interval=1,
66-
gpu_collect=False,
67-
by_epoch=False,
68-
**eval_kwargs):
69-
if not isinstance(dataloader, DataLoader):
70-
raise TypeError(
71-
'dataloader must be a pytorch DataLoader, but got {}'.format(
72-
type(dataloader)))
73-
self.dataloader = dataloader
74-
self.interval = interval
75-
self.gpu_collect = gpu_collect
76-
self.by_epoch = by_epoch
77-
self.eval_kwargs = eval_kwargs
68+
greater_keys = ['mIoU', 'mAcc', 'aAcc']
69+
70+
def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
71+
super().__init__(*args, by_epoch=by_epoch, **kwargs)
72+
self.efficient_test = efficient_test
7873

7974
def after_train_iter(self, runner):
80-
"""After train epoch hook."""
75+
"""After train epoch hook.
76+
77+
Override default ``multi_gpu_test``.
78+
"""
8179
if self.by_epoch or not self.every_n_iters(runner, self.interval):
8280
return
8381
from mmseg.apis import multi_gpu_test
@@ -86,13 +84,17 @@ def after_train_iter(self, runner):
8684
runner.model,
8785
self.dataloader,
8886
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
89-
gpu_collect=self.gpu_collect)
87+
gpu_collect=self.gpu_collect,
88+
efficient_test=self.efficient_test)
9089
if runner.rank == 0:
9190
print('\n')
9291
self.evaluate(runner, results)
9392

9493
def after_train_epoch(self, runner):
95-
"""After train epoch hook."""
94+
"""After train epoch hook.
95+
96+
Override default ``multi_gpu_test``.
97+
"""
9698
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
9799
return
98100
from mmseg.apis import multi_gpu_test

tests/test_eval_hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_iter_eval_hook():
6363

6464
# test EvalHook
6565
with tempfile.TemporaryDirectory() as tmpdir:
66-
eval_hook = EvalHook(data_loader)
66+
eval_hook = EvalHook(data_loader, by_epoch=False)
6767
runner = mmcv.runner.IterBasedRunner(
6868
model=model,
6969
optimizer=optimizer,
@@ -143,7 +143,7 @@ def test_dist_eval_hook():
143143

144144
# test DistEvalHook
145145
with tempfile.TemporaryDirectory() as tmpdir:
146-
eval_hook = DistEvalHook(data_loader)
146+
eval_hook = DistEvalHook(data_loader, by_epoch=False)
147147
runner = mmcv.runner.IterBasedRunner(
148148
model=model,
149149
optimizer=optimizer,

0 commit comments

Comments
 (0)