Skip to content

feat(paged-attention): add paged attention support with arithmetic shape expressions#470

Open
wangqin1723-max wants to merge 1 commit intohw-native-sys:mainfrom
wangqin1723-max:feature/paged-attention
Open

feat(paged-attention): add paged attention support with arithmetic shape expressions#470
wangqin1723-max wants to merge 1 commit intohw-native-sys:mainfrom
wangqin1723-max:feature/paged-attention

Conversation

@wangqin1723-max
Copy link
Contributor

@wangqin1723-max wangqin1723-max commented Mar 12, 2026

  • Support arithmetic expressions in tensor shape dimensions (language parser)
  • Add paged attention example and PTOAS codegen tests
  • Synchronize orchestration codegen with simpler runtime uint64→uint32 refactor
  • Support discontinuous block_table in paged attention

@coderabbitai
Copy link

coderabbitai bot commented Mar 12, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds 2-block InCore kernels and a multitier paged-attention program builder that orchestrates x2/x1 block processing with offset-based loads/stores, extends tests for the new kernels and multitier flow, enhances shape parsing to evaluate binary operators in dimension expressions, and tweaks tensor init detection for arange patterns.

Changes

Cohort / File(s) Summary
Paged Attention 2-Block Kernels & Multitier Orchestration
examples/ir_parser/paged_attention_example.py
Added kernel_qk_matmul_2block, kernel_softmax_prepare_2block, kernel_pv_matmul_2block and build_paged_attention_multitier_program. Replaced single-block calls with 2-block variants, added two-tier (x2/x1) bn-loop orchestration, and refactored many loads/stores to use explicit offsets/shapes.
Type Resolver Shape Expression Parsing
python/pypto/language/parser/type_resolver.py
Introduced _parse_shape_dim_expr with static and IR binop mappings (_BINOP_STATIC, _BINOP_IR) to support Add/Sub/Mult/FloorDiv/Mod in shape dims; delegates element parsing and improves error handling.
Paged Attention Tests
tests/st/codegen/test_paged_attention.py
Added tests and test-case classes for 2-block kernels (QKMatmul2BlockTestCase, SoftmaxPrepare2BlockTestCase, PVMatmul2BlockTestCase) and multitier examples (PagedAttentionMultitierExampleTestCase, PTOAS variants); extended test imports and added expected-compute logic for multitier scenarios.
Test Harness: Golden Init Pattern
tests/st/harness/adapters/golden_generator.py
Detects 1D arange(n) tensor pattern and emits torch.arange(n, dtype=...) initializer when matched (new pattern check inserted among existing tensor init heuristics).

Sequence Diagram(s)

sequenceDiagram
    participant Q as Query
    participant KV as KV Cache
    participant QK as QK Matmul (2-block)
    participant SF as Softmax Prepare (2-block)
    participant PV as PV Matmul (2-block)
    participant Out as Output

    Q->>KV: fetch KV blocks 0-1
    KV-->>QK: K0-1, V0-1
    QK->>QK: compute Q·K for blocks 0-1
    QK->>SF: S_ij (blocks 0-1)
    SF->>SF: scale & softmax → P_ij, m_i, l_i
    SF->>PV: P_ij (blocks 0-1)
    PV->>PV: compute partial O_i
    PV->>Out: store/update O_i

    Q->>KV: fetch KV blocks 2-3
    KV-->>QK: K2-3, V2-3
    QK->>SF: S_ij (blocks 2-3)
    SF->>PV: P_ij (blocks 2-3)
    PV->>Out: online update O_i

    Q->>KV: handle remaining 1-block path
    KV-->>Out: finalize output (x1 tier)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • Hzfengsy

Poem

🐰 Two blocks hop in, kernels hum and glide,
Tiers align, offsets map the stride,
Shapes now count with + and × in tune,
Tests march forward under a silver moon,
A rabbit cheers—multitier code, rejoice! 🎉

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 47.06% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately reflects the main changes: adding paged attention support and arithmetic shape expressions. Both features are prominent in the PR across multiple files.
Description check ✅ Passed The PR description is directly related to the changeset, covering arithmetic expressions in shape dimensions, paged attention examples, and codegen tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and performance of tensor operations within the language. It introduces the ability to define tensor shapes using arithmetic expressions, which simplifies dynamic sizing. Furthermore, it optimizes the paged attention mechanism by implementing specialized 2-block kernels and a multi-tier processing program, leading to more efficient handling of KV-cache blocks. These changes collectively contribute to a more robust and performant system for advanced tensor computations.

