Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion integration_tests/odml_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ def create_input(self):
'Dense_0/dot_general0_lhs',
# No Dense_0/add0_input0 because add is fused.
'Dense_1/dot_general0_lhs',
# No add0_lhs because it's quantized as Dense_0/dot_general0_lhs.
# add0_lhs is collected now due to isolation.
'add0_lhs',
'add0_rhs',
'final_output0',
}
Expand Down
95 changes: 68 additions & 27 deletions qwix/_src/providers/odml_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,32 @@ class AuxDataKey(str, enum.Enum):
FakeQuantFn = Callable[[jax.Array, qarray.HowToQuantize, str | None], jax.Array]


def _copy_for_isolation(original_array: jax.Array) -> jax.Array:
"""Creates a copy of the array to isolate it from other branches.

This is used to prevent quantization metadata (_FQ_RULE, _FQ_ARRAY) from
leaking across shared branches. It ports essential structural metadata
but leaves the copy in a clean, unquantized state.

Args:
original_array: The array to copy.

Returns:
A copy of the array with some metadata preserved.
"""
array_copy = jnp.array(original_array, copy=True)
# We deliberately do NOT copy _FQ_RULE or _FQ_ARRAY to ensure metadata
# isolation and prevent rule leakage across branches.
if aux_data.get(original_array, AuxDataKey.IS_ACTIVATION, False):
aux_data.set(array_copy, AuxDataKey.IS_ACTIVATION, True)
fixed_range = aux_data.get(original_array, AuxDataKey.FIXED_RANGE, None)
if fixed_range is not None:
aux_data.set(array_copy, AuxDataKey.FIXED_RANGE, fixed_range)
if aux_data.get(original_array, AuxDataKey.ALLOW_FUSION, False):
aux_data.set(array_copy, AuxDataKey.ALLOW_FUSION, True)
return array_copy


class QuantizedOp:
"""A generic quantized op that allows different scales for inputs and output.

Expand Down Expand Up @@ -271,10 +297,15 @@ def _fake_quant_inputs(
"""Fake quantize the inputs of the op."""
args = list(args)
if len(self.input_idx) == 1:
# Guaranteed to be activation because _inputs_have_activations was True.
idx = self.input_idx[0]
args[idx] = self._maybe_fake_quant(args[idx], rule, op_id)
elif len(self.input_idx) == 2:
lhs, rhs = tuple(self.input_idx) # pylint: disable=unbalanced-tuple-unpacking
# Possible combinations at this point (since at least one is activation):
# 1. Activation / Weight
# 2. Activation / Activation
# 3. Activation / Constant (Non-Weight, Non-Activation)
# Binary ops could have non-array args, e.g. x + 1.
if isinstance(args[lhs], jax.Array):
args[lhs] = self._maybe_fake_quant(args[lhs], rule, op_id + '_lhs')
Expand All @@ -294,8 +325,8 @@ def _maybe_fake_quant(
) -> jax.Array:
"""Fake quantize the array based on the given rule.

This function assumes the array is an activation, unless it has weight_name
aux_data, e.g., in jnp.take.
This function assumes the array is a non-weight tensor (activation or
constant), unless it has weight_name aux_data, e.g., in jnp.take.

