Fix kernel dispatch naming in quant_sdpa_vector_2pass#1
Open
andershansson wants to merge 20 commits intoCC-Yeh:quantized_sdpafrom
Open
Fix kernel dispatch naming in quant_sdpa_vector_2pass#1andershansson wants to merge 20 commits intoCC-Yeh:quantized_sdpafrom
andershansson wants to merge 20 commits intoCC-Yeh:quantized_sdpafrom
Conversation
The dispatch in quant_sdpa_vector_2pass builds the kernel name using
v.shape(-1), but for quantized inputs this returns the packed uint32
dimension (16 for 4-bit packing head_dim=128, 24 for 6-bit, 32 for 8-bit)
rather than the logical value head_dim that the kernels are instantiated
with.
The kernels in scaled_dot_product_attention.metal are instantiated with
logical dims (64_64, 128_128, 256_256), so dispatch fails with:
[metal::Device] Unable to load function
quant_sdpa_vector_2pass_1_float_128_16
Function quant_sdpa_vector_2pass_1_float_128_16 was not found
in the library
The function already has 'int bits' as a parameter, so the logical dim
is recoverable: logical_dim = packed_dim * 32 / bits. This formula works
cleanly for all currently-supported affine bit widths (4, 6, 8):
bits=4: packed=16 -> logical=16*32/4=128
bits=6: packed=24 -> logical=24*32/6=128
bits=8: packed=32 -> logical=32*32/8=128
Verified by running the existing test_quantized_sdpa_affine test cases
which were failing on all three bit widths before the fix and now pass.
All 10 quantized SDPA test methods (54 subtests) in test_quantized.py
pass after this one-line fix.
Reproduction and verification performed on M4 Max, macOS 26,
Python 3.12, built from source against commit 6291e80 with the
downloaded Metal Toolchain 17E188.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the kernel-dispatch naming bug
@Thump604reported in#3026([metal::Device] Unable to load function quant_sdpa_vector_2pass_1_float_128_16). One-line change inmlx/backend/metal/scaled_dot_product_attention.cpp, plus the already-existingtest_quantized_sdpa_affinetests all start passing.The bug
In
quant_sdpa_vector_2pass(around line 660) the dispatch builds the kernel name fromv.shape(-1):kname += std::to_string(v.shape(-1));For a quantized
v,v.shape(-1)returns the packed uint32 dimension, not the logical value head_dim. The kernels inscaled_dot_product_attention.metalare instantiated with the logical dims (64_64,128_128,256_256), so the dispatch looks up names that don't exist:quant_sdpa_vector_2pass_1_float_128_16❌quant_sdpa_vector_2pass_1_float_128_24❌quant_sdpa_vector_2pass_1_float_128_32❌(The second pass at line 776 is fine — it uses
out.shape(-1)which is the unquantized output shape.)The fix
The function already has
int bitsas a parameter, so the logical dim is recoverable from the packed dim:The formula works cleanly for all currently-supported affine bit widths:
packed * 32 / bits16 * 32 / 4 = 128✓24 * 32 / 6 = 128✓32 * 32 / 8 = 128✓Verification
Built from source on M4 Max, macOS 26, Python 3.12, with the downloaded Metal Toolchain 17E188. Against commit
6291e80.Before the fix (
test_quantized_sdpa_affineinpython/tests/test_quantized.py):After the fix:
All 10 quantized SDPA test methods (
test_quantized_sdpa,test_quantized_sdpa_affine,test_quantized_sdpa_masked,test_quantized_sdpa_affine_masked,test_quantized_sdpa_sinks,test_quantized_sdpa_masked_with_sinks,test_quantized_sdpa_affine_masked_with_sinks,test_quantized_sdpa_causal,test_quantized_sdpa_affine_causal,test_quantized_sdpa_causal_with_array_mask_error) now pass — 54 subtests total.Also independently verified numerical correctness against a dequantize-then-fp16-SDPA reference path on realistic decode shapes (B=1, n_h=32, n_kv=8, head_dim ∈ {64, 128, 256}, bits ∈ {4, 8}, context lengths 2048 and 8192) — cosine similarity 1.000000 across every configuration.
Context
We're maintaining turboquant-mlx, a TurboQuant-style KV cache compression library built on top of MLX, and we're tracking ml-explore#3026 closely. Happy to help move this PR forward with additional M4 Max benchmarks or cross-architecture correctness tests (Qwen2.5 with biased k_proj, Phi-3.5 with
head_dim=96, gpt-oss alternating sliding/full attention) once the main PR is unblocked.Test plan
python -m pytest python/tests/test_quantized.py -k "sdpa"— all 10 tests, 54 subtests pass on M4 Max