Skip to content

Extract trace from prepare_and_convert and remove export_program #10493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 55 additions & 42 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from torch._inductor.decomposition import remove_decompositions
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from torch.export import export
from torch.export.exported_program import ExportedProgram

from .passes import get_cadence_passes
Expand All @@ -55,27 +54,24 @@
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
def prepare_and_convert_pt2(
def trace(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
) -> ExportedProgram:
"""
Prepare and convert a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_and_fuse_pt2 instead, which will instantiate a
default quantizer for you if needed.
If calibration data is provided, it will be used to calibrate the model. If
not, the inputs will be used for calibration instead, which is useful for
unit tests but should not be used for end-to-end use cases.
Returns a GraphModule with the converted model.
Trace the model with export_for_training and return an ExportedProgram.
"""

# Make the model inference mode by calling model.eval()
model.eval()

# Prevent mkldnn decompositions
torch._C._set_mkldnn_enabled(False)

# Get default decompositions
decomp_table = torch.export.default_decompositions()

# Select ops to keep
ops_to_keep = [
torch.ops.aten.conv1d.default,
Expand All @@ -85,19 +81,46 @@ def prepare_and_convert_pt2(
torch.ops.aten.matmul.default,
torch.ops.aten.rms_norm.default,
]

# Remove decompositions for the ops we want to keep
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
remove_decompositions(decomp_table, ops_to_keep)

# Export with dynamo
model_gm = (
torch.export.export_for_training(model, inputs, strict=True)
.run_decompositions(decomp_table)
.module()
)
program = torch.export.export_for_training(
model, inputs, strict=True
).run_decompositions(decomp_table)

if dump_graphs:
logging.info("Graph before quantization:")
logging.info(model_gm.graph.print_tabular())
logging.info(program.module().graph.print_tabular())

return program


def prepare_and_convert_pt2(
program: ExportedProgram,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Prepare and convert a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_and_fuse_pt2 instead, which will instantiate a
default quantizer for you if needed.
If calibration data is provided, it will be used to calibrate the model. If
not, the inputs will be used for calibration instead, which is useful for
unit tests but should not be used for end-to-end use cases.
Returns a GraphModule with the converted model.
"""

# Get the graph module from the ExportedProgram
model_gm = program.module()

assert isinstance(model_gm, torch.fx.GraphModule)

# Prepare
prepared_model = prepare_pt2e(model_gm, quantizer)
Expand All @@ -121,10 +144,10 @@ def prepare_and_convert_pt2(


# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
# if the quantizer here is different from the quantizer used to prepare/convert.
# It is however useful for unit tests to separate the converted model from the
# fused model, to be able to get reference numerics.
# If this does not apply, please use quantize_pt2 instead.
def fuse_pt2(
converted_graph_module: torch.fx.GraphModule,
quantizer: CadenceQuantizer,
Expand Down Expand Up @@ -167,9 +190,15 @@ def quantize_pt2(
if not quantizer:
quantizer = CadenceDefaultQuantizer()

program = trace(model, inputs, dump_graphs=dump_graphs)

if dump_graphs:
logging.info("Graph after trace:")
logging.info(program.graph.print_tabular())

# Get converted graph module
converted_gm = prepare_and_convert_pt2(
model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
program, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
)

# Get fused model
Expand All @@ -184,22 +213,6 @@ def quantize_pt2(
return program


# Export the model and lower it to an ExportedProgram (in aten IR)
def export_program(
model: torch.nn.Module,
inputs: tuple[object, ...],
) -> ExportedProgram:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# Prevent mkldnn decompositions
torch._C._set_mkldnn_enabled(False)

# Export the model and return it.
expo_program = export(model, inputs, strict=True)

return expo_program


def _lower_ep_to_edge(
expo_program: ExportedProgram,
dump_graphs: bool = False,
Expand Down Expand Up @@ -248,7 +261,7 @@ def export_to_edge(
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs)
expo_program = trace(model, inputs)

# Lower the model to edge IR.
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
Expand Down
6 changes: 5 additions & 1 deletion backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
export_to_executorch_gen_etrecord,
fuse_pt2,
prepare_and_convert_pt2,
trace,
)

from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
Expand Down Expand Up @@ -48,8 +49,11 @@ def export_model(
# Instantiate the quantizer
quantizer = CadenceDefaultQuantizer()

# Trace the model
ep = trace(model, example_inputs)

# Convert the model
converted_model = prepare_and_convert_pt2(model, example_inputs, quantizer)
converted_model = prepare_and_convert_pt2(ep, example_inputs, quantizer)

# Get reference outputs from converted model
ref_outputs = converted_model(*example_inputs)
Expand Down
68 changes: 37 additions & 31 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch.nn.functional as F
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import export_to_edge
from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass
from executorch.backends.cadence.aot.graph_builder import GraphBuilder

from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
from executorch.backends.cadence.aot.remove_ops import (
RemoveAliasCopyOpPass,
RemoveBranchedQuantDequant,
Expand All @@ -42,9 +42,6 @@
from parameterized.parameterized import parameterized
from pyre_extensions import none_throws

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from torch.export import export_for_training
from torch.fx.passes.infra.pass_base import PassResult


Expand Down Expand Up @@ -459,44 +456,53 @@ def forward(self, x, y):
)

def test_remove_nop_quant_dequant(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.linear = torch.nn.Linear(6, 12, bias=False)
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(8, 8))
q0 = builder.call_operator(
op=exir_ops.edge.cadence.quantize_per_tensor.default,
args=(x, 0.01662161760032177, -4, -128, 127, torch.int8),
)
dq0 = builder.call_operator(
op=exir_ops.edge.cadence.dequantize_per_tensor.default,
args=(q0, 0.01662161760032177, -4, -128, 127, torch.int8),
)
q1 = builder.call_operator(
op=exir_ops.edge.cadence.quantize_per_tensor.default,
args=(x, 0.012577153742313385, -9, -128, 127, torch.int8),
)
builder.output([dq0, q1])
graph_module = builder.get_graph_module()

def forward(self, x):
x = self.linear(x)
return x
# Expect the dq op to be removed by the pass
self.assertEqual(
count_node(
graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default
),
1,
)

inp = torch.randn(2, 8, 1, 6)
# Expect 1 quantize op left since it has no matching dequant
self.assertEqual(
count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default),
2,
)

# Run the standard quant/convert steps, but without fusing
# this leaves two redundant quant/dequant pairs to test with
quantizer = CadenceDefaultQuantizer()
model_exp = export_for_training(M(), (inp,), strict=True).module()
prepared_model = prepare_pt2e(model_exp, quantizer)
prepared_model(inp)
converted_model = convert_pt2e(prepared_model)
p = FuseQuantDequantToRequantizePass()

graph_module = (
compiler.export_to_cadence(
converted_model,
(inp,),
)
.exported_program()
.graph_module
)
graph_after_passes = cast(PassResult, p(graph_module)).graph_module

# Expect all quantize ops to be removed by the pass
# Expect the dq op to be removed by the pass
self.assertEqual(
count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default),
count_node(
graph_after_passes, exir_ops.edge.cadence.dequantize_per_tensor.default
),
0,
)

# Expect 1 dequantize op for the weights
# Expect 1 quantize op left since it has no matching dequant
self.assertEqual(
count_node(
graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default
graph_after_passes, exir_ops.edge.cadence.quantize_per_tensor.default
),
1,
)
Expand Down
Loading