Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions qwix/_src/core/conv_general_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class ConvGeneralQtConfig:

# Misc.
disable_channelwise_axes: bool = False
bwd_use_original_residuals: bool = False


# Swaps the first two dimension indices of a specification.
Expand Down Expand Up @@ -187,11 +186,8 @@ def _quantize_operand(
operand, qtype, scale, zero_point
)

residuals = (lhs, rhs)
lhs = _quantize_operand(lhs, for_lhs=True)
rhs = _quantize_operand(rhs, for_lhs=False)
if not config.bwd_use_original_residuals:
residuals = (lhs, rhs)

primal_out = conv_general.conv_general_dilated(
lhs,
Expand All @@ -204,6 +200,7 @@ def _quantize_operand(
feature_group_count,
batch_group_count,
)
residuals = (lhs, rhs)

return primal_out, residuals

Expand Down
142 changes: 72 additions & 70 deletions qwix/_src/core/dot_general_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from qwix._src.core import dot_general
from qwix._src.core import numerics
from qwix._src.core import qarray
from qwix._src.core import qarray_qt


@dataclasses.dataclass(slots=True, frozen=True, kw_only=True)
Expand All @@ -43,19 +44,16 @@ class DotGeneralQtConfig:
dlhs_grad_qtype: jax.typing.DTypeLike | None = None # incoming gradient
dlhs_grad_calibration_method: str = 'absmax'
dlhs_tile_size: int | float | None = None
dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None

# Backward pass (drhs).
drhs_grad_qtype: jax.typing.DTypeLike | None = None # incoming gradient
drhs_grad_calibration_method: str = 'absmax'
drhs_tile_size: int | float | None = None
drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None

# Misc.
disable_channelwise_axes: bool = False
bwd_use_original_residuals: bool = False # what to use as residuals

# Configs for stochastic rounding.
dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None
drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None


def _ranges_like(*xs):
Expand Down Expand Up @@ -124,84 +122,47 @@ def _apply_rhs_scale_to_lhs(lhs, rhs_scale, dnums):
# disable interceptions for dot_general_qt_fwd.
@interception.disable_interceptions
def dot_general_qt_fwd(
lhs: jax.Array,
rhs: jax.Array,
lhs: jax.Array | qarray_qt.QArrayWithGradient,
rhs: jax.Array | qarray_qt.QArrayWithGradient,
dimension_numbers: jax.lax.DotDimensionNumbers,
config: DotGeneralQtConfig,
):
"""Forward pass for dot_general_qt custom VJP."""
ndims = (lhs.ndim, rhs.ndim)

def _quantize_operand(operand: jax.Array, is_lhs: bool) -> qarray.MaybeQArray:
"""Quantizes a single operand for the forward pass if configured to do so."""
if is_lhs:
qtype = config.lhs_qtype
calibration_method = config.lhs_calibration_method
collect_quant_stat = config.lhs_collect_quant_stat
else:
qtype = config.rhs_qtype
calibration_method = config.rhs_calibration_method
collect_quant_stat = config.rhs_collect_quant_stat

if not (qtype and numerics.should_quantize(operand.dtype)):
return operand

how = dot_general.get_how_to_quantize(
dimension_numbers=dimension_numbers,
ndims=ndims,
for_lhs=is_lhs,
qtype=qtype,
tile_size=config.tile_size,
calibration_method=calibration_method,
)
if config.disable_channelwise_axes:
how = dataclasses.replace(how, channelwise_axes=[])

calibration = qarray.calibrate(operand, how)
if collect_quant_stat:
calibration = collect_quant_stat(calibration)
scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype)
return qarray.quantize_with_scale_zero_point(
operand, how.qtype, scale, zero_point
)

qlhs = _quantize_operand(lhs, is_lhs=True)
qrhs = _quantize_operand(rhs, is_lhs=False)

primal_out = dot_general.dot_general(qlhs, qrhs, dimension_numbers)

if config.bwd_use_original_residuals:
residuals = (lhs, rhs)
else:
residuals = (qlhs, qrhs)

return primal_out, residuals
del config
return dot_general.dot_general(lhs, rhs, dimension_numbers), (lhs, rhs)


def dot_general_qt_bwd(
fwd_dimension_numbers: jax.lax.DotDimensionNumbers,
config: DotGeneralQtConfig,
residuals: tuple[qarray.MaybeQArray, qarray.MaybeQArray],
residuals: tuple[
jax.Array | qarray_qt.QArrayWithGradient,
jax.Array | qarray_qt.QArrayWithGradient,
],
g: jax.Array,
):
"""Backward pass for dot_general_qt custom VJP."""
lhs, rhs = residuals

