Skip to content

Commit 575c5c4

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Extract trace from prepare_and_convert and remove export_program (#10493)
Summary: As titled. Will be used in later changes to fix some inconsistencies. Reviewed By: zonglinpeng Differential Revision: D73440517
1 parent 072403b commit 575c5c4

File tree

2 files changed

+60
-43
lines changed

2 files changed

+60
-43
lines changed

backends/cadence/aot/compiler.py

+55-42
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from torch._inductor.decomposition import remove_decompositions
4040
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4141

42-
from torch.export import export
4342
from torch.export.exported_program import ExportedProgram
4443

4544
from .passes import get_cadence_passes
@@ -55,27 +54,24 @@
5554
# however useful for unit tests to separate the converted model from the fused
5655
# model, to be able to get reference numerics.
5756
# If this does not apply, please use quantize_and_fuse_pt2 instead.
58-
def prepare_and_convert_pt2(
57+
def trace(
5958
model: torch.nn.Module,
6059
inputs: tuple[object, ...],
61-
quantizer: CadenceQuantizer,
62-
calibration_data: Optional[list[tuple[object, ...]]] = None,
6360
dump_graphs: bool = False,
64-
) -> torch.fx.GraphModule:
61+
) -> ExportedProgram:
6562
"""
66-
Prepare and convert a model using the given quantizer.
67-
The quantizer must be supplied and be the same as the one used to
68-
fuse the model later, if applicable. If you do not expect that behavior,
69-
please use quantize_and_fuse_pt2 instead, which will instantiate a
70-
default quantizer for you if needed.
71-
If calibration data is provided, it will be used to calibrate the model. If
72-
not, the inputs will be used for calibration instead, which is useful for
73-
unit tests but should not be used for end-to-end use cases.
74-
Returns a GraphModule with the converted model.
63+
Trace the model with export_for_training and return an ExportedProgram.
7564
"""
7665

66+
# Make the model inference mode by calling model.eval()
67+
model.eval()
68+
69+
# Prevent mkldnn decompositions
70+
torch._C._set_mkldnn_enabled(False)
71+
7772
# Get default decompositions
7873
decomp_table = torch.export.default_decompositions()
74+
7975
# Select ops to keep
8076
ops_to_keep = [
8177
torch.ops.aten.conv1d.default,
@@ -85,19 +81,46 @@ def prepare_and_convert_pt2(
8581
torch.ops.aten.matmul.default,
8682
torch.ops.aten.rms_norm.default,
8783
]
84+
8885
# Remove decompositions for the ops we want to keep
8986
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
9087
remove_decompositions(decomp_table, ops_to_keep)
88+
9189
# Export with dynamo
92-
model_gm = (
93-
torch.export.export_for_training(model, inputs, strict=True)
94-
.run_decompositions(decomp_table)
95-
.module()
96-
)
90+
program = torch.export.export_for_training(
91+
model, inputs, strict=True
92+
).run_decompositions(decomp_table)
9793

9894
if dump_graphs:
9995
logging.info("Graph before quantization:")
100-
logging.info(model_gm.graph.print_tabular())
96+
logging.info(program.module().graph.print_tabular())
97+
98+
return program
99+
100+
101+
def prepare_and_convert_pt2(
102+
program: ExportedProgram,
103+
inputs: tuple[object, ...],
104+
quantizer: CadenceQuantizer,
105+
calibration_data: Optional[list[tuple[object, ...]]] = None,
106+
dump_graphs: bool = False,
107+
) -> torch.fx.GraphModule:
108+
"""
109+
Prepare and convert a model using the given quantizer.
110+
The quantizer must be supplied and be the same as the one used to
111+
fuse the model later, if applicable. If you do not expect that behavior,
112+
please use quantize_and_fuse_pt2 instead, which will instantiate a
113+
default quantizer for you if needed.
114+
If calibration data is provided, it will be used to calibrate the model. If
115+
not, the inputs will be used for calibration instead, which is useful for
116+
unit tests but should not be used for end-to-end use cases.
117+
Returns a GraphModule with the converted model.
118+
"""
119+
120+
# Get the graph module from the ExportedProgram
121+
model_gm = program.module()
122+
123+
assert isinstance(model_gm, torch.fx.GraphModule)
101124

102125
# Prepare
103126
prepared_model = prepare_pt2e(model_gm, quantizer)
@@ -121,10 +144,10 @@ def prepare_and_convert_pt2(
121144

122145

123146
# Note: this is not meant as a primary API since it can create inconsistencies
124-
# if the quantizer here is different from the quantizer used to convert. It is
125-
# however useful for unit tests to separate the converted model from the fused
126-
# model, to be able to get reference numerics.
127-
# If this does not apply, please use quantize_and_fuse_pt2 instead.
147+
# if the quantizer here is different from the quantizer used to prepare/convert.
148+
# It is however useful for unit tests to separate the converted model from the
149+
# fused model, to be able to get reference numerics.
150+
# If this does not apply, please use quantize_pt2 instead.
128151
def fuse_pt2(
129152
converted_graph_module: torch.fx.GraphModule,
130153
quantizer: CadenceQuantizer,
@@ -166,9 +189,15 @@ def quantize_pt2(
166189
if not quantizer:
167190
quantizer = CadenceDefaultQuantizer()
168191

192+
program = trace(model, inputs, dump_graphs=dump_graphs)
193+
194+
if dump_graphs:
195+
logging.info("Graph after trace:")
196+
logging.info(program.graph.print_tabular())
197+
169198
# Get converted graph module
170199
converted_gm = prepare_and_convert_pt2(
171-
model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
200+
program, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
172201
)
173202

174203
# Get fused model
@@ -181,22 +210,6 @@ def quantize_pt2(
181210
return fused_gm
182211

183212

184-
# Export the model and lower it to an ExportedProgram (in aten IR)
185-
def export_program(
186-
model: torch.nn.Module,
187-
inputs: tuple[object, ...],
188-
) -> ExportedProgram:
189-
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
190-
191-
# Prevent mkldnn decompositions
192-
torch._C._set_mkldnn_enabled(False)
193-
194-
# Export the model and return it.
195-
expo_program = export(model, inputs, strict=True)
196-
197-
return expo_program
198-
199-
200213
def lower_ep_to_edge(
201214
expo_program: ExportedProgram,
202215
dump_graphs: bool = False,
@@ -245,7 +258,7 @@ def export_to_edge(
245258
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
246259

247260
# Export the model into an ExportedProgram.
248-
expo_program = export_program(model, inputs)
261+
expo_program = trace(model, inputs)
249262

250263
# Lower the model to edge IR.
251264
edge_prog_manager = lower_ep_to_edge(expo_program, dump_graphs, constant_methods)

backends/cadence/aot/export_example.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
export_to_executorch_gen_etrecord,
1919
fuse_pt2,
2020
prepare_and_convert_pt2,
21+
trace,
2122
)
2223

2324
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
@@ -48,8 +49,11 @@ def export_model(
4849
# Instantiate the quantizer
4950
quantizer = CadenceDefaultQuantizer()
5051

52+
# Trace the model
53+
ep = trace(model, example_inputs)
54+
5155
# Convert the model
52-
converted_model = prepare_and_convert_pt2(model, example_inputs, quantizer)
56+
converted_model = prepare_and_convert_pt2(ep, example_inputs, quantizer)
5357

5458
# Get reference outputs from converted model
5559
ref_outputs = converted_model(*example_inputs)

0 commit comments

Comments
 (0)