File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -129,11 +129,12 @@ def build(self, shape):
129
129
for node in nodes :
130
130
node_id = PytorchGraph .get_node_id (node )
131
131
node_name = self .rename_nodes (node , node_id )
132
- output_shape_str = re .findall (r'[^()!]+' , node .__str__ ())[1 ]
133
- if '%' in output_shape_str :
134
- out_put_shape = None
132
+ output_str = node .__str__ ().split ('=' )[0 ]
133
+ output_shape_str = re .findall (r'[^()!]+' , output_str )
134
+ if len (output_shape_str ) > 1 :
135
+ output_shape = [int (x .replace ('!' , '' )) for x in output_shape_str [1 ].split (',' )]
135
136
else :
136
- output_shape = [ int ( x . replace ( '!' , '' )) for x in output_shape_str . split ( ',' )]
137
+ output_shape = None
137
138
self .shape_dict [node_name ] = output_shape
138
139
self .layer_map [node_name ] = self .CreateGraphNode (node )
139
140
self .layer_name_map [node_name ] = node_name
You can’t perform that action at this time.
0 commit comments