Skip to content

Commit 349169b

Browse files
committed
caffe global pooling supported. caffe xception -> tf tested.
1 parent 94fbbac commit 349169b

File tree

2 files changed

+41
-24
lines changed

2 files changed

+41
-24
lines changed

mmdnn/conversion/caffe/graph.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105

106106
LayerType = type('LayerType', (), {t : t for t in LAYER_TYPES})
107107

108-
KernelParameters = namedtuple('KernelParameters', ['k_h', 'k_w', 's_h', 's_w', 'p_h', 'p_w'])
108+
KernelParameters = namedtuple('KernelParameters', ['global_pooling', 'k_h', 'k_w', 's_h', 's_w', 'p_h', 'p_w'])
109109

110110
class NodeKind(LayerType):
111111

@@ -185,45 +185,50 @@ def get_kernel_value(scalar, repeated, idx, default=None):
185185
def kernel_parameters(self):
186186
assert self.kind in (NodeKind.Convolution, NodeKind.Pooling, NodeKind.Deconvolution)
187187
params = self.parameters
188-
k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0)
189-
k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1)
190-
s_h = self.get_kernel_value(params.stride_h, params.stride, 0, default=1)
191-
s_w = self.get_kernel_value(params.stride_w, params.stride, 1, default=1)
188+
global_pooling = hasattr(params, 'global_pooling') and params.global_pooling
189+
if not global_pooling:
190+
k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0)
191+
k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1)
192+
s_h = self.get_kernel_value(params.stride_h, params.stride, 0, default=1)
193+
s_w = self.get_kernel_value(params.stride_w, params.stride, 1, default=1)
194+
else:
195+
k_h = k_w = 0
196+
s_h = s_w = 1
192197
p_h = self.get_kernel_value(params.pad_h, params.pad, 0, default=0)
193198
p_w = self.get_kernel_value(params.pad_h, params.pad, 1, default=0)
194-
return KernelParameters(k_h, k_w, s_h, s_w, p_h, p_w)
199+
return KernelParameters(global_pooling, k_h, k_w, s_h, s_w, p_h, p_w)
195200

196201
def __str__(self):
197202
return '[%s] %s' % (self.kind, self.name)
198-
203+
199204
def __repr__(self):
200205
return '%s (0x%x)' %(self.name, id(self))
201206

202207

203208
class CaffeGraph(object):
204-
209+
205210
def __init__(self, nodes=None, name=None):
206211
self.nodes = nodes or []
207212
self.node_lut = {node.name: node for node in self.nodes}
208213
self.name = name
209214
self.prototxt = None
210-
215+
211216
def add_node(self, node):
212217
self.nodes.append(node)
213218
self.node_lut[node.name] = node
214-
219+
215220
def get_node(self, name):
216221
try:
217222
return self.node_lut[name]
218223
except KeyError:
219224
raise ConversionError('Layer not found: %s' % name)
220-
225+
221226
def get_input_nodes(self):
222227
return [node for node in self.nodes if len(node.parents) == 0]
223228

224229
def get_output_nodes(self):
225230
return [node for node in self.nodes if len(node.children) == 0]
226-
231+
227232
def topologically_sorted(self):
228233
visited = set()
229234
sorted_nodes = []
@@ -263,11 +268,11 @@ def compute_output_shapes(self, model):
263268
else:
264269
for node in sorted_nodes:
265270
node.output_shape = TensorShape(*NodeKind.compute_output_shape(node))
266-
271+
267272
# consider rewrite this function to Network.py
268273
def replaced(self, new_nodes):
269274
return CaffeGraph(nodes=new_nodes, name=self.name)
270-
275+
271276
def transformed(self, transformers):
272277
graph = self
273278
for transformer in transformers:
@@ -316,7 +321,7 @@ def load(self):
316321
text_format.Merge(f.read(), self.model)
317322
if self.is_train_proto:
318323
self.process_train_proto()
319-
324+
320325
def process_train_proto(self):
321326
layers = self.model.layer or self.model.layers
322327
delete_layer = set()
@@ -359,7 +364,7 @@ def process_train_proto(self):
359364
elif kind == NodeKind.SigmoidCrossEntropyLoss:
360365
pred.type = NodeKind.Sigmoid if self.model.layer else 19
361366
layers.remove(last_layer)
362-
367+
363368
def filter_layers(self, layers):
364369
phase_map = {0: 'train', 1: 'test'}
365370
filtered_layer_names = set()
@@ -388,7 +393,7 @@ def filter_layers(self, layers):
388393
return filtered_layers
389394

