Skip to content

Commit 45fae72

Browse files
authored
[Feature] Support calculating FLOPs of segmentors (open-mmlab#2706)
## Motivation fix compute flops problems ## Modification Please briefly describe what modification is made in this PR.
1 parent 6468d31 commit 45fae72

File tree

1 file changed

+86
-22
lines changed

1 file changed

+86
-22
lines changed

tools/analysis_tools/get_flops.py

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import argparse
3+
import tempfile
4+
from pathlib import Path
35

4-
from mmcv.cnn import get_model_complexity_info
5-
from mmengine import Config
6+
import torch
7+
from mmengine import Config, DictAction
8+
from mmengine.logging import MMLogger
9+
from mmengine.model import revert_sync_batchnorm
10+
from mmengine.registry import init_default_scope
611

7-
from mmseg.models import build_segmentor
12+
from mmseg.models import BaseSegmentor
13+
from mmseg.registry import MODELS
14+
from mmseg.structures import SegDataSample
15+
16+
try:
17+
from mmengine.analysis import get_model_complexity_info
18+
from mmengine.analysis.print_helper import _format_size
19+
except ImportError:
20+
raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.')
821

922

1023
def parse_args():
@@ -17,43 +30,94 @@ def parse_args():
1730
nargs='+',
1831
default=[2048, 1024],
1932
help='input image size')
33+
parser.add_argument(
34+
'--cfg-options',
35+
nargs='+',
36+
action=DictAction,
37+
help='override some settings in the used config, the key-value pair '
38+
'in xxx=yyy format will be merged into config file. If the value to '
39+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
40+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
41+
'Note that the quotation marks are necessary and that no white space '
42+
'is allowed.')
2043
args = parser.parse_args()
2144
return args
2245

2346

24-
def main():
47+
def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
48+
config_name = Path(args.config)
2549

26-
args = parse_args()
50+
if not config_name.exists():
51+
logger.error(f'Config file {config_name} does not exist')
52+
53+
cfg: Config = Config.fromfile(config_name)
54+
cfg.work_dir = tempfile.TemporaryDirectory().name
55+
cfg.log_level = 'WARN'
56+
if args.cfg_options is not None:
57+
cfg.merge_from_dict(args.cfg_options)
58+
59+
init_default_scope(cfg.get('scope', 'mmseg'))
2760

2861
if len(args.shape) == 1:
2962
input_shape = (3, args.shape[0], args.shape[0])
3063
elif len(args.shape) == 2:
3164
input_shape = (3, ) + tuple(args.shape)
3265
else:
3366
raise ValueError('invalid input shape')
67+
result = {}
3468

35-
cfg = Config.fromfile(args.config)
36-
cfg.model.pretrained = None
37-
model = build_segmentor(
38-
cfg.model,
39-
train_cfg=cfg.get('train_cfg'),
40-
test_cfg=cfg.get('test_cfg')).cuda()
69+
model: BaseSegmentor = MODELS.build(cfg.model)
70+
if hasattr(model, 'auxiliary_head'):
71+
model.auxiliary_head = None
72+
if torch.cuda.is_available():
73+
model.cuda()
74+
model = revert_sync_batchnorm(model)
75+
result['ori_shape'] = input_shape[-2:]
76+
result['pad_shape'] = input_shape[-2:]
77+
data_batch = {
78+
'inputs': [torch.rand(input_shape)],
79+
'data_samples': [SegDataSample(metainfo=result)]
80+
}
81+
data = model.data_preprocessor(data_batch)
4182
model.eval()
83+
if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']:
84+
# TODO: Support MaskFormer and Mask2Former
85+
raise NotImplementedError('MaskFormer and Mask2Former are not '
86+
'supported yet.')
87+
outputs = get_model_complexity_info(
88+
model,
89+
input_shape,
90+
inputs=data['inputs'],
91+
show_table=False,
92+
show_arch=False)
93+
result['flops'] = _format_size(outputs['flops'])
94+
result['params'] = _format_size(outputs['params'])
95+
result['compute_type'] = 'direct: randomly generate a picture'
96+
return result
4297

43-
if hasattr(model, 'forward_dummy'):
44-
model.forward = model.forward_dummy
45-
else:
46-
raise NotImplementedError(
47-
'FLOPs counter is currently not currently supported with {}'.
48-
format(model.__class__.__name__))
4998

50-
flops, params = get_model_complexity_info(model, input_shape)
99+
def main():
100+
101+
args = parse_args()
102+
logger = MMLogger.get_instance(name='MMLogger')
103+
104+
result = inference(args, logger)
51105
split_line = '=' * 30
52-
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
53-
split_line, input_shape, flops, params))
106+
ori_shape = result['ori_shape']
107+
pad_shape = result['pad_shape']
108+
flops = result['flops']
109+
params = result['params']
110+
compute_type = result['compute_type']
111+
112+
if pad_shape != ori_shape:
113+
print(f'{split_line}\nUse size divisor set input shape '
114+
f'from {ori_shape} to {pad_shape}')
115+
print(f'{split_line}\nCompute type: {compute_type}\n'
116+
f'Input shape: {pad_shape}\nFlops: {flops}\n'
117+
f'Params: {params}\n{split_line}')
54118
print('!!!Please be cautious if you use the results in papers. '
55-
'You may need to check if all ops are supported and verify that the '
56-
'flops computation is correct.')
119+
'You may need to check if all ops are supported and verify '
120+
'that the flops computation is correct.')
57121

58122

59123
if __name__ == '__main__':

0 commit comments

Comments
 (0)