11# Copyright (c) OpenMMLab. All rights reserved.
22import argparse
3+ import tempfile
4+ from pathlib import Path
35
4- from mmcv .cnn import get_model_complexity_info
5- from mmengine import Config
6+ import torch
7+ from mmengine import Config , DictAction
8+ from mmengine .logging import MMLogger
9+ from mmengine .model import revert_sync_batchnorm
10+ from mmengine .registry import init_default_scope
611
7- from mmseg .models import build_segmentor
12+ from mmseg .models import BaseSegmentor
13+ from mmseg .registry import MODELS
14+ from mmseg .structures import SegDataSample
15+
16+ try :
17+ from mmengine .analysis import get_model_complexity_info
18+ from mmengine .analysis .print_helper import _format_size
19+ except ImportError :
20+ raise ImportError ('Please upgrade mmengine >= 0.6.0 to use this script.' )
821
922
1023def parse_args ():
@@ -17,43 +30,94 @@ def parse_args():
1730 nargs = '+' ,
1831 default = [2048 , 1024 ],
1932 help = 'input image size' )
33+ parser .add_argument (
34+ '--cfg-options' ,
35+ nargs = '+' ,
36+ action = DictAction ,
37+ help = 'override some settings in the used config, the key-value pair '
38+ 'in xxx=yyy format will be merged into config file. If the value to '
39+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
40+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
41+ 'Note that the quotation marks are necessary and that no white space '
42+ 'is allowed.' )
2043 args = parser .parse_args ()
2144 return args
2245
2346
24- def main ():
47+ def inference (args : argparse .Namespace , logger : MMLogger ) -> dict :
48+ config_name = Path (args .config )
2549
26- args = parse_args ()
50+ if not config_name .exists ():
51+ logger .error (f'Config file { config_name } does not exist' )
52+
53+ cfg : Config = Config .fromfile (config_name )
54+ cfg .work_dir = tempfile .TemporaryDirectory ().name
55+ cfg .log_level = 'WARN'
56+ if args .cfg_options is not None :
57+ cfg .merge_from_dict (args .cfg_options )
58+
59+ init_default_scope (cfg .get ('scope' , 'mmseg' ))
2760
2861 if len (args .shape ) == 1 :
2962 input_shape = (3 , args .shape [0 ], args .shape [0 ])
3063 elif len (args .shape ) == 2 :
3164 input_shape = (3 , ) + tuple (args .shape )
3265 else :
3366 raise ValueError ('invalid input shape' )
67+ result = {}
3468
35- cfg = Config .fromfile (args .config )
36- cfg .model .pretrained = None
37- model = build_segmentor (
38- cfg .model ,
39- train_cfg = cfg .get ('train_cfg' ),
40- test_cfg = cfg .get ('test_cfg' )).cuda ()
69+ model : BaseSegmentor = MODELS .build (cfg .model )
70+ if hasattr (model , 'auxiliary_head' ):
71+ model .auxiliary_head = None
72+ if torch .cuda .is_available ():
73+ model .cuda ()
74+ model = revert_sync_batchnorm (model )
75+ result ['ori_shape' ] = input_shape [- 2 :]
76+ result ['pad_shape' ] = input_shape [- 2 :]
77+ data_batch = {
78+ 'inputs' : [torch .rand (input_shape )],
79+ 'data_samples' : [SegDataSample (metainfo = result )]
80+ }
81+ data = model .data_preprocessor (data_batch )
4182 model .eval ()
83+ if cfg .model .decode_head .type in ['MaskFormerHead' , 'Mask2FormerHead' ]:
84+ # TODO: Support MaskFormer and Mask2Former
85+ raise NotImplementedError ('MaskFormer and Mask2Former are not '
86+ 'supported yet.' )
87+ outputs = get_model_complexity_info (
88+ model ,
89+ input_shape ,
90+ inputs = data ['inputs' ],
91+ show_table = False ,
92+ show_arch = False )
93+ result ['flops' ] = _format_size (outputs ['flops' ])
94+ result ['params' ] = _format_size (outputs ['params' ])
95+ result ['compute_type' ] = 'direct: randomly generate a picture'
96+ return result
4297
43- if hasattr (model , 'forward_dummy' ):
44- model .forward = model .forward_dummy
45- else :
46- raise NotImplementedError (
47- 'FLOPs counter is currently not currently supported with {}' .
48- format (model .__class__ .__name__ ))
4998
50- flops , params = get_model_complexity_info (model , input_shape )
99+ def main ():
100+
101+ args = parse_args ()
102+ logger = MMLogger .get_instance (name = 'MMLogger' )
103+
104+ result = inference (args , logger )
51105 split_line = '=' * 30
52- print ('{0}\n Input shape: {1}\n Flops: {2}\n Params: {3}\n {0}' .format (
53- split_line , input_shape , flops , params ))
106+ ori_shape = result ['ori_shape' ]
107+ pad_shape = result ['pad_shape' ]
108+ flops = result ['flops' ]
109+ params = result ['params' ]
110+ compute_type = result ['compute_type' ]
111+
112+ if pad_shape != ori_shape :
113+ print (f'{ split_line } \n Use size divisor set input shape '
114+ f'from { ori_shape } to { pad_shape } ' )
115+ print (f'{ split_line } \n Compute type: { compute_type } \n '
116+ f'Input shape: { pad_shape } \n Flops: { flops } \n '
117+ f'Params: { params } \n { split_line } ' )
54118 print ('!!!Please be cautious if you use the results in papers. '
55- 'You may need to check if all ops are supported and verify that the '
56- 'flops computation is correct.' )
119+ 'You may need to check if all ops are supported and verify '
120+ 'that the flops computation is correct.' )
57121
58122
59123if __name__ == '__main__' :
0 commit comments