From efc195b6f1de0c9115f81567d9f7e6562c1c0176 Mon Sep 17 00:00:00 2001 From: Hee Sung Son Date: Sat, 16 May 2026 06:23:07 +0900 Subject: [PATCH] feat(fast): add sliding-window SDPA kernel path --- .../cuda/scaled_dot_product_attention.cpp | 4 ++ .../steel/attn/kernels/steel_attention.h | 37 ++++++++++- .../steel/attn/kernels/steel_attention_nax.h | 42 +++++++++++- mlx/backend/metal/kernels/steel/attn/params.h | 2 + .../metal/scaled_dot_product_attention.cpp | 29 +++++++-- mlx/backend/no_gpu/primitives.cpp | 1 + mlx/fast.cpp | 64 +++++++++++++++---- mlx/fast.h | 8 ++- mlx/fast_primitives.h | 15 ++++- python/src/fast.cpp | 29 +++++++-- 10 files changed, 204 insertions(+), 27 deletions(-) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index ca411e91c6..4b502e26eb 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -558,7 +558,11 @@ bool ScaledDotProductAttention::use_fallback( bool do_causal, bool is_training, bool output_logsumexp, + int window_size, Stream s) { + if (window_size > 0) { + return true; + } if (s.device == Device::cpu) { return true; } diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 0d9628e834..8615bb31fb 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -14,6 +14,7 @@ constant bool align_K [[function_constant(201)]]; constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; constant bool has_sinks [[function_constant(302)]]; +constant bool has_window [[function_constant(303)]]; struct MaxOp { template @@ -234,6 +235,7 @@ template < } int kb_lim = params->NK; + int kb_start = 0; int kb_min_causal = params->NK; if (do_causal) { @@ -246,8 +248,19 @@ template < kb_min_causal = (q_min / BK); } + if (has_window) { + int q_min = tid.x * BQ + params->qL_off; + int k_min = q_min - params->window_size + 1; + if (k_min > 0) { + kb_start = k_min / BK; + if (kb_start > kb_lim) { + kb_start = kb_lim; + } + } + } + // Loop over KV seq length - for (int kb = 0; kb < kb_lim; kb++) { + for (int kb = kb_start; kb < kb_lim; kb++) { // Load K block and apply scale threadgroup_barrier(mem_flags::mem_threadgroup); if (!align_K && kb == (params->NK_aligned)) { @@ -325,6 +338,28 @@ template < } } + if (has_window && kb < (kb_start + ((BQ + BK - 1) / BK))) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = + tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos - (col_pos + jj) >= params->window_size) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + // Other masking as needed if (has_mask) { using stile_t = decltype(Stile); diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index adc9a42798..9ef85eb479 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -17,6 +17,7 @@ constant bool align_K [[function_constant(201)]]; constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; constant bool has_sinks [[function_constant(302)]]; +constant bool has_window [[function_constant(303)]]; template struct TransformScale { @@ -174,6 +175,7 @@ template < } int kb_lim = params->NK; + int kb_start = 0; int kb_min_causal = params->NK; if (do_causal) { @@ -186,6 +188,17 @@ template < kb_min_causal = (q_min / BK); } + if (has_window) { + int q_min = tid.x * BQ + params->qL_off; + int k_min = q_min - params->window_size + 1; + if (k_min > 0) { + kb_start = k_min / BK; + if (kb_start > kb_lim) { + kb_start = kb_lim; + } + } + } + const bool is_last_bq = int(tid.x) == (params->NQ_aligned); // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); const bool is_last_q = is_last_bq; @@ -194,7 +207,7 @@ template < const short lim_rows_k = params->kL_rem; // Loop over KV seq length - for (int kb = 0; kb < kb_lim; kb++) { + for (int kb = kb_start; kb < kb_lim; kb++) { const int is_last_k = (kb == (params->NK_aligned)); // Do S = Q @ K.T @@ -303,6 +316,33 @@ template < } } + if (has_window && kb < (kb_start + ((BQ + BK - 1) / BK))) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + params->qL_off + tm; + const int base_col = kb * BK; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + thread auto& fg = Stile.frag_at(iq, ik); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < stile_t::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::kFragThrCols; jj++) { + const auto r = + base_row + iq * kU + ii * stile_t::kFragRowsJump + sm; + const auto c = base_col + ik * kU + jj + sn; + const auto loc = ii * stile_t::kFragThrCols + jj; + fg[loc] = ((r - c) >= params->window_size) ? neg_inf : fg[loc]; + } + } + } + } + } + // Other masking as needed if (has_mask) { constexpr auto neg_inf = Limits::finite_min; diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h index f1cf09fada..571a82a64c 100644 --- a/mlx/backend/metal/kernels/steel/attn/params.h +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -30,6 +30,8 @@ struct AttnParams { int kL_rem; ///< Remainder in last key/value block int qL_off; ///< Offset in query sequence start + int window_size; ///< Sliding window size (0 = no window) + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d387a5c08c..ba8b36c933 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -24,6 +24,7 @@ void sdpa_full_self_attention_nax( const float scale, array& o, bool do_causal_, + int window_size_, const std::optional& mask, const std::optional& sinks) { using namespace mlx::steel; @@ -48,13 +49,15 @@ void sdpa_full_self_attention_nax( const bool has_mask = mask.has_value(); const bool do_causal = do_causal_; const bool has_sinks = sinks.has_value(); + const bool has_window = window_size_ > 0; metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, {&has_mask, MTL::DataType::DataTypeBool, 300}, {&do_causal, MTL::DataType::DataTypeBool, 301}, - {&has_sinks, MTL::DataType::DataTypeBool, 302}}; + {&has_sinks, MTL::DataType::DataTypeBool, 302}, + {&has_window, MTL::DataType::DataTypeBool, 303}}; std::string base_name; concatenate( @@ -87,7 +90,9 @@ void sdpa_full_self_attention_nax( "_do_causal_", (do_causal ? 't' : 'n'), "_has_sinks_", - (has_sinks ? 't' : 'n')); + (has_sinks ? 't' : 'n'), + "_has_window_", + (has_window ? 't' : 'n')); auto& compute_encoder = metal::get_command_encoder(s); @@ -133,6 +138,8 @@ void sdpa_full_self_attention_nax( /* int kL_rem = */ (kL - NK_aligned * bk), /* int qL_off = */ (kL - qL), + /* int window_size = */ window_size_, + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, @@ -172,6 +179,7 @@ void sdpa_full_self_attention_metal( const float scale, array& o, bool do_causal_, + int window_size_, const std::optional& mask, const std::optional& sinks) { if (metal::is_nax_available() && q.shape(3) != 80 && @@ -185,6 +193,7 @@ void sdpa_full_self_attention_metal( /* const float scale = */ scale, /* array& o = */ o, /* bool do_causal_ = */ do_causal_, + /* int window_size_ = */ window_size_, /* const std::optional& mask = */ mask, /* const std::optional& sinks = */ sinks); } @@ -211,13 +220,15 @@ void sdpa_full_self_attention_metal( const bool has_mask = mask.has_value(); const bool do_causal = do_causal_; const bool has_sinks = sinks.has_value(); + const bool has_window = window_size_ > 0; metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, {&has_mask, MTL::DataType::DataTypeBool, 300}, {&do_causal, MTL::DataType::DataTypeBool, 301}, - {&has_sinks, MTL::DataType::DataTypeBool, 302}}; + {&has_sinks, MTL::DataType::DataTypeBool, 302}, + {&has_window, MTL::DataType::DataTypeBool, 303}}; std::string base_name; concatenate( @@ -250,7 +261,9 @@ void sdpa_full_self_attention_metal( "_do_causal_", (do_causal ? 't' : 'n'), "_has_sinks_", - (has_sinks ? 't' : 'n')); + (has_sinks ? 't' : 'n'), + "_has_window_", + (has_window ? 't' : 'n')); auto& compute_encoder = metal::get_command_encoder(s); @@ -296,6 +309,8 @@ void sdpa_full_self_attention_metal( /* int kL_rem = */ (kL - NK_aligned * bk), /* int qL_off = */ (kL - qL), + /* int window_size = */ window_size_, + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, @@ -597,6 +612,7 @@ bool ScaledDotProductAttention::use_fallback( bool do_causal, bool is_training, bool output_logsumexp, + int window_size, Stream s) { if (is_training) { // It's faster for training on Metal to use the unfused SDPA for both @@ -623,7 +639,8 @@ bool ScaledDotProductAttention::use_fallback( (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || query_head_dim == 256); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + (query_head_dim == 64 || query_head_dim == 80 || + query_head_dim == 128 || (query_head_dim == 256 && window_size > 0)); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); @@ -782,7 +799,7 @@ void ScaledDotProductAttention::eval_gpu( : std::nullopt; sdpa_full_self_attention_metal( - s, d, q, k, v, scale_, o, do_causal_, mask, sinks); + s, d, q, k, v, scale_, o, do_causal_, window_size_, mask, sinks); } metal::get_command_encoder(s).add_temporaries(std::move(copies)); diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..1ade32634b 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -32,6 +32,7 @@ bool fast::ScaledDotProductAttention::use_fallback( bool do_causal, bool is_training, bool output_logsumexp, + int window_size, Stream s) { return true; } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..e1af53c4af 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -618,6 +618,7 @@ array scaled_dot_product_attention( const std::string& mask_mode /* = "" */, std::optional mask_arr /* = {} */, const std::optional& sinks /* = {} */, + int window_size /* = 0 */, StreamOrDevice s /* = {}*/) { for (const auto& tensor : {queries, keys, values}) { if (tensor.ndim() != 4) { @@ -662,6 +663,12 @@ array scaled_dot_product_attention( << mask_arr->shape() << " expected to have at most rank 4."; throw std::invalid_argument(msg.str()); } + if (window_size < 0) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] window_size must be non-negative; " + << "received " << window_size << "."; + throw std::invalid_argument(msg.str()); + } const size_t batch_dim = queries.shape(0); for (const auto& tensor : {keys, values}) { @@ -718,6 +725,7 @@ array scaled_dot_product_attention( n_q_heads, n_kv_heads, do_causal, + window_size, has_sinks, has_arr_mask, s](const std::vector& inputs) { @@ -731,19 +739,45 @@ array scaled_dot_product_attention( v = expand_dims(v, 2, s); } auto scores = matmul(q, swapaxes(k, -1, -2, s), s); - if (has_arr_mask || do_causal) { + if (has_arr_mask || do_causal || window_size > 0) { // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv] auto make_or_fetch_mask = [&]() { + int kL = k.shape(-2); + int qL = q.shape(-2); + int offset = kL - qL; + auto q_idx = arange(offset, qL + offset, s); + auto k_idx = arange(0, kL, s); + q_idx = expand_dims(q_idx, 1, s); + k_idx = expand_dims(k_idx, 0, s); + if (do_causal) { - int kL = k.shape(-2); - int qL = q.shape(-2); - int offset = kL - qL; - auto q_idx = arange(offset, qL + offset, s); - auto k_idx = arange(0, kL, s); - q_idx = expand_dims(q_idx, 1, s); - k_idx = expand_dims(k_idx, 0, s); - return greater_equal(q_idx, k_idx, s); + auto causal_mask = greater_equal(q_idx, k_idx, s); + if (window_size > 0) { + auto window_mask = + less(q_idx, add(k_idx, array(window_size, k_idx.dtype()), s), s); + return logical_and(causal_mask, window_mask, s); + } + return causal_mask; } + + if (window_size > 0) { + auto window_mask = + less(q_idx, add(k_idx, array(window_size, k_idx.dtype()), s), s); + if (!has_arr_mask) { + return window_mask; + } + auto mask = inputs[3]; + if (mask.dtype() == bool_) { + return logical_and(mask, window_mask, s); + } + auto additive_window_mask = where( + window_mask, + full_like(mask, 0, mask.dtype(), s), + full_like(mask, finfo(mask.dtype()).min, mask.dtype(), s), + s); + return add(mask, additive_window_mask, s); + } + return inputs[3]; }; auto mask = make_or_fetch_mask(); @@ -834,6 +868,7 @@ array scaled_dot_product_attention( do_causal, is_training, output_logsumexp, + window_size, stream)) { if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) { // Convert bool mask to additive mask. @@ -846,7 +881,13 @@ array scaled_dot_product_attention( } Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; auto primitive = std::make_shared( - stream, fallback, scale, do_causal, has_sinks, output_logsumexp); + stream, + fallback, + scale, + do_causal, + has_sinks, + output_logsumexp, + window_size); if (output_logsumexp) { return array::make_arrays( {std::move(out_shape), Shape{q.shape(0), q.shape(1), q.shape(2), 1}}, @@ -912,7 +953,8 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { static_cast(other); return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ && has_sinks_ == a_other.has_sinks_ && - output_logsumexp_ == a_other.output_logsumexp_; + output_logsumexp_ == a_other.output_logsumexp_ && + window_size_ == a_other.window_size_; } bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const { diff --git a/mlx/fast.h b/mlx/fast.h index 1183aba8fe..48dd115933 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -43,7 +43,12 @@ MLX_API array rope( const std::optional& freqs = std::nullopt, StreamOrDevice s = {}); -/** Computes: O = softmax(Q @ K.T) @ V **/ +/** Computes: O = softmax(Q @ K.T) @ V + * + * `window_size` (>0) enables a sliding-window attention pattern: query at + * absolute position q only attends to keys in [max(0, q - window_size + 1), q]. + * When 0 (default), the standard full-attention behavior is used. + **/ MLX_API array scaled_dot_product_attention( const array& queries, const array& keys, @@ -52,6 +57,7 @@ MLX_API array scaled_dot_product_attention( const std::string& mask_mode = "", std::optional mask_arr = {}, const std::optional& sinks = {}, + int window_size = 0, StreamOrDevice s = {}); using TemplateArg = std::variant; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..d80568be0e 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -211,12 +211,14 @@ class ScaledDotProductAttention : public Custom { float scale, bool do_causal, bool has_sinks, - bool output_logsumexp) + bool output_logsumexp, + int window_size = 0) : Custom(stream, std::move(fallback)), scale_(scale), do_causal_(do_causal), has_sinks_(has_sinks), - output_logsumexp_(output_logsumexp) {} + output_logsumexp_(output_logsumexp), + window_size_(window_size) {} static bool use_fallback( const array& q, @@ -227,6 +229,7 @@ class ScaledDotProductAttention : public Custom { bool do_causal, bool is_training, bool output_logsumexp, + int window_size, Stream s); static bool supports_bool_mask(); @@ -250,7 +253,12 @@ class ScaledDotProductAttention : public Custom { DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return std::make_tuple( - nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_); + nullptr, + scale_, + do_causal_, + has_sinks_, + output_logsumexp_, + window_size_); } private: @@ -258,6 +266,7 @@ class ScaledDotProductAttention : public Custom { bool do_causal_; bool has_sinks_; bool output_logsumexp_; + int window_size_; }; class ScaledDotProductAttentionVJP : public Custom { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..afa026fb71 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -197,6 +197,7 @@ void init_fast(nb::module_& parent_module) { const float scale, const std::variant& mask, const std::optional& sinks, + int window_size, mx::StreamOrDevice s) { bool has_mask = !std::holds_alternative(mask); bool has_str_mask = @@ -213,16 +214,32 @@ void init_fast(nb::module_& parent_module) { throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, mask_str, std::nullopt, sinks, s); + queries, + keys, + values, + scale, + mask_str, + std::nullopt, + sinks, + window_size, + s); } else { auto mask_arr = std::get(mask); return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", mask_arr, sinks, s); + queries, + keys, + values, + scale, + "", + mask_arr, + sinks, + window_size, + s); } } else { return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {}, sinks, s); + queries, keys, values, scale, "", {}, sinks, window_size, s); } }, "q"_a, @@ -232,9 +249,10 @@ void init_fast(nb::module_& parent_module) { "scale"_a, "mask"_a = nb::none(), "sinks"_a = nb::none(), + "window_size"_a = 0, "stream"_a = nb::none(), nb::sig( - "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, sinks: Optional[array] = None, window_size: int = 0, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. @@ -276,6 +294,9 @@ void init_fast(nb::module_& parent_module) { last query aligns with the last key. sinks (array, optional): An optional array of attention sinks. Default: ``None``. + window_size (int, optional): A sliding-window size. When greater + than zero, each query attends only to keys in the last + ``window_size`` positions. Default: ``0``. Returns: array: The output array.