Highlights

  • Extended Tensor Shape Dimension Parsing: The TypeResolver now supports arithmetic expressions (addition, subtraction, multiplication, floor division, modulo) within tensor shape dimensions, allowing for more dynamic and flexible shape definitions. This involves a recursive evaluation combining static integer folding and dynamic IR expression construction.
  • Introduced 2-Block Paged Attention Kernels: New kernel variants (kernel_qk_matmul_2block, kernel_softmax_prepare_2block, kernel_pv_matmul_2block) were added to the paged attention example, designed to process two physically contiguous blocks per operation for improved efficiency.
  • Implemented Multi-Tier Paged Attention Program: A new build_paged_attention_multitier_program was added, which orchestrates paged attention using a 2-tier loop structure, leveraging the new 2-block kernels for optimized processing of block pairs.
  • Standardized Load/Store Calls: pl.load and pl.store calls in the paged attention example were updated to use explicit keyword arguments for offsets, shapes, and output_tensor, enhancing readability and maintainability.
  • Comprehensive Test Coverage for New Features: Dedicated test cases were added to validate the functionality of the new 2-block kernels and the build_paged_attention_multitier_program, ensuring correctness and stability.
Changelog
  • examples/ir_parser/paged_attention_example.py
    • Added kernel_qk_matmul_2block, kernel_softmax_prepare_2block, and kernel_pv_matmul_2block for 2-block processing.
    • Introduced build_paged_attention_multitier_program to orchestrate paged attention with a 2-tier loop structure.
    • Migrated pl.load and pl.store calls to use keyword arguments for clarity.
  • python/pypto/language/parser/type_resolver.py
    • Refactored _parse_shape_elements to delegate to _parse_shape_dim_expr for recursive parsing.
    • Implemented _parse_shape_dim_expr to handle integer literals, variables, and arithmetic BinOp expressions in shape dimensions, supporting static folding and dynamic IR expression generation.
    • Defined _BINOP_STATIC and _BINOP_IR dictionaries for arithmetic operations.
  • tests/st/codegen/test_paged_attention.py
    • Imported new kernel functions and the build_paged_attention_multitier_program.
    • Added PagedAttentionMultitierExampleTestCase and PagedAttentionMultitierExamplePTOASTestCase to test the new multi-tier paged attention program.
    • Included QKMatmul2BlockTestCase, SoftmaxPrepare2BlockTestCase, and PVMatmul2BlockTestCase to verify the functionality of the new 2-block kernels.
    • Updated pytest.mark.parametrize decorators to include the new test cases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for arithmetic expressions in tensor shape dimensions, a valuable enhancement. The implementation in type_resolver.py is robust and well-designed. Additionally, it adds 2-block kernel variants for paged attention and a new multi-tier program, complete with comprehensive tests. My review identifies a few opportunities to improve efficiency in the new paged attention example by hoisting redundant load and slice operations out of loops.

