From 897da9a28fccb701179f352c38e0bec66057a01b Mon Sep 17 00:00:00 2001 From: Jiwon Shin Date: Thu, 16 Apr 2026 14:40:14 -0700 Subject: [PATCH] [QwixOdml] Enhance Qwix ODML docstring and inline comments to enhance interpretability. PiperOrigin-RevId: 900917079 --- qwix/_src/providers/odml.py | 79 +++++++++++---- qwix/_src/providers/odml_ops.py | 168 ++++++++++++++++++-------------- qwix/_src/providers/qt.py | 10 +- 3 files changed, 158 insertions(+), 99 deletions(-) diff --git a/qwix/_src/providers/odml.py b/qwix/_src/providers/odml.py index 1e1e2080..225dfa0b 100644 --- a/qwix/_src/providers/odml.py +++ b/qwix/_src/providers/odml.py @@ -77,6 +77,9 @@ def __init__( self._strict = strict self._ops = odml_ops.get_all_ops() + # Only these contraction ops support toggling channelwise weight + # quantization (standard for ODML). For other ops, per-channel weight + # quantization is either not applicable or not supported. for name in [ 'jax.lax.conv_general_dilated', 'jax.lax.dot_general', @@ -115,7 +118,9 @@ def nn_param( # Clear the previous aux_data such as fq_array. aux_data.clear(ret if unbox else ret.unbox()) # weight_name is used to distinguish weights from activations. - aux_data.set(ret if unbox else ret.unbox(), 'weight_name', name) + aux_data.set( + ret if unbox else ret.unbox(), odml_ops.AuxDataKey.WEIGHT_NAME, name + ) return ret def get_interceptors( @@ -145,7 +150,14 @@ def get_interceptors( ] def get_intercept_map(self): - """Used for interception.""" + """Returns a map of function names to their intercepted implementations. + + This method instantiates operator classes from `odml_ops` as functors that + bind to this provider's specific context (e.g., `_fake_quant`). JAX uses + these instances' `__call__` methods to replace the original operations, + allowing them to maintain operator-specific logic while accessing + provider-level state. + """ intercept_map = super().get_intercept_map() intercept_map['flax.linen.Module.param'] = self.nn_param # Add all the ops to the intercept map. @@ -161,8 +173,21 @@ def get_intercept_map(self): def process_model_inputs( self, model: Any, model_args: Any, model_kwargs: Any ) -> tuple[Any, Any, Any]: - """Quantize the input of the model.""" - # Set weight_name for nnx models. Linen models are handled in nn_param. + """Prepares model activations for quantization metadata propagation. + + This method also handles weight tagging for NNX models as a special case. + + Args: + model: The model to process. + model_args: Positional arguments to the model. + model_kwargs: Keyword arguments to the model. + + Returns: + The processed model and arguments with appropriate auxiliary data. + """ + # Weight Handling (NNX only): Eagerly iterate over the graph to clear stale + # metadata and tag parameters with _WEIGHT_NAME. For Flax Linen models, + # weights are handled lazily via `nn_param` interception. if isinstance(model, nnx.Module): for path, node in nnx.iter_graph(model): if isinstance(node, nnx.Module): @@ -171,9 +196,16 @@ def process_model_inputs( # Clear the previous aux_data such as fq_array. aux_data.clear(node.value) # weight_name is used to distinguish weights from activations. - aux_data.set(node.value, 'weight_name', path[-1]) - - # Quantize the model inputs if needed. + aux_data.set(node.value, odml_ops.AuxDataKey.WEIGHT_NAME, path[-1]) + + # Activation Handling: Apply the `ModelInput` operator to all leaves of + # `model_args` and `model_kwargs` (the actual arguments passed to the + # model). + # ModelInput behavior: + # - For non-jax.Array objects (e.g., bool, int), it's a no-op. + # - For jax.Array objects, it clears stale metadata, marks them as + # activations (_IS_ACTIVATION = True), and attaches fixed ranges if set. + # This prepares the inputs as origin points for metadata tracking. op = odml_ops.ModelInput( fixed_range_for_output=self._fixed_range_for_inputs, get_rule_and_op_id_fn=self._get_current_rule_and_op_id, @@ -202,23 +234,30 @@ def _fake_quant( how: qarray.HowToQuantize, quant_stat_name: str | None = None, ) -> jax.Array: - """Apply fake quantization to array. + """Numerical operation used by intercepted model ops to fake-quantize tensors. + + This method is the core implementation passed as a callback to intercepted + operators (e.g., in `odml_ops.py`). It is invoked by those operators to + perform the actual numerical quantization tasks for both activations and + weights during the model execution. - This function can be used on both activations and weights. Gradient will be - passed through. + It handles: + 1. Calibration (including fixed-range overrides from `aux_data`). + 2. Quantization statistics collection and moving-average updates. + 3. Scale and zero-point computation. + 4. Gradient pass-through via a straight-through estimator (STE). Args: array: The array to quantize. - how: How to quantize the array. - quant_stat_name: The name for the quantization statistics. If set, the - quantization statistics will be collected and the scale will be computed - from the statistics. + how: Parameters defining how to quantize the array (e.g., qtype). + quant_stat_name: Unique name for collecting and averaging quantization + statistics. If None, statistics are not collected. Returns: The fake quantized array. """ # Check and apply the fixed-range calibration asscociated with the array. - fixed_range = aux_data.get(array, 'fixed_range', None) + fixed_range = aux_data.get(array, odml_ops.AuxDataKey.FIXED_RANGE, None) if fixed_range is not None: calibration_method = f'fixed,{fixed_range[0]},{fixed_range[1]}' how = dataclasses.replace(how, calibration_method=calibration_method) @@ -226,7 +265,7 @@ def _fake_quant( calibration = qarray.calibrate(array, how) if quant_stat_name is not None: is_fixed_range = how.calibration_method.startswith('fixed') - calibration = self._collect_quant_stat( + calibration = self._update_and_get_quant_stat( quant_stat_name, calibration, is_fixed_range ) scale, zero_point = qarray.compute_scale_zero_point(calibration, how.qtype) @@ -238,13 +277,13 @@ def _fake_quant( ste_array = qarray.clip_to_calibration(array, calibration, how.tiled_axes) return ste_array + jax.lax.stop_gradient(dq_array - ste_array) - def _collect_quant_stat( + def _update_and_get_quant_stat( self, name: str, calibration: averaging.Calibration, calibration_is_fixed_range: bool, ) -> averaging.Calibration: - """Collects the quantization statistics.""" + """Updates the running quantization statistics and returns the average.""" # For SRQ, only per-tensor scale is supported, so we don't need to check the # act_batch_axes at all. calibration = jax.tree.map(lambda x: x.mean(keepdims=True), calibration) @@ -324,7 +363,7 @@ def _flatten_dot_general(self, *args, _dot_general, **kwargs): # This special handling is needed because tflite doesn't support multiple # quantization_dimensions. if ( - aux_data.get(args[1], 'weight_name', None) is not None + aux_data.get(args[1], odml_ops.AuxDataKey.WEIGHT_NAME, None) is not None and args[1].ndim > 2 and tuple(args[2][0][1]) == (0,) ): @@ -346,7 +385,7 @@ def _fake_quant( # Make the scale and zero point statically computed. with jax.ensure_compile_time_eval(): # Check if the array is a weight or an activation. - weight_name = aux_data.get(array, 'weight_name', None) + weight_name = aux_data.get(array, odml_ops.AuxDataKey.WEIGHT_NAME, None) if weight_name is not None: # Weights. assert quant_stat_name is None mdl_path = flax_util.get_current_module_path() diff --git a/qwix/_src/providers/odml_ops.py b/qwix/_src/providers/odml_ops.py index ce9e8763..e273cecb 100644 --- a/qwix/_src/providers/odml_ops.py +++ b/qwix/_src/providers/odml_ops.py @@ -14,6 +14,7 @@ """ODML ops for QAT.""" import dataclasses +import enum import functools import sys from typing import Any, Callable, Sequence @@ -115,43 +116,46 @@ def get_all_ops(): ) -### Possible auxiliary data associated with an array +class AuxDataKey(str, enum.Enum): + """Auxiliary data keys.""" -# Whether an array should be fake quantized by the next op and what rule to use. -# -# For the output of an op, it's not fake-quantized immediately because the next -# op may choose to delay the FQ, e.g. dot_general + add + relu can be fused and -# no FQ should be inserted in between. -_FQ_RULE = 'fq_rule' # QuantizationRule + # Whether an array should be fake quantized by the next op and what rule to + # use. For the output of an op, it's not fake-quantized immediately because + # the next op may choose to delay the FQ, e.g. dot_general + add + relu can be + # fused and no FQ should be inserted in between. + FQ_RULE = 'fq_rule' # QuantizationRule + + # Whether the (unquantized) array is already fake-quantized in another code + # path and what the fake-quantized array is. This avoids the same array being + # fake-quantized multiple times. + FQ_ARRAY = 'fq_array' # array -# Whether the (unquantized) array is already fake-quantized in another code path -# and what the fake-quantized array is. This avoids the same array being -# fake-quantized multiple times. -_FQ_ARRAY = 'fq_array' # array + # Whether the previous op allows to fuse arithmetic ops or batch norm after + # it. + ALLOW_FUSION = 'allow_fusion' # bool -# Whether the previous op allows to fuse arithmetic ops or batch norm after it. -_ALLOW_FUSION = 'allow_fusion' # bool + # Whether the array is an activation. An array can be either an activation, + # a weight, or a constant. + IS_ACTIVATION = 'is_activation' # bool -# Whether the array is an activation. An array can be either an activation, -# a weight, or a constant. -_IS_ACTIVATION = 'is_activation' # bool + # Whether the array is a weight and what is its name. Weights don't need to + # have quantization statistics collected because they are statically + # quantized. The name is useful in the conversion provider to find the + # static weight. + WEIGHT_NAME = 'weight_name' # str -# Whether the array is a weight and what is its name. Weights don't need to have -# quantization statistics collected because they are statically quantized. -# The name is useful in the conversion provider to find the static weight. -_WEIGHT_NAME = 'weight_name' # str + # Fixed range for logistic functions whose output ranges are known, e.g. + # softmax. + FIXED_RANGE = 'fixed_range' # tuple[float, float] -# Fixed range for logistic functions whose output ranges are known, e.g. -# softmax. -_FIXED_RANGE = 'fixed_range' # tuple[float, float] # Metadata keys that depend on the value being preserved. # If the value changes (e.g. add, mul), these keys become invalid. _VALUE_DEPENDENT_METADATA = ( - _WEIGHT_NAME, - _FQ_RULE, - _FIXED_RANGE, - _ALLOW_FUSION, + AuxDataKey.WEIGHT_NAME, + AuxDataKey.FQ_RULE, + AuxDataKey.FIXED_RANGE, + AuxDataKey.ALLOW_FUSION, ) # These ops only change the tensor view or layout, not the values. @@ -245,7 +249,7 @@ def _inputs_have_activations(self, args: Sequence[Any]) -> bool: raise ValueError(f'input_idx is not set for op {self._op_name}.') for idx in self.input_idx: if isinstance(args[idx], jax.Array) and aux_data.get( - args[idx], _IS_ACTIVATION, False + args[idx], AuxDataKey.IS_ACTIVATION, False ): return True return False @@ -303,7 +307,7 @@ def _maybe_fake_quant( The fake quantized array. """ # Check if the array is already quantized in another code path. - fq_array = aux_data.get(array, _FQ_ARRAY, None) + fq_array = aux_data.get(array, AuxDataKey.FQ_ARRAY, None) if fq_array is not None: return array if fq_array == 'self' else fq_array @@ -313,7 +317,7 @@ def _maybe_fake_quant( # 1) Handle the Weight case (immediate quantization). # rule.weight_qtype means this op will use rule.weight_qtype as weights. - if aux_data.get(array, _WEIGHT_NAME, None) is not None: + if aux_data.get(array, AuxDataKey.WEIGHT_NAME, None) is not None: if rule and rule.weight_qtype: # If there is a rule for weights, quantize the weights. how = qarray.HowToQuantize( @@ -325,7 +329,7 @@ def _maybe_fake_quant( calibration_method=rule.act_calibration_method, ) fq_array = self._fake_quant_fn(array, how, None) - aux_data.set(array, _FQ_ARRAY, fq_array) + aux_data.set(array, AuxDataKey.FQ_ARRAY, fq_array) return fq_array else: # No rule for weights, return as is. @@ -340,7 +344,7 @@ def _maybe_fake_quant( if rule and rule.act_qtype is None: return array - previous_rule = aux_data.get(array, _FQ_RULE, None) + previous_rule = aux_data.get(array, AuxDataKey.FQ_RULE, None) if previous_rule is not None: # Delayed Quantization: The Previous Op (Producer) specified how its # output should be quantized. The Current Op (Consumer) now executes @@ -366,13 +370,13 @@ def _maybe_fake_quant( qtype=rule.act_qtype, tiled_axes={}, # Use per-channel scales for batch axes, which will be reduced later - # in _collect_quant_stat. + # in _update_and_get_quant_stat. channelwise_axes=rule.act_batch_axes, calibration_method=rule.act_calibration_method, ) fq_array = self._fake_quant_fn(array, how, quant_stat_name) - aux_data.set(array, _FQ_ARRAY, fq_array) + aux_data.set(array, AuxDataKey.FQ_ARRAY, fq_array) return fq_array def _fake_quant_output( @@ -382,12 +386,12 @@ def _fake_quant_output( # Handle all leaves of the pytree. for x in jax.tree_util.tree_leaves(outputs): if isinstance(x, jax.Array): - aux_data.set(x, _IS_ACTIVATION, True) + aux_data.set(x, AuxDataKey.IS_ACTIVATION, True) if self.fixed_range_for_output is not None: - aux_data.set(x, _FIXED_RANGE, self.fixed_range_for_output) + aux_data.set(x, AuxDataKey.FIXED_RANGE, self.fixed_range_for_output) # Output is only quantized in SRQ. if rule and rule.act_qtype and rule.act_static_scale: - aux_data.set(x, _FQ_RULE, rule) + aux_data.set(x, AuxDataKey.FQ_RULE, rule) return outputs @@ -405,7 +409,7 @@ def __call__(self, *args, **kwargs): out = self._call_original_op(*args, **kwargs) if rule and rule.act_qtype: # Mark the output as already quantized. - aux_data.set(out, _FQ_ARRAY, 'self') + aux_data.set(out, AuxDataKey.FQ_ARRAY, 'self') return self._fake_quant_output(out, rule) @@ -419,7 +423,7 @@ def __call__(self, *args, **kwargs): return self._call_original_op(*args, **kwargs) rule, _ = self._get_rule_and_op_id_fn(self._op_name) if rule is None or rule.act_qtype is None: - rule = aux_data.get(args[self.input_idx[0]], _FQ_RULE, None) + rule = aux_data.get(args[self.input_idx[0]], AuxDataKey.FQ_RULE, None) # No quantization on the input. out = self._call_original_op(*args, **kwargs) return self._fake_quant_output(out, rule) @@ -467,11 +471,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def __call__(self, x: Any) -> Any: - if self.check_activation and not aux_data.get(x, _IS_ACTIVATION, False): + if self.check_activation and not aux_data.get( + x, AuxDataKey.IS_ACTIVATION, False + ): raise NotAnActivationError _, op_id = self._get_rule_and_op_id_fn(self._op_name) if self.fixed_range_for_output is not None: - aux_data.set(x, _FIXED_RANGE, self.fixed_range_for_output) + aux_data.set(x, AuxDataKey.FIXED_RANGE, self.fixed_range_for_output) # Only FQ the output if the previous op wants. return self._maybe_fake_quant(x, None, op_id) @@ -485,18 +491,19 @@ def _forward_metadata(inputs: Any, outputs: Any, is_value_preserving_op: bool): is_value_preserving_op: Whether the op preserves the value. Metadata propagation rules: - 1. _IS_ACTIVATION: Propagated if ANY input is an activation (Union). + 1. AuxDataKey.IS_ACTIVATION: Propagated if ANY input is an activation (Union). This tracks data provenance - if data comes from an activation, it remains an activation regardless of the operation. - 2. _WEIGHT_NAME, _FQ_RULE, _FIXED_RANGE, _ALLOW_FUSION: Propagated ONLY for - value-preserving ops (e.g. reshape, transpose). + 2. AuxDataKey.WEIGHT_NAME, AuxDataKey.FQ_RULE, AuxDataKey.FIXED_RANGE, + AuxDataKey.ALLOW_FUSION: Propagated ONLY for value-preserving ops (e.g. + reshape, transpose). These keys are "value-preserving" because they describe properties of the specific tensor values (e.g. "this tensor is weight 'w'", "this tensor has range X"). If the values change (e.g. add 1), these properties are lost. - 3. _FQ_ARRAY: Propagated if ALL activation inputs share the same FQ array - (Intersection) AND the op is value-preserving. + 3. AuxDataKey.FQ_ARRAY: Propagated if ALL activation inputs share the same + FQ array (Intersection) AND the op is value-preserving. This ensures we don't accidentally treat a mixed or modified value as already quantized. """ @@ -509,10 +516,10 @@ def _forward_metadata(inputs: Any, outputs: Any, is_value_preserving_op: bool): continue # Check if at least one arg is activation. - if aux_data.get(arg, _IS_ACTIVATION, False): + if aux_data.get(arg, AuxDataKey.IS_ACTIVATION, False): is_activation = True # Check if every arg is quantized. - if aux_data.get(arg, _FQ_ARRAY, None) != 'self': + if aux_data.get(arg, AuxDataKey.FQ_ARRAY, None) != 'self': all_args_quantized = False # For value-preserving ops, handle _VALUE_DEPENDENT_METADATA. @@ -521,21 +528,21 @@ def _forward_metadata(inputs: Any, outputs: Any, is_value_preserving_op: bool): val = aux_data.get(arg, key, None) if val is None: continue - if key == _WEIGHT_NAME: + if key == AuxDataKey.WEIGHT_NAME: weight_names.add(val) else: # Last wins for value dependent metadata. metadata[key] = val - # Set _IS_ACTIVATION if at least one arg is activation. + # Set IS_ACTIVATION if at least one arg is activation. if is_activation: - metadata[_IS_ACTIVATION] = True - # For value-preserving ops, set _WEIGHT_NAME if purely a single weight op. + metadata[AuxDataKey.IS_ACTIVATION] = True + # For value-preserving ops, set WEIGHT_NAME if purely a single weight op. elif len(weight_names) == 1: - metadata[_WEIGHT_NAME] = next(iter(weight_names)) - # For value-preserving ops, set _FQ_ARRAY if all args are fq and out is act. + metadata[AuxDataKey.WEIGHT_NAME] = next(iter(weight_names)) + # For value-preserving ops, set FQ_ARRAY if all args are fq and out is act. if is_value_preserving_op and is_activation and all_args_quantized: - metadata[_FQ_ARRAY] = 'self' + metadata[AuxDataKey.FQ_ARRAY] = 'self' # Propagate metadata to outputs. if metadata: @@ -585,12 +592,12 @@ class BatchNorm(QuantizedOp): """BatchNorm op, which can be fused into previous op completely.""" def __call__(self, norm, x: jax.Array, *args, **kwargs) -> jax.Array: - if not aux_data.get(x, _IS_ACTIVATION, False): + if not aux_data.get(x, AuxDataKey.IS_ACTIVATION, False): return norm(x, *args, **kwargs) - if aux_data.get(x, _ALLOW_FUSION, False): - rule = aux_data.get(x, _FQ_RULE, None) + if aux_data.get(x, AuxDataKey.ALLOW_FUSION, False): + rule = aux_data.get(x, AuxDataKey.FQ_RULE, None) out = norm(x, *args, **kwargs) - aux_data.set(out, _ALLOW_FUSION, True) + aux_data.set(out, AuxDataKey.ALLOW_FUSION, True) else: rule, op_id = self._get_rule_and_op_id_fn('batch_norm_op') x = self._maybe_fake_quant(x, rule, op_id) @@ -635,8 +642,8 @@ def _fake_quant_inputs( """Fake quantize the inputs of the op.""" if ( self._op_name in ('add', 'sub', 'mul', 'truediv') - and aux_data.get(args[1], _ALLOW_FUSION, False) - and not aux_data.get(args[2], _IS_ACTIVATION, False) + and aux_data.get(args[1], AuxDataKey.ALLOW_FUSION, False) + and not aux_data.get(args[2], AuxDataKey.IS_ACTIVATION, False) ): # The previous op allows to fuse adding a constant. self._output_allow_fusion = True @@ -647,7 +654,7 @@ def _fake_quant_output( self, outputs: Any, rule: qconfig.QuantizationRule | None ) -> Any: if self._output_allow_fusion: - aux_data.set(outputs, _ALLOW_FUSION, True) + aux_data.set(outputs, AuxDataKey.ALLOW_FUSION, True) return super()._fake_quant_output(outputs, rule) @@ -656,12 +663,17 @@ class Concatenate(QuantizedOp): def __call__(self, arrays: Sequence[jax.Array], *args, **kwargs) -> jax.Array: """QAT concatenate.""" - if not any(aux_data.get(x, _IS_ACTIVATION, False) for x in arrays): + if not any( + aux_data.get(x, AuxDataKey.IS_ACTIVATION, False) for x in arrays + ): return self._call_original_op(arrays, *args, **kwargs) # Forward the fixed_range if all inputs have the same. - fixed_range = aux_data.get(arrays[0], _FIXED_RANGE, None) - if any(aux_data.get(x, _FIXED_RANGE, None) != fixed_range for x in arrays): + fixed_range = aux_data.get(arrays[0], AuxDataKey.FIXED_RANGE, None) + if any( + aux_data.get(x, AuxDataKey.FIXED_RANGE, None) != fixed_range + for x in arrays + ): fixed_range = None # If ourselves is not quantized, fake quantize the inputs if needed. @@ -676,7 +688,7 @@ def __call__(self, arrays: Sequence[jax.Array], *args, **kwargs) -> jax.Array: out = jnp.concatenate(arrays, *args, **kwargs) if fixed_range is not None: - aux_data.set(out, _FIXED_RANGE, fixed_range) + aux_data.set(out, AuxDataKey.FIXED_RANGE, fixed_range) return self._fake_quant_output(out, rule) @@ -702,7 +714,7 @@ def __call__(self, *args, **kwargs) -> jax.Array: out = self._call_original_op(*args, **kwargs) if rule and rule.act_qtype: # Output doesn't need more FQ. - aux_data.set(out, _FQ_ARRAY, 'self') + aux_data.set(out, AuxDataKey.FQ_ARRAY, 'self') return self._fake_quant_output(out, rule) @@ -711,12 +723,12 @@ class Silu(QuantizedOp): def __call__(self, x: jax.Array) -> jax.Array: """QAT silu.""" - if not aux_data.get(x, _IS_ACTIVATION, False): + if not aux_data.get(x, AuxDataKey.IS_ACTIVATION, False): return self._call_original_op(x) rule, op_id = self._get_rule_and_op_id_fn(self._op_name) x = self._maybe_fake_quant(x, rule, op_id) y = jax.nn.sigmoid(x) - aux_data.set(y, _FIXED_RANGE, Softmax.fixed_range_for_output) + aux_data.set(y, AuxDataKey.FIXED_RANGE, Softmax.fixed_range_for_output) y = self._maybe_fake_quant(y, rule, op_id + '_sigmoid') return self._fake_quant_output(x * y, rule) @@ -786,10 +798,18 @@ def __call__(self, *args, **kwargs) -> jax.Array: rule, op_id = self._get_rule_and_op_id_fn(self._op_name) args = list(args) - lhs_is_activation = aux_data.get(args[lhs_idx], _IS_ACTIVATION, False) - lhs_is_weight = aux_data.get(args[lhs_idx], _WEIGHT_NAME, None) is not None - rhs_is_activation = aux_data.get(args[rhs_idx], _IS_ACTIVATION, False) - rhs_is_weight = aux_data.get(args[rhs_idx], _WEIGHT_NAME, None) is not None + lhs_is_activation = aux_data.get( + args[lhs_idx], AuxDataKey.IS_ACTIVATION, False + ) + lhs_is_weight = ( + aux_data.get(args[lhs_idx], AuxDataKey.WEIGHT_NAME, None) is not None + ) + rhs_is_activation = aux_data.get( + args[rhs_idx], AuxDataKey.IS_ACTIVATION, False + ) + rhs_is_weight = ( + aux_data.get(args[rhs_idx], AuxDataKey.WEIGHT_NAME, None) is not None + ) assert lhs_is_activation + lhs_is_weight <= 1 assert rhs_is_activation + rhs_is_weight <= 1 @@ -835,7 +855,7 @@ def __call__(self, *args, **kwargs) -> jax.Array: ) out = self._call_original_op(*args, **kwargs) - aux_data.set(out, _ALLOW_FUSION, True) + aux_data.set(out, AuxDataKey.ALLOW_FUSION, True) return self._fake_quant_output(out, rule) diff --git a/qwix/_src/providers/qt.py b/qwix/_src/providers/qt.py index 6c78d4af..1b99482f 100644 --- a/qwix/_src/providers/qt.py +++ b/qwix/_src/providers/qt.py @@ -237,13 +237,13 @@ def get_intercept_map(self): 'jax.lax.ragged_dot': self.ragged_dot, } - def _collect_quant_stat( + def _update_and_get_quant_stat( self, name: str, batch_axes: tuple[int, ...], calibration: averaging.Calibration, ) -> averaging.Calibration: - """Collects the quantization statistics.""" + """Updates the running quantization statistics and returns the average.""" # Calculate the mean over the batch axes. calibration = jax.tree.map( lambda x: x.mean(axis=batch_axes, keepdims=True), calibration @@ -274,7 +274,7 @@ def _create_conv_general_qt_config( lhs_collect_quant_stat = None if rule.act_qtype is not None and rule.act_static_scale: lhs_collect_quant_stat = functools.partial( - self._collect_quant_stat, f'{op_id}_lhs', rule.act_batch_axes + self._update_and_get_quant_stat, f'{op_id}_lhs', rule.act_batch_axes ) assert flax_util.find_param(rhs) is not None @@ -322,7 +322,7 @@ def _create_dot_general_qt_config( lhs_calibration_method = rule.act_calibration_method if rule.act_static_scale: lhs_collect_quant_stat = functools.partial( - self._collect_quant_stat, f'{op_id}_lhs', rule.act_batch_axes + self._update_and_get_quant_stat, f'{op_id}_lhs', rule.act_batch_axes ) # RHS configs based on whether it's a weight or an activation. @@ -341,7 +341,7 @@ def _create_dot_general_qt_config( rhs_calibration_method = rule.act_calibration_method if rule.act_static_scale: rhs_collect_quant_stat = functools.partial( - self._collect_quant_stat, f'{op_id}_rhs', rule.act_batch_axes + self._update_and_get_quant_stat, f'{op_id}_rhs', rule.act_batch_axes ) # bwd config, which is only enabled when bwd_qtype is set.