Skip to content

Add fused quantized MLA SDPA kernel for Mistral Small 4#3373

Closed
ProducerGuy wants to merge 5 commits into
ml-explore:mainfrom
ProducerGuy:phase3c-v3-kernel-opt
Closed

Add fused quantized MLA SDPA kernel for Mistral Small 4#3373
ProducerGuy wants to merge 5 commits into
ml-explore:mainfrom
ProducerGuy:phase3c-v3-kernel-opt

Conversation

@ProducerGuy
Copy link
Copy Markdown

@ProducerGuy ProducerGuy commented Apr 4, 2026

Title: Add fused quantized MLA SDPA kernel for Mistral Small 4

Summary

Adds a fused Metal kernel for Multi-head Latent Attention (MLA) decode,
replacing 5+ separate kernel dispatches with one. Designed for Mistral
Small 4's absorbed MLA architecture. Companion to ml-explore/mlx-lm#1037.

What this kernel does

Single dispatch that fuses:

  1. INT4 affine dequant of quantized latent cache
  2. Split nope/rope attention scoring via simd_sum
  3. Online softmax (running max + sum_exp)
  4. Value accumulation (reuses dequanted latent — no re-read)
  5. Cross-simdgroup reduction
  6. New token quantization + direct cache append

Based on sdpa_vector.h patterns. Decode-only (L==1).

Key design decisions

  • Scale applied at query load time (per sdpa_vector.h pattern)
  • New token quantized to threadgroup memory, NOT reread from device cache
    after writing (avoids device memory barrier)
  • Cache append via copy_shared_buffer aliasing — zero-copy, eliminates
    SliceUpdate full-cache copy overhead
  • Threadgroup memory padding (+1 stride) avoids bank conflicts in
    cross-simdgroup transpose reduction
  • All intermediate values stay in registers/threadgroup memory — no
    device memory round-trips between stages

Files

File Purpose
mlx/backend/metal/kernels/mla_fused_sdpa.metal Metal kernel (v1 + v2)
mlx/backend/metal/mla_fused_sdpa.cpp GPU dispatch
mlx/fast_primitives.h Primitive classes
mlx/fast.h API declarations
mlx/fast.cpp Op functions with fallbacks
python/src/fast.cpp Python bindings
mlx/backend/metal/kernels/mla_quantize_store.metal INT4 quantize kernel
mlx/backend/metal/mla_quantize_store.cpp Quantize dispatch

API

# Fused SDPA + cache update
mx.fast.mla_fused_sdpa_v2(
    q_nope, q_pe,                    # [B, H, 256], [B, H, 64]
    cache_packed, cache_scales,      # [B, S_alloc, 32], [B, S_alloc, 4]
    cache_biases, cache_kpe,         # [B, S_alloc, 4], [B, S_alloc, 64]
    new_latent, new_kpe,             # [B, 1, 256], [B, 1, 64]
    scale, seq_offset)
# Returns: (sdpa_out, updated_packed, updated_scales, updated_biases, updated_kpe)

Hardware

Tested on Apple M5 Max (128GB, 40-core GPU). Uses simd_sum,
simd_shuffle_xor, simd_max — requires Apple Silicon with
simdgroup support (M1+).

Test plan

  • Staged correctness tests (S=1,4,8,16,32,64)
  • Max relative error ≤ 0.01% vs reference
  • Aliasing hazard tests (sequential calls, interleaved layers)
  • Full model integration verified
  • Perplexity: 4.606 ± 0.064 (zero degradation from INT4 cache)

Producer Guy added 5 commits April 1, 2026 20:49
Based on sdpa_vector.h: online softmax, BN=32 simdgroups, in-kernel
INT4 dequant, split nope/rope scoring, value accumulation reusing
dequanted latent. Single dispatch replaces 5+ separate ops.

Kernel time: 0.180 ms (matches fp16 SDPA 0.177 ms).
Correctness: all staged tests pass, max rel error 0.0096%.

Files:
  mla_fused_sdpa.metal — Metal kernel
  mla_fused_sdpa.cpp — C++ eval_gpu dispatch
  fast_primitives.h — MLAFusedSDPA class
  fast.h / fast.cpp — op function + fallback
  python/src/fast.cpp — Python binding
Added scale as constant float& buffer(10). Applied at query load time:
q_n[i] = static_cast<U>(scale) * static_cast<U>(q_nope_ptr[...])

Same pattern as sdpa_vector.h. Eliminates separate Python scale dispatch.
Single-dispatch INT4 affine quantization replacing mx.quantize
for MLA latent cache. Per-group min/max via simd_shuffle_xor,
quantize + pack in registers, one simdgroup per 256-dim vector.

No measurable tok/s improvement — dispatch overhead is absorbed
by MLX lazy eval pipeline. Kernel works correctly.
mla_fused_sdpa_v2 kernel: quantizes new token to threadgroup memory,
writes to cache for persistence, SDPA reads new token from threadgroup
memory (not device cache). copy_shared_buffer aliasing eliminates
SliceUpdate full-cache copies. MLAFusedSDPAWithCacheUpdate primitive
returns 5 outputs (SDPA result + 4 aliased cache arrays).
Added #pragma clang loop unroll(full) on all fixed-iteration loops.
Added +1 padding to tg_out stride (33 vs 32) for bank-conflict-free
cross-simdgroup transpose reduction.

Attribution testing inconclusive — improvements within run-to-run noise
at short context (S=22-30). Research predicts this: kernel is latency-bound
at short S, micro-optimizations improve the variable part which is near-zero.
Kept because no regression detected. Not a strategic driver.
@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Apr 8, 2026

Thanks for the pull request! Unfortunately due to maintenance burden we do not plan to add built-in GPU kernel for every possible op, for cases like this we suggest using custom extensions and custom kernels instead.

@zcbenz zcbenz closed this Apr 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants