diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 3d43ca2956e..594c4189b3a 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -39,7 +39,6 @@ from torch._inductor.decomposition import remove_decompositions from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.export import export from torch.export.exported_program import ExportedProgram from .passes import get_cadence_passes @@ -55,27 +54,24 @@ # however useful for unit tests to separate the converted model from the fused # model, to be able to get reference numerics. # If this does not apply, please use quantize_and_fuse_pt2 instead. -def prepare_and_convert_pt2( +def trace( model: torch.nn.Module, inputs: tuple[object, ...], - quantizer: CadenceQuantizer, - calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, -) -> torch.fx.GraphModule: +) -> ExportedProgram: """ - Prepare and convert a model using the given quantizer. - The quantizer must be supplied and be the same as the one used to - fuse the model later, if applicable. If you do not expect that behavior, - please use quantize_and_fuse_pt2 instead, which will instantiate a - default quantizer for you if needed. - If calibration data is provided, it will be used to calibrate the model. If - not, the inputs will be used for calibration instead, which is useful for - unit tests but should not be used for end-to-end use cases. - Returns a GraphModule with the converted model. + Trace the model with export_for_training and return an ExportedProgram. """ + # Make the model inference mode by calling model.eval() + model.eval() + + # Prevent mkldnn decompositions + torch._C._set_mkldnn_enabled(False) + # Get default decompositions decomp_table = torch.export.default_decompositions() + # Select ops to keep ops_to_keep = [ torch.ops.aten.conv1d.default, @@ -85,19 +81,46 @@ def prepare_and_convert_pt2( torch.ops.aten.matmul.default, torch.ops.aten.rms_norm.default, ] + # Remove decompositions for the ops we want to keep # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any remove_decompositions(decomp_table, ops_to_keep) + # Export with dynamo - model_gm = ( - torch.export.export_for_training(model, inputs, strict=True) - .run_decompositions(decomp_table) - .module() - ) + program = torch.export.export_for_training( + model, inputs, strict=True + ).run_decompositions(decomp_table) if dump_graphs: logging.info("Graph before quantization:") - logging.info(model_gm.graph.print_tabular()) + logging.info(program.module().graph.print_tabular()) + + return program + + +def prepare_and_convert_pt2( + program: ExportedProgram, + inputs: tuple[object, ...], + quantizer: CadenceQuantizer, + calibration_data: Optional[list[tuple[object, ...]]] = None, + dump_graphs: bool = False, +) -> torch.fx.GraphModule: + """ + Prepare and convert a model using the given quantizer. + The quantizer must be supplied and be the same as the one used to + fuse the model later, if applicable. If you do not expect that behavior, + please use quantize_and_fuse_pt2 instead, which will instantiate a + default quantizer for you if needed. + If calibration data is provided, it will be used to calibrate the model. If + not, the inputs will be used for calibration instead, which is useful for + unit tests but should not be used for end-to-end use cases. + Returns a GraphModule with the converted model. + """ + + # Get the graph module from the ExportedProgram + model_gm = program.module() + + assert isinstance(model_gm, torch.fx.GraphModule) # Prepare prepared_model = prepare_pt2e(model_gm, quantizer) @@ -121,10 +144,10 @@ def prepare_and_convert_pt2( # Note: this is not meant as a primary API since it can create inconsistencies -# if the quantizer here is different from the quantizer used to convert. It is -# however useful for unit tests to separate the converted model from the fused -# model, to be able to get reference numerics. -# If this does not apply, please use quantize_and_fuse_pt2 instead. +# if the quantizer here is different from the quantizer used to prepare/convert. +# It is however useful for unit tests to separate the converted model from the +# fused model, to be able to get reference numerics. +# If this does not apply, please use quantize_pt2 instead. def fuse_pt2( converted_graph_module: torch.fx.GraphModule, quantizer: CadenceQuantizer, @@ -167,9 +190,15 @@ def quantize_pt2( if not quantizer: quantizer = CadenceDefaultQuantizer() + program = trace(model, inputs, dump_graphs=dump_graphs) + + if dump_graphs: + logging.info("Graph after trace:") + logging.info(program.graph.print_tabular()) + # Get converted graph module converted_gm = prepare_and_convert_pt2( - model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs + program, inputs, quantizer, calibration_data, dump_graphs=dump_graphs ) # Get fused model @@ -184,22 +213,6 @@ def quantize_pt2( return program -# Export the model and lower it to an ExportedProgram (in aten IR) -def export_program( - model: torch.nn.Module, - inputs: tuple[object, ...], -) -> ExportedProgram: - assert isinstance(model, torch.nn.Module), "model should be an nn.Module" - - # Prevent mkldnn decompositions - torch._C._set_mkldnn_enabled(False) - - # Export the model and return it. - expo_program = export(model, inputs, strict=True) - - return expo_program - - def _lower_ep_to_edge( expo_program: ExportedProgram, dump_graphs: bool = False, @@ -248,7 +261,7 @@ def export_to_edge( assert isinstance(model, torch.nn.Module), "model should be an nn.Module" # Export the model into an ExportedProgram. - expo_program = export_program(model, inputs) + expo_program = trace(model, inputs) # Lower the model to edge IR. edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods) diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index d2148870e53..6eaead7105e 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -18,6 +18,7 @@ export_to_executorch_gen_etrecord, fuse_pt2, prepare_and_convert_pt2, + trace, ) from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer @@ -48,8 +49,11 @@ def export_model( # Instantiate the quantizer quantizer = CadenceDefaultQuantizer() + # Trace the model + ep = trace(model, example_inputs) + # Convert the model - converted_model = prepare_and_convert_pt2(model, example_inputs, quantizer) + converted_model = prepare_and_convert_pt2(ep, example_inputs, quantizer) # Get reference outputs from converted model ref_outputs = converted_model(*example_inputs) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 8caba7799b5..74c39ae3ee3 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -16,10 +16,10 @@ import torch.nn.functional as F from executorch.backends.cadence.aot import compiler from executorch.backends.cadence.aot.compiler import export_to_edge +from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match -from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer from executorch.backends.cadence.aot.remove_ops import ( RemoveAliasCopyOpPass, RemoveBranchedQuantDequant, @@ -42,9 +42,6 @@ from parameterized.parameterized import parameterized from pyre_extensions import none_throws -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e - -from torch.export import export_for_training from torch.fx.passes.infra.pass_base import PassResult @@ -459,44 +456,53 @@ def forward(self, x, y): ) def test_remove_nop_quant_dequant(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.linear = torch.nn.Linear(6, 12, bias=False) + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 8)) + q0 = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.01662161760032177, -4, -128, 127, torch.int8), + ) + dq0 = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(q0, 0.01662161760032177, -4, -128, 127, torch.int8), + ) + q1 = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.012577153742313385, -9, -128, 127, torch.int8), + ) + builder.output([dq0, q1]) + graph_module = builder.get_graph_module() - def forward(self, x): - x = self.linear(x) - return x + # Expect the dq op to be removed by the pass + self.assertEqual( + count_node( + graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + ), + 1, + ) - inp = torch.randn(2, 8, 1, 6) + # Expect 1 quantize op left since it has no matching dequant + self.assertEqual( + count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default), + 2, + ) - # Run the standard quant/convert steps, but without fusing - # this leaves two redundant quant/dequant pairs to test with - quantizer = CadenceDefaultQuantizer() - model_exp = export_for_training(M(), (inp,), strict=True).module() - prepared_model = prepare_pt2e(model_exp, quantizer) - prepared_model(inp) - converted_model = convert_pt2e(prepared_model) + p = FuseQuantDequantToRequantizePass() - graph_module = ( - compiler.export_to_cadence( - converted_model, - (inp,), - ) - .exported_program() - .graph_module - ) + graph_after_passes = cast(PassResult, p(graph_module)).graph_module - # Expect all quantize ops to be removed by the pass + # Expect the dq op to be removed by the pass self.assertEqual( - count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default), + count_node( + graph_after_passes, exir_ops.edge.cadence.dequantize_per_tensor.default + ), 0, ) - # Expect 1 dequantize op for the weights + # Expect 1 quantize op left since it has no matching dequant self.assertEqual( count_node( - graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + graph_after_passes, exir_ops.edge.cadence.quantize_per_tensor.default ), 1, )