feat(ir): add LowerBreakContinue pass for break/continue in InCore functions#491
feat(ir): add LowerBreakContinue pass for break/continue in InCore functions#491Hzfengsy wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
…nctions (hw-native-sys#448) Rewrites BreakStmt/ContinueStmt in InCore/AIC/AIV functions into equivalent structured control flow using phi-node IfStmt restructuring (for continue) and ForStmt-to-WhileStmt conversion (for break). Runs after InferTileMemorySpace and before InitMemRef in both Default and CCE strategies. Requires and preserves SSAForm and SplitIncoreOrch.
📝 WalkthroughWalkthroughThis PR introduces a LowerBreakContinue pass that transforms break and continue statements in InCore functions into structured control flow using while loops and phi-node-based yields. It includes pass declarations, properties, Python bindings, registration, comprehensive implementation, documentation, and tests. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. 📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial compiler optimization pass designed to transform unstructured control flow ( Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Pull request overview
This PR adds a new IR transform pass (LowerBreakContinue) and wires it into the default/CCE pass pipelines, with accompanying Python bindings, tests, and documentation, so codegen no longer needs to handle BreakStmt/ContinueStmt directly.
Changes:
- Implement
LowerBreakContinueC++ pass and register it in the build. - Expose the pass to Python (bindings + stubs) and add it to
PassManagerstrategies. - Add extensive unit tests and new/updated pipeline documentation pages.
Reviewed changes
Copilot reviewed 12 out of 22 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/ut/ir/transforms/test_pass_manager.py | Updates default strategy pass count/order assertions to include LowerBreakContinue. |
| tests/ut/ir/transforms/test_lower_break_continue_pass.py | Adds unit tests covering many break/continue patterns and pipeline integration. |
| src/ir/transforms/lower_break_continue_pass.cpp | Implements the LowerBreakContinue lowering logic (for/while rewriting + phi-based restructuring). |
| python/pypto/pypto_core/passes.pyi | Adds lower_break_continue() to Python type stubs and exports. |
| python/pypto/ir/pass_manager.py | Inserts LowerBreakContinue into Default and CCE strategy pass lists after ResolveTransposeLayout. |
| python/bindings/modules/passes.cpp | Exposes lower_break_continue in Python bindings. |
| include/pypto/ir/transforms/passes.h | Declares the new pass in the public C++ pass API. |
| include/pypto/ir/transforms/pass_properties.h | Adds required/produced property declarations for the new pass. |
| docs/zh-cn/dev/passes/* | Adds/updates Chinese pass docs including the new pass and pipeline ordering. |
| docs/en/dev/passes/* | Adds/updates English pass docs including the new pass and pipeline ordering. |
| CMakeLists.txt | Adds the new pass source file to the build. |
| .claude/rules/pass-doc-ordering.md | Updates the documented pipeline numbering/order to include the new pass. |
| /// Flatten a statement into a vector of statements (unwrapping SeqStmts). | ||
| static std::vector<StmtPtr> FlattenToVec(const StmtPtr& stmt) { | ||
| if (auto seq = std::dynamic_pointer_cast<const SeqStmts>(stmt)) { | ||
| return seq->stmts_; | ||
| } | ||
| return {stmt}; | ||
| } | ||
|
|
| for (const auto& s : flat) { | ||
| if (auto assign = std::dynamic_pointer_cast<const AssignStmt>(s)) { | ||
| available_vars.insert(assign->var_->UniqueId()); | ||
| } |
| // Case 2: IfStmt containing continue in a branch | ||
| auto if_stmt = std::dynamic_pointer_cast<const IfStmt>(stmt); | ||
| if (!if_stmt) continue; | ||
|
|
||
| auto then_scan = scanner.Scan(if_stmt->then_body_); | ||
| bool else_has_continue = false; | ||
| if (if_stmt->else_body_.has_value()) { | ||
| else_has_continue = scanner.Scan(*if_stmt->else_body_).has_continue; | ||
| } | ||
|
|
||
| if (!then_scan.has_continue && !else_has_continue) continue; | ||
|
|
||
| bool escape_in_then = then_scan.has_continue; | ||
| std::vector<StmtPtr> pre(stmts.begin(), stmts.begin() + static_cast<ptrdiff_t>(i)); | ||
| std::vector<StmtPtr> post(stmts.begin() + static_cast<ptrdiff_t>(i) + 1, stmts.end()); | ||
|
|
||
| auto continue_values = ResolveYieldAtContinue(original_yield_values, pre, iter_args); | ||
| auto normal_stmts = CollectNormalPath(if_stmt, escape_in_then, post); |
| // Case 2: IfStmt containing break in a branch | ||
| auto if_stmt = std::dynamic_pointer_cast<const IfStmt>(stmt); | ||
| if (!if_stmt) continue; | ||
|
|
||
| auto then_scan = scanner.Scan(if_stmt->then_body_); | ||
| bool else_has_break = false; | ||
| if (if_stmt->else_body_.has_value()) { | ||
| else_has_break = scanner.Scan(*if_stmt->else_body_).has_break; | ||
| } | ||
|
|
||
| if (!then_scan.has_break && !else_has_break) continue; | ||
|
|
||
| bool escape_in_then = then_scan.has_break; | ||
| std::vector<StmtPtr> pre(stmts.begin(), stmts.begin() + static_cast<ptrdiff_t>(i)); | ||
| std::vector<StmtPtr> post(stmts.begin() + static_cast<ptrdiff_t>(i) + 1, stmts.end()); | ||
|
|
||
| auto break_values = build_break_values(pre); | ||
| auto normal_stmts = CollectNormalPath(if_stmt, escape_in_then, post); | ||
| auto normal_result = ProcessBodyForBreak(normal_stmts, break_flag_index, while_iter_args, |
There was a problem hiding this comment.
Code Review
This pull request introduces a new LowerBreakContinue compiler pass, a significant and well-executed feature for rewriting break and continue statements into structured control flow. The C++ implementation is robust, handling various cases including nested loops, multiple iter_args, and combinations of break and continue. The transformation strategy is sound, and the feature is supported by an extensive and thorough suite of unit tests, which provides high confidence in its correctness. The integration into the build system, Python bindings, and pass manager is also well done. My feedback is focused on minor improvements to the documentation to ensure it accurately reflects the new pass's position within the compilation pipeline.
| - Input IR must be in SSA form (SSAForm required and preserved) | ||
| - InCore scopes must be outlined (SplitIncoreOrch required and preserved) | ||
|
|
||
| **When to use**: Run after InferTileMemorySpace and before InitMemRef. |
There was a problem hiding this comment.
For better precision and consistency with the PR description and pass_manager.py, it would be beneficial to state that this pass runs after ResolveTransposeLayout. While it is also after InferTileMemorySpace, ResolveTransposeLayout is its direct predecessor in the pipeline.
| **When to use**: Run after InferTileMemorySpace and before InitMemRef. | |
| **When to use**: Run after ResolveTransposeLayout and before InitMemRef. |
| ## Pipeline Position | ||
|
|
||
| ```text | ||
| ... → InferTileMemorySpace → LowerBreakContinue → InitMemRef → ... |
There was a problem hiding this comment.
The pipeline position diagram appears to be missing the ResolveTransposeLayout pass. According to python/pypto/ir/pass_manager.py, this pass runs after ResolveTransposeLayout. The diagram should be updated to reflect the correct pass ordering for accuracy.
| ... → InferTileMemorySpace → LowerBreakContinue → InitMemRef → ... | |
| ... → ResolveTransposeLayout → LowerBreakContinue → InitMemRef → ... |
| - 输入 IR 必须为 SSA 形式(SSAForm 必需且保持) | ||
| - InCore 作用域必须已提取(SplitIncoreOrch 必需且保持) | ||
|
|
||
| **使用时机**:在 InferTileMemorySpace 之后、InitMemRef 之前运行。 |
There was a problem hiding this comment.
For better precision and consistency with the implementation in pass_manager.py, it would be beneficial to state that this pass runs after ResolveTransposeLayout. While it is also after InferTileMemorySpace, ResolveTransposeLayout is its direct predecessor in the pipeline.
| **使用时机**:在 InferTileMemorySpace 之后、InitMemRef 之前运行。 | |
| **使用时机**:在 ResolveTransposeLayout 之后、InitMemRef 之前运行。 |
| ## 流水线位置 | ||
|
|
||
| ```text | ||
| ... → InferTileMemorySpace → LowerBreakContinue → InitMemRef → ... |
There was a problem hiding this comment.
The pipeline position diagram appears to be missing the ResolveTransposeLayout pass. According to python/pypto/ir/pass_manager.py, this pass runs after ResolveTransposeLayout. The diagram should be updated to reflect the correct pass ordering for accuracy.
| ... → InferTileMemorySpace → LowerBreakContinue → InitMemRef → ... | |
| ... → ResolveTransposeLayout → LowerBreakContinue → InitMemRef → ... |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/en/dev/passes/11-lower_break_continue.md`:
- Line 18: Update the docs to reflect the actual pass ordering: change the
description for LowerBreakContinue to "Run after ResolveTransposeLayout and
before InitMemRef" (matching pass_manager.py where LowerBreakContinue is
scheduled) and adjust the pipeline diagram so LowerBreakContinue is placed after
ResolveTransposeLayout and before InitMemRef; reference the LowerBreakContinue
pass name and the ResolveTransposeLayout and InitMemRef pass names when making
the change.
In `@docs/zh-cn/dev/passes/11-lower_break_continue.md`:
- Line 18: Update the pipeline placement text and diagram so LowerBreakContinue
is described as running after ResolveTransposeLayout (not directly after
InferTileMemorySpace); specifically, change the sentence containing "在
InferTileMemorySpace 之后、InitMemRef 之前运行" to indicate it runs after
ResolveTransposeLayout and before InitMemRef, and make the same change in the
other occurrence referenced around lines 97-100; ensure mentions of
ResolveTransposeLayout and the LowerBreakContinue pass are consistent.
In `@python/bindings/modules/passes.cpp`:
- Around line 229-230: The docstring for the binding
passes.def("lower_break_continue", &pass::LowerBreakContinue, ...) is too
narrow: update the bound Python docstring to state that LowerBreakContinue
operates on all functions for which IsInCoreType() returns true (not just
InCore), explicitly mentioning it also runs on AIC and AIV functions; edit the
string literal passed to passes.def for lower_break_continue to broaden the
scope language accordingly while keeping the description otherwise the same.
In `@src/ir/transforms/lower_break_continue_pass.cpp`:
- Around line 350-385: The code incorrectly treats a branch as wholly escaping
when scanner.Scan only reports that some path inside it escapes; change the
logic to only collapse a branch when all paths in that branch escape (e.g.,
add/consume an "all_paths_escape" or equivalent from scanner.Scan for then_body_
and else_body_), otherwise perform path-sensitive rewriting: call
ProcessBodyForContinue/CollectNormalPath on the actual branch body to preserve
statements before the inner escape and only move the exact escape paths into
BuildEscapeIfStmt; update the escape_in_then/else checks to use the new
all-paths flag (rather than has_continue) and apply the same fix to the
analogous block around lines 429-451 (the branches that build new IfStmt and
call BuildEscapeIfStmt), referencing IfStmt, then_body_, else_body_,
scanner.Scan, CollectNormalPath, ProcessBodyForContinue, and BuildEscapeIfStmt.
- Around line 172-217: ResolveYieldAtContinue is incorrectly reusing
original_yield_values for the escape path; change it so that for each yield
index j you use iter_args[j] when an iter_arg exists (i.e., if j <
iter_args.size()) regardless of whether original_yield_values[j] is a constant,
expression, or a Var defined earlier, and only fall back to
original_yield_values[j] when there is no corresponding iter_arg; update the
loop that builds resolved (using original_yield_values, var, available_vars, and
iter_args) to implement this behavior.
In `@tests/ut/ir/transforms/test_lower_break_continue_pass.py`:
- Around line 25-911: Tests in test_lower_break_continue_pass.py only cover
pl.range (ForStmt) textual output and miss direct pl.while_ cases and structural
IR assertions; add new fixtures that use pl.while_ loops (e.g., create tests
similar to test_break_in_for/test_continue_in_for but defining kernels with
pl.while_ instead of pl.range) and assert that passes.lower_break_continue()
invokes the LowerWhileWithContinue/LowerWhileWithBreak paths by checking IR
shape via ir.assert_structural_equal or inspecting node types (e.g., ensure the
transformed function contains WhileStmt nodes or specific pass-produced
properties), and strengthen a few nested-if tests (like
test_continue_in_nested_if/test_break_in_nested_if) to assert the IR structure
(not just absence of "continue"/"break") after passes.lower_break_continue() so
regressions in lowering are caught.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2c094c83-8d39-44de-a082-2de13e38fb2b
📒 Files selected for processing (22)
.claude/rules/pass-doc-ordering.mdCMakeLists.txtdocs/en/dev/passes/11-lower_break_continue.mddocs/en/dev/passes/12-init_memref.mddocs/en/dev/passes/13-basic_memory_reuse.mddocs/en/dev/passes/14-insert_sync.mddocs/en/dev/passes/15-allocate_memory_addr.mddocs/en/dev/passes/16-utility_passes.mddocs/zh-cn/dev/passes/11-lower_break_continue.mddocs/zh-cn/dev/passes/12-init_memref.mddocs/zh-cn/dev/passes/13-basic_memory_reuse.mddocs/zh-cn/dev/passes/14-insert_sync.mddocs/zh-cn/dev/passes/15-allocate_memory_addr.mddocs/zh-cn/dev/passes/16-utility_passes.mdinclude/pypto/ir/transforms/pass_properties.hinclude/pypto/ir/transforms/passes.hpython/bindings/modules/passes.cpppython/pypto/ir/pass_manager.pypython/pypto/pypto_core/passes.pyisrc/ir/transforms/lower_break_continue_pass.cpptests/ut/ir/transforms/test_lower_break_continue_pass.pytests/ut/ir/transforms/test_pass_manager.py
| - Input IR must be in SSA form (SSAForm required and preserved) | ||
| - InCore scopes must be outlined (SplitIncoreOrch required and preserved) | ||
|
|
||
| **When to use**: Run after InferTileMemorySpace and before InitMemRef. |
There was a problem hiding this comment.
Minor documentation inconsistency in pipeline position.
The text states "Run after InferTileMemorySpace" but looking at pass_manager.py, LowerBreakContinue actually runs after ResolveTransposeLayout. The pipeline diagram on line 100 is also slightly inconsistent with the actual ordering.
Consider updating to: "Run after ResolveTransposeLayout and before InitMemRef."
📝 Suggested fix
-**When to use**: Run after InferTileMemorySpace and before InitMemRef.
+**When to use**: Run after ResolveTransposeLayout and before InitMemRef.And update the pipeline diagram at line 100:
-... → InferTileMemorySpace → LowerBreakContinue → InitMemRef → ...
+... → ResolveTransposeLayout → LowerBreakContinue → InitMemRef → ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/en/dev/passes/11-lower_break_continue.md` at line 18, Update the docs to
reflect the actual pass ordering: change the description for LowerBreakContinue
to "Run after ResolveTransposeLayout and before InitMemRef" (matching
pass_manager.py where LowerBreakContinue is scheduled) and adjust the pipeline
diagram so LowerBreakContinue is placed after ResolveTransposeLayout and before
InitMemRef; reference the LowerBreakContinue pass name and the
ResolveTransposeLayout and InitMemRef pass names when making the change.
| - 输入 IR 必须为 SSA 形式(SSAForm 必需且保持) | ||
| - InCore 作用域必须已提取(SplitIncoreOrch 必需且保持) | ||
|
|
||
| **使用时机**:在 InferTileMemorySpace 之后、InitMemRef 之前运行。 |
There was a problem hiding this comment.
Update the pipeline placement to include ResolveTransposeLayout.
The doc currently implies LowerBreakContinue runs directly after InferTileMemorySpace, but the registered/tested order places it after ResolveTransposeLayout. The text and diagram should match the real pipeline.
📝 Proposed fix
-**使用时机**:在 InferTileMemorySpace 之后、InitMemRef 之前运行。
+**使用时机**:在 ResolveTransposeLayout 之后、InitMemRef 之前运行。
...
-... → InferTileMemorySpace → LowerBreakContinue → InitMemRef → ...
+... → InferTileMemorySpace → ResolveTransposeLayout → LowerBreakContinue → InitMemRef → ...Also applies to: 97-100
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/zh-cn/dev/passes/11-lower_break_continue.md` at line 18, Update the
pipeline placement text and diagram so LowerBreakContinue is described as
running after ResolveTransposeLayout (not directly after InferTileMemorySpace);
specifically, change the sentence containing "在 InferTileMemorySpace
之后、InitMemRef 之前运行" to indicate it runs after ResolveTransposeLayout and before
InitMemRef, and make the same change in the other occurrence referenced around
lines 97-100; ensure mentions of ResolveTransposeLayout and the
LowerBreakContinue pass are consistent.
| passes.def("lower_break_continue", &pass::LowerBreakContinue, | ||
| "Create a pass that lowers break/continue into equivalent control flow in InCore functions"); |
There was a problem hiding this comment.
Broaden the Python docstring to match the actual pass scope.
The binding says this only lowers control flow in InCore functions, but the pass runs on all IsInCoreType() functions, including AIC and AIV. The current text will under-document the public API.
✏️ Proposed fix
passes.def("lower_break_continue", &pass::LowerBreakContinue,
- "Create a pass that lowers break/continue into equivalent control flow in InCore functions");
+ "Create a pass that lowers break/continue into equivalent control flow in InCore/AIC/AIV functions");📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| passes.def("lower_break_continue", &pass::LowerBreakContinue, | |
| "Create a pass that lowers break/continue into equivalent control flow in InCore functions"); | |
| passes.def("lower_break_continue", &pass::LowerBreakContinue, | |
| "Create a pass that lowers break/continue into equivalent control flow in InCore/AIC/AIV functions"); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@python/bindings/modules/passes.cpp` around lines 229 - 230, The docstring for
the binding passes.def("lower_break_continue", &pass::LowerBreakContinue, ...)
is too narrow: update the bound Python docstring to state that
LowerBreakContinue operates on all functions for which IsInCoreType() returns
true (not just InCore), explicitly mentioning it also runs on AIC and AIV
functions; edit the string literal passed to passes.def for lower_break_continue
to broaden the scope language accordingly while keeping the description
otherwise the same.
| static std::vector<ExprPtr> ResolveYieldAtContinue(const std::vector<ExprPtr>& original_yield_values, | ||
| const std::vector<StmtPtr>& pre_continue_stmts, | ||
| const std::vector<IterArgPtr>& iter_args) { | ||
| // Collect unique IDs of all variables defined before the continue point | ||
| std::unordered_set<uint64_t> available_vars; | ||
| for (const auto& stmt : pre_continue_stmts) { | ||
| auto flat = FlattenToVec(stmt); | ||
| for (const auto& s : flat) { | ||
| if (auto assign = std::dynamic_pointer_cast<const AssignStmt>(s)) { | ||
| available_vars.insert(assign->var_->UniqueId()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Iter args are always available within the loop body | ||
| for (const auto& ia : iter_args) { | ||
| available_vars.insert(ia->UniqueId()); | ||
| } | ||
|
|
||
| std::vector<ExprPtr> resolved; | ||
| resolved.reserve(original_yield_values.size()); | ||
|
|
||
| for (size_t j = 0; j < original_yield_values.size(); ++j) { | ||
| const auto& val = original_yield_values[j]; | ||
| auto var = std::dynamic_pointer_cast<const Var>(val); | ||
| if (!var) { | ||
| // Not a variable reference (e.g., a constant) — use as-is | ||
| resolved.push_back(val); | ||
| continue; | ||
| } | ||
|
|
||
| if (available_vars.count(var->UniqueId())) { | ||
| resolved.push_back(val); | ||
| continue; | ||
| } | ||
|
|
||
| // Variable not available — use corresponding iter_arg | ||
| if (j < iter_args.size()) { | ||
| resolved.push_back(iter_args[j]); | ||
| } else { | ||
| resolved.push_back(val); | ||
| } | ||
| } | ||
|
|
||
| return resolved; | ||
| } |
There was a problem hiding this comment.
Escape paths must carry the current iter_args, not the trailing yield expressions.
break/continue skip the loop’s trailing YieldStmt, so the loop-carried state must stay at the current iter_arg values. This helper currently reuses original_yield_values when they are constants, expressions, or vars defined earlier in the body, which changes semantics. For example, y = ...; if cond: continue; x_iter = pl.yield_(y) will incorrectly advance x_iter on the continue path instead of preserving the current x_iter. The same bug affects break.
🐛 Proposed fix
static std::vector<ExprPtr> ResolveYieldAtContinue(const std::vector<ExprPtr>& original_yield_values,
const std::vector<StmtPtr>& pre_continue_stmts,
const std::vector<IterArgPtr>& iter_args) {
- // Collect unique IDs of all variables defined before the continue point
- std::unordered_set<uint64_t> available_vars;
- for (const auto& stmt : pre_continue_stmts) {
- auto flat = FlattenToVec(stmt);
- for (const auto& s : flat) {
- if (auto assign = std::dynamic_pointer_cast<const AssignStmt>(s)) {
- available_vars.insert(assign->var_->UniqueId());
- }
- }
- }
-
- // Iter args are always available within the loop body
- for (const auto& ia : iter_args) {
- available_vars.insert(ia->UniqueId());
- }
-
std::vector<ExprPtr> resolved;
resolved.reserve(original_yield_values.size());
for (size_t j = 0; j < original_yield_values.size(); ++j) {
- const auto& val = original_yield_values[j];
- auto var = std::dynamic_pointer_cast<const Var>(val);
- if (!var) {
- // Not a variable reference (e.g., a constant) — use as-is
- resolved.push_back(val);
- continue;
- }
-
- if (available_vars.count(var->UniqueId())) {
- resolved.push_back(val);
- continue;
- }
-
- // Variable not available — use corresponding iter_arg
- if (j < iter_args.size()) {
- resolved.push_back(iter_args[j]);
- } else {
- resolved.push_back(val);
- }
+ resolved.push_back(j < iter_args.size() ? iter_args[j] : original_yield_values[j]);
}
return resolved;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| static std::vector<ExprPtr> ResolveYieldAtContinue(const std::vector<ExprPtr>& original_yield_values, | |
| const std::vector<StmtPtr>& pre_continue_stmts, | |
| const std::vector<IterArgPtr>& iter_args) { | |
| // Collect unique IDs of all variables defined before the continue point | |
| std::unordered_set<uint64_t> available_vars; | |
| for (const auto& stmt : pre_continue_stmts) { | |
| auto flat = FlattenToVec(stmt); | |
| for (const auto& s : flat) { | |
| if (auto assign = std::dynamic_pointer_cast<const AssignStmt>(s)) { | |
| available_vars.insert(assign->var_->UniqueId()); | |
| } | |
| } | |
| } | |
| // Iter args are always available within the loop body | |
| for (const auto& ia : iter_args) { | |
| available_vars.insert(ia->UniqueId()); | |
| } | |
| std::vector<ExprPtr> resolved; | |
| resolved.reserve(original_yield_values.size()); | |
| for (size_t j = 0; j < original_yield_values.size(); ++j) { | |
| const auto& val = original_yield_values[j]; | |
| auto var = std::dynamic_pointer_cast<const Var>(val); | |
| if (!var) { | |
| // Not a variable reference (e.g., a constant) — use as-is | |
| resolved.push_back(val); | |
| continue; | |
| } | |
| if (available_vars.count(var->UniqueId())) { | |
| resolved.push_back(val); | |
| continue; | |
| } | |
| // Variable not available — use corresponding iter_arg | |
| if (j < iter_args.size()) { | |
| resolved.push_back(iter_args[j]); | |
| } else { | |
| resolved.push_back(val); | |
| } | |
| } | |
| return resolved; | |
| } | |
| static std::vector<ExprPtr> ResolveYieldAtContinue(const std::vector<ExprPtr>& original_yield_values, | |
| const std::vector<StmtPtr>& pre_continue_stmts, | |
| const std::vector<IterArgPtr>& iter_args) { | |
| std::vector<ExprPtr> resolved; | |
| resolved.reserve(original_yield_values.size()); | |
| for (size_t j = 0; j < original_yield_values.size(); ++j) { | |
| resolved.push_back(j < iter_args.size() ? iter_args[j] : original_yield_values[j]); | |
| } | |
| return resolved; | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ir/transforms/lower_break_continue_pass.cpp` around lines 172 - 217,
ResolveYieldAtContinue is incorrectly reusing original_yield_values for the
escape path; change it so that for each yield index j you use iter_args[j] when
an iter_arg exists (i.e., if j < iter_args.size()) regardless of whether
original_yield_values[j] is a constant, expression, or a Var defined earlier,
and only fall back to original_yield_values[j] when there is no corresponding
iter_arg; update the loop that builds resolved (using original_yield_values,
var, available_vars, and iter_args) to implement this behavior.
| // Case 2: IfStmt containing continue in a branch | ||
| auto if_stmt = std::dynamic_pointer_cast<const IfStmt>(stmt); | ||
| if (!if_stmt) continue; | ||
|
|
||
| auto then_scan = scanner.Scan(if_stmt->then_body_); | ||
| bool else_has_continue = false; | ||
| if (if_stmt->else_body_.has_value()) { | ||
| else_has_continue = scanner.Scan(*if_stmt->else_body_).has_continue; | ||
| } | ||
|
|
||
| if (!then_scan.has_continue && !else_has_continue) continue; | ||
|
|
||
| bool escape_in_then = then_scan.has_continue; | ||
| std::vector<StmtPtr> pre(stmts.begin(), stmts.begin() + static_cast<ptrdiff_t>(i)); | ||
| std::vector<StmtPtr> post(stmts.begin() + static_cast<ptrdiff_t>(i) + 1, stmts.end()); | ||
|
|
||
| auto continue_values = ResolveYieldAtContinue(original_yield_values, pre, iter_args); | ||
| auto normal_stmts = CollectNormalPath(if_stmt, escape_in_then, post); | ||
| auto normal_result = | ||
| ProcessBodyForContinue(normal_stmts, iter_args, original_yield_values, name_counter, span); | ||
|
|
||
| if (original_yield_values.empty()) { | ||
| // No iter_args — no yields or phi needed | ||
| auto empty_body = std::make_shared<SeqStmts>(std::vector<StmtPtr>{}, if_stmt->span_); | ||
| auto filled_body = MakeSeq(std::move(normal_result.stmts), if_stmt->span_); | ||
| StmtPtr then_body = escape_in_then ? static_cast<StmtPtr>(empty_body) : filled_body; | ||
| StmtPtr else_body = escape_in_then ? filled_body : static_cast<StmtPtr>(empty_body); | ||
| auto new_if = std::make_shared<IfStmt>(if_stmt->condition_, then_body, std::make_optional(else_body), | ||
| std::vector<VarPtr>{}, if_stmt->span_); | ||
| std::vector<StmtPtr> result(pre.begin(), pre.end()); | ||
| result.push_back(new_if); | ||
| return BodyResult{std::move(result), {}}; | ||
| } | ||
|
|
||
| return BuildEscapeIfStmt(if_stmt, escape_in_then, pre, continue_values, std::move(normal_result), | ||
| name_counter, span); |
There was a problem hiding this comment.
A nested break/continue is being widened to the whole outer branch.
scanner.Scan(...) only tells you that a branch contains an escape somewhere, but this code then treats the entire then/else body as escaping. That breaks cases like if a: if b: continue / if a: if b: break: the rewritten outer IfStmt will escape on a, not on a && b, and any statements before the inner escape inside that branch are dropped. This needs path-sensitive rewriting of the branch body, or an “all paths escape” check before collapsing the whole branch.
Also applies to: 429-451
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ir/transforms/lower_break_continue_pass.cpp` around lines 350 - 385, The
code incorrectly treats a branch as wholly escaping when scanner.Scan only
reports that some path inside it escapes; change the logic to only collapse a
branch when all paths in that branch escape (e.g., add/consume an
"all_paths_escape" or equivalent from scanner.Scan for then_body_ and
else_body_), otherwise perform path-sensitive rewriting: call
ProcessBodyForContinue/CollectNormalPath on the actual branch body to preserve
statements before the inner escape and only move the exact escape paths into
BuildEscapeIfStmt; update the escape_in_then/else checks to use the new
all-paths flag (rather than has_continue) and apply the same fix to the
analogous block around lines 429-451 (the branches that build new IfStmt and
call BuildEscapeIfStmt), referencing IfStmt, then_body_, else_body_,
scanner.Scan, CollectNormalPath, ProcessBodyForContinue, and BuildEscapeIfStmt.
| def test_continue_in_for(): | ||
| """Continue in ForStmt restructured to if/else with phi-node yield.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 5: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
| # Should still be a ForStmt (no break) | ||
| assert "pl.range(" in printed | ||
| # Phi-node approach: IfStmt with yields feeding a trailing yield | ||
| assert "pl.yield_(" in printed | ||
|
|
||
|
|
||
| def test_break_in_for(): | ||
| """Break in ForStmt converts to WhileStmt with break flag.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i > 5: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert "pl.while_" in printed | ||
|
|
||
|
|
||
| def test_break_and_continue_in_for(): | ||
| """ForStmt with both break and continue.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 3: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| if i > 7: | ||
| break | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_no_break_continue_noop(): | ||
| """Pass is identity when no break/continue.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| ir.assert_structural_equal(After, Before) | ||
|
|
||
|
|
||
| def test_orchestration_untouched(): | ||
| """Non-InCore functions are not transformed.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function | ||
| def main(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i in pl.range(0, 10, 1): | ||
| if i > 5: | ||
| break | ||
| return x_0 | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| ir.assert_structural_equal(After, Before) | ||
|
|
||
|
|
||
| def test_continue_multiple_iter_args(): | ||
| """Continue with multiple iter_args yields current iter_arg values.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel( | ||
| self, | ||
| a_0: pl.Tensor[[64], pl.FP32], | ||
| b_0: pl.Tensor[[64], pl.FP32], | ||
| ) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (a_iter, b_iter) in pl.range(0, 10, 1, init_values=(a_0, b_0)): | ||
| if i < 5: | ||
| continue | ||
| a_new: pl.Tensor[[64], pl.FP32] = pl.add(a_iter, b_iter) | ||
| b_new: pl.Tensor[[64], pl.FP32] = pl.add(b_iter, a_iter) | ||
| a_iter, b_iter = pl.yield_(a_new, b_new) # noqa: PLW2901 | ||
| return a_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_continue_with_pre_continue_assignment(): | ||
| """Continue after assignments — backward resolution yields iter_arg value.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| if i < 5: | ||
| continue | ||
| z: pl.Tensor[[64], pl.FP32] = pl.add(y, y) | ||
| x_iter = pl.yield_(z) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_break_negative_step(): | ||
| """Break in for loop with negative step uses > condition.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(10, 0, -1, init_values=(x_0,)): | ||
| if i < 3: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
|
|
||
|
|
||
| def test_aic_function_type(): | ||
| """Pass processes AIC function type.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.AIC, strict_ssa=True) | ||
| def aic_kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 5: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_continue_no_iter_args(): | ||
| """Continue in loop with no carried state.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i in pl.range(0, 10, 1): | ||
| if i < 5: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_0, x_0) # noqa: F841 | ||
| return x_0 | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_break_no_iter_args(): | ||
| """Break in loop with no carried state.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i in pl.range(0, 10, 1): | ||
| if i > 5: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_0, x_0) # noqa: F841 | ||
| return x_0 | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
|
|
||
|
|
||
| def test_nested_loops_only_inner(): | ||
| """Only inner loop with continue is transformed, outer loop unchanged.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_outer,) in pl.range(0, 4, 1, init_values=(x_0,)): | ||
| for j, (x_inner,) in pl.range(0, 8, 1, init_values=(x_outer,)): | ||
| if j < 2: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_inner, x_inner) | ||
| x_inner = pl.yield_(y) # noqa: PLW2901 | ||
| x_outer = pl.yield_(x_inner) # noqa: PLW2901 | ||
| return x_outer | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
| # Outer loop should still be a ForStmt | ||
| assert "pl.range(4" in printed or "pl.range(0, 4" in printed | ||
|
|
||
|
|
||
| def test_multiple_continues_in_body(): | ||
| """Two separate if-continue blocks in the same loop body.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 2: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| if i > 8: | ||
| continue | ||
| z: pl.Tensor[[64], pl.FP32] = pl.add(y, y) | ||
| x_iter = pl.yield_(z) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_both_outer_and_inner_loop_have_break(): | ||
| """Outer and inner loop both have break — both converted to WhileStmt.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_outer,) in pl.range(0, 4, 1, init_values=(x_0,)): | ||
| for j, (x_inner,) in pl.range(0, 8, 1, init_values=(x_outer,)): | ||
| if j > 3: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_inner, x_inner) | ||
| x_inner = pl.yield_(y) # noqa: PLW2901 | ||
| if i > 2: | ||
| break | ||
| x_outer = pl.yield_(x_inner) # noqa: PLW2901 | ||
| return x_outer | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
|
|
||
|
|
||
| def test_multi_function_program(): | ||
| """Program with InCore and Orchestration — only InCore transformed.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def incore_kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i > 5: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| @pl.function(type=pl.FunctionType.Orchestration) | ||
| def main(self, x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| y: pl.Tensor[[64], pl.FP32] = self.incore_kernel(x) | ||
| return y | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| # InCore function should have break lowered | ||
| assert not _has_bare_keyword(printed, "break") | ||
|
|
||
|
|
||
| def test_pass_properties(): | ||
| """Verify pass has correct properties.""" | ||
| p = passes.lower_break_continue() | ||
| assert p.get_name() == "LowerBreakContinue" | ||
|
|
||
| required = p.get_required_properties() | ||
| assert required.contains(passes.IRProperty.SSAForm) | ||
| assert required.contains(passes.IRProperty.SplitIncoreOrch) | ||
|
|
||
| produced = p.get_produced_properties() | ||
| assert produced.contains(passes.IRProperty.SSAForm) | ||
| assert produced.contains(passes.IRProperty.SplitIncoreOrch) | ||
|
|
||
|
|
||
| def test_pass_in_pipeline(): | ||
| """Verify pass is registered in both Default and CCE strategies.""" | ||
| for strategy in [OptimizationStrategy.Default, OptimizationStrategy.CCE]: | ||
| pm = PassManager.get_strategy(strategy) | ||
| names = pm.get_pass_names() | ||
| assert "LowerBreakContinue" in names | ||
| # Must come after InferTileMemorySpace | ||
| rtl_idx = names.index("ResolveTransposeLayout") | ||
| lbc_idx = names.index("LowerBreakContinue") | ||
| assert lbc_idx == rtl_idx + 1 | ||
|
|
||
|
|
||
| def test_pipeline_integration(): | ||
| """Pass works in a partial compilation pipeline.""" | ||
|
|
||
| @pl.program | ||
| class Input: | ||
| @pl.function | ||
| def main(self, x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| with pl.incore(): | ||
| for i in pl.range(10): | ||
| if i < 5: | ||
| continue | ||
| x = pl.add(x, x) | ||
| return x | ||
|
|
||
| after_ssa = passes.convert_to_ssa()(Input) | ||
| after_outline = passes.outline_incore_scopes()(after_ssa) | ||
| after_lower = passes.lower_break_continue()(after_outline) | ||
|
|
||
| printed = after_lower.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
| assert not _has_bare_keyword(printed, "break") | ||
|
|
||
|
|
||
| # =========================================================================== | ||
| # Nested loops | ||
| # =========================================================================== | ||
|
|
||
|
|
||
| def test_nested_continue_outer_break_inner(): | ||
| """Continue in outer loop, break in inner loop.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_outer,) in pl.range(0, 4, 1, init_values=(x_0,)): | ||
| for j, (x_inner,) in pl.range(0, 8, 1, init_values=(x_outer,)): | ||
| if j > 3: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_inner, x_inner) | ||
| x_inner = pl.yield_(y) # noqa: PLW2901 | ||
| if i < 2: | ||
| continue | ||
| x_outer = pl.yield_(x_inner) # noqa: PLW2901 | ||
| return x_outer | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
| # Inner loop should be while (has break), outer stays for (only continue) | ||
| assert "pl.while_" in printed | ||
| assert "pl.range(" in printed | ||
|
|
||
|
|
||
| def test_nested_break_outer_continue_inner(): | ||
| """Break in outer loop, continue in inner loop.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_outer,) in pl.range(0, 4, 1, init_values=(x_0,)): | ||
| for j, (x_inner,) in pl.range(0, 8, 1, init_values=(x_outer,)): | ||
| if j < 2: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_inner, x_inner) | ||
| x_inner = pl.yield_(y) # noqa: PLW2901 | ||
| if i > 2: | ||
| break | ||
| x_outer = pl.yield_(x_inner) # noqa: PLW2901 | ||
| return x_outer | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_nested_continue_both_loops(): | ||
| """Continue in both inner and outer loops.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_outer,) in pl.range(0, 4, 1, init_values=(x_0,)): | ||
| if i < 1: | ||
| continue | ||
| for j, (x_inner,) in pl.range(0, 8, 1, init_values=(x_outer,)): | ||
| if j < 2: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_inner, x_inner) | ||
| x_inner = pl.yield_(y) # noqa: PLW2901 | ||
| x_outer = pl.yield_(x_inner) # noqa: PLW2901 | ||
| return x_outer | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
| # Both should still be ForStmts (no break) | ||
| assert "pl.while_" not in printed | ||
|
|
||
|
|
||
| def test_nested_break_and_continue_inner(): | ||
| """Inner loop has both break and continue, outer is clean.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_outer,) in pl.range(0, 4, 1, init_values=(x_0,)): | ||
| for j, (x_inner,) in pl.range(0, 8, 1, init_values=(x_outer,)): | ||
| if j < 2: | ||
| continue | ||
| if j > 5: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_inner, x_inner) | ||
| x_inner = pl.yield_(y) # noqa: PLW2901 | ||
| x_outer = pl.yield_(x_inner) # noqa: PLW2901 | ||
| return x_outer | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
| # Inner loop should become while (has break) | ||
| assert "pl.while_" in printed | ||
|
|
||
|
|
||
| def test_three_level_nesting_break_at_each(): | ||
| """Three levels of nested loops, break at each level.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_l1,) in pl.range(0, 3, 1, init_values=(x_0,)): | ||
| for j, (x_l2,) in pl.range(0, 4, 1, init_values=(x_l1,)): | ||
| for k, (x_l3,) in pl.range(0, 5, 1, init_values=(x_l2,)): | ||
| if k > 2: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_l3, x_l3) | ||
| x_l3 = pl.yield_(y) # noqa: PLW2901 | ||
| if j > 1: | ||
| break | ||
| x_l2 = pl.yield_(x_l3) # noqa: PLW2901 | ||
| if i > 0: | ||
| break | ||
| x_l1 = pl.yield_(x_l2) # noqa: PLW2901 | ||
| return x_l1 | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
|
|
||
|
|
||
| # =========================================================================== | ||
| # Nested branches | ||
| # =========================================================================== | ||
|
|
||
|
|
||
| def test_continue_in_nested_if(): | ||
| """Continue inside a nested if (if inside if).""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 8: | ||
| if i < 3: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_break_in_nested_if(): | ||
| """Break inside a nested if.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i > 3: | ||
| if i > 7: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert "pl.while_" in printed | ||
|
|
||
|
|
||
| def test_continue_in_else_branch(): | ||
| """Continue in else branch of IfStmt (not then branch).""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i > 5: | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| else: | ||
| continue | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_break_in_else_branch(): | ||
| """Break in else branch of IfStmt.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 7: | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| else: | ||
| break | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert "pl.while_" in printed | ||
|
|
||
|
|
||
| def test_if_else_continue_then_break_else(): | ||
| """Continue in then branch, break in else branch of same IfStmt.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| if i < 3: | ||
| continue | ||
| elif i > 7: | ||
| break | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_normal_if_else_before_continue(): | ||
| """If/else without break/continue, followed by a continue guard.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 5: | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| else: | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_0) | ||
| if i < 2: | ||
| continue | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| # =========================================================================== | ||
| # Unconditional break/continue | ||
| # =========================================================================== | ||
|
|
||
|
|
||
| def test_unconditional_break(): | ||
| """Bare break as first statement — loop executes 0 iterations effectively.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| break | ||
| x_iter = pl.yield_(x_iter) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
|
|
||
|
|
||
| def test_unconditional_continue(): | ||
| """Bare continue as first statement — all iterations are skipped.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| continue | ||
| x_iter = pl.yield_(x_iter) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| # =========================================================================== | ||
| # Multiple break/continue patterns | ||
| # =========================================================================== | ||
|
|
||
|
|
||
| def test_back_to_back_breaks(): | ||
| """Two separate if-break blocks in the same loop body.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i > 8: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| if i > 5: | ||
| break | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert "pl.while_" in printed | ||
|
|
||
|
|
||
| def test_continue_then_break(): | ||
| """Continue guard first, then break guard in same body.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 2: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| if i > 7: | ||
| break | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_break_then_continue(): | ||
| """Break guard first, then continue guard in same body.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i > 8: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| if i < 3: | ||
| continue | ||
| z: pl.Tensor[[64], pl.FP32] = pl.add(y, y) | ||
| x_iter = pl.yield_(z) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| # =========================================================================== | ||
| # Complex nested combinations | ||
| # =========================================================================== | ||
|
|
||
|
|
||
| def test_nested_loop_inner_continue_outer_break_and_continue(): | ||
| """Inner loop has continue, outer loop has both break and continue.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_outer,) in pl.range(0, 6, 1, init_values=(x_0,)): | ||
| if i < 1: | ||
| continue | ||
| for j, (x_inner,) in pl.range(0, 8, 1, init_values=(x_outer,)): | ||
| if j < 2: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_inner, x_inner) | ||
| x_inner = pl.yield_(y) # noqa: PLW2901 | ||
| if i > 4: | ||
| break | ||
| x_outer = pl.yield_(x_inner) # noqa: PLW2901 | ||
| return x_outer | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_nested_loop_both_have_break_and_continue(): | ||
| """Both inner and outer loops have break and continue.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_outer,) in pl.range(0, 4, 1, init_values=(x_0,)): | ||
| if i < 1: | ||
| continue | ||
| for j, (x_inner,) in pl.range(0, 8, 1, init_values=(x_outer,)): | ||
| if j < 2: | ||
| continue | ||
| if j > 5: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_inner, x_inner) | ||
| x_inner = pl.yield_(y) # noqa: PLW2901 | ||
| if i > 2: | ||
| break | ||
| x_outer = pl.yield_(x_inner) # noqa: PLW2901 | ||
| return x_outer | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_deeply_nested_if_with_continue(): | ||
| """Continue inside three levels of nested ifs.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i < 8: | ||
| if i < 5: | ||
| if i < 2: | ||
| continue | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| def test_deeply_nested_if_with_break(): | ||
| """Break inside three levels of nested ifs.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 10, 1, init_values=(x_0,)): | ||
| if i > 3: | ||
| if i > 5: | ||
| if i > 7: | ||
| break | ||
| y: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| x_iter = pl.yield_(y) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert "pl.while_" in printed | ||
|
|
||
|
|
||
| def test_multiple_iter_args_with_break(): | ||
| """Break with multiple iter_args — all are carried through WhileStmt.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel( | ||
| self, | ||
| a_0: pl.Tensor[[64], pl.FP32], | ||
| b_0: pl.Tensor[[64], pl.FP32], | ||
| ) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (a_iter, b_iter) in pl.range(0, 10, 1, init_values=(a_0, b_0)): | ||
| if i > 5: | ||
| break | ||
| a_new: pl.Tensor[[64], pl.FP32] = pl.add(a_iter, b_iter) | ||
| b_new: pl.Tensor[[64], pl.FP32] = pl.add(b_iter, a_iter) | ||
| a_iter, b_iter = pl.yield_(a_new, b_new) # noqa: PLW2901 | ||
| return a_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "break") | ||
| assert "pl.while_" in printed | ||
|
|
||
|
|
||
| def test_computation_between_continues(): | ||
| """Multiple continues with computation in between each guard.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore, strict_ssa=True) | ||
| def kernel(self, x_0: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: | ||
| for i, (x_iter,) in pl.range(0, 20, 1, init_values=(x_0,)): | ||
| a: pl.Tensor[[64], pl.FP32] = pl.add(x_iter, x_iter) | ||
| if i < 5: | ||
| continue | ||
| b: pl.Tensor[[64], pl.FP32] = pl.add(a, a) | ||
| if i < 10: | ||
| continue | ||
| c: pl.Tensor[[64], pl.FP32] = pl.add(b, b) | ||
| if i < 15: | ||
| continue | ||
| x_iter = pl.yield_(c) # noqa: PLW2901 | ||
| return x_iter | ||
|
|
||
| After = passes.lower_break_continue()(Before) | ||
| printed = After.as_python() | ||
| assert not _has_bare_keyword(printed, "continue") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) |
There was a problem hiding this comment.
Please add direct pl.while_ cases and stronger structural assertions.
This suite is still almost entirely pl.range(...) + as_python() text checks, so it leaves the dedicated LowerWhileWithContinue/LowerWhileWithBreak paths untested and it won’t catch semantic regressions in nested branches if the rendered text still looks plausible. At minimum, add source pl.while_ fixtures and a few assertions on the rewritten IR shape for nested if cases instead of only checking that bare keywords disappeared.
🧰 Tools
🪛 Ruff (0.15.5)
[warning] 97-97: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
[warning] 247-247: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
[warning] 470-470: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
[warning] 659-659: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
[warning] 676-676: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ut/ir/transforms/test_lower_break_continue_pass.py` around lines 25 -
911, Tests in test_lower_break_continue_pass.py only cover pl.range (ForStmt)
textual output and miss direct pl.while_ cases and structural IR assertions; add
new fixtures that use pl.while_ loops (e.g., create tests similar to
test_break_in_for/test_continue_in_for but defining kernels with pl.while_
instead of pl.range) and assert that passes.lower_break_continue() invokes the
LowerWhileWithContinue/LowerWhileWithBreak paths by checking IR shape via
ir.assert_structural_equal or inspecting node types (e.g., ensure the
transformed function contains WhileStmt nodes or specific pass-produced
properties), and strengthen a few nested-if tests (like
test_continue_in_nested_if/test_break_in_nested_if) to assert the IR structure
(not just absence of "continue"/"break") after passes.lower_break_continue() so
regressions in lowering are caught.
|
Closing as duplicate of #494. |
Summary
LowerBreakContinuepass that rewritesBreakStmt/ContinueStmtin InCore/AIC/AIV functions into equivalent structured control flowIfStmtwith phi-nodereturn_vars— continue path yields current iter_arg values, normal path yields computed values, single trailingYieldStmtuses phi resultsForStmttoWhileStmtwith__brk_flagIterArg — break path sets flag toTrue, while condition checksnot __brk_flagResolveTransposeLayoutand beforeInitMemRefin both Default and CCE strategiesSSAFormandSplitIncoreOrchpropertiesFiles Changed
src/ir/transforms/lower_break_continue_pass.cpp(new)passes.h(factory declaration),pass_properties.h(properties)CMakeLists.txtpasses.cpp,passes.pyi,pass_manager.pytest_lower_break_continue_pass.py(40 tests),test_pass_manager.py(updated)11-lower_break_continue.md(en + zh-cn), doc renumbering for subsequent passesTest plan
Fixes #448