Skip to content

Commit 8f37e55

Browse files
authored
add tool pytorch2torchscript (open-mmlab#469)
* add tool pytorch2torchscript * fix the assert message for pytorch version.
1 parent 67eee62 commit 8f37e55

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

docs/useful_tools.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,18 @@ Description of arguments:
7474

7575
**Note**: This tool is still experimental. Some customized operators are not supported for now.
7676

77+
### Convert to TorchScript (experimental)
78+
79+
We also provide a script to convert model to [TorchScript](https://pytorch.org/docs/stable/jit.html) format. You can use the pytorch C++ API [LibTorch](https://pytorch.org/docs/stable/cpp_index.html) inference the trained model. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and TorchScript model.
80+
81+
```shell
82+
python tools/pytorch2torchscript.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
83+
```
84+
85+
**Note**: It's only support PyTorch>=1.8.0 for now.
86+
87+
**Note**: This tool is still experimental. Some customized operators are not supported for now.
88+
7789
## Miscellaneous
7890

7991
### Print the entire config

tools/pytorch2torchscript.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import argparse
2+
3+
import mmcv
4+
import numpy as np
5+
import torch
6+
import torch._C
7+
import torch.serialization
8+
from mmcv.runner import load_checkpoint
9+
from torch import nn
10+
11+
from mmseg.models import build_segmentor
12+
13+
torch.manual_seed(3)
14+
15+
16+
def digit_version(version_str):
17+
digit_version = []
18+
for x in version_str.split('.'):
19+
if x.isdigit():
20+
digit_version.append(int(x))
21+
elif x.find('rc') != -1:
22+
patch_version = x.split('rc')
23+
digit_version.append(int(patch_version[0]) - 1)
24+
digit_version.append(int(patch_version[1]))
25+
return digit_version
26+
27+
28+
def check_torch_version():
29+
torch_minimum_version = '1.8.0'
30+
torch_version = digit_version(torch.__version__)
31+
32+
assert (torch_version >= digit_version(torch_minimum_version)), \
33+
f'Torch=={torch.__version__} is not support for converting to ' \
34+
f'torchscript. Please install pytorch>={torch_minimum_version}.'
35+
36+
37+
def _convert_batchnorm(module):
38+
module_output = module
39+
if isinstance(module, torch.nn.SyncBatchNorm):
40+
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
41+
module.momentum, module.affine,
42+
module.track_running_stats)
43+
if module.affine:
44+
module_output.weight.data = module.weight.data.clone().detach()
45+
module_output.bias.data = module.bias.data.clone().detach()
46+
# keep requires_grad unchanged
47+
module_output.weight.requires_grad = module.weight.requires_grad
48+
module_output.bias.requires_grad = module.bias.requires_grad
49+
module_output.running_mean = module.running_mean
50+
module_output.running_var = module.running_var
51+
module_output.num_batches_tracked = module.num_batches_tracked
52+
for name, child in module.named_children():
53+
module_output.add_module(name, _convert_batchnorm(child))
54+
del module
55+
return module_output
56+
57+
58+
def _demo_mm_inputs(input_shape, num_classes):
59+
"""Create a superset of inputs needed to run test or train batches.
60+
61+
Args:
62+
input_shape (tuple):
63+
input batch dimensions
64+
num_classes (int):
65+
number of semantic classes
66+
"""
67+
(N, C, H, W) = input_shape
68+
rng = np.random.RandomState(0)
69+
imgs = rng.rand(*input_shape)
70+
segs = rng.randint(
71+
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
72+
img_metas = [{
73+
'img_shape': (H, W, C),
74+
'ori_shape': (H, W, C),
75+
'pad_shape': (H, W, C),
76+
'filename': '<demo>.png',
77+
'scale_factor': 1.0,
78+
'flip': False,
79+
} for _ in range(N)]
80+
mm_inputs = {
81+
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
82+
'img_metas': img_metas,
83+
'gt_semantic_seg': torch.LongTensor(segs)
84+
}
85+
return mm_inputs
86+
87+
88+
def pytorch2libtorch(model,
89+
input_shape,
90+
show=False,
91+
output_file='tmp.pt',
92+
verify=False):
93+
"""Export Pytorch model to TorchScript model and verify the outputs are
94+
same between Pytorch and TorchScript.
95+
96+
Args:
97+
model (nn.Module): Pytorch model we want to export.
98+
input_shape (tuple): Use this input shape to construct
99+
the corresponding dummy input and execute the model.
100+
show (bool): Whether print the computation graph. Default: False.
101+
output_file (string): The path to where we store the
102+
output TorchScript model. Default: `tmp.pt`.
103+
verify (bool): Whether compare the outputs between
104+
Pytorch and TorchScript. Default: False.
105+
"""
106+
if isinstance(model.decode_head, nn.ModuleList):
107+
num_classes = model.decode_head[-1].num_classes
108+
else:
109+
num_classes = model.decode_head.num_classes
110+
111+
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
112+
113+
imgs = mm_inputs.pop('imgs')
114+
115+
# replace the orginal forword with forward_dummy
116+
model.forward = model.forward_dummy
117+
model.eval()
118+
traced_model = torch.jit.trace(
119+
model,
120+
example_inputs=imgs,
121+
check_trace=verify,
122+
)
123+
124+
if show:
125+
print(traced_model.graph)
126+
127+
traced_model.save(output_file)
128+
print('Successfully exported TorchScript model: {}'.format(output_file))
129+
130+
131+
def parse_args():
132+
parser = argparse.ArgumentParser(
133+
description='Convert MMSeg to TorchScript')
134+
parser.add_argument('config', help='test config file path')
135+
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
136+
parser.add_argument(
137+
'--show', action='store_true', help='show TorchScript graph')
138+
parser.add_argument(
139+
'--verify', action='store_true', help='verify the TorchScript model')
140+
parser.add_argument('--output-file', type=str, default='tmp.pt')
141+
parser.add_argument(
142+
'--shape',
143+
type=int,
144+
nargs='+',
145+
default=[512, 512],
146+
help='input image size (height, width)')
147+
args = parser.parse_args()
148+
return args
149+
150+
151+
if __name__ == '__main__':
152+
args = parse_args()
153+
check_torch_version()
154+
155+
if len(args.shape) == 1:
156+
input_shape = (1, 3, args.shape[0], args.shape[0])
157+
elif len(args.shape) == 2:
158+
input_shape = (
159+
1,
160+
3,
161+
) + tuple(args.shape)
162+
else:
163+
raise ValueError('invalid input shape')
164+
165+
cfg = mmcv.Config.fromfile(args.config)
166+
cfg.model.pretrained = None
167+
168+
# build the model and load checkpoint
169+
cfg.model.train_cfg = None
170+
segmentor = build_segmentor(
171+
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
172+
# convert SyncBN to BN
173+
segmentor = _convert_batchnorm(segmentor)
174+
175+
if args.checkpoint:
176+
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
177+
178+
# convert the PyTorch model to LibTorch model
179+
pytorch2libtorch(
180+
segmentor,
181+
input_shape,
182+
show=args.show,
183+
output_file=args.output_file,
184+
verify=args.verify)

0 commit comments

Comments
 (0)