test(st): Add dynamic-shape paged attention and emit tensor view layout attribute#578
test(st): Add dynamic-shape paged attention and emit tensor view layout attribute#578Crystal-wzy wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the PyPTO codegen by introducing support for dynamic shapes in paged attention kernels. It addresses a critical issue where tensor view layout attributes were not correctly emitted, particularly for transposed key caches. By integrating dynamic variables into kernel type annotations and providing a comprehensive example and test suite, this change enables more flexible and robust code generation for advanced deep learning models, ensuring correct representation and execution on target hardware like Ascend910B. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a new example implementing dynamic-shaped, paged attention with module-level dynamic InCore kernels and an orchestration program builder; adds integration tests for the dynamic kernels; and updates PTO codegen to emit explicit tensor view layout annotations in generated code. Changes
Sequence DiagramsequenceDiagram
participant User as User/Test
participant Builder as build_dynamic_paged_attention_program()
participant Program as DynamicPagedAttentionProgram
participant Orchestrator as paged_attention()
participant Init as dyn_kernel_init_inplace
participant QK as dyn_kernel_qk_matmul
participant Softmax as dyn_kernel_softmax_prepare
participant PV as dyn_kernel_pv_matmul
participant Update as dyn_kernel_online_update
User->>Builder: request program (batch, heads, head_dim, block_size, max_blocks, q_tile)
Builder->>Program: construct program class with dynamic variables
Builder-->>User: return Program
User->>Program: invoke paged_attention(...)
Program->>Orchestrator: start orchestration
Orchestrator->>Init: init OI/LI/MI accumulators
Init-->>Orchestrator: zero-filled accumulators
loop per batch / head / block
Orchestrator->>QK: load Q, load/transposed K, compute S_ij
QK-->>Orchestrator: S_ij (scores)
Orchestrator->>Softmax: scale S_ij, compute max/exp/normalize -> P_ij, update mi/li
Softmax-->>Orchestrator: P_ij, mi, li
Orchestrator->>PV: load V, compute P_ij * V -> O_new
PV-->>Orchestrator: O_new
Orchestrator->>Update: online update with is_first/is_last -> update OI/LI/MI, write dst
Update-->>Orchestrator: updated accumulators and dst
end
Orchestrator-->>Program: assemble final output
Program-->>User: return output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. 📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can scan for known vulnerabilities in your dependencies using OSV Scanner.OSV Scanner will automatically detect and report security vulnerabilities in your project's dependencies. No additional configuration is required. |
There was a problem hiding this comment.
Code Review
This pull request introduces dynamic-shape paged attention, including a new example, tests, and a codegen fix to emit tensor layout attributes. The changes are well-structured. However, I've identified a few correctness and maintainability issues in the new dynamic_paged_attention_example.py file. Specifically, some values from the config tensor are not being used, is_first/is_last flag creation can be improved, and there's a minor opportunity for code deduplication. I've also suggested a small simplification in the C++ codegen change. Please see the detailed comments for suggestions.
| batch_cfg: pl.Scalar[pl.INT64] = pl.tensor.read(config, [0]) | ||
| num_heads_cfg: pl.Scalar[pl.INT64] = pl.tensor.read(config, [1]) | ||
| head_dim_cfg: pl.Scalar[pl.INT64] = pl.tensor.read(config, [3]) | ||
| block_size_cfg: pl.Scalar[pl.INT64] = pl.tensor.read(config, [4]) | ||
| block_num_cfg: pl.Scalar[pl.INT64] = pl.tensor.read(config, [5]) |
There was a problem hiding this comment.
The config tensor is documented to contain kv_head_num (at index 2) and scale_bits (at index 6), but these values are not read or used in the orchestration function. Instead, a hardcoded scale of 1.0 is passed to dyn_kernel_softmax_prepare. This is a correctness issue as the scaling factor is a crucial part of the attention mechanism and is used in the reference implementation. The scale value should be derived from config and used. The unused kv_head_num should either be used or removed from the documentation and config to avoid confusion.
| if bn == 0: | ||
| is_first: pl.Scalar[pl.INT64] = pl.yield_(1) # type: ignore[reportArgumentType] | ||
| else: | ||
| is_first: pl.Scalar[pl.INT64] = pl.yield_(0) # type: ignore[reportArgumentType] | ||
| if bn == bn_this_batch - 1: | ||
| is_last: pl.Scalar[pl.INT64] = pl.yield_(1) # type: ignore[reportArgumentType] | ||
| else: | ||
| is_last: pl.Scalar[pl.INT64] = pl.yield_(0) # type: ignore[reportArgumentType] |
There was a problem hiding this comment.
The creation of is_first and is_last flags using pl.yield_ is verbose and their type is annotated as pl.Scalar[pl.INT64], which mismatches the dyn_kernel_online_update function's signature expecting pl.Scalar[pl.BOOL]. This can be simplified by using direct boolean comparisons, which is cleaner and ensures type correctness.
is_first = bn == 0
is_last = bn == bn_this_batch - 1| li_out = pl.store(li_updated_dn, [0, 0], li) | ||
|
|
||
| if is_last: | ||
| dst_tile = pl.row_expand_div(oi_updated, li_updated_dn) | ||
| dst_out = pl.store(dst_tile, [0, 0], dst) | ||
| oi_out = pl.store(oi_updated, [0, 0], oi) | ||
| else: | ||
| zero_tile = pl.tile.full([_Q_TILE, _HEAD_DIM], dtype=pl.FP32, value=0.0) | ||
| dst_out = pl.store(zero_tile, [0, 0], dst) | ||
| oi_out = pl.store(oi_updated, [0, 0], oi) |
There was a problem hiding this comment.
The assignment to oi_out is duplicated inside the if is_last: and else: blocks. You can hoist this assignment out of the conditional to reduce code duplication and improve readability, as it's the same in both branches.
| li_out = pl.store(li_updated_dn, [0, 0], li) | |
| if is_last: | |
| dst_tile = pl.row_expand_div(oi_updated, li_updated_dn) | |
| dst_out = pl.store(dst_tile, [0, 0], dst) | |
| oi_out = pl.store(oi_updated, [0, 0], oi) | |
| else: | |
| zero_tile = pl.tile.full([_Q_TILE, _HEAD_DIM], dtype=pl.FP32, value=0.0) | |
| dst_out = pl.store(zero_tile, [0, 0], dst) | |
| oi_out = pl.store(oi_updated, [0, 0], oi) | |
| li_out = pl.store(li_updated_dn, [0, 0], li) | |
| oi_out = pl.store(oi_updated, [0, 0], oi) | |
| if is_last: | |
| dst_tile = pl.row_expand_div(oi_updated, li_updated_dn) | |
| dst_out = pl.store(dst_tile, [0, 0], dst) | |
| else: | |
| zero_tile = pl.tile.full([_Q_TILE, _HEAD_DIM], dtype=pl.FP32, value=0.0) | |
| dst_out = pl.store(zero_tile, [0, 0], dst) |
| switch (tensor_type->tensor_view_.value().layout) { | ||
| case ir::TensorLayout::ND: layout_str = "nd"; break; | ||
| case ir::TensorLayout::DN: layout_str = "dn"; break; | ||
| case ir::TensorLayout::NZ: layout_str = "nz"; break; | ||
| } |
There was a problem hiding this comment.
In this switch statement, layout_str is initialized to "nd" before the if block. The case ir::TensorLayout::ND: then redundantly re-assigns the same value. To improve clarity and remove redundancy, you can replace the assignment with a simple break.
| switch (tensor_type->tensor_view_.value().layout) { | |
| case ir::TensorLayout::ND: layout_str = "nd"; break; | |
| case ir::TensorLayout::DN: layout_str = "dn"; break; | |
| case ir::TensorLayout::NZ: layout_str = "nz"; break; | |
| } | |
| switch (tensor_type->tensor_view_.value().layout) { | |
| case ir::TensorLayout::DN: layout_str = "dn"; break; | |
| case ir::TensorLayout::NZ: layout_str = "nz"; break; | |
| case ir::TensorLayout::ND: break; | |
| } |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
examples/ir_parser/dynamic_paged_attention_example.py (1)
352-359: Use one scalar type foris_first/is_last.
dyn_kernel_online_update()expectspl.Scalar[pl.BOOL], but this site materializespl.Scalar[pl.INT64]. Please align one side or the other instead of relying on an implicit conversion.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ir_parser/dynamic_paged_attention_example.py` around lines 352 - 359, The yields for is_first and is_last currently produce pl.Scalar[pl.INT64]; change them to produce pl.Scalar[pl.BOOL] to match dyn_kernel_online_update()'s expected type by yielding boolean values (True/False) or explicitly casting to pl.Scalar[pl.BOOL]; update the two yield_ calls that set is_first and the two that set is_last so they return booleans instead of 0/1 and ensure the annotation states pl.Scalar[pl.BOOL] to keep the types consistent with dyn_kernel_online_update().
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/ir_parser/dynamic_paged_attention_example.py`:
- Around line 299-350: The partial-final-block path fails because
dyn_kernel_softmax_prepare assumes a full BLOCK_SIZE_DYN for both its input sij
and output out_pij and still loads _BLOCK_SIZE columns, which breaks when
valid_len < block_size_cfg; update dyn_kernel_softmax_prepare (and any related
kernel constants like BLOCK_SIZE_DYN/_BLOCK_SIZE) to accept the actual valid_len
(or a runtime mask) so it only reads/writes valid_len columns, or alternatively
ensure you pass a padded sij slice of size block_size_cfg (with safe padding)
into dyn_kernel_softmax_prepare and keep pij_f16 sized to block_size_cfg;
reference symbols: dyn_kernel_softmax_prepare, sij_valid, pij_f16, valid_len,
BLOCK_SIZE_DYN, _BLOCK_SIZE.
- Around line 200-227: The builder accepts configurations that the dyn_kernel_*
InCore kernels cannot handle because they hardcode 16x128x128 tiles and the
orchestration ceil-divides num_heads without a tail path; add fast-fail
validation at the start of build_dynamic_paged_attention_program to check that
q_tile equals the module tile constant (_Q_TILE, 16), head_dim equals the kernel
tile size (128), block_size equals the kernel block size (128), and that
num_heads is a multiple of q_tile (no remainder), and raise a clear ValueError
if any check fails so callers cannot pass unsupported shapes to dyn_kernel_* or
the orchestration logic.
---
Nitpick comments:
In `@examples/ir_parser/dynamic_paged_attention_example.py`:
- Around line 352-359: The yields for is_first and is_last currently produce
pl.Scalar[pl.INT64]; change them to produce pl.Scalar[pl.BOOL] to match
dyn_kernel_online_update()'s expected type by yielding boolean values
(True/False) or explicitly casting to pl.Scalar[pl.BOOL]; update the two yield_
calls that set is_first and the two that set is_last so they return booleans
instead of 0/1 and ensure the annotation states pl.Scalar[pl.BOOL] to keep the
types consistent with dyn_kernel_online_update().
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 442b2496-c90a-4598-8150-a93bedf6c70d
📒 Files selected for processing (3)
examples/ir_parser/dynamic_paged_attention_example.pysrc/codegen/pto/pto_codegen.cpptests/st/codegen/test_dynamic_paged_attention.py
| def build_dynamic_paged_attention_program( | ||
| batch: int, | ||
| num_heads: int, | ||
| head_dim: int, | ||
| block_size: int, | ||
| max_num_blocks_per_req: int, | ||
| q_tile: int = 16, | ||
| ): | ||
| """Build a paged-attention @pl.program whose InCore kernels use dynamic shapes. | ||
|
|
||
| InCore kernel tensor type annotations reference module-level pl.dynamic() | ||
| variables (Q_HEADS, HEAD_DIM_DYN, BLOCK_SIZE_DYN) so their shapes are | ||
| resolved at runtime from the concrete tensors passed by the orchestration | ||
| function. Load operations use the module-level constants (_Q_TILE, | ||
| _HEAD_DIM, _BLOCK_SIZE) for the tile sizes. | ||
|
|
||
| The orchestration function is identical in structure to the static version | ||
| in paged_attention_example.py (same pl.slice masking, same pipeline). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| batch: number of requests in the batch | ||
| num_heads: number of query heads | ||
| head_dim: per-head feature dimension | ||
| block_size: KV-cache block size (rows per physical block) | ||
| max_num_blocks_per_req: maximum number of KV blocks per request | ||
| q_tile: query-head tile size (must equal _Q_TILE = 16) | ||
| """ |
There was a problem hiding this comment.
Fail fast on shapes this example does not actually support.
The API looks fully dynamic, but every dyn_kernel_* body still loads fixed 16x128x128 tiles. The orchestration also ceil-divides num_heads without a tail path, so a remainder head-group still slices a full q_tile rows. Right now the builder accepts configurations the kernels cannot execute safely.
Suggested guardrails
def build_dynamic_paged_attention_program(
batch: int,
num_heads: int,
head_dim: int,
block_size: int,
max_num_blocks_per_req: int,
q_tile: int = 16,
):
"""Build a paged-attention `@pl.program` whose InCore kernels use dynamic shapes.
@@
q_tile: query-head tile size (must equal _Q_TILE = 16)
"""
+ if (q_tile, head_dim, block_size) != (_Q_TILE, _HEAD_DIM, _BLOCK_SIZE):
+ raise ValueError(
+ "This example currently supports only "
+ f"q_tile={_Q_TILE}, head_dim={_HEAD_DIM}, block_size={_BLOCK_SIZE}"
+ )
+ if num_heads % q_tile != 0:
+ raise ValueError("num_heads must be divisible by q_tile")
+
query_rows = batch * num_headsAlso applies to: 269-318
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ir_parser/dynamic_paged_attention_example.py` around lines 200 -
227, The builder accepts configurations that the dyn_kernel_* InCore kernels
cannot handle because they hardcode 16x128x128 tiles and the orchestration
ceil-divides num_heads without a tail path; add fast-fail validation at the
start of build_dynamic_paged_attention_program to check that q_tile equals the
module tile constant (_Q_TILE, 16), head_dim equals the kernel tile size (128),
block_size equals the kernel block size (128), and that num_heads is a multiple
of q_tile (no remainder), and raise a clear ValueError if any check fails so
callers cannot pass unsupported shapes to dyn_kernel_* or the orchestration
logic.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/ir_parser/dynamic_paged_attention_example.py`:
- Around line 46-57: The docstring for dyn_kernel_init_inplace is incorrect: the
function does not zero-initialize but merely returns the provided tensors (oi,
li, mi) as a passthrough for type/shape binding. Update the docstring to state
it is a no-op passthrough used for type/shape binding (or optionally rename the
function to something like dyn_kernel_bind_shapes or dyn_kernel_passthrough) and
mention that actual zero-initialization happens via pl.create_tensor before this
function is called.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 4a07963b-94e9-413e-9387-3a7934421976
📒 Files selected for processing (3)
examples/ir_parser/dynamic_paged_attention_example.pysrc/codegen/pto/pto_codegen.cpptests/st/codegen/test_dynamic_paged_attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/codegen/pto/pto_codegen.cpp
- tests/st/codegen/test_dynamic_paged_attention.py
9c98ff1 to
1412fd9
Compare
|
|
||
|
|
||
| def build_dynamic_paged_attention_program( | ||
| batch: int, |
| oi: pl.Out[pl.Tensor[[16, 128], pl.FP32]], | ||
| li: pl.Out[pl.Tensor[[16, 1], pl.FP32]], | ||
| mi: pl.Out[pl.Tensor[[16, 1], pl.FP32]], | ||
| li: pl.Out[pl.Tensor[[16, 1], pl.FP32, pl.DN]], |
| """ | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "batch,num_heads,head_dim,block_size,context_len,max_model_len", |
e32e9d9 to
2b71a81
Compare
## Summary
- Add `examples/ir_parser/dynamic_paged_attention_example.py` with
`build_dynamic_paged_attention_program()` builder that defines five InCore
kernel closures (`dyn_kernel_init_inplace`, `dyn_kernel_qk_matmul`,
`dyn_kernel_softmax_prepare`, `dyn_kernel_pv_matmul`, `dyn_kernel_online_update`);
type annotations use module-level `pl.dynamic()` variables
(`Q_HEADS`, `HEAD_DIM_DYN`, `BLOCK_SIZE_DYN`), load ops use concrete
closure variables (`_Q_TILE`, `_HEAD_DIM`, `_BLOCK_SIZE`)
- Add `tests/st/codegen/test_dynamic_paged_attention.py` with
`DynamicPagedAttentionTestCase` inheriting golden reference and tensor
definitions from `PagedAttentionTestCase`, targeting Ascend910B PTO backend;
3 parametrized configurations
- Extend `tests/st/codegen/test_paged_attention.py`: `PagedAttentionTestCase`
accepts `context_len: int | list[int]` and constructs `context_lens` tensor
from a list when heterogeneous per-request lengths are needed (required by
`DynamicPagedAttentionTestCase`)
- Fix `PTOCodegen::EmitMakeTensorViews` to emit `{layout = #pto.layout<nd|dn|nz>}`
attribute on tensor view assignments so DN-layout tensors (e.g. transposed
key_cache in paged attention) are correctly represented in generated PTO IR
- Fix `examples/ir_parser/paged_attention_example.py`: annotate `[16, 1]`
output tensors in `kernel_init_inplace` and `kernel_softmax_prepare` with `pl.DN`
## Testing
- [x] DynamicPagedAttentionTestCase passes on 910B PTO hardware
- [x] Pre-commit hooks pass (ruff, pyright, clang-format, cpplint)
Summary
examples/ir_parser/dynamic_paged_attention_example.pywithbuild_dynamic_paged_attention_program()builder that defines five InCorekernel closures (
dyn_kernel_init_inplace,dyn_kernel_qk_matmul,dyn_kernel_softmax_prepare,dyn_kernel_pv_matmul,dyn_kernel_online_update);type annotations use module-level
pl.dynamic()variables(
Q_HEADS,HEAD_DIM_DYN,BLOCK_SIZE_DYN), load ops use concreteclosure variables (
_Q_TILE,_HEAD_DIM,_BLOCK_SIZE)tests/st/codegen/test_dynamic_paged_attention.pywithDynamicPagedAttentionTestCaseinheriting golden reference and tensordefinitions from
PagedAttentionTestCase, targeting Ascend910B PTO backend;3 parametrized configurations
PTOCodegen::EmitMakeTensorViewsto emit{layout = #pto.layout<nd|dn|nz>}attribute on tensor view assignments so DN-layout tensors (e.g. transposed
key_cache in paged attention) are correctly represented in generated PTO IR
examples/ir_parser/paged_attention_example.py: annotate[16, 1]output tensors in
kernel_init_inplaceandkernel_softmax_preparewithpl.DNTesting