Skip to content

Commit 8dbbdd8

Browse files
authored
[Feature] Add model ensemble tools (open-mmlab#2218)
* [Feature] Add model ensemble tool * [Enhance] Add en and zh_cn instructions for model_ensemble * [Enhance] Add default-value for --out and modify instruction * [Enhance] Add arg-type for --out * [Enhance] Delete redundant code
1 parent 76a5138 commit 8dbbdd8

File tree

4 files changed

+205
-0
lines changed

4 files changed

+205
-0
lines changed

docs/en/useful_tools.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,32 @@ result/pred_result.pkl \
424424
result/confusion_matrix \
425425
--show
426426
```
427+
428+
## Model ensemble
429+
430+
To complete the integration of prediction probabilities for multiple models, we provide 'tools/model_ensemble.py'
431+
432+
### Usage
433+
434+
```bash
435+
python tools/model_ensemble.py \
436+
--config ${CONFIG_FILE1} ${CONFIG_FILE2} ... \
437+
--checkpoint ${CHECKPOINT_FILE1} ${CHECKPOINT_FILE2} ...\
438+
--aug-test \
439+
--out ${OUTPUT_DIR}\
440+
--gpus ${GPU_USED}\
441+
```
442+
443+
### Description of all arguments
444+
445+
- `--config`: Path to the config file for the ensemble model
446+
- `--checkpoint`: Path to the checkpoint file for the ensemble model
447+
- `--aug-test`: Whether to use flip and multi-scale test
448+
- `--out`: Save folder for model ensemble results
449+
- `--gpus`: Gpu-id used for model ensemble
450+
451+
### Result of model ensemble
452+
453+
- The model ensemble will generate an unrendered segmentation mask for each input, the input shape is `[H, W]`, the segmentation mask shape is `[H, W]`, and each pixel-value in the segmentation mask represents the pixel category after segmentation at that position.
454+
455+
- The filename of the model ensemble result will be named in the same filename as `Ground Truth`. If the filename of `Ground Truth` is called `1.png`, the model ensemble result file will also be named `1.png` and placed in the folder specified by `--out`.

docs/zh_cn/useful_tools.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,31 @@ configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py \
366366
checkpoint/fcn_r50-d8_512x1024_40k_cityscapes_20200604_192608-efe53f0d.pth \
367367
fcn
368368
```
369+
370+
## 模型集成
371+
372+
我们提供了`tools/model_ensemble.py` 完成对多个模型的预测概率进行集成的脚本
373+
374+
### 使用方法
375+
376+
```bash
377+
python tools/model_ensemble.py \
378+
--config ${CONFIG_FILE1} ${CONFIG_FILE2} ... \
379+
--checkpoint ${CHECKPOINT_FILE1} ${CHECKPOINT_FILE2} ...\
380+
--aug-test \
381+
--out ${OUTPUT_DIR}\
382+
--gpus ${GPU_USED}\
383+
```
384+
385+
### 各个参数的描述:
386+
387+
- `--config`: 集成模型的配置文件的路径
388+
- `--checkpoint`: 集成模型的权重文件的路径
389+
- `--aug-test`: 是否使用翻转和多尺度预测
390+
- `--out`: 模型集成结果的保存文件夹路径
391+
- `--gpus`: 模型集成使用的gpu-id
392+
393+
### 模型集成结果
394+
395+
- 模型集成会对每一张输入,形状为`[H, W]`,产生一张未渲染的分割掩膜文件(segmentation mask),形状为`[H, W]`,分割掩膜中的每个像素点的值代表该位置分割后的像素类别.
396+
- 模型集成结果的文件名会采用和`Ground Truth`一致的文件命名,如`Ground Truth`文件名称为`1.png`,则模型集成结果文件也会被命名为`1.png`,并放置在`--out`指定的文件夹中.

mmseg/models/segmentors/encoder_decoder.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,15 @@ def simple_test(self, img, img_meta, rescale=True):
278278
seg_pred = list(seg_pred)
279279
return seg_pred
280280

281+
def simple_test_logits(self, img, img_metas, rescale=True):
282+
"""Test without augmentations.
283+
284+
Return numpy seg_map logits.
285+
"""
286+
seg_logit = self.inference(img[0], img_metas[0], rescale)
287+
seg_logit = seg_logit.cpu().numpy()
288+
return seg_logit
289+
281290
def aug_test(self, imgs, img_metas, rescale=True):
282291
"""Test with augmentations.
283292
@@ -300,3 +309,21 @@ def aug_test(self, imgs, img_metas, rescale=True):
300309
# unravel batch dim
301310
seg_pred = list(seg_pred)
302311
return seg_pred
312+
313+
def aug_test_logits(self, img, img_metas, rescale=True):
314+
"""Test with augmentations.
315+
316+
Return seg_map logits. Only rescale=True is supported.
317+
"""
318+
# aug_test rescale all imgs back to ori_shape for now
319+
assert rescale
320+
321+
imgs = img
322+
seg_logit = self.inference(imgs[0], img_metas[0], rescale)
323+
for i in range(1, len(imgs)):
324+
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
325+
seg_logit += cur_seg_logit
326+
327+
seg_logit /= len(imgs)
328+
seg_logit = seg_logit.cpu().numpy()
329+
return seg_logit

tools/model_ensemble.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import argparse
3+
import os
4+
5+
import mmcv
6+
import numpy as np
7+
import torch
8+
from mmcv.parallel import MMDataParallel
9+
from mmcv.parallel.scatter_gather import scatter_kwargs
10+
from mmcv.runner import load_checkpoint, wrap_fp16_model
11+
from PIL import Image
12+
13+
from mmseg.datasets import build_dataloader, build_dataset
14+
from mmseg.models import build_segmentor
15+
16+
17+
@torch.no_grad()
18+
def main(args):
19+
20+
models = []
21+
gpu_ids = args.gpus
22+
configs = args.config
23+
ckpts = args.checkpoint
24+
25+
cfg = mmcv.Config.fromfile(configs[0])
26+
27+
if args.aug_test:
28+
cfg.data.test.pipeline[1].img_ratios = [
29+
0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
30+
]
31+
cfg.data.test.pipeline[1].flip = True
32+
else:
33+
cfg.data.test.pipeline[1].img_ratios = [1.0]
34+
cfg.data.test.pipeline[1].flip = False
35+
36+
torch.backends.cudnn.benchmark = True
37+
38+
# build the dataloader
39+
dataset = build_dataset(cfg.data.test)
40+
data_loader = build_dataloader(
41+
dataset,
42+
samples_per_gpu=1,
43+
workers_per_gpu=4,
44+
dist=False,
45+
shuffle=False,
46+
)
47+
48+
for idx, (config, ckpt) in enumerate(zip(configs, ckpts)):
49+
cfg = mmcv.Config.fromfile(config)
50+
cfg.model.pretrained = None
51+
cfg.data.test.test_mode = True
52+
53+
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
54+
if cfg.get('fp16', None):
55+
wrap_fp16_model(model)
56+
load_checkpoint(model, ckpt, map_location='cpu')
57+
torch.cuda.empty_cache()
58+
tmpdir = args.out
59+
mmcv.mkdir_or_exist(tmpdir)
60+
model = MMDataParallel(model, device_ids=[gpu_ids[idx % len(gpu_ids)]])
61+
model.eval()
62+
models.append(model)
63+
64+
dataset = data_loader.dataset
65+
prog_bar = mmcv.ProgressBar(len(dataset))
66+
loader_indices = data_loader.batch_sampler
67+
for batch_indices, data in zip(loader_indices, data_loader):
68+
result = []
69+
70+
for model in models:
71+
x, _ = scatter_kwargs(
72+
inputs=data, kwargs=None, target_gpus=model.device_ids)
73+
if args.aug_test:
74+
logits = model.module.aug_test_logits(**x[0])
75+
else:
76+
logits = model.module.simple_test_logits(**x[0])
77+
result.append(logits)
78+
79+
result_logits = 0
80+
for logit in result:
81+
result_logits += logit
82+
83+
pred = result_logits.argmax(axis=1).squeeze()
84+
img_info = dataset.img_infos[batch_indices[0]]
85+
file_name = os.path.join(
86+
tmpdir, img_info['ann']['seg_map'].split(os.path.sep)[-1])
87+
Image.fromarray(pred.astype(np.uint8)).save(file_name)
88+
prog_bar.update()
89+
90+
91+
def parse_args():
92+
parser = argparse.ArgumentParser(
93+
description='Model Ensemble with logits result')
94+
parser.add_argument(
95+
'--config', type=str, nargs='+', help='ensemble config files path')
96+
parser.add_argument(
97+
'--checkpoint',
98+
type=str,
99+
nargs='+',
100+
help='ensemble checkpoint files path')
101+
parser.add_argument(
102+
'--aug-test',
103+
action='store_true',
104+
help='control ensemble aug-result or single-result (default)')
105+
parser.add_argument(
106+
'--out', type=str, default='results', help='the dir to save result')
107+
parser.add_argument(
108+
'--gpus', type=int, nargs='+', default=[0], help='id of gpu to use')
109+
110+
args = parser.parse_args()
111+
assert len(args.config) == len(args.checkpoint), \
112+
f'len(config) must equal len(checkpoint), ' \
113+
f'but len(config) = {len(args.config)} and' \
114+
f'len(checkpoint) = {len(args.checkpoint)}'
115+
assert args.out, "ensemble result out-dir can't be None"
116+
return args
117+
118+
119+
if __name__ == '__main__':
120+
args = parse_args()
121+
main(args)

0 commit comments

Comments
 (0)