def _compute_gradient_for_operand(
g: jax.Array, y: qarray.MaybeQArray, *, for_dlhs: bool
):
"""Compute dot_general for gradient and other_fwd_operand."""
) -> jax.Array | qarray_qt.QArrayWithGradient:
"""Compute dx from g and y."""
bwd_dnums, transpose_axes = _update_dimension_numbers_for_backward(
fwd_dimension_numbers, (lhs.ndim, rhs.ndim), for_dlhs=for_dlhs
)
if for_dlhs:
g_qtype = config.dlhs_grad_qtype
g_tile_size = config.dlhs_tile_size
g_calibration_method = config.dlhs_grad_calibration_method
g_noise_fn = config.dlhs_stochastic_rounding_noise_fn
result_type = lhs # the result gradient must match this type.
else:
g_qtype = config.drhs_grad_qtype
g_tile_size = config.drhs_tile_size
g_calibration_method = config.drhs_grad_calibration_method
g_noise_fn = config.drhs_stochastic_rounding_noise_fn
result_type = rhs # the result gradient must match this type.

if g_qtype and numerics.should_quantize(g.dtype):
if isinstance(y, qarray.QArray) and not qarray.get_tiled_axes(y):
Expand All @@ -219,23 +180,20 @@ def _compute_gradient_for_operand(
tile_size=g_tile_size,
calibration_method=g_calibration_method,
)
g_how = dataclasses.replace(g_how, noise_fn=g_noise_fn)
if config.disable_channelwise_axes:
g_how = dataclasses.replace(g_how, channelwise_axes=[])

if for_dlhs and config.dlhs_stochastic_rounding_noise_fn:
g_how = dataclasses.replace(
g_how,
noise_fn=config.dlhs_stochastic_rounding_noise_fn,
)
if not for_dlhs and config.drhs_stochastic_rounding_noise_fn:
g_how = dataclasses.replace(
g_how,
noise_fn=config.drhs_stochastic_rounding_noise_fn,
)
g = qarray.quantize(g, g_how)

grad_res = dot_general.dot_general(g, y, bwd_dnums)
return jax.lax.transpose(grad_res, transpose_axes)
grad_res = jax.lax.transpose(grad_res, transpose_axes)
if isinstance(result_type, qarray_qt.QArrayWithGradient):
return dataclasses.replace(
result_type, qvalue=None, scale=None, zero_point=None, _grad=grad_res
)
else:
return grad_res

