From 3c17a1d53d79d966805f2a67fcecbfea4f79f0db Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Thu, 14 May 2026 11:33:43 -0700 Subject: [PATCH] Gate FuseQATConvBN behind is_qat=True; opt in from QAT deployments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The FuseQATConvBN pass added in D104497938 ran unconditionally inside `apply_pre_edge_transform_passes`. Its `_prep_conv_biases` step delegates to the shared `_quantize_fused_conv_bias` helper, which iterates every conv in the graph and asserts each conv input is `dequantize_per_tensor` — an invariant that only holds inside the conv-BN simulation chain `prepare_qat_pt2e` inserts. PTQ graphs trip the assert (T271158088). Two failure modes seen in the wild: - `test_quantized_w8a32_conv1d_out_2` uses `CadenceW8A32MixedQuantizer` so activations stay float32; the conv input is the placeholder, not a dequant. - `test_conv2d_out_7` is `channel_last=True`, so the conv input is `aten.permute`, not a dequant; the helper only unwraps `unsqueeze` variants. Add an `is_qat: bool = False` parameter to `apply_pre_edge_transform_passes` and only include `FuseQATConvBN` when True. Plumb through `quantize_pt2`/`get_fake_quant_model` and forward from the modai recipe lambda so `ar_*_qat_et_recipe` factories actually opt in. QAT-trained models lowered via blobgen need a way to reach the QAT recipe. Add `is_qat: bool` to `Packaging` and have `Rt700Hifi4Deployment` pass `train=self.packaging.is_qat` to `get_recipe_with_custom_settings`. Models like `activity_classification_artemis` should set `"is_qat": True` in their `defs.bzl` packaging block. Differential Revision: D105061752 --- backends/cadence/aot/compiler.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 5c66c9eb62b..ee1da44f391 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -152,6 +152,7 @@ def convert_pt2( def apply_pre_edge_transform_passes( converted_program: ExportedProgram, quantizer: CadenceQuantizer, + is_qat: bool = False, ) -> ExportedProgram: """ Apply pre-edge transform passes including QuantFusion and torch ops passes. @@ -166,12 +167,11 @@ def apply_pre_edge_transform_passes( """ # pyre-ignore[16]: no attribute patterns = [q.pattern for q in quantizer.quantizers] - PassManager( - [ - FuseQATConvBN(converted_program), - QuantFusion(patterns), - ] - )(converted_program.graph_module) + passes = [] + if is_qat: + passes.append(FuseQATConvBN(converted_program)) + passes.append(QuantFusion(patterns)) + PassManager(passes)(converted_program.graph_module) # Apply torch ops passes (e.g., ReplaceMulTensorWithMulAndFullOpsPass) fused_program = apply_torch_ops_passes(converted_program) @@ -187,19 +187,24 @@ def get_fake_quant_model( quantizer: CadenceQuantizer, calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, + is_qat: bool = False, ) -> torch.fx.GraphModule: # Make the model inference mode by calling model.eval() model.eval() ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition() - program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep) + program = trace( + model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat + ) if dump_graphs: logging.info("Graph after trace:") logging.info(program.graph.print_tabular()) # Get prepared graph module - prepared_gm = prepare_traced_pt2(program, quantizer, dump_graphs=dump_graphs) + prepared_gm = prepare_traced_pt2( + program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat + ) # Calibrate # If no calibration data is provided, use the inputs @@ -221,6 +226,7 @@ def quantize_pt2( calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, quant_input_args: Optional[list[str]] = None, + is_qat: bool = False, ) -> ExportedProgram: """ Trace, prepare, convert and fuse the model using the given quantizer. @@ -242,6 +248,7 @@ def quantize_pt2( quantizer=quantizer, calibration_data=calibration_data, dump_graphs=dump_graphs, + is_qat=is_qat, ) # Wrap the model to handle quantized inputs if provided if quant_input_args is not None: @@ -254,7 +261,7 @@ def quantize_pt2( if quant_input_args is not None: QuantizedInputWrapper.sink_dequants(program) - fused_program = apply_pre_edge_transform_passes(program, quantizer) + fused_program = apply_pre_edge_transform_passes(program, quantizer, is_qat=is_qat) if dump_graphs: logging.info("Graph after quantization and fusion:")