7
7
import torch
8
8
import torch ._C
9
9
import torch .serialization
10
+ from mmcv import DictAction
10
11
from mmcv .onnx import register_extra_symbolics
11
12
from mmcv .runner import load_checkpoint
12
13
from torch import nn
13
14
15
+ from mmseg .apis import show_result_pyplot
16
+ from mmseg .apis .inference import LoadImage
17
+ from mmseg .datasets .pipelines import Compose
14
18
from mmseg .models import build_segmentor
15
19
16
20
torch .manual_seed (3 )
@@ -67,25 +71,61 @@ def _demo_mm_inputs(input_shape, num_classes):
67
71
return mm_inputs
68
72
69
73
74
+ def _prepare_input_img (img_path , test_pipeline , shape = None ):
75
+ # build the data pipeline
76
+ if shape is not None :
77
+ test_pipeline [1 ]['img_scale' ] = shape
78
+ test_pipeline [1 ]['transforms' ][0 ]['keep_ratio' ] = False
79
+ test_pipeline = [LoadImage ()] + test_pipeline [1 :]
80
+ test_pipeline = Compose (test_pipeline )
81
+ # prepare data
82
+ data = dict (img = img_path )
83
+ data = test_pipeline (data )
84
+ imgs = data ['img' ]
85
+ img_metas = [i .data for i in data ['img_metas' ]]
86
+
87
+ mm_inputs = {'imgs' : imgs , 'img_metas' : img_metas }
88
+
89
+ return mm_inputs
90
+
91
+
92
+ def _update_input_img (img_list , img_meta_list ):
93
+ # update img and its meta list
94
+ N , C , H , W = img_list [0 ].shape
95
+ img_meta = img_meta_list [0 ][0 ]
96
+ new_img_meta_list = [[{
97
+ 'img_shape' : (H , W , C ),
98
+ 'ori_shape' : (H , W , C ),
99
+ 'pad_shape' : (H , W , C ),
100
+ 'filename' : img_meta ['filename' ],
101
+ 'scale_factor' : 1. ,
102
+ 'flip' : False ,
103
+ } for _ in range (N )]]
104
+
105
+ return img_list , new_img_meta_list
106
+
107
+
70
108
def pytorch2onnx (model ,
71
- input_shape ,
109
+ mm_inputs ,
72
110
opset_version = 11 ,
73
111
show = False ,
74
112
output_file = 'tmp.onnx' ,
75
- verify = False ):
113
+ verify = False ,
114
+ dynamic_export = False ):
76
115
"""Export Pytorch model to ONNX model and verify the outputs are same
77
116
between Pytorch and ONNX.
78
117
79
118
Args:
80
119
model (nn.Module): Pytorch model we want to export.
81
- input_shape (tuple): Use this input shape to construct
82
- the corresponding dummy input and execute the model.
120
+ mm_inputs (dict): Contain the input tensors and img_metas information.
83
121
opset_version (int): The onnx op version. Default: 11.
84
122
show (bool): Whether print the computation graph. Default: False.
85
123
output_file (string): The path to where we store the output ONNX model.
86
124
Default: `tmp.onnx`.
87
125
verify (bool): Whether compare the outputs between Pytorch and ONNX.
88
126
Default: False.
127
+ dynamic_export (bool): Whether to export ONNX with dynamic axis.
128
+ Default: False.
89
129
"""
90
130
model .cpu ().eval ()
91
131
@@ -94,28 +134,45 @@ def pytorch2onnx(model,
94
134
else :
95
135
num_classes = model .decode_head .num_classes
96
136
97
- mm_inputs = _demo_mm_inputs (input_shape , num_classes )
98
-
99
137
imgs = mm_inputs .pop ('imgs' )
100
138
img_metas = mm_inputs .pop ('img_metas' )
139
+ ori_shape = img_metas [0 ]['ori_shape' ]
101
140
102
141
img_list = [img [None , :] for img in imgs ]
103
142
img_meta_list = [[img_meta ] for img_meta in img_metas ]
143
+ img_list , img_meta_list = _update_input_img (img_list , img_meta_list )
104
144
105
145
# replace original forward function
106
146
origin_forward = model .forward
107
147
model .forward = partial (
108
148
model .forward , img_metas = img_meta_list , return_loss = False )
149
+ dynamic_axes = None
150
+ if dynamic_export :
151
+ dynamic_axes = {
152
+ 'input' : {
153
+ 0 : 'batch' ,
154
+ 2 : 'height' ,
155
+ 3 : 'width'
156
+ },
157
+ 'output' : {
158
+ 1 : 'batch' ,
159
+ 2 : 'height' ,
160
+ 3 : 'width'
161
+ }
162
+ }
109
163
110
164
register_extra_symbolics (opset_version )
111
165
with torch .no_grad ():
112
166
torch .onnx .export (
113
167
model , (img_list , ),
114
168
output_file ,
169
+ input_names = ['input' ],
170
+ output_names = ['output' ],
115
171
export_params = True ,
116
- keep_initializers_as_inputs = True ,
172
+ keep_initializers_as_inputs = False ,
117
173
verbose = show ,
118
- opset_version = opset_version )
174
+ opset_version = opset_version ,
175
+ dynamic_axes = dynamic_axes )
119
176
print (f'Successfully exported ONNX model: { output_file } ' )
120
177
model .forward = origin_forward
121
178
@@ -125,9 +182,28 @@ def pytorch2onnx(model,
125
182
onnx_model = onnx .load (output_file )
126
183
onnx .checker .check_model (onnx_model )
127
184
185
+ if dynamic_export :
186
+ # scale image for dynamic shape test
187
+ img_list = [
188
+ nn .functional .interpolate (_ , scale_factor = 1.5 )
189
+ for _ in img_list
190
+ ]
191
+ # concate flip image for batch test
192
+ flip_img_list = [_ .flip (- 1 ) for _ in img_list ]
193
+ img_list = [
194
+ torch .cat ((ori_img , flip_img ), 0 )
195
+ for ori_img , flip_img in zip (img_list , flip_img_list )
196
+ ]
197
+
198
+ # update img_meta
199
+ img_list , img_meta_list = _update_input_img (
200
+ img_list , img_meta_list )
201
+
128
202
# check the numerical value
129
203
# get pytorch output
130
- pytorch_result = model (img_list , img_meta_list , return_loss = False )[0 ]
204
+ with torch .no_grad ():
205
+ pytorch_result = model (img_list , img_meta_list , return_loss = False )
206
+ pytorch_result = np .stack (pytorch_result , 0 )
131
207
132
208
# get onnx output
133
209
input_all = [node .name for node in onnx_model .graph .input ]
@@ -138,18 +214,55 @@ def pytorch2onnx(model,
138
214
assert (len (net_feed_input ) == 1 )
139
215
sess = rt .InferenceSession (output_file )
140
216
onnx_result = sess .run (
141
- None , {net_feed_input [0 ]: img_list [0 ].detach ().numpy ()})[0 ]
142
- if not np .allclose (pytorch_result , onnx_result ):
143
- raise ValueError (
144
- 'The outputs are different between Pytorch and ONNX' )
217
+ None , {net_feed_input [0 ]: img_list [0 ].detach ().numpy ()})[0 ][0 ]
218
+ # show segmentation results
219
+ if show :
220
+ import cv2
221
+ import os .path as osp
222
+ img = img_meta_list [0 ][0 ]['filename' ]
223
+ if not osp .exists (img ):
224
+ img = imgs [0 ][:3 , ...].permute (1 , 2 , 0 ) * 255
225
+ img = img .detach ().numpy ().astype (np .uint8 )
226
+ # resize onnx_result to ori_shape
227
+ onnx_result_ = cv2 .resize (onnx_result [0 ].astype (np .uint8 ),
228
+ (ori_shape [1 ], ori_shape [0 ]))
229
+ show_result_pyplot (
230
+ model ,
231
+ img , (onnx_result_ , ),
232
+ palette = model .PALETTE ,
233
+ block = False ,
234
+ title = 'ONNXRuntime' ,
235
+ opacity = 0.5 )
236
+
237
+ # resize pytorch_result to ori_shape
238
+ pytorch_result_ = cv2 .resize (pytorch_result [0 ].astype (np .uint8 ),
239
+ (ori_shape [1 ], ori_shape [0 ]))
240
+ show_result_pyplot (
241
+ model ,
242
+ img , (pytorch_result_ , ),
243
+ title = 'PyTorch' ,
244
+ palette = model .PALETTE ,
245
+ opacity = 0.5 )
246
+ # compare results
247
+ np .testing .assert_allclose (
248
+ pytorch_result .astype (np .float32 ) / num_classes ,
249
+ onnx_result .astype (np .float32 ) / num_classes ,
250
+ rtol = 1e-5 ,
251
+ atol = 1e-5 ,
252
+ err_msg = 'The outputs are different between Pytorch and ONNX' )
145
253
print ('The outputs are same between Pytorch and ONNX' )
146
254
147
255
148
256
def parse_args ():
149
257
parser = argparse .ArgumentParser (description = 'Convert MMSeg to ONNX' )
150
258
parser .add_argument ('config' , help = 'test config file path' )
151
259
parser .add_argument ('--checkpoint' , help = 'checkpoint file' , default = None )
152
- parser .add_argument ('--show' , action = 'store_true' , help = 'show onnx graph' )
260
+ parser .add_argument (
261
+ '--input-img' , type = str , help = 'Images for input' , default = None )
262
+ parser .add_argument (
263
+ '--show' ,
264
+ action = 'store_true' ,
265
+ help = 'show onnx graph and segmentation results' )
153
266
parser .add_argument (
154
267
'--verify' , action = 'store_true' , help = 'verify the onnx model' )
155
268
parser .add_argument ('--output-file' , type = str , default = 'tmp.onnx' )
@@ -160,6 +273,20 @@ def parse_args():
160
273
nargs = '+' ,
161
274
default = [256 , 256 ],
162
275
help = 'input image size' )
276
+ parser .add_argument (
277
+ '--cfg-options' ,
278
+ nargs = '+' ,
279
+ action = DictAction ,
280
+ help = 'Override some settings in the used config, the key-value pair '
281
+ 'in xxx=yyy format will be merged into config file. If the value to '
282
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
283
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
284
+ 'Note that the quotation marks are necessary and that no white space '
285
+ 'is allowed.' )
286
+ parser .add_argument (
287
+ '--dynamic-export' ,
288
+ action = 'store_true' ,
289
+ help = 'Whether to export onnx with dynamic axis.' )
163
290
args = parser .parse_args ()
164
291
return args
165
292
@@ -178,6 +305,8 @@ def parse_args():
178
305
raise ValueError ('invalid input shape' )
179
306
180
307
cfg = mmcv .Config .fromfile (args .config )
308
+ if args .cfg_options is not None :
309
+ cfg .merge_from_dict (args .cfg_options )
181
310
cfg .model .pretrained = None
182
311
183
312
# build the model and load checkpoint
@@ -188,13 +317,28 @@ def parse_args():
188
317
segmentor = _convert_batchnorm (segmentor )
189
318
190
319
if args .checkpoint :
191
- load_checkpoint (segmentor , args .checkpoint , map_location = 'cpu' )
320
+ checkpoint = load_checkpoint (
321
+ segmentor , args .checkpoint , map_location = 'cpu' )
322
+ segmentor .CLASSES = checkpoint ['meta' ]['CLASSES' ]
323
+ segmentor .PALETTE = checkpoint ['meta' ]['PALETTE' ]
324
+
325
+ # read input or create dummpy input
326
+ if args .input_img is not None :
327
+ mm_inputs = _prepare_input_img (args .input_img , cfg .data .test .pipeline ,
328
+ (input_shape [3 ], input_shape [2 ]))
329
+ else :
330
+ if isinstance (segmentor .decode_head , nn .ModuleList ):
331
+ num_classes = segmentor .decode_head [- 1 ].num_classes
332
+ else :
333
+ num_classes = segmentor .decode_head .num_classes
334
+ mm_inputs = _demo_mm_inputs (input_shape , num_classes )
192
335
193
- # conver model to onnx file
336
+ # convert model to onnx file
194
337
pytorch2onnx (
195
338
segmentor ,
196
- input_shape ,
339
+ mm_inputs ,
197
340
opset_version = args .opset_version ,
198
341
show = args .show ,
199
342
output_file = args .output_file ,
200
- verify = args .verify )
343
+ verify = args .verify ,
344
+ dynamic_export = args .dynamic_export )
0 commit comments