Skip to content

Commit 789d1a1

Browse files
grimoireRunningLeonxvjiarui
authored
add dynamic export and visualize to pytorch2onnx (open-mmlab#463)
* add dynamic export and visualize to pytorch2onnx * update document * fix lint * fix dynamic error and add visualization * fix lint * update docstring * update doc * Update help info for --show Co-authored-by: Jerry Jiarui XU <[email protected]> * fix lint Co-authored-by: maningsheng <[email protected]> Co-authored-by: Jerry Jiarui XU <[email protected]>
1 parent e0e985f commit 789d1a1

File tree

5 files changed

+202
-26
lines changed

5 files changed

+202
-26
lines changed

docs/useful_tools.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,32 @@ The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pt
4646

4747
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
4848

49-
```shell
50-
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
49+
```bash
50+
python tools/pytorch2onnx.py \
51+
${CONFIG_FILE} \
52+
--checkpoint ${CHECKPOINT_FILE} \
53+
--output-file ${ONNX_FILE} \
54+
--input-img ${INPUT_IMG} \
55+
--shape ${INPUT_SHAPE} \
56+
--show \
57+
--verify \
58+
--dynamic-export \
59+
--cfg-options \
60+
model.test_cfg.mode="whole"
5161
```
5262

63+
Description of arguments:
64+
65+
- `config` : The path of a model config file.
66+
- `--checkpoint` : The path of a model checkpoint file.
67+
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
68+
- `--input-img` : The path of an input image for conversion and visualize.
69+
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`.
70+
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
71+
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
72+
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
73+
- `--cfg-options`:Update config options.
74+
5375
**Note**: This tool is still experimental. Some customized operators are not supported for now.
5476

5577
## Miscellaneous

mmseg/apis/inference.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def show_result_pyplot(model,
103103
result,
104104
palette=None,
105105
fig_size=(15, 10),
106-
opacity=0.5):
106+
opacity=0.5,
107+
title='',
108+
block=True):
107109
"""Visualize the segmentation results on the image.
108110
109111
Args:
@@ -117,11 +119,17 @@ def show_result_pyplot(model,
117119
opacity(float): Opacity of painted segmentation map.
118120
Default 0.5.
119121
Must be in (0, 1] range.
122+
title (str): The title of pyplot figure.
123+
Default is ''.
124+
block (bool): Whether to block the pyplot figure.
125+
Default is True.
120126
"""
121127
if hasattr(model, 'module'):
122128
model = model.module
123129
img = model.show_result(
124130
img, result, palette=palette, show=False, opacity=opacity)
125131
plt.figure(figsize=fig_size)
126132
plt.imshow(mmcv.bgr2rgb(img))
127-
plt.show()
133+
plt.title(title)
134+
plt.tight_layout()
135+
plt.show(block=block)

mmseg/models/segmentors/encoder_decoder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,14 @@ def whole_inference(self, img, img_meta, rescale):
216216

217217
seg_logit = self.encode_decode(img, img_meta)
218218
if rescale:
219+
# support dynamic shape for onnx
220+
if torch.onnx.is_in_onnx_export():
221+
size = img.shape[2:]
222+
else:
223+
size = img_meta[0]['ori_shape'][:2]
219224
seg_logit = resize(
220225
seg_logit,
221-
size=img_meta[0]['ori_shape'][:2],
226+
size=size,
222227
mode='bilinear',
223228
align_corners=self.align_corners,
224229
warning=False)

mmseg/ops/wrappers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import warnings
22

3-
import torch
43
import torch.nn as nn
54
import torch.nn.functional as F
65

@@ -24,8 +23,6 @@ def resize(input,
2423
'the output would more aligned if '
2524
f'input size {(input_h, input_w)} is `x+1` and '
2625
f'out size {(output_h, output_w)} is `nx+1`')
27-
if isinstance(size, torch.Size):
28-
size = tuple(int(x) for x in size)
2926
return F.interpolate(input, size, scale_factor, mode, align_corners)
3027

3128

tools/pytorch2onnx.py

Lines changed: 162 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
import torch
88
import torch._C
99
import torch.serialization
10+
from mmcv import DictAction
1011
from mmcv.onnx import register_extra_symbolics
1112
from mmcv.runner import load_checkpoint
1213
from torch import nn
1314

15+
from mmseg.apis import show_result_pyplot
16+
from mmseg.apis.inference import LoadImage
17+
from mmseg.datasets.pipelines import Compose
1418
from mmseg.models import build_segmentor
1519

1620
torch.manual_seed(3)
@@ -67,25 +71,61 @@ def _demo_mm_inputs(input_shape, num_classes):
6771
return mm_inputs
6872

6973

