Fix tensor dimension issues and refactor attention scope in Qwen3 prefill#22
Fix tensor dimension issues and refactor attention scope in Qwen3 prefill#22xzhxzhxzh123 merged 2 commits intohw-native-sys:mainfrom
Conversation
…fill - Fix hidden_states slice dimension from 2D to 3D to match tensor shape - Remove unsupported valid_shape parameter from create_tensor calls - Add reshape after slice to adapt dimensions for downstream ops - Separate KV cache update loop from attention loop in Scope 2 - Fix down_acc_3d usage in output assembly - Adjust MLP_OUT_CHUNK from 256 to 64
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThis PR updates the Qwen3 32B prefill example: reduces MLP output tiling granularity, changes batch-parallel and KV‑cache update scheduling, refactors tensor slicing/reshape patterns and valid_shape usages, adjusts attention masking and output assembly, and removes Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. 📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can disable sequence diagrams in the walkthrough.Disable the |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses several critical issues and refactorings within the Qwen3 prefill model, primarily focusing on tensor dimension correctness and optimizing the attention mechanism. The changes ensure proper tensor shape handling throughout the computation, refine the attention scope for better clarity and stability, and update build configurations to target a specific backend, ultimately enhancing the model's robustness and execution on the intended platform. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces several fixes and refactorings to the Qwen3 prefill implementation. Key changes include correcting tensor dimensions for slicing operations, removing unsupported API parameters, and refactoring the attention scope for clarity and correctness. The changes appear to be solid improvements. I've identified one area of code duplication that could be refactored to improve maintainability.
| x_chunk = pl.reshape( | ||
| pl.cast( | ||
| pl.slice(hidden_states, [1, TOK_TILE, K_CHUNK], [b, p0, k0], | ||
| valid_shape=[1, valid_tok, K_CHUNK]), | ||
| target_type=pl.FP32, | ||
| ) | ||
| ), | ||
| [TOK_TILE, K_CHUNK] | ||
| ) |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/qwen3/qwen3_32b_prefill.py (1)
433-443:⚠️ Potential issue | 🟡 MinorAPI inconsistency:
work_dirparameter removed while other qwen3 modules retain it.The
work_dir: str | None = Noneparameter was removed fromcompile_and_run, but the related modules (qwen3_32b_decode.pyat line 411 andqwen3_32b_training_forward_and_backward.pyat line 931) still include this parameter. This creates an inconsistent API surface across the qwen3 examples.Consider either:
- Retaining
work_dirfor consistency with sibling modules, or- Removing
work_dirfrom all qwen3 modules if it's no longer needed🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/qwen3/qwen3_32b_prefill.py` around lines 433 - 443, The compile_and_run function signature in qwen3_32b_prefill.py removed the work_dir parameter, causing an inconsistent API with sibling modules (see qwen3_32b_decode.py and qwen3_32b_training_forward_and_backward.py); either restore work_dir: str | None = None to the compile_and_run signature and propagate it to any internal calls/variables that need a working directory, or remove work_dir from the other modules so all three functions (compile_and_run in qwen3_32b_prefill.py, qwen3_32b_decode.py, and qwen3_32b_training_forward_and_backward.py) share the same signature; update callers to match the chosen approach and ensure any file/dump behavior (e.g., dump_passes or file paths) uses the unified work_dir handling.
🧹 Nitpick comments (2)
examples/qwen3/qwen3_32b_prefill.py (2)
160-167: Same indentation issue as above.This block has the same inconsistent indentation pattern as lines 138-145.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/qwen3/qwen3_32b_prefill.py` around lines 160 - 167, The x_chunk construction has inconsistent indentation around the nested pl.reshape/pl.cast/pl.slice calls (symbols: x_chunk, pl.reshape, pl.cast, pl.slice, hidden_states, TOK_TILE, K_CHUNK, valid_tok, b, p0, k0); fix by aligning the chained function calls and their arguments consistently with the surrounding blocks (same style used for earlier similar blocks) so each nested call and its parameters are indented uniformly and closing parentheses line up with their opening calls.
138-145: Inconsistent indentation in multi-line expression.The
pl.cast(and its contents are not properly indented relative topl.reshape(. While Python parses this correctly due to parentheses, it harms readability.🔧 Suggested indentation fix
- x_chunk = pl.reshape( - pl.cast( - pl.slice(hidden_states, [1, TOK_TILE, K_CHUNK], [b, p0, k0], - valid_shape=[1, valid_tok, K_CHUNK]), + x_chunk = pl.reshape( + pl.cast( + pl.slice(hidden_states, [1, TOK_TILE, K_CHUNK], [b, p0, k0], + valid_shape=[1, valid_tok, K_CHUNK]), target_type=pl.FP32, - ), - [TOK_TILE, K_CHUNK] - ) + ), + [TOK_TILE, K_CHUNK] + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/qwen3/qwen3_32b_prefill.py` around lines 138 - 145, The multi-line expression building x_chunk uses pl.reshape(pl.cast(pl.slice(...))) but has inconsistent indentation that reduces readability; reformat the nested calls so the arguments to pl.reshape are aligned and the pl.cast( and its pl.slice(...) block are indented one level under pl.reshape, making pl.slice(hidden_states, [1, TOK_TILE, K_CHUNK], [b, p0, k0], valid_shape=[1, valid_tok, K_CHUNK]) clearly nested inside pl.cast which is the first argument to pl.reshape; locate x_chunk, pl.reshape, pl.cast, and pl.slice to apply this consistent indentation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@examples/qwen3/qwen3_32b_prefill.py`:
- Around line 433-443: The compile_and_run function signature in
qwen3_32b_prefill.py removed the work_dir parameter, causing an inconsistent API
with sibling modules (see qwen3_32b_decode.py and
qwen3_32b_training_forward_and_backward.py); either restore work_dir: str | None
= None to the compile_and_run signature and propagate it to any internal
calls/variables that need a working directory, or remove work_dir from the other
modules so all three functions (compile_and_run in qwen3_32b_prefill.py,
qwen3_32b_decode.py, and qwen3_32b_training_forward_and_backward.py) share the
same signature; update callers to match the chosen approach and ensure any
file/dump behavior (e.g., dump_passes or file paths) uses the unified work_dir
handling.
---
Nitpick comments:
In `@examples/qwen3/qwen3_32b_prefill.py`:
- Around line 160-167: The x_chunk construction has inconsistent indentation
around the nested pl.reshape/pl.cast/pl.slice calls (symbols: x_chunk,
pl.reshape, pl.cast, pl.slice, hidden_states, TOK_TILE, K_CHUNK, valid_tok, b,
p0, k0); fix by aligning the chained function calls and their arguments
consistently with the surrounding blocks (same style used for earlier similar
blocks) so each nested call and its parameters are indented uniformly and
closing parentheses line up with their opening calls.
- Around line 138-145: The multi-line expression building x_chunk uses
pl.reshape(pl.cast(pl.slice(...))) but has inconsistent indentation that reduces
readability; reformat the nested calls so the arguments to pl.reshape are
aligned and the pl.cast( and its pl.slice(...) block are indented one level
under pl.reshape, making pl.slice(hidden_states, [1, TOK_TILE, K_CHUNK], [b, p0,
k0], valid_shape=[1, valid_tok, K_CHUNK]) clearly nested inside pl.cast which is
the first argument to pl.reshape; locate x_chunk, pl.reshape, pl.cast, and
pl.slice to apply this consistent indentation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: ee21d379-0d7d-483c-b9c5-887467c864e8
📒 Files selected for processing (1)
examples/qwen3/qwen3_32b_prefill.py
Summary by CodeRabbit
New Features
Improvements