From 4d998057dc93f16ebc9dabd7606f0a26cff63ddf Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 15:25:18 -0700 Subject: [PATCH 01/10] Add DSv4 sparse attention and indexer Triton ops --- aiter/ops/triton/__init__.py | 2 + .../_triton_kernels/attention/dsv4_indexer.py | 133 +++++++++++++++ .../attention/sparse_mqa_sink.py | 152 ++++++++++++++++++ aiter/ops/triton/attention/dsv4_indexer.py | 132 +++++++++++++++ aiter/ops/triton/attention/sparse_mqa_sink.py | 102 ++++++++++++ op_tests/test_dsv4_indexer.py | 53 ++++++ op_tests/test_sparse_mqa_sink.py | 49 ++++++ 7 files changed, 623 insertions(+) create mode 100644 aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py create mode 100644 aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py create mode 100644 aiter/ops/triton/attention/dsv4_indexer.py create mode 100644 aiter/ops/triton/attention/sparse_mqa_sink.py create mode 100644 op_tests/test_dsv4_indexer.py create mode 100644 op_tests/test_sparse_mqa_sink.py diff --git a/aiter/ops/triton/__init__.py b/aiter/ops/triton/__init__.py index bef4ccc506..c8057288bf 100644 --- a/aiter/ops/triton/__init__.py +++ b/aiter/ops/triton/__init__.py @@ -103,6 +103,8 @@ "pa_prefill": "attention.pa_prefill", "pod_attention": "attention.pod_attention", "prefill_attention": "attention.prefill_attention", + "dsv4_indexer": "attention.dsv4_indexer", + "sparse_mqa_sink": "attention.sparse_mqa_sink", "unified_attention_sparse_mla": "attention.unified_attention_sparse_mla", "unified_attention": "attention.unified_attention", # Fusions modules (fusions/) diff --git a/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py b/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py new file mode 100644 index 0000000000..cef46edd17 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py @@ -0,0 +1,133 @@ +import triton +import triton.language as tl + + +@triton.jit +def _dsv4_indexer_dense_kernel( + out_ptr, # [num_tokens, topk] + positions_ptr, # [num_tokens] + out_stride_t: tl.int64, + out_stride_k: tl.int64, + n_committed: tl.constexpr, + offset: tl.int32, + ratio: tl.constexpr, + BLOCK_K: tl.constexpr, +): + token_id = tl.program_id(0) + offs_k = tl.arange(0, BLOCK_K) + pos = tl.load(positions_ptr + token_id).to(tl.int32) + causal_limit = (pos + 1) // ratio + valid = (offs_k < n_committed) & (offs_k < causal_limit) + out = tl.where(valid, offs_k + offset, -1) + tl.store( + out_ptr + token_id * out_stride_t + offs_k * out_stride_k, + out, + mask=offs_k < n_committed, + ) + + +@triton.jit +def _dsv4_indexer_score_kernel( + score_ptr, # [num_tokens, kv_len], fp32 + q_ptr, # [num_tokens, num_heads, head_dim] + kv_ptr, # [kv_len, head_dim] + weights_ptr, # [num_tokens, num_heads] + positions_ptr, # [num_tokens] + q_stride_t: tl.int64, + q_stride_h: tl.int64, + q_stride_d: tl.int64, + kv_stride_t: tl.int64, + kv_stride_d: tl.int64, + weights_stride_t: tl.int64, + weights_stride_h: tl.int64, + score_stride_t: tl.int64, + score_stride_k: tl.int64, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_len: tl.constexpr, + ratio: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + offs_t = tile_id * BLOCK_T + tl.arange(0, BLOCK_T) + offs_d = tl.arange(0, BLOCK_D) + d_mask = offs_d < head_dim + acc = tl.zeros((BLOCK_T,), dtype=tl.float32) + + for h_start in range(0, num_heads, BLOCK_H): + offs_h = h_start + tl.arange(0, BLOCK_H) + h_mask = offs_h < num_heads + + q = tl.load( + q_ptr + + token_id * q_stride_t + + offs_h[:, None] * q_stride_h + + offs_d[None, :] * q_stride_d, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + kv = tl.load( + kv_ptr + offs_t[None, :] * kv_stride_t + offs_d[:, None] * kv_stride_d, + mask=(offs_t[None, :] < kv_len) & d_mask[:, None], + other=0.0, + cache_modifier=".cg", + ) + dots = tl.dot(q, kv) + dots = tl.maximum(dots, 0.0) + weights = tl.load( + weights_ptr + token_id * weights_stride_t + offs_h * weights_stride_h, + mask=h_mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + acc += tl.sum(dots * weights[:, None], axis=0) + + pos = tl.load(positions_ptr + token_id).to(tl.int32) + causal_limit = (pos + 1) // ratio + valid = (offs_t < kv_len) & (offs_t < causal_limit) + acc = tl.where(valid, acc, float("-inf")) + tl.store( + score_ptr + token_id * score_stride_t + offs_t * score_stride_k, + acc, + mask=offs_t < kv_len, + ) + + +@triton.jit +def _dsv4_indexer_finalize_kernel( + out_ptr, # [num_tokens, topk], int32 + values_ptr, # [num_tokens, topk], fp32 + indices_ptr, # [num_tokens, topk], int64 from aiter topk + out_stride_t: tl.int64, + out_stride_k: tl.int64, + values_stride_t: tl.int64, + values_stride_k: tl.int64, + indices_stride_t: tl.int64, + indices_stride_k: tl.int64, + offset: tl.int32, + topk: tl.constexpr, + BLOCK_K: tl.constexpr, +): + token_id = tl.program_id(0) + offs_k = tl.arange(0, BLOCK_K) + values = tl.load( + values_ptr + token_id * values_stride_t + offs_k * values_stride_k, + mask=offs_k < topk, + other=float("-inf"), + ) + indices = tl.load( + indices_ptr + token_id * indices_stride_t + offs_k * indices_stride_k, + mask=offs_k < topk, + other=-1, + ).to(tl.int32) + out = tl.where(values > -3.0e38, indices + offset, -1) + tl.store( + out_ptr + token_id * out_stride_t + offs_k * out_stride_k, + out, + mask=offs_k < topk, + ) diff --git a/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py b/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py new file mode 100644 index 0000000000..9d94c041d8 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py @@ -0,0 +1,152 @@ +import triton +import triton.language as tl + + +@triton.jit +def _find_seq_idx(cu_seqlens_q_ptr, token_idx, num_seqs): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(cu_seqlens_q_ptr + mid) + if val <= token_idx: + left = mid + 1 + else: + right = mid + return left - 1 + + +@triton.jit +def _sparse_mqa_sink_kernel( + out_ptr, # [num_tokens, num_heads, head_dim] + q_ptr, # [num_tokens, num_heads, head_dim] + kv_ptr, # [num_blocks, block_size, head_dim] + topk_ptr, # [num_tokens, topk] + attn_sink_ptr, # [num_heads] + block_table_ptr, # [num_seqs, max_blocks_per_seq] + cu_seqlens_q_ptr, # [num_seqs + 1] + seqused_k_ptr, # [num_seqs] + scale, + q_stride_t: tl.int64, + q_stride_h: tl.int64, + q_stride_d: tl.int64, + out_stride_t: tl.int64, + out_stride_h: tl.int64, + out_stride_d: tl.int64, + kv_stride_b: tl.int64, + kv_stride_s: tl.int64, + kv_stride_d: tl.int64, + topk_stride_t: tl.int64, + topk_stride_k: tl.int64, + block_table_stride_b: tl.int64, + block_table_stride_blk: tl.int64, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + topk_count: tl.constexpr, + block_size: tl.constexpr, + num_seqs: tl.int32, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, + TILE_K: tl.constexpr, +): + """Sparse MQA with DSv4's attention-sink denominator. + + One program handles one query token and BLOCK_H query heads. KV is MQA: + all query heads share the same [topk, head_dim] key/value rows. + """ + head_blocks: tl.constexpr = (num_heads + BLOCK_H - 1) // BLOCK_H + program_id = tl.program_id(0) + token_id = program_id // head_blocks + head_block = program_id % head_blocks + + seq_idx = _find_seq_idx(cu_seqlens_q_ptr, token_id, num_seqs) + seq_start = tl.load(cu_seqlens_q_ptr + seq_idx) + seq_end = tl.load(cu_seqlens_q_ptr + seq_idx + 1) + if token_id >= seq_end: + return + local_token = token_id - seq_start + kv_len = tl.load(seqused_k_ptr + seq_idx) + + offs_h = head_block * BLOCK_H + tl.arange(0, BLOCK_H) + offs_d = tl.arange(0, BLOCK_D) + h_mask = offs_h < num_heads + d_mask = offs_d < head_dim + + q = tl.load( + q_ptr + + token_id * q_stride_t + + offs_h[:, None] * q_stride_h + + offs_d[None, :] * q_stride_d, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + + sink = tl.load(attn_sink_ptr + offs_h, mask=h_mask, other=float("-inf")).to( + tl.float32 + ) + has_sink = sink > -3.0e38 + m_i = tl.where(has_sink, sink, float("-inf")) + l_i = tl.where(has_sink, 1.0, 0.0) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + for tile_start in range(0, topk_count, TILE_K): + offs_k = tile_start + tl.arange(0, TILE_K) + topk_pos = tl.load( + topk_ptr + token_id * topk_stride_t + offs_k * topk_stride_k, + mask=offs_k < topk_count, + other=-1, + ) + valid_k = (offs_k < topk_count) & (topk_pos >= 0) & (topk_pos < kv_len) + + logical_block = topk_pos // block_size + slot = topk_pos - logical_block * block_size + physical_block = tl.load( + block_table_ptr + + seq_idx * block_table_stride_b + + logical_block * block_table_stride_blk, + mask=valid_k, + other=0, + ) + + k = tl.load( + kv_ptr + + physical_block[None, :] * kv_stride_b + + slot[None, :] * kv_stride_s + + offs_d[:, None] * kv_stride_d, + mask=d_mask[:, None] & valid_k[None, :], + other=0.0, + cache_modifier=".cg", + ) + scores = tl.dot(q, k) * scale + scores = tl.where(h_mask[:, None] & valid_k[None, :], scores, float("-inf")) + + m_new = tl.maximum(m_i, tl.max(scores, axis=1)) + m_new = tl.where(m_new > float("-inf"), m_new, 0.0) + p = tl.exp(scores - m_new[:, None]) + alpha = tl.exp(m_i - m_new) + l_new = l_i * alpha + tl.sum(p, axis=1) + acc = acc * alpha[:, None] + + v = tl.load( + kv_ptr + + physical_block[:, None] * kv_stride_b + + slot[:, None] * kv_stride_s + + offs_d[None, :] * kv_stride_d, + mask=valid_k[:, None] & d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + acc += tl.dot(p.to(v.dtype), v) + m_i = m_new + l_i = l_new + + acc = acc * tl.where(l_i[:, None] > 0.0, 1.0 / l_i[:, None], 0.0) + tl.store( + out_ptr + + token_id * out_stride_t + + offs_h[:, None] * out_stride_h + + offs_d[None, :] * out_stride_d, + acc, + mask=h_mask[:, None] & d_mask[None, :], + ) diff --git a/aiter/ops/triton/attention/dsv4_indexer.py b/aiter/ops/triton/attention/dsv4_indexer.py new file mode 100644 index 0000000000..245f7025ac --- /dev/null +++ b/aiter/ops/triton/attention/dsv4_indexer.py @@ -0,0 +1,132 @@ +import torch +import triton + +from aiter.ops.triton._triton_kernels.attention.dsv4_indexer import ( + _dsv4_indexer_dense_kernel, + _dsv4_indexer_finalize_kernel, + _dsv4_indexer_score_kernel, +) +from aiter.ops.triton.topk import topk as _aiter_topk + + +def dsv4_indexer_topk( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + positions: torch.Tensor, + index_topk: int, + offset: int, + *, + ratio: int = 4, + block_t: int = 64, + block_h: int = 8, +) -> torch.Tensor: + """DeepSeek-V4 Indexer scorer + causal top-k. + + Computes the Indexer's learned sparse compressed-KV selection without + materializing the Torch fallback's [tokens, heads, committed] score tensor: + + score[t, k] = sum_h relu(q[t, h] @ kv[k]) * weights[t, h] + + Args: + q: [num_tokens, 64, 128], BF16/FP16/FP8-like storage accepted by Triton. + kv: [num_committed, 128], compressed Indexer KV. + weights: [num_tokens, 64], FP32/BF16, already includes model scaling. + positions: [num_tokens], absolute token positions. + index_topk: model top-k cap, 512 for V4-Flash or 1024 for V4-Pro. + offset: index offset into the sparse-attention [window || compressed] KV. + ratio: compression ratio. DSv4 CSA Indexer uses 4. + + Returns: + [num_tokens, min(index_topk, num_committed)] int32. Future entries are -1. + """ + assert q.dim() == 3, f"q must be [T, H, D], got {q.shape}" + assert kv.dim() == 2, f"kv must be [N, D], got {kv.shape}" + assert weights.dim() == 2, f"weights must be [T, H], got {weights.shape}" + assert positions.dim() == 1, f"positions must be [T], got {positions.shape}" + assert q.shape[0] == weights.shape[0] == positions.shape[0] + assert q.shape[1] == weights.shape[1] + assert q.shape[2] == kv.shape[1] + assert ratio > 0 + + num_tokens, num_heads, head_dim = q.shape + n_committed = kv.shape[0] + actual_topk = min(int(index_topk), n_committed) + if actual_topk <= 0: + return torch.empty((num_tokens, 0), device=q.device, dtype=torch.int32) + + q = q.contiguous() + kv = kv.contiguous() + weights = weights.contiguous() + positions = positions.contiguous() + out = torch.empty((num_tokens, actual_topk), device=q.device, dtype=torch.int32) + + # If top-k covers every committed compressed entry, the order does not + # affect downstream sparse attention. Emit dense causal indices and skip the + # expensive learned scorer entirely. This is the common 1k1k DSv4 case + # where n_committed=256 and index_topk is 512/1024. + if actual_topk == n_committed: + block_k = triton.next_power_of_2(max(actual_topk, 1)) + _dsv4_indexer_dense_kernel[(num_tokens,)]( + out, + positions, + out.stride(0), + out.stride(1), + n_committed, + int(offset), + int(ratio), + BLOCK_K=block_k, + num_warps=4, + num_stages=1, + ) + return out + + score = torch.empty((num_tokens, n_committed), device=q.device, dtype=torch.float32) + block_t = min(block_t, triton.next_power_of_2(max(n_committed, 1))) + block_h = min(block_h, triton.next_power_of_2(num_heads)) + block_d = triton.next_power_of_2(head_dim) + grid = (num_tokens, triton.cdiv(n_committed, block_t)) + _dsv4_indexer_score_kernel[grid]( + score, + q, + kv, + weights, + positions, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + weights.stride(0), + weights.stride(1), + score.stride(0), + score.stride(1), + num_heads, + head_dim, + n_committed, + int(ratio), + BLOCK_T=block_t, + BLOCK_H=block_h, + BLOCK_D=block_d, + num_warps=4, + num_stages=1, + ) + values, indices = _aiter_topk(score, actual_topk, dim=-1) + block_k = triton.next_power_of_2(max(actual_topk, 1)) + _dsv4_indexer_finalize_kernel[(num_tokens,)]( + out, + values, + indices, + out.stride(0), + out.stride(1), + values.stride(0), + values.stride(1), + indices.stride(0), + indices.stride(1), + int(offset), + actual_topk, + BLOCK_K=block_k, + num_warps=4, + num_stages=1, + ) + return out diff --git a/aiter/ops/triton/attention/sparse_mqa_sink.py b/aiter/ops/triton/attention/sparse_mqa_sink.py new file mode 100644 index 0000000000..681d8ccea2 --- /dev/null +++ b/aiter/ops/triton/attention/sparse_mqa_sink.py @@ -0,0 +1,102 @@ +import torch +import triton + +from aiter.ops.triton._triton_kernels.attention.sparse_mqa_sink import ( + _sparse_mqa_sink_kernel, +) + + +def sparse_mqa_sink( + q: torch.Tensor, + kv: torch.Tensor, + out: torch.Tensor, + cu_seqlens_q: torch.Tensor, + seqused_k: torch.Tensor, + softmax_scale: float, + topk_indices: torch.Tensor, + block_table: torch.Tensor, + attn_sink: torch.Tensor, + *, + tile_k: int = 64, + block_h: int = 8, +) -> torch.Tensor: + """Sparse MQA with DSv4 attention-sink semantics. + + Args: + q: [num_tokens, num_heads, head_dim], BF16/FP16. + kv: [num_blocks, block_size, head_dim], BF16/FP16. + out: [num_tokens, num_heads, head_dim], same dtype as q. + cu_seqlens_q: [num_seqs + 1], int32 token offsets. + seqused_k: [num_seqs], int32 logical KV lengths before padding. + softmax_scale: scalar multiplier for q @ k. + topk_indices: [num_tokens, topk], int32 logical KV positions. -1 is invalid. + block_table: [num_seqs, max_blocks_per_seq], int32 logical->physical block IDs. + attn_sink: [num_heads], FP32 sink logits included in the denominator only. + """ + assert q.dim() == 3, f"q must be [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"kv must be [num_blocks, block_size, D], got {kv.shape}" + assert out.shape == q.shape, f"out shape {out.shape} must match q {q.shape}" + assert topk_indices.dim() == 2 and topk_indices.shape[0] == q.shape[0] + assert cu_seqlens_q.dim() == 1 + assert seqused_k.dim() == 1 + assert block_table.dim() == 2 + assert attn_sink.shape == (q.shape[1],) + assert kv.shape[2] == q.shape[2] + + if q.numel() == 0: + return out + + q = q.contiguous() + kv = kv.contiguous() + topk_indices = topk_indices.contiguous() + cu_seqlens_q = cu_seqlens_q.contiguous() + seqused_k = seqused_k.contiguous() + block_table = block_table.contiguous() + attn_sink = attn_sink.contiguous() + + num_tokens, num_heads, head_dim = q.shape + block_size = kv.shape[1] + topk_count = topk_indices.shape[1] + num_seqs = seqused_k.shape[0] + + block_h = min(block_h, triton.next_power_of_2(num_heads)) + block_d = triton.next_power_of_2(head_dim) + tile_k = min(tile_k, triton.next_power_of_2(max(topk_count, 1))) + head_blocks = triton.cdiv(num_heads, block_h) + grid = (num_tokens * head_blocks,) + + _sparse_mqa_sink_kernel[grid]( + out, + q, + kv, + topk_indices, + attn_sink, + block_table, + cu_seqlens_q, + seqused_k, + float(softmax_scale), + q.stride(0), + q.stride(1), + q.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + topk_indices.stride(0), + topk_indices.stride(1), + block_table.stride(0), + block_table.stride(1), + num_heads, + head_dim, + topk_count, + block_size, + num_seqs, + BLOCK_H=block_h, + BLOCK_D=block_d, + TILE_K=tile_k, + num_warps=4, + num_stages=1, + ) + return out diff --git a/op_tests/test_dsv4_indexer.py b/op_tests/test_dsv4_indexer.py new file mode 100644 index 0000000000..79427b4335 --- /dev/null +++ b/op_tests/test_dsv4_indexer.py @@ -0,0 +1,53 @@ +import pytest +import torch + +from aiter.ops.triton.attention.dsv4_indexer import dsv4_indexer_topk + + +def _reference(q, kv, weights, positions, index_topk, offset, ratio=4): + qf = q.float() + kvf = kv.float() + wf = weights.float() + scores = torch.einsum("thd,nd->thn", qf, kvf) + scores = (scores.relu_() * wf.unsqueeze(-1)).sum(dim=1) + valid = torch.arange(kv.shape[0], device=q.device).unsqueeze(0) < ( + (positions.to(torch.long) + 1) // ratio + ).unsqueeze(1) + scores = scores.masked_fill(~valid, float("-inf")) + k = min(index_topk, kv.shape[0]) + if k == 0: + return torch.empty((q.shape[0], 0), device=q.device, dtype=torch.int32) + values, indices = scores.topk(k, dim=-1) + return torch.where(values > -3.0e38, indices.to(torch.int32) + offset, -1) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_dense_causal_indices(): + torch.manual_seed(0) + tokens, heads, dim, committed = 9, 64, 128, 16 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + 3 + + out = dsv4_indexer_topk(q, kv, weights, positions, 64, 128) + expected = torch.arange(committed, device="cuda", dtype=torch.int32).expand(tokens, -1) + 128 + valid = torch.arange(committed, device="cuda").unsqueeze(0) < ( + (positions + 1) // 4 + ).unsqueeze(1) + expected = torch.where(valid, expected, -1) + torch.testing.assert_close(out, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_scored_topk_matches_torch(): + torch.manual_seed(1) + tokens, heads, dim, committed, k = 7, 64, 128, 80, 12 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + committed * 4 + + out = dsv4_indexer_topk(q, kv, weights, positions, k, 128) + ref = _reference(q, kv, weights, positions, k, 128) + torch.testing.assert_close(out, ref) diff --git a/op_tests/test_sparse_mqa_sink.py b/op_tests/test_sparse_mqa_sink.py new file mode 100644 index 0000000000..a8013ea683 --- /dev/null +++ b/op_tests/test_sparse_mqa_sink.py @@ -0,0 +1,49 @@ +import pytest +import torch + +from aiter.ops.triton.attention.sparse_mqa_sink import sparse_mqa_sink + + +def _reference(q, kv, topk, attn_sink, scale): + t, h, d = q.shape + out = torch.empty_like(q) + qf = q.float() + kvf = kv.float() + for i in range(t): + valid = topk[i] >= 0 + if not bool(valid.any()): + out[i].zero_() + continue + idx = topk[i, valid].long() + k = kvf[idx] + scores = torch.matmul(qf[i], k.t()) * scale + combined = torch.cat([scores, attn_sink.float().view(h, 1)], dim=-1) + weights = torch.softmax(combined, dim=-1)[..., :-1] + out[i] = torch.matmul(weights, k).to(out.dtype) + return out + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("topk_count", [16, 48]) +def test_sparse_mqa_sink_matches_torch(topk_count): + torch.manual_seed(0) + tokens, heads, dim = 5, 16, 64 + kv_len, block_size = 73, 32 + num_blocks = (kv_len + block_size - 1) // block_size + padded = num_blocks * block_size + + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv_flat = torch.randn(padded, dim, device="cuda", dtype=torch.bfloat16) + kv_flat[kv_len:].zero_() + kv_blocks = kv_flat.view(num_blocks, block_size, dim) + topk = torch.randint(0, kv_len, (tokens, topk_count), device="cuda", dtype=torch.int32) + topk[0, -3:] = -1 + attn_sink = torch.randn(heads, device="cuda", dtype=torch.float32) + cu = torch.tensor([0, tokens], device="cuda", dtype=torch.int32) + seqused = torch.tensor([kv_len], device="cuda", dtype=torch.int32) + block_table = torch.arange(num_blocks, device="cuda", dtype=torch.int32).view(1, -1) + out = torch.empty_like(q) + + sparse_mqa_sink(q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink) + ref = _reference(q, kv_flat[:kv_len], topk, attn_sink, dim**-0.5) + torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) From c2eafc5b6cdea49c7719259874081de3346f97a5 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 15:51:49 -0700 Subject: [PATCH 02/10] Tile DSv4 sparse MQA output dimension --- .../attention/sparse_mqa_sink.py | 59 +++++++++++-------- aiter/ops/triton/attention/sparse_mqa_sink.py | 9 ++- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py b/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py index 9d94c041d8..e7872c8b1b 100644 --- a/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py +++ b/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py @@ -47,17 +47,19 @@ def _sparse_mqa_sink_kernel( num_seqs: tl.int32, BLOCK_H: tl.constexpr, BLOCK_D: tl.constexpr, + SCORE_D: tl.constexpr, TILE_K: tl.constexpr, ): """Sparse MQA with DSv4's attention-sink denominator. - One program handles one query token and BLOCK_H query heads. KV is MQA: - all query heads share the same [topk, head_dim] key/value rows. + One program handles one query token, BLOCK_H query heads, and one output + dimension tile. KV is MQA: all query heads share the same [topk, head_dim] + key/value rows. """ head_blocks: tl.constexpr = (num_heads + BLOCK_H - 1) // BLOCK_H - program_id = tl.program_id(0) - token_id = program_id // head_blocks - head_block = program_id % head_blocks + token_id = tl.program_id(0) + head_block = tl.program_id(1) + dim_block = tl.program_id(2) seq_idx = _find_seq_idx(cu_seqlens_q_ptr, token_id, num_seqs) seq_start = tl.load(cu_seqlens_q_ptr + seq_idx) @@ -68,20 +70,11 @@ def _sparse_mqa_sink_kernel( kv_len = tl.load(seqused_k_ptr + seq_idx) offs_h = head_block * BLOCK_H + tl.arange(0, BLOCK_H) - offs_d = tl.arange(0, BLOCK_D) + offs_d = dim_block * BLOCK_D + tl.arange(0, BLOCK_D) + offs_score_d = tl.arange(0, SCORE_D) h_mask = offs_h < num_heads d_mask = offs_d < head_dim - q = tl.load( - q_ptr - + token_id * q_stride_t - + offs_h[:, None] * q_stride_h - + offs_d[None, :] * q_stride_d, - mask=h_mask[:, None] & d_mask[None, :], - other=0.0, - cache_modifier=".cg", - ) - sink = tl.load(attn_sink_ptr + offs_h, mask=h_mask, other=float("-inf")).to( tl.float32 ) @@ -109,16 +102,30 @@ def _sparse_mqa_sink_kernel( other=0, ) - k = tl.load( - kv_ptr - + physical_block[None, :] * kv_stride_b - + slot[None, :] * kv_stride_s - + offs_d[:, None] * kv_stride_d, - mask=d_mask[:, None] & valid_k[None, :], - other=0.0, - cache_modifier=".cg", - ) - scores = tl.dot(q, k) * scale + scores = tl.zeros((BLOCK_H, TILE_K), dtype=tl.float32) + for d_start in range(0, head_dim, SCORE_D): + score_d = d_start + offs_score_d + score_d_mask = score_d < head_dim + q = tl.load( + q_ptr + + token_id * q_stride_t + + offs_h[:, None] * q_stride_h + + score_d[None, :] * q_stride_d, + mask=h_mask[:, None] & score_d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + k = tl.load( + kv_ptr + + physical_block[None, :] * kv_stride_b + + slot[None, :] * kv_stride_s + + score_d[:, None] * kv_stride_d, + mask=score_d_mask[:, None] & valid_k[None, :], + other=0.0, + cache_modifier=".cg", + ) + scores += tl.dot(q, k) + scores *= scale scores = tl.where(h_mask[:, None] & valid_k[None, :], scores, float("-inf")) m_new = tl.maximum(m_i, tl.max(scores, axis=1)) diff --git a/aiter/ops/triton/attention/sparse_mqa_sink.py b/aiter/ops/triton/attention/sparse_mqa_sink.py index 681d8ccea2..9748f9d815 100644 --- a/aiter/ops/triton/attention/sparse_mqa_sink.py +++ b/aiter/ops/triton/attention/sparse_mqa_sink.py @@ -19,6 +19,8 @@ def sparse_mqa_sink( *, tile_k: int = 64, block_h: int = 8, + block_d: int = 64, + score_d: int = 64, ) -> torch.Tensor: """Sparse MQA with DSv4 attention-sink semantics. @@ -60,10 +62,12 @@ def sparse_mqa_sink( num_seqs = seqused_k.shape[0] block_h = min(block_h, triton.next_power_of_2(num_heads)) - block_d = triton.next_power_of_2(head_dim) + block_d = min(block_d, triton.next_power_of_2(head_dim)) + score_d = min(score_d, triton.next_power_of_2(head_dim)) tile_k = min(tile_k, triton.next_power_of_2(max(topk_count, 1))) head_blocks = triton.cdiv(num_heads, block_h) - grid = (num_tokens * head_blocks,) + dim_blocks = triton.cdiv(head_dim, block_d) + grid = (num_tokens, head_blocks, dim_blocks) _sparse_mqa_sink_kernel[grid]( out, @@ -95,6 +99,7 @@ def sparse_mqa_sink( num_seqs, BLOCK_H=block_h, BLOCK_D=block_d, + SCORE_D=score_d, TILE_K=tile_k, num_warps=4, num_stages=1, From 13b6b39c3f41fabc20c087073004defcd9f5f61f Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 18:19:32 -0700 Subject: [PATCH 03/10] Retile DSv4 sparse MQA sink --- aiter/ops/triton/attention/sparse_mqa_sink.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aiter/ops/triton/attention/sparse_mqa_sink.py b/aiter/ops/triton/attention/sparse_mqa_sink.py index 9748f9d815..a2b4dd9c54 100644 --- a/aiter/ops/triton/attention/sparse_mqa_sink.py +++ b/aiter/ops/triton/attention/sparse_mqa_sink.py @@ -18,8 +18,8 @@ def sparse_mqa_sink( attn_sink: torch.Tensor, *, tile_k: int = 64, - block_h: int = 8, - block_d: int = 64, + block_h: int = 4, + block_d: int = 128, score_d: int = 64, ) -> torch.Tensor: """Sparse MQA with DSv4 attention-sink semantics. @@ -61,6 +61,9 @@ def sparse_mqa_sink( topk_count = topk_indices.shape[1] num_seqs = seqused_k.shape[0] + # Keep the accumulator footprint comparable to the original 8x64 tile + # while halving output-D tiles. That cuts repeated QK score work for + # DSv4's 512-wide value vector from 8x to 4x. block_h = min(block_h, triton.next_power_of_2(num_heads)) block_d = min(block_d, triton.next_power_of_2(head_dim)) score_d = min(score_d, triton.next_power_of_2(head_dim)) From 2616197cfc4901537405a4e2041d1a1b46600c2f Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 22:05:13 -0700 Subject: [PATCH 04/10] rm tests --- op_tests/test_dsv4_indexer.py | 53 -------------------------------- op_tests/test_sparse_mqa_sink.py | 49 ----------------------------- 2 files changed, 102 deletions(-) delete mode 100644 op_tests/test_dsv4_indexer.py delete mode 100644 op_tests/test_sparse_mqa_sink.py diff --git a/op_tests/test_dsv4_indexer.py b/op_tests/test_dsv4_indexer.py deleted file mode 100644 index 79427b4335..0000000000 --- a/op_tests/test_dsv4_indexer.py +++ /dev/null @@ -1,53 +0,0 @@ -import pytest -import torch - -from aiter.ops.triton.attention.dsv4_indexer import dsv4_indexer_topk - - -def _reference(q, kv, weights, positions, index_topk, offset, ratio=4): - qf = q.float() - kvf = kv.float() - wf = weights.float() - scores = torch.einsum("thd,nd->thn", qf, kvf) - scores = (scores.relu_() * wf.unsqueeze(-1)).sum(dim=1) - valid = torch.arange(kv.shape[0], device=q.device).unsqueeze(0) < ( - (positions.to(torch.long) + 1) // ratio - ).unsqueeze(1) - scores = scores.masked_fill(~valid, float("-inf")) - k = min(index_topk, kv.shape[0]) - if k == 0: - return torch.empty((q.shape[0], 0), device=q.device, dtype=torch.int32) - values, indices = scores.topk(k, dim=-1) - return torch.where(values > -3.0e38, indices.to(torch.int32) + offset, -1) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -def test_dsv4_indexer_dense_causal_indices(): - torch.manual_seed(0) - tokens, heads, dim, committed = 9, 64, 128, 16 - q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) - kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) - weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) - positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + 3 - - out = dsv4_indexer_topk(q, kv, weights, positions, 64, 128) - expected = torch.arange(committed, device="cuda", dtype=torch.int32).expand(tokens, -1) + 128 - valid = torch.arange(committed, device="cuda").unsqueeze(0) < ( - (positions + 1) // 4 - ).unsqueeze(1) - expected = torch.where(valid, expected, -1) - torch.testing.assert_close(out, expected) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -def test_dsv4_indexer_scored_topk_matches_torch(): - torch.manual_seed(1) - tokens, heads, dim, committed, k = 7, 64, 128, 80, 12 - q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) - kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) - weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) - positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + committed * 4 - - out = dsv4_indexer_topk(q, kv, weights, positions, k, 128) - ref = _reference(q, kv, weights, positions, k, 128) - torch.testing.assert_close(out, ref) diff --git a/op_tests/test_sparse_mqa_sink.py b/op_tests/test_sparse_mqa_sink.py deleted file mode 100644 index a8013ea683..0000000000 --- a/op_tests/test_sparse_mqa_sink.py +++ /dev/null @@ -1,49 +0,0 @@ -import pytest -import torch - -from aiter.ops.triton.attention.sparse_mqa_sink import sparse_mqa_sink - - -def _reference(q, kv, topk, attn_sink, scale): - t, h, d = q.shape - out = torch.empty_like(q) - qf = q.float() - kvf = kv.float() - for i in range(t): - valid = topk[i] >= 0 - if not bool(valid.any()): - out[i].zero_() - continue - idx = topk[i, valid].long() - k = kvf[idx] - scores = torch.matmul(qf[i], k.t()) * scale - combined = torch.cat([scores, attn_sink.float().view(h, 1)], dim=-1) - weights = torch.softmax(combined, dim=-1)[..., :-1] - out[i] = torch.matmul(weights, k).to(out.dtype) - return out - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -@pytest.mark.parametrize("topk_count", [16, 48]) -def test_sparse_mqa_sink_matches_torch(topk_count): - torch.manual_seed(0) - tokens, heads, dim = 5, 16, 64 - kv_len, block_size = 73, 32 - num_blocks = (kv_len + block_size - 1) // block_size - padded = num_blocks * block_size - - q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) - kv_flat = torch.randn(padded, dim, device="cuda", dtype=torch.bfloat16) - kv_flat[kv_len:].zero_() - kv_blocks = kv_flat.view(num_blocks, block_size, dim) - topk = torch.randint(0, kv_len, (tokens, topk_count), device="cuda", dtype=torch.int32) - topk[0, -3:] = -1 - attn_sink = torch.randn(heads, device="cuda", dtype=torch.float32) - cu = torch.tensor([0, tokens], device="cuda", dtype=torch.int32) - seqused = torch.tensor([kv_len], device="cuda", dtype=torch.int32) - block_table = torch.arange(num_blocks, device="cuda", dtype=torch.int32).view(1, -1) - out = torch.empty_like(q) - - sparse_mqa_sink(q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink) - ref = _reference(q, kv_flat[:kv_len], topk, attn_sink, dim**-0.5) - torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) From 7c5fd02d67425cdcc7de7f3fe97c79adfdf40d12 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Fri, 1 May 2026 08:21:38 -0700 Subject: [PATCH 05/10] Revert "rm tests" This reverts commit 2616197cfc4901537405a4e2041d1a1b46600c2f. --- op_tests/test_dsv4_indexer.py | 53 ++++++++++++++++++++++++++++++++ op_tests/test_sparse_mqa_sink.py | 49 +++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 op_tests/test_dsv4_indexer.py create mode 100644 op_tests/test_sparse_mqa_sink.py diff --git a/op_tests/test_dsv4_indexer.py b/op_tests/test_dsv4_indexer.py new file mode 100644 index 0000000000..79427b4335 --- /dev/null +++ b/op_tests/test_dsv4_indexer.py @@ -0,0 +1,53 @@ +import pytest +import torch + +from aiter.ops.triton.attention.dsv4_indexer import dsv4_indexer_topk + + +def _reference(q, kv, weights, positions, index_topk, offset, ratio=4): + qf = q.float() + kvf = kv.float() + wf = weights.float() + scores = torch.einsum("thd,nd->thn", qf, kvf) + scores = (scores.relu_() * wf.unsqueeze(-1)).sum(dim=1) + valid = torch.arange(kv.shape[0], device=q.device).unsqueeze(0) < ( + (positions.to(torch.long) + 1) // ratio + ).unsqueeze(1) + scores = scores.masked_fill(~valid, float("-inf")) + k = min(index_topk, kv.shape[0]) + if k == 0: + return torch.empty((q.shape[0], 0), device=q.device, dtype=torch.int32) + values, indices = scores.topk(k, dim=-1) + return torch.where(values > -3.0e38, indices.to(torch.int32) + offset, -1) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_dense_causal_indices(): + torch.manual_seed(0) + tokens, heads, dim, committed = 9, 64, 128, 16 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + 3 + + out = dsv4_indexer_topk(q, kv, weights, positions, 64, 128) + expected = torch.arange(committed, device="cuda", dtype=torch.int32).expand(tokens, -1) + 128 + valid = torch.arange(committed, device="cuda").unsqueeze(0) < ( + (positions + 1) // 4 + ).unsqueeze(1) + expected = torch.where(valid, expected, -1) + torch.testing.assert_close(out, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_scored_topk_matches_torch(): + torch.manual_seed(1) + tokens, heads, dim, committed, k = 7, 64, 128, 80, 12 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + committed * 4 + + out = dsv4_indexer_topk(q, kv, weights, positions, k, 128) + ref = _reference(q, kv, weights, positions, k, 128) + torch.testing.assert_close(out, ref) diff --git a/op_tests/test_sparse_mqa_sink.py b/op_tests/test_sparse_mqa_sink.py new file mode 100644 index 0000000000..a8013ea683 --- /dev/null +++ b/op_tests/test_sparse_mqa_sink.py @@ -0,0 +1,49 @@ +import pytest +import torch + +from aiter.ops.triton.attention.sparse_mqa_sink import sparse_mqa_sink + + +def _reference(q, kv, topk, attn_sink, scale): + t, h, d = q.shape + out = torch.empty_like(q) + qf = q.float() + kvf = kv.float() + for i in range(t): + valid = topk[i] >= 0 + if not bool(valid.any()): + out[i].zero_() + continue + idx = topk[i, valid].long() + k = kvf[idx] + scores = torch.matmul(qf[i], k.t()) * scale + combined = torch.cat([scores, attn_sink.float().view(h, 1)], dim=-1) + weights = torch.softmax(combined, dim=-1)[..., :-1] + out[i] = torch.matmul(weights, k).to(out.dtype) + return out + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("topk_count", [16, 48]) +def test_sparse_mqa_sink_matches_torch(topk_count): + torch.manual_seed(0) + tokens, heads, dim = 5, 16, 64 + kv_len, block_size = 73, 32 + num_blocks = (kv_len + block_size - 1) // block_size + padded = num_blocks * block_size + + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv_flat = torch.randn(padded, dim, device="cuda", dtype=torch.bfloat16) + kv_flat[kv_len:].zero_() + kv_blocks = kv_flat.view(num_blocks, block_size, dim) + topk = torch.randint(0, kv_len, (tokens, topk_count), device="cuda", dtype=torch.int32) + topk[0, -3:] = -1 + attn_sink = torch.randn(heads, device="cuda", dtype=torch.float32) + cu = torch.tensor([0, tokens], device="cuda", dtype=torch.int32) + seqused = torch.tensor([kv_len], device="cuda", dtype=torch.int32) + block_table = torch.arange(num_blocks, device="cuda", dtype=torch.int32).view(1, -1) + out = torch.empty_like(q) + + sparse_mqa_sink(q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink) + ref = _reference(q, kv_flat[:kv_len], topk, attn_sink, dim**-0.5) + torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) From 19d07bd936315ef89825d07a0c1740db63021b40 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Fri, 1 May 2026 08:24:01 -0700 Subject: [PATCH 06/10] Address sparse MQA review comments --- .../attention/sparse_mqa_sink.py | 3 --- aiter/ops/triton/attention/sparse_mqa_sink.py | 25 +++++++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py b/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py index e7872c8b1b..3ecee6eb42 100644 --- a/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py +++ b/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py @@ -56,17 +56,14 @@ def _sparse_mqa_sink_kernel( dimension tile. KV is MQA: all query heads share the same [topk, head_dim] key/value rows. """ - head_blocks: tl.constexpr = (num_heads + BLOCK_H - 1) // BLOCK_H token_id = tl.program_id(0) head_block = tl.program_id(1) dim_block = tl.program_id(2) seq_idx = _find_seq_idx(cu_seqlens_q_ptr, token_id, num_seqs) - seq_start = tl.load(cu_seqlens_q_ptr + seq_idx) seq_end = tl.load(cu_seqlens_q_ptr + seq_idx + 1) if token_id >= seq_end: return - local_token = token_id - seq_start kv_len = tl.load(seqused_k_ptr + seq_idx) offs_h = head_block * BLOCK_H + tl.arange(0, BLOCK_H) diff --git a/aiter/ops/triton/attention/sparse_mqa_sink.py b/aiter/ops/triton/attention/sparse_mqa_sink.py index a2b4dd9c54..67eae672dc 100644 --- a/aiter/ops/triton/attention/sparse_mqa_sink.py +++ b/aiter/ops/triton/attention/sparse_mqa_sink.py @@ -45,9 +45,29 @@ def sparse_mqa_sink( assert attn_sink.shape == (q.shape[1],) assert kv.shape[2] == q.shape[2] + num_tokens, num_heads, head_dim = q.shape + block_size = kv.shape[1] + topk_count = topk_indices.shape[1] + num_seqs = seqused_k.shape[0] + assert ( + cu_seqlens_q.shape[0] == num_seqs + 1 + ), ( + "cu_seqlens_q must have length num_seqs + 1, " + f"got {cu_seqlens_q.shape[0]} vs {num_seqs + 1}" + ) + assert ( + block_table.shape[0] == num_seqs + ), f"block_table rows {block_table.shape[0]} must match num_seqs {num_seqs}" + if q.numel() == 0: return out + assert num_seqs > 0, "non-empty q requires at least one sequence" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert ( + cu_seqlens_q[-1] == num_tokens + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} must equal num_tokens {num_tokens}" + q = q.contiguous() kv = kv.contiguous() topk_indices = topk_indices.contiguous() @@ -56,11 +76,6 @@ def sparse_mqa_sink( block_table = block_table.contiguous() attn_sink = attn_sink.contiguous() - num_tokens, num_heads, head_dim = q.shape - block_size = kv.shape[1] - topk_count = topk_indices.shape[1] - num_seqs = seqused_k.shape[0] - # Keep the accumulator footprint comparable to the original 8x64 tile # while halving output-D tiles. That cuts repeated QK score work for # DSv4's 512-wide value vector from 8x to 4x. From 914786bfdd820a5c305e342feee41333291e0275 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Fri, 1 May 2026 08:26:05 -0700 Subject: [PATCH 07/10] Avoid sync in sparse MQA input checks --- aiter/ops/triton/attention/sparse_mqa_sink.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/aiter/ops/triton/attention/sparse_mqa_sink.py b/aiter/ops/triton/attention/sparse_mqa_sink.py index 67eae672dc..eb59b60f08 100644 --- a/aiter/ops/triton/attention/sparse_mqa_sink.py +++ b/aiter/ops/triton/attention/sparse_mqa_sink.py @@ -63,10 +63,15 @@ def sparse_mqa_sink( return out assert num_seqs > 0, "non-empty q requires at least one sequence" - assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" - assert ( - cu_seqlens_q[-1] == num_tokens - ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} must equal num_tokens {num_tokens}" + # Keep value checks on-device to avoid synchronizing this hot path. + if hasattr(torch, "_assert_async"): + torch._assert_async(cu_seqlens_q[0] == 0) + torch._assert_async(cu_seqlens_q[-1] == num_tokens) + else: + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert ( + cu_seqlens_q[-1] == num_tokens + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} must equal num_tokens {num_tokens}" q = q.contiguous() kv = kv.contiguous() From 0923d27163ae5b722be27ea980e447fe6c3c7308 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Fri, 1 May 2026 08:59:01 -0700 Subject: [PATCH 08/10] Add batched DSv4 indexer coverage --- .../_triton_kernels/attention/dsv4_indexer.py | 108 ++++++++++ aiter/ops/triton/attention/dsv4_indexer.py | 195 ++++++++++++++---- aiter/ops/triton/attention/sparse_mqa_sink.py | 29 ++- op_tests/test_dsv4_indexer.py | 119 ++++++++++- op_tests/test_sparse_mqa_sink.py | 99 ++++++++- 5 files changed, 491 insertions(+), 59 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py b/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py index cef46edd17..9d0f78625f 100644 --- a/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py +++ b/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py @@ -26,6 +26,34 @@ def _dsv4_indexer_dense_kernel( ) +@triton.jit +def _dsv4_indexer_dense_batched_kernel( + out_ptr, # [num_tokens, topk] + positions_ptr, # [num_tokens] + seq_ids_ptr, # [num_tokens] + kv_lens_ptr, # [num_seqs] + out_stride_t: tl.int64, + out_stride_k: tl.int64, + n_committed: tl.constexpr, + offset: tl.int32, + ratio: tl.constexpr, + BLOCK_K: tl.constexpr, +): + token_id = tl.program_id(0) + offs_k = tl.arange(0, BLOCK_K) + seq_id = tl.load(seq_ids_ptr + token_id).to(tl.int32) + kv_len = tl.load(kv_lens_ptr + seq_id).to(tl.int32) + pos = tl.load(positions_ptr + token_id).to(tl.int32) + causal_limit = (pos + 1) // ratio + valid = (offs_k < n_committed) & (offs_k < kv_len) & (offs_k < causal_limit) + out = tl.where(valid, offs_k + offset, -1) + tl.store( + out_ptr + token_id * out_stride_t + offs_k * out_stride_k, + out, + mask=offs_k < n_committed, + ) + + @triton.jit def _dsv4_indexer_score_kernel( score_ptr, # [num_tokens, kv_len], fp32 @@ -98,6 +126,86 @@ def _dsv4_indexer_score_kernel( ) +@triton.jit +def _dsv4_indexer_score_batched_kernel( + score_ptr, # [num_tokens, kv_len], fp32 + q_ptr, # [num_tokens, num_heads, head_dim] + kv_ptr, # [num_seqs, kv_len, head_dim] + weights_ptr, # [num_tokens, num_heads] + positions_ptr, # [num_tokens] + seq_ids_ptr, # [num_tokens] + kv_lens_ptr, # [num_seqs] + q_stride_t: tl.int64, + q_stride_h: tl.int64, + q_stride_d: tl.int64, + kv_stride_b: tl.int64, + kv_stride_t: tl.int64, + kv_stride_d: tl.int64, + weights_stride_t: tl.int64, + weights_stride_h: tl.int64, + score_stride_t: tl.int64, + score_stride_k: tl.int64, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_len_max: tl.constexpr, + ratio: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + seq_id = tl.load(seq_ids_ptr + token_id).to(tl.int32) + kv_len = tl.load(kv_lens_ptr + seq_id).to(tl.int32) + offs_t = tile_id * BLOCK_T + tl.arange(0, BLOCK_T) + offs_d = tl.arange(0, BLOCK_D) + d_mask = offs_d < head_dim + acc = tl.zeros((BLOCK_T,), dtype=tl.float32) + + for h_start in range(0, num_heads, BLOCK_H): + offs_h = h_start + tl.arange(0, BLOCK_H) + h_mask = offs_h < num_heads + + q = tl.load( + q_ptr + + token_id * q_stride_t + + offs_h[:, None] * q_stride_h + + offs_d[None, :] * q_stride_d, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + kv = tl.load( + kv_ptr + + seq_id * kv_stride_b + + offs_t[None, :] * kv_stride_t + + offs_d[:, None] * kv_stride_d, + mask=(offs_t[None, :] < kv_len) & d_mask[:, None], + other=0.0, + cache_modifier=".cg", + ) + dots = tl.dot(q, kv) + dots = tl.maximum(dots, 0.0) + weights = tl.load( + weights_ptr + token_id * weights_stride_t + offs_h * weights_stride_h, + mask=h_mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + acc += tl.sum(dots * weights[:, None], axis=0) + + pos = tl.load(positions_ptr + token_id).to(tl.int32) + causal_limit = (pos + 1) // ratio + valid = (offs_t < kv_len_max) & (offs_t < kv_len) & (offs_t < causal_limit) + acc = tl.where(valid, acc, float("-inf")) + tl.store( + score_ptr + token_id * score_stride_t + offs_t * score_stride_k, + acc, + mask=offs_t < kv_len_max, + ) + + @triton.jit def _dsv4_indexer_finalize_kernel( out_ptr, # [num_tokens, topk], int32 diff --git a/aiter/ops/triton/attention/dsv4_indexer.py b/aiter/ops/triton/attention/dsv4_indexer.py index 245f7025ac..dd4d9d0646 100644 --- a/aiter/ops/triton/attention/dsv4_indexer.py +++ b/aiter/ops/triton/attention/dsv4_indexer.py @@ -2,12 +2,16 @@ import triton from aiter.ops.triton._triton_kernels.attention.dsv4_indexer import ( + _dsv4_indexer_dense_batched_kernel, _dsv4_indexer_dense_kernel, _dsv4_indexer_finalize_kernel, + _dsv4_indexer_score_batched_kernel, _dsv4_indexer_score_kernel, ) from aiter.ops.triton.topk import topk as _aiter_topk +_DEQUANT_DTYPES = (torch.float16, torch.bfloat16) + def dsv4_indexer_topk( q: torch.Tensor, @@ -17,6 +21,8 @@ def dsv4_indexer_topk( index_topk: int, offset: int, *, + seq_ids: torch.Tensor | None = None, + kv_lens: torch.Tensor | None = None, ratio: int = 4, block_t: int = 64, block_h: int = 8, @@ -29,36 +35,93 @@ def dsv4_indexer_topk( score[t, k] = sum_h relu(q[t, h] @ kv[k]) * weights[t, h] Args: - q: [num_tokens, 64, 128], BF16/FP16/FP8-like storage accepted by Triton. - kv: [num_committed, 128], compressed Indexer KV. + q: [num_tokens, 64, 128], dequantized BF16/FP16. + kv: [num_committed, 128] or [num_seqs, max_committed, 128], + dequantized BF16/FP16 compressed Indexer KV. weights: [num_tokens, 64], FP32/BF16, already includes model scaling. positions: [num_tokens], absolute token positions. index_topk: model top-k cap, 512 for V4-Flash or 1024 for V4-Pro. offset: index offset into the sparse-attention [window || compressed] KV. + seq_ids: optional [num_tokens] int32/int64 sequence IDs. Required when + kv is batched. + kv_lens: optional [num_seqs] int32/int64 committed KV length per sequence. + Required when kv is batched and shorter than max_committed. ratio: compression ratio. DSv4 CSA Indexer uses 4. Returns: - [num_tokens, min(index_topk, num_committed)] int32. Future entries are -1. + [num_tokens, min(index_topk, max_committed)] int32. Future entries are -1. + + This op does not unpack native DSv4 FP4/FP8 cache layouts or apply their + scale tensors. Callers must pass dequantized BF16/FP16 Q/KV tensors. """ assert q.dim() == 3, f"q must be [T, H, D], got {q.shape}" - assert kv.dim() == 2, f"kv must be [N, D], got {kv.shape}" + assert kv.dim() in (2, 3), f"kv must be [N, D] or [B, N, D], got {kv.shape}" assert weights.dim() == 2, f"weights must be [T, H], got {weights.shape}" assert positions.dim() == 1, f"positions must be [T], got {positions.shape}" + assert positions.dtype in (torch.int32, torch.int64) assert q.shape[0] == weights.shape[0] == positions.shape[0] assert q.shape[1] == weights.shape[1] - assert q.shape[2] == kv.shape[1] + assert q.is_cuda and kv.is_cuda, "q and kv must be CUDA tensors" + assert ( + weights.device == q.device + and positions.device == q.device + and kv.device == q.device + ), "q, kv, weights, and positions must be on the same device" + assert q.dtype in _DEQUANT_DTYPES, f"q must be dequantized BF16/FP16, got {q.dtype}" + assert ( + kv.dtype in _DEQUANT_DTYPES + ), f"kv must be dequantized BF16/FP16, got {kv.dtype}" + assert weights.dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ), f"weights must be FP16/BF16/FP32, got {weights.dtype}" + assert q.shape[2] == kv.shape[-1] + assert index_topk >= 0 assert ratio > 0 num_tokens, num_heads, head_dim = q.shape - n_committed = kv.shape[0] + is_batched = kv.dim() == 3 + n_committed = kv.shape[1] if is_batched else kv.shape[0] + if is_batched: + assert seq_ids is not None, "seq_ids is required when kv is batched" + assert seq_ids.dim() == 1 and seq_ids.shape[0] == num_tokens + assert seq_ids.device == q.device, "seq_ids must be on the same device as q" + assert seq_ids.dtype in (torch.int32, torch.int64) + if kv_lens is None: + kv_lens = torch.full( + (kv.shape[0],), n_committed, device=kv.device, dtype=torch.int32 + ) + assert kv_lens.dim() == 1 and kv_lens.shape[0] == kv.shape[0] + assert kv_lens.device == q.device, "kv_lens must be on the same device as q" + assert kv_lens.dtype in (torch.int32, torch.int64) + if hasattr(torch, "_assert_async"): + torch._assert_async(((seq_ids >= 0) & (seq_ids < kv.shape[0])).all()) + torch._assert_async(((kv_lens >= 0) & (kv_lens <= n_committed)).all()) + else: + assert bool( + ((seq_ids >= 0) & (seq_ids < kv.shape[0])).all() + ), "seq_ids must be in range" + assert bool( + ((kv_lens >= 0) & (kv_lens <= n_committed)).all() + ), "kv_lens must be in range" + else: + assert seq_ids is None, "seq_ids requires batched kv" + assert kv_lens is None, "kv_lens requires batched kv" actual_topk = min(int(index_topk), n_committed) if actual_topk <= 0: return torch.empty((num_tokens, 0), device=q.device, dtype=torch.int32) + if num_tokens == 0: + return torch.empty((0, actual_topk), device=q.device, dtype=torch.int32) q = q.contiguous() kv = kv.contiguous() weights = weights.contiguous() positions = positions.contiguous() + if seq_ids is not None: + seq_ids = seq_ids.contiguous() + if kv_lens is not None: + kv_lens = kv_lens.contiguous() out = torch.empty((num_tokens, actual_topk), device=q.device, dtype=torch.int32) # If top-k covers every committed compressed entry, the order does not @@ -67,18 +130,34 @@ def dsv4_indexer_topk( # where n_committed=256 and index_topk is 512/1024. if actual_topk == n_committed: block_k = triton.next_power_of_2(max(actual_topk, 1)) - _dsv4_indexer_dense_kernel[(num_tokens,)]( - out, - positions, - out.stride(0), - out.stride(1), - n_committed, - int(offset), - int(ratio), - BLOCK_K=block_k, - num_warps=4, - num_stages=1, - ) + if is_batched: + _dsv4_indexer_dense_batched_kernel[(num_tokens,)]( + out, + positions, + seq_ids, + kv_lens, + out.stride(0), + out.stride(1), + n_committed, + int(offset), + int(ratio), + BLOCK_K=block_k, + num_warps=4, + num_stages=1, + ) + else: + _dsv4_indexer_dense_kernel[(num_tokens,)]( + out, + positions, + out.stride(0), + out.stride(1), + n_committed, + int(offset), + int(ratio), + BLOCK_K=block_k, + num_warps=4, + num_stages=1, + ) return out score = torch.empty((num_tokens, n_committed), device=q.device, dtype=torch.float32) @@ -86,31 +165,61 @@ def dsv4_indexer_topk( block_h = min(block_h, triton.next_power_of_2(num_heads)) block_d = triton.next_power_of_2(head_dim) grid = (num_tokens, triton.cdiv(n_committed, block_t)) - _dsv4_indexer_score_kernel[grid]( - score, - q, - kv, - weights, - positions, - q.stride(0), - q.stride(1), - q.stride(2), - kv.stride(0), - kv.stride(1), - weights.stride(0), - weights.stride(1), - score.stride(0), - score.stride(1), - num_heads, - head_dim, - n_committed, - int(ratio), - BLOCK_T=block_t, - BLOCK_H=block_h, - BLOCK_D=block_d, - num_warps=4, - num_stages=1, - ) + if is_batched: + _dsv4_indexer_score_batched_kernel[grid]( + score, + q, + kv, + weights, + positions, + seq_ids, + kv_lens, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + weights.stride(0), + weights.stride(1), + score.stride(0), + score.stride(1), + num_heads, + head_dim, + n_committed, + int(ratio), + BLOCK_T=block_t, + BLOCK_H=block_h, + BLOCK_D=block_d, + num_warps=4, + num_stages=1, + ) + else: + _dsv4_indexer_score_kernel[grid]( + score, + q, + kv, + weights, + positions, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + weights.stride(0), + weights.stride(1), + score.stride(0), + score.stride(1), + num_heads, + head_dim, + n_committed, + int(ratio), + BLOCK_T=block_t, + BLOCK_H=block_h, + BLOCK_D=block_d, + num_warps=4, + num_stages=1, + ) values, indices = _aiter_topk(score, actual_topk, dim=-1) block_k = triton.next_power_of_2(max(actual_topk, 1)) _dsv4_indexer_finalize_kernel[(num_tokens,)]( diff --git a/aiter/ops/triton/attention/sparse_mqa_sink.py b/aiter/ops/triton/attention/sparse_mqa_sink.py index eb59b60f08..a54f6c11b3 100644 --- a/aiter/ops/triton/attention/sparse_mqa_sink.py +++ b/aiter/ops/triton/attention/sparse_mqa_sink.py @@ -5,6 +5,8 @@ _sparse_mqa_sink_kernel, ) +_DEQUANT_DTYPES = (torch.float16, torch.bfloat16) + def sparse_mqa_sink( q: torch.Tensor, @@ -25,8 +27,8 @@ def sparse_mqa_sink( """Sparse MQA with DSv4 attention-sink semantics. Args: - q: [num_tokens, num_heads, head_dim], BF16/FP16. - kv: [num_blocks, block_size, head_dim], BF16/FP16. + q: [num_tokens, num_heads, head_dim], dequantized BF16/FP16. + kv: [num_blocks, block_size, head_dim], dequantized BF16/FP16. out: [num_tokens, num_heads, head_dim], same dtype as q. cu_seqlens_q: [num_seqs + 1], int32 token offsets. seqused_k: [num_seqs], int32 logical KV lengths before padding. @@ -34,6 +36,9 @@ def sparse_mqa_sink( topk_indices: [num_tokens, topk], int32 logical KV positions. -1 is invalid. block_table: [num_seqs, max_blocks_per_seq], int32 logical->physical block IDs. attn_sink: [num_heads], FP32 sink logits included in the denominator only. + + This op does not unpack native DSv4 FP4/FP8 cache layouts or apply their + scale tensors. Callers must pass dequantized BF16/FP16 Q/KV tensors. """ assert q.dim() == 3, f"q must be [T, H, D], got {q.shape}" assert kv.dim() == 3, f"kv must be [num_blocks, block_size, D], got {kv.shape}" @@ -44,6 +49,26 @@ def sparse_mqa_sink( assert block_table.dim() == 2 assert attn_sink.shape == (q.shape[1],) assert kv.shape[2] == q.shape[2] + assert q.is_cuda and kv.is_cuda, "q and kv must be CUDA tensors" + assert ( + out.device == q.device + and cu_seqlens_q.device == q.device + and seqused_k.device == q.device + and topk_indices.device == q.device + and block_table.device == q.device + and attn_sink.device == q.device + and kv.device == q.device + ), "all inputs must be on the same device" + assert q.dtype in _DEQUANT_DTYPES, f"q must be dequantized BF16/FP16, got {q.dtype}" + assert ( + kv.dtype in _DEQUANT_DTYPES + ), f"kv must be dequantized BF16/FP16, got {kv.dtype}" + assert out.dtype == q.dtype, f"out dtype {out.dtype} must match q dtype {q.dtype}" + assert cu_seqlens_q.dtype == torch.int32 + assert seqused_k.dtype == torch.int32 + assert topk_indices.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert attn_sink.dtype == torch.float32 num_tokens, num_heads, head_dim = q.shape block_size = kv.shape[1] diff --git a/op_tests/test_dsv4_indexer.py b/op_tests/test_dsv4_indexer.py index 79427b4335..91ca35af32 100644 --- a/op_tests/test_dsv4_indexer.py +++ b/op_tests/test_dsv4_indexer.py @@ -4,17 +4,41 @@ from aiter.ops.triton.attention.dsv4_indexer import dsv4_indexer_topk -def _reference(q, kv, weights, positions, index_topk, offset, ratio=4): +def _reference( + q, + kv, + weights, + positions, + index_topk, + offset, + ratio=4, + seq_ids=None, + kv_lens=None, +): qf = q.float() kvf = kv.float() wf = weights.float() - scores = torch.einsum("thd,nd->thn", qf, kvf) + if kv.dim() == 3: + assert seq_ids is not None + kvf = kvf[seq_ids.long()] + max_committed = kv.shape[1] + else: + max_committed = kv.shape[0] + if kv.dim() == 3: + scores = torch.einsum("thd,tnd->thn", qf, kvf) + else: + scores = torch.einsum("thd,nd->thn", qf, kvf) scores = (scores.relu_() * wf.unsqueeze(-1)).sum(dim=1) - valid = torch.arange(kv.shape[0], device=q.device).unsqueeze(0) < ( - (positions.to(torch.long) + 1) // ratio - ).unsqueeze(1) + valid_limit = (positions.to(torch.long) + 1) // ratio + if kv_lens is not None: + valid_limit = torch.minimum( + valid_limit, kv_lens[seq_ids.long()].to(torch.long) + ) + valid = torch.arange(max_committed, device=q.device).unsqueeze( + 0 + ) < valid_limit.unsqueeze(1) scores = scores.masked_fill(~valid, float("-inf")) - k = min(index_topk, kv.shape[0]) + k = min(index_topk, max_committed) if k == 0: return torch.empty((q.shape[0], 0), device=q.device, dtype=torch.int32) values, indices = scores.topk(k, dim=-1) @@ -31,7 +55,10 @@ def test_dsv4_indexer_dense_causal_indices(): positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + 3 out = dsv4_indexer_topk(q, kv, weights, positions, 64, 128) - expected = torch.arange(committed, device="cuda", dtype=torch.int32).expand(tokens, -1) + 128 + expected = ( + torch.arange(committed, device="cuda", dtype=torch.int32).expand(tokens, -1) + + 128 + ) valid = torch.arange(committed, device="cuda").unsqueeze(0) < ( (positions + 1) // 4 ).unsqueeze(1) @@ -51,3 +78,81 @@ def test_dsv4_indexer_scored_topk_matches_torch(): out = dsv4_indexer_topk(q, kv, weights, positions, k, 128) ref = _reference(q, kv, weights, positions, k, 128) torch.testing.assert_close(out, ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_zero_committed_returns_empty(): + q = torch.empty(3, 64, 128, device="cuda", dtype=torch.bfloat16) + kv = torch.empty(0, 128, device="cuda", dtype=torch.bfloat16) + weights = torch.empty(3, 64, device="cuda", dtype=torch.float32) + positions = torch.arange(3, device="cuda", dtype=torch.int64) + + out = dsv4_indexer_topk(q, kv, weights, positions, 512, 128) + assert out.shape == (3, 0) + assert out.dtype == torch.int32 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_batched_dense_causal_indices(): + torch.manual_seed(2) + tokens, heads, dim, committed = 4, 64, 128, 16 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(2, committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.tensor([3, 7, 63, 63], device="cuda", dtype=torch.int64) + seq_ids = torch.tensor([0, 1, 0, 1], device="cuda", dtype=torch.int32) + kv_lens = torch.tensor([5, 9], device="cuda", dtype=torch.int32) + + out = dsv4_indexer_topk( + q, kv, weights, positions, 64, 128, seq_ids=seq_ids, kv_lens=kv_lens + ) + expected = ( + torch.arange(committed, device="cuda", dtype=torch.int32).expand(tokens, -1) + + 128 + ) + valid_limit = torch.minimum((positions + 1) // 4, kv_lens[seq_ids.long()]) + valid = torch.arange(committed, device="cuda").unsqueeze(0) < valid_limit.unsqueeze( + 1 + ) + expected = torch.where(valid, expected, -1) + torch.testing.assert_close(out, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_batched_scored_topk_no_cross_sequence_leakage(): + tokens, heads, dim, committed, k = 4, 64, 128, 32, 4 + q = torch.zeros(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + q[:, 0, 0] = 1 + kv = torch.zeros(2, committed, dim, device="cuda", dtype=torch.bfloat16) + kv[0, :, 0] = torch.arange(committed, device="cuda", dtype=torch.float32) + kv[1, :, 0] = torch.arange(committed, 0, -1, device="cuda", dtype=torch.float32) + weights = torch.zeros(tokens, heads, device="cuda", dtype=torch.float32) + weights[:, 0] = 1 + positions = torch.full((tokens,), committed * 4, device="cuda", dtype=torch.int64) + seq_ids = torch.tensor([0, 1, 0, 1], device="cuda", dtype=torch.int32) + kv_lens = torch.full((2,), committed, device="cuda", dtype=torch.int32) + + out = dsv4_indexer_topk( + q, kv, weights, positions, k, 128, seq_ids=seq_ids, kv_lens=kv_lens + ) + ref = _reference( + q, kv, weights, positions, k, 128, seq_ids=seq_ids, kv_lens=kv_lens + ) + torch.testing.assert_close(out, ref) + assert int(out[0, 0]) == 128 + committed - 1 + assert int(out[1, 0]) == 128 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize(("committed", "k"), [(2048, 512), (4096, 1024)]) +def test_dsv4_indexer_large_row_topk_matches_torch(committed, k): + torch.manual_seed(3) + tokens, heads, dim = 1, 64, 128 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.full((tokens,), committed * 4, device="cuda", dtype=torch.int64) + + out = dsv4_indexer_topk(q, kv, weights, positions, k, 128) + ref = _reference(q, kv, weights, positions, k, 128) + torch.testing.assert_close(out, ref) diff --git a/op_tests/test_sparse_mqa_sink.py b/op_tests/test_sparse_mqa_sink.py index a8013ea683..859382ccf4 100644 --- a/op_tests/test_sparse_mqa_sink.py +++ b/op_tests/test_sparse_mqa_sink.py @@ -4,18 +4,26 @@ from aiter.ops.triton.attention.sparse_mqa_sink import sparse_mqa_sink -def _reference(q, kv, topk, attn_sink, scale): +def _reference( + q, kv_blocks, topk, attn_sink, scale, cu_seqlens_q, seqused_k, block_table +): t, h, d = q.shape out = torch.empty_like(q) qf = q.float() - kvf = kv.float() + kvf = kv_blocks.float() + cu_cpu = cu_seqlens_q.cpu().tolist() for i in range(t): - valid = topk[i] >= 0 + seq_idx = next(seq for seq in range(len(cu_cpu) - 1) if cu_cpu[seq + 1] > i) + kv_len = int(seqused_k[seq_idx].item()) + valid = (topk[i] >= 0) & (topk[i] < kv_len) if not bool(valid.any()): out[i].zero_() continue idx = topk[i, valid].long() - k = kvf[idx] + logical_block = idx // kv_blocks.shape[1] + slot = idx % kv_blocks.shape[1] + physical_block = block_table[seq_idx, logical_block].long() + k = kvf[physical_block, slot] scores = torch.matmul(qf[i], k.t()) * scale combined = torch.cat([scores, attn_sink.float().view(h, 1)], dim=-1) weights = torch.softmax(combined, dim=-1)[..., :-1] @@ -36,7 +44,9 @@ def test_sparse_mqa_sink_matches_torch(topk_count): kv_flat = torch.randn(padded, dim, device="cuda", dtype=torch.bfloat16) kv_flat[kv_len:].zero_() kv_blocks = kv_flat.view(num_blocks, block_size, dim) - topk = torch.randint(0, kv_len, (tokens, topk_count), device="cuda", dtype=torch.int32) + topk = torch.randint( + 0, kv_len, (tokens, topk_count), device="cuda", dtype=torch.int32 + ) topk[0, -3:] = -1 attn_sink = torch.randn(heads, device="cuda", dtype=torch.float32) cu = torch.tensor([0, tokens], device="cuda", dtype=torch.int32) @@ -44,6 +54,81 @@ def test_sparse_mqa_sink_matches_torch(topk_count): block_table = torch.arange(num_blocks, device="cuda", dtype=torch.int32).view(1, -1) out = torch.empty_like(q) - sparse_mqa_sink(q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink) - ref = _reference(q, kv_flat[:kv_len], topk, attn_sink, dim**-0.5) + sparse_mqa_sink( + q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink + ) + ref = _reference(q, kv_blocks, topk, attn_sink, dim**-0.5, cu, seqused, block_table) torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize( + ("heads", "topk_count", "kv_len", "tokens"), + [ + (64, 160, 256, 2), # HCA 4K-style top-k + (64, 640, 768, 2), # V4-Flash CSA + (128, 1152, 1280, 1), # V4-Pro CSA + (64, 2048, 2304, 1), # HCA long-context smoke + ], +) +def test_sparse_mqa_sink_dsv4_shapes_match_torch(heads, topk_count, kv_len, tokens): + torch.manual_seed(1) + dim, block_size = 512, 256 + num_blocks = (kv_len + block_size - 1) // block_size + padded = num_blocks * block_size + + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv_flat = torch.randn(padded, dim, device="cuda", dtype=torch.bfloat16) + kv_flat[kv_len:].zero_() + kv_blocks = kv_flat.view(num_blocks, block_size, dim) + topk = torch.randint( + 0, kv_len, (tokens, topk_count), device="cuda", dtype=torch.int32 + ) + topk[0, -min(17, topk_count) :] = -1 + attn_sink = torch.linspace(-8, 8, heads, device="cuda", dtype=torch.float32) + cu = torch.tensor([0, tokens], device="cuda", dtype=torch.int32) + seqused = torch.tensor([kv_len], device="cuda", dtype=torch.int32) + block_table = torch.arange(num_blocks, device="cuda", dtype=torch.int32).view(1, -1) + out = torch.empty_like(q) + + sparse_mqa_sink( + q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink + ) + ref = _reference(q, kv_blocks, topk, attn_sink, dim**-0.5, cu, seqused, block_table) + torch.testing.assert_close(out, ref, rtol=3e-2, atol=3e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_sparse_mqa_sink_multi_sequence_block_table_matches_torch(): + torch.manual_seed(2) + tokens, heads, dim = 5, 64, 512 + topk_count, block_size = 160, 64 + kv_lens = [130, 177] + max_blocks = max((length + block_size - 1) // block_size for length in kv_lens) + total_blocks = 8 + + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv_blocks = torch.randn( + total_blocks, block_size, dim, device="cuda", dtype=torch.bfloat16 + ) + kv_blocks[4:7].add_(8.0) # make cross-sequence leakage obvious + block_table = torch.tensor( + [[2, 0, 1], [6, 4, 5]], device="cuda", dtype=torch.int32 + )[:, :max_blocks] + cu = torch.tensor([0, 2, tokens], device="cuda", dtype=torch.int32) + seqused = torch.tensor(kv_lens, device="cuda", dtype=torch.int32) + topk = torch.empty(tokens, topk_count, device="cuda", dtype=torch.int32) + for i, kv_len in enumerate([kv_lens[0]] * 2 + [kv_lens[1]] * 3): + topk[i] = torch.randint( + 0, kv_len, (topk_count,), device="cuda", dtype=torch.int32 + ) + topk[1, -5:] = -1 + topk[3, -7:] = -1 + attn_sink = torch.randn(heads, device="cuda", dtype=torch.float32) + out = torch.empty_like(q) + + sparse_mqa_sink( + q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink + ) + ref = _reference(q, kv_blocks, topk, attn_sink, dim**-0.5, cu, seqused, block_table) + torch.testing.assert_close(out, ref, rtol=3e-2, atol=3e-2) From 220bd4d9b17683ac7653eef10bbd997b9877c6ae Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Sun, 3 May 2026 10:38:23 -0700 Subject: [PATCH 09/10] Format DSv4 sparse indexer files --- aiter/ops/triton/attention/sparse_mqa_sink.py | 4 +--- op_tests/test_dsv4_indexer.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/aiter/ops/triton/attention/sparse_mqa_sink.py b/aiter/ops/triton/attention/sparse_mqa_sink.py index a54f6c11b3..01bc1f40f6 100644 --- a/aiter/ops/triton/attention/sparse_mqa_sink.py +++ b/aiter/ops/triton/attention/sparse_mqa_sink.py @@ -74,9 +74,7 @@ def sparse_mqa_sink( block_size = kv.shape[1] topk_count = topk_indices.shape[1] num_seqs = seqused_k.shape[0] - assert ( - cu_seqlens_q.shape[0] == num_seqs + 1 - ), ( + assert cu_seqlens_q.shape[0] == num_seqs + 1, ( "cu_seqlens_q must have length num_seqs + 1, " f"got {cu_seqlens_q.shape[0]} vs {num_seqs + 1}" ) diff --git a/op_tests/test_dsv4_indexer.py b/op_tests/test_dsv4_indexer.py index 91ca35af32..5dcafe4aaa 100644 --- a/op_tests/test_dsv4_indexer.py +++ b/op_tests/test_dsv4_indexer.py @@ -31,9 +31,7 @@ def _reference( scores = (scores.relu_() * wf.unsqueeze(-1)).sum(dim=1) valid_limit = (positions.to(torch.long) + 1) // ratio if kv_lens is not None: - valid_limit = torch.minimum( - valid_limit, kv_lens[seq_ids.long()].to(torch.long) - ) + valid_limit = torch.minimum(valid_limit, kv_lens[seq_ids.long()].to(torch.long)) valid = torch.arange(max_committed, device=q.device).unsqueeze( 0 ) < valid_limit.unsqueeze(1) From 883ddb70fd6b0cf0f03c1c51f0f6d8b2cb63c171 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Sun, 3 May 2026 16:14:33 -0700 Subject: [PATCH 10/10] fix: make topk per row width configurable --- aiter/ops/topk.py | 2 ++ csrc/include/rocm_ops.hpp | 6 ++++-- csrc/include/topk_per_row.h | 6 ++++-- csrc/kernels/topk_per_row_kernels.cu | 16 ++++++++++++---- op_tests/test_topk_per_row.py | 2 ++ op_tests/test_topk_row_prefill.py | 1 + 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/aiter/ops/topk.py b/aiter/ops/topk.py index 809a23c08a..1e55fe063f 100755 --- a/aiter/ops/topk.py +++ b/aiter/ops/topk.py @@ -207,6 +207,7 @@ def top_k_per_row_prefill( numRows: int, stride0: int, stride1: int, + k: int = -1, ) -> None: ... @@ -232,6 +233,7 @@ def top_k_per_row_decode( numRows: int, stride0: int, stride1: int, + k: int = -1, ) -> None: ... diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index f43072b1c1..9da0c77e52 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1656,7 +1656,8 @@ namespace py = pybind11; py::arg("values"), \ py::arg("numRows"), \ py::arg("stride0"), \ - py::arg("stride1")); \ + py::arg("stride1"), \ + py::arg("k") = -1); \ m.def("top_k_per_row_decode", \ &top_k_per_row_decode, \ py::arg("logits"), \ @@ -1665,7 +1666,8 @@ namespace py = pybind11; py::arg("indices"), \ py::arg("numRows"), \ py::arg("stride0"), \ - py::arg("stride1")); + py::arg("stride1"), \ + py::arg("k") = -1); #define MLA_METADATA_PYBIND \ m.def("get_mla_metadata_v1", \ diff --git a/csrc/include/topk_per_row.h b/csrc/include/topk_per_row.h index e3bae1887d..c18e3e6b5e 100644 --- a/csrc/include/topk_per_row.h +++ b/csrc/include/topk_per_row.h @@ -9,7 +9,8 @@ void top_k_per_row_prefill(const torch::Tensor& logits, std::optional values, int64_t numRows, int64_t stride0, - int64_t stride1); + int64_t stride1, + int64_t k = -1); void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, @@ -17,4 +18,5 @@ void top_k_per_row_decode(const torch::Tensor& logits, torch::Tensor& indices, int64_t numRows, int64_t stride0, - int64_t stride1); + int64_t stride1, + int64_t k = -1); diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 6edf377ca8..f7d8e3b409 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -2506,11 +2506,15 @@ void top_k_per_row_prefill(const torch::Tensor& logits, std::optional values, int64_t numRows, int64_t stride0, - int64_t stride1) + int64_t stride1, + int64_t k) { size_t buf_size = 0; // will be overwritten by the kernel - static constexpr int kTopK = 2048; + const int kTopK = k > 0 ? static_cast(k) : static_cast(indices.size(1)); + TORCH_CHECK(kTopK > 0, "top_k_per_row_prefill requires k > 0"); + TORCH_CHECK(kTopK <= indices.size(1), + "top_k_per_row_prefill k exceeds indices width"); static constexpr bool is_largest = true; const hipStream_t stream = at::hip::getCurrentHIPStream(); @@ -2630,11 +2634,15 @@ void top_k_per_row_decode(const torch::Tensor& logits, torch::Tensor& indices, int64_t numRows, int64_t stride0, - int64_t stride1) + int64_t stride1, + int64_t k) { size_t buf_size = 0; // will be overwritten by the kernel - static constexpr int kTopK = 2048; + const int kTopK = k > 0 ? static_cast(k) : static_cast(indices.size(1)); + TORCH_CHECK(kTopK > 0, "top_k_per_row_decode requires k > 0"); + TORCH_CHECK(kTopK <= indices.size(1), + "top_k_per_row_decode k exceeds indices width"); static constexpr bool is_largest = true; const hipStream_t stream = at::hip::getCurrentHIPStream(); diff --git a/op_tests/test_topk_per_row.py b/op_tests/test_topk_per_row.py index 34055eb1a6..5ced4ac583 100755 --- a/op_tests/test_topk_per_row.py +++ b/op_tests/test_topk_per_row.py @@ -146,6 +146,7 @@ def run_top_k_per_row_prefill( num_rows, stride_row, stride_col, + k=indices.size(1), ) @@ -182,6 +183,7 @@ def run_top_k_per_row_decode( numRows, stride0, stride1, + k=indices.size(1), ) diff --git a/op_tests/test_topk_row_prefill.py b/op_tests/test_topk_row_prefill.py index 2d47ced573..1fa4178bc2 100644 --- a/op_tests/test_topk_row_prefill.py +++ b/op_tests/test_topk_row_prefill.py @@ -286,6 +286,7 @@ def run_top_k_per_row_prefill( num_rows, stride_row, stride_col, + k=indices.size(1), )