Skip to content

Commit e235c1a

Browse files
clownrat6xvjiarui
andauthored
[Refactor] Support progressive test with fewer memory cost (open-mmlab#709)
* Support progressive test with fewer memory cost. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * support cityscapes eval * fixed cityscapes * 1. Add comments for batch_sampler; 2. Keep eval hook api same and add deprecated warning; 3. Add doc string for dataset.pre_eval; * Add efficient_test doc string. * Modify test tool to compat old version. * Modify eval hook to compat with old version. * Modify test api to compat old version api. * Sampler explanation. * update warning * Modify deploy_test.py * compatible with old output, add efficient test back * clear logic of exclusive * Warning about efficient_test. * Modify format_results save folder. * Fix bugs of format_results. * Modify deploy_test.py. * Update doc * Fix deploy test bugs. * Fix custom dataset unit tests. * Fix dataset unit tests. * Fix eval hook unit tests. * Fix some imcompatible. * Add pre_eval argument for eval hooks. * Update eval hook doc string. * Make pre_eval false in default. * Add unit tests for dataset format_results. * Fix some comments and bc-breaking bug. * Fix pre_eval set cfg field. * Remove redundant codes. Co-authored-by: Jiarui XU <[email protected]>
1 parent 99d8376 commit e235c1a

22 files changed

+652
-191
lines changed

configs/_base_/schedules/schedule_160k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
# runtime settings
77
runner = dict(type='IterBasedRunner', max_iters=160000)
88
checkpoint_config = dict(by_epoch=False, interval=16000)
9-
evaluation = dict(interval=16000, metric='mIoU')
9+
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)

configs/_base_/schedules/schedule_20k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
# runtime settings
77
runner = dict(type='IterBasedRunner', max_iters=20000)
88
checkpoint_config = dict(by_epoch=False, interval=2000)
9-
evaluation = dict(interval=2000, metric='mIoU')
9+
evaluation = dict(interval=2000, metric='mIoU', pre_eval=True)

configs/_base_/schedules/schedule_40k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
# runtime settings
77
runner = dict(type='IterBasedRunner', max_iters=40000)
88
checkpoint_config = dict(by_epoch=False, interval=4000)
9-
evaluation = dict(interval=4000, metric='mIoU')
9+
evaluation = dict(interval=4000, metric='mIoU', pre_eval=True)

configs/_base_/schedules/schedule_80k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
# runtime settings
77
runner = dict(type='IterBasedRunner', max_iters=80000)
88
checkpoint_config = dict(by_epoch=False, interval=8000)
9-
evaluation = dict(interval=8000, metric='mIoU')
9+
evaluation = dict(interval=8000, metric='mIoU', pre_eval=True)

docs/inference.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
2121

2222
Optional arguments:
2323

