Skip to content

[Fix][DSL] Fix while-loop lowering, nested control flow, and add fp32 GEMM output#232

Open
ChuanLi1101 wants to merge 3 commits intomainfrom
fix/while-loop-and-fp32-output
Open

[Fix][DSL] Fix while-loop lowering, nested control flow, and add fp32 GEMM output#232
ChuanLi1101 wants to merge 3 commits intomainfrom
fix/while-loop-and-fp32-output

Conversation

@ChuanLi1101
Copy link

Summary

  • Fix while-loop SCF lowering ([Issue]: While loop counter update issue #211, [Issue]: _ASTREWRITE_MARKER prevents later AST rewriter passes from visiting generated function bodies #210): The CanonicalizeWhile AST rewriter had three critical bugs: (1) the before region used original operands instead of block arguments, making the condition evaluate to a constant; (2) the after region had no scf.yield terminator; (3) carry variables were never rebound from block arguments or WhileOp results. This PR clones the condition op with remapped block arguments, adds carry-variable rebinding at body start, explicit scf_while_yield_ at body end, and post-loop rebinding from WhileOp results.

  • Fix nested control flow rewriting ([Issue]: _ASTREWRITE_MARKER prevents later AST rewriter passes from visiting generated function bodies #210): The _ASTREWRITE_MARKER check in Transformer.visit_FunctionDef prevented all transformers from visiting children of generated helper functions (__then_X / __else_X), so a while or for loop nested inside an if-branch was never lowered to SCF ops.

  • Add fp32 output to preshuffle GEMM: The direct epilog path now supports out_dtype="fp32" by skipping the arith.trunc_f (accumulator is already f32) and adjusting buffer resource byte calculations. The cshuffle epilog path rejects fp32 with a clear error.

Test plan

  • Verify test_ast_rewriter.py passes: AST-level checks for while-loop transformation, nested control flow, and fp32 parameter validation
  • Verify existing test_preshuffle_gemm.py tests still pass (no regression)
  • On GPU: test preshuffle GEMM with out_dtype="fp32" for correctness
  • On GPU: test a kernel using while loop with dynamic condition

Implement a single-query-token Flash Attention kernel targeting the decode phase of autoregressive LLM inference. Uses online softmax with warp-level xor-shuffle reductions on AMD wave64. Includes correctness and performance tests against PyTorch SDPA reference.

Made-with: Cursor
… GEMM output

- Fix CanonicalizeWhile: clone condition op in the before region with
  block arguments so the loop condition is re-evaluated each iteration
  instead of using stale values from outside the WhileOp (#211)

- Fix CanonicalizeWhile: add carry-variable rebinding at body start
  (from after-region block args) and explicit scf.yield at body end
  so the after region has a proper terminator with updated values (#211)

- Fix CanonicalizeWhile: rebind carry variables from WhileOp results
  after the loop so subsequent code sees the final loop-carried values

- Fix _ASTREWRITE_MARKER: always visit children of generated helper
  functions so nested control flow (while inside if, for inside if)
  is still lowered by subsequent AST transformers (#210)

- Add fp32 output support to preshuffle GEMM (direct epilog path):
  skip arith.trunc_f when out_dtype="fp32" and adjust buffer resource
  byte calculations for 4-byte elements

- Add AST-level unit tests for the while-loop transformation and
  fp32 output parameter validation

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant