@@ -33,20 +33,21 @@ def _convert_output_shape(cls, kwargs, node):
33
33
kwargs ['_output_shapes' ] = [shape ]
34
34
35
35
@classmethod
36
- def get_kernel_params (cls , node ):
36
+ def get_kernel_params (cls , node , input_shape ):
37
37
kwargs = {}
38
- if node .kernel_parameters .p_h > 0 or node .kernel_parameters .p_w > 0 :
39
- padding = [0 , node .kernel_parameters .p_h , node .kernel_parameters .p_w , 0 , 0 , node .kernel_parameters .p_h , node .kernel_parameters .p_w , 0 ]
40
- elif node .kernel_parameters .s_h > 1 or node .kernel_parameters .s_w > 1 :
41
- padding = [0 , (node .kernel_parameters .s_h - 1 ) // 2 , (node .kernel_parameters .s_w - 1 ) // 2 , 0 , 0 , node .kernel_parameters .s_h // 2 , node .kernel_parameters .s_w // 2 , 0 ]
42
- else :
43
- padding = None
44
-
45
- kwargs ['auto_pad' ] = 'VALID'
38
+
39
+ o_h_caffe = node .output_shape .height
40
+ o_h_tf = (input_shape .height + node .kernel_parameters .p_h * 2 - node .kernel_parameters .k_h + 1 ) // node .kernel_parameters .s_h
41
+
42
+ o_w_caffe = node .output_shape .width
43
+ o_w_tf = (input_shape .width + node .kernel_parameters .p_w * 2 - node .kernel_parameters .k_w + 1 ) // node .kernel_parameters .s_w
44
+
45
+ kwargs ['pads' ] = [0 , node .kernel_parameters .p_h , node .kernel_parameters .p_w , 0 ] + \
46
+ [0 , node .kernel_parameters .p_h + o_h_caffe - o_h_tf , node .kernel_parameters .p_w + o_w_caffe - o_w_tf , 0 ]
46
47
kwargs ['strides' ] = [1 , node .kernel_parameters .s_h , node .kernel_parameters .s_w , 1 ]
47
48
cls ._convert_output_shape (kwargs , node )
48
49
49
- return kwargs , { 'pads' : padding , 'mode' : 'constant' , 'constant_values' : 0.0 }
50
+ return kwargs
50
51
51
52
52
53
@classmethod
@@ -72,30 +73,22 @@ def map_input(cls, node):
72
73
73
74
@classmethod
74
75
def map_convolution (cls , node ):
75
- kwargs , padding = cls .get_kernel_params (node )
76
76
parent , _ = node .get_only_parent ()
77
+ kwargs = cls .get_kernel_params (node , parent .output_shape )
77
78
kwargs ['kernel_shape' ] = [node .kernel_parameters .k_h , node .kernel_parameters .k_w , parent .output_shape .channels , node .parameters .num_output ]
78
- kwargs ['use_bias' ] = node .parameters .bias_term
79
- group = node .parameters .group
80
- if group != 1 :
81
- kwargs ['group' ] = group
82
-
83
- if padding ['pads' ] != None :
84
- return [Node .create ('Pad' , ** padding ), Node .create ('Conv' , ** kwargs )]
85
- else :
86
- kwargs ['pads' ] = [0 ] * 8
87
- return Node .create ('Conv' , ** kwargs )
79
+ kwargs ['use_bias' ] = node .parameters .bias_term
80
+ kwargs ['group' ] = node .parameters .group
81
+ return Node .create ('Conv' , ** kwargs )
88
82
89
83
90
84
@classmethod
91
85
def map_deconvolution (cls , node ):
92
86
raise NotImplementedError ()
93
- kwargs = cls .get_kernel_params (node )
94
87
parent , _ = node .get_only_parent ()
95
- kwargs [ 'kernel_shape' ] = [ node . kernel_parameters . k_h , node . kernel_parameters . k_w , parent .output_shape . channels , node . parameters . num_output ]
96
- group = node . parameters . group
97
- if group != 1 :
98
- kwargs ['group' ] = group
88
+ kwargs = cls . get_kernel_params ( node , parent .output_shape )
89
+
90
+ kwargs [ 'kernel_shape' ] = [ node . kernel_parameters . k_h , node . kernel_parameters . k_w , parent . output_shape . channels , node . parameters . num_output ]
91
+ kwargs ['group' ] = node . parameters . group
99
92
return Node .create ('deconv' , ** kwargs )
100
93
101
94
@classmethod
@@ -115,7 +108,8 @@ def map_relu(cls, node):
115
108
116
109
@classmethod
117
110
def map_pooling (cls , node ):
118
- kwargs , padding = cls .get_kernel_params (node )
111
+ parent , _ = node .get_only_parent ()
112
+ kwargs = cls .get_kernel_params (node , parent .output_shape )
119
113
if node .parameters .pool == 0 :
120
114
kwargs ['pooling_type' ] = 'MAX'
121
115
elif node .parameters .pool == 1 :
@@ -125,12 +119,7 @@ def map_pooling(cls, node):
125
119
raise ConversionError ('Unsupported pooling type.' )
126
120
kwargs ['kernel_shape' ] = [1 , node .kernel_parameters .k_h , node .kernel_parameters .k_w , 1 ]
127
121
cls ._convert_output_shape (kwargs , node )
128
-
129
- if padding ['pads' ] != None :
130
- return [Node .create ('Pad' , ** padding ), Node .create ('Pool' , ** kwargs )]
131
- else :
132
- kwargs ['pads' ] = [0 ] * 8
133
- return Node .create ('Pool' , ** kwargs )
122
+ return Node .create ('Pool' , ** kwargs )
134
123
135
124
136
125
@classmethod
0 commit comments