Skip to content

CK mha bwd: add sink attention score gradient support#2321

Draft
LJ-underdog wants to merge 6 commits intomainfrom
lj_ck_sink_bwd_v2
Draft

CK mha bwd: add sink attention score gradient support#2321
LJ-underdog wants to merge 6 commits intomainfrom
lj_ck_sink_bwd_v2

Conversation

@LJ-underdog
Copy link
Contributor

@LJ-underdog LJ-underdog commented Mar 18, 2026

Motivation

This PR extends the CK-backed MHA backward paths (mha_bwd / mha_varlen_bwd) to accept sink attention log-scores and optionally accumulate a sink gradient (d_sink), and adds Python tests to validate d_sink correctness.

Technical Details

Plumbs sink / d_sink through the Torch C++ interfaces, pybind args, and CK kernel argument structs.
Updates CK kernel launch argument packing to pass sink pointers into backward kernels (batch + varlen).
Adds new GPU tests for mha_bwd and mha_varlen_bwd d_sink accumulation vs a PyTorch reference.

Test Plan

Add test in test_mha_bwd&varlen_bwd.py

Test Result

Local test passed

Submission Checklist

@github-actions
Copy link
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2321 --add-label <label>

LJ-underdog and others added 2 commits March 18, 2026 02:05
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the CK-backed MHA backward paths (mha_bwd / mha_varlen_bwd) to accept sink attention log-scores and optionally accumulate a sink gradient (d_sink), and adds Python tests to validate d_sink correctness.

Changes:

  • Plumbs sink / d_sink through the Torch C++ interfaces, pybind args, and CK kernel argument structs.
  • Updates CK kernel launch argument packing to pass sink pointers into backward kernels (batch + varlen).
  • Adds new GPU tests for mha_bwd and mha_varlen_bwd d_sink accumulation vs a PyTorch reference.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
op_tests/test_mha_sink_bwd.py New tests validating d_sink accumulation for batch and varlen backward kernels.
aiter/ops/mha.py Updates Python-exposed mha_bwd / mha_varlen_bwd signatures to accept sink / d_sink.
csrc/include/torch/mha_bwd.h Extends the Torch C++ API for mha_bwd to accept sink / d_sink.
csrc/include/torch/mha_varlen_bwd.h Extends the Torch C++ API for mha_varlen_bwd to accept sink / d_sink.
csrc/include/rocm_ops.hpp Adds sink / d_sink parameters to the pybind signatures for backward ops.
csrc/include/mha_bwd.h Extends mha_bwd_args with sink pointer fields.
csrc/cpp_itfs/mha_bwd.cu Passes sink pointers into the CK fmha_bwd_args used by the non-asm path.
csrc/py_itfs_ck/mha_bwd_kernels.cu Adds optional sink/d_sink plumbing to CK batch-mode backward wrapper.
csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu Adds optional sink/d_sink plumbing to CK varlen backward wrapper.
csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py Updates codegen template to include LSEDataType in pipeline problem typing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants