Skip to content

feat(fast): add sliding-window SDPA kernel path#3552

Open
rabbitson87 wants to merge 1 commit into
ml-explore:mainfrom
rabbitson87:sliding-window-attention-kernel
Open

feat(fast): add sliding-window SDPA kernel path#3552
rabbitson87 wants to merge 1 commit into
ml-explore:mainfrom
rabbitson87:sliding-window-attention-kernel

Conversation

@rabbitson87
Copy link
Copy Markdown

feat(fast::sdpa): sliding-window attention via has_window function constant + kb_start truncation

Summary

Adds optional window_size parameter to mlx::fast::scaled_dot_product_attention enabling sliding-window attention (each Q position attends only to the last window_size K positions, combined with causal masking). The implementation mirrors the existing do_causal upper-bound (kb_lim) pattern with a symmetric lower-bound (kb_start) in the steel attention kernel, skipping K-blocks below the window's lower edge wholesale.

Motivation

Sliding-window attention is used by Gemma 2, Gemma 3, Gemma 4, Mistral 7B, and several long-context Qwen variants. Currently mlx callers must construct an explicit [L_q, L_kv] bool mask Array to express the window pattern, which:

  1. Materializes the mask tensor (8 KB at L=1024 → 64 MB at L=8192).
  2. Forces the matmul fallback for head_dim ∉ {64, 80, 128} since the steel kernel's sdpa_full_supported_head_dim doesn't include 256/512.
  3. Materializes the full [B, H, L, L] scores tensor in the fallback path (1 GB bf16 at L=8192, B=1, H=8) → significant DRAM pressure.

With explicit window_size, the steel kernel can:

  • Compute valid K-block range [kb_start, kb_lim) per Q-tile and skip the rest entirely.
  • Apply per-element left-edge mask only for the first ceil(BQ/BK) blocks past kb_start (symmetric with the existing causal upper-edge mask).
  • Avoid mask Array allocation + scan entirely.

Performance (Gemma 4 26B-A4B, M3 Max, macOS 26.3.1)

Methodology: 3 timed trials per cell, alternating ON/OFF within context to
control thermal drift, 1 untimed warmup per cell, 30s cooldown between
trials, 60s between configs, 90s between contexts. STEPS=32 decode,
WARMUP=4 (bench-internal). Window size W=1024.

Prefill (single-pass, sliding-window layers exercised):

Context Baseline (array-mask path) Windowed kernel Δ K-block skip ratio
2K 986.5 ± 8.3 tok/s 1039.4 ± 11.5 tok/s +5.4% 50.0%
4K 863.7 ± 7.5 tok/s 997.6 ± 9.0 tok/s +15.5% 75.0%
8K 678.8 ± 0.6 tok/s 914.7 ± 1.1 tok/s +34.8% 87.5%

Decode (32 steps each, post-prefill autoregressive):

Context Baseline Windowed Δ
2K 49.0 ± 15.5 tok/s 49.9 ± 19.6 tok/s neutral (small-ctx timing noise)
4K 47.1 ± 10.1 tok/s 46.8 ± 9.7 tok/s neutral
8K 50.9 ± 1.2 tok/s 50.4 ± 1.5 tok/s neutral (tight std)

The prefill speedup tracks the theoretical K-block skip ratio
floor((L − W) / L) — at 8K, 87.5% of K-blocks below the window's lower
edge are skipped wholesale, yielding a 34.8% end-to-end prefill speedup
once the unchanged full-attention layers, MoE experts, and dense layers
are amortized in. The 8K measurement is essentially noise-free (baseline
σ = 0.6 tok/s, windowed σ = 1.1 tok/s over 3 trials), making the +34.8%
margin statistically unambiguous.

Decode is neutral at all contexts. The 2K/4K decode std is high because
total decode time at small ctx is short and timing jitter dominates per
trial; the 8K decode measurement (σ ≈ 1.3 tok/s on both arms) is the
canonical reading for decode neutrality.

Raw measurements: see "Reproducibility" section below.

Implementation

Kernel (steel_attention.h / steel_attention_nax.h)

constant bool has_window [[function_constant(303)]];

// Lower-bound K-block truncation, mirrors do_causal upper-bound:
int kb_start = 0;
if (has_window) {
    int q_min = tid.x * BQ + params->qL_off;
    int k_min = q_min - params->window_size + 1;
    if (k_min > 0) kb_start = min(kb_lim, k_min / BK);
}
for (int kb = kb_start; kb < kb_lim; kb++) { ... }

// Per-element left-edge mask (only first ceil(BQ/BK) blocks past kb_start):
if (has_window && kb < kb_start + ((BQ + BK - 1) / BK)) {
    // row_pos - col_pos >= W → mask = -inf
}

The window check uses K-relative coordinates (qL_off = kL - qL equals kv_offset - cache_first_held_pos for both rotated and non-rotated caches), so chunked prefill with rotating sliding-window cache works automatically.

Public API

// mlx/fast.h
MLX_API array scaled_dot_product_attention(
    const array& queries,
    const array& keys,
    const array& values,
    const float scale,
    const std::string& mask_mode = "",
    std::optional<array> mask_arr = {},
    const std::optional<array>& sinks = {},
    int window_size = 0,    // NEW (default 0 = no window)
    StreamOrDevice s = {});

use_fallback

window_size > 0 allows BD=256 path even when LUMEN_GEMMA4_PREFILL_FAST_BD256 env gate is not set — the window-aware kernel is faster than fallback on non-NAX hardware (87.5% K-block skip dominates over the lack of NAX tensor unit fragments).

Backward compatibility

  • Default window_size = 0 preserves all existing behavior. Function constant has_window = false → kernel takes original code path with kb_start = 0.
  • Python: mx.fast.scaled_dot_product_attention(...) gets a new optional window_size kwarg. No breakage to existing callers.
  • C++: signature change is parameter addition with default, ABI-compatible at the C++ level.
  • C ABI: mlx_fast_scaled_dot_product_attention passes window_size=0 internally; no signature change required for C consumers.

Test coverage

The implementation has been validated against Gemma 4 26B-A4B on M3 Max:

  • Functional smoke (chat completions produce coherent output).
  • Multi-trial perf A/B confirming the speedup at 2K/4K/8K.
  • Chunked-prefill rotation case (chunk_size=1024, prompt_len=4096, 4 chunks) — kernel produces valid outputs via K-relative coordinate math.

Outstanding before merge (would appreciate maintainer feedback on test expectations):

  • Bit-identical output vs the manually-constructed sliding mask Array path. In our measurement bf16 accumulation order differs between kernel and array-mask path → greedy argmax can flip when top-K logits are close. Both are mathematically valid sliding attention. Maintainers may want to add tolerance-based tests; we did not include explicit unit tests for this.
  • D ∈ {64, 80, 128, 256} smoke (only D=256 directly tested; the other BD instantiations should behave the same — same kernel code, different specialization).
  • NAX path verification (kernel was updated symmetrically but tested only on non-NAX M3 Max).
  • VJP support — the windowed forward is supported; backward (ScaledDotProductAttentionVJP) was not extended. Probably needs same treatment but it's out of scope for inference-only use.

Notes for reviewers

  • The is_equivalent extension to compare window_size_ is required for correctness — primitives with different windows must not be deduplicated by mlx's graph optimizer. Independent clean A/B (3 trials, alternating, 8K) confirms decode throughput is unchanged (50.9 → 50.4 tok/s with σ ≈ 1.3 on both arms) with windowed kernel ON. An initial -22% decode reading turned out to be measurement variance at low trial count.
  • Two minor host-side helpers (hash_name extended with _has_window_, fallback lambda extended for sliding mask construction) — included for completeness.
  • BD=512 instantiation was independently attempted (BQ=16, BK=8, WM=2, padQ=0 conditional) to unlock head_dim=512 full-attention paths. It builds and runs functionally but A/B showed -25% prefill regression on M3 Max (non-NAX) — the WM=2 / padQ=0 trade-offs vs matmul fallback's tuned gemm dispatch are net negative. Reverted, not included in this PR. May be worth revisiting once NAX-capable M5+ hardware is available.

Files changed

mlx/backend/metal/kernels/steel/attn/params.h                    +2
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h   +47
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +46
mlx/backend/metal/scaled_dot_product_attention.cpp               +~30  (only sliding-window related)
mlx/backend/cuda/scaled_dot_product_attention.cpp                +4    (fallback for unsupported windowed path)
mlx/backend/no_gpu/primitives.cpp                                +1    (signature sync)
mlx/fast.h                                                       +9
mlx/fast.cpp                                                     +~65  (only sliding-window related)
mlx/fast_primitives.h                                            +15
python/src/fast.cpp                                              +~30  (Python kwarg binding/docs)

Total ~200 lines of sliding-window-only changes.

Reproducibility

The performance numbers above were collected with the following protocol on
a quiet M3 Max (no other GPU consumers, mains power):

hardware:    M3 Max (40-core GPU)
os:          macOS 26.3.1
model:       Gemma 4 26B-A4B Q4 (lumen-rs native backend, mlx ~main)
window_size: 1024 (Gemma-4 sliding layers)
bench:       lumen-rs `bench_gemma4_native_e2e`
decode:      STEPS=32, WARMUP=4
arms:        windowed ON (default)   vs   windowed OFF (LUMEN_GEMMA4_SDPA_WINDOWED=0)
trials/cell: 3 (timed) + 1 (untimed warmup)
order:       alternating ON / OFF within each context (thermal-drift control)
cooldowns:   30s between trials, 60s between configs, 90s between contexts

Raw per-trial readings (prefill tok/s | decode tok/s):

ctx=2K  ON  T1: 1032.8 | 64.6     OFF T1:  998.2 | 27.2
ctx=2K  ON  T2: 1055.6 | 22.1     OFF T2:  981.6 | 58.1
ctx=2K  ON  T3: 1029.9 | 62.9     OFF T3:  979.6 | 61.7
ctx=4K  ON  T1:  991.1 | 53.4     OFF T1:  857.4 | 53.8
ctx=4K  ON  T2:  991.4 | 53.8     OFF T2:  859.6 | 54.7
ctx=4K  ON  T3: 1010.4 | 33.1     OFF T3:  874.2 | 32.8
ctx=8K  ON  T1:  915.9 | 48.2     OFF T1:  678.6 | 49.2
ctx=8K  ON  T2:  915.0 | 51.5     OFF T2:  678.1 | 51.8
ctx=8K  ON  T3:  913.2 | 51.4     OFF T3:  679.6 | 51.6

The 8K prefill measurement is the tightest signal in the dataset
(σ = 0.6 / 1.1 tok/s on the two arms) and is the load-bearing evidence
for the kernel's prefill speedup. The smaller-context numbers carry more
timing noise but show consistent directionality and align with the
predicted K-block skip ratio (50% / 75% / 87.5% at 2K / 4K / 8K).

@angeloskath angeloskath requested a review from jagrit06 May 16, 2026 00:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant