From b9837a35343458c4dee0e023e93564c96f7872c2 Mon Sep 17 00:00:00 2001 From: James Chapman Date: Fri, 10 Apr 2026 14:24:46 -0700 Subject: [PATCH] Add initial structure for HiJAX integration in qwix. This change introduces the `qwix/contrib/hijax` directory, which will house an experimental implementation of QArrays using HiJAX. The README explains the motivation for using HiJAX to gain finer control at the jaxpr level. The CL includes: - A new BUILD file. - `hiqarray_common.py`: Defines `QuantizationMetadata` to manage shapes and axes for quantization. - `hiquant_utils.py`: Provides utility functions for handling dtypes in quantization. - `hiqarray_common_test.py`: Tests for the shape and axis computation in `QuantizationMetadata`. - A detailed README.md file with motivation and implementation notes. PiperOrigin-RevId: 897887809 --- qwix/contrib/hijax/README.md | 55 +++++ qwix/contrib/hijax/hiqarray_common.py | 224 ++++++++++++++++++++ qwix/contrib/hijax/hiquant_utils.py | 71 +++++++ tests/contrib/hijax/hiqarray_common_test.py | 106 +++++++++ 4 files changed, 456 insertions(+) create mode 100644 qwix/contrib/hijax/README.md create mode 100644 qwix/contrib/hijax/hiqarray_common.py create mode 100644 qwix/contrib/hijax/hiquant_utils.py create mode 100644 tests/contrib/hijax/hiqarray_common_test.py diff --git a/qwix/contrib/hijax/README.md b/qwix/contrib/hijax/README.md new file mode 100644 index 0000000..b859e34 --- /dev/null +++ b/qwix/contrib/hijax/README.md @@ -0,0 +1,55 @@ +# HiJAX + +This folder contains code update to `QArray`s to use HiJAX. +HiJAX is a relatively new jax feature that allows the user to create +types which persist in a `jaxpr`. This gives finer control over how jax +deals with `QArray`s. + +Note that this implementation is still in an experimental phase and +everything in this folder is subject to change without notice. After hitting +feature parity and implementing all desired functionality, we will slowly +migrate this implementation into the `qwix/_src/core` directory. + +## Motivation + +The overall goal is to integrate `QArray`s more closely with jax and provide +better support at the `jaxpr` level. This gives finer grained control over +how jax deals with `QArrays`. + +Let `Low`, `Hi` denote low and hi precision types, resp. We typically think of +`Low` as a type where differentiation doesn't make sense (e.g. integers). The +cotangent type of `Low` is `Float0` (a trivial type). + +The current `QArray` type isn't Array-like in jax. + +- `QArray` is a pytree that looks roughly like + `tuple[Array[Low], Array[Hi], Array[Hi] | None]` +- The cotangent type of this representation is + `tuple[Array[Float0], Array[Hi], Array[Hi]]`. + Refer to + [jax.dtypes.float0](https://docs.jax.dev/en/latest/_autosummary/jax.dtypes.float0.html) + for more information. +- Techniques like the straight-through estimator assume that the cotangents are + stored as `Array`s or `QArray`s where the data is non-trivial. +- Using HiJAX, we can define the cotangent types of `QArray`s. +- Goal: Reduce reliance on `custom_vjp` for large functions. + +More motivation to come! +We seek to simplify the following: + +- Kernel integration +- Autograd semantics +- Integration with advanced jax features + +## Some Implementation Notes + +The current implementation attempts to minimize crossover between the core +library and the hijax implementation. This way we can add features without +impacting current qwix users. + +This implementation uses the naming `HiQArray` for the hijax implementation of +a `QArray`. + +- `hiqarray_common`: Common jax based functions and dataclasses for `HiQArray`. +- `hiquant_utils`: Common jax utilities for use in `HiQArray`. +- `hiqarray`: Coming soon. diff --git a/qwix/contrib/hijax/hiqarray_common.py b/qwix/contrib/hijax/hiqarray_common.py new file mode 100644 index 0000000..759de79 --- /dev/null +++ b/qwix/contrib/hijax/hiqarray_common.py @@ -0,0 +1,224 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The code in this file implements functionality common to all HiQArrays. + +This comes below the QArray abstraction and operates at the Array level. +""" + +import dataclasses + +import jax.numpy as jnp + + +@dataclasses.dataclass(frozen=True) +class QuantizationMetadata: + """This class contains information used to quantize and dequantize an array. + + If zero_point is not None, then we assume that the scale and zero_point have + the same shape. + """ + + # Information for constructing the shape metadata + quant_axes: tuple[int, ...] + group_sizes: tuple[int, ...] + + # Shapes + data_shape: tuple[int, ...] + quant_shape: tuple[int, ...] + + # Shapes for performing reductions + data_compatible_shape: tuple[int, ...] + quant_compatible_shape: tuple[int, ...] + + # Axes for reductions. Used for internal calculations. + _tiled_reduction_axes: tuple[int, ...] + _full_reduction_axes: tuple[int, ...] + + # Dtypes + dtype: jnp.dtype + qtype: jnp.dtype + + @classmethod + def init( + cls, + data_shape: tuple[int, ...], + quant_info: dict[int, int], + original_dtype: jnp.dtype, + quantized_dtype: jnp.dtype, + ): + """Initializes the quantization metadata for an array. + + Args: + data_shape: The shape of the original array. + quant_info: A dictionary of quantization axes to group sizes. + original_dtype: The dtype of the original array. + quantized_dtype: The dtype of the quantized array. + + Returns: + The quantization metadata for the array. + """ + # Sort the quant_axes and make them non-negative + l = len(quant_info) + sorted_quant_info = sorted( + [(q if q >= 0 else q + l, g) for q, g in quant_info.items()] + ) + quant_axes = tuple([x[0] for x in sorted_quant_info]) + group_sizes = tuple([x[1] for x in sorted_quant_info]) + + if len(quant_axes) != len(set(quant_axes)): + raise ValueError("Quantization axes must be unique and non-negative.") + + # Construct information about the unquantized array + ( + data_compatible_shape, + quant_compatible_shape, + tiled_reduction_axes, + full_reduction_axes, + ) = QuantizationMetadata._get_reduction_shape_and_axes( + data_shape, quant_axes, group_sizes + ) + quant_shape = QuantizationMetadata._get_quant_shape( + data_compatible_shape, tiled_reduction_axes, full_reduction_axes + ) + return cls( + quant_axes, + group_sizes, + data_shape, + quant_shape, + data_compatible_shape, + quant_compatible_shape, + tiled_reduction_axes, + full_reduction_axes, + original_dtype, + quantized_dtype, + ) + + @staticmethod + def _get_reduction_shape_and_axes( + original_shape: tuple[int, ...], + quant_axes: tuple[int, ...], + group_sizes: tuple[int, ...], + ) -> tuple[ + tuple[int, ...], + tuple[int, ...], + tuple[int, ...], + tuple[int, ...], + ]: + """This function returns the intermediate shape needed for performing reductions as well as the axes along which to reduce. + + Assumes: + - quant_axes is sorted and non-negative. + - group_sizes has the same length as quant_axes. + + Args: + original_shape: The shape of the original array. + quant_axes: A tuple of axes to quantize. + group_sizes: A tuple of group sizes for each quantization axis. + + Returns: + A tuple containing: + - data_compatible_shape: The shape of the data array compatible with + the reduction operations. + - quant_compatible_shape: The shape of the quantization parameters + compatible with broadcasting against the data. + - tiled_reduction_axes: Axes in `data_compatible_shape` that correspond + to tiled groups and should be reduced. + - full_reduction_axes: Axes in `data_compatible_shape` that correspond + to full-dimension groups and should be reduced to size 1. + + Example: + original_shape = (16, 32, 64) + quant_axes = (0, 1) + group_sizes = (2, -1) + data_compatible_shape = (8, 2, 32, 64) + quant_compatible_shape = (8, 1, 1, 64) + tiled_reduction_axes = (1,) + full_reduction_axes = (1,) + """ + # intermediate shape used for later reductions + data_compatible_shape = [] + quant_compatible_shape = [] + + # axes which have been tiled and should be reduced away + tiled_reduction_axes = [] + + # axes which have not been tiled and should reduce to shape 1 + full_reduction_axes = tuple([ + q + for i, (q, g) in enumerate(zip(quant_axes, group_sizes)) + if g == -1 or g == original_shape[i] + ]) + + qi = 0 + for i, xi in enumerate(original_shape): + if qi >= len(quant_axes): + data_compatible_shape.append(xi) + quant_compatible_shape.append(xi) + continue + if i == quant_axes[qi]: + gs = group_sizes[qi] + if gs == -1 or gs == xi: + data_compatible_shape.append(xi) + quant_compatible_shape.append(1) + else: + assert xi % gs == 0, "Group size must divide dimension size" + data_compatible_shape.append(xi // gs) + data_compatible_shape.append(gs) + quant_compatible_shape.append(xi // gs) + quant_compatible_shape.append(1) + tiled_reduction_axes.append(len(data_compatible_shape) - 1) + qi += 1 + else: + data_compatible_shape.append(xi) + quant_compatible_shape.append(xi) + + return ( + tuple(data_compatible_shape), + tuple(quant_compatible_shape), + tuple(tiled_reduction_axes), + full_reduction_axes, + ) + + @staticmethod + def _get_quant_shape( + intermediate_shape: tuple[int, ...], + tiled_reduction_axes: tuple[int, ...], + full_reduction_axes: tuple[int, ...], + ) -> tuple[int, ...]: + """Returns the shape of the tensor after quantization.""" + tmp_shape = [] + for i, xi in enumerate(intermediate_shape): + if i in tiled_reduction_axes: + continue + else: + tmp_shape.append(xi) + out = [] + + for i, xi in enumerate(tmp_shape): + if i in full_reduction_axes: + out.append(1) + else: + out.append(xi) + return tuple(out) + + def __repr__(self): + quant_axes = self.quant_axes + group_sizes = self.group_sizes + orig_dtype = self.dtype + quant_dtype = self.qtype + out = ( + f"QuantizationMetadata({quant_axes=}, {group_sizes=}, {orig_dtype=}," + f" {quant_dtype=})" + ) + return out diff --git a/qwix/contrib/hijax/hiquant_utils.py b/qwix/contrib/hijax/hiquant_utils.py new file mode 100644 index 0000000..515ccbd --- /dev/null +++ b/qwix/contrib/hijax/hiquant_utils.py @@ -0,0 +1,71 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for quantization.""" + +import jax.numpy as jnp + +_INTEGER_DTYPES = set([ + jnp.int8, + jnp.int16, + jnp.int32, + jnp.int64, + jnp.uint8, + jnp.uint16, + jnp.uint32, + jnp.uint64, + jnp.dtype("int8"), + jnp.dtype("int16"), + jnp.dtype("int32"), + jnp.dtype("int64"), + jnp.dtype("uint8"), + jnp.dtype("uint16"), + jnp.dtype("uint32"), + jnp.dtype("uint64"), +]) +_FLOAT_DTYPES = set([ + jnp.bfloat16, + jnp.float16, + jnp.float32, + jnp.float64, + jnp.dtype("bfloat16"), + jnp.dtype("float16"), + jnp.dtype("float32"), + jnp.dtype("float64"), +]) + + +def get_bitwidth(dtype: jnp.dtype) -> int: + if dtype in _INTEGER_DTYPES: + return jnp.iinfo(dtype).bits + elif dtype in _FLOAT_DTYPES: + return jnp.finfo(dtype).bits + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def get_accumulation_dtype(dtype: jnp.dtype) -> jnp.dtype: + if dtype in _INTEGER_DTYPES: + return jnp.int32 + elif dtype in _FLOAT_DTYPES: + return jnp.float32 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def is_integer_dtype(dtype: jnp.dtype) -> bool: + return dtype in _INTEGER_DTYPES + + +def is_float_dtype(dtype: jnp.dtype) -> bool: + return dtype in _FLOAT_DTYPES diff --git a/tests/contrib/hijax/hiqarray_common_test.py b/tests/contrib/hijax/hiqarray_common_test.py new file mode 100644 index 0000000..b78f777 --- /dev/null +++ b/tests/contrib/hijax/hiqarray_common_test.py @@ -0,0 +1,106 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax.numpy as jnp +from qwix.contrib.hijax import hiqarray_common as hqc + + +class QuantizationMetadataTest(absltest.TestCase): + + def test_get_shape_and_axes_1(self): + x = jnp.ones((16, 32, 64)) + ( + data_compatible_shape, + quant_compatible_shape, + tiled_reduction_axes, + full_reduction_axes, + ) = hqc.QuantizationMetadata._get_reduction_shape_and_axes( + x.shape, (0, 1, 2), (-1, -1, -1) + ) + self.assertEqual(data_compatible_shape, (16, 32, 64)) + self.assertEqual(quant_compatible_shape, (1, 1, 1)) + self.assertEmpty(tiled_reduction_axes) + self.assertEqual(full_reduction_axes, (0, 1, 2)) + + def test_get_shape_and_axes_2(self): + x = jnp.ones((16, 32, 64)) + ( + data_compatible_shape, + quant_compatible_shape, + tiled_reduction_axes, + full_reduction_axes, + ) = hqc.QuantizationMetadata._get_reduction_shape_and_axes( + x.shape, (0, 1, 2), (2, 4, 8) + ) + self.assertEqual(data_compatible_shape, (8, 2, 8, 4, 8, 8)) + self.assertEqual(quant_compatible_shape, (8, 1, 8, 1, 8, 1)) + self.assertEqual(tiled_reduction_axes, (1, 3, 5)) + self.assertEmpty(full_reduction_axes) + + def test_get_shape_and_axes_3(self): + shape = (16, 32, 64) + ( + data_compatible_shape, + quant_compatible_shape, + tiled_reduction_axes, + full_reduction_axes, + ) = hqc.QuantizationMetadata._get_reduction_shape_and_axes( + shape, (0, 1), (-1, 4) + ) + self.assertEqual(data_compatible_shape, (16, 8, 4, 64)) + self.assertEqual(quant_compatible_shape, (1, 8, 1, 64)) + self.assertEqual(tiled_reduction_axes, (2,)) + self.assertEqual(full_reduction_axes, (0,)) + + metadata = hqc.QuantizationMetadata.init( + shape, + {0: -1, 1: 4}, + jnp.float32, + jnp.float32, + ) + + q_shape = metadata.quant_compatible_shape + self.assertEqual(len(data_compatible_shape), len(q_shape)) + self.assertTrue( + all( + ti == qi or qi == 1 + for ti, qi in zip(data_compatible_shape, q_shape) + ) + ) + + def test_get_shape_and_axes_4(self): + # Verify example in docstring + original_shape = (16, 32, 64) + quant_info = {0: 2, 1: -1} + qmd = hqc.QuantizationMetadata.init( + original_shape, quant_info, jnp.float32, jnp.float32 + ) + self.assertEqual(qmd.data_compatible_shape, (8, 2, 32, 64)) + self.assertEqual(qmd.quant_compatible_shape, (8, 1, 1, 64)) + self.assertEqual(qmd._tiled_reduction_axes, (1,)) + self.assertEqual(qmd._full_reduction_axes, (1,)) + + def test_get_quant_shape(self): + tmp_shape = (16, 32, 64) + tiled_reduction_axes = () + full_reduction_axes = (0, 1, 2) + new_shape = hqc.QuantizationMetadata._get_quant_shape( + tmp_shape, tiled_reduction_axes, full_reduction_axes + ) + self.assertEqual(new_shape, (1, 1, 1)) + + +if __name__ == "__main__": + absltest.main()