Skip to content

Commit f0a9798

Browse files
authored
Add Nonetype check
Add nonetype check in rename_FullyConnected Add some remind when convert unsupported op
1 parent d6ee015 commit f0a9798

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

mmdnn/conversion/pytorch/pytorch_parser.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def gen_IR(self):
9797
for layer in self.src_graph.topological_sort:
9898
current_node = self.src_graph.get_node(layer)
9999
onnx_node_type = current_node.type
100+
if onnx_node_type not in PytorchParser.layer_map.keys():
101+
print("PyTorch parser has not supported operator [%s]. IR network strucuture may lost info."
102+
% (onnx_node_type))
103+
return
100104
node_type = PytorchParser.layer_map[onnx_node_type]
101105

102106

@@ -398,17 +402,18 @@ def rename_FullyConnected(self, source_node):
398402
# weight: N x M -> C x H x W x M -> H x W x C x M -> N x M
399403
if self.weight_loaded:
400404
parent = self.src_graph.get_parent(source_node.name, [0])
401-
while parent.type == 'onnx::Flatten' or parent.type == 'onnx::Dropout':
402-
parent = self.src_graph.get_parent(parent.name, [0])
403-
if len(self.shape_dict[parent.name]) == 4:
404-
#
405-
original_shape = W.shape
406-
channel_first_list = self.shape_dict[parent.name][1:]
407-
dim = len(channel_first_list) + 1
408-
weight = W.reshape(channel_first_list + [original_shape[1]])
409-
assert dim > 2
410-
weight = weight.transpose(list(range(1, dim-1)) + [0, dim-1])
411-
W = weight.reshape(original_shape)
405+
if parent:
406+
while parent.type == 'onnx::Flatten' or parent.type == 'onnx::Dropout':
407+
parent = self.src_graph.get_parent(parent.name, [0])
408+
if len(self.shape_dict[parent.name]) == 4:
409+
#
410+
original_shape = W.shape
411+
channel_first_list = self.shape_dict[parent.name][1:]
412+
dim = len(channel_first_list) + 1
413+
weight = W.reshape(channel_first_list + [original_shape[1]])
414+
assert dim > 2
415+
weight = weight.transpose(list(range(1, dim-1)) + [0, dim-1])
416+
W = weight.reshape(original_shape)
412417

413418
# weights
414419
self.set_weight(source_node.name, 'weights', W )

0 commit comments

Comments
 (0)