390395
def make_node(self, layer):
391-
kind = NodeKind.map_raw_kind(layer.type)
396+
kind = NodeKind.map_raw_kind(layer.type)
392397
if kind is None:
393398
# TODO: raise error
394399
pass

mmdnn/conversion/caffe/mapper.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,27 @@ def _convert_output_shape(cls, kwargs, node):
3636
def get_kernel_params(cls, node, input_shape):
3737
kwargs = {}
3838

39-
o_h_caffe = node.output_shape.height
40-
o_h_tf = (input_shape.height + node.kernel_parameters.p_h * 2 - node.kernel_parameters.k_h + 1) // node.kernel_parameters.s_h
39+
if node.kernel_parameters.global_pooling:
40+
kwargs['kernel_shape'] = [1, input_shape.height, input_shape.width, 1]
41+
kwargs['pads'] = [0] * 8
4142

42-
o_w_caffe = node.output_shape.width
43-
o_w_tf = (input_shape.width + node.kernel_parameters.p_w * 2 - node.kernel_parameters.k_w + 1) // node.kernel_parameters.s_w
43+
else:
44+
from mmdnn.conversion.caffe.graph import NodeKind
45+
if node.kind == NodeKind.Pooling:
46+
kwargs['kernel_shape'] = [1, node.kernel_parameters.k_h, node.kernel_parameters.k_w, 1]
47+
elif node.kind == NodeKind.Convolution:
48+
pass
49+
else:
50+
raise ValueError
51+
52+
o_h_caffe = node.output_shape.height
53+
o_h_tf = (input_shape.height + node.kernel_parameters.p_h * 2 - node.kernel_parameters.k_h + 1) // node.kernel_parameters.s_h
54+
o_w_caffe = node.output_shape.width
55+
o_w_tf = (input_shape.width + node.kernel_parameters.p_w * 2 - node.kernel_parameters.k_w + 1) // node.kernel_parameters.s_w
56+
57+
kwargs['pads'] = [0, node.kernel_parameters.p_h, node.kernel_parameters.p_w, 0] + \
58+
[0, node.kernel_parameters.p_h + o_h_caffe - o_h_tf, node.kernel_parameters.p_w + o_w_caffe - o_w_tf, 0]
4459

45-
kwargs['pads'] = [0, node.kernel_parameters.p_h, node.kernel_parameters.p_w, 0] + \
46-
[0, node.kernel_parameters.p_h + o_h_caffe - o_h_tf, node.kernel_parameters.p_w + o_w_caffe - o_w_tf, 0]
4760
kwargs['strides'] = [1, node.kernel_parameters.s_h, node.kernel_parameters.s_w, 1]
4861
cls._convert_output_shape(kwargs, node)
4962

@@ -117,7 +130,6 @@ def map_pooling(cls, node):
117130
else:
118131
# Stochastic pooling, for instance.
119132
raise ConversionError('Unsupported pooling type.')
120-
kwargs['kernel_shape'] = [1, node.kernel_parameters.k_h, node.kernel_parameters.k_w, 1]
121133
cls._convert_output_shape(kwargs, node)
122134
return Node.create('Pool', **kwargs)
123135

0 commit comments

Comments
 (0)