Dsv4 sparse indexer#2998
Open
Oseltamivir wants to merge 14 commits intoROCm:mainfrom
Open
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Adds Triton implementations for DeepSeek-V4 (DSv4) sparse attention and Indexer top-k selection to replace slow Torch fallbacks in ATOM/serving paths.
Changes:
- Introduce
sparse_mqa_sinkTriton op implementing DSv4 sparse MQA forward with attention-sink denominator semantics. - Introduce
dsv4_indexer_topkTriton op implementing DSv4 Indexer scoring + causal top-k, including a dense causal fast path. - Add unit tests for both new ops and register the modules in Triton backward-compat import map.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_sparse_mqa_sink.py | Adds correctness test comparing sparse_mqa_sink vs a Torch reference. |
| op_tests/test_dsv4_indexer.py | Adds tests for Indexer dense-causal fast path and scored top-k vs Torch reference. |
| aiter/ops/triton/attention/sparse_mqa_sink.py | Python wrapper for launching the sparse MQA sink Triton kernel. |
| aiter/ops/triton/attention/dsv4_indexer.py | Python wrapper for Indexer scoring + top-k, including dense fast path. |
| aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py | Triton kernel for sparse MQA sink with per-token top-k gather and sink denominator. |
| aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py | Triton kernels for dense causal indices, scoring, and finalizing offset indices. |
| aiter/ops/triton/init.py | Registers dsv4_indexer and sparse_mqa_sink for backward-compatible imports. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This was referenced May 2, 2026
# Conflicts: # aiter/ops/topk.py # csrc/include/rocm_ops.hpp # csrc/include/topk_per_row.h # csrc/kernels/topk_per_row_kernels.cu # op_tests/test_topk_per_row.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
DSv4 uses a sparse attention path where each query gathers a small top-k set of compressed KV entries, plus an Indexer path that scores compressed KV entries to produce those top-k indices.
The current ATOM DSv4 integration has correctness-first Torch fallbacks for both paths. Those fallbacks materialize large intermediate tensors and are too slow for serving, especially at
conc > 1. This PR adds AITER Triton kernels for the DSv4 sparse MQA attention sink path and the DSv4 indexer scorer/top-k path so ATOM can avoid the Torch fallback.Technical Details
This PR adds:
sparse_mqa_sink: DSv4 sparse MQA forward with attention-sink denominator semantics.dsv4_indexer_topk: DSv4 Indexer scorer and causal top-k selection without materializing the Torch fallback’s[tokens, heads, committed_kv]score tensor.actual_topk == n_committed, which is common forshort-context DSv4 serving.
The sparse attention kernel supports DSv4’s MQA layout:
q:[num_tokens, num_heads, head_dim]kv:[num_blocks, block_size, head_dim]topk_indices:[num_tokens, topk]attn_sink:[num_heads]The Indexer kernel computes:
score[t, k] = sum_h relu(q[t, h] @ kv[k]) * weights[t, h]then applies the DSv4 causal compressed-token mask and returns offset top-k indices for the downstream sparse attention gather.Relevant downstream integration target: ROCm/ATOM DeepSeek-V4 PR650.
Test Plan
Test Result
Local syntax/import validation passed with:
The branch is clean against current ROCm/aiter:main and contains only the DSv4 sparse/indexer kernel additions plus tests.
Was tested and is being used at SemiAnalysisAI/InferenceX #1229, with runs: https://github.com/SemiAnalysisAI/InferenceX/actions/runs/25193385172
op_tests results: https://github.com/SemiAnalysisAI/InferenceX/actions/runs/25221896798
Submission Checklist
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.