Args:
array: The array to quantize.
Expand Down Expand Up @@ -335,44 +366,53 @@ def _maybe_fake_quant(
# No rule for weights, return as is.
return array

# 2) Handle the Activation case (delayed quantization).
# rule.act_qtype means this op will produce an rule.act_qtype output.
# The producer(Op N-1) sets the rule for the activation.
# The consumer(Op N) uses that rule to quantize the activation.
# 2) Handle the Non-Weight case (Activations and Constants).

# Do not quantize if the rule explicitly disables it.
if rule and rule.act_qtype is None:
return array
# If the current operation does not quantize this input, we return a copy
# to avoid sharing metadata with other branches consuming the same tensor.
if rule is None or rule.act_qtype is None:
return _copy_for_isolation(array)

# Determine the effective rule to use.
previous_rule = aux_data.get(array, AuxDataKey.FQ_RULE, None)
if previous_rule is not None:
# Delayed Quantization: The Previous Op (Producer) specified how its
# output should be quantized. The Current Op (Consumer) now executes
# that rule on this input array.
rule = previous_rule
# Delayed Quantization: use producer's rule. No copy needed because
# all consumers agree on this rule.
effective_rule = previous_rule
needs_copy = False
else:
# Immediate Quantiztion: If the input activation has no previous rule
# (e.g. the first layer after ModelInput, or the first layer after an
# excluded layer), we use the current op's rule to quantize it
# immediately.
pass

# If there is no rule or the rule does not have an activation quantization
# type, return as is.
if rule is None or rule.act_qtype is None:
# Immediate Quantization (Fallback): use current consumer's rule.
# For both activations (that lacked a producer rule) and non-weight/
# non-activation tensors (like constants), we fall back to using the
# activation quantization rule of the current operation (specifically
# act_qtype, act_batch_axes, and act_calibration_method) to ensure
# compatibility with the operation's execution.
# We need to copy to protect shared tensors from metadata leakage.
effective_rule = rule
needs_copy = True

# If the effective rule does not have an activation quantization type,
# return as is.
if effective_rule.act_qtype is None:
return array
if not rule.act_static_scale:

# Apply copying if needed for isolation.
if needs_copy:
array = _copy_for_isolation(array)

# Proceed with quantization.
if not effective_rule.act_static_scale:
# DRQ is only supported in DotEinsumConv and they should call
# _fake_quant_fn directly.
return array

how = qarray.HowToQuantize(
qtype=rule.act_qtype,
qtype=effective_rule.act_qtype,
tiled_axes={},
# Use per-channel scales for batch axes, which will be reduced later
# in _collect_quant_stat.
channelwise_axes=rule.act_batch_axes,
calibration_method=rule.act_calibration_method,
channelwise_axes=effective_rule.act_batch_axes,
calibration_method=effective_rule.act_calibration_method,
)

fq_array = self._fake_quant_fn(array, how, quant_stat_name)
Expand Down Expand Up @@ -479,7 +519,8 @@ def __call__(self, x: Any) -> Any:
if self.fixed_range_for_output is not None:
aux_data.set(x, AuxDataKey.FIXED_RANGE, self.fixed_range_for_output)
# Only FQ the output if the previous op wants.
return self._maybe_fake_quant(x, None, op_id)
previous_rule = aux_data.get(x, AuxDataKey.FQ_RULE, None)
return self._maybe_fake_quant(x, previous_rule, op_id)


def _forward_metadata(inputs: Any, outputs: Any, is_value_preserving_op: bool):
Expand Down
35 changes: 35 additions & 0 deletions tests/_src/providers/odml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,41 @@ def test_odml_interception_stack(self):
interception.PRIMITIVE_BIND_KEY, numerical_interceptor.mapping
)

def test_mixed_tags_at_boundary(self):
class BranchModel(nn.Module):

@nn.compact
def __call__(self, x):
x1 = nn.Dense(features=8, name='quant_dense')(x)
x2 = nn.Dense(features=8, name='float_dense')(x)
return jnp.multiply(x1, x2)

model = BranchModel()
rules = [
qconfig.QuantizationRule(
module_path='.*quant_dense.*',
weight_qtype=jnp.int8,
act_qtype=jnp.int8,
),
]
qat_provider = odml.OdmlQatProvider(rules)
qat_model = qwix_model.quantize_model(model, qat_provider)
model_input = jnp.ones((1, 8), dtype=jnp.float32)
qat_vars = qat_model.init(jax.random.key(0), model_input)

flat_stats = flax.traverse_util.flatten_dict(qat_vars['quant_stats'])
stat_keys = {'/'.join(k[:-1]) for k in flat_stats}

# quant_dense should have stats collected
self.assertIn('quant_dense/dot_general0_lhs', stat_keys)

# float_dense should NOT have stats collected
self.assertNotIn('float_dense/dot_general0_lhs', stat_keys)

# multiply should NOT have stats collected because of mixed tags handling
self.assertNotIn('multiply0_lhs', stat_keys)
self.assertNotIn('multiply0_rhs', stat_keys)


if __name__ == '__main__':
absltest.main()