diff --git a/qwix/_src/providers/odml.py b/qwix/_src/providers/odml.py index 4f8f19b..db4d2c8 100644 --- a/qwix/_src/providers/odml.py +++ b/qwix/_src/providers/odml.py @@ -228,7 +228,7 @@ def _fake_quant( calibration = qarray.calibrate(array, how) if quant_stat_name is not None: is_fixed_range = how.calibration_method.startswith('fixed') - calibration = self._collect_quant_stat( + calibration = self._update_and_get_quant_stat( quant_stat_name, calibration, is_fixed_range ) scale, zero_point = qarray.compute_scale_zero_point(calibration, how.qtype) @@ -240,13 +240,13 @@ def _fake_quant( ste_array = qarray.clip_to_calibration(array, calibration, how.tiled_axes) return ste_array + jax.lax.stop_gradient(dq_array - ste_array) - def _collect_quant_stat( + def _update_and_get_quant_stat( self, name: str, calibration: averaging.Calibration, calibration_is_fixed_range: bool, ) -> averaging.Calibration: - """Collects the quantization statistics.""" + """Updates the running quantization statistics and returns the average.""" # For SRQ, only per-tensor scale is supported, so we don't need to check the # act_batch_axes at all. calibration = jax.tree.map(lambda x: x.mean(keepdims=True), calibration) diff --git a/qwix/_src/providers/odml_ops.py b/qwix/_src/providers/odml_ops.py index ead8703..7b53769 100644 --- a/qwix/_src/providers/odml_ops.py +++ b/qwix/_src/providers/odml_ops.py @@ -410,7 +410,7 @@ def _maybe_fake_quant( qtype=effective_rule.act_qtype, tiled_axes={}, # Use per-channel scales for batch axes, which will be reduced later - # in _collect_quant_stat. + # in _update_and_get_quant_stat. channelwise_axes=effective_rule.act_batch_axes, calibration_method=effective_rule.act_calibration_method, ) diff --git a/qwix/_src/providers/qt.py b/qwix/_src/providers/qt.py index 6c78d4a..1b99482 100644 --- a/qwix/_src/providers/qt.py +++ b/qwix/_src/providers/qt.py @@ -237,13 +237,13 @@ def get_intercept_map(self): 'jax.lax.ragged_dot': self.ragged_dot, } - def _collect_quant_stat( + def _update_and_get_quant_stat( self, name: str, batch_axes: tuple[int, ...], calibration: averaging.Calibration, ) -> averaging.Calibration: - """Collects the quantization statistics.""" + """Updates the running quantization statistics and returns the average.""" # Calculate the mean over the batch axes. calibration = jax.tree.map( lambda x: x.mean(axis=batch_axes, keepdims=True), calibration @@ -274,7 +274,7 @@ def _create_conv_general_qt_config( lhs_collect_quant_stat = None if rule.act_qtype is not None and rule.act_static_scale: lhs_collect_quant_stat = functools.partial( - self._collect_quant_stat, f'{op_id}_lhs', rule.act_batch_axes + self._update_and_get_quant_stat, f'{op_id}_lhs', rule.act_batch_axes ) assert flax_util.find_param(rhs) is not None @@ -322,7 +322,7 @@ def _create_dot_general_qt_config( lhs_calibration_method = rule.act_calibration_method if rule.act_static_scale: lhs_collect_quant_stat = functools.partial( - self._collect_quant_stat, f'{op_id}_lhs', rule.act_batch_axes + self._update_and_get_quant_stat, f'{op_id}_lhs', rule.act_batch_axes ) # RHS configs based on whether it's a weight or an activation. @@ -341,7 +341,7 @@ def _create_dot_general_qt_config( rhs_calibration_method = rule.act_calibration_method if rule.act_static_scale: rhs_collect_quant_stat = functools.partial( - self._collect_quant_stat, f'{op_id}_rhs', rule.act_batch_axes + self._update_and_get_quant_stat, f'{op_id}_rhs', rule.act_batch_axes ) # bwd config, which is only enabled when bwd_qtype is set.