From 3407075b7573ef31342565436c14bba58ffdf299 Mon Sep 17 00:00:00 2001 From: Jiwon Shin Date: Mon, 20 Apr 2026 12:42:57 -0700 Subject: [PATCH] [QwixOdmlQuantizationBoundary] Fix MLIR assertion failure and metadata leakage in U-Net quantization. This change stabilizes the Qwix ODML quantization pipeline by enforcing strict boundaries between quantized and floating-point regions. Key changes: - Refactored metadata isolation in odml_ops.py by introducing `_copy_for_isolation` to explicitly copy arrays and port structural metadata, preventing rule leakage across shared branches. - Hardened `_maybe_fake_quant` to explicitly handle non-weight tensors (activations and constants) and clarified activation rule sharing for constants. - Added detailed inline comments mapping valid input combinations and guarantees for operation inputs. - Enforced that rule is None or rule.act_qtype is None in `_maybe_fake_quant` implies full precision (FP), returning a copy to avoid sharing metadata. - Updated FinalOutput to preserve intended delayed quantization behavior for model outputs. PiperOrigin-RevId: 902774410 --- integration_tests/odml_coverage_test.py | 3 +- qwix/_src/providers/odml_ops.py | 95 ++++++++++++++++++------- tests/_src/providers/odml_test.py | 35 +++++++++ 3 files changed, 105 insertions(+), 28 deletions(-) diff --git a/integration_tests/odml_coverage_test.py b/integration_tests/odml_coverage_test.py index 646618bf..f0dc050e 100644 --- a/integration_tests/odml_coverage_test.py +++ b/integration_tests/odml_coverage_test.py @@ -253,7 +253,8 @@ def create_input(self): 'Dense_0/dot_general0_lhs', # No Dense_0/add0_input0 because add is fused. 'Dense_1/dot_general0_lhs', - # No add0_lhs because it's quantized as Dense_0/dot_general0_lhs. + # add0_lhs is collected now due to isolation. + 'add0_lhs', 'add0_rhs', 'final_output0', } diff --git a/qwix/_src/providers/odml_ops.py b/qwix/_src/providers/odml_ops.py index f6131e03..ead87038 100644 --- a/qwix/_src/providers/odml_ops.py +++ b/qwix/_src/providers/odml_ops.py @@ -182,6 +182,32 @@ class AuxDataKey(str, enum.Enum): FakeQuantFn = Callable[[jax.Array, qarray.HowToQuantize, str | None], jax.Array] +def _copy_for_isolation(original_array: jax.Array) -> jax.Array: + """Creates a copy of the array to isolate it from other branches. + + This is used to prevent quantization metadata (_FQ_RULE, _FQ_ARRAY) from + leaking across shared branches. It ports essential structural metadata + but leaves the copy in a clean, unquantized state. + + Args: + original_array: The array to copy. + + Returns: + A copy of the array with some metadata preserved. + """ + array_copy = jnp.array(original_array, copy=True) + # We deliberately do NOT copy _FQ_RULE or _FQ_ARRAY to ensure metadata + # isolation and prevent rule leakage across branches. + if aux_data.get(original_array, AuxDataKey.IS_ACTIVATION, False): + aux_data.set(array_copy, AuxDataKey.IS_ACTIVATION, True) + fixed_range = aux_data.get(original_array, AuxDataKey.FIXED_RANGE, None) + if fixed_range is not None: + aux_data.set(array_copy, AuxDataKey.FIXED_RANGE, fixed_range) + if aux_data.get(original_array, AuxDataKey.ALLOW_FUSION, False): + aux_data.set(array_copy, AuxDataKey.ALLOW_FUSION, True) + return array_copy + + class QuantizedOp: """A generic quantized op that allows different scales for inputs and output. @@ -271,10 +297,15 @@ def _fake_quant_inputs( """Fake quantize the inputs of the op.""" args = list(args) if len(self.input_idx) == 1: + # Guaranteed to be activation because _inputs_have_activations was True. idx = self.input_idx[0] args[idx] = self._maybe_fake_quant(args[idx], rule, op_id) elif len(self.input_idx) == 2: lhs, rhs = tuple(self.input_idx) # pylint: disable=unbalanced-tuple-unpacking + # Possible combinations at this point (since at least one is activation): + # 1. Activation / Weight + # 2. Activation / Activation + # 3. Activation / Constant (Non-Weight, Non-Activation) # Binary ops could have non-array args, e.g. x + 1. if isinstance(args[lhs], jax.Array): args[lhs] = self._maybe_fake_quant(args[lhs], rule, op_id + '_lhs') @@ -294,8 +325,8 @@ def _maybe_fake_quant( ) -> jax.Array: """Fake quantize the array based on the given rule. - This function assumes the array is an activation, unless it has weight_name - aux_data, e.g., in jnp.take. + This function assumes the array is a non-weight tensor (activation or + constant), unless it has weight_name aux_data, e.g., in jnp.take. Args: array: The array to quantize. @@ -335,44 +366,53 @@ def _maybe_fake_quant( # No rule for weights, return as is. return array - # 2) Handle the Activation case (delayed quantization). - # rule.act_qtype means this op will produce an rule.act_qtype output. - # The producer(Op N-1) sets the rule for the activation. - # The consumer(Op N) uses that rule to quantize the activation. + # 2) Handle the Non-Weight case (Activations and Constants). - # Do not quantize if the rule explicitly disables it. - if rule and rule.act_qtype is None: - return array + # If the current operation does not quantize this input, we return a copy + # to avoid sharing metadata with other branches consuming the same tensor. + if rule is None or rule.act_qtype is None: + return _copy_for_isolation(array) + # Determine the effective rule to use. previous_rule = aux_data.get(array, AuxDataKey.FQ_RULE, None) if previous_rule is not None: - # Delayed Quantization: The Previous Op (Producer) specified how its - # output should be quantized. The Current Op (Consumer) now executes - # that rule on this input array. - rule = previous_rule + # Delayed Quantization: use producer's rule. No copy needed because + # all consumers agree on this rule. + effective_rule = previous_rule + needs_copy = False else: - # Immediate Quantiztion: If the input activation has no previous rule - # (e.g. the first layer after ModelInput, or the first layer after an - # excluded layer), we use the current op's rule to quantize it - # immediately. - pass - - # If there is no rule or the rule does not have an activation quantization - # type, return as is. - if rule is None or rule.act_qtype is None: + # Immediate Quantization (Fallback): use current consumer's rule. + # For both activations (that lacked a producer rule) and non-weight/ + # non-activation tensors (like constants), we fall back to using the + # activation quantization rule of the current operation (specifically + # act_qtype, act_batch_axes, and act_calibration_method) to ensure + # compatibility with the operation's execution. + # We need to copy to protect shared tensors from metadata leakage. + effective_rule = rule + needs_copy = True + + # If the effective rule does not have an activation quantization type, + # return as is. + if effective_rule.act_qtype is None: return array - if not rule.act_static_scale: + + # Apply copying if needed for isolation. + if needs_copy: + array = _copy_for_isolation(array) + + # Proceed with quantization. + if not effective_rule.act_static_scale: # DRQ is only supported in DotEinsumConv and they should call # _fake_quant_fn directly. return array how = qarray.HowToQuantize( - qtype=rule.act_qtype, + qtype=effective_rule.act_qtype, tiled_axes={}, # Use per-channel scales for batch axes, which will be reduced later # in _collect_quant_stat. - channelwise_axes=rule.act_batch_axes, - calibration_method=rule.act_calibration_method, + channelwise_axes=effective_rule.act_batch_axes, + calibration_method=effective_rule.act_calibration_method, ) fq_array = self._fake_quant_fn(array, how, quant_stat_name) @@ -479,7 +519,8 @@ def __call__(self, x: Any) -> Any: if self.fixed_range_for_output is not None: aux_data.set(x, AuxDataKey.FIXED_RANGE, self.fixed_range_for_output) # Only FQ the output if the previous op wants. - return self._maybe_fake_quant(x, None, op_id) + previous_rule = aux_data.get(x, AuxDataKey.FQ_RULE, None) + return self._maybe_fake_quant(x, previous_rule, op_id) def _forward_metadata(inputs: Any, outputs: Any, is_value_preserving_op: bool): diff --git a/tests/_src/providers/odml_test.py b/tests/_src/providers/odml_test.py index 39f6af4f..9d7e2d4d 100644 --- a/tests/_src/providers/odml_test.py +++ b/tests/_src/providers/odml_test.py @@ -230,6 +230,41 @@ def test_odml_interception_stack(self): interception.PRIMITIVE_BIND_KEY, numerical_interceptor.mapping ) + def test_mixed_tags_at_boundary(self): + class BranchModel(nn.Module): + + @nn.compact + def __call__(self, x): + x1 = nn.Dense(features=8, name='quant_dense')(x) + x2 = nn.Dense(features=8, name='float_dense')(x) + return jnp.multiply(x1, x2) + + model = BranchModel() + rules = [ + qconfig.QuantizationRule( + module_path='.*quant_dense.*', + weight_qtype=jnp.int8, + act_qtype=jnp.int8, + ), + ] + qat_provider = odml.OdmlQatProvider(rules) + qat_model = qwix_model.quantize_model(model, qat_provider) + model_input = jnp.ones((1, 8), dtype=jnp.float32) + qat_vars = qat_model.init(jax.random.key(0), model_input) + + flat_stats = flax.traverse_util.flatten_dict(qat_vars['quant_stats']) + stat_keys = {'/'.join(k[:-1]) for k in flat_stats} + + # quant_dense should have stats collected + self.assertIn('quant_dense/dot_general0_lhs', stat_keys) + + # float_dense should NOT have stats collected + self.assertNotIn('float_dense/dot_general0_lhs', stat_keys) + + # multiply should NOT have stats collected because of mixed tags handling + self.assertNotIn('multiply0_lhs', stat_keys) + self.assertNotIn('multiply0_rhs', stat_keys) + if __name__ == '__main__': absltest.main()