Add fused quantized MLA SDPA kernel for Mistral Small 4#3373
Closed
ProducerGuy wants to merge 5 commits into
Closed
Add fused quantized MLA SDPA kernel for Mistral Small 4#3373ProducerGuy wants to merge 5 commits into
ProducerGuy wants to merge 5 commits into
Conversation
added 5 commits
April 1, 2026 20:49
Based on sdpa_vector.h: online softmax, BN=32 simdgroups, in-kernel INT4 dequant, split nope/rope scoring, value accumulation reusing dequanted latent. Single dispatch replaces 5+ separate ops. Kernel time: 0.180 ms (matches fp16 SDPA 0.177 ms). Correctness: all staged tests pass, max rel error 0.0096%. Files: mla_fused_sdpa.metal — Metal kernel mla_fused_sdpa.cpp — C++ eval_gpu dispatch fast_primitives.h — MLAFusedSDPA class fast.h / fast.cpp — op function + fallback python/src/fast.cpp — Python binding
Added scale as constant float& buffer(10). Applied at query load time: q_n[i] = static_cast<U>(scale) * static_cast<U>(q_nope_ptr[...]) Same pattern as sdpa_vector.h. Eliminates separate Python scale dispatch.
Single-dispatch INT4 affine quantization replacing mx.quantize for MLA latent cache. Per-group min/max via simd_shuffle_xor, quantize + pack in registers, one simdgroup per 256-dim vector. No measurable tok/s improvement — dispatch overhead is absorbed by MLX lazy eval pipeline. Kernel works correctly.
mla_fused_sdpa_v2 kernel: quantizes new token to threadgroup memory, writes to cache for persistence, SDPA reads new token from threadgroup memory (not device cache). copy_shared_buffer aliasing eliminates SliceUpdate full-cache copies. MLAFusedSDPAWithCacheUpdate primitive returns 5 outputs (SDPA result + 4 aliased cache arrays).
Added #pragma clang loop unroll(full) on all fixed-iteration loops. Added +1 padding to tg_out stride (33 vs 32) for bank-conflict-free cross-simdgroup transpose reduction. Attribution testing inconclusive — improvements within run-to-run noise at short context (S=22-30). Research predicts this: kernel is latency-bound at short S, micro-optimizations improve the variable part which is near-zero. Kept because no regression detected. Not a strategic driver.
Collaborator
|
Thanks for the pull request! Unfortunately due to maintenance burden we do not plan to add built-in GPU kernel for every possible op, for cases like this we suggest using custom extensions and custom kernels instead. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Title: Add fused quantized MLA SDPA kernel for Mistral Small 4
Summary
Adds a fused Metal kernel for Multi-head Latent Attention (MLA) decode,
replacing 5+ separate kernel dispatches with one. Designed for Mistral
Small 4's absorbed MLA architecture. Companion to ml-explore/mlx-lm#1037.
What this kernel does
Single dispatch that fuses:
simd_sumBased on
sdpa_vector.hpatterns. Decode-only (L==1).Key design decisions
sdpa_vector.hpattern)after writing (avoids device memory barrier)
copy_shared_bufferaliasing — zero-copy, eliminatesSliceUpdatefull-cache copy overheadcross-simdgroup transpose reduction
device memory round-trips between stages
Files
mlx/backend/metal/kernels/mla_fused_sdpa.metalmlx/backend/metal/mla_fused_sdpa.cppmlx/fast_primitives.hmlx/fast.hmlx/fast.cpppython/src/fast.cppmlx/backend/metal/kernels/mla_quantize_store.metalmlx/backend/metal/mla_quantize_store.cppAPI
Hardware
Tested on Apple M5 Max (128GB, 40-core GPU). Uses
simd_sum,simd_shuffle_xor,simd_max— requires Apple Silicon withsimdgroup support (M1+).
Test plan