24-
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
24+
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file. (After mmseg v0.17, the output results become pre-evaluation results or format result paths)
2525
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `mIoU` is available for all dataset. Cityscapes could be evaluated by `cityscapes` as well as standard `mIoU` metrics.
2626
- `--show`: If specified, segmentation results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`.
2727
- `--show-dir`: If specified, segmentation results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option.
28-
- `--eval-options`: Optional parameters during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB).
28+
- `--eval-options`: Optional parameters for `dataset.format_results` and `dataset.evaluate` during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB). (`efficient_test` argument does not have effect after mmseg v0.17, we use a progressive mode to evaluation and format results which can largely save memory cost and evaluation time.)
2929

3030
Examples:
3131

@@ -98,4 +98,4 @@ Assume that you have already downloaded the checkpoints to the directory `checkp
9898
--eval mIoU
9999
```
100100

101-
Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory.
101+
Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory. (After mmseg v0.17, efficient_test has not effect and we use a progressive mode to evaluation and format results efficiently by default.)

docs_zh-CN/inference.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}]
2020

2121
可选参数:
2222

23-
- `RESULT_FILE`: pickle 格式的输出结果的文件名,如果不专门指定,结果将不会被专门保存成文件
24-
- `EVAL_METRICS`: 在结果里将被评估的指标这主要取决于数据集, `mIoU` 对于所有数据集都可获得,像 Cityscapes 数据集可以通过 `cityscapes` 命令来专门评估,就像标准的 `mIoU`一样
25-
- `--show`: 如果被指定,分割结果将会在一张图像里画出来并且在另一个窗口展示它仅仅是用来调试与可视化,并且仅针对单卡 GPU 测试请确认 GUI 在您的环境里可用,否则您也许会遇到报错 `cannot connect to X server`
26-
- `--show-dir`: 如果被指定,分割结果将会在一张图像里画出来并且保存在指定文件夹里它仅仅是用来调试与可视化,并且仅针对单卡GPU测试使用该参数时,您的环境不需要 GUI
27-
- `--eval-options`: 评估时的可选参数,当设置 `efficient_test=True` 时,它将会保存中间结果至本地文件里以节约 CPU 内存请确认您本地硬盘有足够的存储空间(大于20GB)
23+
- `RESULT_FILE`: pickle 格式的输出结果的文件名,如果不专门指定,结果将不会被专门保存成文件。(MMseg v0.17 之后,args.out 将只会保存评估时的中间结果或者是分割图的保存路径。)
24+
- `EVAL_METRICS`: 在结果里将被评估的指标这主要取决于数据集, `mIoU` 对于所有数据集都可获得,像 Cityscapes 数据集可以通过 `cityscapes` 命令来专门评估,就像标准的 `mIoU`一样
25+
- `--show`: 如果被指定,分割结果将会在一张图像里画出来并且在另一个窗口展示它仅仅是用来调试与可视化,并且仅针对单卡 GPU 测试请确认 GUI 在您的环境里可用,否则您也许会遇到报错 `cannot connect to X server`
26+
- `--show-dir`: 如果被指定,分割结果将会在一张图像里画出来并且保存在指定文件夹里它仅仅是用来调试与可视化,并且仅针对单卡GPU测试使用该参数时,您的环境不需要 GUI
27+
- `--eval-options`: 评估时的可选参数,当设置 `efficient_test=True` 时,它将会保存中间结果至本地文件里以节约 CPU 内存请确认您本地硬盘有足够的存储空间(大于20GB)。(MMseg v0.17 之后,`efficient_test` 不再生效,我们重构了 test api,通过使用一种渐近式的方式来提升评估和保存结果的效率。
2828

2929
例子:
3030

@@ -96,4 +96,4 @@ python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}]
9696
--eval mIoU
9797
```
9898

99-
使用 ```pmap``` 可查看 CPU 内存情况, ```efficient_test=True``` 会使用约 2.25GB 的 CPU 内存, ```efficient_test=False``` 会使用约 11.06GB 的 CPU 内存。 这个可选参数可以节约很多 CPU 内存。
99+
使用 ```pmap``` 可查看 CPU 内存情况, ```efficient_test=True``` 会使用约 2.25GB 的 CPU 内存, ```efficient_test=False``` 会使用约 11.06GB 的 CPU 内存。 这个可选参数可以节约很多 CPU 内存。(MMseg v0.17 之后, `efficient_test` 参数将不再生效, 我们使用了一种渐近的方式来更加有效快速地评估和保存结果。)

mmseg/apis/test.py

Lines changed: 101 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import os.path as osp
33
import tempfile
4+
import warnings
45

