Skip to content

Commit 7cd0670

Browse files
authored
Update pytorch_graph.py
Add output shape check
1 parent 9199dcb commit 7cd0670

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

mmdnn/conversion/pytorch/pytorch_graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,12 @@ def build(self, shape):
129129
for node in nodes:
130130
node_id = PytorchGraph.get_node_id(node)
131131
node_name = self.rename_nodes(node, node_id)
132-
output_shape_str = re.findall(r'[^()!]+', node.__str__())[1]
133-
if '%' in output_shape_str:
134-
out_put_shape = None
132+
output_str = node.__str__().split('=')[0]
133+
output_shape_str = re.findall(r'[^()!]+', output_str)
134+
if len(output_shape_str) > 1:
135+
output_shape = [int(x.replace('!', '')) for x in output_shape_str[1].split(',')]
135136
else:
136-
output_shape = [int(x.replace('!', '')) for x in output_shape_str.split(',')]
137+
output_shape = None
137138
self.shape_dict[node_name] = output_shape
138139
self.layer_map[node_name] = self.CreateGraphNode(node)
139140
self.layer_name_map[node_name] = node_name

0 commit comments

Comments
 (0)