Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlx/backend/cuda/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,11 @@ bool ScaledDotProductAttention::use_fallback(
bool do_causal,
bool is_training,
bool output_logsumexp,
int window_size,
Stream s) {
if (window_size > 0) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
Expand Down
37 changes: 36 additions & 1 deletion mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
constant bool has_window [[function_constant(303)]];

struct MaxOp {
template <typename T>
Expand Down Expand Up @@ -234,6 +235,7 @@ template <
}

int kb_lim = params->NK;
int kb_start = 0;
int kb_min_causal = params->NK;

if (do_causal) {
Expand All @@ -246,8 +248,19 @@ template <
kb_min_causal = (q_min / BK);
}

if (has_window) {
int q_min = tid.x * BQ + params->qL_off;
int k_min = q_min - params->window_size + 1;
if (k_min > 0) {
kb_start = k_min / BK;
if (kb_start > kb_lim) {
kb_start = kb_lim;
}
}
}

// Loop over KV seq length
for (int kb = 0; kb < kb_lim; kb++) {
for (int kb = kb_start; kb < kb_lim; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
Expand Down Expand Up @@ -325,6 +338,28 @@ template <
}
}

if (has_window && kb < (kb_start + ((BQ + BK - 1) / BK))) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = Limits<selem_t>::finite_min;

STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
const int row_pos =
tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if (row_pos - (col_pos + jj) >= params->window_size) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}

// Other masking as needed
if (has_mask) {
using stile_t = decltype(Stile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
constant bool has_window [[function_constant(303)]];

template <typename T>
struct TransformScale {
Expand Down Expand Up @@ -174,6 +175,7 @@ template <
}

int kb_lim = params->NK;
int kb_start = 0;
int kb_min_causal = params->NK;

if (do_causal) {
Expand All @@ -186,6 +188,17 @@ template <
kb_min_causal = (q_min / BK);
}

if (has_window) {
int q_min = tid.x * BQ + params->qL_off;
int k_min = q_min - params->window_size + 1;
if (k_min > 0) {
kb_start = k_min / BK;
if (kb_start > kb_lim) {
kb_start = kb_lim;
}
}
}

const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
// const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
const bool is_last_q = is_last_bq;
Expand All @@ -194,7 +207,7 @@ template <
const short lim_rows_k = params->kL_rem;

// Loop over KV seq length
for (int kb = 0; kb < kb_lim; kb++) {
for (int kb = kb_start; kb < kb_lim; kb++) {
const int is_last_k = (kb == (params->NK_aligned));

// Do S = Q @ K.T
Expand Down Expand Up @@ -303,6 +316,33 @@ template <
}
}

if (has_window && kb < (kb_start + ((BQ + BK - 1) / BK))) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;

const int base_row = tid.x * BQ + params->qL_off + tm;
const int base_col = kb * BK;

STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
thread auto& fg = Stile.frag_at(iq, ik);

STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < stile_t::kFragThrRows; ii++) {
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::kFragThrCols; jj++) {
const auto r =
base_row + iq * kU + ii * stile_t::kFragRowsJump + sm;
const auto c = base_col + ik * kU + jj + sn;
const auto loc = ii * stile_t::kFragThrCols + jj;
fg[loc] = ((r - c) >= params->window_size) ? neg_inf : fg[loc];
}
}
}
}
}

// Other masking as needed
if (has_mask) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/metal/kernels/steel/attn/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct AttnParams {
int kL_rem; ///< Remainder in last key/value block
int qL_off; ///< Offset in query sequence start

int window_size; ///< Sliding window size (0 = no window)

int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
Expand Down
29 changes: 23 additions & 6 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void sdpa_full_self_attention_nax(
const float scale,
array& o,
bool do_causal_,
int window_size_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
using namespace mlx::steel;
Expand All @@ -48,13 +49,15 @@ void sdpa_full_self_attention_nax(
const bool has_mask = mask.has_value();
const bool do_causal = do_causal_;
const bool has_sinks = sinks.has_value();
const bool has_window = window_size_ > 0;

metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
{&has_sinks, MTL::DataType::DataTypeBool, 302},
{&has_window, MTL::DataType::DataTypeBool, 303}};

std::string base_name;
concatenate(
Expand Down Expand Up @@ -87,7 +90,9 @@ void sdpa_full_self_attention_nax(
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
(has_sinks ? 't' : 'n'),
"_has_window_",
(has_window ? 't' : 'n'));

auto& compute_encoder = metal::get_command_encoder(s);

Expand Down Expand Up @@ -133,6 +138,8 @@ void sdpa_full_self_attention_nax(
/* int kL_rem = */ (kL - NK_aligned * bk),
/* int qL_off = */ (kL - qL),

/* int window_size = */ window_size_,

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
Expand Down Expand Up @@ -172,6 +179,7 @@ void sdpa_full_self_attention_metal(
const float scale,
array& o,
bool do_causal_,
int window_size_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
Expand All @@ -185,6 +193,7 @@ void sdpa_full_self_attention_metal(
/* const float scale = */ scale,
/* array& o = */ o,
/* bool do_causal_ = */ do_causal_,
/* int window_size_ = */ window_size_,
/* const std::optional<array>& mask = */ mask,
/* const std::optional<array>& sinks = */ sinks);
}
Expand All @@ -211,13 +220,15 @@ void sdpa_full_self_attention_metal(
const bool has_mask = mask.has_value();
const bool do_causal = do_causal_;
const bool has_sinks = sinks.has_value();
const bool has_window = window_size_ > 0;

metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
{&has_sinks, MTL::DataType::DataTypeBool, 302},
{&has_window, MTL::DataType::DataTypeBool, 303}};

std::string base_name;
concatenate(
Expand Down Expand Up @@ -250,7 +261,9 @@ void sdpa_full_self_attention_metal(
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
(has_sinks ? 't' : 'n'),
"_has_window_",
(has_window ? 't' : 'n'));

auto& compute_encoder = metal::get_command_encoder(s);

Expand Down Expand Up @@ -296,6 +309,8 @@ void sdpa_full_self_attention_metal(
/* int kL_rem = */ (kL - NK_aligned * bk),
/* int qL_off = */ (kL - qL),

/* int window_size = */ window_size_,

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
Expand Down Expand Up @@ -597,6 +612,7 @@ bool ScaledDotProductAttention::use_fallback(
bool do_causal,
bool is_training,
bool output_logsumexp,
int window_size,
Stream s) {
if (is_training) {
// It's faster for training on Metal to use the unfused SDPA for both
Expand All @@ -623,7 +639,8 @@ bool ScaledDotProductAttention::use_fallback(
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
(query_head_dim == 64 || query_head_dim == 80 ||
query_head_dim == 128 || (query_head_dim == 256 && window_size > 0));

const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
Expand Down Expand Up @@ -782,7 +799,7 @@ void ScaledDotProductAttention::eval_gpu(
: std::nullopt;

sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
s, d, q, k, v, scale_, o, do_causal_, window_size_, mask, sinks);
}

metal::get_command_encoder(s).add_temporaries(std::move(copies));
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_gpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ bool fast::ScaledDotProductAttention::use_fallback(
bool do_causal,
bool is_training,
bool output_logsumexp,
int window_size,
Stream s) {
return true;
}
Expand Down
Loading