From d3fd1243eb459f443066e508915e40307603f17c Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 01:27:29 -0500 Subject: [PATCH 1/6] CK mha bwd: add sink attention score gradient support --- aiter/ops/mha.py | 18 +- csrc/cpp_itfs/mha_bwd.cu | 2 + csrc/include/mha_bwd.h | 2 + csrc/include/rocm_ops.hpp | 8 +- csrc/include/torch/mha_bwd.h | 4 +- csrc/include/torch/mha_varlen_bwd.h | 4 +- csrc/py_itfs_ck/mha_bwd_kernels.cu | 9 +- csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu | 7 +- .../fmha_bwd_pre_post_kernel_generate.py | 1 + op_tests/test_mha_sink_bwd.py | 294 ++++++++++++++++++ 10 files changed, 336 insertions(+), 13 deletions(-) create mode 100644 op_tests/test_mha_sink_bwd.py diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 2424cb7955..02e9bba651 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -624,7 +624,8 @@ def cmdGenFunc_mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ): md_name = "mha_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -775,7 +776,8 @@ def gen_mha_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv) @@ -807,7 +809,8 @@ def mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -889,7 +892,8 @@ def cmdGenFunc_mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> dict[str, Any]: md_name = "mha_varlen_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -1117,7 +1121,8 @@ def gen_mha_varlen_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return gen_mha_varlen_bwd_fake_tensors_common( q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv @@ -1157,7 +1162,8 @@ def mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 00952b7f06..b96c28ce7e 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -168,6 +168,8 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* dv_ptr */ a.dv_ptr, /* dbias_ptr */ a.dbias_ptr, /* dq_acc_ptr */ a.dq_acc_ptr, + /* sink_ptr */ a.sink_ptr, + /* d_sink_ptr */ a.d_sink_ptr, /* seqstart_q_ptr */ a.seqstart_q_ptr, /* seqstart_k_ptr */ a.seqstart_k_ptr, diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index 1ec94cb10e..235564aed3 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -47,6 +47,8 @@ struct mha_bwd_args void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; + const void* sink_ptr = nullptr; // sink scores [batch, nhead] log-space (LSEDataType=float); nullptr disables sink + void* d_sink_ptr = nullptr; // sink gradient accumulator [nhead] (LSEDataType=float); nullptr disables sink grad // Usage notes for sequence length pointer parameters: // // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 76c03b4331..0b0c032670 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -873,7 +873,9 @@ namespace py = pybind11; py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("rng_state") = std::nullopt, \ - py::arg("gen") = std::nullopt); + py::arg("gen") = std::nullopt, \ + py::arg("sink") = std::nullopt, \ + py::arg("d_sink") = std::nullopt); #define MHA_FWD_ASM_PYBIND \ m.def("fmha_v3_fwd", \ @@ -1005,7 +1007,9 @@ namespace py = pybind11; py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt, \ py::arg("cu_seqlens_q_padded") = std::nullopt, \ - py::arg("cu_seqlens_k_padded") = std::nullopt); + py::arg("cu_seqlens_k_padded") = std::nullopt, \ + py::arg("sink") = std::nullopt, \ + py::arg("d_sink") = std::nullopt); #define MOE_CK_2STAGES_PYBIND \ m.def("ck_moe_stage1", \ diff --git a/csrc/include/torch/mha_bwd.h b/csrc/include/torch/mha_bwd.h index 5b1ea2c098..ce8fcf8ca1 100644 --- a/csrc/include/torch/mha_bwd.h +++ b/csrc/include/torch/mha_bwd.h @@ -24,6 +24,8 @@ std::vector mha_bwd(const at::Tensor& dout, // [b, sq, hq, d] std::optional bias_, // [sq, sk] std::optional alibi_slopes, // [hq] or [b, hq] std::optional rng_state, - std::optional gen); + std::optional gen, + std::optional sink, // [b, hq] log-space sink scores (float) + std::optional d_sink); // [hq] sink gradient output (float) } // namespace torch_itfs } // namespace aiter diff --git a/csrc/include/torch/mha_varlen_bwd.h b/csrc/include/torch/mha_varlen_bwd.h index ac78ec2fb3..e5ba6f0754 100644 --- a/csrc/include/torch/mha_varlen_bwd.h +++ b/csrc/include/torch/mha_varlen_bwd.h @@ -30,7 +30,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d] std::optional rng_state, std::optional gen, std::optional cu_seqlens_q_padded, // [b+1] - std::optional cu_seqlens_k_padded // [b+1] + std::optional cu_seqlens_k_padded, // [b+1] + std::optional sink, // [b, hq] log-space sink scores (float) + std::optional d_sink // [hq] sink gradient output (float) ); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index d5f05f4e46..aba8755f5a 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -31,7 +31,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] std::optional bias_, // [sq, sk] std::optional alibi_slopes_, // [hq] or [b, hq] std::optional rng_state_, - std::optional gen_) + std::optional gen_, + std::optional sink_, // [b, hq] log-space sink scores (float) + std::optional d_sink_) // [hq] sink gradient output (float) { if (is_causal) { window_size_right = 0; } @@ -198,6 +200,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] hipLaunchKernelGGL( aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, reinterpret_cast(rng_state.data_ptr())); + } else { + // No dropout: allocate a dummy tensor so data_ptr() is always valid. + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (seqlen_q > 0) { @@ -329,6 +334,8 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] dv_expanded.data_ptr(), dbias_ptr, dq_accum.data_ptr(), + (sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq] + (d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq] nullptr, // seqstart_q_ptr nullptr, // seqstart_k_ptr nullptr, // seqlen_q_ptr diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index fc1de89635..260e3471f3 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -36,8 +36,9 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] std::optional rng_state_, std::optional gen_, std::optional cu_seqlens_q_padded, // [b+1] - std::optional cu_seqlens_k_padded // [b+1] - ) + std::optional cu_seqlens_k_padded, // [b+1] + std::optional sink_, // [b, hq] log-space sink scores (float) + std::optional d_sink_) // [hq] sink gradient output (float) { if (is_causal) { window_size_right = 0; } @@ -341,6 +342,8 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), // dq_acc + (sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq] + (d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq] seqstart_q_ptr, // seqstart_q_ptr (physical cumulative) seqstart_k_ptr, // seqstart_k_ptr (physical cumulative) nullptr, // seqlen_q_ptr (per-sequence logical) diff --git a/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py b/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py index b09a0ccd77..beef57c5e1 100644 --- a/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py +++ b/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py @@ -94,6 +94,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype: str) -> Optional[dict] typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::LSEDataType, /* BlockSize = */ 64, {F_hdim}, {F_mode}, diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py new file mode 100644 index 0000000000..63b0ca765b --- /dev/null +++ b/op_tests/test_mha_sink_bwd.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# Tests for mha_bwd / mha_varlen_bwd with sink gradient support. +# +# The sink_bwd feature adds two arguments to mha_bwd: +# sink : [batch, nhead] float32 – per-batch-per-head log-space sink score +# d_sink : [nhead] float32 – accumulator for the sink gradient (output) +# +# Reference formula (derived from kernel block_fmha_bwd_dot_do_o.hpp): +# D[b, h, q] = sum_j(dout[b, q, h, j] * out[b, q, h, j]) * p_undrop +# P_sink[b, h, q] = exp(sink[b, h] - lse_fwd[b, h, q]) +# d_sink[h] = sum_{b, q} (-P_sink[b, h, q] * D[b, h, q]) + +import pytest +import torch + +import aiter +from aiter import dtypes, mha_bwd, mha_fwd, mha_varlen_bwd + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): + """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" + q = torch.randn(batch, seqlen_q, nhead, hdim, device=device, dtype=dtype).requires_grad_(True) + k = torch.randn(batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype).requires_grad_(True) + v = torch.randn(batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype).requires_grad_(True) + dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) + return q, k, v, dout + + +def run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = mha_fwd( + q, k, v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def reference_d_sink(dout, out, lse, sink, p_undrop=1.0): + """ + Pure-PyTorch reference for d_sink. + + dout : [B, Sq, H, Dv] + out : [B, Sq, H, Dv] + lse : [B, H, Sq] (forward LSE without sink) + sink : [B, H] + returns d_sink : [H] + """ + # D[b, q, h] = sum_j(dout * out) * p_undrop -> shape [B, Sq, H] + D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] + # reorder to [B, H, Sq] to align with lse + D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] + + # P_sink[b, h, q] = exp(sink[b, h] - lse[b, h, q]) + sink_bhs = sink.unsqueeze(-1) # [B, H, 1] + p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] + + # d_sink[h] = sum_{b, q} (-P_sink * D) + d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] + return d_sink.float() + + +# --------------------------------------------------------------------------- +# parametrize +# --------------------------------------------------------------------------- + +DTYPES = [dtypes.fp16, dtypes.bf16] +CAUSALS = [False, True] +CONFIGS = [ + # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) + (2, 128, 128, 4, 4, 64), + (1, 64, 64, 6, 2, 128), +] + + +@pytest.mark.parametrize("causal", CAUSALS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) +def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): + """ + Verify that mha_bwd correctly accumulates d_sink. + + Strategy + -------- + 1. Run mha_fwd to obtain (out, lse). + 2. Create a random sink tensor in log-space [30, 60] and a zero d_sink buffer. + 3. Call mha_bwd with sink/d_sink. + 4. Compare the kernel d_sink with the PyTorch reference. + """ + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim ** -0.5 + + q, k, v, dout = make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + + # --- forward --- + out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + # --- sink tensors --- + # sink: [batch, nhead], uniform in [30, 60] in log-space + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + # --- backward --- + dq, dk, dv, softmax_d = mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + # d_sink must have been written (non-zero for non-trivial inputs) + assert d_sink.abs().max() > 0, "d_sink was not updated by mha_bwd" + + # --- reference --- + d_sink_ref = reference_d_sink(dout, out, lse, sink) + + # Tolerances: fp16/bf16 are noisy; use relatively loose absolute tolerance + # because sink values are large (exp() amplifies small differences) + rtol = 0.02 + atol = 0.5 # absolute tolerance in float units for d_sink + torch.testing.assert_close( + d_sink, d_sink_ref, + rtol=rtol, atol=atol, + msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, " + f"B={batch}, Sq={seqlen_q}, H={nhead}" + ) + + +@pytest.mark.parametrize("causal", CAUSALS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) +def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): + """ + Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs. + + We compare mha_bwd with sink=None (baseline) against mha_bwd with a + near-zero sink (small values so the rescaling is negligible). + The gradients should be numerically close. + """ + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim ** -0.5 + + q, k, v, dout = make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + + # --- forward --- + out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + common_bwd_args = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + ) + + # baseline: no sink + dq_base, dk_base, dv_base, _ = mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, + **common_bwd_args, + ) + + # with sink = very negative values → exp(sink - lse) ≈ 0 → no effect + sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq_sink, dk_sink, dv_sink, _ = mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, + **common_bwd_args, + sink=sink_small, + d_sink=d_sink, + ) + + # With negligible sink, gradients should match the no-sink baseline + rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) + torch.testing.assert_close(dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink") + torch.testing.assert_close(dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink") + torch.testing.assert_close(dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink") + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): + """Passing sink=None must give identical output to omitting sink entirely.""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + softmax_scale = hdim ** -0.5 + + q, k, v, dout = make_qkvo(batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device) + out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) + + common = dict( + dropout_p=0.0, softmax_scale=softmax_scale, + is_causal=False, window_size_left=-1, window_size_right=-1, + deterministic=False, + ) + + dq1, dk1, dv1, d1 = mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common) + dq2, dk2, dv2, d2 = mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common, + sink=None, d_sink=None) + + torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") + torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") + torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") + torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_mha_varlen_bwd_sink_dsink(dtype): + """ + Smoke test: mha_varlen_bwd with sink/d_sink produces finite, non-zero d_sink + and doesn't corrupt dQ/dK/dV shapes. + + In group (varlen) mode the CK kernel expects: + lse: [nhead, total_q] (not the batch-mode [batch, nhead, seqlen]) + sink: [batch, nhead] (one log-space score per batch-head pair) + We derive these from a batch-mode forward pass. + """ + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + hdim_v = hdim + softmax_scale = hdim ** -0.5 + + # build equal-length varlen inputs (no padding) + cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) + cu_seqlens_k = cu_seqlens_q.clone() + total_q = seqlen * batch + total_k = seqlen * batch + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + # forward (batch mode) → convert outputs to group-mode shapes + q_b = q.view(batch, seqlen, nhead, hdim) + k_b = k.view(batch, seqlen, nhead, hdim) + v_b = v.view(batch, seqlen, nhead, hdim_v) + out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + + out = out_b.view(total_q, nhead, hdim_v) + + # lse for group mode: [nhead, total_q] + # lse_b is [batch, nhead, seqlen]; permute to [nhead, batch, seqlen] then flatten + lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() + + # sink: [batch, nhead], moderate log-space values + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = mha_varlen_bwd( + dout, q, k, v, out, lse, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink contains non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + assert dq.shape == q.shape + assert dk.shape == k.shape + assert dv.shape == v.shape From d974b3eadabf05d2a327f67f2c362b0585728435 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 02:05:25 -0500 Subject: [PATCH 2/6] test: add varlen sink bwd tests to test_mha_sink_bwd --- op_tests/test_mha_sink_bwd.py | 136 +++++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 10 deletions(-) diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py index 63b0ca765b..f9bdaf956a 100644 --- a/op_tests/test_mha_sink_bwd.py +++ b/op_tests/test_mha_sink_bwd.py @@ -227,23 +227,58 @@ def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") +def reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): + """ + Reference d_sink for varlen mode. + + dout : [total_q, H, Dv] + out : [total_q, H, Dv] + lse_group : [H, total_q] – group-mode LSE (flattened across batches) + sink : [B, H] + seqlens_q : list of per-batch sequence lengths + returns d_sink : [H] + """ + nhead = sink.shape[1] + d_sink = torch.zeros(nhead, device=sink.device, dtype=torch.float32) + + offset = 0 + for b, sq in enumerate(seqlens_q): + dout_b = dout[offset:offset + sq].float() # [sq, H, Dv] + out_b = out[offset:offset + sq].float() # [sq, H, Dv] + lse_b = lse_group[:, offset:offset + sq] # [H, sq] + + # D[q, h] = sum_j(dout[q,h,j] * out[q,h,j]) + D_qh = (dout_b * out_b).sum(dim=-1) # [sq, H] + D_hq = D_qh.permute(1, 0) # [H, sq] + + # P_sink[h, q] = exp(sink[b, h] - lse[h, q]) + p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) # [H, sq] + + d_sink += (-p_sink * D_hq).sum(dim=-1) # [H] + offset += sq + + return d_sink + + @pytest.mark.parametrize("dtype", DTYPES) def test_mha_varlen_bwd_sink_dsink(dtype): """ - Smoke test: mha_varlen_bwd with sink/d_sink produces finite, non-zero d_sink - and doesn't corrupt dQ/dK/dV shapes. + Numerical correctness test: mha_varlen_bwd with sink/d_sink. + + Verifies: + 1. d_sink values match the per-batch reference computation. + 2. Handles equal-length sequences correctly. In group (varlen) mode the CK kernel expects: - lse: [nhead, total_q] (not the batch-mode [batch, nhead, seqlen]) - sink: [batch, nhead] (one log-space score per batch-head pair) - We derive these from a batch-mode forward pass. + lse : [nhead, total_q] (flattened across batches per head) + sink : [batch, nhead] (one log-space score per batch-head pair) """ device = torch.device("cuda") batch, seqlen, nhead, hdim = 2, 64, 4, 64 hdim_v = hdim softmax_scale = hdim ** -0.5 + seqlens_q = [seqlen] * batch - # build equal-length varlen inputs (no padding) cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) cu_seqlens_k = cu_seqlens_q.clone() total_q = seqlen * batch @@ -254,19 +289,16 @@ def test_mha_varlen_bwd_sink_dsink(dtype): v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) - # forward (batch mode) → convert outputs to group-mode shapes + # forward (batch mode) → convert to group-mode shapes q_b = q.view(batch, seqlen, nhead, hdim) k_b = k.view(batch, seqlen, nhead, hdim) v_b = v.view(batch, seqlen, nhead, hdim_v) out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) out = out_b.view(total_q, nhead, hdim_v) - # lse for group mode: [nhead, total_q] - # lse_b is [batch, nhead, seqlen]; permute to [nhead, batch, seqlen] then flatten lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() - # sink: [batch, nhead], moderate log-space values sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) @@ -292,3 +324,87 @@ def test_mha_varlen_bwd_sink_dsink(dtype): assert dq.shape == q.shape assert dk.shape == k.shape assert dv.shape == v.shape + + # numerical correctness vs reference + d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close(d_sink, d_sink_ref, rtol=0.02, atol=0.5, + msg="varlen d_sink mismatch vs reference") + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_mha_varlen_bwd_sink_variable_lengths(dtype): + """ + Varlen sink test with variable-length sequences per batch entry. + + Ensures: + - Kernel correctly uses seqstart_q to determine per-batch sink values. + - d_sink accumulates correctly across batches with different lengths. + """ + device = torch.device("cuda") + nhead, hdim = 4, 64 + hdim_v = hdim + softmax_scale = hdim ** -0.5 + + # variable lengths: batch 0 has 48 tokens, batch 1 has 80 tokens + seqlens_q = [48, 80] + seqlens_k = [48, 80] + batch = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_sq = torch.tensor([0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + device=device, dtype=torch.int32) + cu_sk = torch.tensor([0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), + device=device, dtype=torch.int32) + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + # forward per batch segment (different seq lengths → can't use batch mode directly) + out_parts, lse_parts = [], [] + offset_q, offset_k = 0, 0 + for sq, sk in zip(seqlens_q, seqlens_k): + q_b = q[offset_q:offset_q+sq].unsqueeze(0) + k_b = k[offset_k:offset_k+sk].unsqueeze(0) + v_b = v[offset_k:offset_k+sk].unsqueeze(0) + out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + out_parts.append(out_b.squeeze(0)) # [sq, H, Dv] + lse_parts.append(lse_b.squeeze(0).permute(1, 0)) # [sq, H] + offset_q += sq + offset_k += sk + + out = torch.cat(out_parts, dim=0) # [total_q, H, Dv] + # group-mode lse: [H, total_q] + lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() # [H, total_q] + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = mha_varlen_bwd( + dout, q, k, v, out, lse, + cu_seqlens_q=cu_sq, + cu_seqlens_k=cu_sk, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink has non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + + # reference + d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close(d_sink, d_sink_ref, rtol=0.02, atol=0.5, + msg="varlen variable-length d_sink mismatch") From b065c60461d96527831149a5f26572097f9288ea Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 15:07:22 +0800 Subject: [PATCH 3/6] Update op_tests/test_mha_sink_bwd.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- op_tests/test_mha_sink_bwd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py index f9bdaf956a..764ddb5a16 100644 --- a/op_tests/test_mha_sink_bwd.py +++ b/op_tests/test_mha_sink_bwd.py @@ -15,7 +15,6 @@ import pytest import torch -import aiter from aiter import dtypes, mha_bwd, mha_fwd, mha_varlen_bwd From a7bf4aeed1d6e37fca955e24fceab7f6a74d4a00 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 02:11:28 -0500 Subject: [PATCH 4/6] style: apply black formatting to test_mha_sink_bwd --- op_tests/test_mha_sink_bwd.py | 237 +++++++++++++++++++++++----------- 1 file changed, 162 insertions(+), 75 deletions(-) diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py index 764ddb5a16..422e35af49 100644 --- a/op_tests/test_mha_sink_bwd.py +++ b/op_tests/test_mha_sink_bwd.py @@ -17,24 +17,32 @@ from aiter import dtypes, mha_bwd, mha_fwd, mha_varlen_bwd - # --------------------------------------------------------------------------- # helpers # --------------------------------------------------------------------------- + def make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" - q = torch.randn(batch, seqlen_q, nhead, hdim, device=device, dtype=dtype).requires_grad_(True) - k = torch.randn(batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype).requires_grad_(True) - v = torch.randn(batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype).requires_grad_(True) - dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) + q = torch.randn( + batch, seqlen_q, nhead, hdim, device=device, dtype=dtype + ).requires_grad_(True) + k = torch.randn( + batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype + ).requires_grad_(True) + v = torch.randn( + batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype + ).requires_grad_(True) + dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) return q, k, v, dout def run_fwd(q, k, v, softmax_scale, causal): """Run mha_fwd and return (out, lse).""" out, lse, _, _ = mha_fwd( - q, k, v, + q, + k, + v, dropout_p=0.0, softmax_scale=softmax_scale, is_causal=causal, @@ -60,14 +68,14 @@ def reference_d_sink(dout, out, lse, sink, p_undrop=1.0): # D[b, q, h] = sum_j(dout * out) * p_undrop -> shape [B, Sq, H] D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] # reorder to [B, H, Sq] to align with lse - D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] + D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] # P_sink[b, h, q] = exp(sink[b, h] - lse[b, h, q]) - sink_bhs = sink.unsqueeze(-1) # [B, H, 1] - p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] + sink_bhs = sink.unsqueeze(-1) # [B, H, 1] + p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] # d_sink[h] = sum_{b, q} (-P_sink * D) - d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] + d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] return d_sink.float() @@ -75,19 +83,21 @@ def reference_d_sink(dout, out, lse, sink, p_undrop=1.0): # parametrize # --------------------------------------------------------------------------- -DTYPES = [dtypes.fp16, dtypes.bf16] -CAUSALS = [False, True] -CONFIGS = [ +DTYPES = [dtypes.fp16, dtypes.bf16] +CAUSALS = [False, True] +CONFIGS = [ # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) (2, 128, 128, 4, 4, 64), - (1, 64, 64, 6, 2, 128), + (1, 64, 64, 6, 2, 128), ] -@pytest.mark.parametrize("causal", CAUSALS) -@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("causal", CAUSALS) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) -def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): +def test_mha_bwd_sink_dsink( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): """ Verify that mha_bwd correctly accumulates d_sink. @@ -100,7 +110,7 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty """ device = torch.device("cuda") hdim_v = hdim - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 q, k, v, dout = make_qkvo( batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device @@ -111,12 +121,19 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty # --- sink tensors --- # sink: [batch, nhead], uniform in [30, 60] in log-space - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) # --- backward --- dq, dk, dv, softmax_d = mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, dropout_p=0.0, softmax_scale=softmax_scale, is_causal=causal, @@ -136,19 +153,23 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty # Tolerances: fp16/bf16 are noisy; use relatively loose absolute tolerance # because sink values are large (exp() amplifies small differences) rtol = 0.02 - atol = 0.5 # absolute tolerance in float units for d_sink + atol = 0.5 # absolute tolerance in float units for d_sink torch.testing.assert_close( - d_sink, d_sink_ref, - rtol=rtol, atol=atol, + d_sink, + d_sink_ref, + rtol=rtol, + atol=atol, msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, " - f"B={batch}, Sq={seqlen_q}, H={nhead}" + f"B={batch}, Sq={seqlen_q}, H={nhead}", ) -@pytest.mark.parametrize("causal", CAUSALS) -@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("causal", CAUSALS) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) -def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): +def test_mha_bwd_with_sink_dq_dk_dv( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): """ Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs. @@ -158,7 +179,7 @@ def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, h """ device = torch.device("cuda") hdim_v = hdim - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 q, k, v, dout = make_qkvo( batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device @@ -178,16 +199,26 @@ def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, h # baseline: no sink dq_base, dk_base, dv_base, _ = mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, **common_bwd_args, ) # with sink = very negative values → exp(sink - lse) ≈ 0 → no effect sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq_sink, dk_sink, dv_sink, _ = mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, **common_bwd_args, sink=sink_small, d_sink=d_sink, @@ -195,9 +226,15 @@ def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, h # With negligible sink, gradients should match the no-sink baseline rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) - torch.testing.assert_close(dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink") - torch.testing.assert_close(dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink") - torch.testing.assert_close(dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink") + torch.testing.assert_close( + dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink" + ) + torch.testing.assert_close( + dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink" + ) + torch.testing.assert_close( + dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink" + ) @pytest.mark.parametrize("dtype", DTYPES) @@ -205,25 +242,43 @@ def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): """Passing sink=None must give identical output to omitting sink entirely.""" device = torch.device("cuda") batch, seqlen, nhead, hdim = 2, 64, 4, 64 - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 - q, k, v, dout = make_qkvo(batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device) + q, k, v, dout = make_qkvo( + batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device + ) out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) common = dict( - dropout_p=0.0, softmax_scale=softmax_scale, - is_causal=False, window_size_left=-1, window_size_right=-1, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=False, + window_size_left=-1, + window_size_right=-1, deterministic=False, ) - dq1, dk1, dv1, d1 = mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common) - dq2, dk2, dv2, d2 = mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common, - sink=None, d_sink=None) + dq1, dk1, dv1, d1 = mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common + ) + dq2, dk2, dv2, d2 = mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common, + sink=None, + d_sink=None, + ) torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") - torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") + torch.testing.assert_close( + d1, d2, msg="softmax_d differs with sink=None vs omitted" + ) def reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): @@ -242,18 +297,18 @@ def reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): offset = 0 for b, sq in enumerate(seqlens_q): - dout_b = dout[offset:offset + sq].float() # [sq, H, Dv] - out_b = out[offset:offset + sq].float() # [sq, H, Dv] - lse_b = lse_group[:, offset:offset + sq] # [H, sq] + dout_b = dout[offset : offset + sq].float() # [sq, H, Dv] + out_b = out[offset : offset + sq].float() # [sq, H, Dv] + lse_b = lse_group[:, offset : offset + sq] # [H, sq] # D[q, h] = sum_j(dout[q,h,j] * out[q,h,j]) - D_qh = (dout_b * out_b).sum(dim=-1) # [sq, H] - D_hq = D_qh.permute(1, 0) # [H, sq] + D_qh = (dout_b * out_b).sum(dim=-1) # [sq, H] + D_hq = D_qh.permute(1, 0) # [H, sq] # P_sink[h, q] = exp(sink[b, h] - lse[h, q]) p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) # [H, sq] - d_sink += (-p_sink * D_hq).sum(dim=-1) # [H] + d_sink += (-p_sink * D_hq).sum(dim=-1) # [H] offset += sq return d_sink @@ -275,17 +330,19 @@ def test_mha_varlen_bwd_sink_dsink(dtype): device = torch.device("cuda") batch, seqlen, nhead, hdim = 2, 64, 4, 64 hdim_v = hdim - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 seqlens_q = [seqlen] * batch - cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) + cu_seqlens_q = torch.tensor( + [0, seqlen, seqlen * 2], device=device, dtype=torch.int32 + ) cu_seqlens_k = cu_seqlens_q.clone() total_q = seqlen * batch total_k = seqlen * batch - q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) - k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) - v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) # forward (batch mode) → convert to group-mode shapes @@ -298,11 +355,18 @@ def test_mha_varlen_bwd_sink_dsink(dtype): # lse for group mode: [nhead, total_q] lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, _ = mha_varlen_bwd( - dout, q, k, v, out, lse, + dout, + q, + k, + v, + out, + lse, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=seqlen, @@ -326,8 +390,13 @@ def test_mha_varlen_bwd_sink_dsink(dtype): # numerical correctness vs reference d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) - torch.testing.assert_close(d_sink, d_sink_ref, rtol=0.02, atol=0.5, - msg="varlen d_sink mismatch vs reference") + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen d_sink mismatch vs reference", + ) @pytest.mark.parametrize("dtype", DTYPES) @@ -342,7 +411,7 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): device = torch.device("cuda") nhead, hdim = 4, 64 hdim_v = hdim - softmax_scale = hdim ** -0.5 + softmax_scale = hdim**-0.5 # variable lengths: batch 0 has 48 tokens, batch 1 has 80 tokens seqlens_q = [48, 80] @@ -353,38 +422,51 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): total_q = sum(seqlens_q) total_k = sum(seqlens_k) - cu_sq = torch.tensor([0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), - device=device, dtype=torch.int32) - cu_sk = torch.tensor([0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), - device=device, dtype=torch.int32) + cu_sq = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_sk = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), + device=device, + dtype=torch.int32, + ) - q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) - k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) - v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) # forward per batch segment (different seq lengths → can't use batch mode directly) out_parts, lse_parts = [], [] offset_q, offset_k = 0, 0 for sq, sk in zip(seqlens_q, seqlens_k): - q_b = q[offset_q:offset_q+sq].unsqueeze(0) - k_b = k[offset_k:offset_k+sk].unsqueeze(0) - v_b = v[offset_k:offset_k+sk].unsqueeze(0) + q_b = q[offset_q : offset_q + sq].unsqueeze(0) + k_b = k[offset_k : offset_k + sk].unsqueeze(0) + v_b = v[offset_k : offset_k + sk].unsqueeze(0) out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) - out_parts.append(out_b.squeeze(0)) # [sq, H, Dv] - lse_parts.append(lse_b.squeeze(0).permute(1, 0)) # [sq, H] + out_parts.append(out_b.squeeze(0)) # [sq, H, Dv] + lse_parts.append(lse_b.squeeze(0).permute(1, 0)) # [sq, H] offset_q += sq offset_k += sk - out = torch.cat(out_parts, dim=0) # [total_q, H, Dv] + out = torch.cat(out_parts, dim=0) # [total_q, H, Dv] # group-mode lse: [H, total_q] lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() # [H, total_q] - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, _ = mha_varlen_bwd( - dout, q, k, v, out, lse, + dout, + q, + k, + v, + out, + lse, cu_seqlens_q=cu_sq, cu_seqlens_k=cu_sk, max_seqlen_q=max_seqlen_q, @@ -405,5 +487,10 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): # reference d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) - torch.testing.assert_close(d_sink, d_sink_ref, rtol=0.02, atol=0.5, - msg="varlen variable-length d_sink mismatch") + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen variable-length d_sink mismatch", + ) From 4ac64dbbd8a091168eb11d7895a5edd4715d52e2 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 02:23:30 -0500 Subject: [PATCH 5/6] test: move sink bwd tests into test_mha.py and test_mha_varlen.py --- op_tests/test_mha.py | 181 +++++++++++++ op_tests/test_mha_sink_bwd.py | 496 ---------------------------------- op_tests/test_mha_varlen.py | 198 ++++++++++++++ 3 files changed, 379 insertions(+), 496 deletions(-) delete mode 100644 op_tests/test_mha_sink_bwd.py diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 7c15eff72b..4656a49aed 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -860,3 +860,184 @@ def test_flash_attn_seq_padding( df = pd.DataFrame(collected) aiter.logger.info(f"mha summary:\n{df}") + + +# --------------------------------------------------------------------------- +# Sink backward tests (mha_bwd with sink / d_sink) +# +# Reference formula (derived from kernel block_fmha_bwd_dot_do_o.hpp): +# D[b, h, q] = sum_j(dout[b, q, h, j] * out[b, q, h, j]) * p_undrop +# P_sink[b, h, q] = exp(sink[b, h] - lse_fwd[b, h, q]) +# d_sink[h] = sum_{b, q} (-P_sink[b, h, q] * D[b, h, q]) +# --------------------------------------------------------------------------- + + +def _sink_make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): + """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" + q = torch.randn(batch, seqlen_q, nhead, hdim, device=device, dtype=dtype).requires_grad_(True) + k = torch.randn(batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype).requires_grad_(True) + v = torch.randn(batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype).requires_grad_(True) + dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) + return q, k, v, dout + + +def _sink_run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = aiter.mha_fwd( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def _sink_reference_d_sink(dout, out, lse, sink, p_undrop=1.0): + """ + Pure-PyTorch reference for d_sink. + + dout : [B, Sq, H, Dv] + out : [B, Sq, H, Dv] + lse : [B, H, Sq] (forward LSE without sink) + sink : [B, H] + returns d_sink : [H] + """ + D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] + D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] + sink_bhs = sink.unsqueeze(-1) # [B, H, 1] + p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] + d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] + return d_sink.float() + + +_SINK_DTYPES = [dtypes.fp16, dtypes.bf16] +_SINK_CAUSALS = [False, True] +_SINK_CONFIGS = [ + # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) + (2, 128, 128, 4, 4, 64), + (1, 64, 64, 6, 2, 128), +] + + +@pytest.mark.parametrize("causal", _SINK_CAUSALS) +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) +def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): + """Verify that mha_bwd correctly accumulates d_sink.""" + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, softmax_d = aiter.mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert d_sink.abs().max() > 0, "d_sink was not updated by mha_bwd" + + d_sink_ref = _sink_reference_d_sink(dout, out, lse, sink) + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, B={batch}, Sq={seqlen_q}, H={nhead}", + ) + + +@pytest.mark.parametrize("causal", _SINK_CAUSALS) +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) +def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): + """Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs.""" + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + common_bwd_args = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + ) + + dq_base, dk_base, dv_base, _ = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common_bwd_args + ) + + sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq_sink, dk_sink, dv_sink, _ = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common_bwd_args, + sink=sink_small, d_sink=d_sink, + ) + + rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) + torch.testing.assert_close(dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink") + torch.testing.assert_close(dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink") + torch.testing.assert_close(dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink") + + +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): + """Passing sink=None must give identical output to omitting sink entirely.""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo(batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) + + common = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + ) + + dq1, dk1, dv1, d1 = aiter.mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common) + dq2, dk2, dv2, d2 = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common, sink=None, d_sink=None + ) + + torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") + torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") + torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") + torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") diff --git a/op_tests/test_mha_sink_bwd.py b/op_tests/test_mha_sink_bwd.py deleted file mode 100644 index 422e35af49..0000000000 --- a/op_tests/test_mha_sink_bwd.py +++ /dev/null @@ -1,496 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# -# Tests for mha_bwd / mha_varlen_bwd with sink gradient support. -# -# The sink_bwd feature adds two arguments to mha_bwd: -# sink : [batch, nhead] float32 – per-batch-per-head log-space sink score -# d_sink : [nhead] float32 – accumulator for the sink gradient (output) -# -# Reference formula (derived from kernel block_fmha_bwd_dot_do_o.hpp): -# D[b, h, q] = sum_j(dout[b, q, h, j] * out[b, q, h, j]) * p_undrop -# P_sink[b, h, q] = exp(sink[b, h] - lse_fwd[b, h, q]) -# d_sink[h] = sum_{b, q} (-P_sink[b, h, q] * D[b, h, q]) - -import pytest -import torch - -from aiter import dtypes, mha_bwd, mha_fwd, mha_varlen_bwd - -# --------------------------------------------------------------------------- -# helpers -# --------------------------------------------------------------------------- - - -def make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): - """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" - q = torch.randn( - batch, seqlen_q, nhead, hdim, device=device, dtype=dtype - ).requires_grad_(True) - k = torch.randn( - batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype - ).requires_grad_(True) - v = torch.randn( - batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype - ).requires_grad_(True) - dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) - return q, k, v, dout - - -def run_fwd(q, k, v, softmax_scale, causal): - """Run mha_fwd and return (out, lse).""" - out, lse, _, _ = mha_fwd( - q, - k, - v, - dropout_p=0.0, - softmax_scale=softmax_scale, - is_causal=causal, - window_size_left=-1, - window_size_right=0 if causal else -1, - sink_size=0, - return_softmax_lse=True, - return_dropout_randval=False, - ) - return out, lse - - -def reference_d_sink(dout, out, lse, sink, p_undrop=1.0): - """ - Pure-PyTorch reference for d_sink. - - dout : [B, Sq, H, Dv] - out : [B, Sq, H, Dv] - lse : [B, H, Sq] (forward LSE without sink) - sink : [B, H] - returns d_sink : [H] - """ - # D[b, q, h] = sum_j(dout * out) * p_undrop -> shape [B, Sq, H] - D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] - # reorder to [B, H, Sq] to align with lse - D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] - - # P_sink[b, h, q] = exp(sink[b, h] - lse[b, h, q]) - sink_bhs = sink.unsqueeze(-1) # [B, H, 1] - p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] - - # d_sink[h] = sum_{b, q} (-P_sink * D) - d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] - return d_sink.float() - - -# --------------------------------------------------------------------------- -# parametrize -# --------------------------------------------------------------------------- - -DTYPES = [dtypes.fp16, dtypes.bf16] -CAUSALS = [False, True] -CONFIGS = [ - # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) - (2, 128, 128, 4, 4, 64), - (1, 64, 64, 6, 2, 128), -] - - -@pytest.mark.parametrize("causal", CAUSALS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) -def test_mha_bwd_sink_dsink( - batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal -): - """ - Verify that mha_bwd correctly accumulates d_sink. - - Strategy - -------- - 1. Run mha_fwd to obtain (out, lse). - 2. Create a random sink tensor in log-space [30, 60] and a zero d_sink buffer. - 3. Call mha_bwd with sink/d_sink. - 4. Compare the kernel d_sink with the PyTorch reference. - """ - device = torch.device("cuda") - hdim_v = hdim - softmax_scale = hdim**-0.5 - - q, k, v, dout = make_qkvo( - batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device - ) - - # --- forward --- - out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) - - # --- sink tensors --- - # sink: [batch, nhead], uniform in [30, 60] in log-space - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( - 30.0, 60.0 - ) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) - - # --- backward --- - dq, dk, dv, softmax_d = mha_bwd( - dout, - q.detach(), - k.detach(), - v.detach(), - out, - lse, - dropout_p=0.0, - softmax_scale=softmax_scale, - is_causal=causal, - window_size_left=-1, - window_size_right=0 if causal else -1, - deterministic=False, - sink=sink, - d_sink=d_sink, - ) - - # d_sink must have been written (non-zero for non-trivial inputs) - assert d_sink.abs().max() > 0, "d_sink was not updated by mha_bwd" - - # --- reference --- - d_sink_ref = reference_d_sink(dout, out, lse, sink) - - # Tolerances: fp16/bf16 are noisy; use relatively loose absolute tolerance - # because sink values are large (exp() amplifies small differences) - rtol = 0.02 - atol = 0.5 # absolute tolerance in float units for d_sink - torch.testing.assert_close( - d_sink, - d_sink_ref, - rtol=rtol, - atol=atol, - msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, " - f"B={batch}, Sq={seqlen_q}, H={nhead}", - ) - - -@pytest.mark.parametrize("causal", CAUSALS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", CONFIGS) -def test_mha_bwd_with_sink_dq_dk_dv( - batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal -): - """ - Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs. - - We compare mha_bwd with sink=None (baseline) against mha_bwd with a - near-zero sink (small values so the rescaling is negligible). - The gradients should be numerically close. - """ - device = torch.device("cuda") - hdim_v = hdim - softmax_scale = hdim**-0.5 - - q, k, v, dout = make_qkvo( - batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device - ) - - # --- forward --- - out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) - - common_bwd_args = dict( - dropout_p=0.0, - softmax_scale=softmax_scale, - is_causal=causal, - window_size_left=-1, - window_size_right=0 if causal else -1, - deterministic=False, - ) - - # baseline: no sink - dq_base, dk_base, dv_base, _ = mha_bwd( - dout, - q.detach(), - k.detach(), - v.detach(), - out, - lse, - **common_bwd_args, - ) - - # with sink = very negative values → exp(sink - lse) ≈ 0 → no effect - sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) - - dq_sink, dk_sink, dv_sink, _ = mha_bwd( - dout, - q.detach(), - k.detach(), - v.detach(), - out, - lse, - **common_bwd_args, - sink=sink_small, - d_sink=d_sink, - ) - - # With negligible sink, gradients should match the no-sink baseline - rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) - torch.testing.assert_close( - dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink" - ) - torch.testing.assert_close( - dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink" - ) - torch.testing.assert_close( - dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink" - ) - - -@pytest.mark.parametrize("dtype", DTYPES) -def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): - """Passing sink=None must give identical output to omitting sink entirely.""" - device = torch.device("cuda") - batch, seqlen, nhead, hdim = 2, 64, 4, 64 - softmax_scale = hdim**-0.5 - - q, k, v, dout = make_qkvo( - batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device - ) - out, lse = run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) - - common = dict( - dropout_p=0.0, - softmax_scale=softmax_scale, - is_causal=False, - window_size_left=-1, - window_size_right=-1, - deterministic=False, - ) - - dq1, dk1, dv1, d1 = mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, **common - ) - dq2, dk2, dv2, d2 = mha_bwd( - dout, - q.detach(), - k.detach(), - v.detach(), - out, - lse, - **common, - sink=None, - d_sink=None, - ) - - torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") - torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") - torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") - torch.testing.assert_close( - d1, d2, msg="softmax_d differs with sink=None vs omitted" - ) - - -def reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): - """ - Reference d_sink for varlen mode. - - dout : [total_q, H, Dv] - out : [total_q, H, Dv] - lse_group : [H, total_q] – group-mode LSE (flattened across batches) - sink : [B, H] - seqlens_q : list of per-batch sequence lengths - returns d_sink : [H] - """ - nhead = sink.shape[1] - d_sink = torch.zeros(nhead, device=sink.device, dtype=torch.float32) - - offset = 0 - for b, sq in enumerate(seqlens_q): - dout_b = dout[offset : offset + sq].float() # [sq, H, Dv] - out_b = out[offset : offset + sq].float() # [sq, H, Dv] - lse_b = lse_group[:, offset : offset + sq] # [H, sq] - - # D[q, h] = sum_j(dout[q,h,j] * out[q,h,j]) - D_qh = (dout_b * out_b).sum(dim=-1) # [sq, H] - D_hq = D_qh.permute(1, 0) # [H, sq] - - # P_sink[h, q] = exp(sink[b, h] - lse[h, q]) - p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) # [H, sq] - - d_sink += (-p_sink * D_hq).sum(dim=-1) # [H] - offset += sq - - return d_sink - - -@pytest.mark.parametrize("dtype", DTYPES) -def test_mha_varlen_bwd_sink_dsink(dtype): - """ - Numerical correctness test: mha_varlen_bwd with sink/d_sink. - - Verifies: - 1. d_sink values match the per-batch reference computation. - 2. Handles equal-length sequences correctly. - - In group (varlen) mode the CK kernel expects: - lse : [nhead, total_q] (flattened across batches per head) - sink : [batch, nhead] (one log-space score per batch-head pair) - """ - device = torch.device("cuda") - batch, seqlen, nhead, hdim = 2, 64, 4, 64 - hdim_v = hdim - softmax_scale = hdim**-0.5 - seqlens_q = [seqlen] * batch - - cu_seqlens_q = torch.tensor( - [0, seqlen, seqlen * 2], device=device, dtype=torch.int32 - ) - cu_seqlens_k = cu_seqlens_q.clone() - total_q = seqlen * batch - total_k = seqlen * batch - - q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) - k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) - v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) - dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) - - # forward (batch mode) → convert to group-mode shapes - q_b = q.view(batch, seqlen, nhead, hdim) - k_b = k.view(batch, seqlen, nhead, hdim) - v_b = v.view(batch, seqlen, nhead, hdim_v) - out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) - - out = out_b.view(total_q, nhead, hdim_v) - # lse for group mode: [nhead, total_q] - lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() - - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( - 30.0, 60.0 - ) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) - - dq, dk, dv, _ = mha_varlen_bwd( - dout, - q, - k, - v, - out, - lse, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=seqlen, - max_seqlen_k=seqlen, - dropout_p=0.0, - softmax_scale=softmax_scale, - zero_tensors=False, - is_causal=False, - window_size_left=-1, - window_size_right=-1, - deterministic=False, - sink=sink, - d_sink=d_sink, - ) - - assert torch.isfinite(d_sink).all(), f"d_sink contains non-finite values: {d_sink}" - assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" - assert dq.shape == q.shape - assert dk.shape == k.shape - assert dv.shape == v.shape - - # numerical correctness vs reference - d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) - torch.testing.assert_close( - d_sink, - d_sink_ref, - rtol=0.02, - atol=0.5, - msg="varlen d_sink mismatch vs reference", - ) - - -@pytest.mark.parametrize("dtype", DTYPES) -def test_mha_varlen_bwd_sink_variable_lengths(dtype): - """ - Varlen sink test with variable-length sequences per batch entry. - - Ensures: - - Kernel correctly uses seqstart_q to determine per-batch sink values. - - d_sink accumulates correctly across batches with different lengths. - """ - device = torch.device("cuda") - nhead, hdim = 4, 64 - hdim_v = hdim - softmax_scale = hdim**-0.5 - - # variable lengths: batch 0 has 48 tokens, batch 1 has 80 tokens - seqlens_q = [48, 80] - seqlens_k = [48, 80] - batch = len(seqlens_q) - max_seqlen_q = max(seqlens_q) - max_seqlen_k = max(seqlens_k) - total_q = sum(seqlens_q) - total_k = sum(seqlens_k) - - cu_sq = torch.tensor( - [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), - device=device, - dtype=torch.int32, - ) - cu_sk = torch.tensor( - [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), - device=device, - dtype=torch.int32, - ) - - q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) - k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) - v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) - dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) - - # forward per batch segment (different seq lengths → can't use batch mode directly) - out_parts, lse_parts = [], [] - offset_q, offset_k = 0, 0 - for sq, sk in zip(seqlens_q, seqlens_k): - q_b = q[offset_q : offset_q + sq].unsqueeze(0) - k_b = k[offset_k : offset_k + sk].unsqueeze(0) - v_b = v[offset_k : offset_k + sk].unsqueeze(0) - out_b, lse_b = run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) - out_parts.append(out_b.squeeze(0)) # [sq, H, Dv] - lse_parts.append(lse_b.squeeze(0).permute(1, 0)) # [sq, H] - offset_q += sq - offset_k += sk - - out = torch.cat(out_parts, dim=0) # [total_q, H, Dv] - # group-mode lse: [H, total_q] - lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() # [H, total_q] - - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( - 30.0, 60.0 - ) - d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) - - dq, dk, dv, _ = mha_varlen_bwd( - dout, - q, - k, - v, - out, - lse, - cu_seqlens_q=cu_sq, - cu_seqlens_k=cu_sk, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=0.0, - softmax_scale=softmax_scale, - zero_tensors=False, - is_causal=False, - window_size_left=-1, - window_size_right=-1, - deterministic=False, - sink=sink, - d_sink=d_sink, - ) - - assert torch.isfinite(d_sink).all(), f"d_sink has non-finite values: {d_sink}" - assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" - - # reference - d_sink_ref = reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) - torch.testing.assert_close( - d_sink, - d_sink_ref, - rtol=0.02, - atol=0.5, - msg="varlen variable-length d_sink mismatch", - ) diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 5027753822..10d673d606 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -1101,3 +1101,201 @@ def varlen_flash_attn_seq_padding_benchmark( df_padding = pd.DataFrame(padding_collected) aiter.logger.info(f"mha_varlen_seq_padding summary:\n{df_padding}") + + +# --------------------------------------------------------------------------- +# Sink backward tests (mha_varlen_bwd with sink / d_sink) +# --------------------------------------------------------------------------- + + +def _vsink_run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = aiter.mha_fwd( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def _vsink_reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): + """ + Reference d_sink for varlen mode. + + dout : [total_q, H, Dv] + out : [total_q, H, Dv] + lse_group : [H, total_q] – group-mode LSE (flattened across batches) + sink : [B, H] + seqlens_q : list of per-batch sequence lengths + returns d_sink : [H] + """ + nhead = sink.shape[1] + d_sink = torch.zeros(nhead, device=sink.device, dtype=torch.float32) + + offset = 0 + for b, sq in enumerate(seqlens_q): + dout_b = dout[offset : offset + sq].float() + out_b = out[offset : offset + sq].float() + lse_b = lse_group[:, offset : offset + sq] + + D_qh = (dout_b * out_b).sum(dim=-1) + D_hq = D_qh.permute(1, 0) + p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) + d_sink += (-p_sink * D_hq).sum(dim=-1) + offset += sq + + return d_sink + + +_VSINK_DTYPES = [dtypes.fp16, dtypes.bf16] + + +@pytest.mark.parametrize("dtype", _VSINK_DTYPES) +def test_mha_varlen_bwd_sink_dsink(dtype): + """Numerical correctness test: mha_varlen_bwd with sink/d_sink (equal-length sequences).""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + hdim_v = hdim + softmax_scale = hdim**-0.5 + seqlens_q = [seqlen] * batch + + cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) + cu_seqlens_k = cu_seqlens_q.clone() + total_q = seqlen * batch + total_k = seqlen * batch + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + q_b = q.view(batch, seqlen, nhead, hdim) + k_b = k.view(batch, seqlen, nhead, hdim) + v_b = v.view(batch, seqlen, nhead, hdim_v) + out_b, lse_b = _vsink_run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + + out = out_b.view(total_q, nhead, hdim_v) + lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = aiter.mha_varlen_bwd( + dout, + q, + k, + v, + out, + lse, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink contains non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + assert dq.shape == q.shape + assert dk.shape == k.shape + assert dv.shape == v.shape + + d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close( + d_sink, d_sink_ref, rtol=0.02, atol=0.5, msg="varlen d_sink mismatch vs reference" + ) + + +@pytest.mark.parametrize("dtype", _VSINK_DTYPES) +def test_mha_varlen_bwd_sink_variable_lengths(dtype): + """Varlen sink test with variable-length sequences per batch entry.""" + device = torch.device("cuda") + nhead, hdim = 4, 64 + hdim_v = hdim + softmax_scale = hdim**-0.5 + + seqlens_q = [48, 80] + seqlens_k = [48, 80] + batch = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_sq = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + device=device, dtype=torch.int32, + ) + cu_sk = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), + device=device, dtype=torch.int32, + ) + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + out_parts, lse_parts = [], [] + offset_q, offset_k = 0, 0 + for sq, sk in zip(seqlens_q, seqlens_k): + q_b = q[offset_q : offset_q + sq].unsqueeze(0) + k_b = k[offset_k : offset_k + sk].unsqueeze(0) + v_b = v[offset_k : offset_k + sk].unsqueeze(0) + out_b, lse_b = _vsink_run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + out_parts.append(out_b.squeeze(0)) + lse_parts.append(lse_b.squeeze(0).permute(1, 0)) + offset_q += sq + offset_k += sk + + out = torch.cat(out_parts, dim=0) + lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = aiter.mha_varlen_bwd( + dout, + q, + k, + v, + out, + lse, + cu_seqlens_q=cu_sq, + cu_seqlens_k=cu_sk, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink has non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + + d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close( + d_sink, d_sink_ref, rtol=0.02, atol=0.5, msg="varlen variable-length d_sink mismatch" + ) From 3690ad19d5784faf2f928fc5a64d916e048c494a Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 18 Mar 2026 02:31:48 -0500 Subject: [PATCH 6/6] style: apply black formatting to sink bwd tests in test_mha and test_mha_varlen --- op_tests/test_mha.py | 73 +++++++++++++++++++++++++++++-------- op_tests/test_mha_varlen.py | 30 +++++++++++---- 2 files changed, 80 insertions(+), 23 deletions(-) diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 4656a49aed..47b5342cfe 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -872,11 +872,19 @@ def test_flash_attn_seq_padding( # --------------------------------------------------------------------------- -def _sink_make_qkvo(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device): +def _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device +): """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" - q = torch.randn(batch, seqlen_q, nhead, hdim, device=device, dtype=dtype).requires_grad_(True) - k = torch.randn(batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype).requires_grad_(True) - v = torch.randn(batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype).requires_grad_(True) + q = torch.randn( + batch, seqlen_q, nhead, hdim, device=device, dtype=dtype + ).requires_grad_(True) + k = torch.randn( + batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype + ).requires_grad_(True) + v = torch.randn( + batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype + ).requires_grad_(True) dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) return q, k, v, dout @@ -929,7 +937,9 @@ def _sink_reference_d_sink(dout, out, lse, sink, p_undrop=1.0): @pytest.mark.parametrize("causal", _SINK_CAUSALS) @pytest.mark.parametrize("dtype", _SINK_DTYPES) @pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) -def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): +def test_mha_bwd_sink_dsink( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): """Verify that mha_bwd correctly accumulates d_sink.""" device = torch.device("cuda") hdim_v = hdim @@ -940,7 +950,9 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty ) out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, softmax_d = aiter.mha_bwd( @@ -975,7 +987,9 @@ def test_mha_bwd_sink_dsink(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dty @pytest.mark.parametrize("causal", _SINK_CAUSALS) @pytest.mark.parametrize("dtype", _SINK_DTYPES) @pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) -def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal): +def test_mha_bwd_with_sink_dq_dk_dv( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): """Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs.""" device = torch.device("cuda") hdim_v = hdim @@ -1003,14 +1017,27 @@ def test_mha_bwd_with_sink_dq_dk_dv(batch, seqlen_q, seqlen_k, nhead, nhead_k, h d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq_sink, dk_sink, dv_sink, _ = aiter.mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, **common_bwd_args, - sink=sink_small, d_sink=d_sink, + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common_bwd_args, + sink=sink_small, + d_sink=d_sink, ) rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) - torch.testing.assert_close(dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink") - torch.testing.assert_close(dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink") - torch.testing.assert_close(dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink") + torch.testing.assert_close( + dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink" + ) + torch.testing.assert_close( + dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink" + ) + torch.testing.assert_close( + dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink" + ) @pytest.mark.parametrize("dtype", _SINK_DTYPES) @@ -1020,7 +1047,9 @@ def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): batch, seqlen, nhead, hdim = 2, 64, 4, 64 softmax_scale = hdim**-0.5 - q, k, v, dout = _sink_make_qkvo(batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device) + q, k, v, dout = _sink_make_qkvo( + batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device + ) out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) common = dict( @@ -1032,12 +1061,24 @@ def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): deterministic=False, ) - dq1, dk1, dv1, d1 = aiter.mha_bwd(dout, q.detach(), k.detach(), v.detach(), out, lse, **common) + dq1, dk1, dv1, d1 = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common + ) dq2, dk2, dv2, d2 = aiter.mha_bwd( - dout, q.detach(), k.detach(), v.detach(), out, lse, **common, sink=None, d_sink=None + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common, + sink=None, + d_sink=None, ) torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") - torch.testing.assert_close(d1, d2, msg="softmax_d differs with sink=None vs omitted") + torch.testing.assert_close( + d1, d2, msg="softmax_d differs with sink=None vs omitted" + ) diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 10d673d606..a523580184 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -1167,7 +1167,9 @@ def test_mha_varlen_bwd_sink_dsink(dtype): softmax_scale = hdim**-0.5 seqlens_q = [seqlen] * batch - cu_seqlens_q = torch.tensor([0, seqlen, seqlen * 2], device=device, dtype=torch.int32) + cu_seqlens_q = torch.tensor( + [0, seqlen, seqlen * 2], device=device, dtype=torch.int32 + ) cu_seqlens_k = cu_seqlens_q.clone() total_q = seqlen * batch total_k = seqlen * batch @@ -1185,7 +1187,9 @@ def test_mha_varlen_bwd_sink_dsink(dtype): out = out_b.view(total_q, nhead, hdim_v) lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, _ = aiter.mha_varlen_bwd( @@ -1218,7 +1222,11 @@ def test_mha_varlen_bwd_sink_dsink(dtype): d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) torch.testing.assert_close( - d_sink, d_sink_ref, rtol=0.02, atol=0.5, msg="varlen d_sink mismatch vs reference" + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen d_sink mismatch vs reference", ) @@ -1240,11 +1248,13 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): cu_sq = torch.tensor( [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), - device=device, dtype=torch.int32, + device=device, + dtype=torch.int32, ) cu_sk = torch.tensor( [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), - device=device, dtype=torch.int32, + device=device, + dtype=torch.int32, ) q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) @@ -1267,7 +1277,9 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): out = torch.cat(out_parts, dim=0) lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() - sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_(30.0, 60.0) + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) dq, dk, dv, _ = aiter.mha_varlen_bwd( @@ -1297,5 +1309,9 @@ def test_mha_varlen_bwd_sink_variable_lengths(dtype): d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) torch.testing.assert_close( - d_sink, d_sink_ref, rtol=0.02, atol=0.5, msg="varlen variable-length d_sink mismatch" + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen variable-length d_sink mismatch", )