39
39
from torch ._inductor .decomposition import remove_decompositions
40
40
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
41
41
42
- from torch .export import export
43
42
from torch .export .exported_program import ExportedProgram
44
43
45
44
from .passes import get_cadence_passes
55
54
# however useful for unit tests to separate the converted model from the fused
56
55
# model, to be able to get reference numerics.
57
56
# If this does not apply, please use quantize_and_fuse_pt2 instead.
58
- def prepare_and_convert_pt2 (
57
+ def trace (
59
58
model : torch .nn .Module ,
60
59
inputs : tuple [object , ...],
61
- quantizer : CadenceQuantizer ,
62
- calibration_data : Optional [list [tuple [object , ...]]] = None ,
63
60
dump_graphs : bool = False ,
64
- ) -> torch . fx . GraphModule :
61
+ ) -> ExportedProgram :
65
62
"""
66
- Prepare and convert a model using the given quantizer.
67
- The quantizer must be supplied and be the same as the one used to
68
- fuse the model later, if applicable. If you do not expect that behavior,
69
- please use quantize_and_fuse_pt2 instead, which will instantiate a
70
- default quantizer for you if needed.
71
- If calibration data is provided, it will be used to calibrate the model. If
72
- not, the inputs will be used for calibration instead, which is useful for
73
- unit tests but should not be used for end-to-end use cases.
74
- Returns a GraphModule with the converted model.
63
+ Trace the model with export_for_training and return an ExportedProgram.
75
64
"""
76
65
66
+ # Make the model inference mode by calling model.eval()
67
+ model .eval ()
68
+
69
+ # Prevent mkldnn decompositions
70
+ torch ._C ._set_mkldnn_enabled (False )
71
+
77
72
# Get default decompositions
78
73
decomp_table = torch .export .default_decompositions ()
74
+
79
75
# Select ops to keep
80
76
ops_to_keep = [
81
77
torch .ops .aten .conv1d .default ,
@@ -85,19 +81,46 @@ def prepare_and_convert_pt2(
85
81
torch .ops .aten .matmul .default ,
86
82
torch .ops .aten .rms_norm .default ,
87
83
]
84
+
88
85
# Remove decompositions for the ops we want to keep
89
86
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
90
87
remove_decompositions (decomp_table , ops_to_keep )
88
+
91
89
# Export with dynamo
92
- model_gm = (
93
- torch .export .export_for_training (model , inputs , strict = True )
94
- .run_decompositions (decomp_table )
95
- .module ()
96
- )
90
+ program = torch .export .export_for_training (
91
+ model , inputs , strict = True
92
+ ).run_decompositions (decomp_table )
97
93
98
94
if dump_graphs :
99
95
logging .info ("Graph before quantization:" )
100
- logging .info (model_gm .graph .print_tabular ())
96
+ logging .info (program .module ().graph .print_tabular ())
97
+
98
+ return program
99
+
100
+
101
+ def prepare_and_convert_pt2 (
102
+ program : ExportedProgram ,
103
+ inputs : tuple [object , ...],
104
+ quantizer : CadenceQuantizer ,
105
+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
106
+ dump_graphs : bool = False ,
107
+ ) -> torch .fx .GraphModule :
108
+ """
109
+ Prepare and convert a model using the given quantizer.
110
+ The quantizer must be supplied and be the same as the one used to
111
+ fuse the model later, if applicable. If you do not expect that behavior,
112
+ please use quantize_and_fuse_pt2 instead, which will instantiate a
113
+ default quantizer for you if needed.
114
+ If calibration data is provided, it will be used to calibrate the model. If
115
+ not, the inputs will be used for calibration instead, which is useful for
116
+ unit tests but should not be used for end-to-end use cases.
117
+ Returns a GraphModule with the converted model.
118
+ """
119
+
120
+ # Get the graph module from the ExportedProgram
121
+ model_gm = program .module ()
122
+
123
+ assert isinstance (model_gm , torch .fx .GraphModule )
101
124
102
125
# Prepare
103
126
prepared_model = prepare_pt2e (model_gm , quantizer )
@@ -121,10 +144,10 @@ def prepare_and_convert_pt2(
121
144
122
145
123
146
# Note: this is not meant as a primary API since it can create inconsistencies
124
- # if the quantizer here is different from the quantizer used to convert. It is
125
- # however useful for unit tests to separate the converted model from the fused
126
- # model, to be able to get reference numerics.
127
- # If this does not apply, please use quantize_and_fuse_pt2 instead.
147
+ # if the quantizer here is different from the quantizer used to prepare/ convert.
148
+ # It is however useful for unit tests to separate the converted model from the
149
+ # fused model, to be able to get reference numerics.
150
+ # If this does not apply, please use quantize_pt2 instead.
128
151
def fuse_pt2 (
129
152
converted_graph_module : torch .fx .GraphModule ,
130
153
quantizer : CadenceQuantizer ,
@@ -167,9 +190,15 @@ def quantize_pt2(
167
190
if not quantizer :
168
191
quantizer = CadenceDefaultQuantizer ()
169
192
193
+ program = trace (model , inputs , dump_graphs = dump_graphs )
194
+
195
+ if dump_graphs :
196
+ logging .info ("Graph after trace:" )
197
+ logging .info (program .graph .print_tabular ())
198
+
170
199
# Get converted graph module
171
200
converted_gm = prepare_and_convert_pt2 (
172
- model , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
201
+ program , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
173
202
)
174
203
175
204
# Get fused model
@@ -184,22 +213,6 @@ def quantize_pt2(
184
213
return program
185
214
186
215
187
- # Export the model and lower it to an ExportedProgram (in aten IR)
188
- def export_program (
189
- model : torch .nn .Module ,
190
- inputs : tuple [object , ...],
191
- ) -> ExportedProgram :
192
- assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
193
-
194
- # Prevent mkldnn decompositions
195
- torch ._C ._set_mkldnn_enabled (False )
196
-
197
- # Export the model and return it.
198
- expo_program = export (model , inputs , strict = True )
199
-
200
- return expo_program
201
-
202
-
203
216
def _lower_ep_to_edge (
204
217
expo_program : ExportedProgram ,
205
218
dump_graphs : bool = False ,
@@ -248,7 +261,7 @@ def export_to_edge(
248
261
assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
249
262
250
263
# Export the model into an ExportedProgram.
251
- expo_program = export_program (model , inputs )
264
+ expo_program = trace (model , inputs )
252
265
253
266
# Lower the model to edge IR.
254
267
edge_prog_manager = _lower_ep_to_edge (expo_program , dump_graphs , constant_methods )
0 commit comments