10
10
import contextlib
11
11
from torch .jit import _unique_state_dict
12
12
13
+ class scope_name_workaround (object ):
14
+ def __init__ (self ):
15
+ self .backup = None
16
+
17
+ def __enter__ (self ):
18
+ def _tracing_name (self_ , tracing_state ):
19
+ if not tracing_state ._traced_module_stack :
20
+ return None
21
+ module = tracing_state ._traced_module_stack [- 1 ]
22
+ for name , child in module .named_children ():
23
+ if child is self_ :
24
+ return name
25
+ return None
26
+
27
+ def _slow_forward (self_ , * input , ** kwargs ):
28
+ tracing_state = torch ._C ._get_tracing_state ()
29
+ if not tracing_state or isinstance (self_ .forward , torch ._C .ScriptMethod ):
30
+ return self_ .forward (* input , ** kwargs )
31
+ if not hasattr (tracing_state , '_traced_module_stack' ):
32
+ tracing_state ._traced_module_stack = []
33
+ name = _tracing_name (self_ , tracing_state )
34
+ if name :
35
+ tracing_state .push_scope ('%s[%s]' % (self_ ._get_name (), name ))
36
+ else :
37
+ tracing_state .push_scope (self_ ._get_name ())
38
+ tracing_state ._traced_module_stack .append (self_ )
39
+ try :
40
+ result = self_ .forward (* input , ** kwargs )
41
+ finally :
42
+ tracing_state .pop_scope ()
43
+ tracing_state ._traced_module_stack .pop ()
44
+ return result
45
+
46
+ self .backup = torch .nn .Module ._slow_forward
47
+ setattr (torch .nn .Module , '_slow_forward' , _slow_forward )
13
48
49
+ def __exit__ (self , type , value , tb ):
50
+ setattr (torch .nn .Module , '_slow_forward' , self .backup )
14
51
15
52
class PytorchGraphNode (GraphNode ):
16
53
@@ -71,54 +108,12 @@ def __init__(self, model):
71
108
self .shape_dict = dict ()
72
109
self .layer_weight_map = dict ()
73
110
74
-
75
- @staticmethod
76
- def _optimize_graph (graph , aten , export_raw_ir = False ):
77
- # run dce first to eliminate dead parts of the graph that might have been
78
- # left behind by things like symbolic_override
79
-
80
- torch ._C ._jit_pass_dce (graph )
81
- torch ._C ._jit_pass_lint (graph )
82
-
83
- torch ._C ._jit_pass_peephole (graph )
84
- torch ._C ._jit_pass_lint (graph )
85
- if not export_raw_ir :
86
- graph = torch ._C ._jit_pass_onnx (graph , aten )
87
- torch ._C ._jit_pass_lint (graph )
88
- torch ._C ._jit_pass_onnx_peephole (graph )
89
- torch ._C ._jit_pass_lint (graph )
90
- torch ._C ._jit_pass_dce (graph )
91
- torch ._C ._jit_pass_lint (graph )
92
- graph = torch ._C ._jit_pass_canonicalize (graph )
93
- torch ._C ._jit_pass_lint (graph )
94
- return graph
95
-
96
-
97
111
@staticmethod
98
112
def get_node_id (node ):
99
113
import re
100
114
node_id = re .search (r"[\d]+" , node .__str__ ())
101
115
return node_id .group (0 )
102
116
103
- @contextlib .contextmanager
104
- def set_training (self , model , mode ):
105
- r"""
106
- A context manager to temporarily set the training mode of 'model'
107
- to 'mode', resetting it when we exit the with-block. A no-op if
108
- mode is None.
109
- """
110
- if mode is None :
111
- yield
112
- return
113
- old_mode = model .training
114
- if old_mode != mode :
115
- model .train (mode )
116
- try :
117
- yield
118
- finally :
119
- if old_mode != mode :
120
- model .train (old_mode )
121
-
122
117
123
118
def build (self , shape ):
124
119
"""
@@ -180,6 +175,45 @@ def node_connection(self, graph, node, node_name):
180
175
def CreateGraphNode (self , node ):
181
176
return PytorchGraphNode040 (node )
182
177
178
+ @staticmethod
179
+ def _optimize_graph (graph , aten , export_raw_ir = False ):
180
+ # run dce first to eliminate dead parts of the graph that might have been
181
+ # left behind by things like symbolic_override
182
+
183
+ torch ._C ._jit_pass_dce (graph )
184
+ torch ._C ._jit_pass_lint (graph )
185
+
186
+ torch ._C ._jit_pass_peephole (graph )
187
+ torch ._C ._jit_pass_lint (graph )
188
+ if not export_raw_ir :
189
+ graph = torch ._C ._jit_pass_onnx (graph , aten )
190
+ torch ._C ._jit_pass_lint (graph )
191
+ torch ._C ._jit_pass_onnx_peephole (graph )
192
+ torch ._C ._jit_pass_lint (graph )
193
+ torch ._C ._jit_pass_dce (graph )
194
+ torch ._C ._jit_pass_lint (graph )
195
+ graph = torch ._C ._jit_pass_canonicalize (graph )
196
+ torch ._C ._jit_pass_lint (graph )
197
+ return graph
198
+
199
+ @contextlib .contextmanager
200
+ def set_training (self , model , mode ):
201
+ r"""
202
+ A context manager to temporarily set the training mode of 'model'
203
+ to 'mode', resetting it when we exit the with-block. A no-op if
204
+ mode is None.
205
+ """
206
+ if mode is None :
207
+ yield
208
+ return
209
+ old_mode = model .training
210
+ if old_mode != mode :
211
+ model .train (mode )
212
+ try :
213
+ yield
214
+ finally :
215
+ if old_mode != mode :
216
+ model .train (old_mode )
183
217
184
218
class PytorchGraph151 (PytorchGraph ):
185
219
@@ -188,22 +222,21 @@ def __init__(self, model):
188
222
189
223
def extractgraph (self , dummy_input ):
190
224
import re
191
- import torch .onnx .utils
192
- # connect name and id in nodes with weights
193
- graph , params_dict , torch_out = torch .onnx .utils ._model_to_graph (self .model , dummy_input , _retain_param_name = True )
225
+ from torch .onnx .utils import OperatorExportTypes
226
+ from torch .onnx .utils import _trace
227
+
228
+ self .model .eval ()
229
+ with scope_name_workaround ():
230
+ graph = _trace (self .model , dummy_input , OperatorExportTypes .ONNX )
194
231
nodes = list (graph .nodes ())
232
+
195
233
for node in nodes :
196
234
# print(node.__str__())
197
235
node_id = PytorchGraph .get_node_id (node )
198
236
node_name = 'node' + node_id
199
- node_scope_str = re .findall (r'[^()!]+' , node .__str__ ())[- 2 ]
200
- for x in node_scope_str .split (',' ):
201
- if re .findall (r'%\S+.weight' , x ):
202
- node_scope = '.' .join (re .findall (r'%\S+.weight' , x )[0 ].replace ('%' ,'' ,1 ).split ('.' )[:- 1 ])
203
- self .layer_weight_map [node_name ] = node_scope
204
-
205
- graph , params_dict , torch_out = torch .onnx .utils ._model_to_graph (self .model , dummy_input )
206
- nodes = list (graph .nodes ())
237
+ self .layer_weight_map [node_name ] = '.' .join (
238
+ re .findall (r'\[([\w\d.]+)\]' , node .scopeName ())
239
+ )
207
240
return graph , nodes
208
241
209
242
def rename_nodes (self , node , node_id ):
0 commit comments