From 6b8b32f83ab02f909189a10bd8e26b0163dd1273 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Sat, 18 Apr 2026 05:38:19 +0800 Subject: [PATCH] =?UTF-8?q?feat(ascend):=20op-cache-attn=20group=20?= =?UTF-8?q?=E2=80=94=20ReshapeAndCache,=20FlashAttention,=20PagedAttention?= =?UTF-8?q?,=20TopkToppSampling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four KV-cache and attention operators: | op | impl | |---|---| | ReshapeAndCache | 3 impls: aclnnInplaceIndexCopy (kernel.h); custom AscendC (kernel_v2.h); ATB `ReshapeAndCacheParam` (kernel_atb.h, int64 `slot_mapping` handled via cached async `aclnnCast`) | | FlashAttention | `aclnnFusedInferAttentionScoreV4` (prefill + paged decode). Supports both the native `(window_left, window_right)` pair and a new `std::optional sliding_window` entry (additive, vLLM-style) | | PagedAttention | ATB `PagedAttentionParam` with optional CPU-pinned host tensors (`seq_lens_host` / `block_table_host`) that make the op NPUGraph-capturable | | TopkToppSampling | ATB `TopkToppSamplingParam` | Includes vLLM API alignment commits: - `perf(reshape_and_cache)`: int64 slot_mapping routed through cached async `aclnnCast` (no D2H sync, NPUGraph-compatible) - `feat(flash_attention)`: add `sliding_window` entry, additive - `docs(paged_attention)`: base class comment explains the CPU-host tensor contract New `src/base/.h`: paged_attention, topk_topp_sampling. Modified: reshape_and_cache, flash_attention. --- src/ascend/flash_attention/kernel.h | 375 +++++++++++++ src/ascend/paged_attention/kernel_atb.h | 283 ++++++++++ src/ascend/reshape_and_cache/kernel.h | 109 ++++ src/ascend/reshape_and_cache/kernel_atb.h | 257 +++++++++ src/ascend/reshape_and_cache/kernel_v2.h | 123 +++++ src/ascend/topk_topp_sampling/kernel_atb.h | 185 +++++++ src/base/flash_attention.h | 67 ++- src/base/paged_attention.h | 129 +++++ src/base/reshape_and_cache.h | 9 - src/base/topk_topp_sampling.h | 62 +++ tests/test_flash_attention.py | 597 +++++++++++++++++++++ tests/test_paged_attention.py | 554 +++++++++++++++++++ tests/test_reshape_and_cache.py | 273 ++++++++++ 13 files changed, 2993 insertions(+), 30 deletions(-) create mode 100644 src/ascend/flash_attention/kernel.h create mode 100644 src/ascend/paged_attention/kernel_atb.h create mode 100644 src/ascend/reshape_and_cache/kernel.h create mode 100644 src/ascend/reshape_and_cache/kernel_atb.h create mode 100644 src/ascend/reshape_and_cache/kernel_v2.h create mode 100644 src/ascend/topk_topp_sampling/kernel_atb.h create mode 100644 src/base/paged_attention.h create mode 100644 src/base/topk_topp_sampling.h create mode 100644 tests/test_flash_attention.py create mode 100644 tests/test_paged_attention.py create mode 100644 tests/test_reshape_and_cache.py diff --git a/src/ascend/flash_attention/kernel.h b/src/ascend/flash_attention/kernel.h new file mode 100644 index 00000000..e1ef5b00 --- /dev/null +++ b/src/ascend/flash_attention/kernel.h @@ -0,0 +1,375 @@ +#ifndef INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ +#define INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_fused_infer_attention_score_v4.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/flash_attention.h" +#include "operator.h" + +namespace infini::ops { + +namespace detail { + +// Extract cu_seqlens differences to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...]. +// Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths). +// +// When cu_seqlens is a CPU tensor (device type kCpu), the data pointer is +// already on the host and can be read directly — no D2H sync needed. +inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + + const int64_t* cu_host_ptr = nullptr; + std::vector cu_host_buf; + + if (cu_seqlens.device().type() == Device::Type::kCpu) { + cu_host_ptr = static_cast(cu_seqlens.data()); + } else { + cu_host_buf.resize(n); + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), cu_seqlens.data(), + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + cu_host_ptr = cu_host_buf.data(); + } + + std::vector lengths(n - 1); + for (size_t i = 0; i < lengths.size(); ++i) { + lengths[i] = cu_host_ptr[i + 1] - cu_host_ptr[i]; + } + + return aclCreateIntArray(lengths.data(), + static_cast(lengths.size())); +} + +// Extract cumulative end positions from cu_seqlens to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> cum_lens = [s1, s1+s2, ...]. +// FIA V4 TND varlen uses cumulative end positions, matching the vllm-ascend +// convention for npu_fused_infer_attention_score actual_seq_lengths. +// +// When cu_seqlens is a CPU tensor, reads directly from host memory. +inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + + const int64_t* cu_host_ptr = nullptr; + std::vector cu_host_buf; + + if (cu_seqlens.device().type() == Device::Type::kCpu) { + cu_host_ptr = static_cast(cu_seqlens.data()); + } else { + cu_host_buf.resize(n); + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), cu_seqlens.data(), + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + cu_host_ptr = cu_host_buf.data(); + } + + // Skip the leading 0; return [s1, s1+s2, ...]. + return aclCreateIntArray(cu_host_ptr + 1, static_cast(n - 1)); +} + +// Allocate a 2048x2048 lower-triangular UINT8 causal mask on device. +// Required for `sparseMode` >= 2. +inline aclTensor* makeCausalMask(void** mask_buf, aclrtStream stream) { + constexpr int64_t kMaskDim = 2048; + const int64_t mask_elems = kMaskDim * kMaskDim; + const size_t mask_bytes = static_cast(mask_elems); // uint8_t + + aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + std::vector host_mask(mask_elems); + for (int64_t r = 0; r < kMaskDim; ++r) { + for (int64_t c = 0; c < kMaskDim; ++c) { + // 1 = masked out (upper triangle); 0 = attend (lower triangle). + host_mask[r * kMaskDim + c] = (c > r) ? 1 : 0; + } + } + aclrtMemcpyAsync(*mask_buf, mask_bytes, host_mask.data(), mask_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, stream); + aclrtSynchronizeStream(stream); + + std::vector mask_shape = {kMaskDim, kMaskDim}; + std::vector mask_strides = {kMaskDim, 1}; + std::vector mask_storage = {mask_elems}; + return aclCreateTensor(mask_shape.data(), 2, ACL_UINT8, mask_strides.data(), + 0, ACL_FORMAT_ND, mask_storage.data(), 1, *mask_buf); +} + +} // namespace detail + +template <> +class Operator : public FlashAttention { + public: + Operator(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, bool causal, + int64_t window_left, int64_t window_right, int64_t block_size, + Tensor output, std::optional sliding_window = std::nullopt) + : FlashAttention(query, key, value, cu_seqlens_q, cu_seqlens_kv, + block_table, num_heads, num_kv_heads, head_size, scale, + causal, window_left, window_right, block_size, output, + sliding_window) { + paged_ = block_table.has_value() && block_size > 0; + aclDataType acl_dt = ascend::ToAclDtype(query.dtype()); + + if (!paged_) { + // Prefill: cache Q and output (TND layout). + prefill_q_cache_ = ascend::AclTensorCache(query); + prefill_out_cache_ = ascend::AclTensorCache(output); + + // Pre-compute causal mask once (sparse_mode >= 2). Read the + // resolved pair from base-class members so `sliding_window` + // normalization is honored at cache-key construction. + if (causal) { + int64_t sm = (window_left_ >= 0) ? 4 : 3; + if (sm >= 2) { + causal_mask_ = detail::makeCausalMask(&causal_mask_buf_, nullptr); + } + } + } else { + // Decode: cache Q/output (BNSD), block_table. + const int64_t N = query.size(1); + const int64_t D = query.size(2); + const int64_t B = query.size(0); + + decode_q_cache_ = ascend::AclTensorCache({B, N, 1, D}, acl_dt, + const_cast(query.data())); + decode_out_cache_ = + ascend::AclTensorCache({B, N, 1, D}, acl_dt, output.data()); + block_table_cache_ = ascend::AclTensorCache(block_table.value()); + + // Pre-compute KV reshape metadata. + const int64_t nb = key.size(0); + const int64_t bsz = key.size(1); + const int64_t NkvD = key.size(2) * key.size(3); + kv_shape_ = {nb, bsz, NkvD}; + kv_strides_ = {bsz * NkvD, NkvD, 1}; + kv_storage_shape_ = {nb * bsz * NkvD}; + kv_acl_dt_ = acl_dt; + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + if (causal_mask_) aclDestroyTensor(causal_mask_); + if (causal_mask_buf_) aclrtFree(causal_mask_buf_); + } + + void operator()(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output, + std::optional sliding_window) const override { + auto stream = static_cast(stream_); + const bool paged = paged_; + + // The base class stored the resolved window pair in `window_left_` / + // `window_right_` at construction; prefer those over the call-site + // args so that `sliding_window` is honored here as well. + int64_t wl = window_left_; + int64_t wr = window_right_; + (void)window_left; + (void)window_right; + (void)sliding_window; + + int64_t sparse_mode; + int64_t pre_tokens = 2147483647; + int64_t next_tokens = 2147483647; + if (causal) { + if (wl >= 0) { + sparse_mode = 4; + pre_tokens = wl; + next_tokens = 0; + } else { + sparse_mode = 3; + next_tokens = 0; + } + } else { + sparse_mode = 0; + if (wl >= 0) pre_tokens = wl; + if (wr >= 0) next_tokens = wr; + } + + if (!paged) { + // --- Prefill --- + int64_t T = query.size(0); + + // cumSeqLengths / extractSeqLengths automatically skip D2H when + // cu_seqlens is a CPU tensor (see detail:: helpers above). + aclIntArray* seq_q = + cu_seqlens_q.has_value() + ? detail::cumSeqLengths(cu_seqlens_q.value(), stream) + : aclCreateIntArray(&T, 1); + aclIntArray* seq_kv = + cu_seqlens_kv.has_value() + ? detail::cumSeqLengths(cu_seqlens_kv.value(), stream) + : aclCreateIntArray(&T, 1); + + aclTensor* t_q = prefill_q_cache_.get(const_cast(query.data())); + // K/V descriptors go into TensorList which takes ownership — must be + // per-call (cannot cache). + aclTensor* t_k = ascend::BuildAclTensor(key); + aclTensor* t_v = ascend::BuildAclTensor(value); + aclTensor* t_out = prefill_out_cache_.get(output.data()); + + const aclTensor* k_arr[] = {t_k}; + const aclTensor* v_arr[] = {t_v}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_q, key_list, val_list, + nullptr, // pseShift + causal_mask_, // attenMask (pre-computed, or nullptr) + seq_q, // actualSeqLengths + seq_kv, // actualSeqLengthsKv + nullptr, nullptr, nullptr, nullptr, + nullptr, // deqScale1..quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // queryPaddingSize, kvPaddingSize + nullptr, nullptr, nullptr, + nullptr, // key/value antiquant scale/offset + nullptr, nullptr, + nullptr, // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen + nullptr, nullptr, + nullptr, // queryRope, keyRope, keyRopeAntiquantScale + nullptr, nullptr, // dequantScaleQuery, learnableSink + num_heads, scale, pre_tokens, next_tokens, const_cast("TND"), + num_kv_heads, sparse_mode, + 0, // innerPrecise + 0, // blockSize (unused for prefill) + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_out, nullptr, &ws_needed, &executor); + assert( + gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (prefill)"); + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed); + aclError ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, + executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (prefill)"); + + // t_q and t_out are owned by caches — do NOT destroy. + // t_k and t_v are owned by TensorLists. + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyIntArray(seq_q); + aclDestroyIntArray(seq_kv); + return; + } + + // --- Paged decode --- + assert(cu_seqlens_kv.has_value() && + "`FlashAttention` paged decode requires `cu_seqlens_kv`"); + + aclTensor* t_query = decode_q_cache_.get(const_cast(query.data())); + aclTensor* t_output = decode_out_cache_.get(output.data()); + + // K/V descriptors go into TensorList which takes ownership — must be + // per-call. Use pre-computed metadata to avoid heap allocs. + aclTensor* t_key = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dt_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(key.data())); + aclTensor* t_value = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dt_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(value.data())); + + // extractSeqLengths skips D2H when cu_seqlens_kv is a CPU tensor. + aclIntArray* seq_kv = + detail::extractSeqLengths(cu_seqlens_kv.value(), stream); + aclTensor* t_block_table = + block_table_cache_.get(const_cast(block_table.value().data())); + + const aclTensor* k_arr[] = {t_key}; + const aclTensor* v_arr[] = {t_value}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_query, key_list, val_list, + nullptr, // pseShift + nullptr, // attenMask (sparseMode ignored for Q_S=1) + nullptr, // actualSeqLengths (ignored for Q_S=1) + seq_kv, // actualSeqLengthsKv (mandatory for paged) + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + t_block_table, // blockTable + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, num_heads, scale, + static_cast(2147483647), static_cast(2147483647), + const_cast("BNSD"), num_kv_heads, + 0, // sparseMode=0 (ignored for Q_S=1) + 0, // innerPrecise + block_size, // blockSize + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_output, nullptr, &ws_needed, &executor); + assert(gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (decode)"); + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed); + aclError ret = + aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (decode)"); + + // t_query, t_output, t_block_table owned by caches — do NOT destroy. + // t_key, t_value owned by TensorLists. + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyIntArray(seq_kv); + } + + private: + bool paged_ = false; + + mutable ascend::AclTensorCache prefill_q_cache_; + + mutable ascend::AclTensorCache prefill_out_cache_; + + mutable ascend::AclTensorCache decode_q_cache_; + + mutable ascend::AclTensorCache decode_out_cache_; + + mutable ascend::AclTensorCache block_table_cache_; + + aclTensor* causal_mask_ = nullptr; + + void* causal_mask_buf_ = nullptr; + + std::vector kv_shape_; + + std::vector kv_strides_; + + std::vector kv_storage_shape_; + + aclDataType kv_acl_dt_ = ACL_DT_UNDEFINED; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/paged_attention/kernel_atb.h b/src/ascend/paged_attention/kernel_atb.h new file mode 100644 index 00000000..adf5f36b --- /dev/null +++ b/src/ascend/paged_attention/kernel_atb.h @@ -0,0 +1,283 @@ +#ifndef INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ +#define INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "ascend/atb_common_.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "atb/context.h" +#include "atb/infer_op_params.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "base/paged_attention.h" +#include "operator.h" + +namespace infini::ops { + +// ATB-based paged decode attention (implementation index 0). +// +// Wraps ATB `PagedAttentionParam` with the default `inputLayout` +// (`TYPE_BSND`). For decode (single token per request) the S +// dimension is implicitly 1, so query and output use 3D shape +// [batch, num_heads, head_size] matching vLLM's convention. +// +// ATB internally constructs `aclIntArray*` from the `hostData` field +// of `block_table` and `context_lens` tensors. By default the operator +// performs synchronous D2H copies for these two small tensors each call. +// When the caller provides `seq_lens_host` and `block_table_host` (CPU +// pinned tensors), the D2H copies are skipped entirely — enabling full +// NPUGraph capture of the decode attention path. +// +// ATB VariantPack layout (BSND with S=1): +// inTensors[0] = query [B, N, D] +// inTensors[1] = key_cache [num_blocks, block_size, Nkv, D] +// inTensors[2] = value_cache [num_blocks, block_size, Nkv, D] +// inTensors[3] = block_table [B, max_num_blocks] (device + host) +// inTensors[4] = context_lens [B] (int32) (device + host) +// outTensors[0] = output [B, N, D] +template <> +class Operator + : public PagedAttention { + public: + Operator(const Tensor query, const Tensor key_cache, const Tensor value_cache, + const Tensor seq_lens, const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output, + std::optional seq_lens_host = std::nullopt, + std::optional block_table_host = std::nullopt) + : PagedAttention(query, key_cache, value_cache, seq_lens, block_table, + num_heads, num_kv_heads, head_size, scale, block_size, + output, seq_lens_host, block_table_host) { + int64_t B = static_cast(batch_size_); + int64_t N = num_heads_; + int64_t Nkv = num_kv_heads_; + int64_t D = head_size_; + + // Query/output shapes: 3D [B, N, D] (BSND with S=1 for decode). + query_tnd_shape_ = {B, N, D}; + output_tnd_shape_ = {B, N, D}; + + // KV cache shapes. + int64_t num_blocks = static_cast(key_cache.size(0)); + int64_t bs = static_cast(key_cache.size(1)); + kv_cache_shape_ = {num_blocks, bs, Nkv, D}; + + // Block table and context lens shapes. + int64_t max_blocks = static_cast(block_table.size(1)); + block_table_shape_ = {B, max_blocks}; + context_lens_shape_ = {B}; + + // ACL data types. + acl_dt_ = ascend::ToAclDtype(query.dtype()); + bt_dt_ = ascend::ToAclDtype(block_table.dtype()); + sl_dt_ = ascend::ToAclDtype(seq_lens.dtype()); + + // Element sizes for `dataSize` computation. + elem_size_ = query.element_size(); + bt_elem_size_ = block_table.element_size(); + sl_elem_size_ = seq_lens.element_size(); + + // Pre-allocate pinned host buffers for D2H copies. + // ATB PA reads `hostData` from block_table and context_lens to + // construct internal `aclIntArray*` parameters. + // When caller provides host tensors, skip allocation — the caller's + // pinned buffers will be used directly in `operator()`. + bt_host_bytes_ = static_cast(B * max_blocks) * bt_elem_size_; + sl_host_bytes_ = static_cast(B) * sl_elem_size_; + + if (!has_block_table_host_) { + bt_host_ = std::malloc(bt_host_bytes_); + assert(bt_host_ && "Host buffer allocation for `block_table` failed"); + } + + if (!has_seq_lens_host_) { + sl_host_ = std::malloc(sl_host_bytes_); + assert(sl_host_ && "Host buffer allocation for `seq_lens` failed"); + } + + // Create the ATB operation (reused across calls). + atb::infer::PagedAttentionParam param; + param.headNum = static_cast(N); + param.kvHeadNum = static_cast(Nkv); + param.qkScale = static_cast(scale_); + + atb::Status s = atb::CreateOperation(param, &op_); + assert(s == atb::NO_ERROR && "atb::CreateOperation(PagedAttention) failed"); + } + + ~Operator() { + // Host memory is always safe to free. + if (!has_block_table_host_) { + std::free(bt_host_); + } + + if (!has_seq_lens_host_) { + std::free(sl_host_); + } + + if (!ascend::IsAclRuntimeAlive()) return; + + if (op_) { + atb::DestroyOperation(op_); + } + } + + Operator(const Operator&) = delete; + + Operator& operator=(const Operator&) = delete; + + void operator()(const Tensor query, const Tensor key_cache, + const Tensor value_cache, const Tensor seq_lens, + const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output, + std::optional seq_lens_host, + std::optional block_table_host) const override { + auto stream = static_cast(stream_); + atb::Context* ctx = ascend::GetAtbContext(stream); + + // Use caller-provided host data or perform synchronous D2H copy. + // ATB reads `hostData` to construct internal `aclIntArray*`. + void* bt_host_ptr = bt_host_; + void* sl_host_ptr = sl_host_; + + if (block_table_host.has_value()) { + bt_host_ptr = const_cast(block_table_host.value().data()); + } else { + aclrtMemcpy(bt_host_, bt_host_bytes_, block_table.data(), bt_host_bytes_, + ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (seq_lens_host.has_value()) { + sl_host_ptr = const_cast(seq_lens_host.value().data()); + } else { + aclrtMemcpy(sl_host_, sl_host_bytes_, seq_lens.data(), sl_host_bytes_, + ACL_MEMCPY_DEVICE_TO_HOST); + } + + atb::VariantPack vp = buildVariantPack( + const_cast(query.data()), const_cast(key_cache.data()), + const_cast(value_cache.data()), + const_cast(block_table.data()), + const_cast(seq_lens.data()), output.data(), bt_host_ptr, + sl_host_ptr); + + // Setup computes workspace requirements and binds tensor descriptors. + uint64_t ws_size = 0; + atb::Status s = op_->Setup(vp, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Setup(PagedAttention) failed"); + + // Allocate workspace via the shared pool. + uint8_t* ws_ptr = nullptr; + + if (ws_size > 0) { + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size); + ws_ptr = static_cast(arena.buf); + } + + s = op_->Execute(vp, ws_ptr, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Execute(PagedAttention) failed"); + } + + private: + // Build the ATB VariantPack. + // + // Query and output are 3D [B, N, D] (BSND with S=1 for decode). + // Block table and context lens carry both `deviceData` and + // `hostData` because ATB reads the host copy to build internal + // `aclIntArray*` parameters. + atb::VariantPack buildVariantPack(void* query_data, void* key_cache_data, + void* value_cache_data, + void* block_table_data, void* seq_lens_data, + void* output_data, void* bt_host_ptr, + void* sl_host_ptr) const { + int64_t B = query_tnd_shape_[0]; + int64_t N = query_tnd_shape_[1]; + int64_t D = query_tnd_shape_[2]; + + // Query [B, N, D] — 3D (BSND with S=1). + uint64_t q_bytes = static_cast(B * N * D) * elem_size_; + atb::Tensor t_query = + ascend::ToAtbTensor(query_tnd_shape_, acl_dt_, query_data, q_bytes); + + // KV caches [num_blocks, block_size, Nkv, D]. + int64_t nb = kv_cache_shape_[0]; + int64_t bs = kv_cache_shape_[1]; + int64_t Nkv = kv_cache_shape_[2]; + uint64_t kv_bytes = static_cast(nb * bs * Nkv * D) * elem_size_; + atb::Tensor t_key_cache = + ascend::ToAtbTensor(kv_cache_shape_, acl_dt_, key_cache_data, kv_bytes); + atb::Tensor t_value_cache = ascend::ToAtbTensor(kv_cache_shape_, acl_dt_, + value_cache_data, kv_bytes); + + // Block table [B, max_blocks] — with hostData for `aclIntArray*`. + atb::Tensor t_block_table = ascend::ToAtbTensor( + block_table_shape_, bt_dt_, block_table_data, bt_host_bytes_); + t_block_table.hostData = bt_host_ptr; + + // Context lens [B] — with hostData for `aclIntArray*`. + atb::Tensor t_context_lens = ascend::ToAtbTensor( + context_lens_shape_, sl_dt_, seq_lens_data, sl_host_bytes_); + t_context_lens.hostData = sl_host_ptr; + + // Output [B, N, D] — 3D (BSND with S=1). + atb::Tensor t_output = + ascend::ToAtbTensor(output_tnd_shape_, acl_dt_, output_data, q_bytes); + + atb::VariantPack vp; + vp.inTensors = {t_query, t_key_cache, t_value_cache, t_block_table, + t_context_lens}; + vp.outTensors = {t_output}; + + return vp; + } + + atb::Operation* op_ = nullptr; + + std::vector query_tnd_shape_; + + std::vector output_tnd_shape_; + + std::vector kv_cache_shape_; + + std::vector block_table_shape_; + + std::vector context_lens_shape_; + + aclDataType acl_dt_ = ACL_DT_UNDEFINED; + + aclDataType bt_dt_ = ACL_DT_UNDEFINED; + + aclDataType sl_dt_ = ACL_DT_UNDEFINED; + + uint64_t elem_size_ = 0; + + uint64_t bt_elem_size_ = 0; + + uint64_t sl_elem_size_ = 0; + + // Host-side buffers for ATB's internal `aclIntArray*` construction. + void* bt_host_ = nullptr; + + void* sl_host_ = nullptr; + + uint64_t bt_host_bytes_ = 0; + + uint64_t sl_host_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ diff --git a/src/ascend/reshape_and_cache/kernel.h b/src/ascend/reshape_and_cache/kernel.h new file mode 100644 index 00000000..2d91b8e2 --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel.h @@ -0,0 +1,109 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_index_copy.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +// Device-side scatter via aclnnInplaceIndexCopy. +// +// The previous implementation copied slot_mapping D2H (aclrtSynchronizeStream), +// then issued per-token D2D memcpy in a host loop. For batch=256, this meant +// ~100 us sync + ~500 us host loop overhead. aclnnInplaceIndexCopy performs +// the scatter entirely on the NPU with two ACLNN calls (one for K, one for V), +// eliminating all D2H synchronisation and host-side loops. +// +// Requirement: slot_mapping must contain only non-negative values. Padding +// tokens (slot < 0) must be filtered by the caller before invoking this +// operator. +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out), + key_cache_(key), + value_cache_(value), + slot_cache_(slot_mapping) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t total_slots = num_blocks * bs; + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + + aclDataType acl_dt = ascend::ToAclDtype(key.dtype()); + + // Flattened K cache view: [total_slots, num_kv_heads, head_size]. + // K cache is kv_cache_out[0], starting at offset 0. + kv_k_cache_ = ascend::AclTensorCache({total_slots, nkv, hs}, acl_dt, + kv_cache_out.data()); + + // V cache is kv_cache_out[1], offset by stride(0) elements. + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + kv_v_cache_ = ascend::AclTensorCache( + {total_slots, nkv, hs}, acl_dt, + static_cast(kv_cache_out.data()) + v_offset_bytes_); + } + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + void* kv_k_data = kv_cache_out.data(); + void* kv_v_data = static_cast(kv_cache_out.data()) + v_offset_bytes_; + + auto t_kv_k = kv_k_cache_.get(kv_k_data); + auto t_kv_v = kv_v_cache_.get(kv_v_data); + auto t_key = key_cache_.get(const_cast(key.data())); + auto t_value = value_cache_.get(const_cast(value.data())); + auto t_slot = slot_cache_.get(const_cast(slot_mapping.data())); + + // K cache scatter: kv_k[slot_mapping[i]] = key[i] along dim 0. + // Executor caching is not used here because aclnnInplaceIndexCopy is an + // inplace operation where self is both input and output; the executor + // reuse via aclSetInputTensorAddr does not update the output reference. + uint64_t k_ws = 0; + aclOpExecutor* k_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_k, 0, t_slot, t_key, &k_ws, + &k_exec); + auto& k_arena = ascend::GetWorkspacePool().Ensure(stream, k_ws); + aclnnInplaceIndexCopy(k_arena.buf, k_ws, k_exec, stream); + + // V cache scatter: kv_v[slot_mapping[i]] = value[i] along dim 0. + uint64_t v_ws = 0; + aclOpExecutor* v_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_v, 0, t_slot, t_value, &v_ws, + &v_exec); + auto& v_arena = ascend::GetWorkspacePool().Ensure(stream, v_ws); + aclnnInplaceIndexCopy(v_arena.buf, v_ws, v_exec, stream); + } + + private: + mutable ascend::AclTensorCache kv_k_cache_; + + mutable ascend::AclTensorCache kv_v_cache_; + + mutable ascend::AclTensorCache key_cache_; + + mutable ascend::AclTensorCache value_cache_; + + mutable ascend::AclTensorCache slot_cache_; + + size_t v_offset_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/reshape_and_cache/kernel_atb.h b/src/ascend/reshape_and_cache/kernel_atb.h new file mode 100644 index 00000000..222b61ba --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel_atb.h @@ -0,0 +1,257 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/atb_common_.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "atb/context.h" +#include "atb/infer_op_params.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +// ATB-based KV cache scatter via `atb::infer::ReshapeAndCacheParam` +// (implementation index 2). +// +// Handles both K and V in a single fused operation. Profiled at ~9.5 us/call +// on Ascend 910B (256 tokens, fp16) — 3.7x faster than the +// `aclnnInplaceIndexCopy` path (index 0, ~35 us). +// +// The ATB operation is created once in the constructor. Setup is called +// before each Execute to bind the VariantPack. +// +// NOTE: `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the +// caller passes int64 (the PyTorch / vLLM default), this operator issues an +// async `aclnnCast` to a pre-allocated int32 device buffer. The cast +// executor is cached across calls and the whole step stays on the stream +// with no D2H/H2D round-trip, so the int64 path is NPUGraph-capturable and +// roughly on par with the int32 fast path. +// +// Input layout: +// key, value : [num_tokens, num_kv_heads, head_size] +// slot_mapping: [num_tokens] (int32 or int64) +// +// KV cache layout: +// kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size] +// Output key_cache = kv_cache[0], value_cache = kv_cache[1], each with +// shape [num_blocks, block_size, num_kv_heads, head_size]. +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + int64_t T = static_cast(num_tokens_); + + // Cache shapes for rebuilding VariantPack on each call. + kv_shape_ = {num_blocks, bs, nkv, hs}; + key_shape_ = {T, nkv, hs}; + slot_shape_ = {T}; + acl_dt_ = ascend::ToAclDtype(key.dtype()); + + // Compute V-cache byte offset (kv_cache_out[1]). + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + + // Element sizes for dataSize computation. + elem_size_ = key.element_size(); + + // Pre-allocate int32 device buffer for `slot_mapping`. + // `ReshapeAndCacheParam` requires int32; int64 is silently ignored + // (writes nothing). + slot32_bytes_ = static_cast(T) * sizeof(int32_t); + aclrtMalloc(&slot32_buf_, slot32_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(slot32_buf_ && "aclrtMalloc for slot32_buf_ failed"); + + slot_is_int32_ = (slot_mapping.element_size() == sizeof(int32_t)); + + // Prepare aclnnCast descriptors for the int64 → int32 path. Source + // descriptor's data pointer is refreshed per call; destination is the + // pre-allocated `slot32_buf_`. + if (!slot_is_int32_) { + slot_i64_cache_ = ascend::AclTensorCache( + {T}, ACL_INT64, const_cast(slot_mapping.data())); + slot_i32_cache_ = ascend::AclTensorCache({T}, ACL_INT32, slot32_buf_); + } + + // Create the ATB operation (reused across calls). + atb::infer::ReshapeAndCacheParam param; + atb::Status s = atb::CreateOperation(param, &op_); + assert(s == atb::NO_ERROR && + "atb::CreateOperation(ReshapeAndCache) failed"); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + if (op_) atb::DestroyOperation(op_); + slot_i64_cache_.release(); + slot_i32_cache_.release(); + if (slot32_buf_) aclrtFree(slot32_buf_); + } + + Operator(const Operator&) = delete; + + Operator& operator=(const Operator&) = delete; + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + // `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the + // caller provides int64 (the PyTorch/vLLM default), issue an async + // `aclnnCast` to the pre-allocated int32 device buffer — keeps the + // whole step on-stream and NPUGraph-capturable. + void* slot32_ptr; + + if (slot_is_int32_) { + // Already int32 — pass through directly. + slot32_ptr = const_cast(slot_mapping.data()); + } else { + auto t_src = slot_i64_cache_.get(const_cast(slot_mapping.data())); + auto t_dst = slot_i32_cache_.get(slot32_buf_); + + if (!cast_exec_) { + aclnnCastGetWorkspaceSize(t_src, ACL_INT32, t_dst, &cast_ws_, + &cast_exec_); + aclSetAclOpExecutorRepeatable(cast_exec_); + } else { + aclSetInputTensorAddr(cast_exec_, 0, t_src, + const_cast(slot_mapping.data())); + aclSetOutputTensorAddr(cast_exec_, 0, t_dst, slot32_buf_); + } + + auto& cast_arena = ascend::GetWorkspacePool().Ensure(stream, cast_ws_); + aclnnCast(cast_arena.buf, cast_ws_, cast_exec_, stream); + slot32_ptr = slot32_buf_; + } + + atb::Context* ctx = ascend::GetAtbContext(stream); + + atb::VariantPack vp = buildVariantPack(const_cast(key.data()), + const_cast(value.data()), + kv_cache_out.data(), slot32_ptr); + + // Setup binds the VariantPack and computes workspace requirements. + uint64_t ws_size = 0; + atb::Status s = op_->Setup(vp, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Setup(ReshapeAndCache) failed"); + + // Allocate workspace via the shared pool. + uint8_t* ws_ptr = nullptr; + + if (ws_size > 0) { + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size); + ws_ptr = static_cast(arena.buf); + } + + s = op_->Execute(vp, ws_ptr, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Execute(ReshapeAndCache) failed"); + } + + private: + // Build the ATB VariantPack for this operation. + // + // ATB `ReshapeAndCache` expects 5 inputs and 2 outputs: + // inTensors[0] = key [num_tokens, num_kv_heads, head_size] + // inTensors[1] = value [num_tokens, num_kv_heads, head_size] + // inTensors[2] = key_cache [num_blocks, block_size, num_kv_heads, + // head_size] inTensors[3] = value_cache [num_blocks, block_size, + // num_kv_heads, head_size] inTensors[4] = slot_mapping [num_tokens] (int32) + // outTensors[0] = key_cache (same buffer, in-place) + // outTensors[1] = value_cache (same buffer, in-place) + atb::VariantPack buildVariantPack(void* key_data, void* value_data, + void* kv_out_data, + void* slot32_data) const { + int64_t num_tokens = key_shape_[0]; + int64_t nkv = key_shape_[1]; + int64_t hs = key_shape_[2]; + uint64_t kv_bytes = + static_cast(num_tokens * nkv * hs) * elem_size_; + + int64_t nb = kv_shape_[0]; + int64_t bs = kv_shape_[1]; + uint64_t cache_bytes = + static_cast(nb * bs * nkv * hs) * elem_size_; + + void* v_out_data = static_cast(kv_out_data) + v_offset_bytes_; + + atb::Tensor t_key = + ascend::ToAtbTensor(key_shape_, acl_dt_, key_data, kv_bytes); + + atb::Tensor t_value = + ascend::ToAtbTensor(key_shape_, acl_dt_, value_data, kv_bytes); + + atb::Tensor t_kv_k = + ascend::ToAtbTensor(kv_shape_, acl_dt_, kv_out_data, cache_bytes); + + atb::Tensor t_kv_v = + ascend::ToAtbTensor(kv_shape_, acl_dt_, v_out_data, cache_bytes); + + // Always int32 — the caller's `operator()` has already cast to int32. + atb::Tensor t_slot = + ascend::ToAtbTensor(slot_shape_, ACL_INT32, slot32_data, slot32_bytes_); + + atb::VariantPack vp; + vp.inTensors = {t_key, t_value, t_kv_k, t_kv_v, t_slot}; + vp.outTensors = {t_kv_k, t_kv_v}; + + return vp; + } + + atb::Operation* op_ = nullptr; + + std::vector kv_shape_; + + std::vector key_shape_; + + std::vector slot_shape_; + + aclDataType acl_dt_ = ACL_DT_UNDEFINED; + + size_t v_offset_bytes_ = 0; + + uint64_t elem_size_ = 0; + + // Pre-allocated int32 device buffer for `slot_mapping`. + void* slot32_buf_ = nullptr; + + size_t slot32_bytes_ = 0; + + // True if the caller already provides int32 `slot_mapping`. + bool slot_is_int32_ = false; + + // Cached aclnnCast descriptors (int64 slot_mapping → int32 buffer). + mutable ascend::AclTensorCache slot_i64_cache_; + + mutable ascend::AclTensorCache slot_i32_cache_; + + mutable aclOpExecutor* cast_exec_ = nullptr; + + mutable uint64_t cast_ws_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ diff --git a/src/ascend/reshape_and_cache/kernel_v2.h b/src/ascend/reshape_and_cache/kernel_v2.h new file mode 100644 index 00000000..524e0a8d --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel_v2.h @@ -0,0 +1,123 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ + +// WARNING: This implementation is experimental and has strict hardware limits. +// +// Limitations: +// 1. Requires CANN 8.5.1+ (`aclnnScatterPaKvCache` API). +// 2. Only supported on Atlas A5 hardware (SoC 260). NOT supported on +// A2 (Ascend 910B, SoC 220-225) or A3 (SoC 250-255). +// 3. Not yet validated in production workloads. +// +// On unsupported hardware this file compiles to nothing (guarded by +// `__has_include`). Use `implementation_index=0` (the default +// `aclnnInplaceIndexCopy` path) for general-purpose deployment. + +#if __has_include("aclnnop/aclnn_scatter_pa_kv_cache.h") + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_scatter_pa_kv_cache.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +// Fused KV cache scatter via `aclnnScatterPaKvCache` (implementation index 1). +// +// Handles both K and V scatter in a single CANN kernel launch, replacing two +// separate `aclnnInplaceIndexCopy` calls (index 0). The fused API is +// purpose-built for paged KV cache and avoids the internal decomposition to +// `ScatterElementsV2`. +// +// Requirements: +// - CANN 8.5.1+ (`aclnnop/aclnn_scatter_pa_kv_cache.h`). +// - Atlas A5 hardware (SoC 260). The API is NOT supported on A2 (910B, +// SoC 220-225) or A3 (SoC 250-255). +// +// Select via `implementation_index=1` in Python: +// infini.ops.reshape_and_cache(..., implementation_index=1, stream=s) +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out), + key_cache_(key), + value_cache_(value), + slot_cache_(slot_mapping) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + + aclDataType acl_dt = ascend::ToAclDtype(key.dtype()); + + // 4D K cache view: [num_blocks, block_size, num_kv_heads, head_size]. + // K cache is kv_cache_out[0], starting at offset 0. + kv_k_cache_ = ascend::AclTensorCache({num_blocks, bs, nkv, hs}, acl_dt, + kv_cache_out.data()); + + // V cache is kv_cache_out[1], offset by stride(0) elements. + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + kv_v_cache_ = ascend::AclTensorCache( + {num_blocks, bs, nkv, hs}, acl_dt, + static_cast(kv_cache_out.data()) + v_offset_bytes_); + } + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + void* kv_k_data = kv_cache_out.data(); + void* kv_v_data = static_cast(kv_cache_out.data()) + v_offset_bytes_; + + auto t_key = key_cache_.get(const_cast(key.data())); + auto t_value = value_cache_.get(const_cast(value.data())); + auto t_slot = slot_cache_.get(const_cast(slot_mapping.data())); + auto t_kv_k = kv_k_cache_.get(kv_k_data); + auto t_kv_v = kv_v_cache_.get(kv_v_data); + + // Single fused scatter for both K and V caches. + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnScatterPaKvCacheGetWorkspaceSize( + t_key, t_kv_k, t_slot, t_value, t_kv_v, + /*compressLensOptional=*/nullptr, + /*compressSeqOffsetOptional=*/nullptr, + /*seqLensOptional=*/nullptr, + /*cacheModeOptional=*/nullptr, + /*scatterModeOptional=*/nullptr, + /*stridesOptional=*/nullptr, + /*offsetsOptional=*/nullptr, &ws, &exec); + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws); + aclnnScatterPaKvCache(arena.buf, ws, exec, stream); + } + + private: + mutable ascend::AclTensorCache kv_k_cache_; + + mutable ascend::AclTensorCache kv_v_cache_; + + mutable ascend::AclTensorCache key_cache_; + + mutable ascend::AclTensorCache value_cache_; + + mutable ascend::AclTensorCache slot_cache_; + + size_t v_offset_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif // __has_include("aclnnop/aclnn_scatter_pa_kv_cache.h") + +#endif // INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ diff --git a/src/ascend/topk_topp_sampling/kernel_atb.h b/src/ascend/topk_topp_sampling/kernel_atb.h new file mode 100644 index 00000000..6ad02a85 --- /dev/null +++ b/src/ascend/topk_topp_sampling/kernel_atb.h @@ -0,0 +1,185 @@ +#ifndef INFINI_OPS_ASCEND_TOPK_TOPP_SAMPLING_KERNEL_ATB_H_ +#define INFINI_OPS_ASCEND_TOPK_TOPP_SAMPLING_KERNEL_ATB_H_ + +#ifdef INFINI_HAS_ATB + +#include +#include + +#include "acl/acl.h" +#include "ascend/atb_common_.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "atb/context.h" +#include "atb/infer_op_params.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "base/topk_topp_sampling.h" +#include "operator.h" + +namespace infini::ops { + +// ATB-based fused top-k/top-p sampling via `atb::infer::TopkToppSamplingParam` +// (implementation index 0). +// +// Uses `BATCH_TOPK_EXPONENTIAL_SAMPLING` which matches vLLM's Gumbel-trick +// sampling semantics (`q.exponential_()` -> `probs.div(q).argmax()`). +// Exponential sampling does not require `randSeeds`, making the ATB operation +// parameter-stable and cacheable across calls with the same `topk`. +// +// ATB constraint: input probabilities must be float16 or bfloat16. +// The caller must cast float32 probs to float16 before invoking this kernel. +// +// ATB tensor layout (from `atb_ops_info.ini`): +// in0 (probs) : [B, V] float16/bf16 +// in1 (seeds) : [B, 1] int32 — placeholder for exponential mode +// in2 (unused) : [B, 1] float16/bf16 — placeholder +// in3 (exp_random) : [B, V] float16/bf16 — placeholder +// out0 (indices) : [B, 1] int32 +// out1 (out_probs) : [B, 1] float16/bf16 — placeholder +template <> +class Operator + : public TopkToppSampling { + public: + Operator(const Tensor probs, int64_t topk, double topp, Tensor out) + : TopkToppSampling(probs, topk, topp, out) { + atb::infer::TopkToppSamplingParam param; + param.topkToppSamplingType = + atb::infer::TopkToppSamplingParam::BATCH_TOPK_EXPONENTIAL_SAMPLING; + param.topk = static_cast(topk_); + + atb::Status s = atb::CreateOperation(param, &op_); + + if (s != atb::NO_ERROR) { + fprintf(stderr, + "[TopkToppSampling] atb::CreateOperation failed (status=%d)\n", + static_cast(s)); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + if (op_) atb::DestroyOperation(op_); + } + + Operator(const Operator&) = delete; + + Operator& operator=(const Operator&) = delete; + + void operator()(const Tensor probs, int64_t topk, double topp, + Tensor out) const override { + if (!op_) return; + + auto stream = static_cast(stream_); + atb::Context* ctx = ascend::GetAtbContext(stream); + + int64_t B = batch_size_; + int64_t V = vocab_size_; + aclDataType probs_dt = ascend::ToAclDtype(probs.dtype()); + uint64_t probs_elem = 2; // Float16 or bf16 — both 2 bytes. + void* probs_ptr = const_cast(probs.data()); + void* out_ptr = out.data(); + + // Auxiliary buffers: seeds [B,1] int32 + in2 [B,1] fp16 + out1 [B,1] fp16. + // Also allocate in3 [B,V] fp16 as a scratch buffer. + uint64_t seeds_bytes = static_cast(B) * 4; + uint64_t in2_bytes = static_cast(B) * probs_elem; + uint64_t out1_bytes = static_cast(B) * probs_elem; + uint64_t in3_bytes = static_cast(B * V) * probs_elem; + uint64_t aux_bytes = seeds_bytes + in2_bytes + out1_bytes + in3_bytes; + + // Build tensors using raw descriptors. + auto mk2d = [](aclDataType dt, int64_t d0, int64_t d1, void* data, + uint64_t elem_sz) -> atb::Tensor { + atb::Tensor t; + t.desc.dtype = dt; + t.desc.format = ACL_FORMAT_ND; + t.desc.shape.dimNum = 2; + t.desc.shape.dims[0] = d0; + t.desc.shape.dims[1] = d1; + t.deviceData = data; + t.dataSize = static_cast(d0 * d1) * elem_sz; + + return t; + }; + + // Ensure workspace covers both auxiliary buffers and ATB's own workspace. + auto& arena = ascend::GetWorkspacePool().Ensure(stream, aux_bytes); + auto* base = static_cast(arena.buf); + void* seeds_ptr = base; + void* in2_ptr = base + seeds_bytes; + void* in3_ptr = base + seeds_bytes + in2_bytes; + void* out1_ptr = base + seeds_bytes + in2_bytes + in3_bytes; + + atb::Tensor t_probs = mk2d(probs_dt, B, V, probs_ptr, probs_elem); + atb::Tensor t_seeds = mk2d(ACL_INT32, B, 1, seeds_ptr, 4); + atb::Tensor t_in2 = mk2d(probs_dt, B, 1, in2_ptr, probs_elem); + atb::Tensor t_in3 = mk2d(probs_dt, B, V, in3_ptr, probs_elem); + atb::Tensor t_out0 = mk2d(ACL_INT32, B, 1, out_ptr, 4); + atb::Tensor t_out1 = mk2d(probs_dt, B, 1, out1_ptr, probs_elem); + + atb::VariantPack vp; + vp.inTensors = {t_probs, t_seeds, t_in2, t_in3}; + vp.outTensors = {t_out0, t_out1}; + + uint64_t ws_size = 0; + atb::Status s = op_->Setup(vp, ws_size, ctx); + + if (s != atb::NO_ERROR) { + fprintf(stderr, "[TopkToppSampling] Setup failed (status=%d)\n", + static_cast(s)); + + return; + } + + // ATB workspace (separate from auxiliary buffers). + uint8_t* ws_ptr = nullptr; + + if (ws_size > 0) { + auto& ws_arena = + ascend::GetWorkspacePool().Ensure(stream, aux_bytes + ws_size); + + // Re-derive auxiliary pointers from the (possibly reallocated) arena. + base = static_cast(ws_arena.buf); + ws_ptr = base + aux_bytes; + + // Update tensor data pointers in case the arena was reallocated. + seeds_ptr = base; + in2_ptr = base + seeds_bytes; + in3_ptr = base + seeds_bytes + in2_bytes; + out1_ptr = base + seeds_bytes + in2_bytes + in3_bytes; + + vp.inTensors[1].deviceData = seeds_ptr; + vp.inTensors[2].deviceData = in2_ptr; + vp.inTensors[3].deviceData = in3_ptr; + vp.outTensors[1].deviceData = out1_ptr; + + // Re-run Setup with updated pointers. + s = op_->Setup(vp, ws_size, ctx); + + if (s != atb::NO_ERROR) { + fprintf(stderr, "[TopkToppSampling] Setup (retry) failed (status=%d)\n", + static_cast(s)); + + return; + } + } + + s = op_->Execute(vp, ws_ptr, ws_size, ctx); + + if (s != atb::NO_ERROR) { + fprintf(stderr, "[TopkToppSampling] Execute failed (status=%d)\n", + static_cast(s)); + } + } + + private: + atb::Operation* op_ = nullptr; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_TOPK_TOPP_SAMPLING_KERNEL_ATB_H_ diff --git a/src/base/flash_attention.h b/src/base/flash_attention.h index e5952b51..678a89fc 100644 --- a/src/base/flash_attention.h +++ b/src/base/flash_attention.h @@ -9,31 +9,32 @@ namespace infini::ops { -// Fused multi-head / grouped-query attention. -// -// Interface follows vLLM v1 `AttentionImpl.forward()`: -// `vllm.v1.attention.backends.abstract.AttentionImpl` -// -// Layout: `query` / `key` / `value` are `[T, N, D]` (TND). -// Prefill uses `cu_seqlens_q` / `cu_seqlens_kv` for variable-length packing. -// Decode uses `block_table` for paged KV cache lookup. class FlashAttention : public Operator { public: + // `window_left` / `window_right` is the native InfiniOps pair-form + // window (left-context / right-context tokens, `-1` = disabled). + // `sliding_window` is a vLLM-style single-parameter shortcut: when + // set, it is normalized to `(sliding_window - 1, 0)` — i.e. causal + // sliding over the most recent `sliding_window` tokens. When both + // forms are supplied the normalized values must agree. Callers may + // use whichever form is more natural; the kernel only sees the + // resolved pair. FlashAttention(const Tensor query, const Tensor key, const Tensor value, std::optional cu_seqlens_q, std::optional cu_seqlens_kv, std::optional block_table, int64_t num_heads, int64_t num_kv_heads, int64_t head_size, double scale, bool causal, int64_t window_left, int64_t window_right, - int64_t block_size, Tensor output) + int64_t block_size, Tensor output, + std::optional sliding_window = std::nullopt) : num_tokens_{query.size(0)}, num_heads_{num_heads}, num_kv_heads_{num_kv_heads}, head_size_{head_size}, scale_{scale}, causal_{causal}, - window_left_{window_left}, - window_right_{window_right}, + window_left_{resolveWindowLeft(window_left, sliding_window)}, + window_right_{resolveWindowRight(window_right, sliding_window)}, block_size_{block_size}, dtype_{query.dtype()}, query_shape_{query.shape()}, @@ -48,21 +49,45 @@ class FlashAttention : public Operator { has_cu_seqlens_kv_{cu_seqlens_kv.has_value()}, has_block_table_{block_table.has_value()} { assert(num_heads % num_kv_heads == 0 && - "`FlashAttention` requires num_heads divisible by num_kv_heads"); + "`FlashAttention` requires `num_heads` divisible by `num_kv_heads`"); assert(query.ndim() == 3 && "`FlashAttention` requires query to be 3D [T, N, D]"); } - virtual void operator()(const Tensor query, const Tensor key, - const Tensor value, - std::optional cu_seqlens_q, - std::optional cu_seqlens_kv, - std::optional block_table, int64_t num_heads, - int64_t num_kv_heads, int64_t head_size, double scale, - bool causal, int64_t window_left, - int64_t window_right, int64_t block_size, - Tensor output) const = 0; + virtual void operator()( + const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, bool causal, + int64_t window_left, int64_t window_right, int64_t block_size, + Tensor output, + std::optional sliding_window = std::nullopt) const = 0; + + private: + // Normalize the window representation. If both the explicit pair and + // `sliding_window` are supplied, assert the pair matches the derived + // `(sliding_window - 1, 0)` causal-sliding window. + static int64_t resolveWindowLeft(int64_t window_left, + std::optional sliding_window) { + if (!sliding_window.has_value()) return window_left; + int64_t derived = sliding_window.value() - 1; + assert( + (window_left == -1 || window_left == derived) && + "`FlashAttention`: `window_left` inconsistent with `sliding_window`"); + return derived; + } + + static int64_t resolveWindowRight(int64_t window_right, + std::optional sliding_window) { + if (!sliding_window.has_value()) return window_right; + assert( + (window_right == -1 || window_right == 0) && + "`FlashAttention`: `window_right` inconsistent with `sliding_window` " + "(vLLM sliding_window implies right=0)"); + return 0; + } + public: protected: Tensor::Size num_tokens_{0}; diff --git a/src/base/paged_attention.h b/src/base/paged_attention.h new file mode 100644 index 00000000..aa98e826 --- /dev/null +++ b/src/base/paged_attention.h @@ -0,0 +1,129 @@ +#ifndef INFINI_OPS_BASE_PAGED_ATTENTION_H_ +#define INFINI_OPS_BASE_PAGED_ATTENTION_H_ + +#include +#include +#include + +#include "operator.h" + +namespace infini::ops { + +// Paged decode attention operator. +// +// Performs multi-head attention over paged KV caches for decode (single-token +// queries per sequence). +// +// Interface follows vLLM's paged attention convention: +// - vLLM CUDA: `torch.ops.vllm.paged_attention_v1` uses the same query +// shape [batch, num_heads, head_size] and seq_lens [batch] int32. +// KV cache differs (5D on CUDA for vectorization, 4D here). +// - vLLM-Ascend: `torch_npu._npu_paged_attention` wraps ATB +// `PagedAttentionParam` with default `inputLayout` (`TYPE_BSND`). +// - ATB `PagedAttentionParam`: `headNum`, `kvHeadNum`, `qkScale`, +// `maskType` (default NORM), `inputLayout` (default `TYPE_BSND`). +// +// Input layout (BSND with S=1 for decode): +// query : [batch, num_heads, head_size] +// key_cache : [num_blocks, block_size, num_kv_heads, head_size] +// value_cache : [num_blocks, block_size, num_kv_heads, head_size] +// seq_lens : [batch] int32 — total context length per sequence +// block_table : [batch, max_num_blocks_per_seq] int32 +// +// Output layout: +// output : [batch, num_heads, head_size] +// +// Optional host tensors: `seq_lens_host` and `block_table_host` are CPU +// mirrors of `seq_lens` and `block_table`. They exist because CANN's +// paged-attention APIs mandate CPU-resident metadata — aclnn declares +// `qSeqLens` as a CPU tensor in its signature, and ATB +// `PagedAttentionParam` reads `aclIntArray*` parameters from the +// `hostData` field at `aclnnRunner::Setup()` time. Without caller- +// provided host tensors, the kernel must synchronously D2H-copy both +// each call, which (a) blocks the stream and (b) prevents NPUGraph +// capture (sync copies are not capturable). When the caller already +// has CPU-pinned copies (e.g. vLLM's `optimistic_seq_lens_cpu` and +// `BlockTable.get_cpu_tensor()`), passing them through lets the kernel +// skip both D2H copies and be captured into a full NPUGraph. +class PagedAttention : public Operator { + public: + PagedAttention(const Tensor query, const Tensor key_cache, + const Tensor value_cache, const Tensor seq_lens, + const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output, + std::optional seq_lens_host = std::nullopt, + std::optional block_table_host = std::nullopt) + : batch_size_{query.size(0)}, + num_heads_{num_heads}, + num_kv_heads_{num_kv_heads}, + head_size_{head_size}, + scale_{scale}, + block_size_{block_size}, + dtype_{query.dtype()}, + query_shape_{query.shape()}, + key_cache_shape_{key_cache.shape()}, + value_cache_shape_{value_cache.shape()}, + seq_lens_shape_{seq_lens.shape()}, + block_table_shape_{block_table.shape()}, + output_shape_{output.shape()}, + has_seq_lens_host_{seq_lens_host.has_value()}, + has_block_table_host_{block_table_host.has_value()} { + assert( + num_heads % num_kv_heads == 0 && + "`PagedAttention` requires `num_heads` divisible by `num_kv_heads`."); + assert(query.ndim() == 3 && + "`PagedAttention` requires query to be 3D [batch, num_heads, " + "head_size]."); + assert(key_cache.ndim() == 4 && + "`PagedAttention` requires key_cache to be 4D [num_blocks, " + "block_size, num_kv_heads, head_size]."); + assert(seq_lens.ndim() == 1 && + "`PagedAttention` requires seq_lens to be 1D [batch]."); + assert(block_table.ndim() == 2 && + "`PagedAttention` requires block_table to be 2D [batch, " + "max_num_blocks]."); + } + + virtual void operator()( + const Tensor query, const Tensor key_cache, const Tensor value_cache, + const Tensor seq_lens, const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, int64_t block_size, + Tensor output, std::optional seq_lens_host = std::nullopt, + std::optional block_table_host = std::nullopt) const = 0; + + protected: + Tensor::Size batch_size_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + double scale_{0.0}; + + int64_t block_size_{0}; + + const DataType dtype_; + + Tensor::Shape query_shape_; + + Tensor::Shape key_cache_shape_; + + Tensor::Shape value_cache_shape_; + + Tensor::Shape seq_lens_shape_; + + Tensor::Shape block_table_shape_; + + Tensor::Shape output_shape_; + + bool has_seq_lens_host_{false}; + + bool has_block_table_host_{false}; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_BASE_PAGED_ATTENTION_H_ diff --git a/src/base/reshape_and_cache.h b/src/base/reshape_and_cache.h index 4bbd5db8..5d0adfad 100644 --- a/src/base/reshape_and_cache.h +++ b/src/base/reshape_and_cache.h @@ -8,15 +8,6 @@ namespace infini::ops { -// Scatter `key` / `value` tokens into a paged KV cache. -// -// Interface follows vLLM's `reshape_and_cache` kernel: -// `vllm._custom_ops.reshape_and_cache_flash` -// -// `kv_cache` layout: `[2, num_blocks, block_size, num_kv_heads, head_size]`. -// `slot_mapping`: 1D `[num_tokens]`, each entry is the linear slot index -// into the cache. Padding tokens must be filtered by the caller (no -// negative indices). class ReshapeAndCache : public Operator { public: ReshapeAndCache(const Tensor key, const Tensor value, const Tensor kv_cache, diff --git a/src/base/topk_topp_sampling.h b/src/base/topk_topp_sampling.h new file mode 100644 index 00000000..392b35e8 --- /dev/null +++ b/src/base/topk_topp_sampling.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_BASE_TOPK_TOPP_SAMPLING_H_ +#define INFINI_OPS_BASE_TOPK_TOPP_SAMPLING_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +// Top-k/top-p sampling operator. +// +// Performs fused top-k and top-p filtering followed by random sampling +// from the filtered probability distribution. +// +// Input layout: +// probs : [batch_size, vocab_size] float16/float32 — probability distribution +// (softmax output, must sum to 1 along dim=-1). +// +// Parameters: +// topk : int64_t — number of highest-probability tokens to keep (0 = +// disabled). topp : double — cumulative probability threshold (0.0 = +// disabled). +// +// Output layout: +// out : [batch_size] int32 — sampled token indices. +class TopkToppSampling : public Operator { + public: + TopkToppSampling(const Tensor probs, int64_t topk, double topp, Tensor out) + : batch_size_{probs.size(0)}, + vocab_size_{probs.size(1)}, + topk_{topk}, + topp_{topp}, + dtype_{probs.dtype()} { + assert(probs.ndim() == 2 && + "`TopkToppSampling` requires `probs` to be 2D [batch_size, " + "vocab_size]."); + assert(out.ndim() == 1 && + "`TopkToppSampling` requires `out` to be 1D [batch_size]."); + assert(out.size(0) == probs.size(0) && + "`TopkToppSampling` requires `out` and `probs` to have the same " + "batch_size."); + } + + virtual void operator()(const Tensor probs, int64_t topk, double topp, + Tensor out) const = 0; + + protected: + Tensor::Size batch_size_{0}; + + Tensor::Size vocab_size_{0}; + + int64_t topk_{0}; + + double topp_{0.0}; + + const DataType dtype_; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_BASE_TOPK_TOPP_SAMPLING_H_ diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py new file mode 100644 index 00000000..1e0984ec --- /dev/null +++ b/tests/test_flash_attention.py @@ -0,0 +1,597 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ( + (32, 32, 128), # MHA + (32, 8, 128), # GQA (4x) + (16, 4, 64), # GQA (4x), smaller + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_single( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Single sequence prefill (no block table).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_tokens = 16 + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention( + q, + k, + v, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ((32, 8, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_multi( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Multi-sequence prefill with cu_seqlens.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + seq_lens = [8, 12, 4] + num_tokens = sum(seq_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + cu_seqlens_q = torch.tensor( + [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], + dtype=torch.int64, + device=device, + ) + cu_seqlens_kv = cu_seqlens_q.clone() + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention_multi( + q, + k, + v, + seq_lens, + seq_lens, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_decode( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Decode phase: single token per request with paged KV cache.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 3 + kv_len = 16 # Total KV length per request. + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + # Paged KV cache: vLLM standard layout [num_blocks, block_size, KV_N, D]. + kv_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + cu_seqlens_q = torch.arange(0, num_reqs + 1, dtype=torch.int64, device=device) + cu_seqlens_kv = torch.tensor( + [i * kv_len for i in range(num_reqs + 1)], dtype=torch.int64, device=device + ) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + block_size, + o, + ), + lambda q, k, v, o: _ref_flash_attention_paged( + q, + k, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, + ), + (query, kv_cache, kv_cache, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ((torch.float16, 1e-3, 1e-3),), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_decode_cpu_cuseqlens( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Decode with CPU cu_seqlens_kv — exercises the D2H-free code path.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 3 + kv_len = 16 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + kv_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + cu_seqlens_q = torch.arange(0, num_reqs + 1, dtype=torch.int64, device=device) + + # CPU cu_seqlens_kv — exercises `detail::extractSeqLengths` host path + # (direct pointer read, no D2H copy). + cu_seqlens_kv = torch.tensor( + [i * kv_len for i in range(num_reqs + 1)], dtype=torch.int64 + ) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + block_size, + o, + ), + lambda q, k, v, o: _ref_flash_attention_paged( + q, + k, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, + ), + (query, kv_cache, kv_cache, output), + {}, + rtol=rtol, + atol=atol, + ) + + +def _flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, +): + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + stream=get_stream(query.device), + ) + + return output + + +def _ref_flash_attention( + query, key, value, num_heads, num_kv_heads, head_size, scale, causal=True +): + """PyTorch SDPA reference for single-sequence prefill.""" + # [T, N, D] -> [N, T, D] + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + # GQA: expand K/V to match num_heads. + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + # [N, T, D] -> [1, N, T, D] for scaled_dot_product_attention. + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=scale, is_causal=causal + ) + + # [1, N, T, D] -> [T, N, D] -> original dtype. + return out.squeeze(0).transpose(0, 1).to(query.dtype) + + +def _ref_flash_attention_multi( + query, + key, + value, + seq_lens_q, + seq_lens_kv, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, +): + """PyTorch SDPA reference for multi-sequence prefill.""" + outputs = [] + offset = 0 + for sq, sk in zip(seq_lens_q, seq_lens_kv): + q = query[offset : offset + sq] + k = key[offset : offset + sq] + v = value[offset : offset + sq] + out = _ref_flash_attention( + q, k, v, num_heads, num_kv_heads, head_size, scale, causal + ) + outputs.append(out) + offset += sq + + return torch.cat(outputs, dim=0) + + +def _ref_flash_attention_paged( + query, + kv_cache_arg, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, +): + """PyTorch SDPA reference for decode with paged KV cache.""" + cu_kv = cu_seqlens_kv.cpu() + bt = block_table.cpu() + cache = kv_cache_arg.cpu() + q_cpu = query.cpu() + num_reqs = bt.size(0) + outputs = [] + + for i in range(num_reqs): + q = q_cpu[i : i + 1] # [1, N, D] + kv_len = int(cu_kv[i + 1] - cu_kv[i]) + + # Gather KV from paged cache. + # cache: [num_blocks, KV_N, block_size, D] + blocks = bt[i] + k_pages = [] + v_pages = [] + remaining = kv_len + + for b in blocks: + if remaining <= 0: + break + + take = min(remaining, block_size) + # cache layout: [num_blocks, block_size, KV_N, D] + # Slice [take, KV_N, D], transpose to [KV_N, take, D] for cat. + k_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + v_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + remaining -= take + k = torch.cat(k_pages, dim=1) # [KV_N, kv_len, D] + v = torch.cat(v_pages, dim=1) + + # Decode: Q_S=1 attends to all past KV positions; causal masking is + # not applicable here (it would mask everything beyond position 0). + out = _ref_flash_attention( + q, # [1, N, D] - already TND format + k.transpose(0, 1), # [KV_N, kv_len, D] -> [kv_len, KV_N, D] + v.transpose(0, 1), + num_heads, + num_kv_heads, + head_size, + scale, + causal=False, + ) + outputs.append(out) + + return torch.cat(outputs, dim=0).to(query.device) + + +@pytest.mark.parametrize("sliding_window", (4, 16)) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_sliding_window_equivalence(sliding_window, device): + """The vLLM-style `sliding_window=N` entry must produce the same output + as the native `window_left=N-1, window_right=0` pair. + """ + if not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_tokens = 32 + num_heads = 8 + num_kv_heads = 8 + head_size = 64 + scale = 1.0 / head_size**0.5 + dtype = torch.float16 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + + cu_seqlens_q = torch.tensor([0, num_tokens], dtype=torch.int64, device=device) + cu_seqlens_kv = torch.tensor([0, num_tokens], dtype=torch.int64, device=device) + + # Pair-form call. + out_pair = torch.empty_like(query) + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + sliding_window - 1, + 0, + 0, + out_pair, + stream=get_stream(query.device), + ) + + # vLLM-style single-parameter call. + out_sw = torch.empty_like(query) + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + -1, + 0, + out_sw, + sliding_window=sliding_window, + stream=get_stream(query.device), + ) + + assert torch.equal(out_pair, out_sw), ( + f"Max diff: {(out_pair.float() - out_sw.float()).abs().max().item()}" + ) diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py new file mode 100644 index 00000000..c2258ffa --- /dev/null +++ b/tests/test_paged_attention.py @@ -0,0 +1,554 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream, randn_strided + + +def _atb_pa_unsupported_reason(): + """Return a reason string if ATB PagedAttention can't run here, else `""`. + + Uses a narrow SoC-name check rather than a try/except on the op under + test — the latter silently masks real regressions by converting any + runtime failure in `paged_attention` into a clean skip. + """ + if not (hasattr(torch, "npu") and torch.npu.is_available()): + return "NPU not available" + + if not infini.ops.PagedAttention.active_implementation_indices("ascend"): + return "ATB PagedAttention implementation not registered for Ascend" + + return "" + + +_skip_no_atb_pa = pytest.mark.skipif( + bool(_atb_pa_unsupported_reason()), + reason=_atb_pa_unsupported_reason() or "ATB PagedAttention unsupported", +) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + (32, 32, 128, 128), # MHA + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_basic( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Basic paged decode attention with contiguous block assignments.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 4 + kv_len = 16 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + # Context lengths (total KV length per request). + seq_lens = torch.full((num_reqs,), kv_len, dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_variable_seq_lens( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Paged decode attention where each request has a different KV length.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + kv_lens = [8, 32, 16, 128] + num_reqs = len(kv_lens) + max_blocks_per_req = max((kv + block_size - 1) // block_size for kv in kv_lens) + num_blocks = sum((kv + block_size - 1) // block_size for kv in kv_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: assign blocks sequentially. + block_table = torch.zeros( + (num_reqs, max_blocks_per_req), dtype=torch.int32, device=device + ) + block_idx = 0 + + for i in range(num_reqs): + n_blocks = (kv_lens[i] + block_size - 1) // block_size + + for j in range(n_blocks): + block_table[i, j] = block_idx + block_idx += 1 + + seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ((torch.float16, 1e-3, 1e-3),), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_single_request( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Single request decode (batch_size=1).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 1 + kv_len = 64 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + block_table = torch.arange( + num_blocks_per_req, dtype=torch.int32, device=device + ).unsqueeze(0) + + seq_lens = torch.tensor([kv_len], dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ((torch.float16, 1e-3, 1e-3),), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_host_tensors( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Paged decode with caller-provided host tensors (D2H-free path).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 4 + kv_len = 16 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + seq_lens = torch.full((num_reqs,), kv_len, dtype=torch.int32, device=device) + + # CPU copies for the D2H-free path. + seq_lens_cpu = seq_lens.cpu().contiguous() + block_table_cpu = block_table.cpu().contiguous() + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention_with_host( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + seq_lens_cpu, + block_table_cpu, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +def _paged_attention_with_host( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, + seq_lens_host, + block_table_host, +): + """Call paged attention with caller-provided host tensors.""" + infini.ops.paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, + seq_lens_host=seq_lens_host, + block_table_host=block_table_host, + stream=get_stream(query.device), + ) + + return output + + +def _paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, +): + infini.ops.paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, + stream=get_stream(query.device), + ) + + return output + + +def _ref_paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, +): + """PyTorch SDPA reference for paged decode attention.""" + sl = seq_lens.cpu() + bt = block_table.cpu() + kc = key_cache.cpu().float() + vc = value_cache.cpu().float() + q_cpu = query.cpu().float() + num_reqs = bt.size(0) + outputs = [] + + for i in range(num_reqs): + q = q_cpu[i : i + 1] # [1, N, D] + kv_len = int(sl[i].item()) + + # Gather K and V from paged cache. + # Cache layout: [num_blocks, block_size, Nkv, D]. + blocks = bt[i] + k_pages = [] + v_pages = [] + remaining = kv_len + + for b in blocks: + if remaining <= 0: + break + + take = min(remaining, block_size) + k_pages.append(kc[int(b.item()), :take, :, :]) + v_pages.append(vc[int(b.item()), :take, :, :]) + remaining -= take + + # [kv_len, Nkv, D] + k = torch.cat(k_pages, dim=0) + v = torch.cat(v_pages, dim=0) + + # SDPA reference with GQA expansion. + # q: [1, N, D] -> [N, 1, D] + q_t = q.transpose(0, 1) + # k, v: [kv_len, Nkv, D] -> [Nkv, kv_len, D] + k_t = k.transpose(0, 1) + v_t = v.transpose(0, 1) + + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k_t = k_t.repeat_interleave(ratio, dim=0) + v_t = v_t.repeat_interleave(ratio, dim=0) + + # [N, 1, D] and [N, kv_len, D] -> [1, N, 1, D] and [1, N, kv_len, D] + q_4d = q_t.unsqueeze(0) + k_4d = k_t.unsqueeze(0) + v_4d = v_t.unsqueeze(0) + + # Decode: query attends to all past KV (no causal mask). + out = torch.nn.functional.scaled_dot_product_attention( + q_4d, + k_4d, + v_4d, + scale=scale, + is_causal=False, + ) + + # [1, N, 1, D] -> [1, N, D] + outputs.append(out.squeeze(0).transpose(0, 1).squeeze(0).unsqueeze(0)) + + return torch.cat(outputs, dim=0).to(query.dtype).to(query.device) diff --git a/tests/test_reshape_and_cache.py b/tests/test_reshape_and_cache.py new file mode 100644 index 00000000..4f69501f --- /dev/null +++ b/tests/test_reshape_and_cache.py @@ -0,0 +1,273 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_stream, randn_strided + +# ReshapeAndCache only works on NPU (aclrtMemcpy-based), so tests only +# parametrize on float16/bfloat16 and use explicit device parametrization. + +# `aclnnScatterPaKvCache` (index 1) requires Atlas A5 (SoC 260). It compiles +# on 910B (CANN 8.5.1 headers present) but produces wrong results at runtime. +_SKIP_INDEX_1 = pytest.mark.skip( + reason="`aclnnScatterPaKvCache` (index 1) requires Atlas A5; " + "not supported on Ascend 910B" +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (1, 8, 128, 4, 16), + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + (16, 2, 128, 8, 64), + ), +) +@pytest.mark.parametrize( + "implementation_index", + (0, pytest.param(1, marks=_SKIP_INDEX_1), 2), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_contiguous( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + implementation_index, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + active_indices = infini.ops.ReshapeAndCache.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + # Layout: [2, num_blocks, block_size, num_kv_heads, head_size] + # Index 0 = key cache, index 1 = value cache. + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Contiguous slot mapping: token i -> slot i. + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + return Payload( + lambda *args, **kwargs: _reshape_and_cache( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + ), +) +@pytest.mark.parametrize( + "implementation_index", + (0, pytest.param(1, marks=_SKIP_INDEX_1), 2), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_noncontiguous_slots( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + implementation_index, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + active_indices = infini.ops.ReshapeAndCache.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Non-contiguous slots: skip every other slot. + slot_mapping = torch.tensor( + [i * 2 for i in range(num_tokens)], dtype=torch.int64, device=device + ) + + return Payload( + lambda *args, **kwargs: _reshape_and_cache( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (8, 8, 128, 4, 16), + (16, 4, 64, 8, 32), + ), +) +@pytest.mark.parametrize( + "implementation_index", + (0, pytest.param(1, marks=_SKIP_INDEX_1), 2), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_padding_slots( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + implementation_index, + dtype, + rtol, + atol, + device, +): + """Graph-padded decode: slots with `-1` must be skipped, not written. + + `aclnnInplaceIndexCopy` silently treats `slot=-1` as "last index" which + corrupts the last KV cache entry. The wrapper must filter `-1` slots + before calling the underlying op. + """ + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + active_indices = infini.ops.ReshapeAndCache.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + + # Every other token is a padding slot (`-1`); valid slots map to unique + # contiguous positions so a correct wrapper leaves the final entry of + # the last block untouched. + slot_values = [] + valid = 0 + + for i in range(num_tokens): + if i % 2 == 0: + slot_values.append(-1) + else: + slot_values.append(valid) + valid += 1 + + slot_mapping = torch.tensor(slot_values, dtype=torch.int64, device=device) + + return Payload( + lambda *args, **kwargs: _reshape_and_cache( + *args, **kwargs, implementation_index=implementation_index + ), + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +def _reshape_and_cache( + key, value, kv_cache, slot_mapping, kv_cache_out, implementation_index=0 +): + infini.ops.reshape_and_cache( + key, + value, + kv_cache, + slot_mapping, + kv_cache_out, + implementation_index=implementation_index, + stream=get_stream(key.device), + ) + + return kv_cache_out + + +def _ref_reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + kv_cache_out = kv_cache_out.clone() + slots = slot_mapping.cpu() + block_size = kv_cache_out.size(2) + + for i in range(key.size(0)): + slot = int(slots[i].item()) + + if slot < 0: + continue + + block_idx = slot // block_size + offset = slot % block_size + kv_cache_out[0, block_idx, offset, :, :] = key[i, :, :] + kv_cache_out[1, block_idx, offset, :, :] = value[i, :, :] + + return kv_cache_out