diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 5c66c9eb62b..ee1da44f391 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -152,6 +152,7 @@ def convert_pt2( def apply_pre_edge_transform_passes( converted_program: ExportedProgram, quantizer: CadenceQuantizer, + is_qat: bool = False, ) -> ExportedProgram: """ Apply pre-edge transform passes including QuantFusion and torch ops passes. @@ -166,12 +167,11 @@ def apply_pre_edge_transform_passes( """ # pyre-ignore[16]: no attribute patterns = [q.pattern for q in quantizer.quantizers] - PassManager( - [ - FuseQATConvBN(converted_program), - QuantFusion(patterns), - ] - )(converted_program.graph_module) + passes = [] + if is_qat: + passes.append(FuseQATConvBN(converted_program)) + passes.append(QuantFusion(patterns)) + PassManager(passes)(converted_program.graph_module) # Apply torch ops passes (e.g., ReplaceMulTensorWithMulAndFullOpsPass) fused_program = apply_torch_ops_passes(converted_program) @@ -187,19 +187,24 @@ def get_fake_quant_model( quantizer: CadenceQuantizer, calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, + is_qat: bool = False, ) -> torch.fx.GraphModule: # Make the model inference mode by calling model.eval() model.eval() ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition() - program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep) + program = trace( + model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat + ) if dump_graphs: logging.info("Graph after trace:") logging.info(program.graph.print_tabular()) # Get prepared graph module - prepared_gm = prepare_traced_pt2(program, quantizer, dump_graphs=dump_graphs) + prepared_gm = prepare_traced_pt2( + program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat + ) # Calibrate # If no calibration data is provided, use the inputs @@ -221,6 +226,7 @@ def quantize_pt2( calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, quant_input_args: Optional[list[str]] = None, + is_qat: bool = False, ) -> ExportedProgram: """ Trace, prepare, convert and fuse the model using the given quantizer. @@ -242,6 +248,7 @@ def quantize_pt2( quantizer=quantizer, calibration_data=calibration_data, dump_graphs=dump_graphs, + is_qat=is_qat, ) # Wrap the model to handle quantized inputs if provided if quant_input_args is not None: @@ -254,7 +261,7 @@ def quantize_pt2( if quant_input_args is not None: QuantizedInputWrapper.sink_dequants(program) - fused_program = apply_pre_edge_transform_passes(program, quantizer) + fused_program = apply_pre_edge_transform_passes(program, quantizer, is_qat=is_qat) if dump_graphs: logging.info("Graph after quantization and fusion:")