Comment on lines +207 to +221
qi_l1_0 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat)
kj0_l1 = pl.load(kj, offsets=[0, 0], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True)
qi_l0a_0 = pl.move(qi_l1_0, target_memory=pl.MemorySpace.Left)
kj0_l0b = pl.move(kj0_l1, target_memory=pl.MemorySpace.Right)
sij_h0_l0c = pl.matmul(qi_l0a_0, kj0_l0b)
_ = pl.store(sij_h0_l0c, offsets=[0, 0], output_tensor=output)
# Second block: qi @ kj[128:256].T -> output[:,128:256]
qi_l1_1 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat)
kj1_l1 = pl.load(
kj, offsets=[0, 128], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True
)
qi_l0a_1 = pl.move(qi_l1_1, target_memory=pl.MemorySpace.Left)
kj1_l0b = pl.move(kj1_l1, target_memory=pl.MemorySpace.Right)
sij_h1_l0c = pl.matmul(qi_l0a_1, kj1_l0b)
out = pl.store(sij_h1_l0c, offsets=[0, 128], output_tensor=output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The qi tensor is loaded and moved to L0A memory twice, once for each block. Since it's the same qi for both operations, it can be loaded and moved just once and then reused for both matmul calls. This avoids a redundant load and move, improving kernel efficiency.

Suggested change
qi_l1_0 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat)
kj0_l1 = pl.load(kj, offsets=[0, 0], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True)
qi_l0a_0 = pl.move(qi_l1_0, target_memory=pl.MemorySpace.Left)
kj0_l0b = pl.move(kj0_l1, target_memory=pl.MemorySpace.Right)
sij_h0_l0c = pl.matmul(qi_l0a_0, kj0_l0b)
_ = pl.store(sij_h0_l0c, offsets=[0, 0], output_tensor=output)
# Second block: qi @ kj[128:256].T -> output[:,128:256]
qi_l1_1 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat)
kj1_l1 = pl.load(
kj, offsets=[0, 128], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True
)
qi_l0a_1 = pl.move(qi_l1_1, target_memory=pl.MemorySpace.Left)
kj1_l0b = pl.move(kj1_l1, target_memory=pl.MemorySpace.Right)
sij_h1_l0c = pl.matmul(qi_l0a_1, kj1_l0b)
out = pl.store(sij_h1_l0c, offsets=[0, 128], output_tensor=output)
qi_l1 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat)
qi_l0a = pl.move(qi_l1, target_memory=pl.MemorySpace.Left)
# First block: qi @ kj[0:128].T -> output[:,0:128]
kj0_l1 = pl.load(kj, offsets=[0, 0], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True)
kj0_l0b = pl.move(kj0_l1, target_memory=pl.MemorySpace.Right)
sij_h0_l0c = pl.matmul(qi_l0a, kj0_l0b)
_ = pl.store(sij_h0_l0c, offsets=[0, 0], output_tensor=output)
# Second block: qi @ kj[128:256].T -> output[:,128:256]
kj1_l1 = pl.load(
kj, offsets=[0, 128], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True
)
kj1_l0b = pl.move(kj1_l1, target_memory=pl.MemorySpace.Right)
sij_h1_l0c = pl.matmul(qi_l0a, kj1_l0b)
out = pl.store(sij_h1_l0c, offsets=[0, 128], output_tensor=output)

