feat(paged-attention): add paged attention support with arithmetic shape expressions#470
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds 2-block InCore kernels and a multitier paged-attention program builder that orchestrates x2/x1 block processing with offset-based loads/stores, extends tests for the new kernels and multitier flow, enhances shape parsing to evaluate binary operators in dimension expressions, and tweaks tensor init detection for arange patterns. Changes
Sequence Diagram(s)sequenceDiagram
participant Q as Query
participant KV as KV Cache
participant QK as QK Matmul (2-block)
participant SF as Softmax Prepare (2-block)
participant PV as PV Matmul (2-block)
participant Out as Output
Q->>KV: fetch KV blocks 0-1
KV-->>QK: K0-1, V0-1
QK->>QK: compute Q·K for blocks 0-1
QK->>SF: S_ij (blocks 0-1)
SF->>SF: scale & softmax → P_ij, m_i, l_i
SF->>PV: P_ij (blocks 0-1)
PV->>PV: compute partial O_i
PV->>Out: store/update O_i
Q->>KV: fetch KV blocks 2-3
KV-->>QK: K2-3, V2-3
QK->>SF: S_ij (blocks 2-3)
SF->>PV: P_ij (blocks 2-3)
PV->>Out: online update O_i
Q->>KV: handle remaining 1-block path
KV-->>Out: finalize output (x1 tier)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. 📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and performance of tensor operations within the language. It introduces the ability to define tensor shapes using arithmetic expressions, which simplifies dynamic sizing. Furthermore, it optimizes the paged attention mechanism by implementing specialized 2-block kernels and a multi-tier processing program, leading to more efficient handling of KV-cache blocks. These changes collectively contribute to a more robust and performant system for advanced tensor computations. Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for arithmetic expressions in tensor shape dimensions, a valuable enhancement. The implementation in type_resolver.py is robust and well-designed. Additionally, it adds 2-block kernel variants for paged attention and a new multi-tier program, complete with comprehensive tests. My review identifies a few opportunities to improve efficiency in the new paged attention example by hoisting redundant load and slice operations out of loops.
| qi_l1_0 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat) | ||
| kj0_l1 = pl.load(kj, offsets=[0, 0], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True) | ||
| qi_l0a_0 = pl.move(qi_l1_0, target_memory=pl.MemorySpace.Left) | ||
| kj0_l0b = pl.move(kj0_l1, target_memory=pl.MemorySpace.Right) | ||
| sij_h0_l0c = pl.matmul(qi_l0a_0, kj0_l0b) | ||
| _ = pl.store(sij_h0_l0c, offsets=[0, 0], output_tensor=output) | ||
| # Second block: qi @ kj[128:256].T -> output[:,128:256] | ||
| qi_l1_1 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat) | ||
| kj1_l1 = pl.load( | ||
| kj, offsets=[0, 128], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True | ||
| ) | ||
| qi_l0a_1 = pl.move(qi_l1_1, target_memory=pl.MemorySpace.Left) | ||
| kj1_l0b = pl.move(kj1_l1, target_memory=pl.MemorySpace.Right) | ||
| sij_h1_l0c = pl.matmul(qi_l0a_1, kj1_l0b) | ||
| out = pl.store(sij_h1_l0c, offsets=[0, 128], output_tensor=output) |
There was a problem hiding this comment.
The qi tensor is loaded and moved to L0A memory twice, once for each block. Since it's the same qi for both operations, it can be loaded and moved just once and then reused for both matmul calls. This avoids a redundant load and move, improving kernel efficiency.
| qi_l1_0 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat) | |
| kj0_l1 = pl.load(kj, offsets=[0, 0], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True) | |
| qi_l0a_0 = pl.move(qi_l1_0, target_memory=pl.MemorySpace.Left) | |
| kj0_l0b = pl.move(kj0_l1, target_memory=pl.MemorySpace.Right) | |
| sij_h0_l0c = pl.matmul(qi_l0a_0, kj0_l0b) | |
| _ = pl.store(sij_h0_l0c, offsets=[0, 0], output_tensor=output) | |
| # Second block: qi @ kj[128:256].T -> output[:,128:256] | |
| qi_l1_1 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat) | |
| kj1_l1 = pl.load( | |
| kj, offsets=[0, 128], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True | |
| ) | |
| qi_l0a_1 = pl.move(qi_l1_1, target_memory=pl.MemorySpace.Left) | |
| kj1_l0b = pl.move(kj1_l1, target_memory=pl.MemorySpace.Right) | |
| sij_h1_l0c = pl.matmul(qi_l0a_1, kj1_l0b) | |
| out = pl.store(sij_h1_l0c, offsets=[0, 128], output_tensor=output) | |
| qi_l1 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat) | |
| qi_l0a = pl.move(qi_l1, target_memory=pl.MemorySpace.Left) | |
| # First block: qi @ kj[0:128].T -> output[:,0:128] | |
| kj0_l1 = pl.load(kj, offsets=[0, 0], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True) | |
| kj0_l0b = pl.move(kj0_l1, target_memory=pl.MemorySpace.Right) | |
| sij_h0_l0c = pl.matmul(qi_l0a, kj0_l0b) | |
| _ = pl.store(sij_h0_l0c, offsets=[0, 0], output_tensor=output) | |
| # Second block: qi @ kj[128:256].T -> output[:,128:256] | |
| kj1_l1 = pl.load( | |
| kj, offsets=[0, 128], shapes=[128, 128], target_memory=pl.MemorySpace.Mat, transpose=True | |
| ) | |
| kj1_l0b = pl.move(kj1_l1, target_memory=pl.MemorySpace.Right) | |
| sij_h1_l0c = pl.matmul(qi_l0a, kj1_l0b) | |
| out = pl.store(sij_h1_l0c, offsets=[0, 128], output_tensor=output) |
| qi_x2: pl.Tensor[[q_tile, head_dim_cfg], pl.BF16] = pl.slice( | ||
| query, | ||
| [q_tile, head_dim_cfg], # type: ignore[reportArgumentType] | ||
| [cur_offset, 0], | ||
| ) |
| qi: pl.Tensor[[q_tile, head_dim_cfg], pl.BF16] = pl.slice( | ||
| query, | ||
| [q_tile, head_dim_cfg], # type: ignore[reportArgumentType] | ||
| [cur_offset, 0], | ||
| ) |
There was a problem hiding this comment.
Similar to the kernelx2 loop, the qi tensor is sliced from query inside the bn3 loop. This is inefficient as cur_offset is constant throughout the loop. This slice should be performed once before the loop. Furthermore, this slice is identical to the one for qi_x2 in the preceding loop. A single slice operation before both loops would be the most efficient approach.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/ir_parser/paged_attention_example.py`:
- Around line 568-582: The code assumes the two-block pair is physically
contiguous but only reads bidx_x2_0; fix by reading the next index (e.g.,
bidx_x2_1 = pl.tensor.read(block_table, [b_idx * block_num_cfg + bn2 + 1]) ),
check contiguity (bidx_x2_1 == bidx_x2_0 + 1) and only perform the
2*block_size_cfg slice into key_cache/value_cache (producing kj_2b and vj_2b)
when contiguous; otherwise fall back to the existing x1 path and slice the two
blocks individually (or iterate two x1 iterations) to avoid reading past
allocated cache rows and feeding wrong K/V pairs. Ensure checks use bn2,
block_table, key_cache, value_cache, kj_2b and vj_2b symbols so behavior is
switched deterministically.
In `@python/pypto/language/parser/type_resolver.py`:
- Around line 587-588: The new arithmetic path currently routes dimensions
through _validate_dim_value which only accepts int/DynVar, causing
closure-provided ir.Expr dims to be rejected; update _validate_dim_value (and
any callers such as the branch in _resolve_shape_dim that handles ast.Name and
the other branch around the code at lines 607-610) to accept ir.Expr as a valid
dimension type (or detect ir.Expr and bypass/convert it) so that
ExprEvaluator-produced ir.Expr from closures is allowed; ensure the change
preserves existing int/DynVar checks and integrates with ExprEvaluator's lifting
behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 9a344827-464e-4140-b1b5-1165f21b3e11
📒 Files selected for processing (3)
examples/ir_parser/paged_attention_example.pypython/pypto/language/parser/type_resolver.pytests/st/codegen/test_paged_attention.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/st/codegen/test_paged_attention.py (1)
667-699: Bind loop variables in nested helper defs to avoid late-binding behavior flagged by Ruff B023.Lines 667-699 close over loop-scoped variables (
b,cur_seq,qi). The nested functions_2blockand_1blockwill fail if these variables change after definition, even though they're called immediately. Bind them as default parameters to capture their current values:Proposed fix
- def _2block(bn2: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _2block( + bn2: int, *, _b: int = b, _cur_seq: int = cur_seq, _qi: torch.Tensor = qi + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ - v0 = min(block_size, cur_seq - bn2 * block_size) - v1 = min(block_size, cur_seq - (bn2 + 1) * block_size) - bidx0 = int(block_table[b, bn2].item()) - bidx1 = int(block_table[b, bn2 + 1].item()) + v0 = min(block_size, _cur_seq - bn2 * block_size) + v1 = min(block_size, _cur_seq - (bn2 + 1) * block_size) + bidx0 = int(block_table[_b, bn2].item()) + bidx1 = int(block_table[_b, bn2 + 1].item()) @@ - sij = torch.mm(qi, kj.T) * scale_value # [q_tile, valid_2b] + sij = torch.mm(_qi, kj.T) * scale_value # [q_tile, valid_2b] @@ - def _1block( - bn: int, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + def _1block( + bn: int, *, _b: int = b, _cur_seq: int = cur_seq, _qi: torch.Tensor = qi + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: @@ - v = min(block_size, cur_seq - bn * block_size) + v = min(block_size, _cur_seq - bn * block_size) @@ - kj = key_cache[int(block_table[b, bn].item()), :v] - vj = value_cache[int(block_table[b, bn].item()), :v] - sij = torch.mm(qi, kj.T) * scale_value + kj = key_cache[int(block_table[_b, bn].item()), :v] + vj = value_cache[int(block_table[_b, bn].item()), :v] + sij = torch.mm(_qi, kj.T) * scale_value🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/st/codegen/test_paged_attention.py` around lines 667 - 699, Nested helper functions _2block and _1block close over loop-scoped variables (b, cur_seq, qi) causing late-binding issues flagged by Ruff B023; fix by binding those loop variables as default parameters on the function definitions (e.g., add b=b, cur_seq=cur_seq, qi=qi to _2block and b=b, cur_seq=cur_seq, qi=qi to _1block, or only the ones each uses) so the current values are captured at definition time while keeping the existing logic in _2block and _1block unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/st/codegen/test_paged_attention.py`:
- Around line 661-663: The loop using range(num_heads // q_tile) (variable
q_idx) drops tail heads when num_heads % q_tile != 0; change the iteration to
cover the final partial tile (e.g., compute full_tiles, rem = divmod(num_heads,
q_tile) and iterate over range(full_tiles + (1 if rem else 0))) and slice qi =
query[b, q_off : min(q_off + q_tile, num_heads), :] so the last tile can be
smaller; likewise update the corresponding write path that only writes full
tiles (the write at the location that handles full tiles around line 726) to
handle partial-tail tiles by writing only the valid head entries for the last
tile. Ensure you reference num_heads, q_tile, q_idx, q_off, qi and the write
that emits per-tile outputs so both read and write handle the remainder.
---
Nitpick comments:
In `@tests/st/codegen/test_paged_attention.py`:
- Around line 667-699: Nested helper functions _2block and _1block close over
loop-scoped variables (b, cur_seq, qi) causing late-binding issues flagged by
Ruff B023; fix by binding those loop variables as default parameters on the
function definitions (e.g., add b=b, cur_seq=cur_seq, qi=qi to _2block and b=b,
cur_seq=cur_seq, qi=qi to _1block, or only the ones each uses) so the current
values are captured at definition time while keeping the existing logic in
_2block and _1block unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 8a4e3a77-9651-41a9-be3b-8fa44f57b211
📒 Files selected for processing (2)
tests/st/codegen/test_paged_attention.pytests/st/harness/adapters/golden_generator.py
| for q_idx in range(num_heads // q_tile): | ||
| q_off = q_idx * q_tile | ||
| qi = query[b, q_off : q_off + q_tile, :] # [q_tile, head_dim] |
There was a problem hiding this comment.
Handle non-multiple-of-q_tile heads in multitier expected path.
Line 661 uses num_heads // q_tile, which silently drops tail heads when num_heads is not divisible by 16. Line 726 then only writes full tiles. This can produce incorrect expected outputs for valid constructor inputs.
Proposed fix
- for q_idx in range(num_heads // q_tile):
- q_off = q_idx * q_tile
- qi = query[b, q_off : q_off + q_tile, :] # [q_tile, head_dim]
+ for q_off in range(0, num_heads, q_tile):
+ q_tile_size = min(q_tile, num_heads - q_off)
+ qi = query[b, q_off : q_off + q_tile_size, :]
@@
- out[b, q_off : q_off + q_tile, :] = oi / li
+ out[b, q_off : q_off + q_tile_size, :] = oi / liAlso applies to: 726-726
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/st/codegen/test_paged_attention.py` around lines 661 - 663, The loop
using range(num_heads // q_tile) (variable q_idx) drops tail heads when
num_heads % q_tile != 0; change the iteration to cover the final partial tile
(e.g., compute full_tiles, rem = divmod(num_heads, q_tile) and iterate over
range(full_tiles + (1 if rem else 0))) and slice qi = query[b, q_off : min(q_off
+ q_tile, num_heads), :] so the last tile can be smaller; likewise update the
corresponding write path that only writes full tiles (the write at the location
that handles full tiles around line 726) to handle partial-tail tiles by writing
only the valid head entries for the last tile. Ensure you reference num_heads,
q_tile, q_idx, q_off, qi and the write that emits per-tile outputs so both read
and write handle the remainder.
| sij_h0_l0c = pl.matmul(qi_l0a_0, kj0_l0b) | ||
| _ = pl.store(sij_h0_l0c, offsets=[0, 0], output_tensor=output) | ||
| # Second block: qi @ kj[128:256].T -> output[:,128:256] | ||
| qi_l1_1 = pl.load(qi, offsets=[0, 0], shapes=[16, 128], target_memory=pl.MemorySpace.Mat) |
| ) | ||
| qi_l0a_1 = pl.move(qi_l1_1, target_memory=pl.MemorySpace.Left) | ||
| kj1_l0b = pl.move(kj1_l1, target_memory=pl.MemorySpace.Right) | ||
| sij_h1_l0c = pl.matmul(qi_l0a_1, kj1_l0b) |
There was a problem hiding this comment.
之前是tranpose不支持非方阵,现在tranpose合到load里面了,可以合并了,已改
| qi: pl.Tensor[[q_tile, head_dim_cfg], pl.BF16] = pl.slice( | ||
| query, | ||
| [q_tile, head_dim_cfg], # type: ignore[reportArgumentType] | ||
| [cur_offset, 0], | ||
| ) |
| qi_x2: pl.Tensor[[q_tile, head_dim_cfg], pl.BF16] = pl.slice( | ||
| query, | ||
| [q_tile, head_dim_cfg], # type: ignore[reportArgumentType] | ||
| [cur_offset, 0], | ||
| ) |
b5430c6 to
6ac4980
Compare
6ac4980 to
ae0dac9
Compare
Uh oh!
There was an error while loading. Please reload this page.