From e59496647449f86482ebedbee22a1bec380ea65e Mon Sep 17 00:00:00 2001 From: Dangyi Liu Date: Thu, 23 Oct 2025 17:11:16 -0700 Subject: [PATCH] Allow gradients on QArray This patch introduces a new approach to associate gradients with QArray. The outcome is we could define vjp rules for `quantize` and `dot_general` separately while ensuring a correct backward pass. PiperOrigin-RevId: 823246141 --- qwix/_src/core/conv_general_qt.py | 5 +- qwix/_src/core/dot_general_qt.py | 142 +++++++++++++++--------------- qwix/_src/core/qarray_qt.py | 79 +++++++++++++++++ qwix/_src/providers/qt.py | 9 +- tests/core/qarray_qt_test.py | 40 +++++++++ tests/providers/qt_test.py | 24 ++--- 6 files changed, 206 insertions(+), 93 deletions(-) create mode 100644 qwix/_src/core/qarray_qt.py create mode 100644 tests/core/qarray_qt_test.py diff --git a/qwix/_src/core/conv_general_qt.py b/qwix/_src/core/conv_general_qt.py index b2e8147d..77ba7396 100644 --- a/qwix/_src/core/conv_general_qt.py +++ b/qwix/_src/core/conv_general_qt.py @@ -47,7 +47,6 @@ class ConvGeneralQtConfig: # Misc. disable_channelwise_axes: bool = False - bwd_use_original_residuals: bool = False # Swaps the first two dimension indices of a specification. @@ -187,11 +186,8 @@ def _quantize_operand( operand, qtype, scale, zero_point ) - residuals = (lhs, rhs) lhs = _quantize_operand(lhs, for_lhs=True) rhs = _quantize_operand(rhs, for_lhs=False) - if not config.bwd_use_original_residuals: - residuals = (lhs, rhs) primal_out = conv_general.conv_general_dilated( lhs, @@ -204,6 +200,7 @@ def _quantize_operand( feature_group_count, batch_group_count, ) + residuals = (lhs, rhs) return primal_out, residuals diff --git a/qwix/_src/core/dot_general_qt.py b/qwix/_src/core/dot_general_qt.py index 0469e318..600390cb 100644 --- a/qwix/_src/core/dot_general_qt.py +++ b/qwix/_src/core/dot_general_qt.py @@ -24,6 +24,7 @@ from qwix._src.core import dot_general from qwix._src.core import numerics from qwix._src.core import qarray +from qwix._src.core import qarray_qt @dataclasses.dataclass(slots=True, frozen=True, kw_only=True) @@ -43,19 +44,16 @@ class DotGeneralQtConfig: dlhs_grad_qtype: jax.typing.DTypeLike | None = None # incoming gradient dlhs_grad_calibration_method: str = 'absmax' dlhs_tile_size: int | float | None = None + dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None # Backward pass (drhs). drhs_grad_qtype: jax.typing.DTypeLike | None = None # incoming gradient drhs_grad_calibration_method: str = 'absmax' drhs_tile_size: int | float | None = None + drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None # Misc. disable_channelwise_axes: bool = False - bwd_use_original_residuals: bool = False # what to use as residuals - - # Configs for stochastic rounding. - dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None - drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None def _ranges_like(*xs): @@ -124,64 +122,23 @@ def _apply_rhs_scale_to_lhs(lhs, rhs_scale, dnums): # disable interceptions for dot_general_qt_fwd. @interception.disable_interceptions def dot_general_qt_fwd( - lhs: jax.Array, - rhs: jax.Array, + lhs: jax.Array | qarray_qt.QArrayWithGradient, + rhs: jax.Array | qarray_qt.QArrayWithGradient, dimension_numbers: jax.lax.DotDimensionNumbers, config: DotGeneralQtConfig, ): """Forward pass for dot_general_qt custom VJP.""" - ndims = (lhs.ndim, rhs.ndim) - - def _quantize_operand(operand: jax.Array, is_lhs: bool) -> qarray.MaybeQArray: - """Quantizes a single operand for the forward pass if configured to do so.""" - if is_lhs: - qtype = config.lhs_qtype - calibration_method = config.lhs_calibration_method - collect_quant_stat = config.lhs_collect_quant_stat - else: - qtype = config.rhs_qtype - calibration_method = config.rhs_calibration_method - collect_quant_stat = config.rhs_collect_quant_stat - - if not (qtype and numerics.should_quantize(operand.dtype)): - return operand - - how = dot_general.get_how_to_quantize( - dimension_numbers=dimension_numbers, - ndims=ndims, - for_lhs=is_lhs, - qtype=qtype, - tile_size=config.tile_size, - calibration_method=calibration_method, - ) - if config.disable_channelwise_axes: - how = dataclasses.replace(how, channelwise_axes=[]) - - calibration = qarray.calibrate(operand, how) - if collect_quant_stat: - calibration = collect_quant_stat(calibration) - scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype) - return qarray.quantize_with_scale_zero_point( - operand, how.qtype, scale, zero_point - ) - - qlhs = _quantize_operand(lhs, is_lhs=True) - qrhs = _quantize_operand(rhs, is_lhs=False) - - primal_out = dot_general.dot_general(qlhs, qrhs, dimension_numbers) - - if config.bwd_use_original_residuals: - residuals = (lhs, rhs) - else: - residuals = (qlhs, qrhs) - - return primal_out, residuals + del config + return dot_general.dot_general(lhs, rhs, dimension_numbers), (lhs, rhs) def dot_general_qt_bwd( fwd_dimension_numbers: jax.lax.DotDimensionNumbers, config: DotGeneralQtConfig, - residuals: tuple[qarray.MaybeQArray, qarray.MaybeQArray], + residuals: tuple[ + jax.Array | qarray_qt.QArrayWithGradient, + jax.Array | qarray_qt.QArrayWithGradient, + ], g: jax.Array, ): """Backward pass for dot_general_qt custom VJP.""" @@ -189,8 +146,8 @@ def dot_general_qt_bwd( def _compute_gradient_for_operand( g: jax.Array, y: qarray.MaybeQArray, *, for_dlhs: bool - ): - """Compute dot_general for gradient and other_fwd_operand.""" + ) -> jax.Array | qarray_qt.QArrayWithGradient: + """Compute dx from g and y.""" bwd_dnums, transpose_axes = _update_dimension_numbers_for_backward( fwd_dimension_numbers, (lhs.ndim, rhs.ndim), for_dlhs=for_dlhs ) @@ -198,10 +155,14 @@ def _compute_gradient_for_operand( g_qtype = config.dlhs_grad_qtype g_tile_size = config.dlhs_tile_size g_calibration_method = config.dlhs_grad_calibration_method + g_noise_fn = config.dlhs_stochastic_rounding_noise_fn + result_type = lhs # the result gradient must match this type. else: g_qtype = config.drhs_grad_qtype g_tile_size = config.drhs_tile_size g_calibration_method = config.drhs_grad_calibration_method + g_noise_fn = config.drhs_stochastic_rounding_noise_fn + result_type = rhs # the result gradient must match this type. if g_qtype and numerics.should_quantize(g.dtype): if isinstance(y, qarray.QArray) and not qarray.get_tiled_axes(y): @@ -219,23 +180,20 @@ def _compute_gradient_for_operand( tile_size=g_tile_size, calibration_method=g_calibration_method, ) + g_how = dataclasses.replace(g_how, noise_fn=g_noise_fn) if config.disable_channelwise_axes: g_how = dataclasses.replace(g_how, channelwise_axes=[]) - if for_dlhs and config.dlhs_stochastic_rounding_noise_fn: - g_how = dataclasses.replace( - g_how, - noise_fn=config.dlhs_stochastic_rounding_noise_fn, - ) - if not for_dlhs and config.drhs_stochastic_rounding_noise_fn: - g_how = dataclasses.replace( - g_how, - noise_fn=config.drhs_stochastic_rounding_noise_fn, - ) g = qarray.quantize(g, g_how) grad_res = dot_general.dot_general(g, y, bwd_dnums) - return jax.lax.transpose(grad_res, transpose_axes) + grad_res = jax.lax.transpose(grad_res, transpose_axes) + if isinstance(result_type, qarray_qt.QArrayWithGradient): + return dataclasses.replace( + result_type, qvalue=None, scale=None, zero_point=None, _grad=grad_res + ) + else: + return grad_res dlhs = _compute_gradient_for_operand(g, rhs, for_dlhs=True) drhs = _compute_gradient_for_operand(g, lhs, for_dlhs=False) @@ -244,6 +202,20 @@ def _compute_gradient_for_operand( @functools.partial(jax.custom_vjp, nondiff_argnums=(2, 3)) +def dot_general_fwd_bwd( + lhs: jax.Array | qarray_qt.QArrayWithGradient, + rhs: jax.Array | qarray_qt.QArrayWithGradient, + dimension_numbers: jax.lax.DotDimensionNumbers, + config: DotGeneralQtConfig, +) -> jax.Array: + """Quantized dot_general with backpropagation support.""" + del config + return dot_general.dot_general(lhs, rhs, dimension_numbers) + + +dot_general_fwd_bwd.defvjp(dot_general_qt_fwd, dot_general_qt_bwd) + + def dot_general_qt( lhs: jax.Array, rhs: jax.Array, @@ -251,8 +223,38 @@ def dot_general_qt( config: DotGeneralQtConfig, ) -> jax.Array: """Quantized dot_general with backpropagation support.""" - result, _ = dot_general_qt_fwd(lhs, rhs, dimension_numbers, config) - return result + if config.lhs_qtype and numerics.should_quantize(lhs.dtype): + how = dot_general.get_how_to_quantize( + dimension_numbers=dimension_numbers, + ndims=(lhs.ndim, rhs.ndim), + for_lhs=True, + qtype=config.lhs_qtype, + tile_size=config.tile_size, + calibration_method=config.lhs_calibration_method, + ) + if config.disable_channelwise_axes: + how = dataclasses.replace(how, channelwise_axes=[]) + + calibration = qarray.calibrate(lhs, how) + if config.lhs_collect_quant_stat: + calibration = config.lhs_collect_quant_stat(calibration) + lhs = qarray_qt.quantize_with_calibration(lhs, how.qtype, calibration) + + if config.rhs_qtype and numerics.should_quantize(rhs.dtype): + how = dot_general.get_how_to_quantize( + dimension_numbers=dimension_numbers, + ndims=(lhs.ndim, rhs.ndim), + for_lhs=False, + qtype=config.rhs_qtype, + tile_size=config.tile_size, + calibration_method=config.rhs_calibration_method, + ) + if config.disable_channelwise_axes: + how = dataclasses.replace(how, channelwise_axes=[]) + calibration = qarray.calibrate(rhs, how) + if config.rhs_collect_quant_stat: + calibration = config.rhs_collect_quant_stat(calibration) + rhs = qarray_qt.quantize_with_calibration(rhs, how.qtype, calibration) -dot_general_qt.defvjp(dot_general_qt_fwd, dot_general_qt_bwd) + return dot_general_fwd_bwd(lhs, rhs, dimension_numbers, config) diff --git a/qwix/_src/core/qarray_qt.py b/qwix/_src/core/qarray_qt.py new file mode 100644 index 00000000..e8910bc1 --- /dev/null +++ b/qwix/_src/core/qarray_qt.py @@ -0,0 +1,79 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""QArray with gradient for custom VJP.""" + +import dataclasses +from typing import Mapping +import flax.struct +import jax +from qwix._src.core import qarray + + +@flax.struct.dataclass +class QArrayWithGradient(qarray.QArray): + """QArray with gradient. + + This dataclass allows us to associate a gradient with the QArray. It's + achieved by defining an extra attribute `_grad` on the QArray, which has the + same dtype and the same shape as the unquantized array. In forward pass, the + `_grad` does nothing and should never be consumed. In backward pass, the + `_grad` carries the gradient of the whole QArray. + + This approach overcomes the Jax limitation on the gradients, i.e., the + gradient of a qvalue of int8[128,128] has to be float0[128,128], while the + gradient of a scale of float32[1,1] has to be float32[1,1]. An alternative + is to define the QArray as a new Hijax type, which is more complex. + """ + + _grad: jax.Array = flax.struct.field(kw_only=True) + + +def quantize_with_calibration( + array: jax.Array, + qtype: jax.typing.DTypeLike, + calibration: Mapping[str, jax.Array], + clip_gradient: bool = False, +) -> QArrayWithGradient: + """Quantizes an array with calibration with backpropagation support. + + Args: + array: The array to quantize. + qtype: The quantized type. + calibration: The calibration of the array. + clip_gradient: Whether to clip the straight-through estimator to the + calibration range, i.e., the gradient outside the calibration range is 0. + + Returns: + The quantized array with backpropagation support. + """ + scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype) + res = qarray.quantize_with_scale_zero_point(array, qtype, scale, zero_point) + if clip_gradient: + array = qarray.clip_to_calibration( + array, calibration, qarray.get_tiled_axes(res) + ) + # Do not allow gradients on the quantized array to flow back to the input. + res = jax.lax.stop_gradient(res) + return QArrayWithGradient(**dataclasses.asdict(res), _grad=array) + + +@jax.custom_jvp +def dequantize(array: QArrayWithGradient) -> jax.Array: + """Dequantizes an array.""" + return qarray.dequantize(array) + + +@dequantize.defjvp +def _dequantize_jvp(primals, tangents): + return dequantize(*primals), tangents[0]._grad # pylint: disable=protected-access diff --git a/qwix/_src/providers/qt.py b/qwix/_src/providers/qt.py index 1f0d0cd3..d03e2741 100644 --- a/qwix/_src/providers/qt.py +++ b/qwix/_src/providers/qt.py @@ -34,7 +34,7 @@ class QtRule(qconfig.QuantizationRule): # In backward pass, quantize the gradients to the given type. This doesn't # affect the residuals as the residuals will reuse the quantization in the - # forward pass, unless bwd_use_original_residuals is set. + # forward pass. bwd_qtype: jax.typing.DTypeLike | None = None # In backward pass, calibrate the gradients using the given method. @@ -48,11 +48,6 @@ class QtRule(qconfig.QuantizationRule): # If True, disable channelwise axes for both forward and backward passes. disable_channelwise_axes: bool = False - # If True, use the original values instead of the quantized values as the - # residuals for backward pass. Enabling this prevents using low-precision - # matmuls during bwd pass and has a negative impact on performance. - bwd_use_original_residuals: bool = False - # Use stochastic rounding for the gradients. (Only 'uniform' is supported.) bwd_stochastic_rounding: str | None = None @@ -293,7 +288,6 @@ def _create_conv_general_qt_config( drhs_grad_calibration_method=rule.bwd_calibration_method, # misc. disable_channelwise_axes=rule.disable_channelwise_axes, - bwd_use_original_residuals=rule.bwd_use_original_residuals, ) def _create_dot_general_qt_config( @@ -392,7 +386,6 @@ def _create_dot_general_qt_config( drhs_grad_calibration_method=rule.bwd_calibration_method, # misc. disable_channelwise_axes=rule.disable_channelwise_axes, - bwd_use_original_residuals=rule.bwd_use_original_residuals, dlhs_stochastic_rounding_noise_fn=dlhs_stochastic_rounding_noise_fn, drhs_stochastic_rounding_noise_fn=drhs_stochastic_rounding_noise_fn, ) diff --git a/tests/core/qarray_qt_test.py b/tests/core/qarray_qt_test.py new file mode 100644 index 00000000..87979372 --- /dev/null +++ b/tests/core/qarray_qt_test.py @@ -0,0 +1,40 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import numpy as jnp +from qwix._src.core import qarray +from qwix._src.core import qarray_qt + + +class QArrayQtTest(parameterized.TestCase): + + def test_qarray_with_gradient(self): + x = jnp.ones((3, 3), jnp.float32) + + def fake_quant_sum(x): + how = qarray.HowToQuantize(qtype=jnp.int8) + x = qarray_qt.quantize_with_calibration( + x, how.qtype, qarray.calibrate(x, how) + ) + x = qarray_qt.dequantize(x) + return jnp.sum(x) + + self.assertTrue((jax.grad(fake_quant_sum)(x) == x).all()) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/providers/qt_test.py b/tests/providers/qt_test.py index 94adca3f..df3ced3a 100644 --- a/tests/providers/qt_test.py +++ b/tests/providers/qt_test.py @@ -65,19 +65,21 @@ def loss_fn(params): self.assertEqual(quant_stats["dot_general0_lhs"]["count"], 1) def test_srq_jit_grad_nnx(self): - """Test SRQ on NNX module.""" - linear = nnx.Linear(12, 10, rngs=nnx.Rngs(0), param_dtype=jnp.bfloat16) - qt_provider = qt.QtProvider([ - qconfig.QuantizationRule( - module_path=".*", - weight_qtype=jnp.int8, - act_qtype=jnp.int8, - act_static_scale=True, - ), - ]) + """Test creating and train an SRQ NNX model inside jit.""" + + def create_srq_nnx_model(model_input): + linear = nnx.Linear(12, 10, rngs=nnx.Rngs(0), param_dtype=jnp.bfloat16) + qt_provider = qt.QtProvider([ + qconfig.QuantizationRule( + weight_qtype=jnp.int8, + act_qtype=jnp.int8, + act_static_scale=True, + ), + ]) + return qwix_model.quantize_model(linear, qt_provider, model_input) model_input = jnp.ones((10, 12), dtype=jnp.float32) - qt_linear = qwix_model.quantize_model(linear, qt_provider, model_input) + qt_linear = nnx.jit(create_srq_nnx_model)(model_input) quant_stats = nnx.variables(qt_linear, flax_util.QuantStat) # quant_stats should be initialized but empty.