Skip to content

Commit 1a2c2eb

Browse files
author
tiqint
committed
caffe resnet152 to tf converts ok.
1 parent 5ea31c4 commit 1a2c2eb

File tree

2 files changed

+30
-39
lines changed

2 files changed

+30
-39
lines changed

mmdnn/conversion/caffe/mapper.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,21 @@ def _convert_output_shape(cls, kwargs, node):
3333
kwargs['_output_shapes'] = [shape]
3434

3535
@classmethod
36-
def get_kernel_params(cls, node):
36+
def get_kernel_params(cls, node, input_shape):
3737
kwargs = {}
38-
if node.kernel_parameters.p_h > 0 or node.kernel_parameters.p_w > 0:
39-
padding = [0, node.kernel_parameters.p_h, node.kernel_parameters.p_w, 0, 0, node.kernel_parameters.p_h, node.kernel_parameters.p_w, 0]
40-
elif node.kernel_parameters.s_h > 1 or node.kernel_parameters.s_w > 1:
41-
padding = [0, (node.kernel_parameters.s_h - 1) // 2, (node.kernel_parameters.s_w - 1) // 2, 0, 0, node.kernel_parameters.s_h // 2, node.kernel_parameters.s_w // 2, 0]
42-
else:
43-
padding = None
44-
45-
kwargs['auto_pad'] = 'VALID'
38+
39+
o_h_caffe = node.output_shape.height
40+
o_h_tf = (input_shape.height + node.kernel_parameters.p_h * 2 - node.kernel_parameters.k_h + 1) // node.kernel_parameters.s_h
41+
42+
o_w_caffe = node.output_shape.width
43+
o_w_tf = (input_shape.width + node.kernel_parameters.p_w * 2 - node.kernel_parameters.k_w + 1) // node.kernel_parameters.s_w
44+
45+
kwargs['pads'] = [0, node.kernel_parameters.p_h, node.kernel_parameters.p_w, 0] + \
46+
[0, node.kernel_parameters.p_h + o_h_caffe - o_h_tf, node.kernel_parameters.p_w + o_w_caffe - o_w_tf, 0]
4647
kwargs['strides'] = [1, node.kernel_parameters.s_h, node.kernel_parameters.s_w, 1]
4748
cls._convert_output_shape(kwargs, node)
4849

49-
return kwargs, {'pads' : padding, 'mode' : 'constant', 'constant_values' : 0.0}
50+
return kwargs
5051

5152

5253
@classmethod
@@ -72,30 +73,22 @@ def map_input(cls, node):
7273

7374
@classmethod
7475
def map_convolution(cls, node):
75-
kwargs, padding = cls.get_kernel_params(node)
7676
parent, _ = node.get_only_parent()
77+
kwargs = cls.get_kernel_params(node, parent.output_shape)
7778
kwargs['kernel_shape'] = [node.kernel_parameters.k_h, node.kernel_parameters.k_w, parent.output_shape.channels, node.parameters.num_output]
78-
kwargs['use_bias'] = node.parameters.bias_term
79-
group = node.parameters.group
80-
if group != 1:
81-
kwargs['group'] = group
82-
83-
if padding['pads'] != None:
84-
return [Node.create('Pad', **padding), Node.create('Conv', **kwargs)]
85-
else:
86-
kwargs['pads'] = [0] * 8
87-
return Node.create('Conv', **kwargs)
79+
kwargs['use_bias'] = node.parameters.bias_term
80+
kwargs['group'] = node.parameters.group
81+
return Node.create('Conv', **kwargs)
8882

8983

9084
@classmethod
9185
def map_deconvolution(cls, node):
9286
raise NotImplementedError()
93-
kwargs = cls.get_kernel_params(node)
9487
parent, _ = node.get_only_parent()
95-
kwargs['kernel_shape'] = [node.kernel_parameters.k_h, node.kernel_parameters.k_w, parent.output_shape.channels, node.parameters.num_output]
96-
group = node.parameters.group
97-
if group != 1:
98-
kwargs['group'] = group
88+
kwargs = cls.get_kernel_params(node, parent.output_shape)
89+
90+
kwargs['kernel_shape'] = [node.kernel_parameters.k_h, node.kernel_parameters.k_w, parent.output_shape.channels, node.parameters.num_output]
91+
kwargs['group'] = node.parameters.group
9992
return Node.create('deconv', **kwargs)
10093

10194
@classmethod
@@ -115,7 +108,8 @@ def map_relu(cls, node):
115108

116109
@classmethod
117110
def map_pooling(cls, node):
118-
kwargs, padding = cls.get_kernel_params(node)
111+
parent, _ = node.get_only_parent()
112+
kwargs = cls.get_kernel_params(node, parent.output_shape)
119113
if node.parameters.pool == 0:
120114
kwargs['pooling_type'] = 'MAX'
121115
elif node.parameters.pool == 1:
@@ -125,12 +119,7 @@ def map_pooling(cls, node):
125119
raise ConversionError('Unsupported pooling type.')
126120
kwargs['kernel_shape'] = [1, node.kernel_parameters.k_h, node.kernel_parameters.k_w, 1]
127121
cls._convert_output_shape(kwargs, node)
128-
129-
if padding['pads'] != None:
130-
return [Node.create('Pad', **padding), Node.create('Pool', **kwargs)]
131-
else:
132-
kwargs['pads'] = [0] * 8
133-
return Node.create('Pool', **kwargs)
122+
return Node.create('Pool', **kwargs)
134123

135124

136125
@classmethod

mmdnn/conversion/examples/tensorflow/imagenet_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TestTF(TestKit):
1515
def __init__(self):
1616
super(TestTF, self).__init__()
1717
self.input, self.model = self.MainModel.KitModel(self.args.w)
18-
# self.input, self.model, self.testop = KitModel(os.path.abspath('.') + '/kit_imagenet.npy')
18+
# self.input, self.model, self.testop = self.MainModel.KitModel(self.args.w)
1919

2020

2121
def preprocess(self, image_path):
@@ -24,7 +24,9 @@ def preprocess(self, image_path):
2424

2525

2626
def print_result(self):
27-
with tf.Session() as sess:
27+
with tf.Session() as sess:
28+
writer = tf.summary.FileWriter('./graphs', sess.graph)
29+
writer.close()
2830
init = tf.global_variables_initializer()
2931
sess.run(init)
3032
predict = sess.run(self.model, feed_dict = {self.input : self.data})
@@ -33,20 +35,20 @@ def print_result(self):
3335

3436

3537
def print_intermediate_result(self, layer_name, if_transpose = False):
36-
testop = tf.get_default_graph().get_operation_by_name(layer_name)
37-
# testop = self.testop
38+
# testop = tf.get_default_graph().get_operation_by_name(layer_name)
39+
testop = self.testop
3840
with tf.Session() as sess:
3941
init = tf.global_variables_initializer()
4042
sess.run(init)
4143
intermediate_output = sess.run(testop, feed_dict = {self.input : self.data})
4244

43-
super(TestTF, self).predict(intermediate_output, if_transpose)
45+
super(TestTF, self).print_intermediate_result(intermediate_output, if_transpose)
4446

4547

4648
def inference(self, image_path):
4749
self.preprocess(image_path)
4850

49-
# self.print_intermediate_result('conv1_7x7_s2_1', False)
51+
# self.print_intermediate_result('conv1_7x7_s2_1', True)
5052

5153
self.print_result()
5254

0 commit comments

Comments
 (0)