From e1c923e5ea9c4c65aa16c03d709013e9319c6fec Mon Sep 17 00:00:00 2001 From: dogukanveziroglu Date: Tue, 14 Apr 2026 12:58:14 +0300 Subject: [PATCH] Support group_size=64 for affine quantized SDPA Add Affine dispatch entries for group_size=64 at bits={4,6,8} and relax the validation in quantized_scaled_dot_product_attention. This matches the default produced by mx.quantize(mode="affine") and the kv_group_size=64 default used by mlx-lm, so users following the MLX/mlx-lm conventions no longer hit an error when using fused quantized attention. Benchmarks (M4, B=1 H=32 D=128 Lq=1, affine 4-bit): Context gs=32 fused gs=64 fused speedup 32K 50 us 41 us +22% 64K 95 us 82 us +16% 128K 176 us 152 us +16% gs=64 is faster at long context because it has half the scale/bias memory traffic. Costs: mlx.metallib: 128,161,428 -> 128,233,236 bytes (+0.056%) libmlx.dylib: unchanged Existing 10 test_quantized_sdpa* tests continue to pass (54 subtests). --- mlx/backend/metal/kernels/sdpa_vector.h | 3 +++ mlx/fast.cpp | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index bf7b0bafd6..a823915123 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -653,6 +653,9 @@ template QUANT_SDPA_DISPATCH(Affine, 32, 4) QUANT_SDPA_DISPATCH(Affine, 32, 6) QUANT_SDPA_DISPATCH(Affine, 32, 8) + QUANT_SDPA_DISPATCH(Affine, 64, 4) + QUANT_SDPA_DISPATCH(Affine, 64, 6) + QUANT_SDPA_DISPATCH(Affine, 64, 8) QUANT_SDPA_DISPATCH(Mxfp4, 32, 4) QUANT_SDPA_DISPATCH(Nvfp4, 16, 4) QUANT_SDPA_DISPATCH(Mxfp8, 32, 8) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index aa3f88fd36..bddc569b55 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -901,9 +901,9 @@ array quantized_scaled_dot_product_attention( // Validate mode-specific group_size and bits if (is_affine) { - if (group_size != 32) { + if (group_size != 32 && group_size != 64) { std::ostringstream msg; - msg << "[" << tag << "] Affine mode supports group_size 32 " + msg << "[" << tag << "] Affine mode supports group_size 32 or 64 " << "but received " << group_size << "."; throw std::invalid_argument(msg.str()); }