Skip to content

Commit 4e05f1c

Browse files
committed
1. Add tf slim resnet_v1_101 model into examples.
2. Implement pytorch model saver.
1 parent 52c0386 commit 4e05f1c

File tree

3 files changed

+32
-16
lines changed

3 files changed

+32
-16
lines changed

mmdnn/conversion/examples/imagenet_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class TestKit(object):
2121
'tensorflow' : {
2222
'vgg19' : [(21, 11.285443), (144, 10.240093), (23, 9.1792336), (22, 8.1113129), (128, 8.1065922)],
2323
'resnet' : [(22, 11.756789), (147, 8.5718527), (24, 6.1751032), (88, 4.3121386), (141, 4.1778097)],
24+
'resnet_v1_101' : [(21, 14.384739), (23, 14.262486), (144, 14.068737), (94, 12.17205), (134, 12.064575)],
2425
'inception_v3' : [(22, 9.4921198), (24, 4.0932288), (25, 3.700398), (23, 3.3715961), (147, 3.3620636)],
2526
'mobilenet' : [(22, 16.223597), (24, 14.54775), (147, 13.173758), (145, 11.36431), (728, 11.083847)]
2627
},
@@ -56,6 +57,7 @@ class TestKit(object):
5657
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, False),
5758
'inception_v3' : lambda path : TestKit.Standard(path, 299),
5859
'resnet' : lambda path : TestKit.Standard(path, 299),
60+
'resnet_v1_101' : lambda path : TestKit.ZeroCenter(path, 224, False),
5961
'resnet152' : lambda path : TestKit.Standard(path, 299),
6062
'mobilenet' : lambda path : TestKit.Standard(path, 224)
6163
},
@@ -92,11 +94,11 @@ def __init__(self):
9294
parser.add_argument('-n', type=_text_type, default='kit_imagenet',
9395
help='Network structure file name.')
9496

95-
parser.add_argument('-s', type = _text_type, help = 'Source Framework Type',
96-
choices = ["caffe", "tensorflow", "keras", "cntk", "mxnet"])
97+
parser.add_argument('-s', type=_text_type, help='Source Framework Type',
98+
choices=self.truth.keys())
9799

98-
parser.add_argument('-w',
99-
type = _text_type, help = 'Network weights file name', required = True)
100+
parser.add_argument('-w', type=_text_type, required=True,
101+
help='Network weights file name')
100102

101103
parser.add_argument('--image', '-i',
102104
type = _text_type,

mmdnn/conversion/examples/pytorch/imagenet_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def inference(self, image_path):
5151
self.test_truth()
5252

5353

54+
def dump(self, path=None):
55+
if path is None: path = self.args.dump
56+
torch.save(self.model, path)
57+
print('PyTorch model file is saved as [{}], generated by [{}.py] and [{}].'.format(
58+
path, self.args.n, self.args.w))
59+
60+
5461
if __name__=='__main__':
5562
tester = TestTorch()
56-
tester.inference(tester.args.image)
63+
if tester.args.dump:
64+
tester.dump()
65+
else:
66+
tester.inference(tester.args.image)

mmdnn/conversion/examples/tensorflow/extract_model.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,23 @@
88
import tensorflow as tf
99
from tensorflow.contrib.slim.python.slim.nets import vgg
1010
from tensorflow.contrib.slim.python.slim.nets import inception
11+
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
1112
from tensorflow.contrib.slim.python.slim.nets import resnet_v2
1213
from mmdnn.conversion.examples.imagenet_test import TestKit
1314

1415
slim = tf.contrib.slim
1516

1617
input_layer_map = {
17-
'vgg16' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 224, 224, 3]),
18-
'vgg19' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 224, 224, 3]),
19-
'inception_v1' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 224, 224, 3]),
20-
'inception_v2' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
21-
'inception_v3' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
22-
'resnet50' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
23-
'resnet101' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
24-
'resnet152' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
25-
'resnet200' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
18+
'vgg16' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
19+
'vgg19' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
20+
'inception_v1' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
21+
'inception_v2' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
22+
'inception_v3' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
23+
'resnet50' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
24+
'resnet_v1_101' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
25+
'resnet101' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
26+
'resnet152' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
27+
'resnet200' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
2628
}
2729

2830
arg_scopes_map = {
@@ -32,6 +34,7 @@
3234
'inception_v2' : inception.inception_v3_arg_scope,
3335
'inception_v3' : inception.inception_v3_arg_scope,
3436
'resnet50' : resnet_v2.resnet_arg_scope,
37+
'resnet_v1_101' : resnet_v2.resnet_arg_scope,
3538
'resnet101' : resnet_v2.resnet_arg_scope,
3639
'resnet152' : resnet_v2.resnet_arg_scope,
3740
'resnet200' : resnet_v2.resnet_arg_scope,
@@ -44,6 +47,7 @@
4447
'inception_v1' : lambda : inception.inception_v1,
4548
'inception_v2' : lambda : inception.inception_v2,
4649
'inception_v3' : lambda : inception.inception_v3,
50+
'resnet_v1_101' : lambda : resnet_v1.resnet_v1_101,
4751
'resnet50' : lambda : resnet_v2.resnet_v2_50,
4852
'resnet101' : lambda : resnet_v2.resnet_v2_101,
4953
'resnet152' : lambda : resnet_v2.resnet_v2_152,
@@ -65,11 +69,11 @@ def _main():
6569

6670
args = parser.parse_args()
6771

68-
num_classes = 1000 if args.network in ('vgg16', 'vgg19') else 1001
72+
num_classes = 1000 if args.network in ('vgg16', 'vgg19', 'resnet_v1_101') else 1001
6973

7074
with slim.arg_scope(arg_scopes_map[args.network]()):
7175
data_input = input_layer_map[args.network]()
72-
logits, endpoints = networks_map[args.network]()(data_input, num_classes = num_classes, is_training = False)
76+
logits, endpoints = networks_map[args.network]()(data_input, num_classes=num_classes, is_training=False)
7377
labels = tf.squeeze(logits)
7478

7579
init = tf.global_variables_initializer()

0 commit comments

Comments
 (0)