Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 59 additions & 20 deletions qwix/_src/providers/odml.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(
self._strict = strict
self._ops = odml_ops.get_all_ops()

# Only these contraction ops support toggling channelwise weight
# quantization (standard for ODML). For other ops, per-channel weight
# quantization is either not applicable or not supported.
for name in [
'jax.lax.conv_general_dilated',
'jax.lax.dot_general',
Expand Down Expand Up @@ -115,7 +118,9 @@ def nn_param(
# Clear the previous aux_data such as fq_array.
aux_data.clear(ret if unbox else ret.unbox())
# weight_name is used to distinguish weights from activations.
aux_data.set(ret if unbox else ret.unbox(), 'weight_name', name)
aux_data.set(
ret if unbox else ret.unbox(), odml_ops.AuxDataKey.WEIGHT_NAME, name
)
return ret

def get_interceptors(
Expand Down Expand Up @@ -145,7 +150,14 @@ def get_interceptors(
]

def get_intercept_map(self):
"""Used for interception."""
"""Returns a map of function names to their intercepted implementations.

This method instantiates operator classes from `odml_ops` as functors that
bind to this provider's specific context (e.g., `_fake_quant`). JAX uses
these instances' `__call__` methods to replace the original operations,
allowing them to maintain operator-specific logic while accessing
provider-level state.
"""
intercept_map = super().get_intercept_map()
intercept_map['flax.linen.Module.param'] = self.nn_param
# Add all the ops to the intercept map.
Expand All @@ -161,8 +173,21 @@ def get_intercept_map(self):
def process_model_inputs(
self, model: Any, model_args: Any, model_kwargs: Any
) -> tuple[Any, Any, Any]:
"""Quantize the input of the model."""
# Set weight_name for nnx models. Linen models are handled in nn_param.
"""Prepares model activations for quantization metadata propagation.

This method also handles weight tagging for NNX models as a special case.

Args:
model: The model to process.
model_args: Positional arguments to the model.
model_kwargs: Keyword arguments to the model.

Returns:
The processed model and arguments with appropriate auxiliary data.
"""
# Weight Handling (NNX only): Eagerly iterate over the graph to clear stale
# metadata and tag parameters with _WEIGHT_NAME. For Flax Linen models,
# weights are handled lazily via `nn_param` interception.
if isinstance(model, nnx.Module):
for path, node in nnx.iter_graph(model):
if isinstance(node, nnx.Module):
Expand All @@ -171,9 +196,16 @@ def process_model_inputs(
# Clear the previous aux_data such as fq_array.
aux_data.clear(node.value)
# weight_name is used to distinguish weights from activations.
aux_data.set(node.value, 'weight_name', path[-1])

# Quantize the model inputs if needed.
aux_data.set(node.value, odml_ops.AuxDataKey.WEIGHT_NAME, path[-1])

# Activation Handling: Apply the `ModelInput` operator to all leaves of
# `model_args` and `model_kwargs` (the actual arguments passed to the
# model).
# ModelInput behavior:
# - For non-jax.Array objects (e.g., bool, int), it's a no-op.
# - For jax.Array objects, it clears stale metadata, marks them as
# activations (_IS_ACTIVATION = True), and attaches fixed ranges if set.
# This prepares the inputs as origin points for metadata tracking.
op = odml_ops.ModelInput(
fixed_range_for_output=self._fixed_range_for_inputs,
get_rule_and_op_id_fn=self._get_current_rule_and_op_id,
Expand Down Expand Up @@ -202,31 +234,38 @@ def _fake_quant(
how: qarray.HowToQuantize,
quant_stat_name: str | None = None,
) -> jax.Array:
"""Apply fake quantization to array.
"""Numerical operation used by intercepted model ops to fake-quantize tensors.

This method is the core implementation passed as a callback to intercepted
operators (e.g., in `odml_ops.py`). It is invoked by those operators to
perform the actual numerical quantization tasks for both activations and
weights during the model execution.

This function can be used on both activations and weights. Gradient will be
passed through.
It handles:
1. Calibration (including fixed-range overrides from `aux_data`).
2. Quantization statistics collection and moving-average updates.
3. Scale and zero-point computation.
4. Gradient pass-through via a straight-through estimator (STE).

Args:
array: The array to quantize.
how: How to quantize the array.
quant_stat_name: The name for the quantization statistics. If set, the
quantization statistics will be collected and the scale will be computed
from the statistics.
how: Parameters defining how to quantize the array (e.g., qtype).
quant_stat_name: Unique name for collecting and averaging quantization
statistics. If None, statistics are not collected.

Returns:
The fake quantized array.
"""
# Check and apply the fixed-range calibration asscociated with the array.
fixed_range = aux_data.get(array, 'fixed_range', None)
fixed_range = aux_data.get(array, odml_ops.AuxDataKey.FIXED_RANGE, None)
if fixed_range is not None:
calibration_method = f'fixed,{fixed_range[0]},{fixed_range[1]}'
how = dataclasses.replace(how, calibration_method=calibration_method)

calibration = qarray.calibrate(array, how)
if quant_stat_name is not None:
is_fixed_range = how.calibration_method.startswith('fixed')
calibration = self._collect_quant_stat(
calibration = self._update_and_get_quant_stat(
quant_stat_name, calibration, is_fixed_range
)
scale, zero_point = qarray.compute_scale_zero_point(calibration, how.qtype)
Expand All @@ -238,13 +277,13 @@ def _fake_quant(
ste_array = qarray.clip_to_calibration(array, calibration, how.tiled_axes)
return ste_array + jax.lax.stop_gradient(dq_array - ste_array)

def _collect_quant_stat(
def _update_and_get_quant_stat(
self,
name: str,
calibration: averaging.Calibration,
calibration_is_fixed_range: bool,
) -> averaging.Calibration:
"""Collects the quantization statistics."""
"""Updates the running quantization statistics and returns the average."""
# For SRQ, only per-tensor scale is supported, so we don't need to check the
# act_batch_axes at all.
calibration = jax.tree.map(lambda x: x.mean(keepdims=True), calibration)
Expand Down Expand Up @@ -324,7 +363,7 @@ def _flatten_dot_general(self, *args, _dot_general, **kwargs):
# This special handling is needed because tflite doesn't support multiple
# quantization_dimensions.
if (
aux_data.get(args[1], 'weight_name', None) is not None
aux_data.get(args[1], odml_ops.AuxDataKey.WEIGHT_NAME, None) is not None
and args[1].ndim > 2
and tuple(args[2][0][1]) == (0,)
):
Expand All @@ -346,7 +385,7 @@ def _fake_quant(
# Make the scale and zero point statically computed.
with jax.ensure_compile_time_eval():
# Check if the array is a weight or an activation.
weight_name = aux_data.get(array, 'weight_name', None)
weight_name = aux_data.get(array, odml_ops.AuxDataKey.WEIGHT_NAME, None)
if weight_name is not None: # Weights.
assert quant_stat_name is None
mdl_path = flax_util.get_current_module_path()
Expand Down
Loading