Skip to content

Commit e3dbf30

Browse files
XiaoXYelvyufeng
andauthored
Pytorch1.5 update extract method (#876)
* change trace function for pytorch 1.5 (#870) Co-authored-by: XiaoXYe <[email protected]> Co-authored-by: nate.river <[email protected]>
1 parent f0a9798 commit e3dbf30

File tree

1 file changed

+86
-53
lines changed

1 file changed

+86
-53
lines changed

mmdnn/conversion/pytorch/pytorch_graph.py

Lines changed: 86 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,44 @@
1010
import contextlib
1111
from torch.jit import _unique_state_dict
1212

13+
class scope_name_workaround(object):
14+
def __init__(self):
15+
self.backup = None
16+
17+
def __enter__(self):
18+
def _tracing_name(self_, tracing_state):
19+
if not tracing_state._traced_module_stack:
20+
return None
21+
module = tracing_state._traced_module_stack[-1]
22+
for name, child in module.named_children():
23+
if child is self_:
24+
return name
25+
return None
26+
27+
def _slow_forward(self_, *input, **kwargs):
28+
tracing_state = torch._C._get_tracing_state()
29+
if not tracing_state or isinstance(self_.forward, torch._C.ScriptMethod):
30+
return self_.forward(*input, **kwargs)
31+
if not hasattr(tracing_state, '_traced_module_stack'):
32+
tracing_state._traced_module_stack = []
33+
name = _tracing_name(self_, tracing_state)
34+
if name:
35+
tracing_state.push_scope('%s[%s]' % (self_._get_name(), name))
36+
else:
37+
tracing_state.push_scope(self_._get_name())
38+
tracing_state._traced_module_stack.append(self_)
39+
try:
40+
result = self_.forward(*input, **kwargs)
41+
finally:
42+
tracing_state.pop_scope()
43+
tracing_state._traced_module_stack.pop()
44+
return result
45+
46+
self.backup = torch.nn.Module._slow_forward
47+
setattr(torch.nn.Module, '_slow_forward', _slow_forward)
1348

49+
def __exit__(self, type, value, tb):
50+
setattr(torch.nn.Module, '_slow_forward', self.backup)
1451

1552
class PytorchGraphNode(GraphNode):
1653

@@ -71,54 +108,12 @@ def __init__(self, model):
71108
self.shape_dict = dict()
72109
self.layer_weight_map = dict()
73110

74-
75-
@staticmethod
76-
def _optimize_graph(graph, aten, export_raw_ir=False):
77-
# run dce first to eliminate dead parts of the graph that might have been
78-
# left behind by things like symbolic_override
79-
80-
torch._C._jit_pass_dce(graph)
81-
torch._C._jit_pass_lint(graph)
82-
83-
torch._C._jit_pass_peephole(graph)
84-
torch._C._jit_pass_lint(graph)
85-
if not export_raw_ir:
86-
graph = torch._C._jit_pass_onnx(graph, aten)
87-
torch._C._jit_pass_lint(graph)
88-
torch._C._jit_pass_onnx_peephole(graph)
89-
torch._C._jit_pass_lint(graph)
90-
torch._C._jit_pass_dce(graph)
91-
torch._C._jit_pass_lint(graph)
92-
graph = torch._C._jit_pass_canonicalize(graph)
93-
torch._C._jit_pass_lint(graph)
94-
return graph
95-
96-
97111
@staticmethod
98112
def get_node_id(node):
99113
import re
100114
node_id = re.search(r"[\d]+", node.__str__())
101115
return node_id.group(0)
102116

103-
@contextlib.contextmanager
104-
def set_training(self, model, mode):
105-
r"""
106-
A context manager to temporarily set the training mode of 'model'
107-
to 'mode', resetting it when we exit the with-block. A no-op if
108-
mode is None.
109-
"""
110-
if mode is None:
111-
yield
112-
return
113-
old_mode = model.training
114-
if old_mode != mode:
115-
model.train(mode)
116-
try:
117-
yield
118-
finally:
119-
if old_mode != mode:
120-
model.train(old_mode)
121-
122117

123118
def build(self, shape):
124119
"""
@@ -180,6 +175,45 @@ def node_connection(self, graph, node, node_name):
180175
def CreateGraphNode(self, node):
181176
return PytorchGraphNode040(node)
182177

178+
@staticmethod
179+
def _optimize_graph(graph, aten, export_raw_ir=False):
180+
# run dce first to eliminate dead parts of the graph that might have been
181+
# left behind by things like symbolic_override
182+
183+
torch._C._jit_pass_dce(graph)
184+
torch._C._jit_pass_lint(graph)
185+
186+
torch._C._jit_pass_peephole(graph)
187+
torch._C._jit_pass_lint(graph)
188+
if not export_raw_ir:
189+
graph = torch._C._jit_pass_onnx(graph, aten)
190+
torch._C._jit_pass_lint(graph)
191+
torch._C._jit_pass_onnx_peephole(graph)
192+
torch._C._jit_pass_lint(graph)
193+
torch._C._jit_pass_dce(graph)
194+
torch._C._jit_pass_lint(graph)
195+
graph = torch._C._jit_pass_canonicalize(graph)
196+
torch._C._jit_pass_lint(graph)
197+
return graph
198+
199+
@contextlib.contextmanager
200+
def set_training(self, model, mode):
201+
r"""
202+
A context manager to temporarily set the training mode of 'model'
203+
to 'mode', resetting it when we exit the with-block. A no-op if
204+
mode is None.
205+
"""
206+
if mode is None:
207+
yield
208+
return
209+
old_mode = model.training
210+
if old_mode != mode:
211+
model.train(mode)
212+
try:
213+
yield
214+
finally:
215+
if old_mode != mode:
216+
model.train(old_mode)
183217

184218
class PytorchGraph151(PytorchGraph):
185219

@@ -188,22 +222,21 @@ def __init__(self, model):
188222

189223
def extractgraph(self, dummy_input):
190224
import re
191-
import torch.onnx.utils
192-
# connect name and id in nodes with weights
193-
graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(self.model, dummy_input, _retain_param_name=True)
225+
from torch.onnx.utils import OperatorExportTypes
226+
from torch.onnx.utils import _trace
227+
228+
self.model.eval()
229+
with scope_name_workaround():
230+
graph = _trace(self.model, dummy_input, OperatorExportTypes.ONNX)
194231
nodes = list(graph.nodes())
232+
195233
for node in nodes:
196234
# print(node.__str__())
197235
node_id = PytorchGraph.get_node_id(node)
198236
node_name = 'node' + node_id
199-
node_scope_str = re.findall(r'[^()!]+', node.__str__())[-2]
200-
for x in node_scope_str.split(','):
201-
if re.findall(r'%\S+.weight', x):
202-
node_scope = '.'.join(re.findall(r'%\S+.weight', x)[0].replace('%','',1).split('.')[:-1])
203-
self.layer_weight_map[node_name] = node_scope
204-
205-
graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(self.model, dummy_input)
206-
nodes = list(graph.nodes())
237+
self.layer_weight_map[node_name] = '.'.join(
238+
re.findall(r'\[([\w\d.]+)\]', node.scopeName())
239+
)
207240
return graph, nodes
208241

209242
def rename_nodes(self, node, node_id):

0 commit comments

Comments
 (0)