Skip to content

Commit 11de71d

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Extract trace from prepare_and_convert and remove export_program (#10493)
Summary: Rewrite a test to use the graph builder and avoid API issues. As titled. Will be used in later changes to fix some inconsistencies. Reviewed By: zonglinpeng Differential Revision: D73440517
1 parent 1656854 commit 11de71d

File tree

3 files changed

+92
-74
lines changed

3 files changed

+92
-74
lines changed

backends/cadence/aot/compiler.py

+55-42
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from torch._inductor.decomposition import remove_decompositions
4040
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4141

42-
from torch.export import export
4342
from torch.export.exported_program import ExportedProgram
4443

4544
from .passes import get_cadence_passes
@@ -55,27 +54,24 @@
5554
# however useful for unit tests to separate the converted model from the fused
5655
# model, to be able to get reference numerics.
5756
# If this does not apply, please use quantize_and_fuse_pt2 instead.
58-
def prepare_and_convert_pt2(
57+
def trace(
5958
model: torch.nn.Module,
6059
inputs: tuple[object, ...],
61-
quantizer: CadenceQuantizer,
62-
calibration_data: Optional[list[tuple[object, ...]]] = None,
6360
dump_graphs: bool = False,
64-
) -> torch.fx.GraphModule:
61+
) -> ExportedProgram:
6562
"""
66-
Prepare and convert a model using the given quantizer.
67-
The quantizer must be supplied and be the same as the one used to
68-
fuse the model later, if applicable. If you do not expect that behavior,
69-
please use quantize_and_fuse_pt2 instead, which will instantiate a
70-
default quantizer for you if needed.
71-
If calibration data is provided, it will be used to calibrate the model. If
72-
not, the inputs will be used for calibration instead, which is useful for
73-
unit tests but should not be used for end-to-end use cases.
74-
Returns a GraphModule with the converted model.
63+
Trace the model with export_for_training and return an ExportedProgram.
7564
"""
7665

66+
# Make the model inference mode by calling model.eval()
67+
model.eval()
68+
69+
# Prevent mkldnn decompositions
70+
torch._C._set_mkldnn_enabled(False)
71+
7772
# Get default decompositions
7873
decomp_table = torch.export.default_decompositions()
74+
7975
# Select ops to keep
8076
ops_to_keep = [
8177
torch.ops.aten.conv1d.default,
@@ -85,19 +81,46 @@ def prepare_and_convert_pt2(
8581
torch.ops.aten.matmul.default,
8682
torch.ops.aten.rms_norm.default,
8783
]
84+
8885
# Remove decompositions for the ops we want to keep
8986
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
9087
remove_decompositions(decomp_table, ops_to_keep)
88+
9189
# Export with dynamo
92-
model_gm = (
93-
torch.export.export_for_training(model, inputs, strict=True)
94-
.run_decompositions(decomp_table)
95-
.module()
96-
)
90+
program = torch.export.export_for_training(
91+
model, inputs, strict=True
92+
).run_decompositions(decomp_table)
9793

9894
if dump_graphs:
9995
logging.info("Graph before quantization:")
100-
logging.info(model_gm.graph.print_tabular())
96+
logging.info(program.module().graph.print_tabular())
97+
98+
return program
99+
100+
101+
def prepare_and_convert_pt2(
102+
program: ExportedProgram,
103+
inputs: tuple[object, ...],
104+
quantizer: CadenceQuantizer,
105+
calibration_data: Optional[list[tuple[object, ...]]] = None,
106+
dump_graphs: bool = False,
107+
) -> torch.fx.GraphModule:
108+
"""
109+
Prepare and convert a model using the given quantizer.
110+
The quantizer must be supplied and be the same as the one used to
111+
fuse the model later, if applicable. If you do not expect that behavior,
112+
please use quantize_and_fuse_pt2 instead, which will instantiate a
113+
default quantizer for you if needed.
114+
If calibration data is provided, it will be used to calibrate the model. If
115+
not, the inputs will be used for calibration instead, which is useful for
116+
unit tests but should not be used for end-to-end use cases.
117+
Returns a GraphModule with the converted model.
118+
"""
119+
120+
# Get the graph module from the ExportedProgram
121+
model_gm = program.module()
122+
123+
assert isinstance(model_gm, torch.fx.GraphModule)
101124

102125
# Prepare
103126
prepared_model = prepare_pt2e(model_gm, quantizer)
@@ -121,10 +144,10 @@ def prepare_and_convert_pt2(
121144

122145

123146
# Note: this is not meant as a primary API since it can create inconsistencies
124-
# if the quantizer here is different from the quantizer used to convert. It is
125-
# however useful for unit tests to separate the converted model from the fused
126-
# model, to be able to get reference numerics.
127-
# If this does not apply, please use quantize_and_fuse_pt2 instead.
147+
# if the quantizer here is different from the quantizer used to prepare/convert.
148+
# It is however useful for unit tests to separate the converted model from the
149+
# fused model, to be able to get reference numerics.
150+
# If this does not apply, please use quantize_pt2 instead.
128151
def fuse_pt2(
129152
converted_graph_module: torch.fx.GraphModule,
130153
quantizer: CadenceQuantizer,
@@ -167,9 +190,15 @@ def quantize_pt2(
167190
if not quantizer:
168191
quantizer = CadenceDefaultQuantizer()
169192

193+
program = trace(model, inputs, dump_graphs=dump_graphs)
194+
195+
if dump_graphs:
196+
logging.info("Graph after trace:")
197+
logging.info(program.graph.print_tabular())
198+
170199
# Get converted graph module
171200
converted_gm = prepare_and_convert_pt2(
172-
model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
201+
program, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
173202
)
174203

175204
# Get fused model
@@ -184,22 +213,6 @@ def quantize_pt2(
184213
return program
185214

186215

187-
# Export the model and lower it to an ExportedProgram (in aten IR)
188-
def export_program(
189-
model: torch.nn.Module,
190-
inputs: tuple[object, ...],
191-
) -> ExportedProgram:
192-
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
193-
194-
# Prevent mkldnn decompositions
195-
torch._C._set_mkldnn_enabled(False)
196-
197-
# Export the model and return it.
198-
expo_program = export(model, inputs, strict=True)
199-
200-
return expo_program
201-
202-
203216
def _lower_ep_to_edge(
204217
expo_program: ExportedProgram,
205218
dump_graphs: bool = False,
@@ -248,7 +261,7 @@ def export_to_edge(
248261
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
249262

250263
# Export the model into an ExportedProgram.
251-
expo_program = export_program(model, inputs)
264+
expo_program = trace(model, inputs)
252265

253266
# Lower the model to edge IR.
254267
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)

backends/cadence/aot/export_example.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
export_to_executorch_gen_etrecord,
1919
fuse_pt2,
2020
prepare_and_convert_pt2,
21+
trace,
2122
)
2223

2324
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
@@ -48,8 +49,11 @@ def export_model(
4849
# Instantiate the quantizer
4950
quantizer = CadenceDefaultQuantizer()
5051

52+
# Trace the model
53+
ep = trace(model, example_inputs)
54+
5155
# Convert the model
52-
converted_model = prepare_and_convert_pt2(model, example_inputs, quantizer)
56+
converted_model = prepare_and_convert_pt2(ep, example_inputs, quantizer)
5357

5458
# Get reference outputs from converted model
5559
ref_outputs = converted_model(*example_inputs)

backends/cadence/aot/tests/test_remove_ops_passes.py

+32-31
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
import torch.nn.functional as F
1717
from executorch.backends.cadence.aot import compiler
1818
from executorch.backends.cadence.aot.compiler import export_to_edge
19+
from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass
1920
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
2021

2122
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
22-
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2323
from executorch.backends.cadence.aot.remove_ops import (
2424
RemoveAliasCopyOpPass,
2525
RemoveBranchedQuantDequant,
@@ -42,9 +42,6 @@
4242
from parameterized.parameterized import parameterized
4343
from pyre_extensions import none_throws
4444

45-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
46-
47-
from torch.export import export_for_training
4845
from torch.fx.passes.infra.pass_base import PassResult
4946

5047

@@ -459,44 +456,48 @@ def forward(self, x, y):
459456
)
460457

461458
def test_remove_nop_quant_dequant(self):
462-
class M(torch.nn.Module):
463-
def __init__(self):
464-
super(M, self).__init__()
465-
self.linear = torch.nn.Linear(6, 12, bias=False)
459+
builder = GraphBuilder()
460+
x = builder.placeholder("x", torch.randn(8, 8))
461+
q0 = builder.call_operator(
462+
op=exir_ops.edge.cadence.quantize_per_tensor.default, args=(x, 0.01662161760032177, -4, -128, 127, torch.int8)
463+
)
464+
dq0 = builder.call_operator(
465+
op=exir_ops.edge.cadence.dequantize_per_tensor.default, args=(q0, 0.01662161760032177, -4, -128, 127, torch.int8)
466+
)
467+
q1 = builder.call_operator(
468+
op=exir_ops.edge.cadence.quantize_per_tensor.default, args=(x, 0.012577153742313385, -9, -128, 127, torch.int8)
469+
)
470+
builder.output([dq0, q1])
471+
graph_module = builder.get_graph_module()
466472

467-
def forward(self, x):
468-
x = self.linear(x)
469-
return x
473+
# Expect the dq op to be removed by the pass
474+
self.assertEqual(
475+
count_node(graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default),
476+
1,
477+
)
470478

471-
inp = torch.randn(2, 8, 1, 6)
479+
# Expect 1 quantize op left since it has no matching dequant
480+
self.assertEqual(
481+
count_node(
482+
graph_module, exir_ops.edge.cadence.quantize_per_tensor.default
483+
),
484+
2,
485+
)
472486

473-
# Run the standard quant/convert steps, but without fusing
474-
# this leaves two redundant quant/dequant pairs to test with
475-
quantizer = CadenceDefaultQuantizer()
476-
model_exp = export_for_training(M(), (inp,), strict=True).module()
477-
prepared_model = prepare_pt2e(model_exp, quantizer)
478-
prepared_model(inp)
479-
converted_model = convert_pt2e(prepared_model)
487+
p = FuseQuantDequantToRequantizePass()
480488

481-
graph_module = (
482-
compiler.export_to_cadence(
483-
converted_model,
484-
(inp,),
485-
)
486-
.exported_program()
487-
.graph_module
488-
)
489+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
489490

490-
# Expect all quantize ops to be removed by the pass
491+
# Expect the dq op to be removed by the pass
491492
self.assertEqual(
492-
count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default),
493+
count_node(graph_after_passes, exir_ops.edge.cadence.dequantize_per_tensor.default),
493494
0,
494495
)
495496

496-
# Expect 1 dequantize op for the weights
497+
# Expect 1 quantize op left since it has no matching dequant
497498
self.assertEqual(
498499
count_node(
499-
graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default
500+
graph_after_passes, exir_ops.edge.cadence.quantize_per_tensor.default
500501
),
501502
1,
502503
)

0 commit comments

Comments
 (0)