74+
def _prepare_input_img(img_path, test_pipeline, shape=None):
75+
# build the data pipeline
76+
if shape is not None:
77+
test_pipeline[1]['img_scale'] = shape
78+
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
79+
test_pipeline = [LoadImage()] + test_pipeline[1:]
80+
test_pipeline = Compose(test_pipeline)
81+
# prepare data
82+
data = dict(img=img_path)
83+
data = test_pipeline(data)
84+
imgs = data['img']
85+
img_metas = [i.data for i in data['img_metas']]
86+
87+
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
88+
89+
return mm_inputs
90+
91+
92+
def _update_input_img(img_list, img_meta_list):
93+
# update img and its meta list
94+
N, C, H, W = img_list[0].shape
95+
img_meta = img_meta_list[0][0]
96+
new_img_meta_list = [[{
97+
'img_shape': (H, W, C),
98+
'ori_shape': (H, W, C),
99+
'pad_shape': (H, W, C),
100+
'filename': img_meta['filename'],
101+
'scale_factor': 1.,
102+
'flip': False,
103+
} for _ in range(N)]]
104+
105+
return img_list, new_img_meta_list
106+
107+
70108
def pytorch2onnx(model,
71-
input_shape,
109+
mm_inputs,
72110
opset_version=11,
73111
show=False,
74112
output_file='tmp.onnx',
75-
verify=False):
113+
verify=False,
114+
dynamic_export=False):
76115
"""Export Pytorch model to ONNX model and verify the outputs are same
77116
between Pytorch and ONNX.
78117
79118
Args:
80119
model (nn.Module): Pytorch model we want to export.
81-
input_shape (tuple): Use this input shape to construct
82-
the corresponding dummy input and execute the model.
120+
mm_inputs (dict): Contain the input tensors and img_metas information.
83121
opset_version (int): The onnx op version. Default: 11.
84122
show (bool): Whether print the computation graph. Default: False.
85123
output_file (string): The path to where we store the output ONNX model.
86124
Default: `tmp.onnx`.
87125
verify (bool): Whether compare the outputs between Pytorch and ONNX.
88126
Default: False.
127+
dynamic_export (bool): Whether to export ONNX with dynamic axis.
128+
Default: False.
89129
"""
90130
model.cpu().eval()
91131

@@ -94,28 +134,45 @@ def pytorch2onnx(model,
94134
else:
95135
num_classes = model.decode_head.num_classes
96136

97-
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
98-
99137
imgs = mm_inputs.pop('imgs')
100138
img_metas = mm_inputs.pop('img_metas')
139+
ori_shape = img_metas[0]['ori_shape']
101140

102141
img_list = [img[None, :] for img in imgs]
103142
img_meta_list = [[img_meta] for img_meta in img_metas]
143+
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
104144

105145
# replace original forward function
106146
origin_forward = model.forward
107147
model.forward = partial(
108148
model.forward, img_metas=img_meta_list, return_loss=False)
149+
dynamic_axes = None
150+
if dynamic_export:
151+
dynamic_axes = {
152+
'input': {
153+
0: 'batch',
154+
2: 'height',
155+
3: 'width'
156+
},
157+
'output': {
158+
1: 'batch',
159+
2: 'height',
160+
3: 'width'
161+
}
162+
}
109163

110164
register_extra_symbolics(opset_version)
111165
with torch.no_grad():
112166
torch.onnx.export(
113167
model, (img_list, ),
114168
output_file,
169+
input_names=['input'],
170+
output_names=['output'],
115171
export_params=True,
116-
keep_initializers_as_inputs=True,
172+
keep_initializers_as_inputs=False,
117173
verbose=show,
118-
opset_version=opset_version)
174+
opset_version=opset_version,
175+
dynamic_axes=dynamic_axes)
119176
print(f'Successfully exported ONNX model: {output_file}')
120177
model.forward = origin_forward
121178

