Skip to content

Commit dc5d53b

Browse files
author
q.yao
authored
[Feature] Update deploy test tools (open-mmlab#553)
* add trt test tool * create deploy_test, update document * fix with isort * move import inside __init__ * remove comment, fix doc * update document
1 parent 66b0525 commit dc5d53b

File tree

2 files changed

+86
-17
lines changed

2 files changed

+86
-17
lines changed

docs/useful_tools.md

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ Description of arguments:
7676

7777
**Note**: This tool is still experimental. Some customized operators are not supported for now.
7878

79-
### Evaluate ONNX model with ONNXRuntime
79+
### Evaluate ONNX model
8080

81-
We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
81+
We provide `tools/deploy_test.py` to evaluate ONNX model with different backend.
8282

8383
#### Prerequisite
8484

@@ -88,12 +88,15 @@ We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
8888
pip install onnx onnxruntime-gpu
8989
```
9090

91+
- Install TensorRT following [how-to-build-tensorrt-plugins-in-mmcv](https://mmcv.readthedocs.io/en/latest/tensorrt_plugin.html#how-to-build-tensorrt-plugins-in-mmcv)(optional)
92+
9193
#### Usage
9294

9395
```bash
94-
python tools/ort_test.py \
96+
python tools/deploy_test.py \
9597
${CONFIG_FILE} \
96-
${ONNX_FILE} \
98+
${MODEL_FILE} \
99+
${BACKEND} \
97100
--out ${OUTPUT_FILE} \
98101
--eval ${EVALUATION_METRICS} \
99102
--show \
@@ -106,7 +109,8 @@ python tools/ort_test.py \
106109
Description of all arguments
107110

108111
- `config`: The path of a model config file.
109-
- `model`: The path of a ONNX model file.
112+
- `model`: The path of a converted model file.
113+
- `backend`: Backend of the inference, options: `onnxruntime`, `tensorrt`.
110114
- `--out`: The path of output result file in pickle format.
111115
- `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. Note that this argument is **mutually exclusive** with `--eval`.
112116
- `--eval`: Evaluation metrics, which depends on the dataset, e.g., "mIoU" for generic datasets, and "cityscapes" for Cityscapes. Note that this argument is **mutually exclusive** with `--format-only`.
@@ -118,12 +122,17 @@ Description of all arguments
118122

119123
#### Results and Models
120124

121-
| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime |
122-
| :--------: | :--------------------------------------------: | :--------: | :----: | :-----: | :---------: |
123-
| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 |
124-
| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 |
125-
| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 |
126-
| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 |
125+
| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime | TensorRT-fp32 | TensorRT-fp16 |
126+
| :--------: | :---------------------------------------------: | :--------: | :----: | :-----: | :---------: | :-----------: | :-----------: |
127+
| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 | 72.2 | 72.2 |
128+
| PSPNet | pspnet_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 77.8 | 77.8 | 77.8 | 77.8 |
129+
| deeplabv3 | deeplabv3_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 79.0 | 79.0 | 79.0 | 79.0 |
130+
| deeplabv3+ | deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 79.6 | 79.5 | 79.5 | 79.5 |
131+
| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 | | |
132+
| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 | | |
133+
| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 | | |
134+
135+
**Note**: TensorRT is only available on configs with `whole mode`.
127136

128137
### Convert to TorchScript (experimental)
129138

tools/ort_test.py renamed to tools/deploy_test.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import os
33
import os.path as osp
44
import warnings
5+
from typing import Any, Iterable
56

67
import mmcv
78
import numpy as np
8-
import onnxruntime as ort
99
import torch
1010
from mmcv.parallel import MMDataParallel
1111
from mmcv.runner import get_dist_info
@@ -18,8 +18,10 @@
1818

1919
class ONNXRuntimeSegmentor(BaseSegmentor):
2020

21-
def __init__(self, onnx_file, cfg, device_id):
21+
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
2222
super(ONNXRuntimeSegmentor, self).__init__()
23+
import onnxruntime as ort
24+
2325
# get the custom op path
2426
ort_custom_op_path = ''
2527
try:
@@ -60,7 +62,8 @@ def encode_decode(self, img, img_metas):
6062
def forward_train(self, imgs, img_metas, **kwargs):
6163
raise NotImplementedError('This method is not implemented.')
6264

63-
def simple_test(self, img, img_meta, **kwargs):
65+
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
66+
**kwargs) -> list:
6467
device_type = img.device.type
6568
self.io_binding.bind_input(
6669
name='input',
@@ -87,11 +90,63 @@ def aug_test(self, imgs, img_metas, **kwargs):
8790
raise NotImplementedError('This method is not implemented.')
8891

8992

90-
def parse_args():
93+
class TensorRTSegmentor(BaseSegmentor):
94+
95+
def __init__(self, trt_file: str, cfg: Any, device_id: int):
96+
super(TensorRTSegmentor, self).__init__()
97+
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
98+
try:
99+
load_tensorrt_plugin()
100+
except (ImportError, ModuleNotFoundError):
101+
warnings.warn('If input model has custom op from mmcv, \
102+
you may have to build mmcv with TensorRT from source.')
103+
model = TRTWraper(
104+
trt_file, input_names=['input'], output_names=['output'])
105+
106+
self.model = model
107+
self.device_id = device_id
108+
self.cfg = cfg
109+
self.test_mode = cfg.model.test_cfg.mode
110+
111+
def extract_feat(self, imgs):
112+
raise NotImplementedError('This method is not implemented.')
113+
114+
def encode_decode(self, img, img_metas):
115+
raise NotImplementedError('This method is not implemented.')
116+
117+
def forward_train(self, imgs, img_metas, **kwargs):
118+
raise NotImplementedError('This method is not implemented.')
119+
120+
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
121+
**kwargs) -> list:
122+
with torch.cuda.device(self.device_id), torch.no_grad():
123+
seg_pred = self.model({'input': img})['output']
124+
seg_pred = seg_pred.detach().cpu().numpy()
125+
# whole might support dynamic reshape
126+
ori_shape = img_meta[0]['ori_shape']
127+
if not (ori_shape[0] == seg_pred.shape[-2]
128+
and ori_shape[1] == seg_pred.shape[-1]):
129+
seg_pred = torch.from_numpy(seg_pred).float()
130+
seg_pred = torch.nn.functional.interpolate(
131+
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
132+
seg_pred = seg_pred.long().detach().cpu().numpy()
133+
seg_pred = seg_pred[0]
134+
seg_pred = list(seg_pred)
135+
return seg_pred
136+
137+
def aug_test(self, imgs, img_metas, **kwargs):
138+
raise NotImplementedError('This method is not implemented.')
139+
140+
141+
def parse_args() -> argparse.Namespace:
91142
parser = argparse.ArgumentParser(
92-
description='mmseg onnxruntime backend test (and eval) a model')
143+
description='mmseg backend test (and eval)')
93144
parser.add_argument('config', help='test config file path')
94145
parser.add_argument('model', help='Input model file')
146+
parser.add_argument(
147+
'--backend',
148+
help='Backend of the model.',
149+
choices=['onnxruntime', 'tensorrt'])
95150
parser.add_argument('--out', help='output result file in pickle format')
96151
parser.add_argument(
97152
'--format-only',
@@ -163,7 +218,12 @@ def main():
163218

164219
# load onnx config and meta
165220
cfg.model.train_cfg = None
166-
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
221+
222+
if args.backend == 'onnxruntime':
223+
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
224+
elif args.backend == 'tensorrt':
225+
model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
226+
167227
model.CLASSES = dataset.CLASSES
168228
model.PALETTE = dataset.PALETTE
169229

0 commit comments

Comments
 (0)