Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions qwix/_src/providers/odml.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _fake_quant(
calibration = qarray.calibrate(array, how)
if quant_stat_name is not None:
is_fixed_range = how.calibration_method.startswith('fixed')
calibration = self._collect_quant_stat(
calibration = self._update_and_get_quant_stat(
quant_stat_name, calibration, is_fixed_range
)
scale, zero_point = qarray.compute_scale_zero_point(calibration, how.qtype)
Expand All @@ -240,13 +240,13 @@ def _fake_quant(
ste_array = qarray.clip_to_calibration(array, calibration, how.tiled_axes)
return ste_array + jax.lax.stop_gradient(dq_array - ste_array)

def _collect_quant_stat(
def _update_and_get_quant_stat(
self,
name: str,
calibration: averaging.Calibration,
calibration_is_fixed_range: bool,
) -> averaging.Calibration:
"""Collects the quantization statistics."""
"""Updates the running quantization statistics and returns the average."""
# For SRQ, only per-tensor scale is supported, so we don't need to check the
# act_batch_axes at all.
calibration = jax.tree.map(lambda x: x.mean(keepdims=True), calibration)
Expand Down
2 changes: 1 addition & 1 deletion qwix/_src/providers/odml_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _maybe_fake_quant(
qtype=effective_rule.act_qtype,
tiled_axes={},
# Use per-channel scales for batch axes, which will be reduced later
# in _collect_quant_stat.
# in _update_and_get_quant_stat.
channelwise_axes=effective_rule.act_batch_axes,
calibration_method=effective_rule.act_calibration_method,
)
Expand Down
10 changes: 5 additions & 5 deletions qwix/_src/providers/qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,13 @@ def get_intercept_map(self):
'jax.lax.ragged_dot': self.ragged_dot,
}

def _collect_quant_stat(
def _update_and_get_quant_stat(
self,
name: str,
batch_axes: tuple[int, ...],
calibration: averaging.Calibration,
) -> averaging.Calibration:
"""Collects the quantization statistics."""
"""Updates the running quantization statistics and returns the average."""
# Calculate the mean over the batch axes.
calibration = jax.tree.map(
lambda x: x.mean(axis=batch_axes, keepdims=True), calibration
Expand Down Expand Up @@ -274,7 +274,7 @@ def _create_conv_general_qt_config(
lhs_collect_quant_stat = None
if rule.act_qtype is not None and rule.act_static_scale:
lhs_collect_quant_stat = functools.partial(
self._collect_quant_stat, f'{op_id}_lhs', rule.act_batch_axes
self._update_and_get_quant_stat, f'{op_id}_lhs', rule.act_batch_axes
)
assert flax_util.find_param(rhs) is not None

Expand Down Expand Up @@ -322,7 +322,7 @@ def _create_dot_general_qt_config(
lhs_calibration_method = rule.act_calibration_method
if rule.act_static_scale:
lhs_collect_quant_stat = functools.partial(
self._collect_quant_stat, f'{op_id}_lhs', rule.act_batch_axes
self._update_and_get_quant_stat, f'{op_id}_lhs', rule.act_batch_axes
)

# RHS configs based on whether it's a weight or an activation.
Expand All @@ -341,7 +341,7 @@ def _create_dot_general_qt_config(
rhs_calibration_method = rule.act_calibration_method
if rule.act_static_scale:
rhs_collect_quant_stat = functools.partial(
self._collect_quant_stat, f'{op_id}_rhs', rule.act_batch_axes
self._update_and_get_quant_stat, f'{op_id}_rhs', rule.act_batch_axes
)

# bwd config, which is only enabled when bwd_qtype is set.
Expand Down