56
import mmcv
67
import numpy as np
@@ -19,7 +20,6 @@ def np2tmp(array, temp_file_name=None, tmpdir=None):
1920
function will generate a file name with tempfile.NamedTemporaryFile
2021
to save ndarray. Default: None.
2122
tmpdir (str): Temporary directory to save Ndarray files. Default: None.
22-
2323
Returns:
2424
str: The numpy file name.
2525
"""
@@ -36,8 +36,11 @@ def single_gpu_test(model,
3636
show=False,
3737
out_dir=None,
3838
efficient_test=False,
39-
opacity=0.5):
40-
"""Test with single GPU.
39+
opacity=0.5,
40+
pre_eval=False,
41+
format_only=False,
42+
format_args={}):
43+
"""Test with single GPU by progressive mode.
4144
4245
Args:
4346
model (nn.Module): Model to be tested.
@@ -46,24 +49,60 @@ def single_gpu_test(model,
4649
out_dir (str, optional): If specified, the results will be dumped into
4750
the directory to save output results.
4851
efficient_test (bool): Whether save the results as local numpy files to
49-
save CPU memory during evaluation. Default: False.
52+
save CPU memory during evaluation. Mutually exclusive with
53+
pre_eval and format_results. Default: False.
5054
opacity(float): Opacity of painted segmentation map.
5155
Default 0.5.
5256
Must be in (0, 1] range.
57+
pre_eval (bool): Use dataset.pre_eval() function to generate
58+
pre_results for metric evaluation. Mutually exclusive with
59+
efficient_test and format_results. Default: False.
60+
format_only (bool): Only format result for results commit.
61+
Mutually exclusive with pre_eval and efficient_test.
62+
Default: False.
63+
format_args (dict): The args for format_results. Default: {}.
5364
Returns:
54-
list: The prediction results.
65+
list: list of evaluation pre-results or list of save file names.
5566
"""
67+
if efficient_test:
68+
warnings.warn(
69+
'DeprecationWarning: ``efficient_test`` will be deprecated, the '
70+
'evaluation is CPU memory friendly with pre_eval=True')
71+
mmcv.mkdir_or_exist('.efficient_test')
72+
# when none of them is set true, return segmentation results as
73+
# a list of np.array.
74+
assert [efficient_test, pre_eval, format_only].count(True) <= 1, \
75+
'``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \
76+
'exclusive, only one of them could be true .'
5677

5778
model.eval()
5879
results = []
5980
dataset = data_loader.dataset
6081
prog_bar = mmcv.ProgressBar(len(dataset))
61-
if efficient_test:
62-
mmcv.mkdir_or_exist('.efficient_test')
63-
for i, data in enumerate(data_loader):
82+
# The pipeline about how the data_loader retrieval samples from dataset:
83+
# sampler -> batch_sampler -> indices
84+
# The indices are passed to dataset_fetcher to get data from dataset.
85+
# data_fetcher -> collate_fn(dataset[index]) -> data_sample
86+
# we use batch_sampler to get correct data idx
87+
loader_indices = data_loader.batch_sampler
88+
89+
for batch_indices, data in zip(loader_indices, data_loader):
6490
with torch.no_grad():
6591
result = model(return_loss=False, **data)
6692

93+
if efficient_test:
94+
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
95+
96+
if format_only:
97+
result = dataset.format_results(
98+
result, indices=batch_indices, **format_args)
99+
if pre_eval:
100+
# TODO: adapt samples_per_gpu > 1.
101+
# only samples_per_gpu=1 valid now
102+
result = dataset.pre_eval(result, indices=batch_indices)
103+
104+
results.extend(result)
105+
67106
if show or out_dir:
68107
img_tensor = data['img'][0]
69108
img_metas = data['img_metas'][0].data[0]
@@ -90,27 +129,22 @@ def single_gpu_test(model,
90129
out_file=out_file,
91130
opacity=opacity)
92131

93-
if isinstance(result, list):
94-
if efficient_test:
95-
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
96-
results.extend(result)
97-
else:
98-
if efficient_test:
99-
result = np2tmp(result, tmpdir='.efficient_test')
100-
results.append(result)
101-
102132
batch_size = len(result)
103133
for _ in range(batch_size):
104134
prog_bar.update()
135+
105136
return results
106137

107138

