Skip to content

Commit 9daec88

Browse files
committed
Implement pytorch model extractor.
1 parent d86e32c commit 9daec88

File tree

3 files changed

+163
-7
lines changed

3 files changed

+163
-7
lines changed

mmdnn/conversion/examples/imagenet_test.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,9 @@
66
from __future__ import absolute_import
77
import argparse
88
import numpy as np
9-
import sys
10-
import os
119
from six import text_type as _text_type
12-
13-
# work for tf 1.4 in windows & linux
1410
from tensorflow.contrib.keras.api.keras.preprocessing import image
1511

16-
# work for tf 1.3 & 1.4 in linux
17-
# from tensorflow.contrib.keras.python.keras.preprocessing import image
18-
1912

2013
class TestKit(object):
2114

@@ -44,6 +37,10 @@ class TestKit(object):
4437
'resnet' : [(21, 0.84012794), (144, 0.097428247), (23, 0.039757393), (146, 0.010432643), (99, 0.0023797606)],
4538
'squeezenet' : [(21, 0.36026478), (128, 0.084114805), (835, 0.07940048), (144, 0.057378717), (749, 0.053491514)],
4639
'inception_bn' : [(21, 0.84332663), (144, 0.041747514), (677, 0.021810319), (973, 0.02054958), (115, 0.008529461)]
40+
},
41+
'pytorch' :{
42+
'resnet152' : [(21, 13.080057), (141, 12.32998), (94, 9.8761454), (146, 9.3761511), (143, 8.9194641)],
43+
'vgg19' : [(821, 8.4734678), (562, 8.3472366), (835, 8.2712851), (749, 7.792901), (807, 6.6604013)],
4744
}
4845
}
4946

@@ -76,6 +73,12 @@ class TestKit(object):
7673
'resnet' : lambda path : TestKit.Identity(path, 224, True),
7774
'squeezenet' : lambda path : TestKit.ZeroCenter(path, 224, False),
7875
'inception_bn' : lambda path : TestKit.Identity(path, 224, False)
76+
},
77+
78+
'pytorch' : {
79+
'vgg19' : lambda path : TestKit.Normalize(path),
80+
'resnet152' : lambda path : TestKit.Normalize(path),
81+
'inception_v3' : lambda path : TestKit.Normalize(path),
7982
}
8083
}
8184

@@ -122,6 +125,17 @@ def ZeroCenter(path, size, BGRTranspose=False):
122125
return x
123126

124127

