From ba04383786d8b536a974c34e30a49a664f41fd30 Mon Sep 17 00:00:00 2001 From: "Chuan(Richard) Li" Date: Thu, 26 Feb 2026 16:53:26 -0800 Subject: [PATCH 1/2] Add decode-phase Flash Attention kernel for LLM inference Implement a single-query-token Flash Attention kernel targeting the decode phase of autoregressive LLM inference. Uses online softmax with warp-level xor-shuffle reductions on AMD wave64. Includes correctness and performance tests against PyTorch SDPA reference. Made-with: Cursor --- kernels/flash_decode_attention.py | 284 +++++++++++++++++++ tests/kernels/test_flash_decode_attention.py | 173 +++++++++++ 2 files changed, 457 insertions(+) create mode 100644 kernels/flash_decode_attention.py create mode 100644 tests/kernels/test_flash_decode_attention.py diff --git a/kernels/flash_decode_attention.py b/kernels/flash_decode_attention.py new file mode 100644 index 00000000..8f9cc283 --- /dev/null +++ b/kernels/flash_decode_attention.py @@ -0,0 +1,284 @@ +"""Flash Decode Attention kernel builder. + +Single-query (decode-phase) attention using online softmax: + O[h,d] = sum_j( softmax(Q[h,:] . K[h,j,:] / sqrt(d_k))_j * V[h,j,d] ) + +Architecture: + Grid: (total_heads, 1, 1) -- one wavefront per (batch, head) + Block: (WARP_SIZE, 1, 1) -- AMD wave64, barrier-free dot product reduction + +Each thread owns ELEMS_PER_THREAD = head_dim / WARP_SIZE output elements. +Dot products Q.K[j] use intra-warp xor-shuffle sum reduction so all lanes +see the same score without shared-memory barriers. +Online softmax avoids materializing the full attention score matrix. + +Memory layout (row-major, batch and heads flattened into dim-0): + Q: [total_heads, head_dim] + K: [total_heads, seq_len, head_dim] + V: [total_heads, seq_len, head_dim] + O: [total_heads, head_dim] + +where total_heads = batch_size * num_heads. +""" + +from _mlir import ir + +from flydsl.dialects.ext import flir, arith, gpu +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_decode_attention" + +WARP_SIZE = 64 + + +def dtype_to_elem_type(dtype_str: str): + if dtype_str == "f32": + return T.f32() + if dtype_str == "f16": + return T.f16() + if dtype_str == "bf16": + return T.bf16() + raise ValueError(f"unsupported dtype: {dtype_str}") + + +def build_flash_decode_attention_module( + seq_len: int, + head_dim: int, + dtype_str: str = "f16", +): + """Build MLIR module for decode-phase flash attention. + + Args: + seq_len: KV cache sequence length (compile-time constant). + head_dim: per-head dimension, must be divisible by WARP_SIZE (64). + dtype_str: element type for Q/K/V/O ("f32", "f16", or "bf16"). + + Returns: + An MlirModule instance whose ``__call__`` launches the kernel. + """ + if head_dim % WARP_SIZE != 0: + raise ValueError( + f"head_dim ({head_dim}) must be divisible by WARP_SIZE ({WARP_SIZE})" + ) + + arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + BLOCK_THREADS = WARP_SIZE + ELEMS_PER_THREAD = head_dim // WARP_SIZE + + _state = {} + + class _FlashDecodeAttn(flir.MlirModule): + GPU_MODULE_NAME = "flash_decode_attn" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + _state["elem_type"] = dtype_to_elem_type(dtype_str) + _state["compute_type"] = T.f32() + + @flir.kernel + def flash_decode_attention_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + K: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + V: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + O: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + total_heads: lambda: T.index(), + ): + elem_type = _state["elem_type"] + compute_type = _state["compute_type"] + fm_fast = flir.arith.FastMathFlags.fast + + h = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + rsqrt_d = arith.constant(1.0 / (head_dim ** 0.5), type=compute_type) + + # Thread t owns output elements [d_base .. d_base + ELEMS_PER_THREAD). + c_ept = flir.const_index(ELEMS_PER_THREAD) + d_base = flir.arith.MulIOp( + arith.as_value(tid), arith.as_value(c_ept) + ).result + + # Pre-compute per-element head-dim indices. + d_indices = [] + for e in range_constexpr(ELEMS_PER_THREAD): + d_off = flir.const_index(e) + d_indices.append( + flir.arith.AddIOp( + arith.as_value(d_base), arith.as_value(d_off) + ).result + ) + + # Load this thread's Q elements into registers. + q_local = [] + for e in range_constexpr(ELEMS_PER_THREAD): + q_e = flir.memref.load(Q, [arith.as_value(h), d_indices[e]]) + q_f = ( + q_e + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(q_e)) + ) + q_local.append(q_f) + + # ---- online softmax state ---- + m = c_neg_inf # running max + l = c_zero_f # running denominator (sum of exp) + acc = [c_zero_f] * ELEMS_PER_THREAD # weighted V accumulator + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + + # ---- main loop over KV-cache positions (compile-time unrolled) ---- + for j_py in range_constexpr(seq_len): + j = flir.const_index(j_py) + + # Partial dot product: Q[d_base:d_base+EPT] . K[h, j, d_base:d_base+EPT] + partial = c_zero_f + for e in range_constexpr(ELEMS_PER_THREAD): + k_e = flir.memref.load( + K, [arith.as_value(h), arith.as_value(j), d_indices[e]] + ) + k_f = ( + k_e + if dtype_str == "f32" + else flir.arith.extf( + compute_type, arith.as_value(k_e) + ) + ) + qk = flir.arith.MulFOp( + arith.as_value(q_local[e]), + arith.as_value(k_f), + fastmath=fm_fast, + ).result + partial = flir.arith.AddFOp( + arith.as_value(partial), qk, fastmath=fm_fast + ).result + + # Warp-wide sum reduction (xor-shuffle, wave64). + w = arith.as_value(partial) + for sh in [32, 16, 8, 4, 2, 1]: + off = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp( + arith.as_value(w), off, width_i32, mode="xor" + ).shuffleResult + ) + w = flir.arith.AddFOp( + arith.as_value(w), peer, fastmath=fm_fast + ).result + + # score = dot(Q, K_j) / sqrt(head_dim) + score = flir.arith.MulFOp( + arith.as_value(w), + arith.as_value(rsqrt_d), + fastmath=fm_fast, + ).result + + # Online softmax update: + # m_new = max(m, score) + # correction = exp2((m_old - m_new) * log2e) + # p = exp2((score - m_new) * log2e) + # l = l * correction + p + # acc[e] = acc[e] * correction + p * V[h, j, e] + m_new = flir.arith.MaximumFOp( + arith.as_value(m), arith.as_value(score) + ).result + + diff_m = flir.arith.SubFOp( + arith.as_value(m), m_new, fastmath=fm_fast + ).result + corr_arg = flir.arith.MulFOp( + diff_m, arith.as_value(c_log2e), fastmath=fm_fast + ).result + correction = flir.math.exp2(corr_arg, fastmath=fm_fast) + + diff_s = flir.arith.SubFOp( + arith.as_value(score), m_new, fastmath=fm_fast + ).result + p_arg = flir.arith.MulFOp( + diff_s, arith.as_value(c_log2e), fastmath=fm_fast + ).result + p = flir.math.exp2(p_arg, fastmath=fm_fast) + + l_corr = flir.arith.MulFOp( + arith.as_value(l), + arith.as_value(correction), + fastmath=fm_fast, + ).result + l = flir.arith.AddFOp( + l_corr, arith.as_value(p), fastmath=fm_fast + ).result + + # Update accumulator with weighted V. + new_acc = [] + for e in range_constexpr(ELEMS_PER_THREAD): + v_e = flir.memref.load( + V, [arith.as_value(h), arith.as_value(j), d_indices[e]] + ) + v_f = ( + v_e + if dtype_str == "f32" + else flir.arith.extf( + compute_type, arith.as_value(v_e) + ) + ) + a_corr = flir.arith.MulFOp( + arith.as_value(acc[e]), + arith.as_value(correction), + fastmath=fm_fast, + ).result + pv = flir.arith.MulFOp( + arith.as_value(p), + arith.as_value(v_f), + fastmath=fm_fast, + ).result + new_acc.append( + flir.arith.AddFOp(a_corr, pv, fastmath=fm_fast).result + ) + + acc = new_acc + m = m_new + + # ---- store output: O[h, d] = acc[d] / l ---- + for e in range_constexpr(ELEMS_PER_THREAD): + out_f32 = flir.arith.DivFOp( + arith.as_value(acc[e]), + arith.as_value(l), + fastmath=fm_fast, + ).result + if dtype_str != "f32": + out_e = flir.arith.truncf(elem_type, out_f32) + else: + out_e = out_f32 + flir.memref.store( + arith.as_value(out_e), + O, + [arith.as_value(h), d_indices[e]], + ) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + K: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + V: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + O: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + total_heads: lambda: T.index(), + ): + c1 = flir.arith_ext.index(1) + gx = total_heads + bx = flir.arith_ext.index(BLOCK_THREADS) + flir.gpu_ext.LaunchFuncOp( + ["flash_decode_attn", "flash_decode_attention_kernel"], + grid_size=(gx, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, total_heads], + ) + + return _FlashDecodeAttn() diff --git a/tests/kernels/test_flash_decode_attention.py b/tests/kernels/test_flash_decode_attention.py new file mode 100644 index 00000000..08a04e19 --- /dev/null +++ b/tests/kernels/test_flash_decode_attention.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +""" +Flash Decode Attention Test + +Verifies the single-query (decode-phase) attention kernel: + O = softmax(Q @ K^T / sqrt(d)) @ V + +where Q has a single token per (batch, head). + +Grid: (batch_size * num_heads, 1, 1) +Block: (64, 1, 1) -- single AMD wave64 +""" + +import sys +import os +from pathlib import Path + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +import pytest + +try: + import torch + import torch.nn.functional as F +except ImportError: + torch = None +if torch is None or not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +from tests.test_common import run_perftest + +import flydsl +from kernels.flash_decode_attention import ( + build_flash_decode_attention_module, + KERNEL_NAME, +) + +WARMUP_ITERS = 5 +BENCH_ITERS = 20 + +DTYPE_MAP = { + "f32": torch.float32, + "f16": torch.float16, + "bf16": torch.bfloat16, +} + +ATOL_MAP = { + "f32": 1e-4, + "f16": 2e-2, + "bf16": 3e-2, +} + + +def run_test( + batch_size: int, + num_heads: int, + seq_len: int, + head_dim: int, + dtype_str: str = "f16", +): + total_heads = batch_size * num_heads + torch_dtype = DTYPE_MAP[dtype_str] + atol = ATOL_MAP[dtype_str] + + print( + f"\nTesting Flash Decode Attention " + f"(B={batch_size}, H={num_heads}, S={seq_len}, D={head_dim}, dtype={dtype_str})" + ) + + try: + m = build_flash_decode_attention_module(seq_len, head_dim, dtype_str) + exe = flydsl.compile(m) + except Exception as e: + print(f"[FAIL] Compile failed: {e}") + import traceback + traceback.print_exc() + return False + + torch.manual_seed(42) + + Q_ref = torch.randn(batch_size, num_heads, 1, head_dim, device="cuda", dtype=torch.float32) + K_ref = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) + V_ref = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) + + # PyTorch reference (always in fp32 for stability). + expected = F.scaled_dot_product_attention( + Q_ref, K_ref, V_ref, is_causal=False + ) # [B, H, 1, D] + expected = expected.squeeze(2).reshape(total_heads, head_dim).to(torch.float32) + + # Prepare device tensors in target dtype, flattened to [total_heads, ...]. + Q_dev = Q_ref.squeeze(2).reshape(total_heads, head_dim).to(torch_dtype).contiguous() + K_dev = K_ref.reshape(total_heads, seq_len, head_dim).to(torch_dtype).contiguous() + V_dev = V_ref.reshape(total_heads, seq_len, head_dim).to(torch_dtype).contiguous() + O_dev = torch.empty(total_heads, head_dim, device="cuda", dtype=torch_dtype) + + print(" Launching kernel...") + + def kernel_launch(): + exe(Q_dev, K_dev, V_dev, O_dev, total_heads) + + kernel_launch() + torch.cuda.synchronize() + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), + num_iters=BENCH_ITERS, + num_warmup=WARMUP_ITERS, + ) + torch.cuda.synchronize() + avg_ms = avg_us / 1000.0 + + elem_bytes = 4 if dtype_str == "f32" else 2 + kv_bytes = 2 * total_heads * seq_len * head_dim * elem_bytes + bandwidth_gbs = kv_bytes / (avg_us / 1e6) / 1e9 + print(f" Kernel avg time: {avg_ms:.4f} ms (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f" Bandwidth (KV read): {bandwidth_gbs:.2f} GB/s") + + output_f32 = O_dev.to(torch.float32) + error = (output_f32 - expected).abs().max().item() + print(f" Max absolute error: {error:.2e} (atol={atol})") + + if error < atol: + print(" PASSED") + return True + else: + print(" FAILED") + print(" Expected (first 8):", expected[0, :8]) + print(" Got (first 8):", output_f32[0, :8]) + return False + + +def test_flash_decode_attention(): + """Pytest entry point -- small configs for CI.""" + configs = [ + # (batch, heads, seq_len, head_dim, dtype) + (1, 1, 64, 128, "f32"), + (2, 4, 128, 128, "f16"), + (1, 2, 64, 128, "bf16"), + ] + + shapes_env = os.environ.get("FLYDSL_FLASH_ATTN_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + b, h, s, d, dt = [x.strip() for x in p.split(",")] + configs.append((int(b), int(h), int(s), int(d), dt)) + + print("=" * 80) + print("Running Flash Decode Attention Tests") + print("=" * 80) + + failures = 0 + for batch, heads, seq_len, head_dim, dtype in configs: + if not run_test(batch, heads, seq_len, head_dim, dtype): + failures += 1 + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + + assert failures == 0, f"{failures} test(s) failed" + + +if __name__ == "__main__": + test_flash_decode_attention() From a4832bb46d5b23d647e94ab56c35e01d44e011ad Mon Sep 17 00:00:00 2001 From: "Chuan(Richard) Li" Date: Tue, 17 Mar 2026 03:02:16 -0700 Subject: [PATCH 2/2] [Fix][DSL] Fix while-loop lowering, nested control flow, and add fp32 GEMM output - Fix CanonicalizeWhile: clone condition op in the before region with block arguments so the loop condition is re-evaluated each iteration instead of using stale values from outside the WhileOp (#211) - Fix CanonicalizeWhile: add carry-variable rebinding at body start (from after-region block args) and explicit scf.yield at body end so the after region has a proper terminator with updated values (#211) - Fix CanonicalizeWhile: rebind carry variables from WhileOp results after the loop so subsequent code sees the final loop-carried values - Fix _ASTREWRITE_MARKER: always visit children of generated helper functions so nested control flow (while inside if, for inside if) is still lowered by subsequent AST transformers (#210) - Add fp32 output support to preshuffle GEMM (direct epilog path): skip arith.trunc_f when out_dtype="fp32" and adjust buffer resource byte calculations for 4-byte elements - Add AST-level unit tests for the while-loop transformation and fp32 output parameter validation Made-with: Cursor --- kernels/preshuffle_gemm.py | 22 ++- python/flydsl/compiler/ast_rewriter.py | 196 +++++++++++++++++++++---- tests/pyir/test_ast_rewriter.py | 170 +++++++++++++++++++++ 3 files changed, 357 insertions(+), 31 deletions(-) create mode 100644 tests/pyir/test_ast_rewriter.py diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index b889bcca..00c615ee 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -66,11 +66,17 @@ def compile_preshuffle_gemm_a8( "in_dtype must be one of ('fp8','int8','int4','fp16','bf16','fp4'), " f"got {in_dtype!r}" ) - if out_dtype not in ("fp16", "bf16"): + if out_dtype not in ("fp16", "bf16", "fp32"): raise ValueError( - f"out_dtype must be 'fp16' or 'bf16', got {out_dtype!r}" + f"out_dtype must be 'fp16', 'bf16', or 'fp32', got {out_dtype!r}" + ) + if out_dtype == "fp32" and use_cshuffle_epilog: + raise ValueError( + "fp32 output is only supported with the direct epilog " + "(use_cshuffle_epilog=False)" ) _out_is_bf16 = out_dtype == "bf16" + _out_is_fp32 = out_dtype == "fp32" is_fp4 = in_dtype == "fp4" is_int4 = in_dtype == "int4" is_int8 = (in_dtype == "int8") or is_int4 @@ -169,7 +175,11 @@ def _mfma_pack_ty(): return T.i16x4 return T.i64 + _out_elem_bytes = 4 if _out_is_fp32 else 2 + def _out_elem(): + if _out_is_fp32: + return T.f32 return T.bf16 if _out_is_bf16 else T.f16 epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" @@ -177,7 +187,7 @@ def _out_elem(): # ── LDS sizing (pure Python, no MLIR ops) ──────────────────────────────── lds_tile_bytes = int(tile_m) * int(lds_stride_bytes) // a_elem_vec_pack - lds_out_bytes = 2 * int(tile_m) * int(tile_n) if use_cshuffle_epilog else 0 + lds_out_bytes = _out_elem_bytes * int(tile_m) * int(tile_n) if use_cshuffle_epilog else 0 if int(lds_stage) == 2: assert lds_out_bytes % 2 == 0, "lds_out_bytes should be multiple of 2" @@ -278,7 +288,7 @@ def kernel_gemm( # ---- Buffer resources (runtime byte sizes for OOB protection) ---- _a_nrec = arith.index_cast(T.i64, c_m * (K * elem_bytes // a_elem_vec_pack)) - _c_nrec = arith.index_cast(T.i64, c_m * c_n * 2) + _c_nrec = arith.index_cast(T.i64, c_m * c_n * _out_elem_bytes) a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False, num_records_bytes=_a_nrec) c_rsrc = buffer_ops.create_buffer_resource(arg_c, max_size=False, @@ -862,9 +872,9 @@ def body_row(*, mi, ii, row_in_tile, row): val_s = (val * s_a) * s_b_vals[ni] else: val_s = val - val_f16 = arith.trunc_f(_out_elem(), val_s) + val_out = val_s if _out_is_fp32 else arith.trunc_f(_out_elem(), val_s) idx_out = idx_base + (ni * 16) - buffer_ops.buffer_store(val_f16, c_rsrc, idx_out) + buffer_ops.buffer_store(val_out, c_rsrc, idx_out) mfma_epilog( use_cshuffle=False, arith=arith, range_constexpr=range_constexpr, diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 6c0e1a91..9ed35766 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -129,8 +129,9 @@ def __init__(self, context, first_lineno): self.first_lineno = first_lineno def visit_FunctionDef(self, node: ast.FunctionDef): - if getattr(node, _ASTREWRITE_MARKER, False): - return node + # Always visit children so nested control flow inside generated + # helpers (e.g. __then_X / __else_X from scf_if_dispatch) is still + # lowered by subsequent transformers. node = self.generic_visit(node) return node @@ -454,6 +455,8 @@ def visit_Yield(self, node: ast.Yield) -> ast.Expr: @ASTRewriter.register class CanonicalizeWhile(Transformer): + _last_while_op = None + @staticmethod def scf_while_init(cond, *, loc=None, ip=None): if loc is None: @@ -464,15 +467,36 @@ def wrapper(): inits = list(cond.owner.operands) result_types = [i.type for i in inits] while_op = scf.WhileOp(result_types, inits, loc=loc, ip=ip) - while_op.regions[0].blocks.append(*[i.type for i in inits]) + CanonicalizeWhile._last_while_op = while_op + + # before region: clone the condition op with block arguments so + # the condition is re-evaluated with updated loop-carried values. + while_op.regions[0].blocks.append(*result_types) before = while_op.regions[0].blocks[0] - while_op.regions[1].blocks.append(*[i.type for i in inits]) + + with ir.InsertionPoint(before): + operand_map = {id(v): ba for v, ba in zip(inits, before.arguments)} + cond_def = cond.owner + new_operands = [operand_map.get(id(o), o) for o in cond_def.operands] + cond_rtypes = [r.type for r in cond_def.results] + attrs = {n: cond_def.attributes[n] for n in cond_def.attributes} + new_cond_op = ir.Operation.create( + cond_def.name, + results=cond_rtypes, + operands=new_operands, + attributes=attrs, + ) + scf.ConditionOp(new_cond_op.results[0], list(before.arguments)) + + # after region: the body will be traced here. + while_op.regions[1].blocks.append(*result_types) after = while_op.regions[1].blocks[0] - with ir.InsertionPoint(before) as ip: - cond_op = scf.ConditionOp(cond, list(before.arguments)) - cond.owner.move_before(cond_op) + with ir.InsertionPoint(after): - yield inits + # Yield block arguments so the caller can rebind Python vars. + yield list(after.arguments) + # Body has completed; scf_while_yield_ should have emitted + # the scf.YieldOp terminator already. if hasattr(CanonicalizeWhile.scf_while_init, "wrapper"): next(CanonicalizeWhile.scf_while_init.wrapper, False) @@ -487,47 +511,169 @@ def scf_while_gen(cond, *, loc=None, ip=None): yield CanonicalizeWhile.scf_while_init(cond, loc=loc, ip=ip) yield CanonicalizeWhile.scf_while_init(cond, loc=loc, ip=ip) + @staticmethod + def scf_while_yield_(yield_vals): + """Emit scf.YieldOp for the while-loop after region.""" + processed = [] + for v in yield_vals: + if isinstance(v, ir.Value): + processed.append(v) + elif hasattr(v, "ir_value"): + processed.append(v.ir_value()) + else: + processed.append(v) + scf.YieldOp(processed) + + @staticmethod + def scf_while_get_results_(): + """Return the results of the most recent WhileOp.""" + if CanonicalizeWhile._last_while_op is not None: + return list(CanonicalizeWhile._last_while_op.results) + return [] + @classmethod def rewrite_globals(cls): return { "scf_while_gen": cls.scf_while_gen, "scf_while_init": cls.scf_while_init, + "scf_while_yield_": cls.scf_while_yield_, + "scf_while_get_results_": cls.scf_while_get_results_, } + @staticmethod + def _extract_carry_names(test_node): + """Map {operand_index: variable_name} for Name nodes in Compare.""" + if isinstance(test_node, ast.NamedExpr): + return CanonicalizeWhile._extract_carry_names(test_node.value) + carry = {} + if isinstance(test_node, ast.Compare): + if isinstance(test_node.left, ast.Name): + carry[0] = test_node.left.id + for i, comp in enumerate(test_node.comparators): + if isinstance(comp, ast.Name): + carry[i + 1] = comp.id + elif isinstance(test_node, ast.Name): + carry[0] = test_node.id + return carry + def visit_While(self, node: ast.While) -> List[ast.AST]: if _is_constexpr(node.test): node.test = _unwrap_constexpr(node.test) node = self.generic_visit(node) return node node = self.generic_visit(node) + + carry_names = self._extract_carry_names(node.test) + if isinstance(node.test, ast.NamedExpr): test = node.test.value else: test = node.test - w = ast.Call(func=ast.Name("scf_while_gen", ctx=ast.Load()), args=[test], keywords=[]) - w = ast.copy_location(w, node) - assign = ast.Assign( - targets=[ast.Name(f"w_{node.lineno}", ctx=ast.Store())], - value=w, + + uid = node.lineno + init_name = f"__init__{uid}" + gen_name = f"__wgen_{uid}" + yield_name = f"__wyield_{uid}" + results_name = f"__wres_{uid}" + + # --- w = scf_while_gen(cond) --- + gen_call = ast.Call( + func=ast.Name("scf_while_gen", ctx=ast.Load()), + args=[test], keywords=[], + ) + gen_assign = ast.Assign( + targets=[ast.Name(gen_name, ctx=ast.Store())], + value=gen_call, ) - assign = ast.fix_missing_locations(ast.copy_location(assign, node)) + gen_assign = ast.fix_missing_locations(ast.copy_location(gen_assign, node)) - next_ = ast.Call( + # --- while (__init__ := next(w, False)): --- + next_call = ast.Call( func=ast.Name("next", ctx=ast.Load()), args=[ - ast.Name(f"w_{node.lineno}", ctx=ast.Load()), + ast.Name(gen_name, ctx=ast.Load()), ast.Constant(False, kind="bool"), ], keywords=[], ) - next_ = ast.fix_missing_locations(ast.copy_location(next_, node)) - if isinstance(node.test, ast.NamedExpr): - node.test.value = next_ - else: - new_test = ast.NamedExpr(target=ast.Name(f"__init__{node.lineno}", ctx=ast.Store()), value=next_) - new_test = ast.copy_location(new_test, node) - node.test = new_test + next_call = ast.fix_missing_locations(ast.copy_location(next_call, node)) + walrus = ast.NamedExpr( + target=ast.Name(init_name, ctx=ast.Store()), + value=next_call, + ) + walrus = ast.copy_location(walrus, node) + node.test = walrus + + # --- Body prologue: rebind carry vars from after-region block args --- + prologue = [] + for idx, name in carry_names.items(): + stmt = ast.Assign( + targets=[ast.Name(name, ctx=ast.Store())], + value=ast.Subscript( + value=ast.Name(init_name, ctx=ast.Load()), + slice=ast.Constant(idx), + ctx=ast.Load(), + ), + ) + prologue.append(ast.fix_missing_locations(ast.copy_location(stmt, node))) + + # --- Body epilogue: yield updated carry vars back --- + epilogue = [] + + # __wyield__ = list(__init__) + epilogue.append(ast.fix_missing_locations(ast.copy_location( + ast.Assign( + targets=[ast.Name(yield_name, ctx=ast.Store())], + value=ast.Call( + func=ast.Name("list", ctx=ast.Load()), + args=[ast.Name(init_name, ctx=ast.Load())], + keywords=[], + ), + ), node))) + + # __wyield__[i] = var (override with updated values) + for idx, name in carry_names.items(): + epilogue.append(ast.fix_missing_locations(ast.copy_location( + ast.Assign( + targets=[ast.Subscript( + value=ast.Name(yield_name, ctx=ast.Load()), + slice=ast.Constant(idx), + ctx=ast.Store(), + )], + value=ast.Name(name, ctx=ast.Load()), + ), node))) + + # scf_while_yield_(__wyield__) + epilogue.append(ast.fix_missing_locations(ast.copy_location( + ast.Expr(value=ast.Call( + func=ast.Name("scf_while_yield_", ctx=ast.Load()), + args=[ast.Name(yield_name, ctx=ast.Load())], + keywords=[], + )), node))) + + node.body = prologue + node.body + epilogue + + # --- Post-loop: rebind carry vars from WhileOp results --- + post = [] + post.append(ast.fix_missing_locations(ast.copy_location( + ast.Assign( + targets=[ast.Name(results_name, ctx=ast.Store())], + value=ast.Call( + func=ast.Name("scf_while_get_results_", ctx=ast.Load()), + args=[], keywords=[], + ), + ), node))) + + for idx, name in carry_names.items(): + post.append(ast.fix_missing_locations(ast.copy_location( + ast.Assign( + targets=[ast.Name(name, ctx=ast.Store())], + value=ast.Subscript( + value=ast.Name(results_name, ctx=ast.Load()), + slice=ast.Constant(idx), + ctx=ast.Load(), + ), + ), node))) node = ast.fix_missing_locations(node) - assign = ast.fix_missing_locations(assign) - return [assign, node] + return [gen_assign, node] + post diff --git a/tests/pyir/test_ast_rewriter.py b/tests/pyir/test_ast_rewriter.py new file mode 100644 index 00000000..63238f3b --- /dev/null +++ b/tests/pyir/test_ast_rewriter.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""Tests for the AST rewriter transformations. + +Focuses on the AST-level transformation output without requiring a full MLIR +context or GPU hardware. +""" + +import ast +import types +from textwrap import dedent + +import pytest + + +def _parse_func(src: str) -> ast.FunctionDef: + """Parse a single function from source and return its AST node.""" + module = ast.parse(dedent(src)) + assert isinstance(module.body[0], ast.FunctionDef) + return module.body[0] + + +def _apply_transformer(func_node, transformer_cls, first_lineno=0): + """Run one NodeTransformer on a function AST and return the result.""" + ctx = types.SimpleNamespace() + rewriter = transformer_cls(context=ctx, first_lineno=first_lineno) + return rewriter.generic_visit(func_node) + + +class TestCanonicalizeWhileAST: + """Verify the CanonicalizeWhile transformer produces the expected AST + pattern without executing anything at the MLIR level.""" + + @staticmethod + def _get_transformer(): + try: + from flydsl.compiler.ast_rewriter import CanonicalizeWhile + return CanonicalizeWhile + except ImportError: + pytest.skip("flydsl not importable") + + def test_while_loop_generates_rebind_prologue(self): + """Body should start with carry-variable rebinding from block args.""" + CW = self._get_transformer() + + src = """\ + def f(): + while x < 10: + x = x + 1 + """ + func_node = _parse_func(src) + result = _apply_transformer(func_node, CW) + code = ast.unparse(result) + + assert "scf_while_gen" in code, "should generate scf_while_gen call" + assert "__init__" in code, "should have walrus binding for block args" + assert "scf_while_yield_" in code, "should yield updated carry vars" + assert "scf_while_get_results_" in code, "should rebind from WhileOp results" + + def test_while_loop_generates_yield_epilogue(self): + """Body should end with an explicit yield of updated carry vars.""" + CW = self._get_transformer() + + src = """\ + def f(): + while i < N: + i = i + 1 + """ + func_node = _parse_func(src) + result = _apply_transformer(func_node, CW) + code = ast.unparse(result) + + assert "scf_while_yield_" in code + # Both carry vars (i and N) should appear in post-loop rebinding + assert "scf_while_get_results_" in code + + def test_constexpr_while_unchanged(self): + """while const_expr(cond) should stay as a plain Python while.""" + CW = self._get_transformer() + + src = """\ + def f(): + while const_expr(x < 10): + x = x + 1 + """ + func_node = _parse_func(src) + result = _apply_transformer(func_node, CW) + code = ast.unparse(result) + + assert "scf_while_gen" not in code, "const_expr while should NOT be lowered" + + +class TestNestedControlFlowAST: + """Verify the _ASTREWRITE_MARKER fix allows nested control flow.""" + + @staticmethod + def _get_if_transformer(): + try: + from flydsl.compiler.ast_rewriter import ReplaceIfWithDispatch + return ReplaceIfWithDispatch + except ImportError: + pytest.skip("flydsl not importable") + + @staticmethod + def _get_while_transformer(): + try: + from flydsl.compiler.ast_rewriter import CanonicalizeWhile + return CanonicalizeWhile + except ImportError: + pytest.skip("flydsl not importable") + + def test_while_inside_if_branch_is_lowered(self): + """A while loop nested inside a then-branch should be transformed + by CanonicalizeWhile even though the branch is a generated function.""" + IfT = self._get_if_transformer() + WhileT = self._get_while_transformer() + + src = """\ + def f(): + if some_cond(): + while x < 10: + x = x + 1 + """ + func_node = _parse_func(src) + # First pass: if → scf_if_dispatch + result = _apply_transformer(func_node, IfT) + # Second pass: while → scf_while_gen + result = _apply_transformer(result, WhileT) + code = ast.unparse(result) + + assert "scf_while_gen" in code, \ + "while loop inside if-branch should be lowered after marker fix" + + +class TestFP32OutputValidation: + """Test that compile_preshuffle_gemm_a8 accepts fp32 out_dtype.""" + + @staticmethod + def _get_compile_fn(): + try: + from kernels.preshuffle_gemm import compile_preshuffle_gemm_a8 + return compile_preshuffle_gemm_a8 + except ImportError: + pytest.skip("kernels.preshuffle_gemm not importable") + + def test_fp32_rejected_with_cshuffle(self): + """fp32 output + cshuffle should raise ValueError.""" + fn = self._get_compile_fn() + with pytest.raises(ValueError, match="fp32.*direct epilog"): + fn(K=128, tile_m=64, tile_n=64, tile_k=64, + out_dtype="fp32", use_cshuffle_epilog=True) + + def test_fp32_accepted_without_cshuffle(self): + """fp32 output + direct epilog should NOT raise ValueError for dtype.""" + fn = self._get_compile_fn() + # This will fail later (needs MLIR context) but should pass + # the dtype validation stage. + try: + fn(K=128, tile_m=64, tile_n=64, tile_k=64, + out_dtype="fp32", use_cshuffle_epilog=False) + except ValueError as e: + assert "out_dtype" not in str(e), \ + f"fp32 should be accepted as out_dtype, got: {e}" + except Exception: + pass # other errors (MLIR context, GPU) are expected + + def test_invalid_out_dtype_rejected(self): + """Unsupported out_dtype should raise ValueError.""" + fn = self._get_compile_fn() + with pytest.raises(ValueError, match="out_dtype"): + fn(K=128, tile_m=64, tile_n=64, tile_k=64, out_dtype="int8")