dlhs = _compute_gradient_for_operand(g, rhs, for_dlhs=True)
drhs = _compute_gradient_for_operand(g, lhs, for_dlhs=False)
Expand All @@ -244,15 +202,59 @@ def _compute_gradient_for_operand(


@functools.partial(jax.custom_vjp, nondiff_argnums=(2, 3))
def dot_general_fwd_bwd(
lhs: jax.Array | qarray_qt.QArrayWithGradient,
rhs: jax.Array | qarray_qt.QArrayWithGradient,
dimension_numbers: jax.lax.DotDimensionNumbers,
config: DotGeneralQtConfig,
) -> jax.Array:
"""Quantized dot_general with backpropagation support."""
del config
return dot_general.dot_general(lhs, rhs, dimension_numbers)


dot_general_fwd_bwd.defvjp(dot_general_qt_fwd, dot_general_qt_bwd)


def dot_general_qt(
lhs: jax.Array,
rhs: jax.Array,
dimension_numbers: jax.lax.DotDimensionNumbers,
config: DotGeneralQtConfig,
) -> jax.Array:
"""Quantized dot_general with backpropagation support."""
result, _ = dot_general_qt_fwd(lhs, rhs, dimension_numbers, config)
return result
if config.lhs_qtype and numerics.should_quantize(lhs.dtype):
how = dot_general.get_how_to_quantize(
dimension_numbers=dimension_numbers,
ndims=(lhs.ndim, rhs.ndim),
for_lhs=True,
qtype=config.lhs_qtype,
tile_size=config.tile_size,
calibration_method=config.lhs_calibration_method,
)
if config.disable_channelwise_axes:
how = dataclasses.replace(how, channelwise_axes=[])

calibration = qarray.calibrate(lhs, how)
if config.lhs_collect_quant_stat:
calibration = config.lhs_collect_quant_stat(calibration)
lhs = qarray_qt.quantize_with_calibration(lhs, how.qtype, calibration)

if config.rhs_qtype and numerics.should_quantize(rhs.dtype):
how = dot_general.get_how_to_quantize(
dimension_numbers=dimension_numbers,
ndims=(lhs.ndim, rhs.ndim),
for_lhs=False,
qtype=config.rhs_qtype,
tile_size=config.tile_size,
calibration_method=config.rhs_calibration_method,
)
if config.disable_channelwise_axes:
how = dataclasses.replace(how, channelwise_axes=[])

calibration = qarray.calibrate(rhs, how)
if config.rhs_collect_quant_stat:
calibration = config.rhs_collect_quant_stat(calibration)
rhs = qarray_qt.quantize_with_calibration(rhs, how.qtype, calibration)

dot_general_qt.defvjp(dot_general_qt_fwd, dot_general_qt_bwd)
return dot_general_fwd_bwd(lhs, rhs, dimension_numbers, config)
79 changes: 79 additions & 0 deletions qwix/_src/core/qarray_qt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""QArray with gradient for custom VJP."""

import dataclasses
from typing import Mapping
import flax.struct
import jax
from qwix._src.core import qarray


@flax.struct.dataclass
class QArrayWithGradient(qarray.QArray):
"""QArray with gradient.

This dataclass allows us to associate a gradient with the QArray. It's
achieved by defining an extra attribute `_grad` on the QArray, which has the
same dtype and the same shape as the unquantized array. In forward pass, the
`_grad` does nothing and should never be consumed. In backward pass, the
`_grad` carries the gradient of the whole QArray.

This approach overcomes the Jax limitation on the gradients, i.e., the
gradient of a qvalue of int8[128,128] has to be float0[128,128], while the
gradient of a scale of float32[1,1] has to be float32[1,1]. An alternative
is to define the QArray as a new Hijax type, which is more complex.
"""

_grad: jax.Array = flax.struct.field(kw_only=True)


def quantize_with_calibration(
array: jax.Array,
qtype: jax.typing.DTypeLike,
calibration: Mapping[str, jax.Array],
clip_gradient: bool = False,
) -> QArrayWithGradient:
"""Quantizes an array with calibration with backpropagation support.

Args:
array: The array to quantize.
qtype: The quantized type.
calibration: The calibration of the array.
clip_gradient: Whether to clip the straight-through estimator to the
calibration range, i.e., the gradient outside the calibration range is 0.

Returns:
The quantized array with backpropagation support.
"""
scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype)
res = qarray.quantize_with_scale_zero_point(array, qtype, scale, zero_point)
if clip_gradient:
array = qarray.clip_to_calibration(
array, calibration, qarray.get_tiled_axes(res)
)
# Do not allow gradients on the quantized array to flow back to the input.
res = jax.lax.stop_gradient(res)
return QArrayWithGradient(**dataclasses.asdict(res), _grad=array)


@jax.custom_jvp
def dequantize(array: QArrayWithGradient) -> jax.Array:
"""Dequantizes an array."""
return qarray.dequantize(array)


@dequantize.defjvp
def _dequantize_jvp(primals, tangents):
return dequantize(*primals), tangents[0]._grad # pylint: disable=protected-access
9 changes: 1 addition & 8 deletions qwix/_src/providers/qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class QtRule(qconfig.QuantizationRule):

# In backward pass, quantize the gradients to the given type. This doesn't
# affect the residuals as the residuals will reuse the quantization in the
# forward pass, unless bwd_use_original_residuals is set.
# forward pass.
bwd_qtype: jax.typing.DTypeLike | None = None

# In backward pass, calibrate the gradients using the given method.
Expand All @@ -48,11 +48,6 @@ class QtRule(qconfig.QuantizationRule):
# If True, disable channelwise axes for both forward and backward passes.
disable_channelwise_axes: bool = False

# If True, use the original values instead of the quantized values as the
# residuals for backward pass. Enabling this prevents using low-precision
# matmuls during bwd pass and has a negative impact on performance.
bwd_use_original_residuals: bool = False

# Use stochastic rounding for the gradients. (Only 'uniform' is supported.)
bwd_stochastic_rounding: str | None = None

Expand Down Expand Up @@ -293,7 +288,6 @@ def _create_conv_general_qt_config(
drhs_grad_calibration_method=rule.bwd_calibration_method,
# misc.
disable_channelwise_axes=rule.disable_channelwise_axes,
bwd_use_original_residuals=rule.bwd_use_original_residuals,
)

def _create_dot_general_qt_config(
Expand Down Expand Up @@ -392,7 +386,6 @@ def _create_dot_general_qt_config(
drhs_grad_calibration_method=rule.bwd_calibration_method,
# misc.
disable_channelwise_axes=rule.disable_channelwise_axes,
bwd_use_original_residuals=rule.bwd_use_original_residuals,
dlhs_stochastic_rounding_noise_fn=dlhs_stochastic_rounding_noise_fn,
drhs_stochastic_rounding_noise_fn=drhs_stochastic_rounding_noise_fn,
)
Expand Down
Loading