@@ -125,9 +182,28 @@ def pytorch2onnx(model,
125182
onnx_model = onnx.load(output_file)
126183
onnx.checker.check_model(onnx_model)
127184

185+
if dynamic_export:
186+
# scale image for dynamic shape test
187+
img_list = [
188+
nn.functional.interpolate(_, scale_factor=1.5)
189+
for _ in img_list
190+
]
191+
# concate flip image for batch test
192+
flip_img_list = [_.flip(-1) for _ in img_list]
193+
img_list = [
194+
torch.cat((ori_img, flip_img), 0)
195+
for ori_img, flip_img in zip(img_list, flip_img_list)
196+
]
197+
198+
# update img_meta
199+
img_list, img_meta_list = _update_input_img(
200+
img_list, img_meta_list)
201+
128202
# check the numerical value
129203
# get pytorch output
130-
pytorch_result = model(img_list, img_meta_list, return_loss=False)[0]
204+
with torch.no_grad():
205+
pytorch_result = model(img_list, img_meta_list, return_loss=False)
206+
pytorch_result = np.stack(pytorch_result, 0)
131207

132208
# get onnx output
133209
input_all = [node.name for node in onnx_model.graph.input]
@@ -138,18 +214,55 @@ def pytorch2onnx(model,
138214
assert (len(net_feed_input) == 1)
139215
sess = rt.InferenceSession(output_file)
140216
onnx_result = sess.run(
141-
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
142-
if not np.allclose(pytorch_result, onnx_result):
143-
raise ValueError(
144-
'The outputs are different between Pytorch and ONNX')
217+
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
218+
# show segmentation results
219+
if show:
220+
import cv2
221+
import os.path as osp
222+
img = img_meta_list[0][0]['filename']
223+
if not osp.exists(img):
224+
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
225+
img = img.detach().numpy().astype(np.uint8)
226+
# resize onnx_result to ori_shape
227+
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
228+
(ori_shape[1], ori_shape[0]))
229+
show_result_pyplot(
230+
model,
231+
img, (onnx_result_, ),
232+
palette=model.PALETTE,
233+
block=False,
234+
title='ONNXRuntime',
235+
opacity=0.5)
236+
237+
# resize pytorch_result to ori_shape
238+
pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
239+
(ori_shape[1], ori_shape[0]))
240+
show_result_pyplot(
241+
model,
242+
img, (pytorch_result_, ),
243+
title='PyTorch',
244+
palette=model.PALETTE,
245+
opacity=0.5)
246+
# compare results
247+
np.testing.assert_allclose(
248+
pytorch_result.astype(np.float32) / num_classes,
249+
onnx_result.astype(np.float32) / num_classes,
250+
rtol=1e-5,
251+
atol=1e-5,
252+
err_msg='The outputs are different between Pytorch and ONNX')
145253
print('The outputs are same between Pytorch and ONNX')
146254

147255

148256
def parse_args():
149257
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
150258
parser.add_argument('config', help='test config file path')
151259
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
152-
parser.add_argument('--show', action='store_true', help='show onnx graph')
260+
parser.add_argument(
261+
'--input-img', type=str, help='Images for input', default=None)
262+
parser.add_argument(
263+
'--show',
264+
action='store_true',
265+
help='show onnx graph and segmentation results')
153266
parser.add_argument(
154267
'--verify', action='store_true', help='verify the onnx model')
155268
parser.add_argument('--output-file', type=str, default='tmp.onnx')
@@ -160,6 +273,20 @@ def parse_args():
160273
nargs='+',
161274
default=[256, 256],
162275
help='input image size')
276+
parser.add_argument(
277+
'--cfg-options',
278+
nargs='+',
279+
action=DictAction,
280+
help='Override some settings in the used config, the key-value pair '
281+
'in xxx=yyy format will be merged into config file. If the value to '
282+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
283+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
284+
'Note that the quotation marks are necessary and that no white space '
285+
'is allowed.')
286+
parser.add_argument(
287+
'--dynamic-export',
288+
action='store_true',
289+
help='Whether to export onnx with dynamic axis.')
163290
args = parser.parse_args()
164291
return args
165292

@@ -178,6 +305,8 @@ def parse_args():
178305
raise ValueError('invalid input shape')
179306

180307
cfg = mmcv.Config.fromfile(args.config)
308+
if args.cfg_options is not None:
309+
cfg.merge_from_dict(args.cfg_options)
181310
cfg.model.pretrained = None
182311

183312
# build the model and load checkpoint
@@ -188,13 +317,28 @@ def parse_args():
188317
segmentor = _convert_batchnorm(segmentor)
189318

190319
if args.checkpoint:
191-
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
320+
checkpoint = load_checkpoint(
321+
segmentor, args.checkpoint, map_location='cpu')
322+
segmentor.CLASSES = checkpoint['meta']['CLASSES']
323+
segmentor.PALETTE = checkpoint['meta']['PALETTE']
324+
325+
# read input or create dummpy input
326+
if args.input_img is not None:
327+
mm_inputs = _prepare_input_img(args.input_img, cfg.data.test.pipeline,
328+
(input_shape[3], input_shape[2]))
329+
else:
330+
if isinstance(segmentor.decode_head, nn.ModuleList):
331+
num_classes = segmentor.decode_head[-1].num_classes
332+
else:
333+
num_classes = segmentor.decode_head.num_classes
334+
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
192335

193-
# conver model to onnx file
336+
# convert model to onnx file
194337
pytorch2onnx(
195338
segmentor,
196-
input_shape,
339+
mm_inputs,
197340
opset_version=args.opset_version,
198341
show=args.show,
199342
output_file=args.output_file,
200-
verify=args.verify)
343+
verify=args.verify,
344+
dynamic_export=args.dynamic_export)

0 commit comments

Comments
 (0)