Comment on lines +562 to +566
qi_x2: pl.Tensor[[q_tile, head_dim_cfg], pl.BF16] = pl.slice(
query,
[q_tile, head_dim_cfg], # type: ignore[reportArgumentType]
[cur_offset, 0],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The qi_x2 tensor is sliced from query inside the bn2 loop. Since cur_offset does not change within this loop, the slice operation is redundant on each iteration. It can be hoisted out of the loop (before line 561) to improve efficiency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修一下

Comment on lines +629 to +633
qi: pl.Tensor[[q_tile, head_dim_cfg], pl.BF16] = pl.slice(
query,
[q_tile, head_dim_cfg], # type: ignore[reportArgumentType]
[cur_offset, 0],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the kernelx2 loop, the qi tensor is sliced from query inside the bn3 loop. This is inefficient as cur_offset is constant throughout the loop. This slice should be performed once before the loop. Furthermore, this slice is identical to the one for qi_x2 in the preceding loop. A single slice operation before both loops would be the most efficient approach.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修一下

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/ir_parser/paged_attention_example.py`:
- Around line 568-582: The code assumes the two-block pair is physically
contiguous but only reads bidx_x2_0; fix by reading the next index (e.g.,
bidx_x2_1 = pl.tensor.read(block_table, [b_idx * block_num_cfg + bn2 + 1]) ),
check contiguity (bidx_x2_1 == bidx_x2_0 + 1) and only perform the
2*block_size_cfg slice into key_cache/value_cache (producing kj_2b and vj_2b)
when contiguous; otherwise fall back to the existing x1 path and slice the two
blocks individually (or iterate two x1 iterations) to avoid reading past
allocated cache rows and feeding wrong K/V pairs. Ensure checks use bn2,
block_table, key_cache, value_cache, kj_2b and vj_2b symbols so behavior is
switched deterministically.

In `@python/pypto/language/parser/type_resolver.py`:
- Around line 587-588: The new arithmetic path currently routes dimensions
through _validate_dim_value which only accepts int/DynVar, causing
closure-provided ir.Expr dims to be rejected; update _validate_dim_value (and
any callers such as the branch in _resolve_shape_dim that handles ast.Name and
the other branch around the code at lines 607-610) to accept ir.Expr as a valid
dimension type (or detect ir.Expr and bypass/convert it) so that
ExprEvaluator-produced ir.Expr from closures is allowed; ensure the change
preserves existing int/DynVar checks and integrates with ExprEvaluator's lifting
behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 9a344827-464e-4140-b1b5-1165f21b3e11

📥 Commits

Reviewing files that changed from the base of the PR and between 1f5ff02 and a850769.

📒 Files selected for processing (3)
  • examples/ir_parser/paged_attention_example.py
  • python/pypto/language/parser/type_resolver.py
  • tests/st/codegen/test_paged_attention.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/st/codegen/test_paged_attention.py (1)

667-699: Bind loop variables in nested helper defs to avoid late-binding behavior flagged by Ruff B023.

Lines 667-699 close over loop-scoped variables (b, cur_seq, qi). The nested functions _2block and _1block will fail if these variables change after definition, even though they're called immediately. Bind them as default parameters to capture their current values:

Proposed fix
-                def _2block(bn2: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+                def _2block(
+                    bn2: int, *, _b: int = b, _cur_seq: int = cur_seq, _qi: torch.Tensor = qi
+                ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@
-                    v0 = min(block_size, cur_seq - bn2 * block_size)
-                    v1 = min(block_size, cur_seq - (bn2 + 1) * block_size)
-                    bidx0 = int(block_table[b, bn2].item())
-                    bidx1 = int(block_table[b, bn2 + 1].item())
+                    v0 = min(block_size, _cur_seq - bn2 * block_size)
+                    v1 = min(block_size, _cur_seq - (bn2 + 1) * block_size)
+                    bidx0 = int(block_table[_b, bn2].item())
+                    bidx1 = int(block_table[_b, bn2 + 1].item())
@@
-                    sij = torch.mm(qi, kj.T) * scale_value  # [q_tile, valid_2b]
+                    sij = torch.mm(_qi, kj.T) * scale_value  # [q_tile, valid_2b]
@@
-                def _1block(
-                    bn: int,
-                ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
+                def _1block(
+                    bn: int, *, _b: int = b, _cur_seq: int = cur_seq, _qi: torch.Tensor = qi
+                ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
@@
-                    v = min(block_size, cur_seq - bn * block_size)
+                    v = min(block_size, _cur_seq - bn * block_size)
@@
-                    kj = key_cache[int(block_table[b, bn].item()), :v]
-                    vj = value_cache[int(block_table[b, bn].item()), :v]
-                    sij = torch.mm(qi, kj.T) * scale_value
+                    kj = key_cache[int(block_table[_b, bn].item()), :v]
+                    vj = value_cache[int(block_table[_b, bn].item()), :v]
+                    sij = torch.mm(_qi, kj.T) * scale_value
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/codegen/test_paged_attention.py` around lines 667 - 699, Nested
helper functions _2block and _1block close over loop-scoped variables (b,
cur_seq, qi) causing late-binding issues flagged by Ruff B023; fix by binding
those loop variables as default parameters on the function definitions (e.g.,
add b=b, cur_seq=cur_seq, qi=qi to _2block and b=b, cur_seq=cur_seq, qi=qi to
_1block, or only the ones each uses) so the current values are captured at
definition time while keeping the existing logic in _2block and _1block
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/st/codegen/test_paged_attention.py`:
- Around line 661-663: The loop using range(num_heads // q_tile) (variable
q_idx) drops tail heads when num_heads % q_tile != 0; change the iteration to
cover the final partial tile (e.g., compute full_tiles, rem = divmod(num_heads,
q_tile) and iterate over range(full_tiles + (1 if rem else 0))) and slice qi =
query[b, q_off : min(q_off + q_tile, num_heads), :] so the last tile can be
smaller; likewise update the corresponding write path that only writes full
tiles (the write at the location that handles full tiles around line 726) to
handle partial-tail tiles by writing only the valid head entries for the last
tile. Ensure you reference num_heads, q_tile, q_idx, q_off, qi and the write
that emits per-tile outputs so both read and write handle the remainder.

---

Nitpick comments:
In `@tests/st/codegen/test_paged_attention.py`:
- Around line 667-699: Nested helper functions _2block and _1block close over
loop-scoped variables (b, cur_seq, qi) causing late-binding issues flagged by
Ruff B023; fix by binding those loop variables as default parameters on the
function definitions (e.g., add b=b, cur_seq=cur_seq, qi=qi to _2block and b=b,
cur_seq=cur_seq, qi=qi to _1block, or only the ones each uses) so the current
values are captured at definition time while keeping the existing logic in
_2block and _1block unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8a4e3a77-9651-41a9-be3b-8fa44f57b211

📥 Commits

Reviewing files that changed from the base of the PR and between a850769 and 99c11a1.

📒 Files selected for processing (2)
  • tests/st/codegen/test_paged_attention.py
  • tests/st/harness/adapters/golden_generator.py

Comment on lines +661 to +663
for q_idx in range(num_heads // q_tile):
q_off = q_idx * q_tile
qi = query[b, q_off : q_off + q_tile, :] # [q_tile, head_dim]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle non-multiple-of-q_tile heads in multitier expected path.

Line 661 uses num_heads // q_tile, which silently drops tail heads when num_heads is not divisible by 16. Line 726 then only writes full tiles. This can produce incorrect expected outputs for valid constructor inputs.

Proposed fix
-            for q_idx in range(num_heads // q_tile):
-                q_off = q_idx * q_tile
-                qi = query[b, q_off : q_off + q_tile, :]  # [q_tile, head_dim]
+            for q_off in range(0, num_heads, q_tile):
+                q_tile_size = min(q_tile, num_heads - q_off)
+                qi = query[b, q_off : q_off + q_tile_size, :]
@@
-                out[b, q_off : q_off + q_tile, :] = oi / li
+                out[b, q_off : q_off + q_tile_size, :] = oi / li

Also applies to: 726-726

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/codegen/test_paged_attention.py` around lines 661 - 663, The loop
using range(num_heads // q_tile) (variable q_idx) drops tail heads when
num_heads % q_tile != 0; change the iteration to cover the final partial tile
(e.g., compute full_tiles, rem = divmod(num_heads, q_tile) and iterate over
range(full_tiles + (1 if rem else 0))) and slice qi = query[b, q_off : min(q_off
+ q_tile, num_heads), :] so the last tile can be smaller; likewise update the
corresponding write path that only writes full tiles (the write at the location
that handles full tiles around line 726) to handle partial-tail tiles by writing
only the valid head entries for the last tile. Ensure you reference num_heads,
q_tile, q_idx, q_off, qi and the write that emits per-tile outputs so both read
and write handle the remainder.

sij_h0_l0c = pl.matmul(qi_l0a_0, kj0_l0b)
_ = pl.store(sij_h0_l0c, offsets=[0, 0], output_tensor=output)
# Second block: qi @ kj[128:256].T -> output[:,128:256]
qi_l1_1 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qi 为什么要load两遍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

)
qi_l0a_1 = pl.move(qi_l1_1, target_memory=pl.MemorySpace.Left)
kj1_l0b = pl.move(kj1_l1, target_memory=pl.MemorySpace.Right)
sij_h1_l0c = pl.matmul(qi_l0a_1, kj1_l0b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不能把两个kj块拼起来一起做matmul吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前是tranpose不支持非方阵,现在tranpose合到load里面了,可以合并了,已改

Comment on lines +629 to +633
qi: pl.Tensor[[q_tile, head_dim_cfg], pl.BF16] = pl.slice(
query,
[q_tile, head_dim_cfg], # type: ignore[reportArgumentType]
[cur_offset, 0],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修一下

Comment on lines +562 to +566
qi_x2: pl.Tensor[[q_tile, head_dim_cfg], pl.BF16] = pl.slice(
query,
[q_tile, head_dim_cfg], # type: ignore[reportArgumentType]
[cur_offset, 0],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修一下

@wangqin1723-max wangqin1723-max force-pushed the feature/paged-attention branch 2 times, most recently from b5430c6 to 6ac4980 Compare March 16, 2026 08:27
@wangqin1723-max wangqin1723-max changed the title feat(language): support arithmetic expressions in tensor shape dimens… feat(paged-attention): add paged attention support with arithmetic shape expressions Mar 16, 2026
@wangqin1723-max wangqin1723-max force-pushed the feature/paged-attention branch from 6ac4980 to ae0dac9 Compare March 16, 2026 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants