Skip to content

feat(ascend): op-cache-attn group — ReshapeAndCache, FlashAttention, PagedAttention, TopkToppSampling#67

Open
zhangyue207 wants to merge 1 commit intomasterfrom
feat/ascend-op-cache-attn
Open

feat(ascend): op-cache-attn group — ReshapeAndCache, FlashAttention, PagedAttention, TopkToppSampling#67
zhangyue207 wants to merge 1 commit intomasterfrom
feat/ascend-op-cache-attn

Conversation

@zhangyue207
Copy link
Copy Markdown
Collaborator

Summary

Four KV-cache and attention Ascend operators — ReshapeAndCache,
FlashAttention, PagedAttention, TopkToppSampling — completing the Ascend
operator set needed for transformer decode.

Part 4 of 4 in the Ascend operator split. Parallel-reviewable with
op-simple and op-norm-rope (operator sets are disjoint).

Depends on: feat/ascend-framework-pr must merge first.

Operators

op impl Notes
ReshapeAndCache 3 impls: aclnnInplaceIndexCopy (kernel.h); custom AscendC (kernel_v2.h); ATB ReshapeAndCacheParam (kernel_atb.h) int64 slot_mapping handled via cached async aclnnCast — no D2H sync, NPUGraph-compatible
FlashAttention aclnnFusedInferAttentionScoreV4 (prefill + paged decode) supports both (window_left, window_right) pair AND std::optional<int64_t> sliding_window entry (vLLM-style, additive)
PagedAttention ATB PagedAttentionParam (impl=0) optional CPU-pinned host tensors (seq_lens_host / block_table_host) enable NPUGraph capture by avoiding per-layer sync D2H
TopkToppSampling ATB TopkToppSamplingParam

vLLM API alignment

perf(reshape_and_cache): async int64→int32 slot_mapping

ATB ReshapeAndCacheParam requires int32 slot_mapping. The previous
implementation handled int64 (PyTorch / vLLM's native dtype) via D2H +
CPU cast + H2D + aclrtSynchronizeStream
, which stalled the stream and
made the int64 path NPUGraph-incapturable. Replaced with a cached
aclnnCast async conversion on-stream; performance matches the int32
pass-through and the whole op is now graph-captureable.

feat(flash_attention): add sliding_window entry (additive)

Native window_left / window_right pair kept as-is; added an optional
std::optional<int64_t> sliding_window:

  • pair only → unchanged behavior
  • sliding_window only → normalized to (sliding_window - 1, 0) causal
    sliding (vLLM convention)
  • both → asserted consistent

test_flash_attention_sliding_window_equivalence asserts bit-exact
equivalence between the two entry points.

docs(paged_attention): host tensor contract

src/base/paged_attention.h class comment explains why seq_lens_host /
block_table_host exist (CANN qSeqLens CPU-resident contract + ATB
hostData + NPUGraph capture prerequisite) so future backend implementors
understand the API contract.

Base headers

  • NEW: src/base/paged_attention.h, topk_topp_sampling.h
  • MODIFY: src/base/reshape_and_cache.h, flash_attention.h

Verification

  • python3 .ci/run.py --local --gpu-id <N> (Ascend 910B + CANN 8.5.1):
    3129 passed / 1798 skipped / 0 failed

Test plan

  • python3 .ci/run.py --local
  • test_flash_attention_sliding_window_equivalence (pair vs
    sliding_window bit-exact): 2 passed
  • test_reshape_and_cache (int32 + int64 paths): 32 passed
  • test_paged_attention (910B skip removed after CANN 8.5.1 fix): 10
    passed
  • clang-format passes locally
  • CUDA / Metax / Cambricon / Moore / Iluvatar regressions (CI-verified)

@zhangyue207
Copy link
Copy Markdown
Collaborator Author

merge test

[gw0] [ 99%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape3-input_strides3-gate_strides3-out_strides3] 
tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape4-None-None-None] 
[gw0] [ 99%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape4-None-None-None] 
tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape5-input_strides5-gate_strides5-out_strides5] 
[gw0] [ 99%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape5-input_strides5-gate_strides5-out_strides5] 
tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape6-None-None-None] 
[gw0] [ 99%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape6-None-None-None] 
tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape7-input_strides7-gate_strides7-out_strides7] 
[gw0] [100%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape7-input_strides7-gate_strides7-out_strides7] 

----------- generated xml file: /workspace/results/test-results.xml ------------
===================== 3767 passed, 1664 skipped in 45.95s ======================
========== Summary ==========
[warn] job ascend_npu: container exited with 137 (likely docker teardown SIGKILL after clean pytest); junit XML reports no failures — treating as success
EXIT=0

@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-cache-attn branch 13 times, most recently from 083573a to e3b6f16 Compare April 21, 2026 06:40
//
// When cu_seqlens is a CPU tensor (device type kCpu), the data pointer is
// already on the host and can be read directly — no D2H sync needed.
inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数命名 UpperCamelCase.

// convention for npu_fused_infer_attention_score actual_seq_lengths.
//
// When cu_seqlens is a CPU tensor, reads directly from host memory.
inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,相关文件感觉可以统一检查一遍。

Comment on lines +331 to +338
assert(gws == ACL_SUCCESS &&
"aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (decode)");

auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed);
aclError ret =
aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, executor, stream);
assert(ret == ACL_SUCCESS &&
"aclnnFusedInferAttentionScoreV4 failed (decode)");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert 里的函数名需用代码格式。


if (!has_block_table_host_) {
bt_host_ = std::malloc(bt_host_bytes_);
assert(bt_host_ && "Host buffer allocation for `block_table` failed");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message 开头小写,其他相关文件建议统一检查和修改一下。

Comment on lines +59 to +62
int64_t B = static_cast<int64_t>(batch_size_);
int64_t N = num_heads_;
int64_t Nkv = num_kv_heads_;
int64_t D = head_size_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量命名 snake_case

…PagedAttention, TopkToppSampling

Four KV-cache and attention operators:

| op | impl |
|---|---|
| ReshapeAndCache | 3 impls: aclnnInplaceIndexCopy (kernel.h); custom AscendC (kernel_v2.h); ATB `ReshapeAndCacheParam` (kernel_atb.h, int64 `slot_mapping` handled via cached async `aclnnCast`) |
| FlashAttention | `aclnnFusedInferAttentionScoreV4` (prefill + paged decode). Supports both the native `(window_left, window_right)` pair and a new `std::optional<int64_t> sliding_window` entry (additive, vLLM-style) |
| PagedAttention | ATB `PagedAttentionParam` with optional CPU-pinned host tensors (`seq_lens_host` / `block_table_host`) that make the op NPUGraph-capturable |
| TopkToppSampling | ATB `TopkToppSamplingParam` |

Includes vLLM API alignment commits:
- `perf(reshape_and_cache)`: int64 slot_mapping routed through cached
  async `aclnnCast` (no D2H sync, NPUGraph-compatible)
- `feat(flash_attention)`: add `sliding_window` entry, additive
- `docs(paged_attention)`: base class comment explains the CPU-host
  tensor contract

New `src/base/<op>.h`: paged_attention, topk_topp_sampling.
Modified: reshape_and_cache, flash_attention.
@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-cache-attn branch from e3b6f16 to 6b8b32f Compare April 22, 2026 06:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants