Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,8 @@ def cmdGenFunc_mha_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
):
md_name = "mha_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
Expand Down Expand Up @@ -775,7 +776,8 @@ def gen_mha_bwd_fake_tensors(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv)

Expand Down Expand Up @@ -807,7 +809,8 @@ def mha_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -889,7 +892,8 @@ def cmdGenFunc_mha_varlen_bwd(
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> dict[str, Any]:
md_name = "mha_varlen_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
Expand Down Expand Up @@ -1117,7 +1121,8 @@ def gen_mha_varlen_bwd_fake_tensors(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return gen_mha_varlen_bwd_fake_tensors_common(
q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv
Expand Down Expand Up @@ -1157,7 +1162,8 @@ def mha_varlen_bwd(
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down
2 changes: 2 additions & 0 deletions csrc/cpp_itfs/mha_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
/* dv_ptr */ a.dv_ptr,
/* dbias_ptr */ a.dbias_ptr,
/* dq_acc_ptr */ a.dq_acc_ptr,
/* sink_ptr */ a.sink_ptr,
/* d_sink_ptr */ a.d_sink_ptr,

/* seqstart_q_ptr */ a.seqstart_q_ptr,
/* seqstart_k_ptr */ a.seqstart_k_ptr,
Expand Down
2 changes: 2 additions & 0 deletions csrc/include/mha_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ struct mha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* sink_ptr = nullptr; // sink scores [batch, nhead] log-space (LSEDataType=float); nullptr disables sink
void* d_sink_ptr = nullptr; // sink gradient accumulator [nhead] (LSEDataType=float); nullptr disables sink grad
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
Expand Down
8 changes: 6 additions & 2 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,9 @@ namespace py = pybind11;
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt);
py::arg("gen") = std::nullopt, \
py::arg("sink") = std::nullopt, \
py::arg("d_sink") = std::nullopt);

#define MHA_FWD_ASM_PYBIND \
m.def("fmha_v3_fwd", \
Expand Down Expand Up @@ -1005,7 +1007,9 @@ namespace py = pybind11;
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt, \
py::arg("cu_seqlens_q_padded") = std::nullopt, \
py::arg("cu_seqlens_k_padded") = std::nullopt);
py::arg("cu_seqlens_k_padded") = std::nullopt, \
py::arg("sink") = std::nullopt, \
py::arg("d_sink") = std::nullopt);

#define MOE_CK_2STAGES_PYBIND \
m.def("ck_moe_stage1", \
Expand Down
4 changes: 3 additions & 1 deletion csrc/include/torch/mha_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ std::vector<at::Tensor> mha_bwd(const at::Tensor& dout, // [b, sq, hq, d]
std::optional<const at::Tensor> bias_, // [sq, sk]
std::optional<const at::Tensor> alibi_slopes, // [hq] or [b, hq]
std::optional<const at::Tensor> rng_state,
std::optional<at::Generator> gen);
std::optional<at::Generator> gen,
std::optional<const at::Tensor> sink, // [b, hq] log-space sink scores (float)
std::optional<at::Tensor> d_sink); // [hq] sink gradient output (float)
} // namespace torch_itfs
} // namespace aiter
4 changes: 3 additions & 1 deletion csrc/include/torch/mha_varlen_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d]
std::optional<const at::Tensor> rng_state,
std::optional<at::Generator> gen,
std::optional<const at::Tensor> cu_seqlens_q_padded, // [b+1]
std::optional<const at::Tensor> cu_seqlens_k_padded // [b+1]
std::optional<const at::Tensor> cu_seqlens_k_padded, // [b+1]
std::optional<const at::Tensor> sink, // [b, hq] log-space sink scores (float)
std::optional<at::Tensor> d_sink // [hq] sink gradient output (float)
);
} // namespace torch_itfs
} // namespace aiter
9 changes: 8 additions & 1 deletion csrc/py_itfs_ck/mha_bwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
std::optional<const at::Tensor> bias_, // [sq, sk]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
std::optional<const at::Tensor> rng_state_,
std::optional<at::Generator> gen_)
std::optional<at::Generator> gen_,
std::optional<const at::Tensor> sink_, // [b, hq] log-space sink scores (float)
std::optional<at::Tensor> d_sink_) // [hq] sink gradient output (float)
{
if (is_causal) { window_size_right = 0; }

Expand Down Expand Up @@ -198,6 +200,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
hipLaunchKernelGGL(
aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
} else {
// No dropout: allocate a dummy tensor so data_ptr() is always valid.
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
}

if (seqlen_q > 0) {
Expand Down Expand Up @@ -329,6 +334,8 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
dv_expanded.data_ptr(),
dbias_ptr,
dq_accum.data_ptr(),
(sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq]
(d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq]
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
Expand Down
7 changes: 5 additions & 2 deletions csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v]
std::optional<const at::Tensor> rng_state_,
std::optional<at::Generator> gen_,
std::optional<const at::Tensor> cu_seqlens_q_padded, // [b+1]
std::optional<const at::Tensor> cu_seqlens_k_padded // [b+1]
)
std::optional<const at::Tensor> cu_seqlens_k_padded, // [b+1]
std::optional<const at::Tensor> sink_, // [b, hq] log-space sink scores (float)
std::optional<at::Tensor> d_sink_) // [hq] sink gradient output (float)
{
if (is_causal) { window_size_right = 0; }

Expand Down Expand Up @@ -341,6 +342,8 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v]
dv_expanded.data_ptr(),
nullptr, // dbias
dq_accum.data_ptr(), // dq_acc
(sink_.has_value() && sink_.value().defined()) ? sink_.value().data_ptr() : nullptr, // sink_ptr [b, hq]
(d_sink_.has_value() && d_sink_.value().defined()) ? d_sink_.value().data_ptr() : nullptr, // d_sink_ptr [hq]
seqstart_q_ptr, // seqstart_q_ptr (physical cumulative)
seqstart_k_ptr, // seqstart_k_ptr (physical cumulative)
nullptr, // seqlen_q_ptr (per-sequence logical)
Expand Down
1 change: 1 addition & 0 deletions csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype: str) -> Optional[dict]
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
/* BlockSize = */ 64,
{F_hdim},
{F_mode},
Expand Down
Loading
Loading