77import torch
88import torch ._C
99import torch .serialization
10+ from mmcv import DictAction
1011from mmcv .onnx import register_extra_symbolics
1112from mmcv .runner import load_checkpoint
1213from torch import nn
1314
15+ from mmseg .apis import show_result_pyplot
16+ from mmseg .apis .inference import LoadImage
17+ from mmseg .datasets .pipelines import Compose
1418from mmseg .models import build_segmentor
1519
1620torch .manual_seed (3 )
@@ -67,25 +71,61 @@ def _demo_mm_inputs(input_shape, num_classes):
6771 return mm_inputs
6872
6973
74+ def _prepare_input_img (img_path , test_pipeline , shape = None ):
75+ # build the data pipeline
76+ if shape is not None :
77+ test_pipeline [1 ]['img_scale' ] = shape
78+ test_pipeline [1 ]['transforms' ][0 ]['keep_ratio' ] = False
79+ test_pipeline = [LoadImage ()] + test_pipeline [1 :]
80+ test_pipeline = Compose (test_pipeline )
81+ # prepare data
82+ data = dict (img = img_path )
83+ data = test_pipeline (data )
84+ imgs = data ['img' ]
85+ img_metas = [i .data for i in data ['img_metas' ]]
86+
87+ mm_inputs = {'imgs' : imgs , 'img_metas' : img_metas }
88+
89+ return mm_inputs
90+
91+
92+ def _update_input_img (img_list , img_meta_list ):
93+ # update img and its meta list
94+ N , C , H , W = img_list [0 ].shape
95+ img_meta = img_meta_list [0 ][0 ]
96+ new_img_meta_list = [[{
97+ 'img_shape' : (H , W , C ),
98+ 'ori_shape' : (H , W , C ),
99+ 'pad_shape' : (H , W , C ),
100+ 'filename' : img_meta ['filename' ],
101+ 'scale_factor' : 1. ,
102+ 'flip' : False ,
103+ } for _ in range (N )]]
104+
105+ return img_list , new_img_meta_list
106+
107+
70108def pytorch2onnx (model ,
71- input_shape ,
109+ mm_inputs ,
72110 opset_version = 11 ,
73111 show = False ,
74112 output_file = 'tmp.onnx' ,
75- verify = False ):
113+ verify = False ,
114+ dynamic_export = False ):
76115 """Export Pytorch model to ONNX model and verify the outputs are same
77116 between Pytorch and ONNX.
78117
79118 Args:
80119 model (nn.Module): Pytorch model we want to export.
81- input_shape (tuple): Use this input shape to construct
82- the corresponding dummy input and execute the model.
120+ mm_inputs (dict): Contain the input tensors and img_metas information.
83121 opset_version (int): The onnx op version. Default: 11.
84122 show (bool): Whether print the computation graph. Default: False.
85123 output_file (string): The path to where we store the output ONNX model.
86124 Default: `tmp.onnx`.
87125 verify (bool): Whether compare the outputs between Pytorch and ONNX.
88126 Default: False.
127+ dynamic_export (bool): Whether to export ONNX with dynamic axis.
128+ Default: False.
89129 """
90130 model .cpu ().eval ()
91131
@@ -94,28 +134,45 @@ def pytorch2onnx(model,
94134 else :
95135 num_classes = model .decode_head .num_classes
96136
97- mm_inputs = _demo_mm_inputs (input_shape , num_classes )
98-
99137 imgs = mm_inputs .pop ('imgs' )
100138 img_metas = mm_inputs .pop ('img_metas' )
139+ ori_shape = img_metas [0 ]['ori_shape' ]
101140
102141 img_list = [img [None , :] for img in imgs ]
103142 img_meta_list = [[img_meta ] for img_meta in img_metas ]
143+ img_list , img_meta_list = _update_input_img (img_list , img_meta_list )
104144
105145 # replace original forward function
106146 origin_forward = model .forward
107147 model .forward = partial (
108148 model .forward , img_metas = img_meta_list , return_loss = False )
149+ dynamic_axes = None
150+ if dynamic_export :
151+ dynamic_axes = {
152+ 'input' : {
153+ 0 : 'batch' ,
154+ 2 : 'height' ,
155+ 3 : 'width'
156+ },
157+ 'output' : {
158+ 1 : 'batch' ,
159+ 2 : 'height' ,
160+ 3 : 'width'
161+ }
162+ }
109163
110164 register_extra_symbolics (opset_version )
111165 with torch .no_grad ():
112166 torch .onnx .export (
113167 model , (img_list , ),
114168 output_file ,
169+ input_names = ['input' ],
170+ output_names = ['output' ],
115171 export_params = True ,
116- keep_initializers_as_inputs = True ,
172+ keep_initializers_as_inputs = False ,
117173 verbose = show ,
118- opset_version = opset_version )
174+ opset_version = opset_version ,
175+ dynamic_axes = dynamic_axes )
119176 print (f'Successfully exported ONNX model: { output_file } ' )
120177 model .forward = origin_forward
121178
@@ -125,9 +182,28 @@ def pytorch2onnx(model,
125182 onnx_model = onnx .load (output_file )
126183 onnx .checker .check_model (onnx_model )
127184
185+ if dynamic_export :
186+ # scale image for dynamic shape test
187+ img_list = [
188+ nn .functional .interpolate (_ , scale_factor = 1.5 )
189+ for _ in img_list
190+ ]
191+ # concate flip image for batch test
192+ flip_img_list = [_ .flip (- 1 ) for _ in img_list ]
193+ img_list = [
194+ torch .cat ((ori_img , flip_img ), 0 )
195+ for ori_img , flip_img in zip (img_list , flip_img_list )
196+ ]
197+
198+ # update img_meta
199+ img_list , img_meta_list = _update_input_img (
200+ img_list , img_meta_list )
201+
128202 # check the numerical value
129203 # get pytorch output
130- pytorch_result = model (img_list , img_meta_list , return_loss = False )[0 ]
204+ with torch .no_grad ():
205+ pytorch_result = model (img_list , img_meta_list , return_loss = False )
206+ pytorch_result = np .stack (pytorch_result , 0 )
131207
132208 # get onnx output
133209 input_all = [node .name for node in onnx_model .graph .input ]
@@ -138,18 +214,55 @@ def pytorch2onnx(model,
138214 assert (len (net_feed_input ) == 1 )
139215 sess = rt .InferenceSession (output_file )
140216 onnx_result = sess .run (
141- None , {net_feed_input [0 ]: img_list [0 ].detach ().numpy ()})[0 ]
142- if not np .allclose (pytorch_result , onnx_result ):
143- raise ValueError (
144- 'The outputs are different between Pytorch and ONNX' )
217+ None , {net_feed_input [0 ]: img_list [0 ].detach ().numpy ()})[0 ][0 ]
218+ # show segmentation results
219+ if show :
220+ import cv2
221+ import os .path as osp
222+ img = img_meta_list [0 ][0 ]['filename' ]
223+ if not osp .exists (img ):
224+ img = imgs [0 ][:3 , ...].permute (1 , 2 , 0 ) * 255
225+ img = img .detach ().numpy ().astype (np .uint8 )
226+ # resize onnx_result to ori_shape
227+ onnx_result_ = cv2 .resize (onnx_result [0 ].astype (np .uint8 ),
228+ (ori_shape [1 ], ori_shape [0 ]))
229+ show_result_pyplot (
230+ model ,
231+ img , (onnx_result_ , ),
232+ palette = model .PALETTE ,
233+ block = False ,
234+ title = 'ONNXRuntime' ,
235+ opacity = 0.5 )
236+
237+ # resize pytorch_result to ori_shape
238+ pytorch_result_ = cv2 .resize (pytorch_result [0 ].astype (np .uint8 ),
239+ (ori_shape [1 ], ori_shape [0 ]))
240+ show_result_pyplot (
241+ model ,
242+ img , (pytorch_result_ , ),
243+ title = 'PyTorch' ,
244+ palette = model .PALETTE ,
245+ opacity = 0.5 )
246+ # compare results
247+ np .testing .assert_allclose (
248+ pytorch_result .astype (np .float32 ) / num_classes ,
249+ onnx_result .astype (np .float32 ) / num_classes ,
250+ rtol = 1e-5 ,
251+ atol = 1e-5 ,
252+ err_msg = 'The outputs are different between Pytorch and ONNX' )
145253 print ('The outputs are same between Pytorch and ONNX' )
146254
147255
148256def parse_args ():
149257 parser = argparse .ArgumentParser (description = 'Convert MMSeg to ONNX' )
150258 parser .add_argument ('config' , help = 'test config file path' )
151259 parser .add_argument ('--checkpoint' , help = 'checkpoint file' , default = None )
152- parser .add_argument ('--show' , action = 'store_true' , help = 'show onnx graph' )
260+ parser .add_argument (
261+ '--input-img' , type = str , help = 'Images for input' , default = None )
262+ parser .add_argument (
263+ '--show' ,
264+ action = 'store_true' ,
265+ help = 'show onnx graph and segmentation results' )
153266 parser .add_argument (
154267 '--verify' , action = 'store_true' , help = 'verify the onnx model' )
155268 parser .add_argument ('--output-file' , type = str , default = 'tmp.onnx' )
@@ -160,6 +273,20 @@ def parse_args():
160273 nargs = '+' ,
161274 default = [256 , 256 ],
162275 help = 'input image size' )
276+ parser .add_argument (
277+ '--cfg-options' ,
278+ nargs = '+' ,
279+ action = DictAction ,
280+ help = 'Override some settings in the used config, the key-value pair '
281+ 'in xxx=yyy format will be merged into config file. If the value to '
282+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
283+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
284+ 'Note that the quotation marks are necessary and that no white space '
285+ 'is allowed.' )
286+ parser .add_argument (
287+ '--dynamic-export' ,
288+ action = 'store_true' ,
289+ help = 'Whether to export onnx with dynamic axis.' )
163290 args = parser .parse_args ()
164291 return args
165292
@@ -178,6 +305,8 @@ def parse_args():
178305 raise ValueError ('invalid input shape' )
179306
180307 cfg = mmcv .Config .fromfile (args .config )
308+ if args .cfg_options is not None :
309+ cfg .merge_from_dict (args .cfg_options )
181310 cfg .model .pretrained = None
182311
183312 # build the model and load checkpoint
@@ -188,13 +317,28 @@ def parse_args():
188317 segmentor = _convert_batchnorm (segmentor )
189318
190319 if args .checkpoint :
191- load_checkpoint (segmentor , args .checkpoint , map_location = 'cpu' )
320+ checkpoint = load_checkpoint (
321+ segmentor , args .checkpoint , map_location = 'cpu' )
322+ segmentor .CLASSES = checkpoint ['meta' ]['CLASSES' ]
323+ segmentor .PALETTE = checkpoint ['meta' ]['PALETTE' ]
324+
325+ # read input or create dummpy input
326+ if args .input_img is not None :
327+ mm_inputs = _prepare_input_img (args .input_img , cfg .data .test .pipeline ,
328+ (input_shape [3 ], input_shape [2 ]))
329+ else :
330+ if isinstance (segmentor .decode_head , nn .ModuleList ):
331+ num_classes = segmentor .decode_head [- 1 ].num_classes
332+ else :
333+ num_classes = segmentor .decode_head .num_classes
334+ mm_inputs = _demo_mm_inputs (input_shape , num_classes )
192335
193- # conver model to onnx file
336+ # convert model to onnx file
194337 pytorch2onnx (
195338 segmentor ,
196- input_shape ,
339+ mm_inputs ,
197340 opset_version = args .opset_version ,
198341 show = args .show ,
199342 output_file = args .output_file ,
200- verify = args .verify )
343+ verify = args .verify ,
344+ dynamic_export = args .dynamic_export )
0 commit comments