diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 2424cb7955..02e9bba651 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -624,7 +624,8 @@ def cmdGenFunc_mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ): md_name = "mha_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -775,7 +776,8 @@ def gen_mha_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv) @@ -807,7 +809,8 @@ def mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -889,7 +892,8 @@ def cmdGenFunc_mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> dict[str, Any]: md_name = "mha_varlen_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -1117,7 +1121,8 @@ def gen_mha_varlen_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return gen_mha_varlen_bwd_fake_tensors_common( q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv @@ -1157,7 +1162,8 @@ def mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 00952b7f06..b96c28ce7e 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -168,6 +168,8 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* dv_ptr */ a.dv_ptr, /* dbias_ptr */ a.dbias_ptr, /* dq_acc_ptr */ a.dq_acc_ptr, + /* sink_ptr */ a.sink_ptr, + /* d_sink_ptr */ a.d_sink_ptr, /* seqstart_q_ptr */ a.seqstart_q_ptr, /* seqstart_k_ptr */ a.seqstart_k_ptr, diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index 1ec94cb10e..235564aed3 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -47,6 +47,8 @@ struct mha_bwd_args void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; + const void* sink_ptr = nullptr; // sink scores [batch, nhead] log-space (LSEDataType=float); nullptr disables sink + void* d_sink_ptr = nullptr; // sink gradient accumulator [nhead] (LSEDataType=float); nullptr disables sink grad // Usage notes for sequence length pointer parameters: // // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 76c03b4331..0b0c032670 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -873,7 +873,9 @@ namespace py = pybind11; py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("rng_state") = std::nullopt, \ - py::arg("gen") = std::nullopt); + py::arg("gen") = std::nullopt, \ + py::arg("sink") = std::nullopt, \ + py::arg("d_sink") = std::nullopt); #define MHA_FWD_ASM_PYBIND \ m.def("fmha_v3_fwd", \ @@ -1005,7 +1007,9 @@ namespace py = pybind11; py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt, \ py::arg("cu_seqlens_q_padded") = std::nullopt, \ - py::arg("cu_seqlens_k_padded") = std::nullopt); + py::arg("cu_seqlens_k_padded") = std::nullopt, \ + py::arg("sink") = std::nullopt, \ + py::arg("d_sink") = std::nullopt); #define MOE_CK_2STAGES_PYBIND \ m.def("ck_moe_stage1", \ diff --git a/csrc/include/torch/mha_bwd.h b/csrc/include/torch/mha_bwd.h index 5b1ea2c098..ce8fcf8ca1 100644 --- a/csrc/include/torch/mha_bwd.h +++ b/csrc/include/torch/mha_bwd.h @@ -24,6 +24,8 @@ std::vector mha_bwd(const at::Tensor& dout, // [b, sq, hq, d] std::optional bias_, // [sq, sk] std::optional alibi_slopes, // [hq] or [b, hq] std::optional rng_state, - std::optional gen); + std::optional gen, + std::optional sink, // [b, hq] log-space sink scores (float) + std::optional d_sink); // [hq] sink gradient output (float) } // namespace torch_itfs } // namespace aiter diff --git a/csrc/include/torch/mha_varlen_bwd.h b/csrc/include/torch/mha_varlen_bwd.h index ac78ec2fb3..e5ba6f0754 100644 --- a/csrc/include/torch/mha_varlen_bwd.h +++ b/csrc/include/torch/mha_varlen_bwd.h @@ -30,7 +30,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d] std::optional rng_state, std::optional gen, std::optional cu_seqlens_q_padded, // [b+1] - std::optional cu_seqlens_k_padded // [b+1] + std::optional cu_seqlens_k_padded, // [b+1] + std::optional sink, // [b, hq] log-space sink scores (float) + std::optional d_sink // [hq] sink gradient output (float) ); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index d5f05f4e46..aba8755f5a 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -31,7 +31,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] std::optional bias_, // [sq, sk] std::optional alibi_slopes_, // [hq] or [b, hq] std::optional rng_state_, - std::optional gen_) + std::optional gen_, + std::optional sink_, // [b, hq] log-space sink scores (float) + std::optional d_sink_) // [hq] sink gradient output (float) { if (is_causal) { window_size_right = 0; } @@ -198,6 +200,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] hipLaunchKernelGGL( aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, reinterpret_cast(rng_state.data_ptr())); + } else { + // No dropout: allocate a dummy tensor so data_ptr() is always valid. + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (seqlen_q > 0) { @@ -329,6 +334,8 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] dv_expanded.data_ptr(), dbias_ptr, dq_accum.data_ptr(), + (sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq] + (d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq] nullptr, // seqstart_q_ptr nullptr, // seqstart_k_ptr nullptr, // seqlen_q_ptr diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index fc1de89635..260e3471f3 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -36,8 +36,9 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] std::optional rng_state_, std::optional gen_, std::optional cu_seqlens_q_padded, // [b+1] - std::optional cu_seqlens_k_padded // [b+1] - ) + std::optional cu_seqlens_k_padded, // [b+1] + std::optional sink_, // [b, hq] log-space sink scores (float) + std::optional d_sink_) // [hq] sink gradient output (float) { if (is_causal) { window_size_right = 0; } @@ -341,6 +342,8 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), // dq_acc + (sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq] + (d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq] seqstart_q_ptr, // seqstart_q_ptr (physical cumulative) seqstart_k_ptr, // seqstart_k_ptr (physical cumulative) nullptr, // seqlen_q_ptr (per-sequence logical) diff --git a/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py b/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py index b09a0ccd77..beef57c5e1 100644 --- a/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py +++ b/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py @@ -94,6 +94,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype: str) -> Optional[dict] typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::LSEDataType, /* BlockSize = */ 64, {F_hdim}, {F_mode}, diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 7c15eff72b..47b5342cfe 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -860,3 +860,225 @@ def test_flash_attn_seq_padding( df = pd.DataFrame(collected) aiter.logger.info(f"mha summary:\n{df}") + + +# --------------------------------------------------------------------------- +# Sink backward tests (mha_bwd with sink / d_sink) +# +# Reference formula (derived from kernel block_fmha_bwd_dot_do_o.hpp): +# D[b, h, q] = sum_j(dout[b, q, h, j] * out[b, q, h, j]) * p_undrop +# P_sink[b, h, q] = exp(sink[b, h] - lse_fwd[b, h, q]) +# d_sink[h] = sum_{b, q} (-P_sink[b, h, q] * D[b, h, q]) +# --------------------------------------------------------------------------- + + +def _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device +): + """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" + q = torch.randn( + batch, seqlen_q, nhead, hdim, device=device, dtype=dtype + ).requires_grad_(True) + k = torch.randn( + batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype + ).requires_grad_(True) + v = torch.randn( + batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype + ).requires_grad_(True) + dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) + return q, k, v, dout + + +def _sink_run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = aiter.mha_fwd( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def _sink_reference_d_sink(dout, out, lse, sink, p_undrop=1.0): + """ + Pure-PyTorch reference for d_sink. + + dout : [B, Sq, H, Dv] + out : [B, Sq, H, Dv] + lse : [B, H, Sq] (forward LSE without sink) + sink : [B, H] + returns d_sink : [H] + """ + D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] + D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] + sink_bhs = sink.unsqueeze(-1) # [B, H, 1] + p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] + d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] + return d_sink.float() + + +_SINK_DTYPES = [dtypes.fp16, dtypes.bf16] +_SINK_CAUSALS = [False, True] +_SINK_CONFIGS = [ + # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) + (2, 128, 128, 4, 4, 64), + (1, 64, 64, 6, 2, 128), +] + + +@pytest.mark.parametrize("causal", _SINK_CAUSALS) +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) +def test_mha_bwd_sink_dsink( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): + """Verify that mha_bwd correctly accumulates d_sink.""" + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, softmax_d = aiter.mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert d_sink.abs().max() > 0, "d_sink was not updated by mha_bwd" + + d_sink_ref = _sink_reference_d_sink(dout, out, lse, sink) + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, B={batch}, Sq={seqlen_q}, H={nhead}", + ) + + +@pytest.mark.parametrize("causal", _SINK_CAUSALS) +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) +def test_mha_bwd_with_sink_dq_dk_dv( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): + """Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs.""" + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + common_bwd_args = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + ) + + dq_base, dk_base, dv_base, _ = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common_bwd_args + ) + + sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq_sink, dk_sink, dv_sink, _ = aiter.mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common_bwd_args, + sink=sink_small, + d_sink=d_sink, + ) + + rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) + torch.testing.assert_close( + dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink" + ) + torch.testing.assert_close( + dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink" + ) + torch.testing.assert_close( + dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink" + ) + + +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): + """Passing sink=None must give identical output to omitting sink entirely.""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) + + common = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + ) + + dq1, dk1, dv1, d1 = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common + ) + dq2, dk2, dv2, d2 = aiter.mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common, + sink=None, + d_sink=None, + ) + + torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") + torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") + torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") + torch.testing.assert_close( + d1, d2, msg="softmax_d differs with sink=None vs omitted" + ) diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 5027753822..a523580184 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -1101,3 +1101,217 @@ def varlen_flash_attn_seq_padding_benchmark( df_padding = pd.DataFrame(padding_collected) aiter.logger.info(f"mha_varlen_seq_padding summary:\n{df_padding}") + + +# --------------------------------------------------------------------------- +# Sink backward tests (mha_varlen_bwd with sink / d_sink) +# --------------------------------------------------------------------------- + + +def _vsink_run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = aiter.mha_fwd( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def _vsink_reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): + """ + Reference d_sink for varlen mode. + + dout : [total_q, H, Dv] + out : [total_q, H, Dv] + lse_group : [H, total_q] – group-mode LSE (flattened across batches) + sink : [B, H] + seqlens_q : list of per-batch sequence lengths + returns d_sink : [H] + """ + nhead = sink.shape[1] + d_sink = torch.zeros(nhead, device=sink.device, dtype=torch.float32) + + offset = 0 + for b, sq in enumerate(seqlens_q): + dout_b = dout[offset : offset + sq].float() + out_b = out[offset : offset + sq].float() + lse_b = lse_group[:, offset : offset + sq] + + D_qh = (dout_b * out_b).sum(dim=-1) + D_hq = D_qh.permute(1, 0) + p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) + d_sink += (-p_sink * D_hq).sum(dim=-1) + offset += sq + + return d_sink + + +_VSINK_DTYPES = [dtypes.fp16, dtypes.bf16] + + +@pytest.mark.parametrize("dtype", _VSINK_DTYPES) +def test_mha_varlen_bwd_sink_dsink(dtype): + """Numerical correctness test: mha_varlen_bwd with sink/d_sink (equal-length sequences).""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + hdim_v = hdim + softmax_scale = hdim**-0.5 + seqlens_q = [seqlen] * batch + + cu_seqlens_q = torch.tensor( + [0, seqlen, seqlen * 2], device=device, dtype=torch.int32 + ) + cu_seqlens_k = cu_seqlens_q.clone() + total_q = seqlen * batch + total_k = seqlen * batch + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + q_b = q.view(batch, seqlen, nhead, hdim) + k_b = k.view(batch, seqlen, nhead, hdim) + v_b = v.view(batch, seqlen, nhead, hdim_v) + out_b, lse_b = _vsink_run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + + out = out_b.view(total_q, nhead, hdim_v) + lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = aiter.mha_varlen_bwd( + dout, + q, + k, + v, + out, + lse, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink contains non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + assert dq.shape == q.shape + assert dk.shape == k.shape + assert dv.shape == v.shape + + d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen d_sink mismatch vs reference", + ) + + +@pytest.mark.parametrize("dtype", _VSINK_DTYPES) +def test_mha_varlen_bwd_sink_variable_lengths(dtype): + """Varlen sink test with variable-length sequences per batch entry.""" + device = torch.device("cuda") + nhead, hdim = 4, 64 + hdim_v = hdim + softmax_scale = hdim**-0.5 + + seqlens_q = [48, 80] + seqlens_k = [48, 80] + batch = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_sq = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_sk = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), + device=device, + dtype=torch.int32, + ) + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + out_parts, lse_parts = [], [] + offset_q, offset_k = 0, 0 + for sq, sk in zip(seqlens_q, seqlens_k): + q_b = q[offset_q : offset_q + sq].unsqueeze(0) + k_b = k[offset_k : offset_k + sk].unsqueeze(0) + v_b = v[offset_k : offset_k + sk].unsqueeze(0) + out_b, lse_b = _vsink_run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + out_parts.append(out_b.squeeze(0)) + lse_parts.append(lse_b.squeeze(0).permute(1, 0)) + offset_q += sq + offset_k += sk + + out = torch.cat(out_parts, dim=0) + lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = aiter.mha_varlen_bwd( + dout, + q, + k, + v, + out, + lse, + cu_seqlens_q=cu_sq, + cu_seqlens_k=cu_sk, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink has non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + + d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen variable-length d_sink mismatch", + )