128+
@staticmethod
129+
def Normalize(path, size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
130+
img = image.load_img(path, target_size=(size, size))
131+
x = image.img_to_array(img)
132+
x /= 255.0
133+
for i in range(0, 3):
134+
x[..., i] -= mean[i]
135+
x[..., i] /= std[i]
136+
return x
137+
138+
125139
@staticmethod
126140
def Standard(path, size):
127141
img = image.load_img(path, target_size = (size, size))
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#----------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
#----------------------------------------------------------------------------------------------
5+
6+
import argparse
7+
import os
8+
from six import text_type as _text_type
9+
from mmdnn.conversion.examples.imagenet_test import TestKit
10+
import torch
11+
import torchvision.models as models
12+
13+
14+
NETWORKS_MAP = {
15+
'inception_v3' : lambda : models.inception_v3(pretrained=True),
16+
'vgg16' : lambda : models.vgg16(pretrained=True),
17+
'vgg19' : lambda : models.vgg19(pretrained=True),
18+
'resnet152' : lambda : models.resnet152(pretrained=True),
19+
'densenet' : lambda : models.densenet201(pretrained=True),
20+
'squeezenet' : lambda : models.squeezenet1_1(pretrained=True)
21+
}
22+
23+
24+
def _main():
25+
parser = argparse.ArgumentParser()
26+
27+
parser.add_argument('-n', '--network',
28+
type=_text_type, help='Model Type', required=True,
29+
choices=NETWORKS_MAP.keys())
30+
31+
parser.add_argument('-i', '--image', type=_text_type, help='Test Image Path')
32+
33+
args = parser.parse_args()
34+
35+
file_name = "imagenet_{}.pt".format(args.network)
36+
if not os.path.exists(file_name):
37+
model = NETWORKS_MAP.get(args.network)
38+
model = model()
39+
torch.save(model, file_name)
40+
print("PyTorch pretrained model is saved as [{}].".format(file_name))
41+
else:
42+
print("File [{}] existed!".format(file_name))
43+
model = torch.load(file_name)
44+
45+
if args.image:
46+
import numpy as np
47+
func = TestKit.preprocess_func['pytorch'][args.network]
48+
img = func(args.image)
49+
img = np.transpose(img, (2, 0, 1))
50+
img = np.expand_dims(img, 0).copy()
51+
data = torch.from_numpy(img)
52+
data = torch.autograd.Variable(data, requires_grad=False)
53+
54+
model.eval()
55+
predict = model(data).data.numpy()
56+
predict = np.squeeze(predict)
57+
top_indices = predict.argsort()[-5:][::-1]
58+
result = [(i, predict[i]) for i in top_indices]
59+
print(result)
60+
61+
# layer_name = 'block2_pool'
62+
# intermediate_layer_model = keras.Model(inputs=model.input,
63+
# outputs=model.get_layer(layer_name).output)
64+
# intermediate_output = intermediate_layer_model.predict(img)
65+
# print (intermediate_output)
66+
# print (intermediate_output.shape)
67+
# print ("%.30f" % np.sum(intermediate_output))
68+
69+
70+
if __name__ == '__main__':
71+
_main()

mmdnn/conversion/pytorch/inference.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
# Licensed under the MIT license. See LICENSE.md file in the project root
4+
# for full license information.
5+
# ==============================================================================
6+
7+
import argparse
8+
import numpy as np
9+
import sys
10+
import os
11+
import caffe
12+
13+
def inference(args):
14+
net = caffe.Net(args.network, args.weights, caffe.TEST)
15+
16+
from tensorflow.contrib.keras.python.keras.preprocessing import image
17+
image_path = 'mmdnn/conversion/examples/data/seagull.jpg'
18+
19+
if args.preprocess == 'vgg':
20+
img = image.load_img(image_path, target_size=(224, 224))
21+
x = image.img_to_array(img)
22+
# Zero-center by mean pixel
23+
x = x[..., ::-1]
24+
x[..., 0] -= 103.939
25+
x[..., 1] -= 116.779
26+
x[..., 2] -= 123.68
27+
28+
x = np.transpose(x, [2, 0, 1])
29+
30+
elif args.preprocess == 'resnet' or args.preprocess == 'inception':
31+
img = image.load_img(image_path, target_size=(224, 224))
32+
x = image.img_to_array(img)
33+
x /= 255.0
34+
x -= 0.5
35+
x *= 2.0
36+
x = np.transpose(x, [2, 0, 1])
37+
38+
else:
39+
assert False
40+
41+
x = np.expand_dims(x, 0)
42+
net.blobs['data'].data[...] = x
43+
predict = np.squeeze(net.forward()['prob'][0])
44+
45+
test = 'pool1/norm1'
46+
immediate_data = net.blobs[test].data[0]
47+
print (immediate_data)
48+
print (immediate_data.shape)
49+
print ("%.30f" % np.sum(np.array(immediate_data)))
50+
51+
top_indices = predict.argsort()[-5:][::-1]
52+
result = [(i, predict[i]) for i in top_indices]
53+
print (result)
54+
55+
56+
if __name__=='__main__':
57+
parser = argparse.ArgumentParser()
58+
59+
parser.add_argument('-p', '--preprocess',
60+
type=str, choices = ["vgg", "resnet", "inception"], help='Model Preprocess Type', required=False, default='vgg')
61+
62+
parser.add_argument('-n', '--network',
63+
type=str, required=True)
64+
65+
parser.add_argument('-w', '--weights',
66+
type=str, required=True)
67+
68+
69+
args = parser.parse_args()
70+
71+
inference(args)

0 commit comments

Comments
 (0)