Skip to content

Commit bdf5adf

Browse files
authored
[Feature] add onnxruntime test tool (open-mmlab#498)
* add onnxruntime test tool, update pytorch2onnx to support slice export * onnx convert with custom output shape, update test code * update pytorch2onnx, add rescale_shape support, add document * update doc for lint error fixing * remove cpu flag in ort_test.py * change class name, fix cuda error * remote comment * fix bug of torch2onnx * mIOU to mIoU
1 parent d568d06 commit bdf5adf

File tree

3 files changed

+320
-33
lines changed

3 files changed

+320
-33
lines changed

docs/useful_tools.md

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ python tools/pytorch2onnx.py \
5353
--output-file ${ONNX_FILE} \
5454
--input-img ${INPUT_IMG} \
5555
--shape ${INPUT_SHAPE} \
56+
--rescale-shape ${RESCALE_SHAPE} \
5657
--show \
5758
--verify \
5859
--dynamic-export \
@@ -66,14 +67,64 @@ Description of arguments:
6667
- `--checkpoint` : The path of a model checkpoint file.
6768
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
6869
- `--input-img` : The path of an input image for conversion and visualize.
69-
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`.
70+
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to img_scale of testpipeline.
71+
- `--rescale-shape`: rescale shape of output, set this value to avoid OOM, only work on `slide` mode.
7072
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
7173
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
7274
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
7375
- `--cfg-options`:Update config options.
7476

7577
**Note**: This tool is still experimental. Some customized operators are not supported for now.
7678

79+
### Evaluate ONNX model with ONNXRuntime
80+
81+
We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
82+
83+
#### Prerequisite
84+
85+
- Install onnx and onnxruntime-gpu
86+
87+
```shell
88+
pip install onnx onnxruntime-gpu
89+
```
90+
91+
#### Usage
92+
93+
```python
94+
python tools/ort_test.py \
95+
${CONFIG_FILE} \
96+
${ONNX_FILE} \
97+
--out ${OUTPUT_FILE} \
98+
--eval ${EVALUATION_METRICS} \
99+
--show \
100+
--show-dir ${SHOW_DIRECTORY} \
101+
--options ${CFG_OPTIONS} \
102+
--eval-options ${EVALUATION_OPTIONS} \
103+
--opacity ${OPACITY} \
104+
```
105+
106+
Description of all arguments
107+
108+
- `config`: The path of a model config file.
109+
- `model`: The path of a ONNX model file.
110+
- `--out`: The path of output result file in pickle format.
111+
- `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. Note that this argument is **mutually exclusive** with `--eval`.
112+
- `--eval`: Evaluation metrics, which depends on the dataset, e.g., "mIoU" for generic datasets, and "cityscapes" for Cityscapes. Note that this argument is **mutually exclusive** with `--format-only`.
113+
- `--show`: Show results flag.
114+
- `--show-dir`: Directory where painted images will be saved
115+
- `--options`: Override some settings in the used config file, the key-value pair in `xxx=yyy` format will be merged into config file.
116+
- `--eval-options`: Custom options for evaluation, the key-value pair in `xxx=yyy` format will be kwargs for `dataset.evaluate()` function
117+
- `--opacity`: Opacity of painted segmentation map. In (0, 1] range.
118+
119+
#### Results and Models
120+
121+
| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime |
122+
| :--------: | :--------------------------------------------: | :--------: | :----: | :-----: | :---------: |
123+
| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 |
124+
| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 |
125+
| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 |
126+
| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 |
127+
77128
### Convert to TorchScript (experimental)
78129

79130
We also provide a script to convert model to [TorchScript](https://pytorch.org/docs/stable/jit.html) format. You can use the pytorch C++ API [LibTorch](https://pytorch.org/docs/stable/cpp_index.html) inference the trained model. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and TorchScript model.

tools/ort_test.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import argparse
2+
import os
3+
import os.path as osp
4+
import warnings
5+
6+
import mmcv
7+
import numpy as np
8+
import onnxruntime as ort
9+
import torch
10+
from mmcv.parallel import MMDataParallel
11+
from mmcv.runner import get_dist_info
12+
from mmcv.utils import DictAction
13+
14+
from mmseg.apis import single_gpu_test
15+
from mmseg.datasets import build_dataloader, build_dataset
16+
from mmseg.models.segmentors.base import BaseSegmentor
17+
18+
19+
class ONNXRuntimeSegmentor(BaseSegmentor):
20+
21+
def __init__(self, onnx_file, cfg, device_id):
22+
super(ONNXRuntimeSegmentor, self).__init__()
23+
# get the custom op path
24+
ort_custom_op_path = ''
25+
try:
26+
from mmcv.ops import get_onnxruntime_op_path
27+
ort_custom_op_path = get_onnxruntime_op_path()
28+
except (ImportError, ModuleNotFoundError):
29+
warnings.warn('If input model has custom op from mmcv, \
30+
you may have to build mmcv with ONNXRuntime from source.')
31+
session_options = ort.SessionOptions()
32+
# register custom op for onnxruntime
33+
if osp.exists(ort_custom_op_path):
34+
session_options.register_custom_ops_library(ort_custom_op_path)
35+
sess = ort.InferenceSession(onnx_file, session_options)
36+
providers = ['CPUExecutionProvider']
37+
options = [{}]
38+
is_cuda_available = ort.get_device() == 'GPU'
39+
if is_cuda_available:
40+
providers.insert(0, 'CUDAExecutionProvider')
41+
options.insert(0, {'device_id': device_id})
42+
43+
sess.set_providers(providers, options)
44+
45+
self.sess = sess
46+
self.device_id = device_id
47+
self.io_binding = sess.io_binding()
48+
self.output_names = [_.name for _ in sess.get_outputs()]
49+
for name in self.output_names:
50+
self.io_binding.bind_output(name)
51+
self.cfg = cfg
52+
self.test_mode = cfg.model.test_cfg.mode
53+
54+
def extract_feat(self, imgs):
55+
raise NotImplementedError('This method is not implemented.')
56+
57+
def encode_decode(self, img, img_metas):
58+
raise NotImplementedError('This method is not implemented.')
59+
60+
def forward_train(self, imgs, img_metas, **kwargs):
61+
raise NotImplementedError('This method is not implemented.')
62+
63+
def simple_test(self, img, img_meta, **kwargs):
64+
device_type = img.device.type
65+
self.io_binding.bind_input(
66+
name='input',
67+
device_type=device_type,
68+
device_id=self.device_id,
69+
element_type=np.float32,
70+
shape=img.shape,
71+
buffer_ptr=img.data_ptr())
72+
self.sess.run_with_iobinding(self.io_binding)
73+
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
74+
# whole might support dynamic reshape
75+
ori_shape = img_meta[0]['ori_shape']
76+
if not (ori_shape[0] == seg_pred.shape[-2]
77+
and ori_shape[1] == seg_pred.shape[-1]):
78+
seg_pred = torch.from_numpy(seg_pred).float()
79+
seg_pred = torch.nn.functional.interpolate(
80+
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
81+
seg_pred = seg_pred.long().detach().cpu().numpy()
82+
seg_pred = seg_pred[0]
83+
seg_pred = list(seg_pred)
84+
return seg_pred
85+
86+
def aug_test(self, imgs, img_metas, **kwargs):
87+
raise NotImplementedError('This method is not implemented.')
88+
89+
90+
def parse_args():
91+
parser = argparse.ArgumentParser(
92+
description='mmseg onnxruntime backend test (and eval) a model')
93+
parser.add_argument('config', help='test config file path')
94+
parser.add_argument('model', help='Input model file')
95+
parser.add_argument('--out', help='output result file in pickle format')
96+
parser.add_argument(
97+
'--format-only',
98+
action='store_true',
99+
help='Format the output results without perform evaluation. It is'
100+
'useful when you want to format the result to a specific format and '
101+
'submit it to the test server')
102+
parser.add_argument(
103+
'--eval',
104+
type=str,
105+
nargs='+',
106+
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
107+
' for generic datasets, and "cityscapes" for Cityscapes')
108+
parser.add_argument('--show', action='store_true', help='show results')
109+
parser.add_argument(
110+
'--show-dir', help='directory where painted images will be saved')
111+
parser.add_argument(
112+
'--options', nargs='+', action=DictAction, help='custom options')
113+
parser.add_argument(
114+
'--eval-options',
115+
nargs='+',
116+
action=DictAction,
117+
help='custom options for evaluation')
118+
parser.add_argument(
119+
'--opacity',
120+
type=float,
121+
default=0.5,
122+
help='Opacity of painted segmentation map. In (0, 1] range.')
123+
parser.add_argument('--local_rank', type=int, default=0)
124+
args = parser.parse_args()
125+
if 'LOCAL_RANK' not in os.environ:
126+
os.environ['LOCAL_RANK'] = str(args.local_rank)
127+
return args
128+
129+
130+
def main():
131+
args = parse_args()
132+
133+
assert args.out or args.eval or args.format_only or args.show \
134+
or args.show_dir, \
135+
('Please specify at least one operation (save/eval/format/show the '
136+
'results / save the results) with the argument "--out", "--eval"'
137+
', "--format-only", "--show" or "--show-dir"')
138+
139+
if args.eval and args.format_only:
140+
raise ValueError('--eval and --format_only cannot be both specified')
141+
142+
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
143+
raise ValueError('The output file must be a pkl file.')
144+
145+
cfg = mmcv.Config.fromfile(args.config)
146+
if args.options is not None:
147+
cfg.merge_from_dict(args.options)
148+
cfg.model.pretrained = None
149+
cfg.data.test.test_mode = True
150+
151+
# init distributed env first, since logger depends on the dist info.
152+
distributed = False
153+
154+
# build the dataloader
155+
# TODO: support multiple images per gpu (only minor changes are needed)
156+
dataset = build_dataset(cfg.data.test)
157+
data_loader = build_dataloader(
158+
dataset,
159+
samples_per_gpu=1,
160+
workers_per_gpu=cfg.data.workers_per_gpu,
161+
dist=distributed,
162+
shuffle=False)
163+
164+
# load onnx config and meta
165+
cfg.model.train_cfg = None
166+
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
167+
model.CLASSES = dataset.CLASSES
168+
model.PALETTE = dataset.PALETTE
169+
170+
efficient_test = False
171+
if args.eval_options is not None:
172+
efficient_test = args.eval_options.get('efficient_test', False)
173+
174+
model = MMDataParallel(model, device_ids=[0])
175+
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
176+
efficient_test, args.opacity)
177+
178+
rank, _ = get_dist_info()
179+
if rank == 0:
180+
if args.out:
181+
print(f'\nwriting results to {args.out}')
182+
mmcv.dump(outputs, args.out)
183+
kwargs = {} if args.eval_options is None else args.eval_options
184+
if args.format_only:
185+
dataset.format_results(outputs, **kwargs)
186+
if args.eval:
187+
dataset.evaluate(outputs, args.eval, **kwargs)
188+
189+
190+
if __name__ == '__main__':
191+
main()

0 commit comments

Comments
 (0)