diff --git a/aiter/ops/triton/__init__.py b/aiter/ops/triton/__init__.py index bef4ccc506..c8057288bf 100644 --- a/aiter/ops/triton/__init__.py +++ b/aiter/ops/triton/__init__.py @@ -103,6 +103,8 @@ "pa_prefill": "attention.pa_prefill", "pod_attention": "attention.pod_attention", "prefill_attention": "attention.prefill_attention", + "dsv4_indexer": "attention.dsv4_indexer", + "sparse_mqa_sink": "attention.sparse_mqa_sink", "unified_attention_sparse_mla": "attention.unified_attention_sparse_mla", "unified_attention": "attention.unified_attention", # Fusions modules (fusions/) diff --git a/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py b/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py new file mode 100644 index 0000000000..9d0f78625f --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py @@ -0,0 +1,241 @@ +import triton +import triton.language as tl + + +@triton.jit +def _dsv4_indexer_dense_kernel( + out_ptr, # [num_tokens, topk] + positions_ptr, # [num_tokens] + out_stride_t: tl.int64, + out_stride_k: tl.int64, + n_committed: tl.constexpr, + offset: tl.int32, + ratio: tl.constexpr, + BLOCK_K: tl.constexpr, +): + token_id = tl.program_id(0) + offs_k = tl.arange(0, BLOCK_K) + pos = tl.load(positions_ptr + token_id).to(tl.int32) + causal_limit = (pos + 1) // ratio + valid = (offs_k < n_committed) & (offs_k < causal_limit) + out = tl.where(valid, offs_k + offset, -1) + tl.store( + out_ptr + token_id * out_stride_t + offs_k * out_stride_k, + out, + mask=offs_k < n_committed, + ) + + +@triton.jit +def _dsv4_indexer_dense_batched_kernel( + out_ptr, # [num_tokens, topk] + positions_ptr, # [num_tokens] + seq_ids_ptr, # [num_tokens] + kv_lens_ptr, # [num_seqs] + out_stride_t: tl.int64, + out_stride_k: tl.int64, + n_committed: tl.constexpr, + offset: tl.int32, + ratio: tl.constexpr, + BLOCK_K: tl.constexpr, +): + token_id = tl.program_id(0) + offs_k = tl.arange(0, BLOCK_K) + seq_id = tl.load(seq_ids_ptr + token_id).to(tl.int32) + kv_len = tl.load(kv_lens_ptr + seq_id).to(tl.int32) + pos = tl.load(positions_ptr + token_id).to(tl.int32) + causal_limit = (pos + 1) // ratio + valid = (offs_k < n_committed) & (offs_k < kv_len) & (offs_k < causal_limit) + out = tl.where(valid, offs_k + offset, -1) + tl.store( + out_ptr + token_id * out_stride_t + offs_k * out_stride_k, + out, + mask=offs_k < n_committed, + ) + + +@triton.jit +def _dsv4_indexer_score_kernel( + score_ptr, # [num_tokens, kv_len], fp32 + q_ptr, # [num_tokens, num_heads, head_dim] + kv_ptr, # [kv_len, head_dim] + weights_ptr, # [num_tokens, num_heads] + positions_ptr, # [num_tokens] + q_stride_t: tl.int64, + q_stride_h: tl.int64, + q_stride_d: tl.int64, + kv_stride_t: tl.int64, + kv_stride_d: tl.int64, + weights_stride_t: tl.int64, + weights_stride_h: tl.int64, + score_stride_t: tl.int64, + score_stride_k: tl.int64, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_len: tl.constexpr, + ratio: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + offs_t = tile_id * BLOCK_T + tl.arange(0, BLOCK_T) + offs_d = tl.arange(0, BLOCK_D) + d_mask = offs_d < head_dim + acc = tl.zeros((BLOCK_T,), dtype=tl.float32) + + for h_start in range(0, num_heads, BLOCK_H): + offs_h = h_start + tl.arange(0, BLOCK_H) + h_mask = offs_h < num_heads + + q = tl.load( + q_ptr + + token_id * q_stride_t + + offs_h[:, None] * q_stride_h + + offs_d[None, :] * q_stride_d, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + kv = tl.load( + kv_ptr + offs_t[None, :] * kv_stride_t + offs_d[:, None] * kv_stride_d, + mask=(offs_t[None, :] < kv_len) & d_mask[:, None], + other=0.0, + cache_modifier=".cg", + ) + dots = tl.dot(q, kv) + dots = tl.maximum(dots, 0.0) + weights = tl.load( + weights_ptr + token_id * weights_stride_t + offs_h * weights_stride_h, + mask=h_mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + acc += tl.sum(dots * weights[:, None], axis=0) + + pos = tl.load(positions_ptr + token_id).to(tl.int32) + causal_limit = (pos + 1) // ratio + valid = (offs_t < kv_len) & (offs_t < causal_limit) + acc = tl.where(valid, acc, float("-inf")) + tl.store( + score_ptr + token_id * score_stride_t + offs_t * score_stride_k, + acc, + mask=offs_t < kv_len, + ) + + +@triton.jit +def _dsv4_indexer_score_batched_kernel( + score_ptr, # [num_tokens, kv_len], fp32 + q_ptr, # [num_tokens, num_heads, head_dim] + kv_ptr, # [num_seqs, kv_len, head_dim] + weights_ptr, # [num_tokens, num_heads] + positions_ptr, # [num_tokens] + seq_ids_ptr, # [num_tokens] + kv_lens_ptr, # [num_seqs] + q_stride_t: tl.int64, + q_stride_h: tl.int64, + q_stride_d: tl.int64, + kv_stride_b: tl.int64, + kv_stride_t: tl.int64, + kv_stride_d: tl.int64, + weights_stride_t: tl.int64, + weights_stride_h: tl.int64, + score_stride_t: tl.int64, + score_stride_k: tl.int64, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_len_max: tl.constexpr, + ratio: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + seq_id = tl.load(seq_ids_ptr + token_id).to(tl.int32) + kv_len = tl.load(kv_lens_ptr + seq_id).to(tl.int32) + offs_t = tile_id * BLOCK_T + tl.arange(0, BLOCK_T) + offs_d = tl.arange(0, BLOCK_D) + d_mask = offs_d < head_dim + acc = tl.zeros((BLOCK_T,), dtype=tl.float32) + + for h_start in range(0, num_heads, BLOCK_H): + offs_h = h_start + tl.arange(0, BLOCK_H) + h_mask = offs_h < num_heads + + q = tl.load( + q_ptr + + token_id * q_stride_t + + offs_h[:, None] * q_stride_h + + offs_d[None, :] * q_stride_d, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + kv = tl.load( + kv_ptr + + seq_id * kv_stride_b + + offs_t[None, :] * kv_stride_t + + offs_d[:, None] * kv_stride_d, + mask=(offs_t[None, :] < kv_len) & d_mask[:, None], + other=0.0, + cache_modifier=".cg", + ) + dots = tl.dot(q, kv) + dots = tl.maximum(dots, 0.0) + weights = tl.load( + weights_ptr + token_id * weights_stride_t + offs_h * weights_stride_h, + mask=h_mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + acc += tl.sum(dots * weights[:, None], axis=0) + + pos = tl.load(positions_ptr + token_id).to(tl.int32) + causal_limit = (pos + 1) // ratio + valid = (offs_t < kv_len_max) & (offs_t < kv_len) & (offs_t < causal_limit) + acc = tl.where(valid, acc, float("-inf")) + tl.store( + score_ptr + token_id * score_stride_t + offs_t * score_stride_k, + acc, + mask=offs_t < kv_len_max, + ) + + +@triton.jit +def _dsv4_indexer_finalize_kernel( + out_ptr, # [num_tokens, topk], int32 + values_ptr, # [num_tokens, topk], fp32 + indices_ptr, # [num_tokens, topk], int64 from aiter topk + out_stride_t: tl.int64, + out_stride_k: tl.int64, + values_stride_t: tl.int64, + values_stride_k: tl.int64, + indices_stride_t: tl.int64, + indices_stride_k: tl.int64, + offset: tl.int32, + topk: tl.constexpr, + BLOCK_K: tl.constexpr, +): + token_id = tl.program_id(0) + offs_k = tl.arange(0, BLOCK_K) + values = tl.load( + values_ptr + token_id * values_stride_t + offs_k * values_stride_k, + mask=offs_k < topk, + other=float("-inf"), + ) + indices = tl.load( + indices_ptr + token_id * indices_stride_t + offs_k * indices_stride_k, + mask=offs_k < topk, + other=-1, + ).to(tl.int32) + out = tl.where(values > -3.0e38, indices + offset, -1) + tl.store( + out_ptr + token_id * out_stride_t + offs_k * out_stride_k, + out, + mask=offs_k < topk, + ) diff --git a/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py b/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py new file mode 100644 index 0000000000..3ecee6eb42 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py @@ -0,0 +1,156 @@ +import triton +import triton.language as tl + + +@triton.jit +def _find_seq_idx(cu_seqlens_q_ptr, token_idx, num_seqs): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(cu_seqlens_q_ptr + mid) + if val <= token_idx: + left = mid + 1 + else: + right = mid + return left - 1 + + +@triton.jit +def _sparse_mqa_sink_kernel( + out_ptr, # [num_tokens, num_heads, head_dim] + q_ptr, # [num_tokens, num_heads, head_dim] + kv_ptr, # [num_blocks, block_size, head_dim] + topk_ptr, # [num_tokens, topk] + attn_sink_ptr, # [num_heads] + block_table_ptr, # [num_seqs, max_blocks_per_seq] + cu_seqlens_q_ptr, # [num_seqs + 1] + seqused_k_ptr, # [num_seqs] + scale, + q_stride_t: tl.int64, + q_stride_h: tl.int64, + q_stride_d: tl.int64, + out_stride_t: tl.int64, + out_stride_h: tl.int64, + out_stride_d: tl.int64, + kv_stride_b: tl.int64, + kv_stride_s: tl.int64, + kv_stride_d: tl.int64, + topk_stride_t: tl.int64, + topk_stride_k: tl.int64, + block_table_stride_b: tl.int64, + block_table_stride_blk: tl.int64, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + topk_count: tl.constexpr, + block_size: tl.constexpr, + num_seqs: tl.int32, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, + SCORE_D: tl.constexpr, + TILE_K: tl.constexpr, +): + """Sparse MQA with DSv4's attention-sink denominator. + + One program handles one query token, BLOCK_H query heads, and one output + dimension tile. KV is MQA: all query heads share the same [topk, head_dim] + key/value rows. + """ + token_id = tl.program_id(0) + head_block = tl.program_id(1) + dim_block = tl.program_id(2) + + seq_idx = _find_seq_idx(cu_seqlens_q_ptr, token_id, num_seqs) + seq_end = tl.load(cu_seqlens_q_ptr + seq_idx + 1) + if token_id >= seq_end: + return + kv_len = tl.load(seqused_k_ptr + seq_idx) + + offs_h = head_block * BLOCK_H + tl.arange(0, BLOCK_H) + offs_d = dim_block * BLOCK_D + tl.arange(0, BLOCK_D) + offs_score_d = tl.arange(0, SCORE_D) + h_mask = offs_h < num_heads + d_mask = offs_d < head_dim + + sink = tl.load(attn_sink_ptr + offs_h, mask=h_mask, other=float("-inf")).to( + tl.float32 + ) + has_sink = sink > -3.0e38 + m_i = tl.where(has_sink, sink, float("-inf")) + l_i = tl.where(has_sink, 1.0, 0.0) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + for tile_start in range(0, topk_count, TILE_K): + offs_k = tile_start + tl.arange(0, TILE_K) + topk_pos = tl.load( + topk_ptr + token_id * topk_stride_t + offs_k * topk_stride_k, + mask=offs_k < topk_count, + other=-1, + ) + valid_k = (offs_k < topk_count) & (topk_pos >= 0) & (topk_pos < kv_len) + + logical_block = topk_pos // block_size + slot = topk_pos - logical_block * block_size + physical_block = tl.load( + block_table_ptr + + seq_idx * block_table_stride_b + + logical_block * block_table_stride_blk, + mask=valid_k, + other=0, + ) + + scores = tl.zeros((BLOCK_H, TILE_K), dtype=tl.float32) + for d_start in range(0, head_dim, SCORE_D): + score_d = d_start + offs_score_d + score_d_mask = score_d < head_dim + q = tl.load( + q_ptr + + token_id * q_stride_t + + offs_h[:, None] * q_stride_h + + score_d[None, :] * q_stride_d, + mask=h_mask[:, None] & score_d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + k = tl.load( + kv_ptr + + physical_block[None, :] * kv_stride_b + + slot[None, :] * kv_stride_s + + score_d[:, None] * kv_stride_d, + mask=score_d_mask[:, None] & valid_k[None, :], + other=0.0, + cache_modifier=".cg", + ) + scores += tl.dot(q, k) + scores *= scale + scores = tl.where(h_mask[:, None] & valid_k[None, :], scores, float("-inf")) + + m_new = tl.maximum(m_i, tl.max(scores, axis=1)) + m_new = tl.where(m_new > float("-inf"), m_new, 0.0) + p = tl.exp(scores - m_new[:, None]) + alpha = tl.exp(m_i - m_new) + l_new = l_i * alpha + tl.sum(p, axis=1) + acc = acc * alpha[:, None] + + v = tl.load( + kv_ptr + + physical_block[:, None] * kv_stride_b + + slot[:, None] * kv_stride_s + + offs_d[None, :] * kv_stride_d, + mask=valid_k[:, None] & d_mask[None, :], + other=0.0, + cache_modifier=".cg", + ) + acc += tl.dot(p.to(v.dtype), v) + m_i = m_new + l_i = l_new + + acc = acc * tl.where(l_i[:, None] > 0.0, 1.0 / l_i[:, None], 0.0) + tl.store( + out_ptr + + token_id * out_stride_t + + offs_h[:, None] * out_stride_h + + offs_d[None, :] * out_stride_d, + acc, + mask=h_mask[:, None] & d_mask[None, :], + ) diff --git a/aiter/ops/triton/attention/dsv4_indexer.py b/aiter/ops/triton/attention/dsv4_indexer.py new file mode 100644 index 0000000000..dd4d9d0646 --- /dev/null +++ b/aiter/ops/triton/attention/dsv4_indexer.py @@ -0,0 +1,241 @@ +import torch +import triton + +from aiter.ops.triton._triton_kernels.attention.dsv4_indexer import ( + _dsv4_indexer_dense_batched_kernel, + _dsv4_indexer_dense_kernel, + _dsv4_indexer_finalize_kernel, + _dsv4_indexer_score_batched_kernel, + _dsv4_indexer_score_kernel, +) +from aiter.ops.triton.topk import topk as _aiter_topk + +_DEQUANT_DTYPES = (torch.float16, torch.bfloat16) + + +def dsv4_indexer_topk( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + positions: torch.Tensor, + index_topk: int, + offset: int, + *, + seq_ids: torch.Tensor | None = None, + kv_lens: torch.Tensor | None = None, + ratio: int = 4, + block_t: int = 64, + block_h: int = 8, +) -> torch.Tensor: + """DeepSeek-V4 Indexer scorer + causal top-k. + + Computes the Indexer's learned sparse compressed-KV selection without + materializing the Torch fallback's [tokens, heads, committed] score tensor: + + score[t, k] = sum_h relu(q[t, h] @ kv[k]) * weights[t, h] + + Args: + q: [num_tokens, 64, 128], dequantized BF16/FP16. + kv: [num_committed, 128] or [num_seqs, max_committed, 128], + dequantized BF16/FP16 compressed Indexer KV. + weights: [num_tokens, 64], FP32/BF16, already includes model scaling. + positions: [num_tokens], absolute token positions. + index_topk: model top-k cap, 512 for V4-Flash or 1024 for V4-Pro. + offset: index offset into the sparse-attention [window || compressed] KV. + seq_ids: optional [num_tokens] int32/int64 sequence IDs. Required when + kv is batched. + kv_lens: optional [num_seqs] int32/int64 committed KV length per sequence. + Required when kv is batched and shorter than max_committed. + ratio: compression ratio. DSv4 CSA Indexer uses 4. + + Returns: + [num_tokens, min(index_topk, max_committed)] int32. Future entries are -1. + + This op does not unpack native DSv4 FP4/FP8 cache layouts or apply their + scale tensors. Callers must pass dequantized BF16/FP16 Q/KV tensors. + """ + assert q.dim() == 3, f"q must be [T, H, D], got {q.shape}" + assert kv.dim() in (2, 3), f"kv must be [N, D] or [B, N, D], got {kv.shape}" + assert weights.dim() == 2, f"weights must be [T, H], got {weights.shape}" + assert positions.dim() == 1, f"positions must be [T], got {positions.shape}" + assert positions.dtype in (torch.int32, torch.int64) + assert q.shape[0] == weights.shape[0] == positions.shape[0] + assert q.shape[1] == weights.shape[1] + assert q.is_cuda and kv.is_cuda, "q and kv must be CUDA tensors" + assert ( + weights.device == q.device + and positions.device == q.device + and kv.device == q.device + ), "q, kv, weights, and positions must be on the same device" + assert q.dtype in _DEQUANT_DTYPES, f"q must be dequantized BF16/FP16, got {q.dtype}" + assert ( + kv.dtype in _DEQUANT_DTYPES + ), f"kv must be dequantized BF16/FP16, got {kv.dtype}" + assert weights.dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ), f"weights must be FP16/BF16/FP32, got {weights.dtype}" + assert q.shape[2] == kv.shape[-1] + assert index_topk >= 0 + assert ratio > 0 + + num_tokens, num_heads, head_dim = q.shape + is_batched = kv.dim() == 3 + n_committed = kv.shape[1] if is_batched else kv.shape[0] + if is_batched: + assert seq_ids is not None, "seq_ids is required when kv is batched" + assert seq_ids.dim() == 1 and seq_ids.shape[0] == num_tokens + assert seq_ids.device == q.device, "seq_ids must be on the same device as q" + assert seq_ids.dtype in (torch.int32, torch.int64) + if kv_lens is None: + kv_lens = torch.full( + (kv.shape[0],), n_committed, device=kv.device, dtype=torch.int32 + ) + assert kv_lens.dim() == 1 and kv_lens.shape[0] == kv.shape[0] + assert kv_lens.device == q.device, "kv_lens must be on the same device as q" + assert kv_lens.dtype in (torch.int32, torch.int64) + if hasattr(torch, "_assert_async"): + torch._assert_async(((seq_ids >= 0) & (seq_ids < kv.shape[0])).all()) + torch._assert_async(((kv_lens >= 0) & (kv_lens <= n_committed)).all()) + else: + assert bool( + ((seq_ids >= 0) & (seq_ids < kv.shape[0])).all() + ), "seq_ids must be in range" + assert bool( + ((kv_lens >= 0) & (kv_lens <= n_committed)).all() + ), "kv_lens must be in range" + else: + assert seq_ids is None, "seq_ids requires batched kv" + assert kv_lens is None, "kv_lens requires batched kv" + actual_topk = min(int(index_topk), n_committed) + if actual_topk <= 0: + return torch.empty((num_tokens, 0), device=q.device, dtype=torch.int32) + if num_tokens == 0: + return torch.empty((0, actual_topk), device=q.device, dtype=torch.int32) + + q = q.contiguous() + kv = kv.contiguous() + weights = weights.contiguous() + positions = positions.contiguous() + if seq_ids is not None: + seq_ids = seq_ids.contiguous() + if kv_lens is not None: + kv_lens = kv_lens.contiguous() + out = torch.empty((num_tokens, actual_topk), device=q.device, dtype=torch.int32) + + # If top-k covers every committed compressed entry, the order does not + # affect downstream sparse attention. Emit dense causal indices and skip the + # expensive learned scorer entirely. This is the common 1k1k DSv4 case + # where n_committed=256 and index_topk is 512/1024. + if actual_topk == n_committed: + block_k = triton.next_power_of_2(max(actual_topk, 1)) + if is_batched: + _dsv4_indexer_dense_batched_kernel[(num_tokens,)]( + out, + positions, + seq_ids, + kv_lens, + out.stride(0), + out.stride(1), + n_committed, + int(offset), + int(ratio), + BLOCK_K=block_k, + num_warps=4, + num_stages=1, + ) + else: + _dsv4_indexer_dense_kernel[(num_tokens,)]( + out, + positions, + out.stride(0), + out.stride(1), + n_committed, + int(offset), + int(ratio), + BLOCK_K=block_k, + num_warps=4, + num_stages=1, + ) + return out + + score = torch.empty((num_tokens, n_committed), device=q.device, dtype=torch.float32) + block_t = min(block_t, triton.next_power_of_2(max(n_committed, 1))) + block_h = min(block_h, triton.next_power_of_2(num_heads)) + block_d = triton.next_power_of_2(head_dim) + grid = (num_tokens, triton.cdiv(n_committed, block_t)) + if is_batched: + _dsv4_indexer_score_batched_kernel[grid]( + score, + q, + kv, + weights, + positions, + seq_ids, + kv_lens, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + weights.stride(0), + weights.stride(1), + score.stride(0), + score.stride(1), + num_heads, + head_dim, + n_committed, + int(ratio), + BLOCK_T=block_t, + BLOCK_H=block_h, + BLOCK_D=block_d, + num_warps=4, + num_stages=1, + ) + else: + _dsv4_indexer_score_kernel[grid]( + score, + q, + kv, + weights, + positions, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + weights.stride(0), + weights.stride(1), + score.stride(0), + score.stride(1), + num_heads, + head_dim, + n_committed, + int(ratio), + BLOCK_T=block_t, + BLOCK_H=block_h, + BLOCK_D=block_d, + num_warps=4, + num_stages=1, + ) + values, indices = _aiter_topk(score, actual_topk, dim=-1) + block_k = triton.next_power_of_2(max(actual_topk, 1)) + _dsv4_indexer_finalize_kernel[(num_tokens,)]( + out, + values, + indices, + out.stride(0), + out.stride(1), + values.stride(0), + values.stride(1), + indices.stride(0), + indices.stride(1), + int(offset), + actual_topk, + BLOCK_K=block_k, + num_warps=4, + num_stages=1, + ) + return out diff --git a/aiter/ops/triton/attention/sparse_mqa_sink.py b/aiter/ops/triton/attention/sparse_mqa_sink.py new file mode 100644 index 0000000000..01bc1f40f6 --- /dev/null +++ b/aiter/ops/triton/attention/sparse_mqa_sink.py @@ -0,0 +1,153 @@ +import torch +import triton + +from aiter.ops.triton._triton_kernels.attention.sparse_mqa_sink import ( + _sparse_mqa_sink_kernel, +) + +_DEQUANT_DTYPES = (torch.float16, torch.bfloat16) + + +def sparse_mqa_sink( + q: torch.Tensor, + kv: torch.Tensor, + out: torch.Tensor, + cu_seqlens_q: torch.Tensor, + seqused_k: torch.Tensor, + softmax_scale: float, + topk_indices: torch.Tensor, + block_table: torch.Tensor, + attn_sink: torch.Tensor, + *, + tile_k: int = 64, + block_h: int = 4, + block_d: int = 128, + score_d: int = 64, +) -> torch.Tensor: + """Sparse MQA with DSv4 attention-sink semantics. + + Args: + q: [num_tokens, num_heads, head_dim], dequantized BF16/FP16. + kv: [num_blocks, block_size, head_dim], dequantized BF16/FP16. + out: [num_tokens, num_heads, head_dim], same dtype as q. + cu_seqlens_q: [num_seqs + 1], int32 token offsets. + seqused_k: [num_seqs], int32 logical KV lengths before padding. + softmax_scale: scalar multiplier for q @ k. + topk_indices: [num_tokens, topk], int32 logical KV positions. -1 is invalid. + block_table: [num_seqs, max_blocks_per_seq], int32 logical->physical block IDs. + attn_sink: [num_heads], FP32 sink logits included in the denominator only. + + This op does not unpack native DSv4 FP4/FP8 cache layouts or apply their + scale tensors. Callers must pass dequantized BF16/FP16 Q/KV tensors. + """ + assert q.dim() == 3, f"q must be [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"kv must be [num_blocks, block_size, D], got {kv.shape}" + assert out.shape == q.shape, f"out shape {out.shape} must match q {q.shape}" + assert topk_indices.dim() == 2 and topk_indices.shape[0] == q.shape[0] + assert cu_seqlens_q.dim() == 1 + assert seqused_k.dim() == 1 + assert block_table.dim() == 2 + assert attn_sink.shape == (q.shape[1],) + assert kv.shape[2] == q.shape[2] + assert q.is_cuda and kv.is_cuda, "q and kv must be CUDA tensors" + assert ( + out.device == q.device + and cu_seqlens_q.device == q.device + and seqused_k.device == q.device + and topk_indices.device == q.device + and block_table.device == q.device + and attn_sink.device == q.device + and kv.device == q.device + ), "all inputs must be on the same device" + assert q.dtype in _DEQUANT_DTYPES, f"q must be dequantized BF16/FP16, got {q.dtype}" + assert ( + kv.dtype in _DEQUANT_DTYPES + ), f"kv must be dequantized BF16/FP16, got {kv.dtype}" + assert out.dtype == q.dtype, f"out dtype {out.dtype} must match q dtype {q.dtype}" + assert cu_seqlens_q.dtype == torch.int32 + assert seqused_k.dtype == torch.int32 + assert topk_indices.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert attn_sink.dtype == torch.float32 + + num_tokens, num_heads, head_dim = q.shape + block_size = kv.shape[1] + topk_count = topk_indices.shape[1] + num_seqs = seqused_k.shape[0] + assert cu_seqlens_q.shape[0] == num_seqs + 1, ( + "cu_seqlens_q must have length num_seqs + 1, " + f"got {cu_seqlens_q.shape[0]} vs {num_seqs + 1}" + ) + assert ( + block_table.shape[0] == num_seqs + ), f"block_table rows {block_table.shape[0]} must match num_seqs {num_seqs}" + + if q.numel() == 0: + return out + + assert num_seqs > 0, "non-empty q requires at least one sequence" + # Keep value checks on-device to avoid synchronizing this hot path. + if hasattr(torch, "_assert_async"): + torch._assert_async(cu_seqlens_q[0] == 0) + torch._assert_async(cu_seqlens_q[-1] == num_tokens) + else: + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert ( + cu_seqlens_q[-1] == num_tokens + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} must equal num_tokens {num_tokens}" + + q = q.contiguous() + kv = kv.contiguous() + topk_indices = topk_indices.contiguous() + cu_seqlens_q = cu_seqlens_q.contiguous() + seqused_k = seqused_k.contiguous() + block_table = block_table.contiguous() + attn_sink = attn_sink.contiguous() + + # Keep the accumulator footprint comparable to the original 8x64 tile + # while halving output-D tiles. That cuts repeated QK score work for + # DSv4's 512-wide value vector from 8x to 4x. + block_h = min(block_h, triton.next_power_of_2(num_heads)) + block_d = min(block_d, triton.next_power_of_2(head_dim)) + score_d = min(score_d, triton.next_power_of_2(head_dim)) + tile_k = min(tile_k, triton.next_power_of_2(max(topk_count, 1))) + head_blocks = triton.cdiv(num_heads, block_h) + dim_blocks = triton.cdiv(head_dim, block_d) + grid = (num_tokens, head_blocks, dim_blocks) + + _sparse_mqa_sink_kernel[grid]( + out, + q, + kv, + topk_indices, + attn_sink, + block_table, + cu_seqlens_q, + seqused_k, + float(softmax_scale), + q.stride(0), + q.stride(1), + q.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + topk_indices.stride(0), + topk_indices.stride(1), + block_table.stride(0), + block_table.stride(1), + num_heads, + head_dim, + topk_count, + block_size, + num_seqs, + BLOCK_H=block_h, + BLOCK_D=block_d, + SCORE_D=score_d, + TILE_K=tile_k, + num_warps=4, + num_stages=1, + ) + return out diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 199626e0cb..7ccc60f542 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -2515,7 +2515,10 @@ void top_k_per_row_prefill(const torch::Tensor& logits, { size_t buf_size = 0; // will be overwritten by the kernel - int kTopK = static_cast(k); + const int kTopK = k > 0 ? static_cast(k) : static_cast(indices.size(1)); + TORCH_CHECK(kTopK > 0, "top_k_per_row_prefill requires k > 0"); + TORCH_CHECK(kTopK <= indices.size(1), + "top_k_per_row_prefill k exceeds indices width"); static constexpr bool is_largest = true; const hipStream_t stream = at::hip::getCurrentHIPStream(); @@ -2641,7 +2644,10 @@ void top_k_per_row_decode(const torch::Tensor& logits, { size_t buf_size = 0; // will be overwritten by the kernel - int kTopK = static_cast(k); + const int kTopK = k > 0 ? static_cast(k) : static_cast(indices.size(1)); + TORCH_CHECK(kTopK > 0, "top_k_per_row_decode requires k > 0"); + TORCH_CHECK(kTopK <= indices.size(1), + "top_k_per_row_decode k exceeds indices width"); static constexpr bool is_largest = true; const hipStream_t stream = at::hip::getCurrentHIPStream(); diff --git a/op_tests/test_dsv4_indexer.py b/op_tests/test_dsv4_indexer.py new file mode 100644 index 0000000000..5dcafe4aaa --- /dev/null +++ b/op_tests/test_dsv4_indexer.py @@ -0,0 +1,156 @@ +import pytest +import torch + +from aiter.ops.triton.attention.dsv4_indexer import dsv4_indexer_topk + + +def _reference( + q, + kv, + weights, + positions, + index_topk, + offset, + ratio=4, + seq_ids=None, + kv_lens=None, +): + qf = q.float() + kvf = kv.float() + wf = weights.float() + if kv.dim() == 3: + assert seq_ids is not None + kvf = kvf[seq_ids.long()] + max_committed = kv.shape[1] + else: + max_committed = kv.shape[0] + if kv.dim() == 3: + scores = torch.einsum("thd,tnd->thn", qf, kvf) + else: + scores = torch.einsum("thd,nd->thn", qf, kvf) + scores = (scores.relu_() * wf.unsqueeze(-1)).sum(dim=1) + valid_limit = (positions.to(torch.long) + 1) // ratio + if kv_lens is not None: + valid_limit = torch.minimum(valid_limit, kv_lens[seq_ids.long()].to(torch.long)) + valid = torch.arange(max_committed, device=q.device).unsqueeze( + 0 + ) < valid_limit.unsqueeze(1) + scores = scores.masked_fill(~valid, float("-inf")) + k = min(index_topk, max_committed) + if k == 0: + return torch.empty((q.shape[0], 0), device=q.device, dtype=torch.int32) + values, indices = scores.topk(k, dim=-1) + return torch.where(values > -3.0e38, indices.to(torch.int32) + offset, -1) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_dense_causal_indices(): + torch.manual_seed(0) + tokens, heads, dim, committed = 9, 64, 128, 16 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + 3 + + out = dsv4_indexer_topk(q, kv, weights, positions, 64, 128) + expected = ( + torch.arange(committed, device="cuda", dtype=torch.int32).expand(tokens, -1) + + 128 + ) + valid = torch.arange(committed, device="cuda").unsqueeze(0) < ( + (positions + 1) // 4 + ).unsqueeze(1) + expected = torch.where(valid, expected, -1) + torch.testing.assert_close(out, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_scored_topk_matches_torch(): + torch.manual_seed(1) + tokens, heads, dim, committed, k = 7, 64, 128, 80, 12 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.arange(tokens, device="cuda", dtype=torch.int64) + committed * 4 + + out = dsv4_indexer_topk(q, kv, weights, positions, k, 128) + ref = _reference(q, kv, weights, positions, k, 128) + torch.testing.assert_close(out, ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_zero_committed_returns_empty(): + q = torch.empty(3, 64, 128, device="cuda", dtype=torch.bfloat16) + kv = torch.empty(0, 128, device="cuda", dtype=torch.bfloat16) + weights = torch.empty(3, 64, device="cuda", dtype=torch.float32) + positions = torch.arange(3, device="cuda", dtype=torch.int64) + + out = dsv4_indexer_topk(q, kv, weights, positions, 512, 128) + assert out.shape == (3, 0) + assert out.dtype == torch.int32 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_batched_dense_causal_indices(): + torch.manual_seed(2) + tokens, heads, dim, committed = 4, 64, 128, 16 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(2, committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.tensor([3, 7, 63, 63], device="cuda", dtype=torch.int64) + seq_ids = torch.tensor([0, 1, 0, 1], device="cuda", dtype=torch.int32) + kv_lens = torch.tensor([5, 9], device="cuda", dtype=torch.int32) + + out = dsv4_indexer_topk( + q, kv, weights, positions, 64, 128, seq_ids=seq_ids, kv_lens=kv_lens + ) + expected = ( + torch.arange(committed, device="cuda", dtype=torch.int32).expand(tokens, -1) + + 128 + ) + valid_limit = torch.minimum((positions + 1) // 4, kv_lens[seq_ids.long()]) + valid = torch.arange(committed, device="cuda").unsqueeze(0) < valid_limit.unsqueeze( + 1 + ) + expected = torch.where(valid, expected, -1) + torch.testing.assert_close(out, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_dsv4_indexer_batched_scored_topk_no_cross_sequence_leakage(): + tokens, heads, dim, committed, k = 4, 64, 128, 32, 4 + q = torch.zeros(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + q[:, 0, 0] = 1 + kv = torch.zeros(2, committed, dim, device="cuda", dtype=torch.bfloat16) + kv[0, :, 0] = torch.arange(committed, device="cuda", dtype=torch.float32) + kv[1, :, 0] = torch.arange(committed, 0, -1, device="cuda", dtype=torch.float32) + weights = torch.zeros(tokens, heads, device="cuda", dtype=torch.float32) + weights[:, 0] = 1 + positions = torch.full((tokens,), committed * 4, device="cuda", dtype=torch.int64) + seq_ids = torch.tensor([0, 1, 0, 1], device="cuda", dtype=torch.int32) + kv_lens = torch.full((2,), committed, device="cuda", dtype=torch.int32) + + out = dsv4_indexer_topk( + q, kv, weights, positions, k, 128, seq_ids=seq_ids, kv_lens=kv_lens + ) + ref = _reference( + q, kv, weights, positions, k, 128, seq_ids=seq_ids, kv_lens=kv_lens + ) + torch.testing.assert_close(out, ref) + assert int(out[0, 0]) == 128 + committed - 1 + assert int(out[1, 0]) == 128 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize(("committed", "k"), [(2048, 512), (4096, 1024)]) +def test_dsv4_indexer_large_row_topk_matches_torch(committed, k): + torch.manual_seed(3) + tokens, heads, dim = 1, 64, 128 + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(committed, dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(tokens, heads, device="cuda", dtype=torch.float32) + positions = torch.full((tokens,), committed * 4, device="cuda", dtype=torch.int64) + + out = dsv4_indexer_topk(q, kv, weights, positions, k, 128) + ref = _reference(q, kv, weights, positions, k, 128) + torch.testing.assert_close(out, ref) diff --git a/op_tests/test_sparse_mqa_sink.py b/op_tests/test_sparse_mqa_sink.py new file mode 100644 index 0000000000..859382ccf4 --- /dev/null +++ b/op_tests/test_sparse_mqa_sink.py @@ -0,0 +1,134 @@ +import pytest +import torch + +from aiter.ops.triton.attention.sparse_mqa_sink import sparse_mqa_sink + + +def _reference( + q, kv_blocks, topk, attn_sink, scale, cu_seqlens_q, seqused_k, block_table +): + t, h, d = q.shape + out = torch.empty_like(q) + qf = q.float() + kvf = kv_blocks.float() + cu_cpu = cu_seqlens_q.cpu().tolist() + for i in range(t): + seq_idx = next(seq for seq in range(len(cu_cpu) - 1) if cu_cpu[seq + 1] > i) + kv_len = int(seqused_k[seq_idx].item()) + valid = (topk[i] >= 0) & (topk[i] < kv_len) + if not bool(valid.any()): + out[i].zero_() + continue + idx = topk[i, valid].long() + logical_block = idx // kv_blocks.shape[1] + slot = idx % kv_blocks.shape[1] + physical_block = block_table[seq_idx, logical_block].long() + k = kvf[physical_block, slot] + scores = torch.matmul(qf[i], k.t()) * scale + combined = torch.cat([scores, attn_sink.float().view(h, 1)], dim=-1) + weights = torch.softmax(combined, dim=-1)[..., :-1] + out[i] = torch.matmul(weights, k).to(out.dtype) + return out + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("topk_count", [16, 48]) +def test_sparse_mqa_sink_matches_torch(topk_count): + torch.manual_seed(0) + tokens, heads, dim = 5, 16, 64 + kv_len, block_size = 73, 32 + num_blocks = (kv_len + block_size - 1) // block_size + padded = num_blocks * block_size + + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv_flat = torch.randn(padded, dim, device="cuda", dtype=torch.bfloat16) + kv_flat[kv_len:].zero_() + kv_blocks = kv_flat.view(num_blocks, block_size, dim) + topk = torch.randint( + 0, kv_len, (tokens, topk_count), device="cuda", dtype=torch.int32 + ) + topk[0, -3:] = -1 + attn_sink = torch.randn(heads, device="cuda", dtype=torch.float32) + cu = torch.tensor([0, tokens], device="cuda", dtype=torch.int32) + seqused = torch.tensor([kv_len], device="cuda", dtype=torch.int32) + block_table = torch.arange(num_blocks, device="cuda", dtype=torch.int32).view(1, -1) + out = torch.empty_like(q) + + sparse_mqa_sink( + q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink + ) + ref = _reference(q, kv_blocks, topk, attn_sink, dim**-0.5, cu, seqused, block_table) + torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize( + ("heads", "topk_count", "kv_len", "tokens"), + [ + (64, 160, 256, 2), # HCA 4K-style top-k + (64, 640, 768, 2), # V4-Flash CSA + (128, 1152, 1280, 1), # V4-Pro CSA + (64, 2048, 2304, 1), # HCA long-context smoke + ], +) +def test_sparse_mqa_sink_dsv4_shapes_match_torch(heads, topk_count, kv_len, tokens): + torch.manual_seed(1) + dim, block_size = 512, 256 + num_blocks = (kv_len + block_size - 1) // block_size + padded = num_blocks * block_size + + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv_flat = torch.randn(padded, dim, device="cuda", dtype=torch.bfloat16) + kv_flat[kv_len:].zero_() + kv_blocks = kv_flat.view(num_blocks, block_size, dim) + topk = torch.randint( + 0, kv_len, (tokens, topk_count), device="cuda", dtype=torch.int32 + ) + topk[0, -min(17, topk_count) :] = -1 + attn_sink = torch.linspace(-8, 8, heads, device="cuda", dtype=torch.float32) + cu = torch.tensor([0, tokens], device="cuda", dtype=torch.int32) + seqused = torch.tensor([kv_len], device="cuda", dtype=torch.int32) + block_table = torch.arange(num_blocks, device="cuda", dtype=torch.int32).view(1, -1) + out = torch.empty_like(q) + + sparse_mqa_sink( + q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink + ) + ref = _reference(q, kv_blocks, topk, attn_sink, dim**-0.5, cu, seqused, block_table) + torch.testing.assert_close(out, ref, rtol=3e-2, atol=3e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +def test_sparse_mqa_sink_multi_sequence_block_table_matches_torch(): + torch.manual_seed(2) + tokens, heads, dim = 5, 64, 512 + topk_count, block_size = 160, 64 + kv_lens = [130, 177] + max_blocks = max((length + block_size - 1) // block_size for length in kv_lens) + total_blocks = 8 + + q = torch.randn(tokens, heads, dim, device="cuda", dtype=torch.bfloat16) + kv_blocks = torch.randn( + total_blocks, block_size, dim, device="cuda", dtype=torch.bfloat16 + ) + kv_blocks[4:7].add_(8.0) # make cross-sequence leakage obvious + block_table = torch.tensor( + [[2, 0, 1], [6, 4, 5]], device="cuda", dtype=torch.int32 + )[:, :max_blocks] + cu = torch.tensor([0, 2, tokens], device="cuda", dtype=torch.int32) + seqused = torch.tensor(kv_lens, device="cuda", dtype=torch.int32) + topk = torch.empty(tokens, topk_count, device="cuda", dtype=torch.int32) + for i, kv_len in enumerate([kv_lens[0]] * 2 + [kv_lens[1]] * 3): + topk[i] = torch.randint( + 0, kv_len, (topk_count,), device="cuda", dtype=torch.int32 + ) + topk[1, -5:] = -1 + topk[3, -7:] = -1 + attn_sink = torch.randn(heads, device="cuda", dtype=torch.float32) + out = torch.empty_like(q) + + sparse_mqa_sink( + q, kv_blocks, out, cu, seqused, dim**-0.5, topk, block_table, attn_sink + ) + ref = _reference(q, kv_blocks, topk, attn_sink, dim**-0.5, cu, seqused, block_table) + torch.testing.assert_close(out, ref, rtol=3e-2, atol=3e-2) diff --git a/op_tests/test_topk_row_prefill.py b/op_tests/test_topk_row_prefill.py index 2d47ced573..1fa4178bc2 100644 --- a/op_tests/test_topk_row_prefill.py +++ b/op_tests/test_topk_row_prefill.py @@ -286,6 +286,7 @@ def run_top_k_per_row_prefill( num_rows, stride_row, stride_col, + k=indices.size(1), )