Skip to content

Commit 9dd3e15

Browse files
[feature]: Able to use save_best option (open-mmlab#575)
* Add save_best option in eval_hook. * Update meta to fix best model can not test bug * refactor with _do_evaluate * remove redundent * add meta Co-authored-by: Jiarui XU <[email protected]>
1 parent a95f6d8 commit 9dd3e15

File tree

3 files changed

+46
-51
lines changed

3 files changed

+46
-51
lines changed

mmseg/core/evaluation/eval_hooks.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os.path as osp
22

3+
import torch.distributed as dist
34
from mmcv.runner import DistEvalHook as _DistEvalHook
45
from mmcv.runner import EvalHook as _EvalHook
6+
from torch.nn.modules.batchnorm import _BatchNorm
57

68

79
class EvalHook(_EvalHook):
@@ -23,33 +25,17 @@ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
2325
super().__init__(*args, by_epoch=by_epoch, **kwargs)
2426
self.efficient_test = efficient_test
2527

26-
def after_train_iter(self, runner):
27-
"""After train epoch hook.
28-
29-
Override default ``single_gpu_test``.
30-
"""
31-
if self.by_epoch or not self.every_n_iters(runner, self.interval):
28+
def _do_evaluate(self, runner):
29+
"""perform evaluation and save ckpt."""
30+
if not self._should_evaluate(runner):
3231
return
33-
from mmseg.apis import single_gpu_test
34-
runner.log_buffer.clear()
35-
results = single_gpu_test(
36-
runner.model,
37-
self.dataloader,
38-
show=False,
39-
efficient_test=self.efficient_test)
40-
self.evaluate(runner, results)
4132

42-
def after_train_epoch(self, runner):
43-
"""After train epoch hook.
44-
45-
Override default ``single_gpu_test``.
46-
"""
47-
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
48-
return
4933
from mmseg.apis import single_gpu_test
50-
runner.log_buffer.clear()
5134
results = single_gpu_test(runner.model, self.dataloader, show=False)
52-
self.evaluate(runner, results)
35+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
36+
key_score = self.evaluate(runner, results)
37+
if self.save_best:
38+
self._save_ckpt(runner, key_score)
5339

5440

5541
class DistEvalHook(_DistEvalHook):
@@ -71,39 +57,38 @@ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
7157
super().__init__(*args, by_epoch=by_epoch, **kwargs)
7258
self.efficient_test = efficient_test
7359

74-
def after_train_iter(self, runner):
75-
"""After train epoch hook.
76-
77-
Override default ``multi_gpu_test``.
78-
"""
79-
if self.by_epoch or not self.every_n_iters(runner, self.interval):
60+
def _do_evaluate(self, runner):
61+
"""perform evaluation and save ckpt."""
62+
# Synchronization of BatchNorm's buffer (running_mean
63+
# and running_var) is not supported in the DDP of pytorch,
64+
# which may cause the inconsistent performance of models in
65+
# different ranks, so we broadcast BatchNorm's buffers
66+
# of rank 0 to other ranks to avoid this.
67+
if self.broadcast_bn_buffer:
68+
model = runner.model
69+
for name, module in model.named_modules():
70+
if isinstance(module,
71+
_BatchNorm) and module.track_running_stats:
72+
dist.broadcast(module.running_var, 0)
73+
dist.broadcast(module.running_mean, 0)
74+
75+
if not self._should_evaluate(runner):
8076
return
81-
from mmseg.apis import multi_gpu_test
82-
runner.log_buffer.clear()
83-
results = multi_gpu_test(
84-
runner.model,
85-
self.dataloader,
86-
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
87-
gpu_collect=self.gpu_collect,
88-
efficient_test=self.efficient_test)
89-
if runner.rank == 0:
90-
print('\n')
91-
self.evaluate(runner, results)
9277

93-
def after_train_epoch(self, runner):
94-
"""After train epoch hook.
78+
tmpdir = self.tmpdir
79+
if tmpdir is None:
80+
tmpdir = osp.join(runner.work_dir, '.eval_hook')
9581

96-
Override default ``multi_gpu_test``.
97-
"""
98-
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
99-
return
10082
from mmseg.apis import multi_gpu_test
101-
runner.log_buffer.clear()
10283
results = multi_gpu_test(
10384
runner.model,
10485
self.dataloader,
105-
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
86+
tmpdir=tmpdir,
10687
gpu_collect=self.gpu_collect)
10788
if runner.rank == 0:
10889
print('\n')
109-
self.evaluate(runner, results)
90+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
91+
key_score = self.evaluate(runner, results)
92+
93+
if self.save_best:
94+
self._save_ckpt(runner, key_score)

tools/test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,16 @@ def main():
122122
if fp16_cfg is not None:
123123
wrap_fp16_model(model)
124124
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
125-
model.CLASSES = checkpoint['meta']['CLASSES']
126-
model.PALETTE = checkpoint['meta']['PALETTE']
125+
if 'CLASSES' in checkpoint.get('meta', {}):
126+
model.CLASSES = checkpoint['meta']['CLASSES']
127+
else:
128+
print('"CLASSES" not found in meta, use dataset.CLASSES instead')
129+
model.CLASSES = dataset.CLASSES
130+
if 'PALETTE' in checkpoint.get('meta', {}):
131+
model.PALETTE = checkpoint['meta']['PALETTE']
132+
else:
133+
print('"PALETTE" not found in meta, use dataset.PALETTE instead')
134+
model.PALETTE = dataset.PALETTE
127135

128136
efficient_test = False
129137
if args.eval_options is not None:

tools/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def main():
149149
PALETTE=datasets[0].PALETTE)
150150
# add an attribute for visualization convenience
151151
model.CLASSES = datasets[0].CLASSES
152+
# passing checkpoint meta for saving best checkpoint
153+
meta.update(cfg.checkpoint_config.meta)
152154
train_segmentor(
153155
model,
154156
datasets,

0 commit comments

Comments
 (0)