feat(fast): add sliding-window SDPA kernel path#3552
Open
rabbitson87 wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat(fast::sdpa): sliding-window attention via has_window function constant + kb_start truncation
Summary
Adds optional
window_sizeparameter tomlx::fast::scaled_dot_product_attentionenabling sliding-window attention (each Q position attends only to the lastwindow_sizeK positions, combined with causal masking). The implementation mirrors the existingdo_causalupper-bound (kb_lim) pattern with a symmetric lower-bound (kb_start) in the steel attention kernel, skipping K-blocks below the window's lower edge wholesale.Motivation
Sliding-window attention is used by Gemma 2, Gemma 3, Gemma 4, Mistral 7B, and several long-context Qwen variants. Currently mlx callers must construct an explicit
[L_q, L_kv]bool mask Array to express the window pattern, which:head_dim ∉ {64, 80, 128}since the steel kernel'ssdpa_full_supported_head_dimdoesn't include 256/512.[B, H, L, L]scores tensor in the fallback path (1 GB bf16 at L=8192, B=1, H=8) → significant DRAM pressure.With explicit
window_size, the steel kernel can:[kb_start, kb_lim)per Q-tile and skip the rest entirely.ceil(BQ/BK)blocks pastkb_start(symmetric with the existing causal upper-edge mask).Performance (Gemma 4 26B-A4B, M3 Max, macOS 26.3.1)
Methodology: 3 timed trials per cell, alternating ON/OFF within context to
control thermal drift, 1 untimed warmup per cell, 30s cooldown between
trials, 60s between configs, 90s between contexts. STEPS=32 decode,
WARMUP=4 (bench-internal). Window size W=1024.
Prefill (single-pass, sliding-window layers exercised):
Decode (32 steps each, post-prefill autoregressive):
The prefill speedup tracks the theoretical K-block skip ratio
floor((L − W) / L)— at 8K, 87.5% of K-blocks below the window's loweredge are skipped wholesale, yielding a 34.8% end-to-end prefill speedup
once the unchanged full-attention layers, MoE experts, and dense layers
are amortized in. The 8K measurement is essentially noise-free (baseline
σ = 0.6 tok/s, windowed σ = 1.1 tok/s over 3 trials), making the +34.8%
margin statistically unambiguous.
Decode is neutral at all contexts. The 2K/4K decode std is high because
total decode time at small ctx is short and timing jitter dominates per
trial; the 8K decode measurement (σ ≈ 1.3 tok/s on both arms) is the
canonical reading for decode neutrality.
Raw measurements: see "Reproducibility" section below.
Implementation
Kernel (
steel_attention.h/steel_attention_nax.h)The window check uses K-relative coordinates (
qL_off = kL - qLequalskv_offset - cache_first_held_posfor both rotated and non-rotated caches), so chunked prefill with rotating sliding-window cache works automatically.Public API
use_fallback
window_size > 0allows BD=256 path even whenLUMEN_GEMMA4_PREFILL_FAST_BD256env gate is not set — the window-aware kernel is faster than fallback on non-NAX hardware (87.5% K-block skip dominates over the lack of NAX tensor unit fragments).Backward compatibility
window_size = 0preserves all existing behavior. Function constanthas_window = false→ kernel takes original code path withkb_start = 0.mx.fast.scaled_dot_product_attention(...)gets a new optionalwindow_sizekwarg. No breakage to existing callers.mlx_fast_scaled_dot_product_attentionpasseswindow_size=0internally; no signature change required for C consumers.Test coverage
The implementation has been validated against Gemma 4 26B-A4B on M3 Max:
Outstanding before merge (would appreciate maintainer feedback on test expectations):
ScaledDotProductAttentionVJP) was not extended. Probably needs same treatment but it's out of scope for inference-only use.Notes for reviewers
is_equivalentextension to comparewindow_size_is required for correctness — primitives with different windows must not be deduplicated by mlx's graph optimizer. Independent clean A/B (3 trials, alternating, 8K) confirms decode throughput is unchanged (50.9 → 50.4 tok/s with σ ≈ 1.3 on both arms) with windowed kernel ON. An initial -22% decode reading turned out to be measurement variance at low trial count.hash_nameextended with_has_window_, fallback lambda extended for sliding mask construction) — included for completeness.Files changed
Total ~200 lines of sliding-window-only changes.
Reproducibility
The performance numbers above were collected with the following protocol on
a quiet M3 Max (no other GPU consumers, mains power):
Raw per-trial readings (prefill tok/s | decode tok/s):
The 8K prefill measurement is the tightest signal in the dataset
(σ = 0.6 / 1.1 tok/s on the two arms) and is the load-bearing evidence
for the kernel's prefill speedup. The smaller-context numbers carry more
timing noise but show consistent directionality and align with the
predicted K-block skip ratio (50% / 75% / 87.5% at 2K / 4K / 8K).