108139
def multi_gpu_test(model,
109140
data_loader,
110141
tmpdir=None,
111142
gpu_collect=False,
112-
efficient_test=False):
113-
"""Test model with multiple gpus.
143+
efficient_test=False,
144+
pre_eval=False,
145+
format_only=False,
146+
format_args={}):
147+
"""Test model with multiple gpus by progressive mode.
114148
115149
This method tests model with multiple gpus and collects the results
116150
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
@@ -123,39 +157,71 @@ def multi_gpu_test(model,
123157
data_loader (utils.data.Dataloader): Pytorch data loader.
124158
tmpdir (str): Path of directory to save the temporary results from
125159
different gpus under cpu mode. The same path is used for efficient
126-
test.
160+
test. Default: None.
127161
gpu_collect (bool): Option to use either gpu or cpu to collect results.
162+
Default: False.
128163
efficient_test (bool): Whether save the results as local numpy files to
129-
save CPU memory during evaluation. Default: False.
164+
save CPU memory during evaluation. Mutually exclusive with
165+
pre_eval and format_results. Default: False.
166+
pre_eval (bool): Use dataset.pre_eval() function to generate
167+
pre_results for metric evaluation. Mutually exclusive with
168+
efficient_test and format_results. Default: False.
169+
format_only (bool): Only format result for results commit.
170+
Mutually exclusive with pre_eval and efficient_test.
171+
Default: False.
172+
format_args (dict): The args for format_results. Default: {}.
130173
131174
Returns:
132-
list: The prediction results.
175+
list: list of evaluation pre-results or list of save file names.
133176
"""
177+
if efficient_test:
178+
warnings.warn(
179+
'DeprecationWarning: ``efficient_test`` will be deprecated, the '
180+
'evaluation is CPU memory friendly with pre_eval=True')
181+
mmcv.mkdir_or_exist('.efficient_test')
182+
# when none of them is set true, return segmentation results as
183+
# a list of np.array.
184+
assert [efficient_test, pre_eval, format_only].count(True) <= 1, \
185+
'``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \
186+
'exclusive, only one of them could be true .'
134187

135188
model.eval()
136189
results = []
137190
dataset = data_loader.dataset
191+
# The pipeline about how the data_loader retrieval samples from dataset:
192+
# sampler -> batch_sampler -> indices
193+
# The indices are passed to dataset_fetcher to get data from dataset.
194+
# data_fetcher -> collate_fn(dataset[index]) -> data_sample
195+
# we use batch_sampler to get correct data idx
196+
197+
# batch_sampler based on DistributedSampler, the indices only point to data
198+
# samples of related machine.
199+
loader_indices = data_loader.batch_sampler
200+
138201
rank, world_size = get_dist_info()
139202
if rank == 0:
140203
prog_bar = mmcv.ProgressBar(len(dataset))
141-
if efficient_test:
142-
mmcv.mkdir_or_exist('.efficient_test')
143-
for i, data in enumerate(data_loader):
204+
205+
for batch_indices, data in zip(loader_indices, data_loader):
144206
with torch.no_grad():
145207
result = model(return_loss=False, rescale=True, **data)
146208

147-
if isinstance(result, list):
148-
if efficient_test:
149-
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
150-
results.extend(result)
151-
else:
152-
if efficient_test:
153-
result = np2tmp(result, tmpdir='.efficient_test')
154-
results.append(result)
209+
if efficient_test:
210+
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
211+
212+
if format_only:
213+
result = dataset.format_results(
214+
result, indices=batch_indices, **format_args)
215+
if pre_eval:
216+
# TODO: adapt samples_per_gpu > 1.
217+
# only samples_per_gpu=1 valid now
218+
result = dataset.pre_eval(result, indices=batch_indices)
219+
220+
results.extend(result)
155221

156222
if rank == 0:
157-
batch_size = len(result)
158-
for _ in range(batch_size * world_size):
223+
batch_size = len(result) * world_size
224+
for _ in range(batch_size):
159225
prog_bar.update()
160226

161227
# collect results from all ranks

mmseg/core/evaluation/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .class_names import get_classes, get_palette
33
from .eval_hooks import DistEvalHook, EvalHook
4-
from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
4+
from .metrics import (eval_metrics, intersect_and_union, mean_dice,
5+
mean_fscore, mean_iou, pre_eval_to_metrics)
56

67
__all__ = [
78
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
8-
'eval_metrics', 'get_classes', 'get_palette'
9+
'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics',
10+
'intersect_and_union'
911
]

mmseg/core/evaluation/eval_hooks.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import os.path as osp
3+
import warnings
34

45
import torch.distributed as dist
56
from mmcv.runner import DistEvalHook as _DistEvalHook
@@ -16,15 +17,28 @@ class EvalHook(_EvalHook):
1617
Default: False.
1718
efficient_test (bool): Whether save the results as local numpy files to
1819
save CPU memory during evaluation. Default: False.
20+
pre_eval (bool): Whether to use progressive mode to evaluate model.
21+
Default: False.
1922
Returns:
2023
list: The prediction results.
2124
"""
2225

2326
greater_keys = ['mIoU', 'mAcc', 'aAcc']
2427

25-
def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
28+
def __init__(self,
29+
*args,
30+
by_epoch=False,
31+
efficient_test=False,
32+
pre_eval=False,
33+
**kwargs):
2634
super().__init__(*args, by_epoch=by_epoch, **kwargs)
27-
self.efficient_test = efficient_test
35+
self.pre_eval = pre_eval
36+
if efficient_test:
37+
warnings.warn(
38+
'DeprecationWarning: ``efficient_test`` for evaluation hook '
39+
'is deprecated, the evaluation hook is CPU memory friendly '
40+
'with ``pre_eval=True`` as argument for ``single_gpu_test()`` '
41+
'function')
2842

2943
def _do_evaluate(self, runner):
3044
"""perform evaluation and save ckpt."""
@@ -33,10 +47,8 @@ def _do_evaluate(self, runner):
3347

3448
from mmseg.apis import single_gpu_test
3549
results = single_gpu_test(
36-
runner.model,
37-
self.dataloader,
38-
show=False,
39-
efficient_test=self.efficient_test)
50+
runner.model, self.dataloader, show=False, pre_eval=self.pre_eval)
51+
runner.log_buffer.clear()
4052
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
4153
key_score = self.evaluate(runner, results)
4254
if self.save_best:
@@ -52,15 +64,28 @@ class DistEvalHook(_DistEvalHook):
5264
Default: False.
5365
efficient_test (bool): Whether save the results as local numpy files to
5466
save CPU memory during evaluation. Default: False.
67+
pre_eval (bool): Whether to use progressive mode to evaluate model.
68+
Default: False.
5569
Returns:
5670
list: The prediction results.
5771
"""
5872

5973
greater_keys = ['mIoU', 'mAcc', 'aAcc']
6074

61-
def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
75+
def __init__(self,
76+
*args,
77+
by_epoch=False,
78+
efficient_test=False,
79+
pre_eval=False,
80+
**kwargs):
6281
super().__init__(*args, by_epoch=by_epoch, **kwargs)
63-
self.efficient_test = efficient_test
82+
self.pre_eval = pre_eval
83+
if efficient_test:
84+
warnings.warn(
85+
'DeprecationWarning: ``efficient_test`` for evaluation hook '
86+
'is deprecated, the evaluation hook is CPU memory friendly '
87+
'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` '
88+
'function')
6489

6590
def _do_evaluate(self, runner):
6691
"""perform evaluation and save ckpt."""
@@ -90,7 +115,10 @@ def _do_evaluate(self, runner):
90115
self.dataloader,
91116
tmpdir=tmpdir,
92117
gpu_collect=self.gpu_collect,
93-
efficient_test=self.efficient_test)
118+
pre_eval=self.pre_eval)
119+
120+
runner.log_buffer.clear()
121+
94122
if runner.rank == 0:
95123
print('\n')
96124
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)

0 commit comments

Comments
 (0)