From 7d03ff80757fd500d46640164634a4586bdd21c8 Mon Sep 17 00:00:00 2001 From: Dedalien Date: Thu, 23 Apr 2026 14:25:57 +0200 Subject: [PATCH 1/2] Add TurboQuant3/4 modes to quantized_scaled_dot_product_attention Integrates TurboQuant (arXiv 2504.19874) into the generic quant-SDPA infrastructure from #3026, rather than as a standalone kernel. New modes "turbo3" (3-bit) and "turbo4" (4-bit) use Lloyd-Max codebooks for the N(0,1) distribution of WHT-rotated, norm-normalised key coordinates. Codebooks are compile-time Metal constants; no runtime buffer needed. Memory layout: - K/V: uint32-packed bit indices, same PackReader path as affine - k_scales/v_scales: per-vector float16 norms / sqrt(D) - group_size == head_dim (one scale per full D-element vector) - Queries pre-rotated by caller via WHT Changes: - primitives.h/cpp: TurboQuant3/TurboQuant4 in QuantizationMode enum, string parsing, quantization_mode_to_string explicit cases - fast.cpp: TurboQuant validation branch (no biases, float scales, group_size == head_dim, head_dim in {64,128,256}, GPU-only fallback) - quantized_utils.h: QuantMode enum extension, QuantConfig/Dequant specialisations, Lloyd-Max codebook constants - sdpa_vector.h: static_assert extended for bits=3; QUANT_SDPA_DISPATCH entries for D in {64,128,256} gated on if constexpr (D >= group_size) - scaled_dot_product_attention.cpp: quant_mode_to_int (TurboQuant3=4, TurboQuant4=5) - python/src/fast.cpp: turbo3/turbo4 in mode docstring Tested on Qwen3.6-27B (head_dim=256, 24Q/4KV, GQA=6), 24 GB unified memory Mac. Enables longer context generation on memory-constrained hardware by compressing the KV cache ~5x. --- mlx/backend/metal/kernels/quantized_utils.h | 86 ++++++++- mlx/backend/metal/kernels/sdpa_vector.h | 18 +- .../metal/scaled_dot_product_attention.cpp | 4 + mlx/fast.cpp | 61 ++++++- mlx/primitives.cpp | 18 +- mlx/primitives.h | 9 +- python/src/fast.cpp | 2 +- python/tests/test_quantized.py | 167 ++++++++++++++++++ 8 files changed, 357 insertions(+), 8 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h index 86aa8b75d0..9484391e48 100644 --- a/mlx/backend/metal/kernels/quantized_utils.h +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -8,7 +8,7 @@ #include "mlx/backend/metal/kernels/fp4.h" #include "mlx/backend/metal/kernels/fp8.h" -enum class QuantMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; +enum class QuantMode { Affine, Mxfp4, Mxfp8, Nvfp4, TurboQuant3, TurboQuant4 }; template struct DecodeValue { @@ -72,6 +72,64 @@ struct QuantConfig { using scale_storage_t = uint8_t; }; +// TurboQuant: codebook-based quantization with per-vector float scales. +// Keys/values are packed bit indices; scales are per-vector L2 norms / sqrt(D). +// Codebooks are Lloyd-Max optimal for N(0,1) (distribution of rotated, +// norm-normalized key coordinates scaled by sqrt(D)). +template <> +struct QuantConfig { + static constant constexpr bool has_bias = false; + + using value_type = void; + using scale_type = void; + + template + using scale_storage_t = T; +}; + +template <> +struct QuantConfig { + static constant constexpr bool has_bias = false; + + using value_type = void; + using scale_type = void; + + template + using scale_storage_t = T; +}; + +// N(0,1) Lloyd-Max 3-bit codebook (8 reconstruction levels). +// Boundaries: 0, ±0.332, ±0.776, ±1.399 (midpoints of adjacent centroids). +constant float turbo3_codebook[8] = { + -1.7481f, + -1.0498f, + -0.5012f, + -0.1624f, + 0.1624f, + 0.5012f, + 1.0498f, + 1.7481f}; + +// N(0,1) equal-probability 4-bit codebook (16 reconstruction levels). +// Computed as E[X | X in (b_i, b_{i+1})] for N(0,1) with equiprobable bins. +constant float turbo4_codebook[16] = { + -1.9672f, + -1.3305f, + -1.0130f, + -0.7811f, + -0.5714f, + -0.4053f, + -0.2382f, + -0.0784f, + 0.0784f, + 0.2382f, + 0.4053f, + 0.5714f, + 0.7811f, + 1.0130f, + 1.3305f, + 1.9672f}; + template struct Dequant { using Cfg = QuantConfig; @@ -98,6 +156,32 @@ struct Dequant { } }; +template +struct Dequant { + [[clang::always_inline]] T raw(uint8_t v) const { + return T(turbo3_codebook[v & 7u]); + } + [[clang::always_inline]] T scale(T s) const { + return s; + } + [[clang::always_inline]] T operator()(uint8_t v, T s, T) const { + return s * raw(v); + } +}; + +template +struct Dequant { + [[clang::always_inline]] T raw(uint8_t v) const { + return T(turbo4_codebook[v & 15u]); + } + [[clang::always_inline]] T scale(T s) const { + return s; + } + [[clang::always_inline]] T operator()(uint8_t v, T s, T) const { + return s * raw(v); + } +}; + // Pack metadata and unpackers for arbitrary bit-widths (wsize fixed at 32 bits) template struct PackInfo { diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index a823915123..5a6d3050ab 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -208,7 +208,9 @@ struct QuantOps { static constant constexpr int granularity = is_fast_path ? 4 : pack_factor; using fast_load_t = metal::conditional_t; static constant constexpr uint32_t fast_mask = (1u << bits) - 1; - static_assert(bits == 4 || bits == 6 || bits == 8, "unsupported quant bits"); + static_assert( + bits == 3 || bits == 4 || bits == 6 || bits == 8, + "unsupported quant bits"); static_assert( !is_fast_path || (group_size % 4) == 0, "group_size must be divisible by 4 for 4/8-bit fast path"); @@ -659,6 +661,20 @@ template QUANT_SDPA_DISPATCH(Mxfp4, 32, 4) QUANT_SDPA_DISPATCH(Nvfp4, 16, 4) QUANT_SDPA_DISPATCH(Mxfp8, 32, 8) + // TurboQuant requires group_size == head_dim (one norm per vector). + // Gate on D to avoid the (D % group_size) == 0 static_assert failing. + if constexpr (D >= 64) { + QUANT_SDPA_DISPATCH(TurboQuant3, 64, 3) + QUANT_SDPA_DISPATCH(TurboQuant4, 64, 4) + } + if constexpr (D >= 128) { + QUANT_SDPA_DISPATCH(TurboQuant3, 128, 3) + QUANT_SDPA_DISPATCH(TurboQuant4, 128, 4) + } + if constexpr (D >= 256) { + QUANT_SDPA_DISPATCH(TurboQuant3, 256, 3) + QUANT_SDPA_DISPATCH(TurboQuant4, 256, 4) + } #undef QUANT_SDPA_DISPATCH } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index fae6188ed5..b79a501342 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -626,6 +626,10 @@ int quant_mode_to_int(QuantizationMode mode) { return 2; case QuantizationMode::Nvfp4: return 3; + case QuantizationMode::TurboQuant3: + return 4; + case QuantizationMode::TurboQuant4: + return 5; default: throw std::invalid_argument( "[quant_sdpa_vector_2pass] Unsupported quantization mode."); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index bddc569b55..2b4caa7499 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -896,6 +896,8 @@ array quantized_scaled_dot_product_attention( // Parse mode and get parameters auto qmode = string_to_quantization_mode(mode, tag); bool is_affine = qmode == QuantizationMode::Affine; + bool is_turbo = qmode == QuantizationMode::TurboQuant3 || + qmode == QuantizationMode::TurboQuant4; auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); @@ -919,6 +921,19 @@ array quantized_scaled_dot_product_attention( "[quantized_scaled_dot_product_attention] Affine mode requires " "key_biases and value_biases."); } + } else if (is_turbo) { + int expected_bits = (qmode == QuantizationMode::TurboQuant3) ? 3 : 4; + if (bits != expected_bits) { + std::ostringstream msg; + msg << "[" << tag << "] Mode '" << mode << "' requires bits " + << expected_bits << "."; + throw std::invalid_argument(msg.str()); + } + if (key_biases.has_value() || value_biases.has_value()) { + throw std::invalid_argument( + "[quantized_scaled_dot_product_attention] Biases are not supported " + "for TurboQuant modes."); + } } else { // FP modes have fixed params - verify if user overrode them incorrectly auto [expected_gs, expected_bits] = @@ -972,12 +987,25 @@ array quantized_scaled_dot_product_attention( "[quantized_scaled_dot_product_attention] Keys and values must be " "uint32."); } - if (!is_affine && + if (!is_affine && !is_turbo && (key_scales.dtype() != uint8 || value_scales.dtype() != uint8)) { throw std::invalid_argument( "[quantized_scaled_dot_product_attention] Scales must be uint8 for fp " "quantization."); } + if (is_turbo) { + auto check_turbo_scale_dtype = [&](const array& s, const char* name) { + if (s.dtype() != float16 && s.dtype() != bfloat16 && + s.dtype() != float32) { + std::ostringstream msg; + msg << "[" << tag << "] TurboQuant " << name + << " scales must be float16, bfloat16, or float32."; + throw std::invalid_argument(msg.str()); + } + }; + check_turbo_scale_dtype(key_scales, "key"); + check_turbo_scale_dtype(value_scales, "value"); + } // Compute and validate dimensions auto key_head_dim = (keys.shape(-1) * 32) / bits; @@ -1015,6 +1043,22 @@ array quantized_scaled_dot_product_attention( << " must be divisible by group_size " << group_size << "."; throw std::invalid_argument(msg.str()); } + if (is_turbo) { + if (group_size != queries.shape(-1)) { + std::ostringstream msg; + msg << "[" << tag << "] TurboQuant requires group_size == head_dim," + << " got group_size=" << group_size + << " head_dim=" << queries.shape(-1) << "."; + throw std::invalid_argument(msg.str()); + } + auto head_dim = queries.shape(-1); + if (head_dim != 64 && head_dim != 128 && head_dim != 256) { + std::ostringstream msg; + msg << "[" << tag << "] TurboQuant only supports head_dim in {64, 128, 256}," + << " got " << head_dim << "."; + throw std::invalid_argument(msg.str()); + } + } // Validate scale/bias shapes auto expected_scale_dim = queries.shape(-1) / group_size; @@ -1063,10 +1107,16 @@ array quantized_scaled_dot_product_attention( has_arr_mask, has_sinks, is_affine, + is_turbo, group_size, bits, mode, s](const std::vector& inputs) { + if (is_turbo) { + throw std::runtime_error( + "[quantized_scaled_dot_product_attention] TurboQuant mode requires " + "a Metal GPU device."); + } auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); int n_repeats = n_q_heads / n_kv_heads; @@ -1181,12 +1231,17 @@ array quantized_scaled_dot_product_attention( Shape full_mask_shape{ queries.shape(0), queries.shape(1), queries.shape(2), keys.shape(-2)}; - std::vector inputs = {q, keys, key_scales}; + // For TurboQuant, scales are per-vector norms stored as floats (not uint8). + // Cast them to final_type so the Metal kernel sees the expected dtype. + auto ks = is_turbo ? astype(key_scales, final_type, stream) : key_scales; + auto vs = is_turbo ? astype(value_scales, final_type, stream) : value_scales; + + std::vector inputs = {q, keys, ks}; if (is_affine) { inputs.push_back(*key_biases); } inputs.push_back(values); - inputs.push_back(value_scales); + inputs.push_back(vs); if (is_affine) { inputs.push_back(*value_biases); } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 46e5295ca0..51308324d5 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3425,9 +3425,13 @@ std::string quantization_mode_to_string(QuantizationMode mode) { case QuantizationMode::Mxfp8: return "mxfp8"; case QuantizationMode::Nvfp4: - default: return "nvfp4"; + case QuantizationMode::TurboQuant3: + return "turbo3"; + case QuantizationMode::TurboQuant4: + return "turbo4"; } + throw std::runtime_error("Unknown quantization mode"); } QuantizationMode string_to_quantization_mode( @@ -3441,6 +3445,10 @@ QuantizationMode string_to_quantization_mode( return QuantizationMode::Mxfp8; } else if (mode == "nvfp4") { return QuantizationMode::Nvfp4; + } else if (mode == "turbo3") { + return QuantizationMode::TurboQuant3; + } else if (mode == "turbo4") { + return QuantizationMode::TurboQuant4; } std::string msg; if (!tag.empty()) { @@ -3473,6 +3481,14 @@ std::pair quantization_params_from_mode( default_group_size = 32; default_bits = 8; break; + case QuantizationMode::TurboQuant3: + default_group_size = 64; + default_bits = 3; + break; + case QuantizationMode::TurboQuant4: + default_group_size = 64; + default_bits = 4; + break; } return { group_size_.value_or(default_group_size), bits_.value_or(default_bits)}; diff --git a/mlx/primitives.h b/mlx/primitives.h index 2b3a6e4719..d858b42807 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -152,7 +152,14 @@ class MLX_API UnaryPrimitive : public Primitive { UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; }; -enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; +enum class QuantizationMode { + Affine, + Mxfp4, + Mxfp8, + Nvfp4, + TurboQuant3, + TurboQuant4 +}; std::string quantization_mode_to_string(QuantizationMode mode); QuantizationMode string_to_quantization_mode( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index c982f6b3e6..01ad481832 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -388,7 +388,7 @@ void init_fast(nb::module_& parent_module) { sinks (array, optional): An optional array of attention sinks with shape ``[N_q]``. group_size (int, optional): The group size used in the KV quantization. Defaults follow the quantization ``mode``. bits (int, optional): The bits used in the KV quantization. Defaults follow the quantization ``mode``. - mode (str, optional): The quantization mode: ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, or ``"affine"``. + mode (str, optional): The quantization mode: ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, ``"affine"``, ``"turbo3"``, or ``"turbo4"``. TurboQuant modes use a WHT rotation + Lloyd-Max codebook and require ``group_size == head_dim``; ``k_biases`` and ``v_biases`` must be ``None``. causal (bool, optional): Whether to apply lower-right aligned causal masking. Cannot be used together with ``mask``. Returns: diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 8c46e7e9e9..bfe0c06e25 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -1608,6 +1608,173 @@ def test_quantize_strided(self): expected = mx.dequantize(w_q, mx.contiguous(scales), mode=mode) self.assertTrue(mx.allclose(w_hat, expected)) + def test_quantized_sdpa_turbo(self): + """TurboQuant 3-bit and 4-bit SDPA: pack indices, run kernel, compare + against standard SDPA using the dequantized K and V.""" + if mx.default_device() == mx.cpu: + self.skipTest("TurboQuant SDPA requires a Metal GPU.") + + import numpy as np + + # Codebooks must match the constants in quantized_utils.h + CB3 = np.array( + [-1.7481, -1.0498, -0.5012, -0.1624, 0.1624, 0.5012, 1.0498, 1.7481], + dtype=np.float32, + ) + CB4 = np.array( + [ + -1.9672, + -1.3305, + -1.0130, + -0.7811, + -0.5714, + -0.4053, + -0.2382, + -0.0784, + 0.0784, + 0.2382, + 0.4053, + 0.5714, + 0.7811, + 1.0130, + 1.3305, + 1.9672, + ], + dtype=np.float32, + ) + + def pack_3bit(indices_np): + """Pack 3-bit indices to uint32 using PackReader<3> byte layout. + + 8 indices → 3 bytes (24 bits), little-endian, stored as uint32 words. + """ + *batch, D = indices_np.shape + assert D % 8 == 0 + n_packs = D // 8 + flat = indices_np.reshape(-1, D).astype(np.int64) + n_tok = flat.shape[0] + buf = np.zeros((n_tok, n_packs * 3), dtype=np.uint8) + for t in range(n_tok): + for p in range(n_packs): + ix = flat[t, p * 8 : p * 8 + 8] + v = ( + int(ix[0]) + | (int(ix[1]) << 3) + | (int(ix[2]) << 6) + | (int(ix[3]) << 9) + | (int(ix[4]) << 12) + | (int(ix[5]) << 15) + | (int(ix[6]) << 18) + | (int(ix[7]) << 21) + ) + buf[t, p * 3] = v & 0xFF + buf[t, p * 3 + 1] = (v >> 8) & 0xFF + buf[t, p * 3 + 2] = (v >> 16) & 0xFF + n_u32 = D * 3 // 32 + u32 = np.frombuffer(buf.tobytes(), dtype=" Date: Sat, 25 Apr 2026 21:02:48 +0200 Subject: [PATCH 2/2] tests: add test_quantized_sdpa_turbo for turbo3/turbo4 12 sub-cases: D in {64,128,256} x bits in {3,4} x B in {1,2}, each targeting a distinct Metal template instantiation. Explicit bfloat16, error path coverage. --- python/tests/test_quantized.py | 182 ++++++++++++++++++--------------- 1 file changed, 101 insertions(+), 81 deletions(-) diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index bfe0c06e25..3fa0579336 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -1610,42 +1610,32 @@ def test_quantize_strided(self): def test_quantized_sdpa_turbo(self): """TurboQuant 3-bit and 4-bit SDPA: pack indices, run kernel, compare - against standard SDPA using the dequantized K and V.""" + against standard SDPA using the dequantized K and V. + + Tests all supported head_dim values (64, 128, 256) and batch sizes + to exercise each Metal template instantiation independently. + """ if mx.default_device() == mx.cpu: self.skipTest("TurboQuant SDPA requires a Metal GPU.") - + import numpy as np - - # Codebooks must match the constants in quantized_utils.h + CB3 = np.array( [-1.7481, -1.0498, -0.5012, -0.1624, 0.1624, 0.5012, 1.0498, 1.7481], dtype=np.float32, ) CB4 = np.array( [ - -1.9672, - -1.3305, - -1.0130, - -0.7811, - -0.5714, - -0.4053, - -0.2382, - -0.0784, - 0.0784, - 0.2382, - 0.4053, - 0.5714, - 0.7811, - 1.0130, - 1.3305, - 1.9672, + -1.9672, -1.3305, -1.0130, -0.7811, + -0.5714, -0.4053, -0.2382, -0.0784, + 0.0784, 0.2382, 0.4053, 0.5714, + 0.7811, 1.0130, 1.3305, 1.9672, ], dtype=np.float32, ) - + def pack_3bit(indices_np): """Pack 3-bit indices to uint32 using PackReader<3> byte layout. - 8 indices → 3 bytes (24 bits), little-endian, stored as uint32 words. """ *batch, D = indices_np.shape @@ -1658,14 +1648,14 @@ def pack_3bit(indices_np): for p in range(n_packs): ix = flat[t, p * 8 : p * 8 + 8] v = ( - int(ix[0]) - | (int(ix[1]) << 3) - | (int(ix[2]) << 6) - | (int(ix[3]) << 9) - | (int(ix[4]) << 12) - | (int(ix[5]) << 15) - | (int(ix[6]) << 18) - | (int(ix[7]) << 21) + int(ix[0]) + | (int(ix[1]) << 3) + | (int(ix[2]) << 6) + | (int(ix[3]) << 9) + | (int(ix[4]) << 12) + | (int(ix[5]) << 15) + | (int(ix[6]) << 18) + | (int(ix[7]) << 21) ) buf[t, p * 3] = v & 0xFF buf[t, p * 3 + 1] = (v >> 8) & 0xFF @@ -1673,7 +1663,7 @@ def pack_3bit(indices_np): n_u32 = D * 3 // 32 u32 = np.frombuffer(buf.tobytes(), dtype="