From eecee073a311b1ad4aad4bd24b97fdfc2caf3d03 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 13 Mar 2026 23:32:23 +0000 Subject: [PATCH 01/63] save work --- examples/xegpu/softmax.py | 336 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 336 insertions(+) create mode 100644 examples/xegpu/softmax.py diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py new file mode 100644 index 00000000..bb90f812 --- /dev/null +++ b/examples/xegpu/softmax.py @@ -0,0 +1,336 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU softmax benchmark. +""" + +import argparse +import ctypes +from typing import Optional +from functools import cached_property + +import numpy as np +from mlir import ir +from mlir.execution_engine import ExecutionEngine +from mlir.dialects import linalg, gpu, bufferization, arith, tensor, func + +from lighthouse.workload import benchmark +from lighthouse.utils.memref import to_ctype as memref_to_ctype +from lighthouse.utils.numpy import numpy_to_ctype +from lighthouse.utils.mlir import func_cif +from lighthouse.ingress.mlir_gen import get_mlir_elem_type +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs, emit_buf_to_tensor + +from xegpu_workload import XeGPUWorkload + + +def softmax_complexity(M: int, N: int, nbytes: int): + """ + Complexity of softmax operation. + + For each row: + - O(N) to find max + - O(N) to compute exp(x - max) and sum + - O(N) to normalize + Total: 3*N operations per row, but with transcendental (exp) operations + """ + # Approximation: 5 FLOPs per element (max, sub, exp, sum, div) + # exp is expensive but we count it as ~1 FLOP for simplicity + flop_count = M * N * 5 + memory_reads = M * N * nbytes # read input + memory_writes = M * N * nbytes # write output + return flop_count, memory_reads, memory_writes + + +class XeGPUSoftmax(XeGPUWorkload): + """ + Softmax workload on XeGPU. + + Computes softmax along the last dimension (rows): + output[i, j] = exp(input[i, j] - max_i) / sum_i(exp(input[i, j] - max_i)) + + where max_i and sum_i are computed over row i. + """ + + def __init__( + self, + M: int, + N: int, + dtype: str = "f32", + ): + super().__init__() + self.M = M + self.N = N + self.shape = (M, N) + assert dtype == "f32", "Only f32 type is supported for softmax" + self.dtype_str = dtype + type_str_to_numpy = { + "f16": np.float16, + "f32": np.float32, + } + self.dtype = type_str_to_numpy[dtype] + + @cached_property + def _initial_host_arrays(self) -> tuple[np.ndarray]: + """Generate initial values on host with numpy.""" + np.random.seed(42) + # Use values in range [-0.5, 0.5] to avoid numerical issues + input_arr = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + return (input_arr,) + + @cached_property + def _reference_solution(self) -> np.ndarray: + """Compute reference solution on host with numpy.""" + (input_arr,) = self._initial_host_arrays + # Use float32 for computation + x = input_arr.astype(np.float32) + # Compute softmax along axis 1 (each row independently) + # Numerically stable version: subtract max before exp + max_vals = np.max(x, axis=1, keepdims=True) + exp_vals = np.exp(x - max_vals) + sum_vals = np.sum(exp_vals, axis=1, keepdims=True) + output = exp_vals / sum_vals + return output.astype(self.dtype) + + def _get_input_arrays( + self, execution_engine: ExecutionEngine + ) -> list[ctypes.Structure]: + # Allocate device memory for input and output + input_gpu = self._allocate_array( + "input", self.shape, self.dtype_str, execution_engine + ) + output_gpu = self._allocate_array( + "output", self.shape, self.dtype_str, execution_engine + ) + + # Copy input to device + (input_host,) = self._initial_host_arrays + copy_fn = f"gpu_copy_2d_{self.dtype_str}" + execution_engine.invoke( + copy_fn, numpy_to_ctype(input_host), memref_to_ctype(input_gpu) + ) + + # Return memrefs: [output, input] + return [output_gpu, input_gpu] + + def check_correctness( + self, execution_engine: ExecutionEngine, verbose: int = 0 + ) -> bool: + # Copy result from device to host + output_gpu = self.gpu_memrefs[("output", self.dtype_str)] + output_host = np.zeros(self.shape, dtype=self.dtype) + execution_engine.invoke( + f"gpu_copy_2d_{self.dtype_str}", + memref_to_ctype(output_gpu), + numpy_to_ctype(output_host), + ) + + output_ref = self._reference_solution + output_computed = output_host.astype(np.float32) + + if verbose > 1: + print("Reference solution (first 5 rows):") + print(output_ref[:5]) + print("Computed solution (first 5 rows):") + print(output_computed[:5]) + + # Check row sums are close to 1.0 + row_sums = np.sum(output_computed, axis=1) + sums_ok = np.allclose(row_sums, 1.0, rtol=1e-5, atol=1e-6) + + # Check values match reference + values_ok = np.allclose(output_computed, output_ref, rtol=1e-4, atol=1e-6) + + success = sums_ok and values_ok + + if verbose: + if success: + print("PASSED") + else: + print("FAILED!") + if not sums_ok: + print(f" Row sums check failed. Min: {row_sums.min():.6f}, Max: {row_sums.max():.6f}") + if not values_ok: + max_diff = np.abs(output_computed - output_ref).max() + print(f" Values mismatch. Max abs diff: {max_diff:.6e}") + return success + + def get_complexity(self) -> tuple[int, int, int]: + nbytes = np.dtype(self.dtype).itemsize + return softmax_complexity(self.M, self.N, nbytes) + + def payload_module(self) -> ir.Module: + """Generate MLIR module for softmax payload.""" + mod = ir.Module.create() + dtype = get_mlir_elem_type(self.dtype_str) + memref_t = ir.MemRefType.get(self.shape, dtype) + + with ir.InsertionPoint(mod.body): + # Function signature: payload(output, input) + @func_cif(memref_t, memref_t, name=self.payload_function_name) + def payload(output, input_arg): + # Convert memrefs to tensors + output_tensor = emit_buf_to_tensor(output, restrict=True, writable=True) + input_tensor = emit_buf_to_tensor(input_arg, restrict=True) + + # Create intermediate buffer for softmax (used internally by linalg.softmax) + # This stores the sum of exp values + M, N = self.shape + softmax_buf_type = ir.MemRefType.get((M,N), dtype) + softmax_buf = gpu.alloc(softmax_buf_type, None, [], [], []) + softmax_buf_tensor = emit_buf_to_tensor(softmax_buf, restrict=True, writable=True) + + # Compute softmax along dimension 1 (rows) + # linalg.softmax performs: exp(x - max(x)) / sum(exp(x - max(x))) + result = linalg.softmax( + (input_tensor.type,), input_tensor, softmax_buf_tensor, dimension=1 + ) + + # Materialize result back to output memref + bufferization.materialize_in_destination( + None, result, output, restrict=True, writable=True + ) + + # Cleanup + gpu.dealloc(None, [], softmax_buf) + + # Emit utility functions for GPU memory management + emit_gpu_util_funcs(dtype, rank=2) + + return mod + + def schedule_module( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> ir.Module: + """ + Generate transform schedule for softmax. + + For now, returns an empty schedule. In the future, this would contain + tiling, vectorization, and XeGPU-specific lowering transformations. + """ + # TODO: Implement proper transform schedule + # For now, create a minimal schedule that just applies bufferization + mod = ir.Module.create() + with ir.InsertionPoint(mod.body): + from mlir.dialects import transform + + # Create a simple transform sequence + @func_cif(name="__transform_main") + def transform_main(): + # Empty transform - just identity + # In a full implementation, this would tile, vectorize, + # and lower to XeGPU operations + pass + + return mod + + def shared_libs(self) -> list[str]: + return ["libmlir_levelzero_runtime.so"] + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Softmax using MLIR XeGPU", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--sizes", + type=int, + nargs=2, + default=[1024, 512], + help="M,N matrix sizes (MxN)", + ) + parser.add_argument( + "--wg-tile", + type=int, + nargs=2, + default=[64, 32], + help="Workgroup tile size M,N.", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the softmax computation.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "tiled", + "vectorized", + "bufferized", + "xegpu-initial", + "xegpu-wg", + "final", + ], + help="Dump kernel IR at different stages of lowering and exit without " + "executing the kernel.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "wg_m": args.wg_tile[0], + "wg_n": args.wg_tile[1], + } + + M, N = args.sizes + dtype = "f32" + + with ir.Context(), ir.Location.unknown(): + wload = XeGPUSoftmax(M=M, N=N, dtype=dtype) + + if args.dump_kernel or args.dump_schedule: + wload.lower_payload( + dump_payload=args.dump_kernel, + dump_schedule=args.dump_schedule, + schedule_parameters=params, + ) + else: + times = benchmark( + wload, + nruns=args.nruns, + nwarmup=args.nwarmup, + schedule_parameters=params, + check_correctness=args.check_result, + verbose=1, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + def list2str(a): + return ",".join(map(str, a)) + + parts = [ + f"sizes={list2str(args.sizes)}", + f"dt={dtype}", + f"wg-tile={list2str(args.wg_tile)}", + f"time(us): {elapsed:.2f}", + f"GFLOPS: {gflops:.2f}", + ] + print(" ".join(parts)) From f991027da04fd09f164f10b771ea470b9b577519 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 16 Mar 2026 22:46:40 +0000 Subject: [PATCH 02/63] save work --- examples/xegpu/softmax.py | 93 ++++++++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 17 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index bb90f812..ab61f1e5 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -13,9 +13,9 @@ import numpy as np from mlir import ir from mlir.execution_engine import ExecutionEngine -from mlir.dialects import linalg, gpu, bufferization, arith, tensor, func +from mlir.dialects import linalg, gpu, bufferization, arith, tensor, func, math -from lighthouse.workload import benchmark +from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype from lighthouse.utils.mlir import func_cif @@ -174,35 +174,94 @@ def payload(output, input_arg): output_tensor = emit_buf_to_tensor(output, restrict=True, writable=True) input_tensor = emit_buf_to_tensor(input_arg, restrict=True) - # Create intermediate buffer for softmax (used internally by linalg.softmax) - # This stores the sum of exp values M, N = self.shape - softmax_buf_type = ir.MemRefType.get((M,N), dtype) - softmax_buf = gpu.alloc(softmax_buf_type, None, [], [], []) - softmax_buf_tensor = emit_buf_to_tensor(softmax_buf, restrict=True, writable=True) - # Compute softmax along dimension 1 (rows) - # linalg.softmax performs: exp(x - max(x)) / sum(exp(x - max(x))) - result = linalg.softmax( - (input_tensor.type,), input_tensor, softmax_buf_tensor, dimension=1 + # Define affine maps for indexing + # #map = affine_map<(d0, d1) -> (d0, d1)> (identity 2D) + # #map1 = affine_map<(d0, d1) -> (d0)> (broadcast/reduce along d1) + d0 = ir.AffineDimExpr.get(0) + d1 = ir.AffineDimExpr.get(1) + map_2d = ir.AffineMap.get(2, 0, [d0, d1]) + map_1d = ir.AffineMap.get(2, 0, [d0]) + + # Step 1: Find max - linalg.generic reduction + neg_inf = arith.constant(dtype, float('-inf')) + max_init = tensor.empty((M,), dtype) + max_filled = linalg.fill(neg_inf, outs=[max_init]) + + @linalg.generic( + [input_tensor], # inputs + [max_filled], # outputs + [map_2d, map_1d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.reduction], # iterator_types + ) + def row_max(in_val, acc): + return arith.maximumf(in_val, acc) + + # Step 2: Subtract max (broadcast) - linalg.generic elementwise + output_init = tensor.empty((M, N), dtype) + + @linalg.generic( + [input_tensor, row_max], # inputs + [output_init], # outputs + [map_2d, map_1d, map_2d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + ) + def shifted(in_val, max_val, out): + return arith.subf(in_val, max_val) + + # Step 3: Compute exp - linalg.generic elementwise + @linalg.generic( + [shifted], # inputs + [output_init], # outputs + [map_2d, map_2d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + ) + def exp_vals(in_val, out): + return math.exp(in_val) + + # Step 4: Sum exp values - linalg.generic reduction + # Create collapsed tensor for sum init + # sum_init_2d = tensor.empty((M, 1), dtype) + sum_init = tensor.empty((M,), dtype) + # sum_init = tensor.CollapseShapeOp(sum_init_2d, [[0, 1]]) + + + zero = arith.constant(dtype, 0.0) + sum_filled = linalg.fill(zero, outs=[sum_init]) + + @linalg.generic( + [exp_vals], # inputs + [sum_filled], # outputs + [map_2d, map_1d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.reduction], # iterator_types ) + def row_sum(in_val, acc): + return arith.addf(in_val, acc) + + # Step 5: Divide by sum (broadcast) - linalg.generic elementwise + @linalg.generic( + [exp_vals, row_sum], # inputs + [output_init], # outputs + [map_2d, map_1d, map_2d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + ) + def result(exp_val, sum_val, out): + return arith.divf(exp_val, sum_val) # Materialize result back to output memref bufferization.materialize_in_destination( None, result, output, restrict=True, writable=True ) - - # Cleanup - gpu.dealloc(None, [], softmax_buf) # Emit utility functions for GPU memory management emit_gpu_util_funcs(dtype, rank=2) return mod - def schedule_module( + def schedule_modules( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None - ) -> ir.Module: + ) -> list[ir.Module]: """ Generate transform schedule for softmax. @@ -223,7 +282,7 @@ def transform_main(): # and lower to XeGPU operations pass - return mod + return [get_bench_wrapper_schedule(self), mod] def shared_libs(self) -> list[str]: return ["libmlir_levelzero_runtime.so"] From 22415bb54f34727337b66392f4585f09110d2b6a Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 16 Mar 2026 22:48:34 +0000 Subject: [PATCH 03/63] save work --- examples/xegpu/softmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index ab61f1e5..dff8bd99 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -224,7 +224,7 @@ def exp_vals(in_val, out): # Create collapsed tensor for sum init # sum_init_2d = tensor.empty((M, 1), dtype) sum_init = tensor.empty((M,), dtype) - # sum_init = tensor.CollapseShapeOp(sum_init_2d, [[0, 1]]) + # tensor.CollapseShapeOp(sum_init, sum_init_2d, [[0, 1]]) zero = arith.constant(dtype, 0.0) From cb9ead174134465610ef797dbb33fbee0196e1cd Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 17 Mar 2026 22:05:47 +0000 Subject: [PATCH 04/63] save work --- examples/xegpu/softmax.py | 64 +++++++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index dff8bd99..97abd49a 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -269,20 +269,66 @@ def schedule_modules( tiling, vectorization, and XeGPU-specific lowering transformations. """ # TODO: Implement proper transform schedule - # For now, create a minimal schedule that just applies bufferization + # For now, create a minimal schedule that prints the last linalg operation mod = ir.Module.create() + mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + with ir.InsertionPoint(mod.body): from mlir.dialects import transform + from mlir.dialects.transform import structured + + # Create a transform sequence with proper signature + named_sequence = transform.named_sequence( + "__transform_main", + [transform.AnyOpType.get()], # input: module + [], # no outputs + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}] + ) - # Create a simple transform sequence - @func_cif(name="__transform_main") - def transform_main(): - # Empty transform - just identity - # In a full implementation, this would tile, vectorize, - # and lower to XeGPU operations - pass + with ir.InsertionPoint(named_sequence.body): + # Get the input module (bodyTarget) + payload_mod = named_sequence.bodyTarget + + # Match all linalg.generic operations + # We have 5 generic ops in softmax: max, sub, exp, sum, div + generic_ops = structured.structured_match( + transform.AnyOpType.get(), + payload_mod, + ops=["linalg.generic"] + ) + + # Split the handle into individual operation handles + # For softmax, we have 5 operations + anytype = transform.AnyOpType.get() + split_ops = transform.split_handle( + (anytype, anytype, anytype, anytype, anytype), # 5 result types + generic_ops + ) + + # The last operation (index 4) is the division + last_op = split_ops[-1] + + # Print the last operation before tiling + # transform.print_(target=last_op, name="last_linalg_generic_before_tiling") + + # Tile the last operation using tile_using_forall + # Tile sizes: [64, 64] for the two parallel dimensions (M, N) + tiled_op, for_op = structured.structured_tile_using_forall( + anytype, anytype, + last_op, + num_threads=[], + tile_sizes=[], + static_tile_sizes=[64, 64], + ) + + # Print the tiled operation + # transform.print_(target=tiled_op, name="tiled_linalg_generic") + # transform.print_(target=for_op, name="forall_op") + + # Required: yield to end the transform sequence + transform.yield_() - return [get_bench_wrapper_schedule(self), mod] + return [mod] def shared_libs(self) -> list[str]: return ["libmlir_levelzero_runtime.so"] From ac39be33c3b90e6eb77c16237611fd33d4490695 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 18 Mar 2026 22:41:32 +0000 Subject: [PATCH 05/63] save work --- examples/xegpu/softmax.py | 86 +++++++++++++++++++++++++++++++++++---- 1 file changed, 78 insertions(+), 8 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index 97abd49a..6d6f87eb 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -14,6 +14,8 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine from mlir.dialects import linalg, gpu, bufferization, arith, tensor, func, math +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype @@ -21,9 +23,25 @@ from lighthouse.utils.mlir import func_cif from lighthouse.ingress.mlir_gen import get_mlir_elem_type from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs, emit_buf_to_tensor +from lighthouse.pipeline.helper import ( + apply_registered_pass, + canonicalize, + match, +) +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.bufferization import LayoutMapOption from xegpu_workload import XeGPUWorkload +def match_and_split(*args, nhandles=1, **kwargs): + """Helper function that splits matched handles.""" + matched = match(*args, **kwargs) + anytype = transform.AnyOpType.get() + matched_ops = transform.split_handle((anytype,) * nhandles, matched) + if nhandles == 1: + matched_ops = [matched_ops] + return matched_ops + def softmax_complexity(M: int, N: int, nbytes: int): """ @@ -274,8 +292,7 @@ def schedule_modules( mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() with ir.InsertionPoint(mod.body): - from mlir.dialects import transform - from mlir.dialects.transform import structured + # Create a transform sequence with proper signature named_sequence = transform.named_sequence( @@ -305,8 +322,11 @@ def schedule_modules( generic_ops ) - # The last operation (index 4) is the division - last_op = split_ops[-1] + # Reverse split_ops to have operations in reverse order + split_ops = list(reversed(split_ops)) + + # The first operation (after reversal) is the division - this is the consumer + last_op = split_ops[0] # Print the last operation before tiling # transform.print_(target=last_op, name="last_linalg_generic_before_tiling") @@ -318,12 +338,62 @@ def schedule_modules( last_op, num_threads=[], tile_sizes=[], - static_tile_sizes=[64, 64], + static_tile_sizes=(64,), ) - # Print the tiled operation - # transform.print_(target=tiled_op, name="tiled_linalg_generic") - # transform.print_(target=for_op, name="forall_op") + # Fuse the producer operations into the forall loop + # Iterate through remaining operations (already in reverse order) + current_forall = for_op + for producer_op in split_ops[1:]: + fused_op, current_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, + producer_op, + current_forall + ) + + func = transform.get_parent_op( + anytype, + current_forall, + op_name="func.func", + deduplicate=True, + ) + transform.apply_cse(func) + canonicalize(func) + func = apply_registered_pass(func, "eliminate-empty-tensors") + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + identity_layout = LayoutMapOption.IdentityLayoutMap + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + payload_mod = transform_bufferization.OneShotBufferizeOp( + payload_mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + payload_mod = apply_registered_pass(payload_mod, "fold-memref-alias-ops") + transform.apply_cse(payload_mod) + canonicalize(payload_mod) + + # convert forall to parallel + wg_loops = match_and_split(payload_mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert to scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + # Required: yield to end the transform sequence transform.yield_() From 51d494e23aa34e3166210e080aa23663e98924c2 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 20 Mar 2026 17:10:46 +0000 Subject: [PATCH 06/63] save work --- examples/xegpu/softmax.py | 67 +++++++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 10 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index 6d6f87eb..be29cdd0 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -15,7 +15,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.dialects import linalg, gpu, bufferization, arith, tensor, func, math from mlir.dialects import transform -from mlir.dialects.transform import structured, loop +from mlir.dialects.transform import structured, loop, xegpu from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype @@ -338,7 +338,7 @@ def schedule_modules( last_op, num_threads=[], tile_sizes=[], - static_tile_sizes=(64,), + static_tile_sizes=(parameters["wg_rows"],), ) # Fuse the producer operations into the forall loop @@ -393,6 +393,39 @@ def schedule_modules( func = apply_registered_pass(func, "lower-affine") transform.apply_cse(func) canonicalize(func) + # set the number of threads for the gpu.launch operation + launch_op = match_and_split(func, ops={"gpu.launch"}) + num_threads = parameters["sg_rows"] * parameters["subgroup_size"] + xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + payload_mod = apply_registered_pass(payload_mod, "gpu-kernel-outlining") + # transform.PrintOp(target=payload_mod, name="before_gpu_outlining") + # transform.apply_cse(payload_mod) + + # set xevm target + # payload_mod = apply_registered_pass( + # payload_mod, + # "xevm-attach-target", + # options={"O": "3", "chip": "bmg"}, + # ) + + # # convert vector to xegpu + # gpu_mod_ops = match_and_split(payload_mod, ops={"gpu.module"}) + # for gpu_mod in gpu_mod_ops: + # gpu_func = match(gpu_mod, ops={"gpu.func"}) + # gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + # transform.apply_cse(gpu_func) + # Required: yield to end the transform sequence @@ -413,15 +446,26 @@ def parse_cli(): "--sizes", type=int, nargs=2, - default=[1024, 512], + default=[1024, 64], help="M,N matrix sizes (MxN)", ) parser.add_argument( - "--wg-tile", + "--wg-rows", type=int, - nargs=2, - default=[64, 32], - help="Workgroup tile size M,N.", + default=64, + help="Number of rows per workgroup.", + ) + parser.add_argument( + "--sg-rows", + type=int, + default=8, + help="Number of rows per subgroup.", + ) + parser.add_argument( + "--subgroup-size", + type=int, + default=16, + help="Subgroup size.", ) parser.add_argument( "--nruns", @@ -468,8 +512,9 @@ def parse_cli(): args = parse_cli() params = { - "wg_m": args.wg_tile[0], - "wg_n": args.wg_tile[1], + "wg_rows": args.wg_rows, + "sg_rows": args.sg_rows, + "subgroup_size": args.subgroup_size, } M, N = args.sizes @@ -504,7 +549,9 @@ def list2str(a): parts = [ f"sizes={list2str(args.sizes)}", f"dt={dtype}", - f"wg-tile={list2str(args.wg_tile)}", + f"wg-rows={args.wg_rows}", + f"sg-rows={args.sg_rows}", + f"subgroup-size={args.subgroup_size}", f"time(us): {elapsed:.2f}", f"GFLOPS: {gflops:.2f}", ] From 7ac8852412f06100c2cf188c37c2f18254e38c83 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 20 Mar 2026 21:37:04 +0000 Subject: [PATCH 07/63] save work --- examples/xegpu/softmax.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index be29cdd0..56b43c21 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -408,23 +408,34 @@ def schedule_modules( op_name="builtin.module", deduplicate=True, ) - payload_mod = apply_registered_pass(payload_mod, "gpu-kernel-outlining") + # payload = match(payload_mod, ops={"func.func"}) # transform.PrintOp(target=payload_mod, name="before_gpu_outlining") - # transform.apply_cse(payload_mod) + payload_mod = apply_registered_pass(payload_mod, "gpu-kernel-outlining") + transform.apply_cse(payload_mod) # set xevm target - # payload_mod = apply_registered_pass( - # payload_mod, - # "xevm-attach-target", - # options={"O": "3", "chip": "bmg"}, - # ) - - # # convert vector to xegpu - # gpu_mod_ops = match_and_split(payload_mod, ops={"gpu.module"}) + payload_mod = apply_registered_pass( + payload_mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # convert vector to xegpu + gpu_mod = match_and_split(payload_mod, ops={"gpu.module"}) # for gpu_mod in gpu_mod_ops: - # gpu_func = match(gpu_mod, ops={"gpu.func"}) - # gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") - # transform.apply_cse(gpu_func) + gpu_func = match(gpu_mod[0], ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + + # Set layout attributes for xegpu.store_nd operations + store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) + # for store_op in store_ops: + xegpu.set_op_layout_attr(store_ops[0], sg_layout=[8, 1], sg_data=[8, 64]) + + payload_mod = apply_registered_pass( + payload_mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} + ) + From d65bf9fe47d2c95c5fdc986177e604ec4c0b6f82 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 20 Mar 2026 22:58:52 +0000 Subject: [PATCH 08/63] save work --- examples/xegpu/softmax.py | 88 ++++++++++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 20 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index 56b43c21..2debc8c8 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -311,14 +311,14 @@ def schedule_modules( generic_ops = structured.structured_match( transform.AnyOpType.get(), payload_mod, - ops=["linalg.generic"] + ops=["linalg.generic", "linalg.fill"] ) # Split the handle into individual operation handles # For softmax, we have 5 operations anytype = transform.AnyOpType.get() split_ops = transform.split_handle( - (anytype, anytype, anytype, anytype, anytype), # 5 result types + (anytype, anytype, anytype, anytype, anytype, anytype, anytype), # 7 result types generic_ops ) @@ -350,20 +350,23 @@ def schedule_modules( producer_op, current_forall ) - - func = transform.get_parent_op( - anytype, - current_forall, - op_name="func.func", - deduplicate=True, - ) - transform.apply_cse(func) - canonicalize(func) - func = apply_registered_pass(func, "eliminate-empty-tensors") + transform.annotate(current_forall, "gpu_loop") + + transform.apply_cse(payload_mod) + # canonicalize(payload_mod) + + # Vectorize and bufferize sequence + func = match(payload_mod, ops={"func.func"}) func = structured.VectorizeChildrenAndApplyPatternsOp( func, fold_type_extensions_into_contract=True, ).result + loops = match_and_split(payload_mod, ops={"scf.forall"}) + loop.loop_hoist_loop_invariant_subsets(loops[0]) + transform.apply_cse(payload_mod) + canonicalize(payload_mod) + # transform.PrintOp(target=payload_mod, name="vectorize") + identity_layout = LayoutMapOption.IdentityLayoutMap payload_mod = transform.get_parent_op( anytype, @@ -377,28 +380,73 @@ def schedule_modules( bufferize_function_boundaries=True, function_boundary_type_conversion=identity_layout, ).result + # payload_mod = transform_bufferization.OneShotBufferizeOp( + # payload_mod, + # allow_return_allocs_from_loops=False, + # bufferize_function_boundaries=True, + # function_boundary_type_conversion=identity_layout, + # ).result payload_mod = apply_registered_pass(payload_mod, "fold-memref-alias-ops") + payload_mod = apply_registered_pass(payload_mod, "drop-equivalent-buffer-results") + payload_mod = apply_registered_pass( + payload_mod, + "buffer-results-to-out-params", + options={ + "add-result-attr": "true", + "hoist-dynamic-allocs": "true", + "hoist-static-allocs": "true", + "modify-public-functions": "true" + } + ) transform.apply_cse(payload_mod) canonicalize(payload_mod) + # # # transform.PrintOp(target=payload_mod, name="bufferize") + + # # func = match(payload_mod, ops={"func.func"}) + gpu_loop = match(payload_mod, op_attrs={"gpu_loop": ir.UnitAttr.get()}) + # # gpu_loop = transform.split_handle(anytype, gpu_loop) + gpu_loop = loop.loop_forall_to_parallel([anytype], gpu_loop) + + # # func = apply_registered_pass(payload_mod, "eliminate-empty-tensors") + # # func = structured.VectorizeChildrenAndApplyPatternsOp( + # # func, + # # fold_type_extensions_into_contract=True, + # # ).result + # # identity_layout = LayoutMapOption.IdentityLayoutMap + payload_mod = transform.get_parent_op( + anytype, + gpu_loop, + op_name="func.func", + deduplicate=True, + ) + # payload_mod = transform_bufferization.OneShotBufferizeOp( + # payload_mod, + # allow_return_allocs_from_loops=True, + # bufferize_function_boundaries=True, + # function_boundary_type_conversion=identity_layout, + # ).result + # payload_mod = apply_registered_pass(payload_mod, "fold-memref-alias-ops") + # transform.apply_cse(payload_mod) + # canonicalize(payload_mod) - # convert forall to parallel - wg_loops = match_and_split(payload_mod, ops={"scf.forall"}) - for wg_loop in wg_loops: - wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) - func = transform.get_parent_op(anytype, wg_loop) + # # convert forall to parallel + # wg_loops = match_and_split(payload_mod, ops={"scf.forall"}) + # for wg_loop in wg_loops: + # wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + # func = transform.get_parent_op(anytype, wg_loop) # convert to scf.parallel to gpu.launch - func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(payload_mod, "gpu-map-parallel-loops") func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") func = apply_registered_pass(func, "lower-affine") transform.apply_cse(func) canonicalize(func) - # set the number of threads for the gpu.launch operation + # # set the number of threads for the gpu.launch operation launch_op = match_and_split(func, ops={"gpu.launch"}) num_threads = parameters["sg_rows"] * parameters["subgroup_size"] xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) - # outline gpu func + # # outline gpu func func = apply_registered_pass(func, "lower-affine") canonicalize(func) func = apply_registered_pass(func, "gpu-launch-sink-index-computations") From 0bf3eb32a7fa9fbd66b0a7a238fb302607c2896a Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 24 Mar 2026 20:32:41 +0000 Subject: [PATCH 09/63] save working version --- examples/xegpu/softmax.py | 138 ++++++++++++-------------------------- 1 file changed, 42 insertions(+), 96 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index 2debc8c8..e95286b5 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -17,6 +17,7 @@ from mlir.dialects import transform from mlir.dialects.transform import structured, loop, xegpu +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -54,7 +55,6 @@ def softmax_complexity(M: int, N: int, nbytes: int): Total: 3*N operations per row, but with transcendental (exp) operations """ # Approximation: 5 FLOPs per element (max, sub, exp, sum, div) - # exp is expensive but we count it as ~1 FLOP for simplicity flop_count = M * N * 5 memory_reads = M * N * nbytes # read input memory_writes = M * N * nbytes # write output @@ -303,8 +303,15 @@ def schedule_modules( ) with ir.InsertionPoint(named_sequence.body): - # Get the input module (bodyTarget) - payload_mod = named_sequence.bodyTarget + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) # Match all linalg.generic operations # We have 5 generic ops in softmax: max, sub, exp, sum, div @@ -315,7 +322,7 @@ def schedule_modules( ) # Split the handle into individual operation handles - # For softmax, we have 5 operations + # For softmax, we have 7 operations anytype = transform.AnyOpType.get() split_ops = transform.split_handle( (anytype, anytype, anytype, anytype, anytype, anytype, anytype), # 7 result types @@ -350,117 +357,59 @@ def schedule_modules( producer_op, current_forall ) - transform.annotate(current_forall, "gpu_loop") - - transform.apply_cse(payload_mod) - # canonicalize(payload_mod) + + func = transform.get_parent_op( + anytype, + current_forall, + op_name="func.func", + deduplicate=True, + ) + transform.apply_cse(func) + canonicalize(func) - # Vectorize and bufferize sequence - func = match(payload_mod, ops={"func.func"}) func = structured.VectorizeChildrenAndApplyPatternsOp( func, fold_type_extensions_into_contract=True, ).result - loops = match_and_split(payload_mod, ops={"scf.forall"}) - loop.loop_hoist_loop_invariant_subsets(loops[0]) - transform.apply_cse(payload_mod) - canonicalize(payload_mod) - # transform.PrintOp(target=payload_mod, name="vectorize") - + transform.apply_cse(func) + canonicalize(func) + payload_mod = apply_registered_pass(payload_mod, "eliminate-empty-tensors") identity_layout = LayoutMapOption.IdentityLayoutMap - payload_mod = transform.get_parent_op( - anytype, - func, - op_name="builtin.module", - deduplicate=True, - ) payload_mod = transform_bufferization.OneShotBufferizeOp( payload_mod, allow_return_allocs_from_loops=True, bufferize_function_boundaries=True, function_boundary_type_conversion=identity_layout, ).result - # payload_mod = transform_bufferization.OneShotBufferizeOp( - # payload_mod, - # allow_return_allocs_from_loops=False, - # bufferize_function_boundaries=True, - # function_boundary_type_conversion=identity_layout, - # ).result + # fold memref.subviews into vector.transfer_read/write ops payload_mod = apply_registered_pass(payload_mod, "fold-memref-alias-ops") - payload_mod = apply_registered_pass(payload_mod, "drop-equivalent-buffer-results") - payload_mod = apply_registered_pass( - payload_mod, - "buffer-results-to-out-params", - options={ - "add-result-attr": "true", - "hoist-dynamic-allocs": "true", - "hoist-static-allocs": "true", - "modify-public-functions": "true" - } - ) transform.apply_cse(payload_mod) canonicalize(payload_mod) - # # # transform.PrintOp(target=payload_mod, name="bufferize") - - # # func = match(payload_mod, ops={"func.func"}) - gpu_loop = match(payload_mod, op_attrs={"gpu_loop": ir.UnitAttr.get()}) - # # gpu_loop = transform.split_handle(anytype, gpu_loop) - gpu_loop = loop.loop_forall_to_parallel([anytype], gpu_loop) - # # func = apply_registered_pass(payload_mod, "eliminate-empty-tensors") - # # func = structured.VectorizeChildrenAndApplyPatternsOp( - # # func, - # # fold_type_extensions_into_contract=True, - # # ).result - # # identity_layout = LayoutMapOption.IdentityLayoutMap - payload_mod = transform.get_parent_op( - anytype, - gpu_loop, - op_name="func.func", - deduplicate=True, - ) - # payload_mod = transform_bufferization.OneShotBufferizeOp( - # payload_mod, - # allow_return_allocs_from_loops=True, - # bufferize_function_boundaries=True, - # function_boundary_type_conversion=identity_layout, - # ).result - # payload_mod = apply_registered_pass(payload_mod, "fold-memref-alias-ops") - # transform.apply_cse(payload_mod) - # canonicalize(payload_mod) - - # # convert forall to parallel - # wg_loops = match_and_split(payload_mod, ops={"scf.forall"}) - # for wg_loop in wg_loops: - # wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) - # func = transform.get_parent_op(anytype, wg_loop) - + # convert forall to parallel + wg_loops = match_and_split(payload_mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) # convert to scf.parallel to gpu.launch - func = apply_registered_pass(payload_mod, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "gpu-map-parallel-loops") func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") func = apply_registered_pass(func, "lower-affine") transform.apply_cse(func) canonicalize(func) - # # set the number of threads for the gpu.launch operation + + # set the number of threads for the gpu.launch operation launch_op = match_and_split(func, ops={"gpu.launch"}) num_threads = parameters["sg_rows"] * parameters["subgroup_size"] xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) - # # outline gpu func + # outline gpu func func = apply_registered_pass(func, "lower-affine") canonicalize(func) func = apply_registered_pass(func, "gpu-launch-sink-index-computations") - payload_mod = transform.get_parent_op( - anytype, - func, - op_name="builtin.module", - deduplicate=True, - ) - # payload = match(payload_mod, ops={"func.func"}) - # transform.PrintOp(target=payload_mod, name="before_gpu_outlining") payload_mod = apply_registered_pass(payload_mod, "gpu-kernel-outlining") transform.apply_cse(payload_mod) - + # set xevm target payload_mod = apply_registered_pass( payload_mod, @@ -469,12 +418,12 @@ def schedule_modules( ) # convert vector to xegpu - gpu_mod = match_and_split(payload_mod, ops={"gpu.module"}) - # for gpu_mod in gpu_mod_ops: - gpu_func = match(gpu_mod[0], ops={"gpu.func"}) - gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") - transform.apply_cse(gpu_func) - + gpu_mod_ops = match_and_split(payload_mod, ops={"gpu.module"}) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + # Set layout attributes for xegpu.store_nd operations store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) # for store_op in store_ops: @@ -483,14 +432,10 @@ def schedule_modules( payload_mod = apply_registered_pass( payload_mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} ) - - - - # Required: yield to end the transform sequence transform.yield_() - return [mod] + return [get_bench_wrapper_schedule(self), mod] def shared_libs(self) -> list[str]: return ["libmlir_levelzero_runtime.so"] @@ -580,6 +525,7 @@ def parse_cli(): dtype = "f32" with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() wload = XeGPUSoftmax(M=M, N=N, dtype=dtype) if args.dump_kernel or args.dump_schedule: From fabd656c1267437174b7c7b9ed472c0b76556ee9 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 24 Mar 2026 21:48:44 +0000 Subject: [PATCH 10/63] save working version --- examples/xegpu/softmax.py | 288 +----------------- .../ingress/mlir_gen/gpu_softmax_payload.py | 121 ++++++++ lighthouse/schedule/xegpu/softmax_schedule.py | 195 ++++++++++++ 3 files changed, 332 insertions(+), 272 deletions(-) create mode 100644 lighthouse/ingress/mlir_gen/gpu_softmax_payload.py create mode 100644 lighthouse/schedule/xegpu/softmax_schedule.py diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index e95286b5..474eb1a3 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -13,36 +13,17 @@ import numpy as np from mlir import ir from mlir.execution_engine import ExecutionEngine -from mlir.dialects import linalg, gpu, bufferization, arith, tensor, func, math -from mlir.dialects import transform -from mlir.dialects.transform import structured, loop, xegpu from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark, get_bench_wrapper_schedule from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype -from lighthouse.utils.mlir import func_cif from lighthouse.ingress.mlir_gen import get_mlir_elem_type -from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs, emit_buf_to_tensor -from lighthouse.pipeline.helper import ( - apply_registered_pass, - canonicalize, - match, -) -from mlir.dialects.transform import bufferization as transform_bufferization -from mlir.dialects.bufferization import LayoutMapOption +from lighthouse.ingress.mlir_gen.gpu_softmax_payload import generate_gpu_softmax_payload +from lighthouse.schedule.xegpu.softmax_schedule import get_softmax_schedule_module from xegpu_workload import XeGPUWorkload -def match_and_split(*args, nhandles=1, **kwargs): - """Helper function that splits matched handles.""" - matched = match(*args, **kwargs) - anytype = transform.AnyOpType.get() - matched_ops = transform.split_handle((anytype,) * nhandles, matched) - if nhandles == 1: - matched_ops = [matched_ops] - return matched_ops - def softmax_complexity(M: int, N: int, nbytes: int): """ @@ -180,262 +161,25 @@ def get_complexity(self) -> tuple[int, int, int]: def payload_module(self) -> ir.Module: """Generate MLIR module for softmax payload.""" - mod = ir.Module.create() dtype = get_mlir_elem_type(self.dtype_str) - memref_t = ir.MemRefType.get(self.shape, dtype) - - with ir.InsertionPoint(mod.body): - # Function signature: payload(output, input) - @func_cif(memref_t, memref_t, name=self.payload_function_name) - def payload(output, input_arg): - # Convert memrefs to tensors - output_tensor = emit_buf_to_tensor(output, restrict=True, writable=True) - input_tensor = emit_buf_to_tensor(input_arg, restrict=True) - - M, N = self.shape - - # Define affine maps for indexing - # #map = affine_map<(d0, d1) -> (d0, d1)> (identity 2D) - # #map1 = affine_map<(d0, d1) -> (d0)> (broadcast/reduce along d1) - d0 = ir.AffineDimExpr.get(0) - d1 = ir.AffineDimExpr.get(1) - map_2d = ir.AffineMap.get(2, 0, [d0, d1]) - map_1d = ir.AffineMap.get(2, 0, [d0]) - - # Step 1: Find max - linalg.generic reduction - neg_inf = arith.constant(dtype, float('-inf')) - max_init = tensor.empty((M,), dtype) - max_filled = linalg.fill(neg_inf, outs=[max_init]) - - @linalg.generic( - [input_tensor], # inputs - [max_filled], # outputs - [map_2d, map_1d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.reduction], # iterator_types - ) - def row_max(in_val, acc): - return arith.maximumf(in_val, acc) - - # Step 2: Subtract max (broadcast) - linalg.generic elementwise - output_init = tensor.empty((M, N), dtype) - - @linalg.generic( - [input_tensor, row_max], # inputs - [output_init], # outputs - [map_2d, map_1d, map_2d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types - ) - def shifted(in_val, max_val, out): - return arith.subf(in_val, max_val) - - # Step 3: Compute exp - linalg.generic elementwise - @linalg.generic( - [shifted], # inputs - [output_init], # outputs - [map_2d, map_2d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types - ) - def exp_vals(in_val, out): - return math.exp(in_val) - - # Step 4: Sum exp values - linalg.generic reduction - # Create collapsed tensor for sum init - # sum_init_2d = tensor.empty((M, 1), dtype) - sum_init = tensor.empty((M,), dtype) - # tensor.CollapseShapeOp(sum_init, sum_init_2d, [[0, 1]]) - - - zero = arith.constant(dtype, 0.0) - sum_filled = linalg.fill(zero, outs=[sum_init]) - - @linalg.generic( - [exp_vals], # inputs - [sum_filled], # outputs - [map_2d, map_1d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.reduction], # iterator_types - ) - def row_sum(in_val, acc): - return arith.addf(in_val, acc) - - # Step 5: Divide by sum (broadcast) - linalg.generic elementwise - @linalg.generic( - [exp_vals, row_sum], # inputs - [output_init], # outputs - [map_2d, map_1d, map_2d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types - ) - def result(exp_val, sum_val, out): - return arith.divf(exp_val, sum_val) - - # Materialize result back to output memref - bufferization.materialize_in_destination( - None, result, output, restrict=True, writable=True - ) - - # Emit utility functions for GPU memory management - emit_gpu_util_funcs(dtype, rank=2) - - return mod + return generate_gpu_softmax_payload( + func_name=self.payload_function_name, + M=self.M, + N=self.N, + dtype=dtype, + ) def schedule_modules( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> list[ir.Module]: - """ - Generate transform schedule for softmax. - - For now, returns an empty schedule. In the future, this would contain - tiling, vectorization, and XeGPU-specific lowering transformations. - """ - # TODO: Implement proper transform schedule - # For now, create a minimal schedule that prints the last linalg operation - mod = ir.Module.create() - mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() - - with ir.InsertionPoint(mod.body): - - - # Create a transform sequence with proper signature - named_sequence = transform.named_sequence( - "__transform_main", - [transform.AnyOpType.get()], # input: module - [], # no outputs - arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}] - ) - - with ir.InsertionPoint(named_sequence.body): - # match the payload module - anytype = transform.AnyOpType.get() - func = match(named_sequence.bodyTarget, ops={"func.func"}) - payload_mod = transform.get_parent_op( - anytype, - func, - op_name="builtin.module", - deduplicate=True, - ) - - # Match all linalg.generic operations - # We have 5 generic ops in softmax: max, sub, exp, sum, div - generic_ops = structured.structured_match( - transform.AnyOpType.get(), - payload_mod, - ops=["linalg.generic", "linalg.fill"] - ) - - # Split the handle into individual operation handles - # For softmax, we have 7 operations - anytype = transform.AnyOpType.get() - split_ops = transform.split_handle( - (anytype, anytype, anytype, anytype, anytype, anytype, anytype), # 7 result types - generic_ops - ) - - # Reverse split_ops to have operations in reverse order - split_ops = list(reversed(split_ops)) - - # The first operation (after reversal) is the division - this is the consumer - last_op = split_ops[0] - - # Print the last operation before tiling - # transform.print_(target=last_op, name="last_linalg_generic_before_tiling") - - # Tile the last operation using tile_using_forall - # Tile sizes: [64, 64] for the two parallel dimensions (M, N) - tiled_op, for_op = structured.structured_tile_using_forall( - anytype, anytype, - last_op, - num_threads=[], - tile_sizes=[], - static_tile_sizes=(parameters["wg_rows"],), - ) - - # Fuse the producer operations into the forall loop - # Iterate through remaining operations (already in reverse order) - current_forall = for_op - for producer_op in split_ops[1:]: - fused_op, current_forall = structured.structured_fuse_into_containing_op( - anytype, anytype, - producer_op, - current_forall - ) - - func = transform.get_parent_op( - anytype, - current_forall, - op_name="func.func", - deduplicate=True, - ) - transform.apply_cse(func) - canonicalize(func) - - func = structured.VectorizeChildrenAndApplyPatternsOp( - func, - fold_type_extensions_into_contract=True, - ).result - transform.apply_cse(func) - canonicalize(func) - payload_mod = apply_registered_pass(payload_mod, "eliminate-empty-tensors") - identity_layout = LayoutMapOption.IdentityLayoutMap - payload_mod = transform_bufferization.OneShotBufferizeOp( - payload_mod, - allow_return_allocs_from_loops=True, - bufferize_function_boundaries=True, - function_boundary_type_conversion=identity_layout, - ).result - # fold memref.subviews into vector.transfer_read/write ops - payload_mod = apply_registered_pass(payload_mod, "fold-memref-alias-ops") - transform.apply_cse(payload_mod) - canonicalize(payload_mod) - - # convert forall to parallel - wg_loops = match_and_split(payload_mod, ops={"scf.forall"}) - for wg_loop in wg_loops: - wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) - func = transform.get_parent_op(anytype, wg_loop) - # convert to scf.parallel to gpu.launch - func = apply_registered_pass(func, "gpu-map-parallel-loops") - func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") - func = apply_registered_pass(func, "lower-affine") - transform.apply_cse(func) - canonicalize(func) - - # set the number of threads for the gpu.launch operation - launch_op = match_and_split(func, ops={"gpu.launch"}) - num_threads = parameters["sg_rows"] * parameters["subgroup_size"] - xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) - - # outline gpu func - func = apply_registered_pass(func, "lower-affine") - canonicalize(func) - func = apply_registered_pass(func, "gpu-launch-sink-index-computations") - payload_mod = apply_registered_pass(payload_mod, "gpu-kernel-outlining") - transform.apply_cse(payload_mod) - - # set xevm target - payload_mod = apply_registered_pass( - payload_mod, - "xevm-attach-target", - options={"O": "3", "chip": "bmg"}, - ) - - # convert vector to xegpu - gpu_mod_ops = match_and_split(payload_mod, ops={"gpu.module"}) - for gpu_mod in gpu_mod_ops: - gpu_func = match(gpu_mod, ops={"gpu.func"}) - gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") - transform.apply_cse(gpu_func) - - # Set layout attributes for xegpu.store_nd operations - store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) - # for store_op in store_ops: - xegpu.set_op_layout_attr(store_ops[0], sg_layout=[8, 1], sg_data=[8, 64]) - - payload_mod = apply_registered_pass( - payload_mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} - ) - # Required: yield to end the transform sequence - transform.yield_() - - return [get_bench_wrapper_schedule(self), mod] + """Generate transform schedule for softmax.""" + return [ + get_bench_wrapper_schedule(self), + get_softmax_schedule_module( + stop_at_stage=stop_at_stage, + parameters=parameters, + ), + ] def shared_libs(self) -> list[str]: return ["libmlir_levelzero_runtime.so"] diff --git a/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py b/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py new file mode 100644 index 00000000..26fd9042 --- /dev/null +++ b/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py @@ -0,0 +1,121 @@ +"""Generate MLIR payload for GPU softmax operation.""" + +from mlir import ir +from mlir.dialects import linalg, bufferization, arith, tensor, func, math + +from lighthouse.utils.mlir import func_cif +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs, emit_buf_to_tensor + + +def generate_gpu_softmax_payload( + func_name: str, + M: int, + N: int, + dtype: ir.Type, +) -> ir.Module: + """ + Generate MLIR module for softmax payload. + + Computes softmax along the last dimension (rows): + output[i, j] = exp(input[i, j] - max_i) / sum_i(exp(input[i, j] - max_i)) + + where max_i and sum_i are computed over row i. + + Args: + func_name: Name of the payload function + M: Number of rows + N: Number of columns + dtype: MLIR element type (e.g., F32Type) + + Returns: + MLIR module containing the softmax payload function + """ + mod = ir.Module.create() + shape = (M, N) + memref_t = ir.MemRefType.get(shape, dtype) + + with ir.InsertionPoint(mod.body): + # Function signature: payload(output, input) + @func_cif(memref_t, memref_t, name=func_name) + def payload(output, input_arg): + # Convert memrefs to tensors + output_tensor = emit_buf_to_tensor(output, restrict=True, writable=True) + input_tensor = emit_buf_to_tensor(input_arg, restrict=True) + + # Define affine maps for indexing + # #map = affine_map<(d0, d1) -> (d0, d1)> (identity 2D) + # #map1 = affine_map<(d0, d1) -> (d0)> (broadcast/reduce along d1) + d0 = ir.AffineDimExpr.get(0) + d1 = ir.AffineDimExpr.get(1) + map_2d = ir.AffineMap.get(2, 0, [d0, d1]) + map_1d = ir.AffineMap.get(2, 0, [d0]) + + # Step 1: Find max - linalg.generic reduction + neg_inf = arith.constant(dtype, float('-inf')) + max_init = tensor.empty((M,), dtype) + max_filled = linalg.fill(neg_inf, outs=[max_init]) + + @linalg.generic( + [input_tensor], # inputs + [max_filled], # outputs + [map_2d, map_1d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.reduction], # iterator_types + ) + def row_max(in_val, acc): + return arith.maximumf(in_val, acc) + + # Step 2: Subtract max (broadcast) - linalg.generic elementwise + output_init = tensor.empty((M, N), dtype) + + @linalg.generic( + [input_tensor, row_max], # inputs + [output_init], # outputs + [map_2d, map_1d, map_2d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + ) + def shifted(in_val, max_val, out): + return arith.subf(in_val, max_val) + + # Step 3: Compute exp - linalg.generic elementwise + @linalg.generic( + [shifted], # inputs + [output_init], # outputs + [map_2d, map_2d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + ) + def exp_vals(in_val, out): + return math.exp(in_val) + + # Step 4: Sum exp values - linalg.generic reduction + sum_init = tensor.empty((M,), dtype) + zero = arith.constant(dtype, 0.0) + sum_filled = linalg.fill(zero, outs=[sum_init]) + + @linalg.generic( + [exp_vals], # inputs + [sum_filled], # outputs + [map_2d, map_1d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.reduction], # iterator_types + ) + def row_sum(in_val, acc): + return arith.addf(in_val, acc) + + # Step 5: Divide by sum (broadcast) - linalg.generic elementwise + @linalg.generic( + [exp_vals, row_sum], # inputs + [output_init], # outputs + [map_2d, map_1d, map_2d], # indexing_maps + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + ) + def result(exp_val, sum_val, out): + return arith.divf(exp_val, sum_val) + + # Materialize result back to output memref + bufferization.materialize_in_destination( + None, result, output, restrict=True, writable=True + ) + + # Emit utility functions for GPU memory management + emit_gpu_util_funcs(dtype, rank=2) + + return mod diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py new file mode 100644 index 00000000..e01f9af6 --- /dev/null +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -0,0 +1,195 @@ +"""Generate MLIR transform schedule for XeGPU softmax operation.""" + +from typing import Optional + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.bufferization import LayoutMapOption + +from lighthouse.pipeline.helper import ( + apply_registered_pass, + canonicalize, + match, +) + + +def match_and_split(*args, nhandles=1, **kwargs): + """Helper function that splits matched handles.""" + matched = match(*args, **kwargs) + anytype = transform.AnyOpType.get() + matched_ops = transform.split_handle((anytype,) * nhandles, matched) + if nhandles == 1: + matched_ops = [matched_ops] + return matched_ops + + +def get_softmax_schedule_module( + stop_at_stage: Optional[str] = None, + parameters: Optional[dict] = None, +) -> ir.Module: + """ + Generate transform schedule for softmax operation. + + The schedule performs the following transformations: + 1. Tile the consumer operation (division) using forall + 2. Fuse producer operations into the forall loop + 3. Vectorize operations + 4. Bufferize tensors + 5. Convert to GPU dialect + 6. Lower to XeGPU operations + + Args: + stop_at_stage: Optional stage name to stop early (for debugging) + parameters: Dictionary with scheduling parameters: + - wg_rows: Number of rows per workgroup + - sg_rows: Number of rows per subgroup + - subgroup_size: Size of subgroup + + Returns: + MLIR module containing the transform schedule + """ + assert parameters is not None, "Schedule parameters must be provided" + + mod = ir.Module.create() + mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + + with ir.InsertionPoint(mod.body): + # Create a transform sequence with proper signature + named_sequence = transform.named_sequence( + "__transform_main", + [transform.AnyOpType.get()], # input: module + [], # no outputs + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}] + ) + + with ir.InsertionPoint(named_sequence.body): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + + # Match all linalg.generic and linalg.fill operations + # We have 7 operations in softmax: + # fill(max_init), max, sub, exp, fill(sum_init), sum, div + generic_ops = structured.structured_match( + transform.AnyOpType.get(), + payload_mod, + ops=["linalg.generic", "linalg.fill"] + ) + + # Split the handle into individual operation handles + anytype = transform.AnyOpType.get() + split_ops = transform.split_handle( + (anytype, anytype, anytype, anytype, anytype, anytype, anytype), # 7 result types + generic_ops + ) + + # Reverse split_ops to have operations in reverse order + split_ops = list(reversed(split_ops)) + + # The first operation (after reversal) is the division - this is the consumer + last_op = split_ops[0] + + # Tile the last operation using tile_using_forall + tiled_op, for_op = structured.structured_tile_using_forall( + anytype, anytype, + last_op, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(parameters["wg_rows"],), + ) + + # Fuse the producer operations into the forall loop + # Iterate through remaining operations (already in reverse order) + current_forall = for_op + for producer_op in split_ops[1:]: + fused_op, current_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, + producer_op, + current_forall + ) + + func = transform.get_parent_op( + anytype, + current_forall, + op_name="func.func", + deduplicate=True, + ) + transform.apply_cse(func) + canonicalize(func) + + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + transform.apply_cse(func) + canonicalize(func) + payload_mod = apply_registered_pass(payload_mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + payload_mod = transform_bufferization.OneShotBufferizeOp( + payload_mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + # fold memref.subviews into vector.transfer_read/write ops + payload_mod = apply_registered_pass(payload_mod, "fold-memref-alias-ops") + transform.apply_cse(payload_mod) + canonicalize(payload_mod) + + # convert forall to parallel + wg_loops = match_and_split(payload_mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + # convert scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set the number of threads for the gpu.launch operation + launch_op = match_and_split(func, ops={"gpu.launch"}) + num_threads = parameters["sg_rows"] * parameters["subgroup_size"] + xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + payload_mod = apply_registered_pass(payload_mod, "gpu-kernel-outlining") + transform.apply_cse(payload_mod) + + # set xevm target + payload_mod = apply_registered_pass( + payload_mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # convert vector to xegpu + gpu_mod_ops = match_and_split(payload_mod, ops={"gpu.module"}) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + + # Set layout attributes for xegpu.store_nd operations + store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) + xegpu.set_op_layout_attr(store_ops[0], sg_layout=[8, 1], sg_data=[8, 64]) + + payload_mod = apply_registered_pass( + payload_mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} + ) + # Required: yield to end the transform sequence + transform.yield_() + + return mod From 1e63d7de4e88047d2251643e5649a477fd0a7c1b Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 24 Mar 2026 22:15:39 +0000 Subject: [PATCH 11/63] save working version --- lighthouse/pipeline/helper.py | 16 + lighthouse/schedule/xegpu/mlp_schedule.py | 18 +- lighthouse/schedule/xegpu/softmax_schedule.py | 307 +++++++++++------- 3 files changed, 203 insertions(+), 138 deletions(-) diff --git a/lighthouse/pipeline/helper.py b/lighthouse/pipeline/helper.py index 213b45df..9c4820d5 100644 --- a/lighthouse/pipeline/helper.py +++ b/lighthouse/pipeline/helper.py @@ -35,3 +35,19 @@ def cleanup_func(target): func = structured.MatchOp.match_op_names(target, ["func.func"]).result transform.apply_cse(func) canonicalize(func) + + +class PipelineInterrupt(Exception): + """Exception to signal early termination of the transform schedule.""" + + pass + + +def match_and_split(*args, nhandles=1, **kwargs): + """Helper function that splits matched handles.""" + matched = match(*args, **kwargs) + anytype = transform.AnyOpType.get() + matched_ops = transform.split_handle((anytype,) * nhandles, matched) + if nhandles == 1: + matched_ops = [matched_ops] + return matched_ops diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index e9fa7909..94cbdf01 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -11,6 +11,8 @@ apply_registered_pass, canonicalize, match, + match_and_split, + PipelineInterrupt, ) from lighthouse.dialects import smt_ext, transform_smt_ext as td_smt_ext @@ -33,22 +35,6 @@ MIN_NB_THREADS = 16 -class PipelineInterrupt(Exception): - """Exception to signal early termination of the transform schedule.""" - - pass - - -def match_and_split(*args, nhandles=1, **kwargs): - """Helper function that splits matched handles.""" - matched = match(*args, **kwargs) - anytype = transform.AnyOpType.get() - matched_ops = transform.split_handle((anytype,) * nhandles, matched) - if nhandles == 1: - matched_ops = [matched_ops] - return matched_ops - - @KnobValue.ast_rewrite(in_exprs=True) def params_with_constraints_imposed( params: dict[str, int | None], knob_name_prefix="" diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index e01f9af6..98acd4e6 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -12,19 +12,11 @@ apply_registered_pass, canonicalize, match, + match_and_split, + PipelineInterrupt, ) -def match_and_split(*args, nhandles=1, **kwargs): - """Helper function that splits matched handles.""" - matched = match(*args, **kwargs) - anytype = transform.AnyOpType.get() - matched_ops = transform.split_handle((anytype,) * nhandles, matched) - if nhandles == 1: - matched_ops = [matched_ops] - return matched_ops - - def get_softmax_schedule_module( stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None, @@ -75,121 +67,192 @@ def get_softmax_schedule_module( deduplicate=True, ) - # Match all linalg.generic and linalg.fill operations - # We have 7 operations in softmax: - # fill(max_init), max, sub, exp, fill(sum_init), sum, div - generic_ops = structured.structured_match( - transform.AnyOpType.get(), + xegpu_softmax_transform_schedule( payload_mod, - ops=["linalg.generic", "linalg.fill"] - ) - - # Split the handle into individual operation handles - anytype = transform.AnyOpType.get() - split_ops = transform.split_handle( - (anytype, anytype, anytype, anytype, anytype, anytype, anytype), # 7 result types - generic_ops - ) - - # Reverse split_ops to have operations in reverse order - split_ops = list(reversed(split_ops)) - - # The first operation (after reversal) is the division - this is the consumer - last_op = split_ops[0] - - # Tile the last operation using tile_using_forall - tiled_op, for_op = structured.structured_tile_using_forall( - anytype, anytype, - last_op, - num_threads=[], - tile_sizes=[], - static_tile_sizes=(parameters["wg_rows"],), + parameters=parameters, + stop_at_stage=stop_at_stage or "", ) + + return mod - # Fuse the producer operations into the forall loop - # Iterate through remaining operations (already in reverse order) - current_forall = for_op - for producer_op in split_ops[1:]: - fused_op, current_forall = structured.structured_fuse_into_containing_op( - anytype, anytype, - producer_op, - current_forall - ) - - func = transform.get_parent_op( - anytype, - current_forall, - op_name="func.func", - deduplicate=True, - ) - transform.apply_cse(func) - canonicalize(func) - - func = structured.VectorizeChildrenAndApplyPatternsOp( - func, - fold_type_extensions_into_contract=True, - ).result - transform.apply_cse(func) - canonicalize(func) - payload_mod = apply_registered_pass(payload_mod, "eliminate-empty-tensors") - identity_layout = LayoutMapOption.IdentityLayoutMap - payload_mod = transform_bufferization.OneShotBufferizeOp( - payload_mod, - allow_return_allocs_from_loops=True, - bufferize_function_boundaries=True, - function_boundary_type_conversion=identity_layout, - ).result - # fold memref.subviews into vector.transfer_read/write ops - payload_mod = apply_registered_pass(payload_mod, "fold-memref-alias-ops") - transform.apply_cse(payload_mod) - canonicalize(payload_mod) - - # convert forall to parallel - wg_loops = match_and_split(payload_mod, ops={"scf.forall"}) - for wg_loop in wg_loops: - wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) - func = transform.get_parent_op(anytype, wg_loop) - # convert scf.parallel to gpu.launch - func = apply_registered_pass(func, "gpu-map-parallel-loops") - func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") - func = apply_registered_pass(func, "lower-affine") - transform.apply_cse(func) - canonicalize(func) - - # set the number of threads for the gpu.launch operation - launch_op = match_and_split(func, ops={"gpu.launch"}) - num_threads = parameters["sg_rows"] * parameters["subgroup_size"] - xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) - - # outline gpu func - func = apply_registered_pass(func, "lower-affine") - canonicalize(func) - func = apply_registered_pass(func, "gpu-launch-sink-index-computations") - payload_mod = apply_registered_pass(payload_mod, "gpu-kernel-outlining") - transform.apply_cse(payload_mod) - - # set xevm target - payload_mod = apply_registered_pass( - payload_mod, - "xevm-attach-target", - options={"O": "3", "chip": "bmg"}, - ) - # convert vector to xegpu - gpu_mod_ops = match_and_split(payload_mod, ops={"gpu.module"}) - for gpu_mod in gpu_mod_ops: - gpu_func = match(gpu_mod, ops={"gpu.func"}) - gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") - transform.apply_cse(gpu_func) - - # Set layout attributes for xegpu.store_nd operations - store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) - xegpu.set_op_layout_attr(store_ops[0], sg_layout=[8, 1], sg_data=[8, 64]) - - payload_mod = apply_registered_pass( - payload_mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} - ) - # Required: yield to end the transform sequence - transform.yield_() +def xegpu_softmax_transform_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +): + """Transform schedule for softmax payload.""" + try: + mod = bundle_xegpu_softmax_schedule( + mod, + parameters=parameters, + stop_at_stage=stop_at_stage, + ) + + mod = bundle_xegpu_to_binary( + mod, + stop_at_stage=stop_at_stage, + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + + +def bundle_xegpu_softmax_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +) -> ir.Value[transform.AnyOpType]: + """Schedule for lowering softmax payload to xegpu wg level.""" + if stop_at_stage == "initial": + raise PipelineInterrupt() + + anytype = transform.AnyOpType.get() + + # Match all linalg.generic and linalg.fill operations + # We have 7 operations in softmax: + # fill(max_init), max, sub, exp, fill(sum_init), sum, div + generic_ops = structured.structured_match( + transform.AnyOpType.get(), + mod, + ops=["linalg.generic", "linalg.fill"] + ) + + # Split the handle into individual operation handles + split_ops = transform.split_handle( + (anytype, anytype, anytype, anytype, anytype, anytype, anytype), # 7 result types + generic_ops + ) + + # Reverse split_ops to have operations in reverse order + split_ops = list(reversed(split_ops)) + + # The first operation (after reversal) is the division - this is the consumer + last_op = split_ops[0] + + # Tile the last operation using tile_using_forall + tiled_op, for_op = structured.structured_tile_using_forall( + anytype, anytype, + last_op, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(parameters["wg_rows"],), + ) + + # Fuse the producer operations into the forall loop + # Iterate through remaining operations (already in reverse order) + current_forall = for_op + for producer_op in split_ops[1:]: + fused_op, current_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, + producer_op, + current_forall + ) + + func = transform.get_parent_op( + anytype, + current_forall, + op_name="func.func", + deduplicate=True, + ) + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "tiled": + raise PipelineInterrupt() + + # vectorize + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + # bufferize + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = transform_bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + # fold memref.subviews into vector.transfer_read/write ops + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + # convert forall to parallel + wg_loops = match_and_split(mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set the number of threads for the gpu.launch operation + launch_op = match_and_split(func, ops={"gpu.launch"}) + num_threads = parameters["sg_rows"] * parameters["subgroup_size"] + xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # convert vector to xegpu + gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + # Set layout attributes for xegpu.store_nd operations + store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) + xegpu.set_op_layout_attr(store_ops[0], sg_layout=[8, 1], sg_data=[8, 64]) + + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + + return mod + + +def bundle_xegpu_to_binary( + mod: ir.Value, stop_at_stage: str = "" +) -> ir.Value[transform.AnyOpType]: + """Schedule for lowering xegpu wg level to binary.""" + # upstream xegpu/xevm pipeline is payload independent. + mod = apply_registered_pass( + mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} + ) + + if stop_at_stage == "final": + raise PipelineInterrupt() + return mod From 64b5d73f52813059ad51293e028c0c1f2cabd459 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 24 Mar 2026 22:47:05 +0000 Subject: [PATCH 12/63] save working version --- examples/xegpu/softmax.py | 1 + lighthouse/schedule/xegpu/softmax_schedule.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index 474eb1a3..c31af47a 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -260,6 +260,7 @@ def parse_cli(): args = parse_cli() params = { + "sizes": args.sizes, "wg_rows": args.wg_rows, "sg_rows": args.sg_rows, "subgroup_size": args.subgroup_size, diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 98acd4e6..26bc5f3e 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -206,7 +206,8 @@ def bundle_xegpu_softmax_schedule( # set the number of threads for the gpu.launch operation launch_op = match_and_split(func, ops={"gpu.launch"}) - num_threads = parameters["sg_rows"] * parameters["subgroup_size"] + num_subgroups = parameters["wg_rows"] // parameters["sg_rows"] + num_threads = num_subgroups * parameters["subgroup_size"] xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) # outline gpu func @@ -233,9 +234,12 @@ def bundle_xegpu_softmax_schedule( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - # Set layout attributes for xegpu.store_nd operations + # Set layout attributes for xegpu.store_nd operations. + # FIXME: currently ecah subgroup is handling the entire row. store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) - xegpu.set_op_layout_attr(store_ops[0], sg_layout=[8, 1], sg_data=[8, 64]) + sg_layout = [parameters["sg_rows"], 1] + sg_data = [parameters["sg_rows"], parameters["sizes"][1]] + xegpu.set_op_layout_attr(store_ops[0], sg_layout=sg_layout, sg_data=sg_data) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From 108f2c09e8ee111b6b96f97553bc9bacead0013e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 25 Mar 2026 20:48:48 +0000 Subject: [PATCH 13/63] save working version --- lighthouse/schedule/xegpu/mlp_schedule.py | 13 +------------ lighthouse/schedule/xegpu/softmax_schedule.py | 17 ++--------------- 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 94cbdf01..ab11a75c 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -14,6 +14,7 @@ match_and_split, PipelineInterrupt, ) +from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary from lighthouse.dialects import smt_ext, transform_smt_ext as td_smt_ext from lighthouse.dialects.transform_tune_ext import knob, KnobValue @@ -600,15 +601,3 @@ def annotate_ab_load(tile, layout_load, layout_dpas): canonicalize(gpu_func) transform.apply_cse(gpu_func) - - -def bundle_xegpu_to_binary( - mod: ir.Value, stop_at_stage: str = "" -) -> ir.Value[transform.AnyOpType]: - """Schedule for lowering xegpu wg level to binary.""" - # upstream xegpu/xevm pipeline is payload independent. - mod = apply_registered_pass( - mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} - ) - - return mod diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 26bc5f3e..9e5226b9 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -15,6 +15,7 @@ match_and_split, PipelineInterrupt, ) +from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary def get_softmax_schedule_module( @@ -38,6 +39,7 @@ def get_softmax_schedule_module( - wg_rows: Number of rows per workgroup - sg_rows: Number of rows per subgroup - subgroup_size: Size of subgroup + - sizes: Tuple with the sizes of the input tensors (e.g. (M, N)) Returns: MLIR module containing the transform schedule @@ -245,18 +247,3 @@ def bundle_xegpu_softmax_schedule( raise PipelineInterrupt() return mod - - -def bundle_xegpu_to_binary( - mod: ir.Value, stop_at_stage: str = "" -) -> ir.Value[transform.AnyOpType]: - """Schedule for lowering xegpu wg level to binary.""" - # upstream xegpu/xevm pipeline is payload independent. - mod = apply_registered_pass( - mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} - ) - - if stop_at_stage == "final": - raise PipelineInterrupt() - - return mod From a7e1e6c7535c5dd085c14ebea39ae13ac25eeb69 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 25 Mar 2026 20:58:29 +0000 Subject: [PATCH 14/63] save working version --- lighthouse/schedule/xegpu/helper.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 lighthouse/schedule/xegpu/helper.py diff --git a/lighthouse/schedule/xegpu/helper.py b/lighthouse/schedule/xegpu/helper.py new file mode 100644 index 00000000..1452301e --- /dev/null +++ b/lighthouse/schedule/xegpu/helper.py @@ -0,0 +1,21 @@ +"""Helper functions for XeGPU scheduling.""" + +from mlir import ir +from mlir.dialects import transform + +from lighthouse.pipeline.helper import apply_registered_pass, PipelineInterrupt + + +def bundle_xegpu_to_binary( + mod: ir.Value, stop_at_stage: str = "" +) -> ir.Value[transform.AnyOpType]: + """Schedule for lowering xegpu wg level to binary.""" + # upstream xegpu/xevm pipeline is payload independent. + mod = apply_registered_pass( + mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} + ) + + if stop_at_stage == "final": + raise PipelineInterrupt() + + return mod From df53caa54ff2c066393f3fca9479203cc0ef471a Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 25 Mar 2026 21:09:29 +0000 Subject: [PATCH 15/63] precommit issues --- examples/xegpu/softmax.py | 12 +-- .../ingress/mlir_gen/gpu_softmax_payload.py | 68 +++++++++------ lighthouse/schedule/xegpu/helper.py | 2 +- lighthouse/schedule/xegpu/softmax_schedule.py | 87 ++++++++++--------- 4 files changed, 97 insertions(+), 72 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index c31af47a..e3cf5840 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -28,7 +28,7 @@ def softmax_complexity(M: int, N: int, nbytes: int): """ Complexity of softmax operation. - + For each row: - O(N) to find max - O(N) to compute exp(x - max) and sum @@ -127,7 +127,7 @@ def check_correctness( output_ref = self._reference_solution output_computed = output_host.astype(np.float32) - + if verbose > 1: print("Reference solution (first 5 rows):") print(output_ref[:5]) @@ -137,10 +137,10 @@ def check_correctness( # Check row sums are close to 1.0 row_sums = np.sum(output_computed, axis=1) sums_ok = np.allclose(row_sums, 1.0, rtol=1e-5, atol=1e-6) - + # Check values match reference values_ok = np.allclose(output_computed, output_ref, rtol=1e-4, atol=1e-6) - + success = sums_ok and values_ok if verbose: @@ -149,7 +149,9 @@ def check_correctness( else: print("FAILED!") if not sums_ok: - print(f" Row sums check failed. Min: {row_sums.min():.6f}, Max: {row_sums.max():.6f}") + print( + f" Row sums check failed. Min: {row_sums.min():.6f}, Max: {row_sums.max():.6f}" + ) if not values_ok: max_diff = np.abs(output_computed - output_ref).max() print(f" Values mismatch. Max abs diff: {max_diff:.6e}") diff --git a/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py b/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py index 26fd9042..4568448e 100644 --- a/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py @@ -1,10 +1,13 @@ """Generate MLIR payload for GPU softmax operation.""" from mlir import ir -from mlir.dialects import linalg, bufferization, arith, tensor, func, math +from mlir.dialects import linalg, bufferization, arith, tensor, math from lighthouse.utils.mlir import func_cif -from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs, emit_buf_to_tensor +from lighthouse.ingress.mlir_gen.gpu_utils import ( + emit_gpu_util_funcs, + emit_buf_to_tensor, +) def generate_gpu_softmax_payload( @@ -15,33 +18,33 @@ def generate_gpu_softmax_payload( ) -> ir.Module: """ Generate MLIR module for softmax payload. - + Computes softmax along the last dimension (rows): output[i, j] = exp(input[i, j] - max_i) / sum_i(exp(input[i, j] - max_i)) - + where max_i and sum_i are computed over row i. - + Args: func_name: Name of the payload function M: Number of rows - N: Number of columns + N: Number of columns dtype: MLIR element type (e.g., F32Type) - + Returns: MLIR module containing the softmax payload function """ mod = ir.Module.create() shape = (M, N) memref_t = ir.MemRefType.get(shape, dtype) - + with ir.InsertionPoint(mod.body): # Function signature: payload(output, input) @func_cif(memref_t, memref_t, name=func_name) def payload(output, input_arg): # Convert memrefs to tensors - output_tensor = emit_buf_to_tensor(output, restrict=True, writable=True) + emit_buf_to_tensor(output, restrict=True, writable=True) input_tensor = emit_buf_to_tensor(input_arg, restrict=True) - + # Define affine maps for indexing # #map = affine_map<(d0, d1) -> (d0, d1)> (identity 2D) # #map1 = affine_map<(d0, d1) -> (d0)> (broadcast/reduce along d1) @@ -49,67 +52,82 @@ def payload(output, input_arg): d1 = ir.AffineDimExpr.get(1) map_2d = ir.AffineMap.get(2, 0, [d0, d1]) map_1d = ir.AffineMap.get(2, 0, [d0]) - + # Step 1: Find max - linalg.generic reduction - neg_inf = arith.constant(dtype, float('-inf')) + neg_inf = arith.constant(dtype, float("-inf")) max_init = tensor.empty((M,), dtype) max_filled = linalg.fill(neg_inf, outs=[max_init]) - + @linalg.generic( [input_tensor], # inputs [max_filled], # outputs [map_2d, map_1d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.reduction], # iterator_types + [ + linalg.IteratorType.parallel, + linalg.IteratorType.reduction, + ], # iterator_types ) def row_max(in_val, acc): return arith.maximumf(in_val, acc) - + # Step 2: Subtract max (broadcast) - linalg.generic elementwise output_init = tensor.empty((M, N), dtype) - + @linalg.generic( [input_tensor, row_max], # inputs [output_init], # outputs [map_2d, map_1d, map_2d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + [ + linalg.IteratorType.parallel, + linalg.IteratorType.parallel, + ], # iterator_types ) def shifted(in_val, max_val, out): return arith.subf(in_val, max_val) - + # Step 3: Compute exp - linalg.generic elementwise @linalg.generic( [shifted], # inputs [output_init], # outputs [map_2d, map_2d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + [ + linalg.IteratorType.parallel, + linalg.IteratorType.parallel, + ], # iterator_types ) def exp_vals(in_val, out): return math.exp(in_val) - + # Step 4: Sum exp values - linalg.generic reduction sum_init = tensor.empty((M,), dtype) zero = arith.constant(dtype, 0.0) sum_filled = linalg.fill(zero, outs=[sum_init]) - + @linalg.generic( [exp_vals], # inputs [sum_filled], # outputs [map_2d, map_1d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.reduction], # iterator_types + [ + linalg.IteratorType.parallel, + linalg.IteratorType.reduction, + ], # iterator_types ) def row_sum(in_val, acc): return arith.addf(in_val, acc) - + # Step 5: Divide by sum (broadcast) - linalg.generic elementwise @linalg.generic( [exp_vals, row_sum], # inputs [output_init], # outputs [map_2d, map_1d, map_2d], # indexing_maps - [linalg.IteratorType.parallel, linalg.IteratorType.parallel], # iterator_types + [ + linalg.IteratorType.parallel, + linalg.IteratorType.parallel, + ], # iterator_types ) def result(exp_val, sum_val, out): return arith.divf(exp_val, sum_val) - + # Materialize result back to output memref bufferization.materialize_in_destination( None, result, output, restrict=True, writable=True diff --git a/lighthouse/schedule/xegpu/helper.py b/lighthouse/schedule/xegpu/helper.py index 1452301e..0c1d93a7 100644 --- a/lighthouse/schedule/xegpu/helper.py +++ b/lighthouse/schedule/xegpu/helper.py @@ -14,7 +14,7 @@ def bundle_xegpu_to_binary( mod = apply_registered_pass( mod, "gpu-lower-to-xevm-pipeline", options={"xegpu-op-level": "workgroup"} ) - + if stop_at_stage == "final": raise PipelineInterrupt() diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 9e5226b9..d9701b84 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -24,7 +24,7 @@ def get_softmax_schedule_module( ) -> ir.Module: """ Generate transform schedule for softmax operation. - + The schedule performs the following transformations: 1. Tile the consumer operation (division) using forall 2. Fuse producer operations into the forall loop @@ -32,32 +32,32 @@ def get_softmax_schedule_module( 4. Bufferize tensors 5. Convert to GPU dialect 6. Lower to XeGPU operations - + Args: stop_at_stage: Optional stage name to stop early (for debugging) parameters: Dictionary with scheduling parameters: - wg_rows: Number of rows per workgroup - - sg_rows: Number of rows per subgroup + - sg_rows: Number of rows per subgroup - subgroup_size: Size of subgroup - sizes: Tuple with the sizes of the input tensors (e.g. (M, N)) - + Returns: MLIR module containing the transform schedule """ assert parameters is not None, "Schedule parameters must be provided" - + mod = ir.Module.create() mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() - + with ir.InsertionPoint(mod.body): # Create a transform sequence with proper signature named_sequence = transform.named_sequence( "__transform_main", [transform.AnyOpType.get()], # input: module [], # no outputs - arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}] + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], ) - + with ir.InsertionPoint(named_sequence.body): # match the payload module anytype = transform.AnyOpType.get() @@ -68,13 +68,13 @@ def get_softmax_schedule_module( op_name="builtin.module", deduplicate=True, ) - + xegpu_softmax_transform_schedule( payload_mod, parameters=parameters, stop_at_stage=stop_at_stage or "", ) - + return mod @@ -107,36 +107,43 @@ def bundle_xegpu_softmax_schedule( stop_at_stage: str = "", ) -> ir.Value[transform.AnyOpType]: """Schedule for lowering softmax payload to xegpu wg level.""" - + if stop_at_stage == "initial": raise PipelineInterrupt() - + anytype = transform.AnyOpType.get() - + # Match all linalg.generic and linalg.fill operations - # We have 7 operations in softmax: + # We have 7 operations in softmax: # fill(max_init), max, sub, exp, fill(sum_init), sum, div generic_ops = structured.structured_match( - transform.AnyOpType.get(), - mod, - ops=["linalg.generic", "linalg.fill"] + transform.AnyOpType.get(), mod, ops=["linalg.generic", "linalg.fill"] ) - + # Split the handle into individual operation handles split_ops = transform.split_handle( - (anytype, anytype, anytype, anytype, anytype, anytype, anytype), # 7 result types - generic_ops + ( + anytype, + anytype, + anytype, + anytype, + anytype, + anytype, + anytype, + ), # 7 result types + generic_ops, ) - + # Reverse split_ops to have operations in reverse order split_ops = list(reversed(split_ops)) - + # The first operation (after reversal) is the division - this is the consumer last_op = split_ops[0] # Tile the last operation using tile_using_forall tiled_op, for_op = structured.structured_tile_using_forall( - anytype, anytype, + anytype, + anytype, last_op, num_threads=[], tile_sizes=[], @@ -148,11 +155,9 @@ def bundle_xegpu_softmax_schedule( current_forall = for_op for producer_op in split_ops[1:]: fused_op, current_forall = structured.structured_fuse_into_containing_op( - anytype, anytype, - producer_op, - current_forall + anytype, anytype, producer_op, current_forall ) - + func = transform.get_parent_op( anytype, current_forall, @@ -161,10 +166,10 @@ def bundle_xegpu_softmax_schedule( ) transform.apply_cse(func) canonicalize(func) - + if stop_at_stage == "tiled": raise PipelineInterrupt() - + # vectorize func = structured.VectorizeChildrenAndApplyPatternsOp( func, @@ -172,10 +177,10 @@ def bundle_xegpu_softmax_schedule( ).result transform.apply_cse(func) canonicalize(func) - + if stop_at_stage == "vectorized": raise PipelineInterrupt() - + # bufferize mod = apply_registered_pass(mod, "eliminate-empty-tensors") identity_layout = LayoutMapOption.IdentityLayoutMap @@ -189,36 +194,36 @@ def bundle_xegpu_softmax_schedule( mod = apply_registered_pass(mod, "fold-memref-alias-ops") transform.apply_cse(mod) canonicalize(mod) - + if stop_at_stage == "bufferized": raise PipelineInterrupt() - + # convert forall to parallel wg_loops = match_and_split(mod, ops={"scf.forall"}) for wg_loop in wg_loops: wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) func = transform.get_parent_op(anytype, wg_loop) - + # convert scf.parallel to gpu.launch func = apply_registered_pass(func, "gpu-map-parallel-loops") func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") func = apply_registered_pass(func, "lower-affine") transform.apply_cse(func) canonicalize(func) - + # set the number of threads for the gpu.launch operation launch_op = match_and_split(func, ops={"gpu.launch"}) num_subgroups = parameters["wg_rows"] // parameters["sg_rows"] num_threads = num_subgroups * parameters["subgroup_size"] xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) - + # outline gpu func func = apply_registered_pass(func, "lower-affine") canonicalize(func) func = apply_registered_pass(func, "gpu-launch-sink-index-computations") mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) - + # set xevm target mod = apply_registered_pass( mod, @@ -232,18 +237,18 @@ def bundle_xegpu_softmax_schedule( gpu_func = match(gpu_mod, ops={"gpu.func"}) gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) - + if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - + # Set layout attributes for xegpu.store_nd operations. # FIXME: currently ecah subgroup is handling the entire row. store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) sg_layout = [parameters["sg_rows"], 1] sg_data = [parameters["sg_rows"], parameters["sizes"][1]] xegpu.set_op_layout_attr(store_ops[0], sg_layout=sg_layout, sg_data=sg_data) - + if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() - + return mod From 9bcc6538adc23d8cb2f9186bfb24cb581919c68d Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 27 Mar 2026 22:36:53 +0000 Subject: [PATCH 16/63] use linalg.softmax --- .../ingress/mlir_gen/gpu_softmax_payload.py | 90 ++----------------- lighthouse/schedule/xegpu/softmax_schedule.py | 60 ++++--------- 2 files changed, 27 insertions(+), 123 deletions(-) diff --git a/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py b/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py index 4568448e..05dc148a 100644 --- a/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_softmax_payload.py @@ -1,7 +1,7 @@ """Generate MLIR payload for GPU softmax operation.""" from mlir import ir -from mlir.dialects import linalg, bufferization, arith, tensor, math +from mlir.dialects import linalg, bufferization, tensor from lighthouse.utils.mlir import func_cif from lighthouse.ingress.mlir_gen.gpu_utils import ( @@ -45,88 +45,16 @@ def payload(output, input_arg): emit_buf_to_tensor(output, restrict=True, writable=True) input_tensor = emit_buf_to_tensor(input_arg, restrict=True) - # Define affine maps for indexing - # #map = affine_map<(d0, d1) -> (d0, d1)> (identity 2D) - # #map1 = affine_map<(d0, d1) -> (d0)> (broadcast/reduce along d1) - d0 = ir.AffineDimExpr.get(0) - d1 = ir.AffineDimExpr.get(1) - map_2d = ir.AffineMap.get(2, 0, [d0, d1]) - map_1d = ir.AffineMap.get(2, 0, [d0]) + # Create output tensor and fill with zeros + output_init = tensor.empty(shape, dtype) - # Step 1: Find max - linalg.generic reduction - neg_inf = arith.constant(dtype, float("-inf")) - max_init = tensor.empty((M,), dtype) - max_filled = linalg.fill(neg_inf, outs=[max_init]) - - @linalg.generic( - [input_tensor], # inputs - [max_filled], # outputs - [map_2d, map_1d], # indexing_maps - [ - linalg.IteratorType.parallel, - linalg.IteratorType.reduction, - ], # iterator_types - ) - def row_max(in_val, acc): - return arith.maximumf(in_val, acc) - - # Step 2: Subtract max (broadcast) - linalg.generic elementwise - output_init = tensor.empty((M, N), dtype) - - @linalg.generic( - [input_tensor, row_max], # inputs - [output_init], # outputs - [map_2d, map_1d, map_2d], # indexing_maps - [ - linalg.IteratorType.parallel, - linalg.IteratorType.parallel, - ], # iterator_types - ) - def shifted(in_val, max_val, out): - return arith.subf(in_val, max_val) - - # Step 3: Compute exp - linalg.generic elementwise - @linalg.generic( - [shifted], # inputs - [output_init], # outputs - [map_2d, map_2d], # indexing_maps - [ - linalg.IteratorType.parallel, - linalg.IteratorType.parallel, - ], # iterator_types - ) - def exp_vals(in_val, out): - return math.exp(in_val) - - # Step 4: Sum exp values - linalg.generic reduction - sum_init = tensor.empty((M,), dtype) - zero = arith.constant(dtype, 0.0) - sum_filled = linalg.fill(zero, outs=[sum_init]) - - @linalg.generic( - [exp_vals], # inputs - [sum_filled], # outputs - [map_2d, map_1d], # indexing_maps - [ - linalg.IteratorType.parallel, - linalg.IteratorType.reduction, - ], # iterator_types - ) - def row_sum(in_val, acc): - return arith.addf(in_val, acc) - - # Step 5: Divide by sum (broadcast) - linalg.generic elementwise - @linalg.generic( - [exp_vals, row_sum], # inputs - [output_init], # outputs - [map_2d, map_1d, map_2d], # indexing_maps - [ - linalg.IteratorType.parallel, - linalg.IteratorType.parallel, - ], # iterator_types + # Apply softmax along dimension 1 (last dimension) + result = linalg.softmax( + result=[ir.RankedTensorType.get(shape, dtype)], + input=input_tensor, + output=output_init, + dimension=1, ) - def result(exp_val, sum_val, out): - return arith.divf(exp_val, sum_val) # Materialize result back to output memref bufferization.materialize_in_destination( diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index d9701b84..35c3ab4c 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -26,12 +26,11 @@ def get_softmax_schedule_module( Generate transform schedule for softmax operation. The schedule performs the following transformations: - 1. Tile the consumer operation (division) using forall - 2. Fuse producer operations into the forall loop - 3. Vectorize operations - 4. Bufferize tensors - 5. Convert to GPU dialect - 6. Lower to XeGPU operations + 1. Tile the linalg.softmax operation using forall + 2. Vectorize operations + 3. Bufferize tensors + 4. Convert to GPU dialect + 5. Lower to XeGPU operations Args: stop_at_stage: Optional stage name to stop early (for debugging) @@ -113,57 +112,34 @@ def bundle_xegpu_softmax_schedule( anytype = transform.AnyOpType.get() - # Match all linalg.generic and linalg.fill operations - # We have 7 operations in softmax: - # fill(max_init), max, sub, exp, fill(sum_init), sum, div - generic_ops = structured.structured_match( - transform.AnyOpType.get(), mod, ops=["linalg.generic", "linalg.fill"] + # Match linalg.softmax operation + # We have only 1 operation: linalg.softmax + softmax_op = structured.structured_match( + transform.AnyOpType.get(), mod, ops=["linalg.softmax"] ) - # Split the handle into individual operation handles - split_ops = transform.split_handle( - ( - anytype, - anytype, - anytype, - anytype, - anytype, - anytype, - anytype, - ), # 7 result types - generic_ops, - ) - - # Reverse split_ops to have operations in reverse order - split_ops = list(reversed(split_ops)) - - # The first operation (after reversal) is the division - this is the consumer - last_op = split_ops[0] - - # Tile the last operation using tile_using_forall + # Tile the softmax operation using tile_using_forall tiled_op, for_op = structured.structured_tile_using_forall( anytype, anytype, - last_op, + softmax_op, num_threads=[], tile_sizes=[], static_tile_sizes=(parameters["wg_rows"],), ) - # Fuse the producer operations into the forall loop - # Iterate through remaining operations (already in reverse order) - current_forall = for_op - for producer_op in split_ops[1:]: - fused_op, current_forall = structured.structured_fuse_into_containing_op( - anytype, anytype, producer_op, current_forall - ) - func = transform.get_parent_op( anytype, - current_forall, + for_op, op_name="func.func", deduplicate=True, ) + # Decompose softmax into linalg.generic operations + softmax_ops = structured.structured_match( + transform.AnyOpType.get(), func, ops=["linalg.softmax"] + ) + structured.structured_decompose_interface(anytype, softmax_ops) + transform.apply_cse(func) canonicalize(func) From 3f5cbceacfca0cf15ea47bab5e29e1b62de48b4e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 30 Mar 2026 19:28:26 +0000 Subject: [PATCH 17/63] save work --- examples/xegpu/softmax.py | 17 +++++++++-- lighthouse/schedule/xegpu/softmax_schedule.py | 29 +++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index e3cf5840..27a58ec2 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -217,6 +217,12 @@ def parse_cli(): default=16, help="Subgroup size.", ) + parser.add_argument( + "--reduction-step-size", + type=int, + default=16, + help="Step size for reduction loop tiling (optional).", + ) parser.add_argument( "--nruns", type=int, @@ -266,6 +272,7 @@ def parse_cli(): "wg_rows": args.wg_rows, "sg_rows": args.sg_rows, "subgroup_size": args.subgroup_size, + "reduction_step_size": args.reduction_step_size, } M, N = args.sizes @@ -304,7 +311,13 @@ def list2str(a): f"wg-rows={args.wg_rows}", f"sg-rows={args.sg_rows}", f"subgroup-size={args.subgroup_size}", - f"time(us): {elapsed:.2f}", - f"GFLOPS: {gflops:.2f}", ] + if args.reduction_step_size is not None: + parts.append(f"reduction-step-size={args.reduction_step_size}") + parts.extend( + [ + f"time(us): {elapsed:.2f}", + f"GFLOPS: {gflops:.2f}", + ] + ) print(" ".join(parts)) diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 35c3ab4c..11410612 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -39,6 +39,7 @@ def get_softmax_schedule_module( - sg_rows: Number of rows per subgroup - subgroup_size: Size of subgroup - sizes: Tuple with the sizes of the input tensors (e.g. (M, N)) + - reduction_step_size: Optional step size for tiling reduction loops Returns: MLIR module containing the transform schedule @@ -140,6 +141,34 @@ def bundle_xegpu_softmax_schedule( ) structured.structured_decompose_interface(anytype, softmax_ops) + linalg_ops = match_and_split( + func, ops={"linalg.generic", "linalg.fill"}, nhandles=6 + ) + init_max_reduction = linalg_ops[0] + max_reduction = linalg_ops[1] + max_center_and_exp_op = linalg_ops[2] + init_sum_reduction = linalg_ops[3] + sum_reduction = linalg_ops[4] + div_op = linalg_ops[5] + + reduction_step_size = parameters["reduction_step_size"] + + # Tile the max reduction using TileReductionUsingFor + _, _, _, for_op = structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=max_reduction, + tile_sizes=[0, reduction_step_size], + ) + + # Fuse the init_max_reduction into the for loop + # fused_init, new_for_loop = structured.structured_fuse_into_containing_op( + # anytype, anytype, init_max_reduction, for_op + # ) + transform.PrintOp(target=init_max_reduction) + transform.apply_cse(func) canonicalize(func) From 6204d6c2a59cb8d816f87b2fc811d96163a613a3 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 30 Mar 2026 22:21:34 +0000 Subject: [PATCH 18/63] add inner dim tiling --- lighthouse/schedule/xegpu/softmax_schedule.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 11410612..9bda77fc 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -144,30 +144,28 @@ def bundle_xegpu_softmax_schedule( linalg_ops = match_and_split( func, ops={"linalg.generic", "linalg.fill"}, nhandles=6 ) - init_max_reduction = linalg_ops[0] max_reduction = linalg_ops[1] max_center_and_exp_op = linalg_ops[2] - init_sum_reduction = linalg_ops[3] sum_reduction = linalg_ops[4] div_op = linalg_ops[5] reduction_step_size = parameters["reduction_step_size"] - # Tile the max reduction using TileReductionUsingFor - _, _, _, for_op = structured.structured_tile_reduction_using_for( - [anytype], - anytype, - anytype, - anytype, - target=max_reduction, - tile_sizes=[0, reduction_step_size], - ) - - # Fuse the init_max_reduction into the for loop - # fused_init, new_for_loop = structured.structured_fuse_into_containing_op( - # anytype, anytype, init_max_reduction, for_op - # ) - transform.PrintOp(target=init_max_reduction) + # Tile all reduction ops using the same step size + reduction_ops = [max_reduction, sum_reduction] + for reduction_op in reduction_ops: + structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=reduction_op, + tile_sizes=[0, reduction_step_size], + ) + # Tile elementwise ops to match the reduction tile size + elementwise_ops = [max_center_and_exp_op, div_op] + for elementwise_op in elementwise_ops: + structured.TileUsingForOp(elementwise_op, sizes=[0, reduction_step_size]) transform.apply_cse(func) canonicalize(func) From 1feb0d48d94e6cc46fb6cbf7c01dc9825e43f73a Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 1 Apr 2026 22:09:26 +0000 Subject: [PATCH 19/63] save fused version --- examples/xegpu/softmax.py | 1 + lighthouse/schedule/xegpu/softmax_schedule.py | 76 +++++++++++++++---- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index 324d6e0c..8049f0f4 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -212,6 +212,7 @@ def parse_cli(): "tiled", "vectorized", "bufferized", + "gpu-outlining", "xegpu-initial", "xegpu-wg", "final", diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 361dfdfe..b42b6479 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -140,6 +140,7 @@ def bundle_xegpu_softmax_schedule( transform.AnyOpType.get(), func, ops=["linalg.softmax"] ) structured.structured_decompose_interface(anytype, softmax_ops) + # transform.print_(target=func, name="After structured_decompose_interface") linalg_ops = match_and_split( func, ops={"linalg.generic", "linalg.fill"}, nhandles=6 @@ -151,22 +152,60 @@ def bundle_xegpu_softmax_schedule( reduction_step_size = parameters["reduction_step_size"] - # Tile all reduction ops using the same step size - reduction_ops = [max_reduction, sum_reduction] - for reduction_op in reduction_ops: - structured.structured_tile_reduction_using_for( - [anytype], - anytype, - anytype, - anytype, - target=reduction_op, - tile_sizes=[0, reduction_step_size], - ) - # Tile elementwise ops to match the reduction tile size - elementwise_ops = [max_center_and_exp_op, div_op] - for elementwise_op in elementwise_ops: - structured.TileUsingForOp(elementwise_op, sizes=[0, reduction_step_size]) + # Tile the division op and fuse the sub+exp producer into it + _, div_loop = structured.TileUsingForOp( + div_op, sizes=[0, reduction_step_size] + ).results + + # Fuse max_center_and_exp_op into the div loop + _, fused_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=max_center_and_exp_op, + containing_op=div_loop, + ) + + # Tile the sum reduction and fuse the sub+exp producer into it + _, _, _, sum_loop = structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=sum_reduction, + tile_sizes=[0, reduction_step_size], + ) + + func = transform.get_parent_op( + anytype, + fused_loop, + op_name="func.func", + deduplicate=True, + ) + # Re-match and split linalg generic ops, there are 5 at this point + linalg_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=5) + max_center_and_exp_op = linalg_ops[1] + + # Fuse max_center_and_exp_op into the sum reduction loop + _, fused_sum_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=max_center_and_exp_op, + containing_op=sum_loop, + ) + + # Tile the max reduction. + max_reduction = linalg_ops[0] + structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=max_reduction, + tile_sizes=[0, reduction_step_size], + ) + + # Cleanup after tiling and fusion transform.apply_cse(func) canonicalize(func) @@ -227,6 +266,9 @@ def bundle_xegpu_softmax_schedule( mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + # set xevm target mod = apply_registered_pass( mod, @@ -241,6 +283,10 @@ def bundle_xegpu_softmax_schedule( gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) + # Cleanup. + transform.apply_cse(mod) + canonicalize(mod) + if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() From a28cf4a1035cf12956aa9922ff0fbd3a46e51314 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 1 Apr 2026 22:30:40 +0000 Subject: [PATCH 20/63] save work --- lighthouse/schedule/xegpu/softmax_schedule.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index b42b6479..18186b89 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -237,6 +237,17 @@ def bundle_xegpu_softmax_schedule( transform.apply_cse(mod) canonicalize(mod) + # promote memref.alloc to memref.alloca in payload function + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + if stop_at_stage == "bufferized": raise PipelineInterrupt() From 79e2f737caae45431f207744187a7119c5504b12 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 3 Apr 2026 19:16:35 +0000 Subject: [PATCH 21/63] save work --- lighthouse/schedule/xegpu/softmax_schedule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 18186b89..bda5f819 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -303,10 +303,10 @@ def bundle_xegpu_softmax_schedule( # Set layout attributes for xegpu.store_nd operations. # FIXME: currently ecah subgroup is handling the entire row. - store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) + store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=5) sg_layout = [parameters["sg_rows"], 1] - sg_data = [parameters["sg_rows"], parameters["sizes"][1]] - xegpu.set_anchor_layout(store_ops[0], sg_layout=sg_layout, sg_data=sg_data) + sg_data = [parameters["sg_rows"], parameters["reduction_step_size"]] + xegpu.set_anchor_layout(store_ops[-1], sg_layout=sg_layout, sg_data=sg_data) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From 55c175c0f69c511f6d753b0f7c55b5480d5303b3 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 3 Apr 2026 22:24:15 +0000 Subject: [PATCH 22/63] save work --- docs/softmax_lowering.md | 402 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 402 insertions(+) create mode 100644 docs/softmax_lowering.md diff --git a/docs/softmax_lowering.md b/docs/softmax_lowering.md new file mode 100644 index 00000000..f67779ae --- /dev/null +++ b/docs/softmax_lowering.md @@ -0,0 +1,402 @@ +# Linalg softmax lowering in XeGPU pipeline + +## Overview + +The lowering process consists of seven stages: +1. **initial** - High-level tensor operations +2. **tiled-softmax** - Tiled softmax operations +3. **decomposed** - Decomposition into constituent operations +4. **vectorized** - Vector operations +5. **bufferized** - Memory-based representation +6. **xegpu-initial** - GPU kernel with XeGPU operations +7. **xegpu-wg** - Work-group optimized XeGPU + +--- + +## Stage 1: Initial + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + %2 = tensor.empty() : tensor<1024x64xf32> + %3 = linalg.softmax dimension(1) ins(%1 : tensor<1024x64xf32>) + outs(%2 : tensor<1024x64xf32>) -> tensor<1024x64xf32> + // ... + return +} +``` +--- + +## Stage 2: Tiled Softmax + +**Key Characteristics:** +- Work distribution via `scf.forall` (16 parallel iterations) +- Each tile processes 64x64 elements + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { + %4 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2) + // Extract 64x64 input slice + %extracted_slice = tensor.extract_slice ... + // Extract 64x64 output slice + %extracted_slice_0 = tensor.extract_slice ... + // Apply softmax to the tile + %5 = linalg.softmax dimension(1) ins(%extracted_slice : tensor<64x64xf32>) + outs(%extracted_slice_0 : tensor<64x64xf32>) -> tensor<64x64xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg3[%4, %c0] [64, 64] [1, 1] : + tensor<64x64xf32> into tensor<1024x64xf32> + } + } + // ... + return +} +``` + +--- + +## Stage 3: Decomposed + +**Key Characteristics:** +- Softmax decomposed into 4 constituent `linalg.generic` ops : max, sub+exp, sum, divide +- Uses `structured.structured_decompose_interface` implemented by `linalg.softmax` + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0xFFC00000 : f32 // -inf for max reduction + %0 = bufferization.to_tensor %arg1 restrict : memref<1024x64xf32> to tensor<1024x64xf32> + %1 = tensor.empty() : tensor<1024x64xf32> + + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x64xf32>) { + %3 = affine.apply #map(%arg2) // %3 = %arg2 * 64 + %extracted_slice = tensor.extract_slice %0[%3, 0] [64, 64] [1, 1] : + tensor<1024x64xf32> to tensor<64x64xf32> + + // Step 1: Find max along dimension 1 + %4 = tensor.empty() : tensor<64xf32> + %5 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<64xf32>) -> tensor<64xf32> + %6 = linalg.generic {indexing_maps = [#map1, #map2], + iterator_types = ["parallel", "reduction"]} + ins(%extracted_slice : tensor<64x64xf32>) outs(%5 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %11 = arith.maxnumf %in, %out : f32 + linalg.yield %11 : f32 + } -> tensor<64xf32> + + // Step 2: Subtract max and exponentiate + %7 = linalg.generic {indexing_maps = [#map1, #map2, #map1], + iterator_types = ["parallel", "parallel"]} + ins(%extracted_slice, %6 : tensor<64x64xf32>, tensor<64xf32>) + outs(%extracted_slice_1 : tensor<64x64xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %11 = arith.subf %in, %in_2 : f32 + %12 = math.exp %11 : f32 + linalg.yield %12 : f32 + } -> tensor<64x64xf32> + + // Step 3: Sum exponentials + %8 = linalg.fill ins(%cst : f32) outs(%4 : tensor<64xf32>) -> tensor<64xf32> + %9 = linalg.generic {indexing_maps = [#map1, #map2], + iterator_types = ["parallel", "reduction"]} + ins(%7 : tensor<64x64xf32>) outs(%8 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %11 = arith.addf %in, %out : f32 + linalg.yield %11 : f32 + } -> tensor<64xf32> + + // Step 4: Normalize by sum + %10 = linalg.generic {indexing_maps = [#map1, #map2, #map1], + iterator_types = ["parallel", "parallel"]} + ins(%7, %9 : tensor<64x64xf32>, tensor<64xf32>) + outs(%extracted_slice_1 : tensor<64x64xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %11 = arith.divf %in, %in_2 : f32 + linalg.yield %11 : f32 + } -> tensor<64x64xf32> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %10 into %arg3[%3, 0] [64, 64] [1, 1] : + tensor<64x64xf32> into tensor<1024x64xf32> + } + } + return +} +``` + +**What Happens:** +- The 1024x64 input is divided into 16 tiles of 64x64 each +- Softmax algorithm made explicit: + 1. **Max reduction**: Find maximum value per row (for numerical stability) + 2. **Exp**: Compute exp(x - max) for each element + 3. **Sum reduction**: Sum exponentials per row + 4. **Normalize**: Divide each element by its row sum +- Each tile is processed independently, enabling parallelization +- Results are inserted back into the output tensor + +--- + +## Stage 4: Vectorized + +**Key Characteristics:** +- `linalg.generic` operations replaced with vector operations +- SIMD-friendly representation using `vector<64x64xf32>` +- Explicit vector multi-reductions +- Vector transfers for reading/writing data + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + %cst = arith.constant dense<0.000000e+00> : vector<64xf32> + %0 = ub.poison : f32 + %cst_0 = arith.constant dense<0xFFC00000> : vector<64xf32> + %c0 = arith.constant 0 : index + %1 = bufferization.to_tensor %arg1 restrict : memref<1024x64xf32> to tensor<1024x64xf32> + %2 = tensor.empty() : tensor<1024x64xf32> + + %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { + %4 = affine.apply #map(%arg2) // %4 = %arg2 * 64 + %extracted_slice = tensor.extract_slice %arg3[%4, 0] [64, 64] [1, 1] + + // Vector read: Load 64x64 tile + %5 = vector.transfer_read %1[%4, %c0], %0 {in_bounds = [true, true]} : + tensor<1024x64xf32>, vector<64x64xf32> + + // Max reduction: Reduce dimension 1 -> vector<64xf32> + %6 = vector.multi_reduction , %5, %cst_0 [1] : + vector<64x64xf32> to vector<64xf32> + + // Broadcast max values back to 64x64 and transpose + %7 = vector.broadcast %6 : vector<64xf32> to vector<64x64xf32> + %8 = vector.transpose %7, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + + // Subtract max and exponentiate + %9 = arith.subf %5, %8 : vector<64x64xf32> + %10 = math.exp %9 : vector<64x64xf32> + + // Sum reduction: Reduce dimension 1 -> vector<64xf32> + %11 = vector.multi_reduction , %10, %cst [1] : + vector<64x64xf32> to vector<64xf32> + + // Broadcast sums back to 64x64 and transpose + %12 = vector.broadcast %11 : vector<64xf32> to vector<64x64xf32> + %13 = vector.transpose %12, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + + // Normalize + %14 = arith.divf %10, %13 : vector<64x64xf32> + + // Vector write + %15 = vector.transfer_write %14, %extracted_slice[%c0, %c0] {in_bounds = [true, true]} : + vector<64x64xf32>, tensor<64x64xf32> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %15 into %arg3[%4, 0] [64, 64] [1, 1] + } + } + return +} +``` + +**What Happens:** +- Linalg operations converted to vector dialect operations +- `vector.transfer_read` loads entire 64x64 tile at once +- `vector.multi_reduction` performs SIMD reductions (max and sum) +- `vector.broadcast` and `vector.transpose` handle dimension alignment +- All arithmetic operations work on vectors, enabling SIMD execution +- `vector.transfer_write` stores results back + +--- + +## Stage 5: Bufferized + +**Key Characteristics:** +- Tensors eliminated, working directly with memrefs +- Vector operations read/write directly from/to memory +- No more tensor extract/insert operations +- Simplified control flow + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + %cst = arith.constant dense<0.000000e+00> : vector<64xf32> + %0 = ub.poison : f32 + %cst_0 = arith.constant dense<0xFFC00000> : vector<64xf32> + %c0 = arith.constant 0 : index + + scf.forall (%arg2) in (16) { + %1 = affine.apply #map(%arg2) // %1 = %arg2 * 64 + + // Direct memref read + %2 = vector.transfer_read %arg1[%1, %c0], %0 {in_bounds = [true, true]} : + memref<1024x64xf32>, vector<64x64xf32> + + // Max reduction + %3 = vector.multi_reduction , %2, %cst_0 [1] : + vector<64x64xf32> to vector<64xf32> + %4 = vector.broadcast %3 : vector<64xf32> to vector<64x64xf32> + %5 = vector.transpose %4, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + + // Subtract and exp + %6 = arith.subf %2, %5 : vector<64x64xf32> + %7 = math.exp %6 : vector<64x64xf32> + + // Sum reduction + %8 = vector.multi_reduction , %7, %cst [1] : + vector<64x64xf32> to vector<64xf32> + %9 = vector.broadcast %8 : vector<64xf32> to vector<64x64xf32> + %10 = vector.transpose %9, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + + // Normalize + %11 = arith.divf %7, %10 : vector<64x64xf32> + + // Direct memref write + vector.transfer_write %11, %arg0[%1, %c0] {in_bounds = [true, true]} : + vector<64x64xf32>, memref<1024x64xf32> + } + return +} +``` + +**What Happens:** +- All tensor operations converted to memref-based operations +- `scf.forall` no longer uses `shared_outs`, simplified to pure side effects +- Vector transfers work directly on input/output memrefs +- Memory layout is now explicit +- This representation is ready for GPU kernel extraction + +--- + +## Stage 6: XeGPU-Initial + +**Key Characteristics:** +- GPU kernel separated from host code +- `gpu.launch_func` invocation with grid/block dimensions +- XeGPU tensor descriptors for memory access +- Block-based load/store operations + +**Code:** + +**Host Side:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + gpu.launch_func @payload_kernel::@payload_kernel + blocks in (%c16, %c1, %c1) + threads in (%c128, %c1, %c1) + args(%arg1 : memref<1024x64xf32>, %arg0 : memref<1024x64xf32>) + return +} +``` + +**GPU Kernel:** +```mlir +gpu.module @payload_kernel [#xevm.target] { + gpu.func @payload_kernel(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) kernel + attributes {known_block_size = array, + known_grid_size = array} { + %cst = arith.constant dense<0.000000e+00> : vector<64xf32> + %cst_0 = arith.constant dense<0xFFC00000> : vector<64xf32> + %c64 = arith.constant 64 : index + %block_id_x = gpu.block_id x + %0 = arith.muli %block_id_x, %c64 overflow : index + + // Create XeGPU tensor descriptor for load + %1 = xegpu.create_nd_tdesc %arg0 : memref<1024x64xf32> -> + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> + + // XeGPU block load + %2 = xegpu.load_nd %1[%0, 0] : + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> -> + vector<64x64xf32> + + // Same compute operations as before + %3 = vector.multi_reduction , %2, %cst_0 [1] : + vector<64x64xf32> to vector<64xf32> + %4 = vector.broadcast %3 : vector<64xf32> to vector<64x64xf32> + %5 = vector.transpose %4, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + %6 = arith.subf %2, %5 : vector<64x64xf32> + %7 = math.exp %6 : vector<64x64xf32> + %8 = vector.multi_reduction , %7, %cst [1] : + vector<64x64xf32> to vector<64xf32> + %9 = vector.broadcast %8 : vector<64xf32> to vector<64x64xf32> + %10 = vector.transpose %9, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + %11 = arith.divf %7, %10 : vector<64x64xf32> + + // Create XeGPU tensor descriptor for store + %12 = xegpu.create_nd_tdesc %arg1 : memref<1024x64xf32> -> + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> + + // XeGPU block store + xegpu.store_nd %11, %12[%0, 0] : + vector<64x64xf32>, + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> + + gpu.return + } +} +``` + +**What Happens:** +- Host function now calls `gpu.launch_func` with 16 blocks, 128 threads per block +- Separate `gpu.module` contains the kernel code +- `gpu.block_id x` replaces the loop iterator +- `xegpu.create_nd_tdesc` creates tensor descriptors for memory regions +- `xegpu.load_nd` performs hardware-optimized block loads +- `xegpu.store_nd` performs hardware-optimized block stores +- XeGPU operations map directly to Intel GPU instructions +- Boundary checking disabled for performance (sizes known at compile time) + +--- + +## Stage 7: XeGPU-WG (Work-Group Optimized) + +**Key Characteristics:** +- Additional layout hints for work-group optimization +- Sub-group layout specification: `sg_layout` and `sg_data` +- Optimized memory access patterns for Intel XeGPU + +**Code (differences from xegpu-initial):** +```mlir +// Store operation now includes layout hints +xegpu.store_nd %11, %12[%0, 0] + <{layout = #xegpu.layout}> : + vector<64x64xf32>, + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> +``` + +**What Happens:** +- The `layout` attribute provides explicit sub-group (SG) tiling information: + - `sg_layout = [8, 1]`: Data is distributed across 8 sub-groups in the first dimension + - `sg_data = [8, 64]`: Each sub-group handles an 8x64 slice of data +- This layout specification: + - Matches the 64x64 total data size (8 sub-groups × 8 rows = 64 rows, 64 columns) + - Optimizes coalesced memory accesses + - Enables efficient SIMD execution within each sub-group + - Aligns with Intel GPU hardware execution model (128 threads = 8 sub-groups of 16 threads) +- The layout is applied to the store operation to optimize write patterns +- Load operation remains unchanged as reads are typically more flexible + +--- + +## Summary + +The lowering pipeline progressively transforms abstract operations into hardware-specific instructions: + +| Stage | Abstraction Level | Key Operations | +|-------|------------------|----------------| +| **initial** | High-level ML | `linalg.softmax` on full tensor | +| **tiled-softmax** | Tiled high-level | `linalg.softmax` on tiles, `scf.forall` | +| **decomposed** | Tiled computation | `linalg.generic` with reductions | +| **vectorized** | Vector operations | `vector.multi_reduction`, `vector.transfer_read/write` | +| **bufferized** | Memory-based | Direct memref operations with vectors | +| **xegpu-initial** | GPU-specific | `xegpu.load_nd`, `xegpu.store_nd`, `gpu.launch_func` | +| **xegpu-wg** | Hardware-optimized | Layout hints for sub-group optimization | + +Each stage maintains the same computational semantics while providing increasingly detailed control over execution and memory access patterns, ultimately targeting efficient execution on Intel XeGPU hardware. From bf3a8c66a058cf4f7e8443d7247f2494437672a1 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 3 Apr 2026 22:51:42 +0000 Subject: [PATCH 23/63] save work --- docs/softmax_lowering.md | 158 ++++++++------------------------------- 1 file changed, 31 insertions(+), 127 deletions(-) diff --git a/docs/softmax_lowering.md b/docs/softmax_lowering.md index f67779ae..c5608e37 100644 --- a/docs/softmax_lowering.md +++ b/docs/softmax_lowering.md @@ -1,7 +1,10 @@ -# Linalg softmax lowering in XeGPU pipeline +# Linalg softmax lowering to XeGPU (Currently supported in lighthouse) ## Overview +**Assumptions:** +Softmax dimension size is small (64 in this example). + The lowering process consists of seven stages: 1. **initial** - High-level tensor operations 2. **tiled-softmax** - Tiled softmax operations @@ -30,7 +33,7 @@ func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { ## Stage 2: Tiled Softmax -**Key Characteristics:** +**Notes** - Work distribution via `scf.forall` (16 parallel iterations) - Each tile processes 64x64 elements @@ -61,63 +64,45 @@ func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { ## Stage 3: Decomposed -**Key Characteristics:** +**Notes** - Softmax decomposed into 4 constituent `linalg.generic` ops : max, sub+exp, sum, divide - Uses `structured.structured_decompose_interface` implemented by `linalg.softmax` **Code:** ```mlir func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %cst_0 = arith.constant 0xFFC00000 : f32 // -inf for max reduction - %0 = bufferization.to_tensor %arg1 restrict : memref<1024x64xf32> to tensor<1024x64xf32> - %1 = tensor.empty() : tensor<1024x64xf32> + // ... %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x64xf32>) { %3 = affine.apply #map(%arg2) // %3 = %arg2 * 64 - %extracted_slice = tensor.extract_slice %0[%3, 0] [64, 64] [1, 1] : - tensor<1024x64xf32> to tensor<64x64xf32> + %extracted_slice = tensor.extract_slice ... // Step 1: Find max along dimension 1 %4 = tensor.empty() : tensor<64xf32> %5 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<64xf32>) -> tensor<64xf32> - %6 = linalg.generic {indexing_maps = [#map1, #map2], - iterator_types = ["parallel", "reduction"]} - ins(%extracted_slice : tensor<64x64xf32>) outs(%5 : tensor<64xf32>) { - ^bb0(%in: f32, %out: f32): + %6 = linalg.generic // ... %11 = arith.maxnumf %in, %out : f32 - linalg.yield %11 : f32 + // ... } -> tensor<64xf32> // Step 2: Subtract max and exponentiate - %7 = linalg.generic {indexing_maps = [#map1, #map2, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%extracted_slice, %6 : tensor<64x64xf32>, tensor<64xf32>) - outs(%extracted_slice_1 : tensor<64x64xf32>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): + %7 = linalg.generic // ... %11 = arith.subf %in, %in_2 : f32 %12 = math.exp %11 : f32 - linalg.yield %12 : f32 + // ... } -> tensor<64x64xf32> // Step 3: Sum exponentials %8 = linalg.fill ins(%cst : f32) outs(%4 : tensor<64xf32>) -> tensor<64xf32> - %9 = linalg.generic {indexing_maps = [#map1, #map2], - iterator_types = ["parallel", "reduction"]} - ins(%7 : tensor<64x64xf32>) outs(%8 : tensor<64xf32>) { - ^bb0(%in: f32, %out: f32): + %9 = linalg.generic // ... %11 = arith.addf %in, %out : f32 - linalg.yield %11 : f32 + // ... } -> tensor<64xf32> // Step 4: Normalize by sum - %10 = linalg.generic {indexing_maps = [#map1, #map2, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%7, %9 : tensor<64x64xf32>, tensor<64xf32>) - outs(%extracted_slice_1 : tensor<64x64xf32>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): + %10 = linalg.generic // ... %11 = arith.divf %in, %in_2 : f32 - linalg.yield %11 : f32 + // ... } -> tensor<64x64xf32> scf.forall.in_parallel { @@ -129,39 +114,21 @@ func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { } ``` -**What Happens:** -- The 1024x64 input is divided into 16 tiles of 64x64 each -- Softmax algorithm made explicit: - 1. **Max reduction**: Find maximum value per row (for numerical stability) - 2. **Exp**: Compute exp(x - max) for each element - 3. **Sum reduction**: Sum exponentials per row - 4. **Normalize**: Divide each element by its row sum -- Each tile is processed independently, enabling parallelization -- Results are inserted back into the output tensor - --- ## Stage 4: Vectorized -**Key Characteristics:** +**Notes** - `linalg.generic` operations replaced with vector operations -- SIMD-friendly representation using `vector<64x64xf32>` -- Explicit vector multi-reductions - Vector transfers for reading/writing data **Code:** ```mlir func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - %cst = arith.constant dense<0.000000e+00> : vector<64xf32> - %0 = ub.poison : f32 - %cst_0 = arith.constant dense<0xFFC00000> : vector<64xf32> - %c0 = arith.constant 0 : index - %1 = bufferization.to_tensor %arg1 restrict : memref<1024x64xf32> to tensor<1024x64xf32> - %2 = tensor.empty() : tensor<1024x64xf32> - + // ... %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { %4 = affine.apply #map(%arg2) // %4 = %arg2 * 64 - %extracted_slice = tensor.extract_slice %arg3[%4, 0] [64, 64] [1, 1] + %extracted_slice = tensor.extract_slice .. // Vector read: Load 64x64 tile %5 = vector.transfer_read %1[%4, %c0], %0 {in_bounds = [true, true]} : @@ -201,32 +168,17 @@ func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { return } ``` - -**What Happens:** -- Linalg operations converted to vector dialect operations -- `vector.transfer_read` loads entire 64x64 tile at once -- `vector.multi_reduction` performs SIMD reductions (max and sum) -- `vector.broadcast` and `vector.transpose` handle dimension alignment -- All arithmetic operations work on vectors, enabling SIMD execution -- `vector.transfer_write` stores results back - --- ## Stage 5: Bufferized -**Key Characteristics:** +**Notes** - Tensors eliminated, working directly with memrefs -- Vector operations read/write directly from/to memory -- No more tensor extract/insert operations -- Simplified control flow **Code:** ```mlir func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - %cst = arith.constant dense<0.000000e+00> : vector<64xf32> - %0 = ub.poison : f32 - %cst_0 = arith.constant dense<0xFFC00000> : vector<64xf32> - %c0 = arith.constant 0 : index + // ... scf.forall (%arg2) in (16) { %1 = affine.apply #map(%arg2) // %1 = %arg2 * 64 @@ -262,31 +214,21 @@ func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { } ``` -**What Happens:** -- All tensor operations converted to memref-based operations -- `scf.forall` no longer uses `shared_outs`, simplified to pure side effects -- Vector transfers work directly on input/output memrefs -- Memory layout is now explicit -- This representation is ready for GPU kernel extraction - --- ## Stage 6: XeGPU-Initial -**Key Characteristics:** -- GPU kernel separated from host code +**Notes** +- GPU kernel separated from host code (Gpu Outlining) - `gpu.launch_func` invocation with grid/block dimensions -- XeGPU tensor descriptors for memory access -- Block-based load/store operations +- Use `vector-to-xegpu` **Code:** **Host Side:** ```mlir func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c128 = arith.constant 128 : index + // ... gpu.launch_func @payload_kernel::@payload_kernel blocks in (%c16, %c1, %c1) threads in (%c128, %c1, %c1) @@ -301,9 +243,7 @@ gpu.module @payload_kernel [#xevm.target] { gpu.func @payload_kernel(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) kernel attributes {known_block_size = array, known_grid_size = array} { - %cst = arith.constant dense<0.000000e+00> : vector<64xf32> - %cst_0 = arith.constant dense<0xFFC00000> : vector<64xf32> - %c64 = arith.constant 64 : index + // ... %block_id_x = gpu.block_id x %0 = arith.muli %block_id_x, %c64 overflow : index @@ -343,24 +283,14 @@ gpu.module @payload_kernel [#xevm.target] { } ``` -**What Happens:** -- Host function now calls `gpu.launch_func` with 16 blocks, 128 threads per block -- Separate `gpu.module` contains the kernel code -- `gpu.block_id x` replaces the loop iterator -- `xegpu.create_nd_tdesc` creates tensor descriptors for memory regions -- `xegpu.load_nd` performs hardware-optimized block loads -- `xegpu.store_nd` performs hardware-optimized block stores -- XeGPU operations map directly to Intel GPU instructions -- Boundary checking disabled for performance (sizes known at compile time) - --- ## Stage 7: XeGPU-WG (Work-Group Optimized) -**Key Characteristics:** -- Additional layout hints for work-group optimization -- Sub-group layout specification: `sg_layout` and `sg_data` -- Optimized memory access patterns for Intel XeGPU +**Notes** +- Sets the layout for anchor xegpu ops. Each Wg consistes of [8, 1] subgroups + doing 8x64 softmax slice. +- Only sets the layotu for `store_nd`. Layout propagation does the rest. **Code (differences from xegpu-initial):** ```mlir @@ -371,32 +301,6 @@ xegpu.store_nd %11, %12[%0, 0] !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> ``` -**What Happens:** -- The `layout` attribute provides explicit sub-group (SG) tiling information: - - `sg_layout = [8, 1]`: Data is distributed across 8 sub-groups in the first dimension - - `sg_data = [8, 64]`: Each sub-group handles an 8x64 slice of data -- This layout specification: - - Matches the 64x64 total data size (8 sub-groups × 8 rows = 64 rows, 64 columns) - - Optimizes coalesced memory accesses - - Enables efficient SIMD execution within each sub-group - - Aligns with Intel GPU hardware execution model (128 threads = 8 sub-groups of 16 threads) -- The layout is applied to the store operation to optimize write patterns -- Load operation remains unchanged as reads are typically more flexible - --- -## Summary - -The lowering pipeline progressively transforms abstract operations into hardware-specific instructions: - -| Stage | Abstraction Level | Key Operations | -|-------|------------------|----------------| -| **initial** | High-level ML | `linalg.softmax` on full tensor | -| **tiled-softmax** | Tiled high-level | `linalg.softmax` on tiles, `scf.forall` | -| **decomposed** | Tiled computation | `linalg.generic` with reductions | -| **vectorized** | Vector operations | `vector.multi_reduction`, `vector.transfer_read/write` | -| **bufferized** | Memory-based | Direct memref operations with vectors | -| **xegpu-initial** | GPU-specific | `xegpu.load_nd`, `xegpu.store_nd`, `gpu.launch_func` | -| **xegpu-wg** | Hardware-optimized | Layout hints for sub-group optimization | - -Each stage maintains the same computational semantics while providing increasingly detailed control over execution and memory access patterns, ultimately targeting efficient execution on Intel XeGPU hardware. +# Supporting larger Softmax dimension sizes. From b083887154b083d184430c0c87baf887d155dcca Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 3 Apr 2026 23:03:15 +0000 Subject: [PATCH 24/63] save work --- examples/xegpu/softmax.py | 8 -- lighthouse/schedule/xegpu/softmax_schedule.py | 90 +------------------ 2 files changed, 3 insertions(+), 95 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index 8049f0f4..afca86e3 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -181,12 +181,6 @@ def parse_cli(): default=16, help="Subgroup size.", ) - parser.add_argument( - "--reduction-step-size", - type=int, - default=16, - help="Step size for reduction loop tiling (optional).", - ) parser.add_argument( "--nruns", type=int, @@ -212,7 +206,6 @@ def parse_cli(): "tiled", "vectorized", "bufferized", - "gpu-outlining", "xegpu-initial", "xegpu-wg", "final", @@ -244,7 +237,6 @@ def parse_cli(): "wg_rows": args.wg_rows, "sg_rows": args.sg_rows, "subgroup_size": args.subgroup_size, - "reduction_step_size": args.reduction_step_size, } M, N = args.sizes diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index bda5f819..0907bc5a 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -39,7 +39,6 @@ def get_softmax_schedule_module( - sg_rows: Number of rows per subgroup - subgroup_size: Size of subgroup - sizes: Tuple with the sizes of the input tensors (e.g. (M, N)) - - reduction_step_size: Optional step size for tiling reduction loops Returns: MLIR module containing the transform schedule @@ -140,72 +139,7 @@ def bundle_xegpu_softmax_schedule( transform.AnyOpType.get(), func, ops=["linalg.softmax"] ) structured.structured_decompose_interface(anytype, softmax_ops) - # transform.print_(target=func, name="After structured_decompose_interface") - linalg_ops = match_and_split( - func, ops={"linalg.generic", "linalg.fill"}, nhandles=6 - ) - max_reduction = linalg_ops[1] - max_center_and_exp_op = linalg_ops[2] - sum_reduction = linalg_ops[4] - div_op = linalg_ops[5] - - reduction_step_size = parameters["reduction_step_size"] - - # Tile the division op and fuse the sub+exp producer into it - _, div_loop = structured.TileUsingForOp( - div_op, sizes=[0, reduction_step_size] - ).results - - # Fuse max_center_and_exp_op into the div loop - _, fused_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=max_center_and_exp_op, - containing_op=div_loop, - ) - - # Tile the sum reduction and fuse the sub+exp producer into it - _, _, _, sum_loop = structured.structured_tile_reduction_using_for( - [anytype], - anytype, - anytype, - anytype, - target=sum_reduction, - tile_sizes=[0, reduction_step_size], - ) - - func = transform.get_parent_op( - anytype, - fused_loop, - op_name="func.func", - deduplicate=True, - ) - - # Re-match and split linalg generic ops, there are 5 at this point - linalg_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=5) - max_center_and_exp_op = linalg_ops[1] - - # Fuse max_center_and_exp_op into the sum reduction loop - _, fused_sum_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=max_center_and_exp_op, - containing_op=sum_loop, - ) - - # Tile the max reduction. - max_reduction = linalg_ops[0] - structured.structured_tile_reduction_using_for( - [anytype], - anytype, - anytype, - anytype, - target=max_reduction, - tile_sizes=[0, reduction_step_size], - ) - - # Cleanup after tiling and fusion transform.apply_cse(func) canonicalize(func) @@ -237,17 +171,6 @@ def bundle_xegpu_softmax_schedule( transform.apply_cse(mod) canonicalize(mod) - # promote memref.alloc to memref.alloca in payload function - func = match(mod, ops={"func.func"}) - func = apply_registered_pass( - func, - "promote-buffers-to-stack", - options={ - "max-alloc-size-in-bytes": "8192", - "max-rank-of-allocated-memref": "2", - }, - ) - if stop_at_stage == "bufferized": raise PipelineInterrupt() @@ -277,9 +200,6 @@ def bundle_xegpu_softmax_schedule( mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) - if stop_at_stage == "gpu-outlining": - raise PipelineInterrupt() - # set xevm target mod = apply_registered_pass( mod, @@ -294,19 +214,15 @@ def bundle_xegpu_softmax_schedule( gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) - # Cleanup. - transform.apply_cse(mod) - canonicalize(mod) - if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() # Set layout attributes for xegpu.store_nd operations. # FIXME: currently ecah subgroup is handling the entire row. - store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=5) + store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) sg_layout = [parameters["sg_rows"], 1] - sg_data = [parameters["sg_rows"], parameters["reduction_step_size"]] - xegpu.set_anchor_layout(store_ops[-1], sg_layout=sg_layout, sg_data=sg_data) + sg_data = [parameters["sg_rows"], parameters["sizes"][1]] + xegpu.set_anchor_layout(store_ops[0], sg_layout=sg_layout, sg_data=sg_data) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From c02b66b6b95e34e1dfdda31508e97ec50d507a03 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 3 Apr 2026 23:16:12 +0000 Subject: [PATCH 25/63] fused version --- examples/xegpu/softmax.py | 8 ++ lighthouse/schedule/xegpu/softmax_schedule.py | 90 ++++++++++++++++++- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index afca86e3..8049f0f4 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -181,6 +181,12 @@ def parse_cli(): default=16, help="Subgroup size.", ) + parser.add_argument( + "--reduction-step-size", + type=int, + default=16, + help="Step size for reduction loop tiling (optional).", + ) parser.add_argument( "--nruns", type=int, @@ -206,6 +212,7 @@ def parse_cli(): "tiled", "vectorized", "bufferized", + "gpu-outlining", "xegpu-initial", "xegpu-wg", "final", @@ -237,6 +244,7 @@ def parse_cli(): "wg_rows": args.wg_rows, "sg_rows": args.sg_rows, "subgroup_size": args.subgroup_size, + "reduction_step_size": args.reduction_step_size, } M, N = args.sizes diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 0907bc5a..bda5f819 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -39,6 +39,7 @@ def get_softmax_schedule_module( - sg_rows: Number of rows per subgroup - subgroup_size: Size of subgroup - sizes: Tuple with the sizes of the input tensors (e.g. (M, N)) + - reduction_step_size: Optional step size for tiling reduction loops Returns: MLIR module containing the transform schedule @@ -139,7 +140,72 @@ def bundle_xegpu_softmax_schedule( transform.AnyOpType.get(), func, ops=["linalg.softmax"] ) structured.structured_decompose_interface(anytype, softmax_ops) + # transform.print_(target=func, name="After structured_decompose_interface") + linalg_ops = match_and_split( + func, ops={"linalg.generic", "linalg.fill"}, nhandles=6 + ) + max_reduction = linalg_ops[1] + max_center_and_exp_op = linalg_ops[2] + sum_reduction = linalg_ops[4] + div_op = linalg_ops[5] + + reduction_step_size = parameters["reduction_step_size"] + + # Tile the division op and fuse the sub+exp producer into it + _, div_loop = structured.TileUsingForOp( + div_op, sizes=[0, reduction_step_size] + ).results + + # Fuse max_center_and_exp_op into the div loop + _, fused_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=max_center_and_exp_op, + containing_op=div_loop, + ) + + # Tile the sum reduction and fuse the sub+exp producer into it + _, _, _, sum_loop = structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=sum_reduction, + tile_sizes=[0, reduction_step_size], + ) + + func = transform.get_parent_op( + anytype, + fused_loop, + op_name="func.func", + deduplicate=True, + ) + + # Re-match and split linalg generic ops, there are 5 at this point + linalg_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=5) + max_center_and_exp_op = linalg_ops[1] + + # Fuse max_center_and_exp_op into the sum reduction loop + _, fused_sum_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=max_center_and_exp_op, + containing_op=sum_loop, + ) + + # Tile the max reduction. + max_reduction = linalg_ops[0] + structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=max_reduction, + tile_sizes=[0, reduction_step_size], + ) + + # Cleanup after tiling and fusion transform.apply_cse(func) canonicalize(func) @@ -171,6 +237,17 @@ def bundle_xegpu_softmax_schedule( transform.apply_cse(mod) canonicalize(mod) + # promote memref.alloc to memref.alloca in payload function + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + if stop_at_stage == "bufferized": raise PipelineInterrupt() @@ -200,6 +277,9 @@ def bundle_xegpu_softmax_schedule( mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + # set xevm target mod = apply_registered_pass( mod, @@ -214,15 +294,19 @@ def bundle_xegpu_softmax_schedule( gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) + # Cleanup. + transform.apply_cse(mod) + canonicalize(mod) + if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() # Set layout attributes for xegpu.store_nd operations. # FIXME: currently ecah subgroup is handling the entire row. - store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) + store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=5) sg_layout = [parameters["sg_rows"], 1] - sg_data = [parameters["sg_rows"], parameters["sizes"][1]] - xegpu.set_anchor_layout(store_ops[0], sg_layout=sg_layout, sg_data=sg_data) + sg_data = [parameters["sg_rows"], parameters["reduction_step_size"]] + xegpu.set_anchor_layout(store_ops[-1], sg_layout=sg_layout, sg_data=sg_data) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From bce6260e57741e5439efa90c9640aff19f90c8f0 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 3 Apr 2026 23:35:41 +0000 Subject: [PATCH 26/63] tiled reduction doc --- docs/softmax_lowering.md | 331 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 330 insertions(+), 1 deletion(-) diff --git a/docs/softmax_lowering.md b/docs/softmax_lowering.md index c5608e37..a31e6f4c 100644 --- a/docs/softmax_lowering.md +++ b/docs/softmax_lowering.md @@ -303,4 +303,333 @@ xegpu.store_nd %11, %12[%0, 0] --- -# Supporting larger Softmax dimension sizes. +# Supporting larger Softmax dimension sizes + +When the softmax dimension is larger than what can fit efficiently in registers, additional tiling and fusion transformations are applied to the reduction dimension. This section shows the intermediate stages between "decomposed" and "vectorized". + +**Approach:** Tile reductions along dimension 1 (step size = 16) and fuse producers into consumers to enable streaming computation. + +--- + +## Decomposed → Tiled: Stage A - Tile div op + +**Notes:** +- Tile the division operation with step size 16 along dimension 1 +- Creates `scf.for` loop iterating over 64 elements in chunks of 16 + +**Key Changes:** +```mlir +// Before: Single division linalg.generic over 64x64 +%11 = linalg.generic {...} ins(%8, %10 : tensor<64x64xf32>, tensor<64xf32>) + outs(%extracted_slice_0 : tensor<64x64xf32>) { ... } -> tensor<64x64xf32> + +// After: Division tiled into 64x16 chunks +%11 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %extracted_slice_0) -> (tensor<64x64xf32>) { + %extracted_slice_3 = tensor.extract_slice %8[0, %arg4] [64, 16] [1, 1] + %12 = linalg.generic {...} ins(%extracted_slice_3, %extracted_slice_4 : tensor<64x16xf32>, tensor<64xf32>) + outs(%extracted_slice_5 : tensor<64x16xf32>) { ... } -> tensor<64x16xf32> + %inserted_slice = tensor.insert_slice %12 into %arg5[0, %arg4] [64, 16] [1, 1] + scf.yield %inserted_slice +} +``` + +--- + +## Stage B - Fuse sub+exp into div loop + +**Notes:** +- Fuse the `sub+exp` producer (max_center_and_exp_op) into the div loop +- Recomputes exp values on-the-fly instead of materializing full 64x64 tensor + +**Key Changes:** +```mlir +%11 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %extracted_slice_0) -> (tensor<64x64xf32>) { + %extracted_slice_3 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] + + // Fused: sub+exp computed per 16-element chunk + %12 = linalg.generic {...} ins(%extracted_slice_3, %extracted_slice_4 : tensor<64x16xf32>, tensor<64xf32>) + outs(%extracted_slice_5 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_8: f32, %out: f32): + %14 = arith.subf %in, %in_8 : f32 + %15 = math.exp %14 : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + + // Division operation + %13 = linalg.generic {...} ins(%12, %extracted_slice_6 : tensor<64x16xf32>, tensor<64xf32>) + outs(%extracted_slice_7 : tensor<64x16xf32>) { ... } -> tensor<64x16xf32> + // ... +} +``` + +--- + +## Stage C - Tile sum reduction + +**Notes:** +- Tile the sum reduction using `structured_tile_reduction_using_for` +- Creates intermediate accumulator tensor (64x16) +- Final reduction via `linalg.reduce` over dimension 1 + +**Key Changes:** +```mlir +// Tiled sum reduction with intermediate accumulator +%10 = tensor.empty() : tensor<64x16xf32> +%11 = linalg.fill ins(%cst_2 : f32) outs(%10 : tensor<64x16xf32>) -> tensor<64x16xf32> + +%12 = scf.for %arg4 = %c0_3 to %c64 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { + %extracted_slice_7 = tensor.extract_slice %8[0, %arg4] [64, 16] [1, 1] + %14 = linalg.generic {...} ins(%extracted_slice_7 : tensor<64x16xf32>) + outs(%extracted_slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %15 = arith.addf %in, %out : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + // ... +} + +// Final reduction to 64xf32 +%reduced = linalg.reduce ins(%12 : tensor<64x16xf32>) outs(%9 : tensor<64xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %14 = arith.addf %in, %init : f32 + linalg.yield %14 : f32 + } +``` + +--- + +## Stage D - Fuse sub+exp into sum loop + +**Notes:** +- Fuse `sub+exp` into the sum reduction loop +- Stream computation: compute exp and accumulate in same loop + +**Key Changes:** +```mlir +%12 = scf.for %arg4 = %c0_3 to %c64 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { + %extracted_slice_7 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] + + // Fused: sub+exp + %14 = linalg.generic {...} ins(%extracted_slice_7, %extracted_slice_8 : tensor<64x16xf32>, tensor<64xf32>) + outs(%extracted_slice_9 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_11: f32, %out: f32): + %16 = arith.subf %in, %in_11 : f32 + %17 = math.exp %16 : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + + // Accumulate sum + %15 = linalg.generic {...} ins(%14 : tensor<64x16xf32>) + outs(%extracted_slice_10 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %16 = arith.addf %in, %out : f32 + linalg.yield %16 : f32 + } -> tensor<64x16xf32> + // ... +} +``` + +--- + +## Stage E - Tile max reduction + +**Notes:** +- Tile max reduction similar to sum reduction +- Creates 64x16 intermediate accumulator +- Final reduction via `linalg.reduce` with maxnumf + +**Key Changes:** +```mlir +// Tiled max reduction +%7 = tensor.empty() : tensor<64x16xf32> +%8 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<64x16xf32>) -> tensor<64x16xf32> + +%9 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %8) -> (tensor<64x16xf32>) { + %extracted_slice_12 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] + %16 = linalg.generic {...} ins(%extracted_slice_12 : tensor<64x16xf32>) + outs(%extracted_slice_13 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %17 = arith.maxnumf %in, %out : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + // ... +} + +// Final max reduction +%reduced = linalg.reduce ins(%9 : tensor<64x16xf32>) outs(%6 : tensor<64xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %16 = arith.maxnumf %in, %init : f32 + linalg.yield %16 : f32 + } +``` + +**Result:** Now all three major computations (max, sum, div) are tiled and operate on 64x16 chunks, with exp computation fused into both sum and div loops. + +--- + +## Stage F - Vectorization + +**Notes:** +- Convert tiled linalg operations to vector operations +- `scf.for` loops remain but operate on vectors +- Vector size: 64x16 for tiled operations + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { + // ... + + // Vectorized max reduction loop + %6 = vector.transfer_write %cst_1, %5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> + %7 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %6) -> (tensor<64x16xf32>) { + %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> + %16 = vector.transfer_read %arg5[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> + %17 = arith.maxnumf %15, %16 : vector<64x16xf32> + %18 = vector.transfer_write %17, %arg5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> + scf.yield %18 : tensor<64x16xf32> + } + %8 = vector.transfer_read %7[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> + %9 = vector.multi_reduction , %8, %cst_2 [1] : vector<64x16xf32> to vector<64xf32> + + // Vectorized sum reduction loop with fused sub+exp + %11 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %10) -> (tensor<64x16xf32>) { + %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> + %16 = vector.broadcast %9 : vector<64xf32> to vector<16x64xf32> + %17 = vector.transpose %16, [1, 0] : vector<16x64xf32> to vector<64x16xf32> + %18 = arith.subf %15, %17 : vector<64x16xf32> + %19 = math.exp %18 : vector<64x16xf32> + %20 = vector.transfer_read %arg5[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> + %21 = arith.addf %19, %20 : vector<64x16xf32> + %22 = vector.transfer_write %21, %arg5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> + scf.yield %22 : tensor<64x16xf32> + } + %12 = vector.transfer_read %11[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> + %13 = vector.multi_reduction , %12, %cst_0 [1] : vector<64x16xf32> to vector<64xf32> + + // Vectorized div loop with fused sub+exp + %14 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %extracted_slice) -> (tensor<64x64xf32>) { + %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> + %16 = vector.broadcast %9 : vector<64xf32> to vector<16x64xf32> + %17 = vector.transpose %16, [1, 0] : vector<16x64xf32> to vector<64x16xf32> + %18 = arith.subf %15, %17 : vector<64x16xf32> + %19 = math.exp %18 : vector<64x16xf32> + %20 = vector.broadcast %13 : vector<64xf32> to vector<16x64xf32> + %21 = vector.transpose %20, [1, 0] : vector<16x64xf32> to vector<64x16xf32> + %22 = arith.divf %19, %21 : vector<64x16xf32> + %23 = vector.transfer_write %22, %arg5[%c0, %arg4] : vector<64x16xf32>, tensor<64x64xf32> + scf.yield %23 : tensor<64x64xf32> + } + } + // ... +} +``` + +--- + +## Stage G - Bufferization + +**Notes:** +- Convert tensors to memrefs +- Allocate stack buffer for 64x16 accumulator: `memref.alloc()` + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + scf.forall (%arg2) in (16) { + %1 = affine.apply #map(%arg2) + %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] + + // Allocate accumulator buffer + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64x16xf32> + + // Max reduction loop + vector.transfer_write %cst_1, %alloc[%c0, %c0] : vector<64x16xf32>, memref<64x16xf32> + scf.for %arg3 = %c0 to %c64 step %c16 { + %6 = vector.transfer_read %arg1[%1, %arg3], %0 : memref<1024x64xf32>, vector<64x16xf32> + %7 = vector.transfer_read %alloc[%c0, %c0], %0 : memref<64x16xf32>, vector<64x16xf32> + %8 = arith.maxnumf %6, %7 : vector<64x16xf32> + vector.transfer_write %8, %alloc[%c0, %c0] : vector<64x16xf32>, memref<64x16xf32> + } + %2 = vector.transfer_read %alloc[%c0, %c0], %0 : memref<64x16xf32>, vector<64x16xf32> + %3 = vector.multi_reduction , %2, %cst_2 [1] : vector<64x16xf32> to vector<64xf32> + + // Sum reduction loop (reuses %alloc) + // ... + + // Div loop (writes to %subview) + // ... + } +} +``` + +--- + +## Stage H - Promote buffers to stack + +**Notes:** +- Convert `memref.alloc()` to `memref.alloca()` for stack allocation +- Reduces memory allocation overhead + +**Code:** +```mlir +scf.forall (%arg2) in (16) { + %1 = affine.apply #map(%arg2) + %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] + + // Stack allocation instead of heap + %alloca = memref.alloca() {alignment = 64 : i64} : memref<64x16xf32> + + // ... same operations using %alloca ... +} +``` + +--- + +## Stage I - GPU outlining + +**Notes:** +- Convert `scf.forall` to `scf.parallel`, then to `gpu.launch` +- Extract GPU kernel into separate `gpu.module` +- Set thread count: 128 threads = (64 rows / 8 sg_rows) × 16 subgroup_size + +**Host Side:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + gpu.launch_func @payload_kernel::@payload_kernel + blocks in (%c16, %c1, %c1) + threads in (%c128, %c1, %c1) + args(%arg0 : memref<1024x64xf32>, %arg1 : memref<1024x64xf32>) + return +} +``` + +**GPU Kernel:** +```mlir +gpu.module @payload_kernel { + gpu.func @payload_kernel(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) kernel + attributes {known_block_size = array, + known_grid_size = array} { + %block_id_x = gpu.block_id x + %1 = arith.muli %block_id_x, %c64 overflow : index + %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] + %alloca = memref.alloca() {alignment = 64 : i64} : memref<64x16xf32> + + // Three reduction loops (max, sum, div) with same structure + scf.for %arg2 = %c0 to %c64 step %c16 { + // Max: accumulate max values + // Sum: compute & accumulate exp(x - max) + // Div: compute exp(x - max) / sum + } + + gpu.return + } +} +``` + +**Summary:** At this stage, the kernel processes 64x16 chunks in streaming fashion through three sequential loops, minimizing memory footprint. From 2df0777727704d1e3e14f2b6d66911908fedab27 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 3 Apr 2026 23:36:42 +0000 Subject: [PATCH 27/63] tiled reduction doc --- lighthouse/schedule/xegpu/softmax_schedule.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index bda5f819..0248919f 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -140,7 +140,7 @@ def bundle_xegpu_softmax_schedule( transform.AnyOpType.get(), func, ops=["linalg.softmax"] ) structured.structured_decompose_interface(anytype, softmax_ops) - # transform.print_(target=func, name="After structured_decompose_interface") + transform.print_(target=func, name="Aftemr structured_decompose_interface") linalg_ops = match_and_split( func, ops={"linalg.generic", "linalg.fill"}, nhandles=6 @@ -156,6 +156,8 @@ def bundle_xegpu_softmax_schedule( _, div_loop = structured.TileUsingForOp( div_op, sizes=[0, reduction_step_size] ).results + + transform.print_(target=func, name="After tiling div op") # Fuse max_center_and_exp_op into the div loop _, fused_loop = structured.structured_fuse_into_containing_op( @@ -164,6 +166,8 @@ def bundle_xegpu_softmax_schedule( producer_op=max_center_and_exp_op, containing_op=div_loop, ) + transform.print_(target=func, name="After fusing max_center_and_exp_op into div loop") + # Tile the sum reduction and fuse the sub+exp producer into it _, _, _, sum_loop = structured.structured_tile_reduction_using_for( @@ -174,6 +178,8 @@ def bundle_xegpu_softmax_schedule( target=sum_reduction, tile_sizes=[0, reduction_step_size], ) + + transform.print_(target=func, name="After tiling sum reduction") func = transform.get_parent_op( anytype, @@ -193,6 +199,7 @@ def bundle_xegpu_softmax_schedule( producer_op=max_center_and_exp_op, containing_op=sum_loop, ) + transform.print_(target=func, name="After fusing max_center_and_exp_op into sum loop") # Tile the max reduction. max_reduction = linalg_ops[0] @@ -204,6 +211,8 @@ def bundle_xegpu_softmax_schedule( target=max_reduction, tile_sizes=[0, reduction_step_size], ) + transform.print_(target=func, name="After tiling max reduction") + # Cleanup after tiling and fusion transform.apply_cse(func) @@ -219,6 +228,8 @@ def bundle_xegpu_softmax_schedule( ).result transform.apply_cse(func) canonicalize(func) + + transform.print_(target=func, name="After vectorization") if stop_at_stage == "vectorized": raise PipelineInterrupt() @@ -236,6 +247,8 @@ def bundle_xegpu_softmax_schedule( mod = apply_registered_pass(mod, "fold-memref-alias-ops") transform.apply_cse(mod) canonicalize(mod) + + transform.print_(target=mod, name="After bufferization") # promote memref.alloc to memref.alloca in payload function func = match(mod, ops={"func.func"}) @@ -247,6 +260,8 @@ def bundle_xegpu_softmax_schedule( "max-rank-of-allocated-memref": "2", }, ) + + transform.print_(target=func, name="After promoting buffers to stack") if stop_at_stage == "bufferized": raise PipelineInterrupt() @@ -276,6 +291,8 @@ def bundle_xegpu_softmax_schedule( func = apply_registered_pass(func, "gpu-launch-sink-index-computations") mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) + + transform.print_(target=mod, name="After GPU outlining") if stop_at_stage == "gpu-outlining": raise PipelineInterrupt() From 56687b7571d3160474dea5dfd4503c307a258dad Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 3 Apr 2026 23:37:03 +0000 Subject: [PATCH 28/63] tiled reduction doc --- lighthouse/schedule/xegpu/softmax_schedule.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 0248919f..7912a45b 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -156,7 +156,7 @@ def bundle_xegpu_softmax_schedule( _, div_loop = structured.TileUsingForOp( div_op, sizes=[0, reduction_step_size] ).results - + transform.print_(target=func, name="After tiling div op") # Fuse max_center_and_exp_op into the div loop @@ -166,8 +166,9 @@ def bundle_xegpu_softmax_schedule( producer_op=max_center_and_exp_op, containing_op=div_loop, ) - transform.print_(target=func, name="After fusing max_center_and_exp_op into div loop") - + transform.print_( + target=func, name="After fusing max_center_and_exp_op into div loop" + ) # Tile the sum reduction and fuse the sub+exp producer into it _, _, _, sum_loop = structured.structured_tile_reduction_using_for( @@ -178,7 +179,7 @@ def bundle_xegpu_softmax_schedule( target=sum_reduction, tile_sizes=[0, reduction_step_size], ) - + transform.print_(target=func, name="After tiling sum reduction") func = transform.get_parent_op( @@ -199,7 +200,9 @@ def bundle_xegpu_softmax_schedule( producer_op=max_center_and_exp_op, containing_op=sum_loop, ) - transform.print_(target=func, name="After fusing max_center_and_exp_op into sum loop") + transform.print_( + target=func, name="After fusing max_center_and_exp_op into sum loop" + ) # Tile the max reduction. max_reduction = linalg_ops[0] @@ -213,7 +216,6 @@ def bundle_xegpu_softmax_schedule( ) transform.print_(target=func, name="After tiling max reduction") - # Cleanup after tiling and fusion transform.apply_cse(func) canonicalize(func) @@ -228,7 +230,7 @@ def bundle_xegpu_softmax_schedule( ).result transform.apply_cse(func) canonicalize(func) - + transform.print_(target=func, name="After vectorization") if stop_at_stage == "vectorized": @@ -247,7 +249,7 @@ def bundle_xegpu_softmax_schedule( mod = apply_registered_pass(mod, "fold-memref-alias-ops") transform.apply_cse(mod) canonicalize(mod) - + transform.print_(target=mod, name="After bufferization") # promote memref.alloc to memref.alloca in payload function @@ -260,7 +262,7 @@ def bundle_xegpu_softmax_schedule( "max-rank-of-allocated-memref": "2", }, ) - + transform.print_(target=func, name="After promoting buffers to stack") if stop_at_stage == "bufferized": @@ -291,7 +293,7 @@ def bundle_xegpu_softmax_schedule( func = apply_registered_pass(func, "gpu-launch-sink-index-computations") mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) - + transform.print_(target=mod, name="After GPU outlining") if stop_at_stage == "gpu-outlining": From d2d4c49f70c10328a3e57da2b6b359f2539ad5df Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 15 Apr 2026 23:50:39 +0000 Subject: [PATCH 29/63] save work --- .../transform/transform_ext/__init__.py | 2 + .../transform_ext/ops/update_address_space.py | 102 ++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 lighthouse/dialects/transform/transform_ext/ops/update_address_space.py diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py index aba08bf0..fb705805 100644 --- a/lighthouse/dialects/transform/transform_ext/__init__.py +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -9,6 +9,7 @@ from .ops.extract_handle import extract_handle from .ops.get_tileable_consumers import get_tileable_consumers from .ops.get_tiling_sizes import get_tiling_sizes +from .ops.update_address_space import update_address_space __all__ = [ "TransformExtensionDialect", @@ -22,4 +23,5 @@ "register_and_load", "replace", "wrap_in_benching_func", + "update_address_space", ] diff --git a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py new file mode 100644 index 00000000..e0752f34 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py @@ -0,0 +1,102 @@ +from mlir import ir +from mlir.dialects import ext, transform, memref +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect + + +class UpdateAddressSpace(TransformExtensionDialect.Operation, name="update_address_space"): + """Update the address space of a memref allocation operation. + + Takes a target memref allocation operation and updates its address space + to the provided value. + """ + + target: ext.Operand[transform.AnyOpType] + address_space: ir.IntegerAttr + updated_op: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True) + + @classmethod + def attach_interface_impls(cls, ctx=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "UpdateAddressSpace", + rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + # Get the target operations to transform + target_ops = state.get_payload_ops(op.target) + + # Get the address space value from the attribute + address_space_value = ir.IntegerAttr(op.address_space).value + + new_ops = [] + + for target_op in target_ops: + # Verify this is a memref.alloca operation + if target_op.OPERATION_NAME != "memref.alloca": + return DiagnosedSilenceableFailure.emit_silenceable_error( + f"Expected memref.alloca operation, got {target_op.OPERATION_NAME}" + ) + + # Get the current result type (should be a MemRefType) + old_result_type = target_op.results[0].type + + memref_type = ir.MemRefType(old_result_type) + + # Create a new memref type with the specified address space + new_memref_type = ir.MemRefType.get( + memref_type.shape, + memref_type.element_type, + layout=memref_type.layout, + memory_space=ir.Attribute.parse(f"{address_space_value}") + ) + print(new_memref_type) + + # Replace the operation with a new one that has the updated type + with ir.InsertionPoint(target_op): + + # Get the operands from the original alloca (dynamic sizes and symbols) + dynamic_sizes = list(target_op.operands[:target_op.attributes["operandSegmentSizes"][0]]) + symbol_operands = list(target_op.operands[target_op.attributes["operandSegmentSizes"][0]:]) + + # Create a new alloca with the updated type + new_alloca = memref.alloca(new_memref_type, dynamic_sizes, symbol_operands) + print(new_alloca) + + # Replace all uses of the old operation with the new one + # rewriter.replace_all_uses_with(target_op.results[0], new_alloca.results[0]) + + # Erase the old operation + rewriter.replace_op(target_op, [new_alloca]) + + new_ops.append(new_alloca.owner) + + # Set the results to the new operations + results.set_ops(op.updated_op, new_ops) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "UpdateAddressSpace") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + transform.consumes_handle(op.op_operands[:1], effects) + transform.produces_handle(op.results, effects) + transform.modifies_payload(effects) + + +def update_address_space( + target: ir.Value, + address_space: int | ir.IntegerAttr, +) -> ir.Value: + if not isinstance(address_space, ir.IntegerAttr): + address_space = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), address_space) + return UpdateAddressSpace(target, address_space=address_space).updated_op From 32a345a87913e8f6a4ad982752dbd1fa234a03e2 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 16 Apr 2026 23:16:02 +0000 Subject: [PATCH 30/63] save work --- .../transform/transform_ext/__init__.py | 2 +- .../transform_ext/ops/update_address_space.py | 41 ++++++++++--------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py index fb705805..997522a2 100644 --- a/lighthouse/dialects/transform/transform_ext/__init__.py +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -22,6 +22,6 @@ "param_cmp_eq", "register_and_load", "replace", - "wrap_in_benching_func", "update_address_space", + "wrap_in_benching_func", ] diff --git a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py index e0752f34..09313647 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py +++ b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py @@ -5,7 +5,9 @@ from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect -class UpdateAddressSpace(TransformExtensionDialect.Operation, name="update_address_space"): +class UpdateAddressSpace( + TransformExtensionDialect.Operation, name="update_address_space" +): """Update the address space of a memref allocation operation. Takes a target memref allocation operation and updates its address space @@ -31,10 +33,8 @@ def apply( ) -> DiagnosedSilenceableFailure: # Get the target operations to transform target_ops = state.get_payload_ops(op.target) - # Get the address space value from the attribute address_space_value = ir.IntegerAttr(op.address_space).value - new_ops = [] for target_op in target_ops: @@ -46,35 +46,36 @@ def apply( # Get the current result type (should be a MemRefType) old_result_type = target_op.results[0].type - memref_type = ir.MemRefType(old_result_type) - # Create a new memref type with the specified address space new_memref_type = ir.MemRefType.get( memref_type.shape, memref_type.element_type, layout=memref_type.layout, - memory_space=ir.Attribute.parse(f"{address_space_value}") + memory_space=ir.Attribute.parse(f"{address_space_value}"), ) - print(new_memref_type) # Replace the operation with a new one that has the updated type with ir.InsertionPoint(target_op): - # Get the operands from the original alloca (dynamic sizes and symbols) - dynamic_sizes = list(target_op.operands[:target_op.attributes["operandSegmentSizes"][0]]) - symbol_operands = list(target_op.operands[target_op.attributes["operandSegmentSizes"][0]:]) - + dynamic_sizes = list( + target_op.operands[ + : target_op.attributes["operandSegmentSizes"][0] + ] + ) + symbol_operands = list( + target_op.operands[ + target_op.attributes["operandSegmentSizes"][0] : + ] + ) # Create a new alloca with the updated type - new_alloca = memref.alloca(new_memref_type, dynamic_sizes, symbol_operands) - print(new_alloca) - + new_alloca = memref.alloca( + new_memref_type, dynamic_sizes, symbol_operands + ) # Replace all uses of the old operation with the new one - # rewriter.replace_all_uses_with(target_op.results[0], new_alloca.results[0]) - - # Erase the old operation + # FIXME: This won't handle operations that consume the memref type and + # return a new memref (such as subview). rewriter.replace_op(target_op, [new_alloca]) - new_ops.append(new_alloca.owner) # Set the results to the new operations @@ -98,5 +99,7 @@ def update_address_space( address_space: int | ir.IntegerAttr, ) -> ir.Value: if not isinstance(address_space, ir.IntegerAttr): - address_space = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), address_space) + address_space = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), address_space + ) return UpdateAddressSpace(target, address_space=address_space).updated_op From eacc9d876840c7c2d27597fbfd97d605e6ac4319 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 17 Apr 2026 00:10:42 +0000 Subject: [PATCH 31/63] save work --- .../transform_ext/ops/update_address_space.py | 73 +++++++++---------- lighthouse/schedule/xegpu/softmax_schedule.py | 31 ++------ 2 files changed, 42 insertions(+), 62 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py index 09313647..2c40bfce 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py +++ b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py @@ -32,51 +32,46 @@ def apply( state: transform.TransformState, ) -> DiagnosedSilenceableFailure: # Get the target operations to transform - target_ops = state.get_payload_ops(op.target) + target_op = state.get_payload_ops(op.target)[0] # Get the address space value from the attribute address_space_value = ir.IntegerAttr(op.address_space).value new_ops = [] - for target_op in target_ops: - # Verify this is a memref.alloca operation - if target_op.OPERATION_NAME != "memref.alloca": - return DiagnosedSilenceableFailure.emit_silenceable_error( - f"Expected memref.alloca operation, got {target_op.OPERATION_NAME}" - ) - - # Get the current result type (should be a MemRefType) - old_result_type = target_op.results[0].type - memref_type = ir.MemRefType(old_result_type) - # Create a new memref type with the specified address space - new_memref_type = ir.MemRefType.get( - memref_type.shape, - memref_type.element_type, - layout=memref_type.layout, - memory_space=ir.Attribute.parse(f"{address_space_value}"), + # Verify this is a memref.alloca operation + if target_op.OPERATION_NAME != "memref.alloca": + return DiagnosedSilenceableFailure.emit_silenceable_error( + f"Expected memref.alloca operation, got {target_op.OPERATION_NAME}" ) - # Replace the operation with a new one that has the updated type - with ir.InsertionPoint(target_op): - # Get the operands from the original alloca (dynamic sizes and symbols) - dynamic_sizes = list( - target_op.operands[ - : target_op.attributes["operandSegmentSizes"][0] - ] - ) - symbol_operands = list( - target_op.operands[ - target_op.attributes["operandSegmentSizes"][0] : - ] - ) - # Create a new alloca with the updated type - new_alloca = memref.alloca( - new_memref_type, dynamic_sizes, symbol_operands - ) - # Replace all uses of the old operation with the new one - # FIXME: This won't handle operations that consume the memref type and - # return a new memref (such as subview). - rewriter.replace_op(target_op, [new_alloca]) - new_ops.append(new_alloca.owner) + # Get the current result type (should be a MemRefType) + old_result_type = target_op.results[0].type + memref_type = ir.MemRefType(old_result_type) + # Create a new memref type with the specified address space + new_memref_type = ir.MemRefType.get( + memref_type.shape, + memref_type.element_type, + layout=memref_type.layout, + memory_space=ir.Attribute.parse(f"{address_space_value}"), + ) + + # Replace the operation with a new one that has the updated type + with ir.InsertionPoint(target_op): + # Get the operands from the original alloca (dynamic sizes and symbols) + dynamic_sizes = list( + target_op.operands[: target_op.attributes["operandSegmentSizes"][0]] + ) + symbol_operands = list( + target_op.operands[target_op.attributes["operandSegmentSizes"][0] :] + ) + # Create a new alloca with the updated type + new_alloca = memref.alloca( + new_memref_type, dynamic_sizes, symbol_operands + ) + # Replace all uses of the old operation with the new one + # FIXME: This won't handle operations that consume the memref type and + # return a new memref (such as subview). + rewriter.replace_op(target_op, [new_alloca]) + new_ops.append(new_alloca.owner) # Set the results to the new operations results.set_ops(op.updated_op, new_ops) diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 7912a45b..8876d052 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -16,6 +16,7 @@ PipelineInterrupt, ) from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary +from lighthouse.dialects.transform import transform_ext def get_softmax_schedule_module( @@ -140,7 +141,6 @@ def bundle_xegpu_softmax_schedule( transform.AnyOpType.get(), func, ops=["linalg.softmax"] ) structured.structured_decompose_interface(anytype, softmax_ops) - transform.print_(target=func, name="Aftemr structured_decompose_interface") linalg_ops = match_and_split( func, ops={"linalg.generic", "linalg.fill"}, nhandles=6 @@ -157,8 +157,6 @@ def bundle_xegpu_softmax_schedule( div_op, sizes=[0, reduction_step_size] ).results - transform.print_(target=func, name="After tiling div op") - # Fuse max_center_and_exp_op into the div loop _, fused_loop = structured.structured_fuse_into_containing_op( anytype, @@ -166,9 +164,6 @@ def bundle_xegpu_softmax_schedule( producer_op=max_center_and_exp_op, containing_op=div_loop, ) - transform.print_( - target=func, name="After fusing max_center_and_exp_op into div loop" - ) # Tile the sum reduction and fuse the sub+exp producer into it _, _, _, sum_loop = structured.structured_tile_reduction_using_for( @@ -180,8 +175,6 @@ def bundle_xegpu_softmax_schedule( tile_sizes=[0, reduction_step_size], ) - transform.print_(target=func, name="After tiling sum reduction") - func = transform.get_parent_op( anytype, fused_loop, @@ -200,9 +193,6 @@ def bundle_xegpu_softmax_schedule( producer_op=max_center_and_exp_op, containing_op=sum_loop, ) - transform.print_( - target=func, name="After fusing max_center_and_exp_op into sum loop" - ) # Tile the max reduction. max_reduction = linalg_ops[0] @@ -214,7 +204,6 @@ def bundle_xegpu_softmax_schedule( target=max_reduction, tile_sizes=[0, reduction_step_size], ) - transform.print_(target=func, name="After tiling max reduction") # Cleanup after tiling and fusion transform.apply_cse(func) @@ -231,8 +220,6 @@ def bundle_xegpu_softmax_schedule( transform.apply_cse(func) canonicalize(func) - transform.print_(target=func, name="After vectorization") - if stop_at_stage == "vectorized": raise PipelineInterrupt() @@ -250,8 +237,6 @@ def bundle_xegpu_softmax_schedule( transform.apply_cse(mod) canonicalize(mod) - transform.print_(target=mod, name="After bufferization") - # promote memref.alloc to memref.alloca in payload function func = match(mod, ops={"func.func"}) func = apply_registered_pass( @@ -263,8 +248,6 @@ def bundle_xegpu_softmax_schedule( }, ) - transform.print_(target=func, name="After promoting buffers to stack") - if stop_at_stage == "bufferized": raise PipelineInterrupt() @@ -294,8 +277,6 @@ def bundle_xegpu_softmax_schedule( mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) - transform.print_(target=mod, name="After GPU outlining") - if stop_at_stage == "gpu-outlining": raise PipelineInterrupt() @@ -306,12 +287,16 @@ def bundle_xegpu_softmax_schedule( options={"O": "3", "chip": "bmg"}, ) - # convert vector to xegpu + # for each gpu function in the gpu module, change memref.alloca address + # space to 3 (SLM) and convert vector to xegpu. gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) for gpu_mod in gpu_mod_ops: gpu_func = match(gpu_mod, ops={"gpu.func"}) - gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") - transform.apply_cse(gpu_func) + allocas = match_and_split(gpu_func, ops={"memref.alloca"}) + for alloca in allocas: + transform_ext.update_address_space(alloca, address_space=3) + # gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + # transform.apply_cse(gpu_func) # Cleanup. transform.apply_cse(mod) From 1313477172c7442e322c6bbb1cbd614c415e8172 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 20 Apr 2026 20:34:32 +0000 Subject: [PATCH 32/63] working version --- examples/xegpu/softmax.py | 4 +++- .../transform_ext/ops/update_address_space.py | 2 +- lighthouse/schedule/xegpu/softmax_schedule.py | 15 +++++++++------ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index cc130297..f75613d0 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -155,7 +155,7 @@ def parse_cli(): "--sizes", type=int, nargs=2, - default=[1024, 64], + default=[1024, 512], help="M,N matrix sizes (MxN)", ) parser.add_argument( @@ -290,6 +290,8 @@ def parse_cli(): ) if not success: raise ValueError("Result mismatch!") + else: + print("Result is correct. Proceeding to benchmark...") times = runner.benchmark( host_input_buffers=wload._initial_host_arrays, diff --git a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py index 2c40bfce..8d0b6041 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py +++ b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py @@ -16,7 +16,7 @@ class UpdateAddressSpace( target: ext.Operand[transform.AnyOpType] address_space: ir.IntegerAttr - updated_op: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True) + updated_op: ext.Result[transform.AnyOpType[()]] = ext.infer_result() @classmethod def attach_interface_impls(cls, ctx=None): diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 8876d052..862677df 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -295,8 +295,8 @@ def bundle_xegpu_softmax_schedule( allocas = match_and_split(gpu_func, ops={"memref.alloca"}) for alloca in allocas: transform_ext.update_address_space(alloca, address_space=3) - # gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") - # transform.apply_cse(gpu_func) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) # Cleanup. transform.apply_cse(mod) @@ -305,12 +305,15 @@ def bundle_xegpu_softmax_schedule( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - # Set layout attributes for xegpu.store_nd operations. - # FIXME: currently ecah subgroup is handling the entire row. - store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=5) + # Set layout attributes for xegpu.store_nd and xegpu.store_matrix ops. + store_nd_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) + store_matrix_ops = match_and_split(gpu_func, ops={"xegpu.store_matrix"}, nhandles=4) sg_layout = [parameters["sg_rows"], 1] sg_data = [parameters["sg_rows"], parameters["reduction_step_size"]] - xegpu.set_anchor_layout(store_ops[-1], sg_layout=sg_layout, sg_data=sg_data) + for store_op in store_nd_ops: + xegpu.set_anchor_layout(store_op, sg_layout=sg_layout, sg_data=sg_data) + for store_op in store_matrix_ops: + xegpu.set_anchor_layout(store_op, sg_layout=sg_layout, sg_data=sg_data) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From f1857aabfd159b7ab30d9a3b5fb3217efff781c9 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 20 Apr 2026 20:36:54 +0000 Subject: [PATCH 33/63] working version --- docs/softmax_lowering.md | 635 --------------------------------------- 1 file changed, 635 deletions(-) delete mode 100644 docs/softmax_lowering.md diff --git a/docs/softmax_lowering.md b/docs/softmax_lowering.md deleted file mode 100644 index a31e6f4c..00000000 --- a/docs/softmax_lowering.md +++ /dev/null @@ -1,635 +0,0 @@ -# Linalg softmax lowering to XeGPU (Currently supported in lighthouse) - -## Overview - -**Assumptions:** -Softmax dimension size is small (64 in this example). - -The lowering process consists of seven stages: -1. **initial** - High-level tensor operations -2. **tiled-softmax** - Tiled softmax operations -3. **decomposed** - Decomposition into constituent operations -4. **vectorized** - Vector operations -5. **bufferized** - Memory-based representation -6. **xegpu-initial** - GPU kernel with XeGPU operations -7. **xegpu-wg** - Work-group optimized XeGPU - ---- - -## Stage 1: Initial - -**Code:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - // ... - %2 = tensor.empty() : tensor<1024x64xf32> - %3 = linalg.softmax dimension(1) ins(%1 : tensor<1024x64xf32>) - outs(%2 : tensor<1024x64xf32>) -> tensor<1024x64xf32> - // ... - return -} -``` ---- - -## Stage 2: Tiled Softmax - -**Notes** -- Work distribution via `scf.forall` (16 parallel iterations) -- Each tile processes 64x64 elements - -**Code:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - // ... - %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { - %4 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2) - // Extract 64x64 input slice - %extracted_slice = tensor.extract_slice ... - // Extract 64x64 output slice - %extracted_slice_0 = tensor.extract_slice ... - // Apply softmax to the tile - %5 = linalg.softmax dimension(1) ins(%extracted_slice : tensor<64x64xf32>) - outs(%extracted_slice_0 : tensor<64x64xf32>) -> tensor<64x64xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %5 into %arg3[%4, %c0] [64, 64] [1, 1] : - tensor<64x64xf32> into tensor<1024x64xf32> - } - } - // ... - return -} -``` - ---- - -## Stage 3: Decomposed - -**Notes** -- Softmax decomposed into 4 constituent `linalg.generic` ops : max, sub+exp, sum, divide -- Uses `structured.structured_decompose_interface` implemented by `linalg.softmax` - -**Code:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - // ... - - %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x64xf32>) { - %3 = affine.apply #map(%arg2) // %3 = %arg2 * 64 - %extracted_slice = tensor.extract_slice ... - - // Step 1: Find max along dimension 1 - %4 = tensor.empty() : tensor<64xf32> - %5 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<64xf32>) -> tensor<64xf32> - %6 = linalg.generic // ... - %11 = arith.maxnumf %in, %out : f32 - // ... - } -> tensor<64xf32> - - // Step 2: Subtract max and exponentiate - %7 = linalg.generic // ... - %11 = arith.subf %in, %in_2 : f32 - %12 = math.exp %11 : f32 - // ... - } -> tensor<64x64xf32> - - // Step 3: Sum exponentials - %8 = linalg.fill ins(%cst : f32) outs(%4 : tensor<64xf32>) -> tensor<64xf32> - %9 = linalg.generic // ... - %11 = arith.addf %in, %out : f32 - // ... - } -> tensor<64xf32> - - // Step 4: Normalize by sum - %10 = linalg.generic // ... - %11 = arith.divf %in, %in_2 : f32 - // ... - } -> tensor<64x64xf32> - - scf.forall.in_parallel { - tensor.parallel_insert_slice %10 into %arg3[%3, 0] [64, 64] [1, 1] : - tensor<64x64xf32> into tensor<1024x64xf32> - } - } - return -} -``` - ---- - -## Stage 4: Vectorized - -**Notes** -- `linalg.generic` operations replaced with vector operations -- Vector transfers for reading/writing data - -**Code:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - // ... - %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { - %4 = affine.apply #map(%arg2) // %4 = %arg2 * 64 - %extracted_slice = tensor.extract_slice .. - - // Vector read: Load 64x64 tile - %5 = vector.transfer_read %1[%4, %c0], %0 {in_bounds = [true, true]} : - tensor<1024x64xf32>, vector<64x64xf32> - - // Max reduction: Reduce dimension 1 -> vector<64xf32> - %6 = vector.multi_reduction , %5, %cst_0 [1] : - vector<64x64xf32> to vector<64xf32> - - // Broadcast max values back to 64x64 and transpose - %7 = vector.broadcast %6 : vector<64xf32> to vector<64x64xf32> - %8 = vector.transpose %7, [1, 0] : vector<64x64xf32> to vector<64x64xf32> - - // Subtract max and exponentiate - %9 = arith.subf %5, %8 : vector<64x64xf32> - %10 = math.exp %9 : vector<64x64xf32> - - // Sum reduction: Reduce dimension 1 -> vector<64xf32> - %11 = vector.multi_reduction , %10, %cst [1] : - vector<64x64xf32> to vector<64xf32> - - // Broadcast sums back to 64x64 and transpose - %12 = vector.broadcast %11 : vector<64xf32> to vector<64x64xf32> - %13 = vector.transpose %12, [1, 0] : vector<64x64xf32> to vector<64x64xf32> - - // Normalize - %14 = arith.divf %10, %13 : vector<64x64xf32> - - // Vector write - %15 = vector.transfer_write %14, %extracted_slice[%c0, %c0] {in_bounds = [true, true]} : - vector<64x64xf32>, tensor<64x64xf32> - - scf.forall.in_parallel { - tensor.parallel_insert_slice %15 into %arg3[%4, 0] [64, 64] [1, 1] - } - } - return -} -``` ---- - -## Stage 5: Bufferized - -**Notes** -- Tensors eliminated, working directly with memrefs - -**Code:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - // ... - - scf.forall (%arg2) in (16) { - %1 = affine.apply #map(%arg2) // %1 = %arg2 * 64 - - // Direct memref read - %2 = vector.transfer_read %arg1[%1, %c0], %0 {in_bounds = [true, true]} : - memref<1024x64xf32>, vector<64x64xf32> - - // Max reduction - %3 = vector.multi_reduction , %2, %cst_0 [1] : - vector<64x64xf32> to vector<64xf32> - %4 = vector.broadcast %3 : vector<64xf32> to vector<64x64xf32> - %5 = vector.transpose %4, [1, 0] : vector<64x64xf32> to vector<64x64xf32> - - // Subtract and exp - %6 = arith.subf %2, %5 : vector<64x64xf32> - %7 = math.exp %6 : vector<64x64xf32> - - // Sum reduction - %8 = vector.multi_reduction , %7, %cst [1] : - vector<64x64xf32> to vector<64xf32> - %9 = vector.broadcast %8 : vector<64xf32> to vector<64x64xf32> - %10 = vector.transpose %9, [1, 0] : vector<64x64xf32> to vector<64x64xf32> - - // Normalize - %11 = arith.divf %7, %10 : vector<64x64xf32> - - // Direct memref write - vector.transfer_write %11, %arg0[%1, %c0] {in_bounds = [true, true]} : - vector<64x64xf32>, memref<1024x64xf32> - } - return -} -``` - ---- - -## Stage 6: XeGPU-Initial - -**Notes** -- GPU kernel separated from host code (Gpu Outlining) -- `gpu.launch_func` invocation with grid/block dimensions -- Use `vector-to-xegpu` - -**Code:** - -**Host Side:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - // ... - gpu.launch_func @payload_kernel::@payload_kernel - blocks in (%c16, %c1, %c1) - threads in (%c128, %c1, %c1) - args(%arg1 : memref<1024x64xf32>, %arg0 : memref<1024x64xf32>) - return -} -``` - -**GPU Kernel:** -```mlir -gpu.module @payload_kernel [#xevm.target] { - gpu.func @payload_kernel(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) kernel - attributes {known_block_size = array, - known_grid_size = array} { - // ... - %block_id_x = gpu.block_id x - %0 = arith.muli %block_id_x, %c64 overflow : index - - // Create XeGPU tensor descriptor for load - %1 = xegpu.create_nd_tdesc %arg0 : memref<1024x64xf32> -> - !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> - - // XeGPU block load - %2 = xegpu.load_nd %1[%0, 0] : - !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> -> - vector<64x64xf32> - - // Same compute operations as before - %3 = vector.multi_reduction , %2, %cst_0 [1] : - vector<64x64xf32> to vector<64xf32> - %4 = vector.broadcast %3 : vector<64xf32> to vector<64x64xf32> - %5 = vector.transpose %4, [1, 0] : vector<64x64xf32> to vector<64x64xf32> - %6 = arith.subf %2, %5 : vector<64x64xf32> - %7 = math.exp %6 : vector<64x64xf32> - %8 = vector.multi_reduction , %7, %cst [1] : - vector<64x64xf32> to vector<64xf32> - %9 = vector.broadcast %8 : vector<64xf32> to vector<64x64xf32> - %10 = vector.transpose %9, [1, 0] : vector<64x64xf32> to vector<64x64xf32> - %11 = arith.divf %7, %10 : vector<64x64xf32> - - // Create XeGPU tensor descriptor for store - %12 = xegpu.create_nd_tdesc %arg1 : memref<1024x64xf32> -> - !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> - - // XeGPU block store - xegpu.store_nd %11, %12[%0, 0] : - vector<64x64xf32>, - !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> - - gpu.return - } -} -``` - ---- - -## Stage 7: XeGPU-WG (Work-Group Optimized) - -**Notes** -- Sets the layout for anchor xegpu ops. Each Wg consistes of [8, 1] subgroups - doing 8x64 softmax slice. -- Only sets the layotu for `store_nd`. Layout propagation does the rest. - -**Code (differences from xegpu-initial):** -```mlir -// Store operation now includes layout hints -xegpu.store_nd %11, %12[%0, 0] - <{layout = #xegpu.layout}> : - vector<64x64xf32>, - !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> -``` - ---- - -# Supporting larger Softmax dimension sizes - -When the softmax dimension is larger than what can fit efficiently in registers, additional tiling and fusion transformations are applied to the reduction dimension. This section shows the intermediate stages between "decomposed" and "vectorized". - -**Approach:** Tile reductions along dimension 1 (step size = 16) and fuse producers into consumers to enable streaming computation. - ---- - -## Decomposed → Tiled: Stage A - Tile div op - -**Notes:** -- Tile the division operation with step size 16 along dimension 1 -- Creates `scf.for` loop iterating over 64 elements in chunks of 16 - -**Key Changes:** -```mlir -// Before: Single division linalg.generic over 64x64 -%11 = linalg.generic {...} ins(%8, %10 : tensor<64x64xf32>, tensor<64xf32>) - outs(%extracted_slice_0 : tensor<64x64xf32>) { ... } -> tensor<64x64xf32> - -// After: Division tiled into 64x16 chunks -%11 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %extracted_slice_0) -> (tensor<64x64xf32>) { - %extracted_slice_3 = tensor.extract_slice %8[0, %arg4] [64, 16] [1, 1] - %12 = linalg.generic {...} ins(%extracted_slice_3, %extracted_slice_4 : tensor<64x16xf32>, tensor<64xf32>) - outs(%extracted_slice_5 : tensor<64x16xf32>) { ... } -> tensor<64x16xf32> - %inserted_slice = tensor.insert_slice %12 into %arg5[0, %arg4] [64, 16] [1, 1] - scf.yield %inserted_slice -} -``` - ---- - -## Stage B - Fuse sub+exp into div loop - -**Notes:** -- Fuse the `sub+exp` producer (max_center_and_exp_op) into the div loop -- Recomputes exp values on-the-fly instead of materializing full 64x64 tensor - -**Key Changes:** -```mlir -%11 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %extracted_slice_0) -> (tensor<64x64xf32>) { - %extracted_slice_3 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] - - // Fused: sub+exp computed per 16-element chunk - %12 = linalg.generic {...} ins(%extracted_slice_3, %extracted_slice_4 : tensor<64x16xf32>, tensor<64xf32>) - outs(%extracted_slice_5 : tensor<64x16xf32>) { - ^bb0(%in: f32, %in_8: f32, %out: f32): - %14 = arith.subf %in, %in_8 : f32 - %15 = math.exp %14 : f32 - linalg.yield %15 : f32 - } -> tensor<64x16xf32> - - // Division operation - %13 = linalg.generic {...} ins(%12, %extracted_slice_6 : tensor<64x16xf32>, tensor<64xf32>) - outs(%extracted_slice_7 : tensor<64x16xf32>) { ... } -> tensor<64x16xf32> - // ... -} -``` - ---- - -## Stage C - Tile sum reduction - -**Notes:** -- Tile the sum reduction using `structured_tile_reduction_using_for` -- Creates intermediate accumulator tensor (64x16) -- Final reduction via `linalg.reduce` over dimension 1 - -**Key Changes:** -```mlir -// Tiled sum reduction with intermediate accumulator -%10 = tensor.empty() : tensor<64x16xf32> -%11 = linalg.fill ins(%cst_2 : f32) outs(%10 : tensor<64x16xf32>) -> tensor<64x16xf32> - -%12 = scf.for %arg4 = %c0_3 to %c64 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { - %extracted_slice_7 = tensor.extract_slice %8[0, %arg4] [64, 16] [1, 1] - %14 = linalg.generic {...} ins(%extracted_slice_7 : tensor<64x16xf32>) - outs(%extracted_slice_8 : tensor<64x16xf32>) { - ^bb0(%in: f32, %out: f32): - %15 = arith.addf %in, %out : f32 - linalg.yield %15 : f32 - } -> tensor<64x16xf32> - // ... -} - -// Final reduction to 64xf32 -%reduced = linalg.reduce ins(%12 : tensor<64x16xf32>) outs(%9 : tensor<64xf32>) dimensions = [1] - (%in: f32, %init: f32) { - %14 = arith.addf %in, %init : f32 - linalg.yield %14 : f32 - } -``` - ---- - -## Stage D - Fuse sub+exp into sum loop - -**Notes:** -- Fuse `sub+exp` into the sum reduction loop -- Stream computation: compute exp and accumulate in same loop - -**Key Changes:** -```mlir -%12 = scf.for %arg4 = %c0_3 to %c64 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { - %extracted_slice_7 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] - - // Fused: sub+exp - %14 = linalg.generic {...} ins(%extracted_slice_7, %extracted_slice_8 : tensor<64x16xf32>, tensor<64xf32>) - outs(%extracted_slice_9 : tensor<64x16xf32>) { - ^bb0(%in: f32, %in_11: f32, %out: f32): - %16 = arith.subf %in, %in_11 : f32 - %17 = math.exp %16 : f32 - linalg.yield %17 : f32 - } -> tensor<64x16xf32> - - // Accumulate sum - %15 = linalg.generic {...} ins(%14 : tensor<64x16xf32>) - outs(%extracted_slice_10 : tensor<64x16xf32>) { - ^bb0(%in: f32, %out: f32): - %16 = arith.addf %in, %out : f32 - linalg.yield %16 : f32 - } -> tensor<64x16xf32> - // ... -} -``` - ---- - -## Stage E - Tile max reduction - -**Notes:** -- Tile max reduction similar to sum reduction -- Creates 64x16 intermediate accumulator -- Final reduction via `linalg.reduce` with maxnumf - -**Key Changes:** -```mlir -// Tiled max reduction -%7 = tensor.empty() : tensor<64x16xf32> -%8 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<64x16xf32>) -> tensor<64x16xf32> - -%9 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %8) -> (tensor<64x16xf32>) { - %extracted_slice_12 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] - %16 = linalg.generic {...} ins(%extracted_slice_12 : tensor<64x16xf32>) - outs(%extracted_slice_13 : tensor<64x16xf32>) { - ^bb0(%in: f32, %out: f32): - %17 = arith.maxnumf %in, %out : f32 - linalg.yield %17 : f32 - } -> tensor<64x16xf32> - // ... -} - -// Final max reduction -%reduced = linalg.reduce ins(%9 : tensor<64x16xf32>) outs(%6 : tensor<64xf32>) dimensions = [1] - (%in: f32, %init: f32) { - %16 = arith.maxnumf %in, %init : f32 - linalg.yield %16 : f32 - } -``` - -**Result:** Now all three major computations (max, sum, div) are tiled and operate on 64x16 chunks, with exp computation fused into both sum and div loops. - ---- - -## Stage F - Vectorization - -**Notes:** -- Convert tiled linalg operations to vector operations -- `scf.for` loops remain but operate on vectors -- Vector size: 64x16 for tiled operations - -**Code:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - // ... - %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { - // ... - - // Vectorized max reduction loop - %6 = vector.transfer_write %cst_1, %5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> - %7 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %6) -> (tensor<64x16xf32>) { - %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> - %16 = vector.transfer_read %arg5[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> - %17 = arith.maxnumf %15, %16 : vector<64x16xf32> - %18 = vector.transfer_write %17, %arg5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> - scf.yield %18 : tensor<64x16xf32> - } - %8 = vector.transfer_read %7[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> - %9 = vector.multi_reduction , %8, %cst_2 [1] : vector<64x16xf32> to vector<64xf32> - - // Vectorized sum reduction loop with fused sub+exp - %11 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %10) -> (tensor<64x16xf32>) { - %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> - %16 = vector.broadcast %9 : vector<64xf32> to vector<16x64xf32> - %17 = vector.transpose %16, [1, 0] : vector<16x64xf32> to vector<64x16xf32> - %18 = arith.subf %15, %17 : vector<64x16xf32> - %19 = math.exp %18 : vector<64x16xf32> - %20 = vector.transfer_read %arg5[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> - %21 = arith.addf %19, %20 : vector<64x16xf32> - %22 = vector.transfer_write %21, %arg5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> - scf.yield %22 : tensor<64x16xf32> - } - %12 = vector.transfer_read %11[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> - %13 = vector.multi_reduction , %12, %cst_0 [1] : vector<64x16xf32> to vector<64xf32> - - // Vectorized div loop with fused sub+exp - %14 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %extracted_slice) -> (tensor<64x64xf32>) { - %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> - %16 = vector.broadcast %9 : vector<64xf32> to vector<16x64xf32> - %17 = vector.transpose %16, [1, 0] : vector<16x64xf32> to vector<64x16xf32> - %18 = arith.subf %15, %17 : vector<64x16xf32> - %19 = math.exp %18 : vector<64x16xf32> - %20 = vector.broadcast %13 : vector<64xf32> to vector<16x64xf32> - %21 = vector.transpose %20, [1, 0] : vector<16x64xf32> to vector<64x16xf32> - %22 = arith.divf %19, %21 : vector<64x16xf32> - %23 = vector.transfer_write %22, %arg5[%c0, %arg4] : vector<64x16xf32>, tensor<64x64xf32> - scf.yield %23 : tensor<64x64xf32> - } - } - // ... -} -``` - ---- - -## Stage G - Bufferization - -**Notes:** -- Convert tensors to memrefs -- Allocate stack buffer for 64x16 accumulator: `memref.alloc()` - -**Code:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - // ... - scf.forall (%arg2) in (16) { - %1 = affine.apply #map(%arg2) - %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] - - // Allocate accumulator buffer - %alloc = memref.alloc() {alignment = 64 : i64} : memref<64x16xf32> - - // Max reduction loop - vector.transfer_write %cst_1, %alloc[%c0, %c0] : vector<64x16xf32>, memref<64x16xf32> - scf.for %arg3 = %c0 to %c64 step %c16 { - %6 = vector.transfer_read %arg1[%1, %arg3], %0 : memref<1024x64xf32>, vector<64x16xf32> - %7 = vector.transfer_read %alloc[%c0, %c0], %0 : memref<64x16xf32>, vector<64x16xf32> - %8 = arith.maxnumf %6, %7 : vector<64x16xf32> - vector.transfer_write %8, %alloc[%c0, %c0] : vector<64x16xf32>, memref<64x16xf32> - } - %2 = vector.transfer_read %alloc[%c0, %c0], %0 : memref<64x16xf32>, vector<64x16xf32> - %3 = vector.multi_reduction , %2, %cst_2 [1] : vector<64x16xf32> to vector<64xf32> - - // Sum reduction loop (reuses %alloc) - // ... - - // Div loop (writes to %subview) - // ... - } -} -``` - ---- - -## Stage H - Promote buffers to stack - -**Notes:** -- Convert `memref.alloc()` to `memref.alloca()` for stack allocation -- Reduces memory allocation overhead - -**Code:** -```mlir -scf.forall (%arg2) in (16) { - %1 = affine.apply #map(%arg2) - %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] - - // Stack allocation instead of heap - %alloca = memref.alloca() {alignment = 64 : i64} : memref<64x16xf32> - - // ... same operations using %alloca ... -} -``` - ---- - -## Stage I - GPU outlining - -**Notes:** -- Convert `scf.forall` to `scf.parallel`, then to `gpu.launch` -- Extract GPU kernel into separate `gpu.module` -- Set thread count: 128 threads = (64 rows / 8 sg_rows) × 16 subgroup_size - -**Host Side:** -```mlir -func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { - %c16 = arith.constant 16 : index - %c1 = arith.constant 1 : index - %c128 = arith.constant 128 : index - gpu.launch_func @payload_kernel::@payload_kernel - blocks in (%c16, %c1, %c1) - threads in (%c128, %c1, %c1) - args(%arg0 : memref<1024x64xf32>, %arg1 : memref<1024x64xf32>) - return -} -``` - -**GPU Kernel:** -```mlir -gpu.module @payload_kernel { - gpu.func @payload_kernel(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) kernel - attributes {known_block_size = array, - known_grid_size = array} { - %block_id_x = gpu.block_id x - %1 = arith.muli %block_id_x, %c64 overflow : index - %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] - %alloca = memref.alloca() {alignment = 64 : i64} : memref<64x16xf32> - - // Three reduction loops (max, sum, div) with same structure - scf.for %arg2 = %c0 to %c64 step %c16 { - // Max: accumulate max values - // Sum: compute & accumulate exp(x - max) - // Div: compute exp(x - max) / sum - } - - gpu.return - } -} -``` - -**Summary:** At this stage, the kernel processes 64x16 chunks in streaming fashion through three sequential loops, minimizing memory footprint. From 240cf084338baf9f30ebe2315d09a489170530d9 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 20 Apr 2026 22:27:53 +0000 Subject: [PATCH 34/63] add initial version --- examples/xegpu/fused_attention.py | 339 ++++++++++++++++++ .../mlir_gen/gpu_fused_attention_payload.py | 69 ++++ .../xegpu/fused_attention_schedule.py | 164 +++++++++ 3 files changed, 572 insertions(+) create mode 100644 examples/xegpu/fused_attention.py create mode 100644 lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py create mode 100644 lighthouse/schedule/xegpu/fused_attention_schedule.py diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py new file mode 100644 index 00000000..5aa84e2a --- /dev/null +++ b/examples/xegpu/fused_attention.py @@ -0,0 +1,339 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU fused attention benchmark. +""" + +import argparse +from typing import Optional +from functools import cached_property + +import numpy as np +from mlir import ir + +from lighthouse import dialects as lh_dialects +from lighthouse.execution.runner import Runner +from lighthouse.pipeline.driver import TransformDriver +from lighthouse.execution import GPUMemoryManager +from lighthouse.utils.numpy import mlir_to_numpy_dtype +from lighthouse.ingress.mlir_gen import get_mlir_elem_type +from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import generate_gpu_fused_attention_payload +from lighthouse.schedule.xegpu.fused_attention_schedule import get_fused_attention_schedule_module + + +def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int): + """ + Complexity of fused attention operation. + + For each batch and head: + - Q @ K^T: O(n_ctx^2 * n_head) operations + - Softmax: O(n_ctx^2) operations + - Attention @ V: O(n_ctx^2 * n_head) operations + Total: approximately 2*n_ctx^2*n_head FLOPs per batch and head + """ + # Approximation: 2 * n_ctx^2 * n_head FLOPs per batch and head + flop_count = Z * H * 2 * n_ctx * n_ctx * n_head + # Memory: read Q, K, V and write output + memory_reads = 3 * Z * H * n_ctx * n_head * nbytes + memory_writes = Z * H * n_ctx * n_head * nbytes + return flop_count, memory_reads, memory_writes + + +def check_correctness( + Q: np.ndarray, K: np.ndarray, V: np.ndarray, output_arr: np.ndarray, verbose: int = 0 +) -> bool: + """ + Check correctness of fused attention output. + + Reference implementation: + - scores = Q @ K^T / sqrt(n_head) + - attention_weights = softmax(scores, dim=-1) + - output = attention_weights @ V + """ + # Use float32 for computation + Q_f32 = Q.astype(np.float32) + K_f32 = K.astype(np.float32) + V_f32 = V.astype(np.float32) + + Z, H, n_ctx, n_head = Q.shape + scale = 1.0 / np.sqrt(n_head) + + output_ref = np.zeros_like(Q_f32) + + # Compute reference for each batch and head + for z in range(Z): + for h in range(H): + # scores = Q @ K^T / sqrt(n_head) + scores = Q_f32[z, h] @ K_f32[z, h].T * scale + + # softmax along last dimension + max_vals = np.max(scores, axis=1, keepdims=True) + exp_vals = np.exp(scores - max_vals) + sum_vals = np.sum(exp_vals, axis=1, keepdims=True) + attention_weights = exp_vals / sum_vals + + # output = attention_weights @ V + output_ref[z, h] = attention_weights @ V_f32[z, h] + + output = output_arr.astype(np.float32) + + if verbose > 1: + print("Reference solution (first batch, first head, first 5 rows):") + print(output_ref[0, 0, :5]) + print("Computed solution (first batch, first head, first 5 rows):") + print(output[0, 0, :5]) + + # Check values match reference + values_ok = np.allclose(output, output_ref, rtol=1e-3, atol=1e-4) + + success = values_ok + + if verbose: + if success: + print("PASSED") + else: + print("FAILED!") + if not values_ok: + max_diff = np.abs(output - output_ref).max() + print(f" Values mismatch. Max abs diff: {max_diff:.6e}") + return success + + +class XeGPUFusedAttention: + """ + Fused attention workload on XeGPU. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + All Q, K, V matrices have shape (Z, H, n_ctx, n_head) where: + - Z: batch size + - H: number of heads + - n_ctx: context length + - n_head: head dimension + """ + + def __init__( + self, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: str = "f32", + ): + self.Z = Z + self.H = H + self.n_ctx = n_ctx + self.n_head = n_head + self.shape = (Z, H, n_ctx, n_head) + assert dtype == "f32", "Only f32 type is supported for fused attention" + self.elem_type = get_mlir_elem_type(dtype) + self.dtype = mlir_to_numpy_dtype(self.elem_type) + self.memory_manager_class = GPUMemoryManager + self.payload_function_name = "payload" + + @cached_property + def _initial_host_arrays(self) -> tuple[np.ndarray]: + """Generate initial values on host with numpy.""" + np.random.seed(42) + # Initialize Q, K, V with small random values + Q = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + K = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + V = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + output_arr = np.zeros(self.shape, dtype=self.dtype) + return (output_arr, Q, K, V) + + def get_complexity(self) -> tuple[int, int, int]: + nbytes = np.dtype(self.dtype).itemsize + return fused_attention_complexity(self.Z, self.H, self.n_ctx, self.n_head, nbytes) + + def payload_module(self) -> ir.Module: + """Generate MLIR module for fused attention payload.""" + return generate_gpu_fused_attention_payload( + func_name=self.payload_function_name, + Z=self.Z, + H=self.H, + n_ctx=self.n_ctx, + n_head=self.n_head, + dtype=self.elem_type, + ) + + def schedule_modules( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> list[ir.Module]: + """Generate transform schedule for fused attention.""" + return [ + Runner.get_bench_wrapper_schedule(self.payload_function_name), + get_fused_attention_schedule_module( + stop_at_stage=stop_at_stage, + parameters=parameters, + ), + ] + + def shared_libs(self) -> list[str]: + return ["libmlir_levelzero_runtime.so"] + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Fused Attention using MLIR XeGPU", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--batch-size", + type=int, + default=2, + help="Batch size (Z)", + ) + parser.add_argument( + "--num-heads", + type=int, + default=8, + help="Number of attention heads (H)", + ) + parser.add_argument( + "--n-ctx", + type=int, + default=512, + help="Context length (sequence length)", + ) + parser.add_argument( + "--n-head", + type=int, + default=64, + help="Head dimension", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the fused attention computation.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "tiled", + "vectorized", + "bufferized", + "gpu-outlining", + "xegpu-initial", + "xegpu-wg", + "final", + ], + help="Dump kernel IR at different stages of lowering and exit without " + "executing the kernel.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + parser.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="Increase output verbosity (e.g. print reference and computed solutions).", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "batch_size": args.batch_size, + "num_heads": args.num_heads, + "n_ctx": args.n_ctx, + "n_head": args.n_head, + } + + Z = args.batch_size + H = args.num_heads + n_ctx = args.n_ctx + n_head = args.n_head + dtype = "f32" + + with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUFusedAttention(Z=Z, H=H, n_ctx=n_ctx, n_head=n_head, dtype=dtype) + + if args.dump_kernel or args.dump_schedule: + pipeline = TransformDriver( + wload.schedule_modules( + stop_at_stage=args.dump_kernel, parameters=params + ) + ) + payload = pipeline.apply(wload.payload_module()) + if args.dump_kernel: + print(payload) + if args.dump_schedule: + for schedule_module in wload.schedule_modules(parameters=params): + print(schedule_module) + else: + pipeline = TransformDriver(wload.schedule_modules(parameters=params)) + payload = pipeline.apply(wload.payload_module()) + runner = Runner( + payload, + mem_manager_cls=wload.memory_manager_class, + shared_libs=wload.shared_libs(), + ) + if args.check_result: + # Setup callback function to copy result from device to host. + result_host_copy, argument_access_callback = ( + Runner.get_gpu_argument_access_callback(wload.shape, wload.dtype) + ) + + # Execute kernel once. + runner.execute( + host_input_buffers=wload._initial_host_arrays, + payload_function_name=wload.payload_function_name, + argument_access_callback=argument_access_callback, + ) + + # Compute reference solution on host. + Q, K, V = wload._initial_host_arrays[1:4] + success = check_correctness( + Q, K, V, + result_host_copy, + verbose=args.verbose, + ) + if not success: + raise ValueError("Result mismatch!") + else: + print("Result is correct. Proceeding to benchmark...") + + times = runner.benchmark( + host_input_buffers=wload._initial_host_arrays, + nruns=args.nruns, + nwarmup=args.nwarmup, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + print( + f"batch-size={Z} " + f"num-heads={H} " + f"n-ctx={n_ctx} " + f"n-head={n_head} " + f"dt={dtype} " + f"time(us): {elapsed:.2f} " + f"GFLOPS: {gflops:.2f} " + ) diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py new file mode 100644 index 00000000..788aa175 --- /dev/null +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -0,0 +1,69 @@ +"""Generate MLIR payload for GPU fused attention operation.""" + +from mlir import ir +from mlir.dialects import linalg, bufferization, tensor, arith + +from lighthouse.utils.mlir import func_cif +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs +from lighthouse.ingress.mlir_gen.utils import emit_buf_to_tensor + + +def generate_gpu_fused_attention_payload( + func_name: str, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: ir.Type, +) -> ir.Module: + """ + Generate MLIR module for fused attention payload. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + Args: + func_name: Name of the payload function + Z: Batch size + H: Number of attention heads + n_ctx: Context length (sequence length) + n_head: Head dimension + dtype: MLIR element type (e.g., F32Type) + + Returns: + MLIR module containing the fused attention payload function + """ + mod = ir.Module.create() + shape = (Z, H, n_ctx, n_head) + memref_t = ir.MemRefType.get(shape, dtype) + + with ir.InsertionPoint(mod.body): + # Function signature: payload(output, Q, K, V) + @func_cif(memref_t, memref_t, memref_t, memref_t, name=func_name) + def payload(output, Q_arg, K_arg, V_arg): + # Convert memrefs to tensors + emit_buf_to_tensor(output, restrict=True, writable=True) + Q_tensor = emit_buf_to_tensor(Q_arg, restrict=True) + K_tensor = emit_buf_to_tensor(K_arg, restrict=True) + V_tensor = emit_buf_to_tensor(V_arg, restrict=True) + + # TODO: Implement fused attention computation + # This will involve: + # 1. Q @ K^T (batch matmul with transpose) + # 2. Scale by 1/sqrt(n_head) + # 3. Softmax along last dimension + # 4. Result @ V (batch matmul) + + # Placeholder: create empty output tensor + output_init = tensor.empty(shape, dtype) + result = output_init + + # Materialize result back to output memref + bufferization.materialize_in_destination( + None, result, output, restrict=True, writable=True + ) + + # Emit utility functions for GPU memory management + emit_gpu_util_funcs(dtype, rank=4) + + return mod diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py new file mode 100644 index 00000000..a05e2ca9 --- /dev/null +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -0,0 +1,164 @@ +"""Generate MLIR transform schedule for XeGPU fused attention operation.""" + +from typing import Optional + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.bufferization import LayoutMapOption + +from lighthouse.pipeline.helper import ( + apply_registered_pass, + canonicalize, + match, + match_and_split, + PipelineInterrupt, +) +from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary +from lighthouse.dialects.transform import transform_ext + + +def get_fused_attention_schedule_module( + stop_at_stage: Optional[str] = None, + parameters: Optional[dict] = None, +) -> ir.Module: + """ + Generate transform schedule for fused attention operation. + + The schedule performs the following transformations: + 1. Tile the fused attention operation + 2. Vectorize operations + 3. Bufferize tensors + 4. Convert to GPU dialect + 5. Lower to XeGPU operations + + Args: + stop_at_stage: Optional stage name to stop early (for debugging) + parameters: Dictionary with scheduling parameters: + - batch_size: Batch size (Z) + - num_heads: Number of attention heads (H) + - n_ctx: Context length + - n_head: Head dimension + + Returns: + MLIR module containing the transform schedule + """ + assert parameters is not None, "Schedule parameters must be provided" + + mod = ir.Module.create() + mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + + with ir.InsertionPoint(mod.body): + # Create a transform sequence with proper signature + named_sequence = transform.named_sequence( + "__transform_main", + [transform.AnyOpType.get()], # input: module + [], # no outputs + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + + with ir.InsertionPoint(named_sequence.body): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + + xegpu_fused_attention_transform_schedule( + payload_mod, + parameters=parameters, + stop_at_stage=stop_at_stage or "", + ) + + return mod + + +def xegpu_fused_attention_transform_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +): + """Transform schedule for fused attention payload.""" + try: + mod = bundle_xegpu_fused_attention_schedule( + mod, + parameters=parameters, + stop_at_stage=stop_at_stage, + ) + + mod = bundle_xegpu_to_binary( + mod, + stop_at_stage=stop_at_stage, + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + + +def bundle_xegpu_fused_attention_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +) -> ir.Value[transform.AnyOpType]: + """Schedule for lowering fused attention payload to xegpu wg level.""" + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + anytype = transform.AnyOpType.get() + + # TODO: Implement tiling, fusion, and lowering for fused attention + # This will involve: + # 1. Matching and tiling matmul operations (Q @ K^T) + # 2. Fusing softmax operation + # 3. Tiling second matmul (attention @ V) + # 4. Vectorization + # 5. Bufferization + # 6. GPU outlining + # 7. XeGPU lowering + + func = match(mod, ops={"func.func"}) + + if stop_at_stage == "tiled": + raise PipelineInterrupt() + + # vectorize (placeholder) + # func = structured.VectorizeChildrenAndApplyPatternsOp( + # func, + # fold_type_extensions_into_contract=True, + # ).result + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + # bufferize (placeholder) + # mod = apply_registered_pass(mod, "eliminate-empty-tensors") + # identity_layout = LayoutMapOption.IdentityLayoutMap + # mod = transform_bufferization.OneShotBufferizeOp( + # mod, + # allow_return_allocs_from_loops=True, + # bufferize_function_boundaries=True, + # function_boundary_type_conversion=identity_layout, + # ).result + + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + + return mod From 3262e4a58ee3e7756b3098033815f3a719041407 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 20 Apr 2026 22:28:04 +0000 Subject: [PATCH 35/63] add initial version --- examples/xegpu/fused_attention.py | 22 ++++++++++++++----- .../mlir_gen/gpu_fused_attention_payload.py | 2 +- .../xegpu/fused_attention_schedule.py | 6 ----- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index 5aa84e2a..c42faf0e 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -18,8 +18,12 @@ from lighthouse.execution import GPUMemoryManager from lighthouse.utils.numpy import mlir_to_numpy_dtype from lighthouse.ingress.mlir_gen import get_mlir_elem_type -from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import generate_gpu_fused_attention_payload -from lighthouse.schedule.xegpu.fused_attention_schedule import get_fused_attention_schedule_module +from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import ( + generate_gpu_fused_attention_payload, +) +from lighthouse.schedule.xegpu.fused_attention_schedule import ( + get_fused_attention_schedule_module, +) def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int): @@ -41,7 +45,11 @@ def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: def check_correctness( - Q: np.ndarray, K: np.ndarray, V: np.ndarray, output_arr: np.ndarray, verbose: int = 0 + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + output_arr: np.ndarray, + verbose: int = 0, ) -> bool: """ Check correctness of fused attention output. @@ -146,7 +154,9 @@ def _initial_host_arrays(self) -> tuple[np.ndarray]: def get_complexity(self) -> tuple[int, int, int]: nbytes = np.dtype(self.dtype).itemsize - return fused_attention_complexity(self.Z, self.H, self.n_ctx, self.n_head, nbytes) + return fused_attention_complexity( + self.Z, self.H, self.n_ctx, self.n_head, nbytes + ) def payload_module(self) -> ir.Module: """Generate MLIR module for fused attention payload.""" @@ -309,7 +319,9 @@ def parse_cli(): # Compute reference solution on host. Q, K, V = wload._initial_host_arrays[1:4] success = check_correctness( - Q, K, V, + Q, + K, + V, result_host_copy, verbose=args.verbose, ) diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py index 788aa175..73046604 100644 --- a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -1,7 +1,7 @@ """Generate MLIR payload for GPU fused attention operation.""" from mlir import ir -from mlir.dialects import linalg, bufferization, tensor, arith +from mlir.dialects import bufferization, tensor from lighthouse.utils.mlir import func_cif from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index a05e2ca9..5fa47770 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -4,19 +4,13 @@ from mlir import ir from mlir.dialects import transform -from mlir.dialects.transform import structured, loop, xegpu -from mlir.dialects.transform import bufferization as transform_bufferization -from mlir.dialects.bufferization import LayoutMapOption from lighthouse.pipeline.helper import ( - apply_registered_pass, canonicalize, match, - match_and_split, PipelineInterrupt, ) from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary -from lighthouse.dialects.transform import transform_ext def get_fused_attention_schedule_module( From 2135de3537ed55cca6da480ec38fcfa82cf83dd7 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 21 Apr 2026 20:59:59 +0000 Subject: [PATCH 36/63] payload done --- .../mlir_gen/gpu_fused_attention_payload.py | 90 ++++++++++++++++--- 1 file changed, 80 insertions(+), 10 deletions(-) diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py index 73046604..87818d08 100644 --- a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -1,7 +1,9 @@ """Generate MLIR payload for GPU fused attention operation.""" +import math + from mlir import ir -from mlir.dialects import bufferization, tensor +from mlir.dialects import arith, bufferization, linalg, tensor from lighthouse.utils.mlir import func_cif from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs @@ -47,16 +49,84 @@ def payload(output, Q_arg, K_arg, V_arg): K_tensor = emit_buf_to_tensor(K_arg, restrict=True) V_tensor = emit_buf_to_tensor(V_arg, restrict=True) - # TODO: Implement fused attention computation - # This will involve: - # 1. Q @ K^T (batch matmul with transpose) - # 2. Scale by 1/sqrt(n_head) - # 3. Softmax along last dimension - # 4. Result @ V (batch matmul) + # Collapse first 3 dimensions (Z, H, n_ctx) into a single dimension + # From (Z, H, n_ctx, n_head) to (Z*H*n_ctx, n_head) + collapsed_dim = Z * H * n_ctx + collapsed_shape_2d = (collapsed_dim, n_head) + + Q_2d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_2d, dtype), + Q_tensor, + reassociation=[[0, 1, 2], [3]], + ) + K_2d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_2d, dtype), + K_tensor, + reassociation=[[0, 1, 2], [3]], + ) + V_2d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_2d, dtype), + V_tensor, + reassociation=[[0, 1, 2], [3]], + ) + + # Step 1: Transpose K to get K^T + # Permute from (collapsed_dim, n_head) to (n_head, collapsed_dim) + kt_shape_2d = (n_head, collapsed_dim) + kt_init = tensor.empty(kt_shape_2d, dtype) + K_transposed = linalg.transpose(K_2d, outs=[kt_init], permutation=[1, 0]) + + # Step 2: Compute Q @ K^T + # Q: (collapsed_dim, n_head) @ K^T: (n_head, collapsed_dim) + # Result: (collapsed_dim, collapsed_dim) + qkt_shape_2d = (collapsed_dim, collapsed_dim) + qkt_init = tensor.empty(qkt_shape_2d, dtype) + # Initialize with zeros for matmul accumulation + zero = arith.constant(dtype, 0.0) + qkt_init_filled = linalg.fill(zero, outs=[qkt_init]) + + # Matmul: Q @ K^T + qkt = linalg.matmul(Q_2d, K_transposed, outs=[qkt_init_filled]) + + # Step 3: Scale by 1/sqrt(n_head) + scale_factor = 1.0 / math.sqrt(n_head) + scale_const = arith.constant(dtype, scale_factor) - # Placeholder: create empty output tensor - output_init = tensor.empty(shape, dtype) - result = output_init + # Create a tensor filled with the scale factor + scale_tensor_init = tensor.empty(qkt_shape_2d, dtype) + scale_tensor = linalg.fill(scale_const, outs=[scale_tensor_init]) + + # Elementwise multiply qkt with scale tensor + scaled_qkt_init = tensor.empty(qkt_shape_2d, dtype) + scaled_qkt = linalg.mul(qkt, scale_tensor, outs=[scaled_qkt_init]) + + # Step 4: Apply softmax along the last dimension (dim=1 in 2D) + softmax_init = tensor.empty(qkt_shape_2d, dtype) + attention_weights = linalg.softmax( + result=[ir.RankedTensorType.get(qkt_shape_2d, dtype)], + input=scaled_qkt, + output=softmax_init, + dimension=1, + ) + + # Step 5: Multiply attention weights by V + # attention_weights: (collapsed_dim, collapsed_dim) @ V: (collapsed_dim, n_head) + # Result: (collapsed_dim, n_head) + output_2d_init = tensor.empty(collapsed_shape_2d, dtype) + output_2d_init_filled = linalg.fill(zero, outs=[output_2d_init]) + + result_2d = linalg.matmul( + attention_weights, V_2d, outs=[output_2d_init_filled] + ) + + # Expand back to 4D: (Z*H*n_ctx, n_head) -> (Z, H, n_ctx, n_head) + result = tensor.expand_shape( + ir.RankedTensorType.get(shape, dtype), + result_2d, + reassociation=[[0, 1, 2], [3]], + output_shape=[], + static_output_shape=shape, + ) # Materialize result back to output memref bufferization.materialize_in_destination( From 361e069c7f079dd781c96b7331d97f465bd62446 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 21 Apr 2026 21:51:32 +0000 Subject: [PATCH 37/63] tiled last matmul --- examples/xegpu/fused_attention.py | 7 ++++ .../xegpu/fused_attention_schedule.py | 42 +++++++++++++------ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index c42faf0e..213915f0 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -214,6 +214,12 @@ def parse_cli(): default=64, help="Head dimension", ) + parser.add_argument( + "--wg-tile-size", + type=int, + default=64, + help="Workgroup tile size for the collapsed batch dimension (Z*H*n_ctx)", + ) parser.add_argument( "--nruns", type=int, @@ -271,6 +277,7 @@ def parse_cli(): "num_heads": args.num_heads, "n_ctx": args.n_ctx, "n_head": args.n_head, + "wg_tile_size": args.wg_tile_size, } Z = args.batch_size diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 5fa47770..d66efa24 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -4,10 +4,12 @@ from mlir import ir from mlir.dialects import transform +from mlir.dialects.transform import structured from lighthouse.pipeline.helper import ( canonicalize, match, + match_and_split, PipelineInterrupt, ) from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary @@ -34,6 +36,7 @@ def get_fused_attention_schedule_module( - num_heads: Number of attention heads (H) - n_ctx: Context length - n_head: Head dimension + - wg_tile_size: Workgroup tile size for the collapsed batch dimension (Z*H*n_ctx) Returns: MLIR module containing the transform schedule @@ -106,18 +109,33 @@ def bundle_xegpu_fused_attention_schedule( raise PipelineInterrupt() anytype = transform.AnyOpType.get() - - # TODO: Implement tiling, fusion, and lowering for fused attention - # This will involve: - # 1. Matching and tiling matmul operations (Q @ K^T) - # 2. Fusing softmax operation - # 3. Tiling second matmul (attention @ V) - # 4. Vectorization - # 5. Bufferization - # 6. GPU outlining - # 7. XeGPU lowering - - func = match(mod, ops={"func.func"}) + # Match all matmul operations - there should be 2: + # 1. Q @ K^T + # 2. attention_weights @ V + matmul_ops = match_and_split(mod, ops={"linalg.matmul"}, nhandles=2) + + # Get the last matmul (attention_weights @ V) + last_matmul = matmul_ops[1] + func = transform.get_parent_op( + anytype, + last_matmul, + op_name="func.func", + deduplicate=True, + ) + + # Tile the last matmul in the batch dimension using tile_using_forall + # Batch dimension is the first dimension (collapsed_dim = Z * H * n_ctx) + # Extract workgroup tile size from parameters + wg_tile_size = parameters["wg_tile_size"] + + tiled_matmul, forall_loop = structured.structured_tile_using_forall( + anytype, + anytype, + last_matmul, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(wg_tile_size, 0), + ) if stop_at_stage == "tiled": raise PipelineInterrupt() From 4d0827ec0b4e533a07ffe13fff42f7c9447722e8 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 21 Apr 2026 23:59:37 +0000 Subject: [PATCH 38/63] change to batch matmul --- .../mlir_gen/gpu_fused_attention_payload.py | 82 +++++++++---------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py index 87818d08..c2f3d4ec 100644 --- a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -49,81 +49,81 @@ def payload(output, Q_arg, K_arg, V_arg): K_tensor = emit_buf_to_tensor(K_arg, restrict=True) V_tensor = emit_buf_to_tensor(V_arg, restrict=True) - # Collapse first 3 dimensions (Z, H, n_ctx) into a single dimension - # From (Z, H, n_ctx, n_head) to (Z*H*n_ctx, n_head) - collapsed_dim = Z * H * n_ctx - collapsed_shape_2d = (collapsed_dim, n_head) + # Collapse first 2 dimensions (Z, H) into a batch dimension + # From (Z, H, n_ctx, n_head) to (Z*H, n_ctx, n_head) + batch_dim = Z * H + collapsed_shape_3d = (batch_dim, n_ctx, n_head) - Q_2d = tensor.collapse_shape( - ir.RankedTensorType.get(collapsed_shape_2d, dtype), + Q_3d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_3d, dtype), Q_tensor, - reassociation=[[0, 1, 2], [3]], + reassociation=[[0, 1], [2], [3]], ) - K_2d = tensor.collapse_shape( - ir.RankedTensorType.get(collapsed_shape_2d, dtype), + K_3d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_3d, dtype), K_tensor, - reassociation=[[0, 1, 2], [3]], + reassociation=[[0, 1], [2], [3]], ) - V_2d = tensor.collapse_shape( - ir.RankedTensorType.get(collapsed_shape_2d, dtype), + V_3d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_3d, dtype), V_tensor, - reassociation=[[0, 1, 2], [3]], + reassociation=[[0, 1], [2], [3]], ) # Step 1: Transpose K to get K^T - # Permute from (collapsed_dim, n_head) to (n_head, collapsed_dim) - kt_shape_2d = (n_head, collapsed_dim) - kt_init = tensor.empty(kt_shape_2d, dtype) - K_transposed = linalg.transpose(K_2d, outs=[kt_init], permutation=[1, 0]) - - # Step 2: Compute Q @ K^T - # Q: (collapsed_dim, n_head) @ K^T: (n_head, collapsed_dim) - # Result: (collapsed_dim, collapsed_dim) - qkt_shape_2d = (collapsed_dim, collapsed_dim) - qkt_init = tensor.empty(qkt_shape_2d, dtype) + # Permute from (batch_dim, n_ctx, n_head) to (batch_dim, n_head, n_ctx) + kt_shape_3d = (batch_dim, n_head, n_ctx) + kt_init = tensor.empty(kt_shape_3d, dtype) + K_transposed = linalg.transpose(K_3d, outs=[kt_init], permutation=[0, 2, 1]) + + # Step 2: Compute Q @ K^T using batch_matmul + # Q: (batch_dim, n_ctx, n_head) @ K^T: (batch_dim, n_head, n_ctx) + # Result: (batch_dim, n_ctx, n_ctx) + qkt_shape_3d = (batch_dim, n_ctx, n_ctx) + qkt_init = tensor.empty(qkt_shape_3d, dtype) # Initialize with zeros for matmul accumulation zero = arith.constant(dtype, 0.0) qkt_init_filled = linalg.fill(zero, outs=[qkt_init]) - # Matmul: Q @ K^T - qkt = linalg.matmul(Q_2d, K_transposed, outs=[qkt_init_filled]) + # Batch matmul: Q @ K^T + qkt = linalg.batch_matmul(Q_3d, K_transposed, outs=[qkt_init_filled]) # Step 3: Scale by 1/sqrt(n_head) scale_factor = 1.0 / math.sqrt(n_head) scale_const = arith.constant(dtype, scale_factor) # Create a tensor filled with the scale factor - scale_tensor_init = tensor.empty(qkt_shape_2d, dtype) + scale_tensor_init = tensor.empty(qkt_shape_3d, dtype) scale_tensor = linalg.fill(scale_const, outs=[scale_tensor_init]) # Elementwise multiply qkt with scale tensor - scaled_qkt_init = tensor.empty(qkt_shape_2d, dtype) + scaled_qkt_init = tensor.empty(qkt_shape_3d, dtype) scaled_qkt = linalg.mul(qkt, scale_tensor, outs=[scaled_qkt_init]) - # Step 4: Apply softmax along the last dimension (dim=1 in 2D) - softmax_init = tensor.empty(qkt_shape_2d, dtype) + # Step 4: Apply softmax along the last dimension (dim=2 in 3D) + softmax_init = tensor.empty(qkt_shape_3d, dtype) attention_weights = linalg.softmax( - result=[ir.RankedTensorType.get(qkt_shape_2d, dtype)], + result=[ir.RankedTensorType.get(qkt_shape_3d, dtype)], input=scaled_qkt, output=softmax_init, - dimension=1, + dimension=2, ) - # Step 5: Multiply attention weights by V - # attention_weights: (collapsed_dim, collapsed_dim) @ V: (collapsed_dim, n_head) - # Result: (collapsed_dim, n_head) - output_2d_init = tensor.empty(collapsed_shape_2d, dtype) - output_2d_init_filled = linalg.fill(zero, outs=[output_2d_init]) + # Step 5: Multiply attention weights by V using batch_matmul + # attention_weights: (batch_dim, n_ctx, n_ctx) @ V: (batch_dim, n_ctx, n_head) + # Result: (batch_dim, n_ctx, n_head) + output_3d_init = tensor.empty(collapsed_shape_3d, dtype) + output_3d_init_filled = linalg.fill(zero, outs=[output_3d_init]) - result_2d = linalg.matmul( - attention_weights, V_2d, outs=[output_2d_init_filled] + result_3d = linalg.batch_matmul( + attention_weights, V_3d, outs=[output_3d_init_filled] ) - # Expand back to 4D: (Z*H*n_ctx, n_head) -> (Z, H, n_ctx, n_head) + # Expand back to 4D: (Z*H, n_ctx, n_head) -> (Z, H, n_ctx, n_head) result = tensor.expand_shape( ir.RankedTensorType.get(shape, dtype), - result_2d, - reassociation=[[0, 1, 2], [3]], + result_3d, + reassociation=[[0, 1], [2], [3]], output_shape=[], static_output_shape=shape, ) From e379b683225f6943ae0b9645c0b3a5aed4bec420 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 22 Apr 2026 16:26:23 +0000 Subject: [PATCH 39/63] save work --- examples/xegpu/fused_attention.py | 3 +- .../xegpu/fused_attention_schedule.py | 109 ++++++++++++++++-- 2 files changed, 104 insertions(+), 8 deletions(-) diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index 213915f0..da61b3f7 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -242,7 +242,8 @@ def parse_cli(): type=str, choices=[ "initial", - "tiled", + "outer-tiled", + "inner-tiled", "vectorized", "bufferized", "gpu-outlining", diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index d66efa24..6ea15081 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -109,6 +109,7 @@ def bundle_xegpu_fused_attention_schedule( raise PipelineInterrupt() anytype = transform.AnyOpType.get() + anyvalue = transform.AnyValueType.get() # Match all matmul operations - there should be 2: # 1. Q @ K^T # 2. attention_weights @ V @@ -137,17 +138,111 @@ def bundle_xegpu_fused_attention_schedule( static_tile_sizes=(wg_tile_size, 0), ) - if stop_at_stage == "tiled": - raise PipelineInterrupt() + # Fuse the softmax producer into forall + softmax_ops = match_and_split(func, ops={"linalg.softmax"}, nhandles=1) + softmax_op = softmax_ops[0] + fused_softmax_op, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=softmax_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) - # vectorize (placeholder) - # func = structured.VectorizeChildrenAndApplyPatternsOp( - # func, - # fold_type_extensions_into_contract=True, - # ).result + # Fuse linalg.mul (scaling) into forall + mul_ops = match_and_split(func, ops={"linalg.mul"}, nhandles=1) + mul_op = mul_ops[0] + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=mul_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # Fuse the first matmul (Q @ K^T) into forall + matmul_ops = match_and_split( + func, ops={"linalg.matmul"}, nhandles=2 + ) # Two matmuls are present. + first_matmul = matmul_ops[0] + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=first_matmul, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # Fuse linalg.transpose (K transpose) into forall + transpose_ops = match_and_split(func, ops={"linalg.transpose"}, nhandles=1) + transpose_op = transpose_ops[0] + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=transpose_op, + containing_op=forall_loop, + ) transform.apply_cse(func) canonicalize(func) + # At this point all of the key operations are fused into the forall loop. + # Remaining linalg.fill ops can be fused trivially. + fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=3) + for fill_op in fill_ops: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=fill_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # tensor.empty() holding the result of transpose can be fused. + transpose_op = match_and_split(func, ops={"linalg.transpose"}, nhandles=1)[0] + transpose_init = transform.get_producer_of_operand( + anytype, transpose_op, operand_number=1 + ) + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=transpose_init, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # tensor.empty() ops holding the result of the softmax can also be fused. + softmax_op = match_and_split(func, ops={"linalg.softmax"}, nhandles=1)[0] + softmax_init = transform.get_producer_of_operand( + anytype, softmax_op, operand_number=1 + ) + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=softmax_init, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "outer-tiled": + raise PipelineInterrupt() + + # # vectorize (placeholder) + # # func = structured.VectorizeChildrenAndApplyPatternsOp( + # # func, + # # fold_type_extensions_into_contract=True, + # # ).result + # transform.apply_cse(func) + # canonicalize(func) + + if stop_at_stage == "inner-tiled": + raise PipelineInterrupt() + if stop_at_stage == "vectorized": raise PipelineInterrupt() From 38f2d97502459124b8eb82d1de167c5b79404126 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 27 Apr 2026 18:49:20 +0000 Subject: [PATCH 40/63] save work --- .../xegpu/fused_attention_schedule.py | 111 +++++++++--------- 1 file changed, 55 insertions(+), 56 deletions(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 6ea15081..c64f2956 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -113,7 +113,7 @@ def bundle_xegpu_fused_attention_schedule( # Match all matmul operations - there should be 2: # 1. Q @ K^T # 2. attention_weights @ V - matmul_ops = match_and_split(mod, ops={"linalg.matmul"}, nhandles=2) + matmul_ops = match_and_split(mod, ops={"linalg.batch_matmul"}, nhandles=2) # Get the last matmul (attention_weights @ V) last_matmul = matmul_ops[1] @@ -124,9 +124,7 @@ def bundle_xegpu_fused_attention_schedule( deduplicate=True, ) - # Tile the last matmul in the batch dimension using tile_using_forall - # Batch dimension is the first dimension (collapsed_dim = Z * H * n_ctx) - # Extract workgroup tile size from parameters + # Tile the last matmul in both batch and M dimensions. wg_tile_size = parameters["wg_tile_size"] tiled_matmul, forall_loop = structured.structured_tile_using_forall( @@ -135,7 +133,7 @@ def bundle_xegpu_fused_attention_schedule( last_matmul, num_threads=[], tile_sizes=[], - static_tile_sizes=(wg_tile_size, 0), + static_tile_sizes=(1, wg_tile_size, 0, 0) ) # Fuse the softmax producer into forall @@ -163,8 +161,9 @@ def bundle_xegpu_fused_attention_schedule( canonicalize(func) # Fuse the first matmul (Q @ K^T) into forall + # TODO: This fusion does not work?? matmul_ops = match_and_split( - func, ops={"linalg.matmul"}, nhandles=2 + func, ops={"linalg.batch_matmul"}, nhandles=2 ) # Two matmuls are present. first_matmul = matmul_ops[0] _, forall_loop = structured.structured_fuse_into_containing_op( @@ -176,58 +175,58 @@ def bundle_xegpu_fused_attention_schedule( transform.apply_cse(func) canonicalize(func) - # Fuse linalg.transpose (K transpose) into forall - transpose_ops = match_and_split(func, ops={"linalg.transpose"}, nhandles=1) - transpose_op = transpose_ops[0] - _, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=transpose_op, - containing_op=forall_loop, - ) - transform.apply_cse(func) - canonicalize(func) - - # At this point all of the key operations are fused into the forall loop. - # Remaining linalg.fill ops can be fused trivially. - fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=3) - for fill_op in fill_ops: - _, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=fill_op, - containing_op=forall_loop, - ) - transform.apply_cse(func) - canonicalize(func) + # # Fuse linalg.transpose (K transpose) into forall + # transpose_ops = match_and_split(func, ops={"linalg.transpose"}, nhandles=1) + # transpose_op = transpose_ops[0] + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=transpose_op, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) - # tensor.empty() holding the result of transpose can be fused. - transpose_op = match_and_split(func, ops={"linalg.transpose"}, nhandles=1)[0] - transpose_init = transform.get_producer_of_operand( - anytype, transpose_op, operand_number=1 - ) - _, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=transpose_init, - containing_op=forall_loop, - ) - transform.apply_cse(func) - canonicalize(func) + # # At this point all of the key operations are fused into the forall loop. + # # Remaining linalg.fill ops can be fused trivially. + # fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=3) + # for fill_op in fill_ops: + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=fill_op, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) + + # # tensor.empty() holding the result of transpose can be fused. + # transpose_op = match_and_split(func, ops={"linalg.transpose"}, nhandles=1)[0] + # transpose_init = transform.get_producer_of_operand( + # anytype, transpose_op, operand_number=1 + # ) + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=transpose_init, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) - # tensor.empty() ops holding the result of the softmax can also be fused. - softmax_op = match_and_split(func, ops={"linalg.softmax"}, nhandles=1)[0] - softmax_init = transform.get_producer_of_operand( - anytype, softmax_op, operand_number=1 - ) - _, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=softmax_init, - containing_op=forall_loop, - ) - transform.apply_cse(func) - canonicalize(func) + # # tensor.empty() ops holding the result of the softmax can also be fused. + # softmax_op = match_and_split(func, ops={"linalg.softmax"}, nhandles=1)[0] + # softmax_init = transform.get_producer_of_operand( + # anytype, softmax_op, operand_number=1 + # ) + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=softmax_init, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) if stop_at_stage == "outer-tiled": raise PipelineInterrupt() From 0076ee264a08b4e1729061fb73ec9c3f89b951dd Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 27 Apr 2026 18:49:28 +0000 Subject: [PATCH 41/63] save work --- lighthouse/schedule/xegpu/fused_attention_schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index c64f2956..4b690eb7 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -133,7 +133,7 @@ def bundle_xegpu_fused_attention_schedule( last_matmul, num_threads=[], tile_sizes=[], - static_tile_sizes=(1, wg_tile_size, 0, 0) + static_tile_sizes=(1, wg_tile_size, 0, 0), ) # Fuse the softmax producer into forall From 452faf77a685753cce53a240a151a13ed3d1d311 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 27 Apr 2026 19:05:40 +0000 Subject: [PATCH 42/63] update llvm --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cf5fe939..0b9d10e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "lighthouse" dynamic = ["version"] requires-python = ">=3.10,<3.13" # Bounds are due to torch-mlir's packaging dependencies = [ - "mlir-python-bindings==20260417+27769d7b5", + "mlir-python-bindings==20260427+652700b4c", "pyyaml>=6.0", ] From 32484775781539bf6826aedabe436c03daddf2d7 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 27 Apr 2026 21:47:35 +0000 Subject: [PATCH 43/63] refactor code --- examples/xegpu/fused_attention.py | 4 +- .../xegpu/fused_attention_schedule.py | 66 +++++-------------- 2 files changed, 20 insertions(+), 50 deletions(-) diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index da61b3f7..2131100f 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -22,7 +22,7 @@ generate_gpu_fused_attention_payload, ) from lighthouse.schedule.xegpu.fused_attention_schedule import ( - get_fused_attention_schedule_module, + fused_attention_schedule, ) @@ -175,7 +175,7 @@ def schedule_modules( """Generate transform schedule for fused attention.""" return [ Runner.get_bench_wrapper_schedule(self.payload_function_name), - get_fused_attention_schedule_module( + fused_attention_schedule( stop_at_stage=stop_at_stage, parameters=parameters, ), diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 4b690eb7..85277787 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -12,10 +12,10 @@ match_and_split, PipelineInterrupt, ) -from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary +from lighthouse.schedule import schedule_boilerplate -def get_fused_attention_schedule_module( +def fused_attention_schedule( stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None, ) -> ir.Module: @@ -43,59 +43,29 @@ def get_fused_attention_schedule_module( """ assert parameters is not None, "Schedule parameters must be provided" - mod = ir.Module.create() - mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() - - with ir.InsertionPoint(mod.body): - # Create a transform sequence with proper signature - named_sequence = transform.named_sequence( - "__transform_main", - [transform.AnyOpType.get()], # input: module - [], # no outputs - arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + with schedule_boilerplate() as (schedule, named_seq): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_seq.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, ) - with ir.InsertionPoint(named_sequence.body): - # match the payload module - anytype = transform.AnyOpType.get() - func = match(named_sequence.bodyTarget, ops={"func.func"}) - payload_mod = transform.get_parent_op( - anytype, - func, - op_name="builtin.module", - deduplicate=True, - ) - - xegpu_fused_attention_transform_schedule( + try: + bundle_xegpu_fused_attention_schedule( payload_mod, parameters=parameters, stop_at_stage=stop_at_stage or "", ) + except PipelineInterrupt: + pass + finally: + transform.yield_() - return mod - - -def xegpu_fused_attention_transform_schedule( - mod: ir.Value[transform.AnyOpType], - parameters: dict, - stop_at_stage: str = "", -): - """Transform schedule for fused attention payload.""" - try: - mod = bundle_xegpu_fused_attention_schedule( - mod, - parameters=parameters, - stop_at_stage=stop_at_stage, - ) - - mod = bundle_xegpu_to_binary( - mod, - stop_at_stage=stop_at_stage, - ) - except PipelineInterrupt: - pass - finally: - transform.yield_() + return schedule def bundle_xegpu_fused_attention_schedule( From ce07760b017aaf1b67b659603e223380401c28b1 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 28 Apr 2026 18:43:50 +0000 Subject: [PATCH 44/63] refactor code --- .../xegpu/fused_attention_schedule.py | 100 +++++++++--------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 85277787..992dff61 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -145,58 +145,58 @@ def bundle_xegpu_fused_attention_schedule( transform.apply_cse(func) canonicalize(func) - # # Fuse linalg.transpose (K transpose) into forall - # transpose_ops = match_and_split(func, ops={"linalg.transpose"}, nhandles=1) - # transpose_op = transpose_ops[0] - # _, forall_loop = structured.structured_fuse_into_containing_op( - # anytype, - # anytype, - # producer_op=transpose_op, - # containing_op=forall_loop, - # ) - # transform.apply_cse(func) - # canonicalize(func) + # Fuse linalg.transpose (K transpose) into forall + transpose_ops = match_and_split(func, ops={"linalg.transpose"}, nhandles=1) + transpose_op = transpose_ops[0] + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=transpose_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) - # # At this point all of the key operations are fused into the forall loop. - # # Remaining linalg.fill ops can be fused trivially. - # fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=3) - # for fill_op in fill_ops: - # _, forall_loop = structured.structured_fuse_into_containing_op( - # anytype, - # anytype, - # producer_op=fill_op, - # containing_op=forall_loop, - # ) - # transform.apply_cse(func) - # canonicalize(func) - - # # tensor.empty() holding the result of transpose can be fused. - # transpose_op = match_and_split(func, ops={"linalg.transpose"}, nhandles=1)[0] - # transpose_init = transform.get_producer_of_operand( - # anytype, transpose_op, operand_number=1 - # ) - # _, forall_loop = structured.structured_fuse_into_containing_op( - # anytype, - # anytype, - # producer_op=transpose_init, - # containing_op=forall_loop, - # ) - # transform.apply_cse(func) - # canonicalize(func) + # At this point all of the key operations are fused into the forall loop. + # Remaining linalg.fill ops can be fused trivially. + fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=3) + for fill_op in fill_ops: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=fill_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) - # # tensor.empty() ops holding the result of the softmax can also be fused. - # softmax_op = match_and_split(func, ops={"linalg.softmax"}, nhandles=1)[0] - # softmax_init = transform.get_producer_of_operand( - # anytype, softmax_op, operand_number=1 - # ) - # _, forall_loop = structured.structured_fuse_into_containing_op( - # anytype, - # anytype, - # producer_op=softmax_init, - # containing_op=forall_loop, - # ) - # transform.apply_cse(func) - # canonicalize(func) + # tensor.empty() holding the result of transpose can be fused. + transpose_op = match_and_split(func, ops={"linalg.transpose"}, nhandles=1)[0] + transpose_init = transform.get_producer_of_operand( + anytype, transpose_op, operand_number=1 + ) + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=transpose_init, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # tensor.empty() ops holding the result of the softmax can also be fused. + softmax_op = match_and_split(func, ops={"linalg.softmax"}, nhandles=1)[0] + softmax_init = transform.get_producer_of_operand( + anytype, softmax_op, operand_number=1 + ) + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=softmax_init, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) if stop_at_stage == "outer-tiled": raise PipelineInterrupt() From a0c421bacf8c14f817edca4a119af6e7f2c10b25 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 28 Apr 2026 23:33:00 +0000 Subject: [PATCH 45/63] parallel dim tiling done --- .../xegpu/fused_attention_schedule.py | 112 ++++++++---------- 1 file changed, 48 insertions(+), 64 deletions(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 992dff61..d838c7f0 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -105,96 +105,80 @@ def bundle_xegpu_fused_attention_schedule( tile_sizes=[], static_tile_sizes=(1, wg_tile_size, 0, 0), ) - - # Fuse the softmax producer into forall - softmax_ops = match_and_split(func, ops={"linalg.softmax"}, nhandles=1) - softmax_op = softmax_ops[0] - fused_softmax_op, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=softmax_op, - containing_op=forall_loop, + # Fuse the zero initialization of the output of the last matmul (tensor.empty) into the forall loop. + tiled_matmul_init = transform.get_producer_of_operand( + anytype, forall_loop, operand_number=0 ) - transform.apply_cse(func) - canonicalize(func) - - # Fuse linalg.mul (scaling) into forall - mul_ops = match_and_split(func, ops={"linalg.mul"}, nhandles=1) - mul_op = mul_ops[0] _, forall_loop = structured.structured_fuse_into_containing_op( anytype, anytype, - producer_op=mul_op, + producer_op=tiled_matmul_init, containing_op=forall_loop, ) transform.apply_cse(func) canonicalize(func) - # Fuse the first matmul (Q @ K^T) into forall - # TODO: This fusion does not work?? - matmul_ops = match_and_split( - func, ops={"linalg.batch_matmul"}, nhandles=2 - ) # Two matmuls are present. - first_matmul = matmul_ops[0] - _, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=first_matmul, - containing_op=forall_loop, - ) + # Decompose softmax into generic ops + softmax_ops = match_and_split(func, ops={"linalg.softmax"}, nhandles=1) + softmax_op = softmax_ops[0] + structured.structured_decompose_interface(anytype, softmax_op) transform.apply_cse(func) canonicalize(func) - # Fuse linalg.transpose (K transpose) into forall - transpose_ops = match_and_split(func, ops={"linalg.transpose"}, nhandles=1) - transpose_op = transpose_ops[0] - _, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=transpose_op, - containing_op=forall_loop, - ) + # Fuse all linalg.generic ops from softmax decomposition (4 ops: max, sub+exp, sum, div) + # Match and fuse in reverse order (from consumer to producer) + generic_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=4) + for generic_op in reversed(generic_ops): + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=generic_op, + containing_op=forall_loop, + ) transform.apply_cse(func) canonicalize(func) - # At this point all of the key operations are fused into the forall loop. - # Remaining linalg.fill ops can be fused trivially. - fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=3) - for fill_op in fill_ops: + # Max and add reductions use linalg.fill to intialize the reduction output. Fuse these fill ops as well. + fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=5) + # Max fill is the third fill op and add fill is the fourth fill op (based on the pattern of decomposition) + max_fill_op = fill_ops[2] + add_fill_op = fill_ops[3] + for fill_op in [max_fill_op, add_fill_op]: _, forall_loop = structured.structured_fuse_into_containing_op( anytype, anytype, producer_op=fill_op, containing_op=forall_loop, ) - transform.apply_cse(func) - canonicalize(func) - - # tensor.empty() holding the result of transpose can be fused. - transpose_op = match_and_split(func, ops={"linalg.transpose"}, nhandles=1)[0] - transpose_init = transform.get_producer_of_operand( - anytype, transpose_op, operand_number=1 - ) - _, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=transpose_init, - containing_op=forall_loop, - ) transform.apply_cse(func) canonicalize(func) - # tensor.empty() ops holding the result of the softmax can also be fused. - softmax_op = match_and_split(func, ops={"linalg.softmax"}, nhandles=1)[0] - softmax_init = transform.get_producer_of_operand( - anytype, softmax_op, operand_number=1 + linalg_mul_op = match_and_split(func, ops={"linalg.mul"}, nhandles=1)[0] + first_matmul = transform.get_producer_of_operand( + anytype, linalg_mul_op, operand_number=0 ) - _, forall_loop = structured.structured_fuse_into_containing_op( - anytype, - anytype, - producer_op=softmax_init, - containing_op=forall_loop, + scale_fill_op = transform.get_producer_of_operand( + anytype, linalg_mul_op, operand_number=1 + ) + transpose_op = transform.get_producer_of_operand( + anytype, first_matmul, operand_number=1 + ) + matmul_fill_op = transform.get_producer_of_operand( + anytype, first_matmul, operand_number=2 ) + for op in [ + linalg_mul_op, + scale_fill_op, + first_matmul, + matmul_fill_op, + transpose_op, + ]: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=op, + containing_op=forall_loop, + ) transform.apply_cse(func) canonicalize(func) From dba42ad3fcc9583056eabc3c8434121aa3874997 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 29 Apr 2026 16:28:39 +0000 Subject: [PATCH 46/63] address comments --- .../transform/transform_ext/ops/update_address_space.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py index 8d0b6041..d1160ad7 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py +++ b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py @@ -38,7 +38,7 @@ def apply( new_ops = [] # Verify this is a memref.alloca operation - if target_op.OPERATION_NAME != "memref.alloca": + if not isinstance(target_op.opview, memref.AllocaOp): return DiagnosedSilenceableFailure.emit_silenceable_error( f"Expected memref.alloca operation, got {target_op.OPERATION_NAME}" ) From ba380a3a0e84d5b5b517b030ddf9d32b1acaa4aa Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 29 Apr 2026 21:52:26 +0000 Subject: [PATCH 47/63] initial verion without reduction tiling --- examples/xegpu/fused_attention.py | 20 +++- .../mlir_gen/gpu_fused_attention_payload.py | 61 +++++----- .../xegpu/fused_attention_schedule.py | 105 ++++++++++++++---- 3 files changed, 131 insertions(+), 55 deletions(-) diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index 2131100f..ebc4297d 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -215,10 +215,22 @@ def parse_cli(): help="Head dimension", ) parser.add_argument( - "--wg-tile-size", + "--wg-rows", type=int, default=64, - help="Workgroup tile size for the collapsed batch dimension (Z*H*n_ctx)", + help="Number of Q*K^T*V rows computed by each work group", + ) + parser.add_argument( + "--sg-rows", + type=int, + default=8, + help="Number of Q*K^T*V rows computed by each subgroup", + ) + parser.add_argument( + "--subgroup-size", + type=int, + default=16, + help="Subgroup size", ) parser.add_argument( "--nruns", @@ -278,7 +290,9 @@ def parse_cli(): "num_heads": args.num_heads, "n_ctx": args.n_ctx, "n_head": args.n_head, - "wg_tile_size": args.wg_tile_size, + "wg_rows": args.wg_rows, + "sg_rows": args.sg_rows, + "subgroup_size": args.subgroup_size, } Z = args.batch_size diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py index c2f3d4ec..73873dc6 100644 --- a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -3,7 +3,7 @@ import math from mlir import ir -from mlir.dialects import arith, bufferization, linalg, tensor +from mlir.dialects import arith, bufferization, linalg, memref, tensor from lighthouse.utils.mlir import func_cif from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs @@ -40,36 +40,42 @@ def generate_gpu_fused_attention_payload( memref_t = ir.MemRefType.get(shape, dtype) with ir.InsertionPoint(mod.body): + # Collapse first 2 dimensions (Z, H) into a batch dimension + # From (Z, H, n_ctx, n_head) to (Z*H, n_ctx, n_head) + batch_dim = Z * H + collapsed_shape_3d = (batch_dim, n_ctx, n_head) + memref_3d_t = ir.MemRefType.get(collapsed_shape_3d, dtype) + # Function signature: payload(output, Q, K, V) @func_cif(memref_t, memref_t, memref_t, memref_t, name=func_name) def payload(output, Q_arg, K_arg, V_arg): - # Convert memrefs to tensors - emit_buf_to_tensor(output, restrict=True, writable=True) - Q_tensor = emit_buf_to_tensor(Q_arg, restrict=True) - K_tensor = emit_buf_to_tensor(K_arg, restrict=True) - V_tensor = emit_buf_to_tensor(V_arg, restrict=True) - - # Collapse first 2 dimensions (Z, H) into a batch dimension - # From (Z, H, n_ctx, n_head) to (Z*H, n_ctx, n_head) - batch_dim = Z * H - collapsed_shape_3d = (batch_dim, n_ctx, n_head) - - Q_3d = tensor.collapse_shape( - ir.RankedTensorType.get(collapsed_shape_3d, dtype), - Q_tensor, + # Collapse memrefs from 4D to 3D + Q_3d_memref = memref.collapse_shape( + memref_3d_t, + Q_arg, + reassociation=[[0, 1], [2], [3]], + ) + K_3d_memref = memref.collapse_shape( + memref_3d_t, + K_arg, reassociation=[[0, 1], [2], [3]], ) - K_3d = tensor.collapse_shape( - ir.RankedTensorType.get(collapsed_shape_3d, dtype), - K_tensor, + V_3d_memref = memref.collapse_shape( + memref_3d_t, + V_arg, reassociation=[[0, 1], [2], [3]], ) - V_3d = tensor.collapse_shape( - ir.RankedTensorType.get(collapsed_shape_3d, dtype), - V_tensor, + output_3d_memref = memref.collapse_shape( + memref_3d_t, + output, reassociation=[[0, 1], [2], [3]], ) + # Convert 3D memrefs to tensors + Q_3d = emit_buf_to_tensor(Q_3d_memref, restrict=True) + K_3d = emit_buf_to_tensor(K_3d_memref, restrict=True) + V_3d = emit_buf_to_tensor(V_3d_memref, restrict=True) + # Step 1: Transpose K to get K^T # Permute from (batch_dim, n_ctx, n_head) to (batch_dim, n_head, n_ctx) kt_shape_3d = (batch_dim, n_head, n_ctx) @@ -119,18 +125,9 @@ def payload(output, Q_arg, K_arg, V_arg): attention_weights, V_3d, outs=[output_3d_init_filled] ) - # Expand back to 4D: (Z*H, n_ctx, n_head) -> (Z, H, n_ctx, n_head) - result = tensor.expand_shape( - ir.RankedTensorType.get(shape, dtype), - result_3d, - reassociation=[[0, 1], [2], [3]], - output_shape=[], - static_output_shape=shape, - ) - - # Materialize result back to output memref + # Materialize 3D result back to 3D output memref bufferization.materialize_in_destination( - None, result, output, restrict=True, writable=True + None, result_3d, output_3d_memref, restrict=True, writable=True ) # Emit utility functions for GPU memory management diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index d838c7f0..5544f9be 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -4,15 +4,19 @@ from mlir import ir from mlir.dialects import transform -from mlir.dialects.transform import structured +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.bufferization import LayoutMapOption from lighthouse.pipeline.helper import ( canonicalize, match, match_and_split, PipelineInterrupt, + apply_registered_pass, ) from lighthouse.schedule import schedule_boilerplate +from lighthouse.dialects.transform import transform_ext def fused_attention_schedule( @@ -36,7 +40,9 @@ def fused_attention_schedule( - num_heads: Number of attention heads (H) - n_ctx: Context length - n_head: Head dimension - - wg_tile_size: Workgroup tile size for the collapsed batch dimension (Z*H*n_ctx) + - wg_rows: Number of Q*K^T*V rows computed by each work group + - sg_rows: Number of Q*K^T*V rows computed by each subgroup + - subgroup_size: Size of subgroup Returns: MLIR module containing the transform schedule @@ -95,7 +101,7 @@ def bundle_xegpu_fused_attention_schedule( ) # Tile the last matmul in both batch and M dimensions. - wg_tile_size = parameters["wg_tile_size"] + wg_rows = parameters["wg_rows"] tiled_matmul, forall_loop = structured.structured_tile_using_forall( anytype, @@ -103,7 +109,7 @@ def bundle_xegpu_fused_attention_schedule( last_matmul, num_threads=[], tile_sizes=[], - static_tile_sizes=(1, wg_tile_size, 0, 0), + static_tile_sizes=(1, wg_rows, 0, 0), ) # Fuse the zero initialization of the output of the last matmul (tensor.empty) into the forall loop. tiled_matmul_init = transform.get_producer_of_operand( @@ -185,39 +191,98 @@ def bundle_xegpu_fused_attention_schedule( if stop_at_stage == "outer-tiled": raise PipelineInterrupt() - # # vectorize (placeholder) - # # func = structured.VectorizeChildrenAndApplyPatternsOp( - # # func, - # # fold_type_extensions_into_contract=True, - # # ).result - # transform.apply_cse(func) - # canonicalize(func) - if stop_at_stage == "inner-tiled": raise PipelineInterrupt() + # vectorize (placeholder) + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + transform.apply_cse(func) + canonicalize(func) + if stop_at_stage == "vectorized": raise PipelineInterrupt() # bufferize (placeholder) - # mod = apply_registered_pass(mod, "eliminate-empty-tensors") - # identity_layout = LayoutMapOption.IdentityLayoutMap - # mod = transform_bufferization.OneShotBufferizeOp( - # mod, - # allow_return_allocs_from_loops=True, - # bufferize_function_boundaries=True, - # function_boundary_type_conversion=identity_layout, - # ).result + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = transform_bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + transform.apply_cse(mod) + canonicalize(mod) if stop_at_stage == "bufferized": raise PipelineInterrupt() + # convert forall to parallel + wg_loops = match_and_split(mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set the number of threads for the gpu.launch operation + launch_op = match_and_split(func, ops={"gpu.launch"}) + wg_rows = parameters["wg_rows"] + sg_rows = parameters["sg_rows"] + subgroup_size = parameters["subgroup_size"] + num_subgroups = wg_rows // sg_rows + num_threads = num_subgroups * subgroup_size + xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + if stop_at_stage == "gpu-outlining": raise PipelineInterrupt() + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # for each gpu function in the gpu module, change memref.alloca address + # space to 3 (SLM) and convert vector to xegpu. + gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + allocas = match_and_split(gpu_func, ops={"memref.alloca"}) + for alloca in allocas: + transform_ext.update_address_space(alloca, address_space=3) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + + # Cleanup. + transform.apply_cse(mod) + canonicalize(mod) + if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() + # Set layout attributes for xegpu store operations + # Note: The exact operations to match depend on what gets generated + # after vectorization and xegpu conversion for fused attention + # This is a placeholder that may need adjustment based on the actual IR + gpu_func = match(gpu_mod_ops[0], ops={"gpu.func"}) + if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From 5222af3fe232844a5b646f18144904dae7e3395d Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 30 Apr 2026 22:28:20 +0000 Subject: [PATCH 48/63] save xegpu wg version --- examples/xegpu/fused_attention.py | 23 +++++++++++------ lighthouse/schedule/xegpu/__init__.py | 2 ++ .../xegpu/fused_attention_schedule.py | 25 ++++++++++--------- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index ebc4297d..5c334bda 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -21,9 +21,7 @@ from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import ( generate_gpu_fused_attention_payload, ) -from lighthouse.schedule.xegpu.fused_attention_schedule import ( - fused_attention_schedule, -) +from lighthouse.schedule.xegpu import fused_attention_schedule, xegpu_to_binary def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int): @@ -173,13 +171,22 @@ def schedule_modules( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> list[ir.Module]: """Generate transform schedule for fused attention.""" - return [ - Runner.get_bench_wrapper_schedule(self.payload_function_name), + schedules = [] + schedules.append(Runner.get_bench_wrapper_schedule(self.payload_function_name)) + + schedules.append( fused_attention_schedule( stop_at_stage=stop_at_stage, parameters=parameters, - ), - ] + ) + ) + + if stop_at_stage and stop_at_stage != "final": + return schedules + + schedules.append(xegpu_to_binary()) + + return schedules def shared_libs(self) -> list[str]: return ["libmlir_levelzero_runtime.so"] @@ -205,7 +212,7 @@ def parse_cli(): parser.add_argument( "--n-ctx", type=int, - default=512, + default=128, help="Context length (sequence length)", ) parser.add_argument( diff --git a/lighthouse/schedule/xegpu/__init__.py b/lighthouse/schedule/xegpu/__init__.py index 23d9ef0c..76f4101c 100644 --- a/lighthouse/schedule/xegpu/__init__.py +++ b/lighthouse/schedule/xegpu/__init__.py @@ -1,8 +1,10 @@ from .xegpu_to_binary import xegpu_to_binary from .mlp_schedule import mlp_schedule from .softmax_schedule import softmax_schedule +from .fused_attention_schedule import fused_attention_schedule __all__ = [ + "fused_attention_schedule", "mlp_schedule", "softmax_schedule", "xegpu_to_binary", diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 5544f9be..f3658bae 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -7,6 +7,9 @@ from mlir.dialects.transform import structured, loop, xegpu from mlir.dialects.transform import bufferization as transform_bufferization from mlir.dialects.bufferization import LayoutMapOption +from mlir.dialects.transform.vector import ( + apply_patterns_vector_cast_away_vector_leading_one_dim, +) from lighthouse.pipeline.helper import ( canonicalize, @@ -16,7 +19,6 @@ apply_registered_pass, ) from lighthouse.schedule import schedule_boilerplate -from lighthouse.dialects.transform import transform_ext def fused_attention_schedule( @@ -85,7 +87,6 @@ def bundle_xegpu_fused_attention_schedule( raise PipelineInterrupt() anytype = transform.AnyOpType.get() - anyvalue = transform.AnyValueType.get() # Match all matmul operations - there should be 2: # 1. Q @ K^T # 2. attention_weights @ V @@ -194,18 +195,21 @@ def bundle_xegpu_fused_attention_schedule( if stop_at_stage == "inner-tiled": raise PipelineInterrupt() - # vectorize (placeholder) + # vectorize func = structured.VectorizeChildrenAndApplyPatternsOp( func, fold_type_extensions_into_contract=True, ).result transform.apply_cse(func) canonicalize(func) + # Try to remove any unit dimensions that may have been introduced due to tiling (e.g. batch dim of 1) + with ir.InsertionPoint(transform.apply_patterns(func).patterns): + apply_patterns_vector_cast_away_vector_leading_one_dim() if stop_at_stage == "vectorized": raise PipelineInterrupt() - # bufferize (placeholder) + # bufferize mod = apply_registered_pass(mod, "eliminate-empty-tensors") identity_layout = LayoutMapOption.IdentityLayoutMap mod = transform_bufferization.OneShotBufferizeOp( @@ -264,9 +268,6 @@ def bundle_xegpu_fused_attention_schedule( gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) for gpu_mod in gpu_mod_ops: gpu_func = match(gpu_mod, ops={"gpu.func"}) - allocas = match_and_split(gpu_func, ops={"memref.alloca"}) - for alloca in allocas: - transform_ext.update_address_space(alloca, address_space=3) gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) @@ -277,11 +278,11 @@ def bundle_xegpu_fused_attention_schedule( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - # Set layout attributes for xegpu store operations - # Note: The exact operations to match depend on what gets generated - # after vectorization and xegpu conversion for fused attention - # This is a placeholder that may need adjustment based on the actual IR - gpu_func = match(gpu_mod_ops[0], ops={"gpu.func"}) + # Set layout attributes for xegpu.store_nd ops. + store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] + sg_layout = [sg_rows, 1] + sg_data = [sg_rows, parameters["n_head"]] + xegpu.set_anchor_layout(store_nd_op, sg_layout=sg_layout, sg_data=sg_data) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From 1dba7ec1d739cee298e90d3427afa5e5629ea2f5 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 6 May 2026 20:42:33 +0000 Subject: [PATCH 49/63] save work --- examples/xegpu/fused_attention.py | 8 +- .../transform/transform_ext/__init__.py | 2 + .../ops/generate_fused_attention.py | 258 ++++++++++++++++++ .../xegpu/fused_attention_schedule.py | 48 ++++ 4 files changed, 312 insertions(+), 4 deletions(-) create mode 100644 lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index 5c334bda..101e3ff9 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -126,14 +126,14 @@ def __init__( H: int, n_ctx: int, n_head: int, - dtype: str = "f32", + dtype: str = "f16", ): self.Z = Z self.H = H self.n_ctx = n_ctx self.n_head = n_head self.shape = (Z, H, n_ctx, n_head) - assert dtype == "f32", "Only f32 type is supported for fused attention" + assert dtype == "f16", "Only f16 type is supported for fused attention" self.elem_type = get_mlir_elem_type(dtype) self.dtype = mlir_to_numpy_dtype(self.elem_type) self.memory_manager_class = GPUMemoryManager @@ -212,7 +212,7 @@ def parse_cli(): parser.add_argument( "--n-ctx", type=int, - default=128, + default=1024, help="Context length (sequence length)", ) parser.add_argument( @@ -306,7 +306,7 @@ def parse_cli(): H = args.num_heads n_ctx = args.n_ctx n_head = args.n_head - dtype = "f32" + dtype = "f16" with ir.Context(), ir.Location.unknown(): lh_dialects.register_and_load() diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py index 997522a2..eec36b6e 100644 --- a/lighthouse/dialects/transform/transform_ext/__init__.py +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -10,11 +10,13 @@ from .ops.get_tileable_consumers import get_tileable_consumers from .ops.get_tiling_sizes import get_tiling_sizes from .ops.update_address_space import update_address_space +from .ops.generate_fused_attention import generate_fused_attention __all__ = [ "TransformExtensionDialect", "convert_func_results_to_args", "extract_handle", + "generate_fused_attention", "get_named_attribute", "get_named_attribute", "get_tileable_consumers", diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py new file mode 100644 index 00000000..7dc63552 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -0,0 +1,258 @@ +"""Transform extension to generate fused attention computation.""" + +from mlir import ir +from mlir.dialects import ext, transform, arith, scf, linalg, tensor +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect + + +class GenerateFusedAttention( + TransformExtensionDialect.Operation, name="generate_fused_attention" +): + """Generate tiled fused attention computation (flash attention optimization). + + Takes Q, K, V slices and output tensor, and generates an inner tiled loop + that computes fused attention with online softmax using running max and sum. + + This implements the flash attention algorithm where: + 1. The computation is tiled along the inner K dimension (sequence length) + 2. Online max and sum are maintained across tiles + 3. Output is incrementally updated with rescaled contributions + + Args: + q_slice: Handle to Q slice operation (tensor.extract_slice) + k_slice: Handle to K slice operation (tensor.extract_slice) + scale_slice: Handle to scaling factor slice operation (tensor.extract_slice) + v_slice: Handle to V slice operation (tensor.extract_slice) + output: Handle to the output operation to replace + tile_size: Size of inner dimension tiles (default: from attributes) + """ + + q_slice: ext.Operand[transform.AnyOpType] + k_slice: ext.Operand[transform.AnyOpType] + scale_slice: ext.Operand[transform.AnyOpType] + v_slice: ext.Operand[transform.AnyOpType] + output: ext.Operand[transform.AnyOpType] + tile_size: ir.IntegerAttr + new_output: ext.Result[transform.AnyOpType[()]] = ext.infer_result() + + @classmethod + def attach_interface_impls(cls, ctx=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "GenerateFusedAttention", + rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + # Get payload operations + q_slice_ops = state.get_payload_ops(op.q_slice) + k_slice_ops = state.get_payload_ops(op.k_slice) + scale_slice_ops = state.get_payload_ops(op.scale_slice) + v_slice_ops = state.get_payload_ops(op.v_slice) + output_ops = state.get_payload_ops(op.output) + + if ( + len(q_slice_ops) != 1 + or len(k_slice_ops) != 1 + or len(scale_slice_ops) != 1 + or len(v_slice_ops) != 1 + or len(output_ops) != 1 + ): + return DiagnosedSilenceableFailure.emit_silenceable_error( + "Expected exactly one operation for each operand" + ) + + q_slice_op = q_slice_ops[0] + k_slice_op = k_slice_ops[0] + scale_slice_op = scale_slice_ops[0] + v_slice_op = v_slice_ops[0] + output_op = output_ops[0] + + # Get tile size + tile_size_value = ir.IntegerAttr(op.tile_size).value + + # Get the result types and shapes + q_result = q_slice_op.results[0] + k_result = k_slice_op.results[0] + scale_result = scale_slice_op.results[0] + v_result = v_slice_op.results[0] + output_result = output_op.results[0] + + # Extract shape information from the slice operations + q_type = ir.RankedTensorType(q_result.type) + k_type = ir.RankedTensorType(k_result.type) + scale_type = ir.RankedTensorType(scale_result.type) + v_type = ir.RankedTensorType(v_result.type) + output_type = ir.RankedTensorType(output_result.type) + print(f"Q type: {q_type}, K type: {k_type}, Scale type: {scale_type}, V type: {v_type}, Output type: {output_type}") + + element_type = q_type.element_type + index_type = ir.IndexType.get() + + # Build the fused attention computation + with ir.InsertionPoint(output_op): + # Collapse the unit batch dimension to get 2D tensors + # Q: [1, seq_q, head_dim] -> [seq_q, head_dim] + # K: [1, seq_k, head_dim] -> [seq_k, head_dim] + # V: [1, seq_k, head_dim] -> [seq_k, head_dim] + # Scale: [1, seq_q, 1] -> [seq_q, 1] + q_2d_ty = ir.RankedTensorType.get((q_type.shape[1], q_type.shape[2]), element_type) + k_2d_ty = ir.RankedTensorType.get((k_type.shape[1], k_type.shape[2]), element_type) + v_2d_ty = ir.RankedTensorType.get((v_type.shape[1], v_type.shape[2]), element_type) + scale_2d_ty = ir.RankedTensorType.get((scale_type.shape[1], scale_type.shape[2]), element_type) + q_2d = tensor.collapse_shape(q_2d_ty, src=q_result, reassociation=[[0, 1], [2]]) + k_2d = tensor.collapse_shape(k_2d_ty, src=k_result, reassociation=[[0, 1], [2]]) + v_2d = tensor.collapse_shape(v_2d_ty, src=v_result, reassociation=[[0, 1], [2]]) + scale_2d = tensor.collapse_shape(scale_2d_ty, src=scale_result, reassociation=[[0, 1], [2]]) + + # Get dimensions from 2D tensors + # Q: [seq_q, head_dim] + # K: [seq_k, head_dim] + # V: [seq_k, head_dim] + seq_q_dim = q_type.shape[1] + head_dim = q_type.shape[2] + seq_k_size = arith.constant(index_type, k_type.shape[1]) + print(f"Seq Q dim: {seq_q_dim}, Head dim: {head_dim}, Sequence length K: {seq_k_size}") + + # Initialize max to -inf + # Shape: [seq_q] (1D for 2D tensors) + neg_inf = arith.constant( + element_type, float("-inf") if element_type == ir.F32Type.get() else -1e10 + ) + max_shape = [seq_q_dim] + max_init = tensor.empty(max_shape, element_type) + running_max = linalg.fill(neg_inf, outs=[max_init]) + + # Initialize sum to 0 + # Shape: [seq_q] (1D for 2D tensors) + zero = arith.constant(element_type, 0.0) + sum_init = tensor.empty(max_shape, element_type) + running_sum = linalg.fill(zero, outs=[sum_init]) + + # Initialize output accumulator to 0 + # Shape: [seq_q, head_dim] (2D) + output_2d_shape = [seq_q_dim, head_dim] + output_2d_init = tensor.empty(output_2d_shape, element_type) + output_acc = linalg.fill(zero, outs=[output_2d_init]) + + # Create tiled loop over K dimension + c0 = arith.constant(index_type, 0) + tile_size_const = arith.constant(index_type, tile_size_value) + + # Build the scf.for loop + loop_result = scf.ForOp( + c0, + seq_k_size, + tile_size_const, + [running_max, running_sum, output_acc], + ) + + with ir.InsertionPoint(loop_result.body): + # Get loop iteration variable and current state + k_idx = loop_result.induction_variable + old_max = loop_result.inner_iter_args[0] + old_sum = loop_result.inner_iter_args[1] + old_output = loop_result.inner_iter_args[2] + + # Slice K and V for this tile + # TODO: Implement proper slicing logic here + # This is a placeholder - needs to be filled with actual slice operations + + # Compute Q @ K^T for this tile + # Q: [seq_q, head_dim] + # K_tile: [tile_size, head_dim] (transposed from [head_dim, tile_size]) + # Result: [seq_q, tile_size] + + # Compute new max across this tile + # new_max: [seq_q] = max(old_max, row_max(Q @ K^T_tile)) + + # Compute exp(Q @ K^T_tile - new_max) + # exp_scores: [seq_q, tile_size] + + # Update running sum with rescaling + # new_sum: [seq_q] = old_sum * exp(old_max - new_max) + row_sum(exp_scores) + + # Update output with rescaling + # new_output: [seq_q, head_dim] = old_output * (old_sum * exp(old_max - new_max) / new_sum) + # + (exp_scores @ V_tile) / new_sum + + # For now, yield the unchanged values (placeholder) + scf.yield_([old_max, old_sum, old_output]) + + # Extract final results from loop + final_max = loop_result.results[0] + final_sum = loop_result.results[1] + final_output_2d = loop_result.results[2] + + # Expand the 2D output back to 3D to match the original output shape + # [seq_q, head_dim] -> [1, seq_q, head_dim] + final_output_3d = tensor.expand_shape(output_type, src=final_output_2d, reassociation=[[0, 1], [2]], output_shape=[], static_output_shape=output_type.shape) + + # Create a dummy add operation to wrap the final output + # This is needed because replace_op requires an operation, not a value + zero_const = arith.constant(element_type, 0.0) + # Create a linalg.add that adds 0 (identity operation) + zero_tensor_shape = output_type.shape + zero_tensor_init = tensor.empty(zero_tensor_shape, element_type) + zero_tensor = linalg.fill(zero_const, outs=[zero_tensor_init]) + + # Create the add operation: final_output_3d + 0 + output_init_for_add = tensor.empty(zero_tensor_shape, element_type) + dummy_add = linalg.add(final_output_3d, zero_tensor, outs=[output_init_for_add]) + + # Replace the original output operation with the dummy add + rewriter.replace_op(output_op, dummy_add.owner) + + results.set_ops(op.new_output, [dummy_add.owner]) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "GenerateFusedAttention") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + # Read Q, K, scale, V slices + transform.only_reads_handle(op.op_operands[:4], effects) + # Consume and replace output + transform.consumes_handle(op.op_operands[4:5], effects) + # Produce new output handle + transform.produces_handle(op.results, effects) + # Modify the payload + transform.modifies_payload(effects) + + +def generate_fused_attention( + q_slice: ir.Value, + k_slice: ir.Value, + scale_slice: ir.Value, + v_slice: ir.Value, + output: ir.Value, + tile_size: int | ir.IntegerAttr, +) -> ir.Value: + """Generate fused attention computation with inner tiling. + + Args: + q_slice: Handle to Q slice operation + k_slice: Handle to K slice operation + scale_slice: Handle to scaling factor slice operation + v_slice: Handle to V slice operation + output: Handle to output operation to replace + tile_size: Size of tiles along the K dimension + + Returns: + Handle to the new output operation + """ + if not isinstance(tile_size, ir.IntegerAttr): + tile_size = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), tile_size) + + return GenerateFusedAttention( + q_slice, k_slice, scale_slice, v_slice, output, tile_size=tile_size + ).new_output diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index f3658bae..33b837c0 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -9,6 +9,7 @@ from mlir.dialects.bufferization import LayoutMapOption from mlir.dialects.transform.vector import ( apply_patterns_vector_cast_away_vector_leading_one_dim, + apply_patterns_vector_drop_unit_dims_with_shape_cast, ) from lighthouse.pipeline.helper import ( @@ -19,6 +20,7 @@ apply_registered_pass, ) from lighthouse.schedule import schedule_boilerplate +from lighthouse.dialects.transform.transform_ext import generate_fused_attention def fused_attention_schedule( @@ -192,6 +194,51 @@ def bundle_xegpu_fused_attention_schedule( if stop_at_stage == "outer-tiled": raise PipelineInterrupt() + # Match Q, K, V slices inside the forall loop + # K slice is the first operand of the transpose op + transpose_op = match_and_split(forall_loop, ops={"linalg.transpose"}, nhandles=1)[0] + k_slice = transform.get_producer_of_operand( + anytype, transpose_op, operand_number=0 + ) + # Q slice is the first operand of the first batch matmul + batch_matmuls = match_and_split(forall_loop, ops={"linalg.batch_matmul"}, nhandles=2) + q_slice = transform.get_producer_of_operand( + anytype, batch_matmuls[0], operand_number=0 + ) + # V slice is the second operand of the last batch matmul (inside the forall loop) + # Need to match the tiled version of the last matmul inside the loop + last_matmul = batch_matmuls[1] + v_slice = transform.get_producer_of_operand( + anytype, last_matmul, operand_number=1 + ) + + # Match the scaling operation (linalg.mul) to get the scaling factor + # The QK output is scaled before softmax: QK * scale + mul_op = match_and_split(forall_loop, ops={"linalg.mul"}, nhandles=1)[0] + scale_slice = transform.get_producer_of_operand( + anytype, mul_op, operand_number=1 + ) + # transform.print_(target=k_slice, name="k_slice") + # transform.print_(target=q_slice, name="q_slice") + # transform.print_(target=v_slice, name="v_slice") + # transform.print_(target=scale_slice, name="scale_slice") + # transform.print_(target=last_matmul, name="tiled_attention_weights_v_matmul") + + # Generate fused attention computation with inner tiling (flash attention) + # This replaces the current unfused computation with a tiled loop that + # maintains online max and sum for efficient memory usage + tile_size = 64 + new_output = generate_fused_attention( + q_slice=q_slice, + k_slice=k_slice, + scale_slice=scale_slice, + v_slice=v_slice, + output=last_matmul, + tile_size=tile_size, + ) + # transform.apply_cse(func) + # canonicalize(func) + if stop_at_stage == "inner-tiled": raise PipelineInterrupt() @@ -205,6 +252,7 @@ def bundle_xegpu_fused_attention_schedule( # Try to remove any unit dimensions that may have been introduced due to tiling (e.g. batch dim of 1) with ir.InsertionPoint(transform.apply_patterns(func).patterns): apply_patterns_vector_cast_away_vector_leading_one_dim() + apply_patterns_vector_drop_unit_dims_with_shape_cast() if stop_at_stage == "vectorized": raise PipelineInterrupt() From 32e03498ea9d9414800a52e7c24acbcfb7117615 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 7 May 2026 23:10:54 +0000 Subject: [PATCH 50/63] save work --- .../ops/generate_fused_attention.py | 165 ++++++++++++++++-- .../xegpu/fused_attention_schedule.py | 4 +- 2 files changed, 148 insertions(+), 21 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index 7dc63552..2de20221 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -1,7 +1,7 @@ """Transform extension to generate fused attention computation.""" from mlir import ir -from mlir.dialects import ext, transform, arith, scf, linalg, tensor +from mlir.dialects import ext, transform, arith, scf, linalg, tensor, math from mlir.dialects.transform import DiagnosedSilenceableFailure from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect @@ -118,7 +118,6 @@ def apply( seq_q_dim = q_type.shape[1] head_dim = q_type.shape[2] seq_k_size = arith.constant(index_type, k_type.shape[1]) - print(f"Seq Q dim: {seq_q_dim}, Head dim: {head_dim}, Sequence length K: {seq_k_size}") # Initialize max to -inf # Shape: [seq_q] (1D for 2D tensors) @@ -161,33 +160,161 @@ def apply( old_output = loop_result.inner_iter_args[2] # Slice K and V for this tile - # TODO: Implement proper slicing logic here - # This is a placeholder - needs to be filled with actual slice operations + # K: [seq_k, head_dim] -> K_tile: [tile_size, head_dim] + # V: [seq_k, head_dim] -> V_tile: [tile_size, head_dim] + one = arith.constant(index_type, 1) + k_tile_type = ir.RankedTensorType.get([tile_size_value, head_dim], element_type) + k_tile = tensor.extract_slice( + k_tile_type, + source=k_2d, + offsets=[k_idx], + sizes=[], + strides=[], + static_offsets=[ir.ShapedType.get_dynamic_size(), 0], + static_sizes=[tile_size_value, head_dim], + static_strides=[1, 1], + ) + + v_tile_type = ir.RankedTensorType.get([tile_size_value, head_dim], element_type) + v_tile = tensor.extract_slice( + v_tile_type, + source=v_2d, + offsets=[k_idx], + sizes=[], + strides=[], + static_offsets=[ir.ShapedType.get_dynamic_size(), 0], + static_sizes=[tile_size_value, head_dim], + static_strides=[1, 1], + ) + # Transpose K_tile: [tile_size, head_dim] -> [head_dim, tile_size] + k_tile_t_shape = [head_dim, tile_size_value] + k_tile_t_init = tensor.empty(k_tile_t_shape, element_type) + k_tile_t = linalg.transpose(k_tile, outs=[k_tile_t_init], permutation=[1, 0]) # Compute Q @ K^T for this tile # Q: [seq_q, head_dim] - # K_tile: [tile_size, head_dim] (transposed from [head_dim, tile_size]) + # K_tile_T: [head_dim, tile_size] # Result: [seq_q, tile_size] + qk_shape = [seq_q_dim, tile_size_value] + qk_init = tensor.empty(qk_shape, element_type) + qk_filled = linalg.fill(zero, outs=[qk_init]) + qk = linalg.matmul(q_2d, k_tile_t, outs=[qk_filled]) + + # Compute row-wise max of qk + # row_max: [seq_q] = max(qk, axis=1) + row_max_init = tensor.empty([seq_q_dim], element_type) + row_max_filled = linalg.fill(neg_inf, outs=[row_max_init]) + dims_attr = ir.DenseI64ArrayAttr.get([1]) + f16 = ir.F16Type.get() + + @linalg.reduce( + result=[ir.RankedTensorType.get([seq_q_dim], element_type)], + inputs=[qk], + inits=[row_max_filled], + dimensions=dims_attr, + ) + def row_max(elem : f16 , acc: f16): + return arith.maximumf(elem, acc) # Compute new max across this tile - # new_max: [seq_q] = max(old_max, row_max(Q @ K^T_tile)) - - # Compute exp(Q @ K^T_tile - new_max) - # exp_scores: [seq_q, tile_size] + # new_max: [seq_q] = max(old_max, row_max) + new_max_init = tensor.empty([seq_q_dim], element_type) + new_max = linalg.max(old_max, row_max, outs=[new_max_init]) + + # Compute exp(qk - new_max) + # First broadcast new_max to [seq_q, 1] then to [seq_q, tile_size] + new_max_2d_type = ir.RankedTensorType.get([seq_q_dim, 1], element_type) + new_max_2d_init = tensor.empty([seq_q_dim, tile_size_value], element_type) + new_max_2d = linalg.broadcast(new_max, outs=[new_max_2d_init], dimensions=[0]) + + # exp_scores: [seq_q, tile_size] = exp(qk - new_max_2d) + exp_scores_init = tensor.empty(qk_shape, element_type) + + @linalg.map( + result=[ir.RankedTensorType.get(qk_shape, element_type)], + inputs=[qk, new_max_2d], + init=exp_scores_init, + ) + def exp_scores(qk_val : f16, max_val: f16, _ : f16): + diff = arith.subf(qk_val, max_val) + return math.exp(diff) + + # Compute row-wise sum of exp_scores + # row_sum_exp: [seq_q] = sum(exp_scores, axis=1) + row_sum_exp_init = tensor.empty([seq_q_dim], element_type) + row_sum_exp_filled = linalg.fill(zero, outs=[row_sum_exp_init]) + + @linalg.reduce( + result=[ir.RankedTensorType.get([seq_q_dim], element_type)], + inputs=[exp_scores], + inits=[row_sum_exp_filled], + dimensions=dims_attr, + ) + def row_sum_exp(elem: f16, acc: f16): + return arith.addf(elem, acc) + + # Compute correction factor for old values: exp(old_max - new_max) + correction_init = tensor.empty([seq_q_dim], element_type) + + @linalg.map( + result=[ir.RankedTensorType.get([seq_q_dim], element_type)], + inputs=[old_max, new_max], + init=correction_init, + ) + def correction(old_val: f16, new_val: f16, _: f16): + diff = arith.subf(old_val, new_val) + return math.exp(diff) # Update running sum with rescaling - # new_sum: [seq_q] = old_sum * exp(old_max - new_max) + row_sum(exp_scores) + # new_sum: [seq_q] = old_sum * correction + row_sum_exp + new_sum_init = tensor.empty([seq_q_dim], element_type) + + @linalg.map( + result=[ir.RankedTensorType.get([seq_q_dim], element_type)], + inputs=[old_sum, correction, row_sum_exp], + init=new_sum_init, + ) + def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): + rescaled = arith.mulf(old_s, corr) + return arith.addf(rescaled, new_s) + + # Compute exp_scores @ V_tile + # exp_scores: [seq_q, tile_size] + # V_tile: [tile_size, head_dim] + # Result: [seq_q, head_dim] + exp_v_init = tensor.empty([seq_q_dim, head_dim], element_type) + exp_v_filled = linalg.fill(zero, outs=[exp_v_init]) + exp_v = linalg.matmul(exp_scores, v_tile, outs=[exp_v_filled]) # Update output with rescaling - # new_output: [seq_q, head_dim] = old_output * (old_sum * exp(old_max - new_max) / new_sum) - # + (exp_scores @ V_tile) / new_sum - - # For now, yield the unchanged values (placeholder) - scf.yield_([old_max, old_sum, old_output]) - - # Extract final results from loop - final_max = loop_result.results[0] - final_sum = loop_result.results[1] + # new_output: [seq_q, head_dim] = old_output * (correction * old_sum / new_sum) + (exp_v / new_sum) + # First compute rescale factor: correction / new_sum (broadcasted to [seq_q, 1]) + rescale_factor_div_init = tensor.empty([seq_q_dim], element_type) + rescale_factor_div = linalg.div(correction, new_sum, outs=[rescale_factor_div_init]) + rescale_factor_mul_init = tensor.empty([seq_q_dim], element_type) + rescale_factor_mul = linalg.mul(rescale_factor_div, old_sum, outs=[rescale_factor_mul_init]) + rescale_factor_2d_init = tensor.empty([seq_q_dim, tile_size_value], element_type) + rescale_factor_2d = linalg.broadcast(rescale_factor_mul, outs=[rescale_factor_2d_init], dimensions=[0]) + + # Rescale old output + rescaled_old_init = tensor.empty([seq_q_dim, head_dim], element_type) + rescaled_old = linalg.mul(old_output, rescale_factor_2d, outs=[rescaled_old_init]) + + # Compute: exp_v / new_sum (broadcast new_sum to [seq_q, tile_size]) + norm_factor_2d_init = tensor.empty([seq_q_dim, tile_size_value], element_type) + norm_factor_2d = linalg.broadcast(new_sum, outs=[norm_factor_2d_init], dimensions=[0]) + + # Normalize new contribution + normalized_exp_v_init = tensor.empty([seq_q_dim, head_dim], element_type) + normalized_exp_v = linalg.div(exp_v, norm_factor_2d, outs=[normalized_exp_v_init]) + + # Add both contributions + new_output_init = tensor.empty([seq_q_dim, head_dim], element_type) + new_output = linalg.add(rescaled_old, normalized_exp_v, outs=[new_output_init]) + + scf.yield_([new_max, new_sum, new_output]) + + # Extract final result from loop final_output_2d = loop_result.results[2] # Expand the 2D output back to 3D to match the original output shape diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 33b837c0..7c6beae5 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -236,8 +236,8 @@ def bundle_xegpu_fused_attention_schedule( output=last_matmul, tile_size=tile_size, ) - # transform.apply_cse(func) - # canonicalize(func) + transform.apply_cse(func) + canonicalize(func) if stop_at_stage == "inner-tiled": raise PipelineInterrupt() From 609e571602452d51c02dc9b987ca9445377a406a Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 7 May 2026 23:11:04 +0000 Subject: [PATCH 51/63] save work --- .../ops/generate_fused_attention.py | 125 +++++++++++++----- .../xegpu/fused_attention_schedule.py | 16 +-- 2 files changed, 100 insertions(+), 41 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index 2de20221..14d6ba46 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -90,7 +90,9 @@ def apply( scale_type = ir.RankedTensorType(scale_result.type) v_type = ir.RankedTensorType(v_result.type) output_type = ir.RankedTensorType(output_result.type) - print(f"Q type: {q_type}, K type: {k_type}, Scale type: {scale_type}, V type: {v_type}, Output type: {output_type}") + print( + f"Q type: {q_type}, K type: {k_type}, Scale type: {scale_type}, V type: {v_type}, Output type: {output_type}" + ) element_type = q_type.element_type index_type = ir.IndexType.get() @@ -102,14 +104,30 @@ def apply( # K: [1, seq_k, head_dim] -> [seq_k, head_dim] # V: [1, seq_k, head_dim] -> [seq_k, head_dim] # Scale: [1, seq_q, 1] -> [seq_q, 1] - q_2d_ty = ir.RankedTensorType.get((q_type.shape[1], q_type.shape[2]), element_type) - k_2d_ty = ir.RankedTensorType.get((k_type.shape[1], k_type.shape[2]), element_type) - v_2d_ty = ir.RankedTensorType.get((v_type.shape[1], v_type.shape[2]), element_type) - scale_2d_ty = ir.RankedTensorType.get((scale_type.shape[1], scale_type.shape[2]), element_type) - q_2d = tensor.collapse_shape(q_2d_ty, src=q_result, reassociation=[[0, 1], [2]]) - k_2d = tensor.collapse_shape(k_2d_ty, src=k_result, reassociation=[[0, 1], [2]]) - v_2d = tensor.collapse_shape(v_2d_ty, src=v_result, reassociation=[[0, 1], [2]]) - scale_2d = tensor.collapse_shape(scale_2d_ty, src=scale_result, reassociation=[[0, 1], [2]]) + q_2d_ty = ir.RankedTensorType.get( + (q_type.shape[1], q_type.shape[2]), element_type + ) + k_2d_ty = ir.RankedTensorType.get( + (k_type.shape[1], k_type.shape[2]), element_type + ) + v_2d_ty = ir.RankedTensorType.get( + (v_type.shape[1], v_type.shape[2]), element_type + ) + scale_2d_ty = ir.RankedTensorType.get( + (scale_type.shape[1], scale_type.shape[2]), element_type + ) + q_2d = tensor.collapse_shape( + q_2d_ty, src=q_result, reassociation=[[0, 1], [2]] + ) + k_2d = tensor.collapse_shape( + k_2d_ty, src=k_result, reassociation=[[0, 1], [2]] + ) + v_2d = tensor.collapse_shape( + v_2d_ty, src=v_result, reassociation=[[0, 1], [2]] + ) + scale_2d = tensor.collapse_shape( + scale_2d_ty, src=scale_result, reassociation=[[0, 1], [2]] + ) # Get dimensions from 2D tensors # Q: [seq_q, head_dim] @@ -122,7 +140,8 @@ def apply( # Initialize max to -inf # Shape: [seq_q] (1D for 2D tensors) neg_inf = arith.constant( - element_type, float("-inf") if element_type == ir.F32Type.get() else -1e10 + element_type, + float("-inf") if element_type == ir.F32Type.get() else -1e10, ) max_shape = [seq_q_dim] max_init = tensor.empty(max_shape, element_type) @@ -163,7 +182,9 @@ def apply( # K: [seq_k, head_dim] -> K_tile: [tile_size, head_dim] # V: [seq_k, head_dim] -> V_tile: [tile_size, head_dim] one = arith.constant(index_type, 1) - k_tile_type = ir.RankedTensorType.get([tile_size_value, head_dim], element_type) + k_tile_type = ir.RankedTensorType.get( + [tile_size_value, head_dim], element_type + ) k_tile = tensor.extract_slice( k_tile_type, source=k_2d, @@ -175,7 +196,9 @@ def apply( static_strides=[1, 1], ) - v_tile_type = ir.RankedTensorType.get([tile_size_value, head_dim], element_type) + v_tile_type = ir.RankedTensorType.get( + [tile_size_value, head_dim], element_type + ) v_tile = tensor.extract_slice( v_tile_type, source=v_2d, @@ -189,7 +212,9 @@ def apply( # Transpose K_tile: [tile_size, head_dim] -> [head_dim, tile_size] k_tile_t_shape = [head_dim, tile_size_value] k_tile_t_init = tensor.empty(k_tile_t_shape, element_type) - k_tile_t = linalg.transpose(k_tile, outs=[k_tile_t_init], permutation=[1, 0]) + k_tile_t = linalg.transpose( + k_tile, outs=[k_tile_t_init], permutation=[1, 0] + ) # Compute Q @ K^T for this tile # Q: [seq_q, head_dim] @@ -213,7 +238,7 @@ def apply( inits=[row_max_filled], dimensions=dims_attr, ) - def row_max(elem : f16 , acc: f16): + def row_max(elem: f16, acc: f16): return arith.maximumf(elem, acc) # Compute new max across this tile @@ -223,9 +248,15 @@ def row_max(elem : f16 , acc: f16): # Compute exp(qk - new_max) # First broadcast new_max to [seq_q, 1] then to [seq_q, tile_size] - new_max_2d_type = ir.RankedTensorType.get([seq_q_dim, 1], element_type) - new_max_2d_init = tensor.empty([seq_q_dim, tile_size_value], element_type) - new_max_2d = linalg.broadcast(new_max, outs=[new_max_2d_init], dimensions=[0]) + new_max_2d_type = ir.RankedTensorType.get( + [seq_q_dim, 1], element_type + ) + new_max_2d_init = tensor.empty( + [seq_q_dim, tile_size_value], element_type + ) + new_max_2d = linalg.broadcast( + new_max, outs=[new_max_2d_init], dimensions=[0] + ) # exp_scores: [seq_q, tile_size] = exp(qk - new_max_2d) exp_scores_init = tensor.empty(qk_shape, element_type) @@ -235,7 +266,7 @@ def row_max(elem : f16 , acc: f16): inputs=[qk, new_max_2d], init=exp_scores_init, ) - def exp_scores(qk_val : f16, max_val: f16, _ : f16): + def exp_scores(qk_val: f16, max_val: f16, _: f16): diff = arith.subf(qk_val, max_val) return math.exp(diff) @@ -290,27 +321,51 @@ def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): # new_output: [seq_q, head_dim] = old_output * (correction * old_sum / new_sum) + (exp_v / new_sum) # First compute rescale factor: correction / new_sum (broadcasted to [seq_q, 1]) rescale_factor_div_init = tensor.empty([seq_q_dim], element_type) - rescale_factor_div = linalg.div(correction, new_sum, outs=[rescale_factor_div_init]) + rescale_factor_div = linalg.div( + correction, new_sum, outs=[rescale_factor_div_init] + ) rescale_factor_mul_init = tensor.empty([seq_q_dim], element_type) - rescale_factor_mul = linalg.mul(rescale_factor_div, old_sum, outs=[rescale_factor_mul_init]) - rescale_factor_2d_init = tensor.empty([seq_q_dim, tile_size_value], element_type) - rescale_factor_2d = linalg.broadcast(rescale_factor_mul, outs=[rescale_factor_2d_init], dimensions=[0]) + rescale_factor_mul = linalg.mul( + rescale_factor_div, old_sum, outs=[rescale_factor_mul_init] + ) + rescale_factor_2d_init = tensor.empty( + [seq_q_dim, tile_size_value], element_type + ) + rescale_factor_2d = linalg.broadcast( + rescale_factor_mul, + outs=[rescale_factor_2d_init], + dimensions=[0], + ) # Rescale old output - rescaled_old_init = tensor.empty([seq_q_dim, head_dim], element_type) - rescaled_old = linalg.mul(old_output, rescale_factor_2d, outs=[rescaled_old_init]) + rescaled_old_init = tensor.empty( + [seq_q_dim, head_dim], element_type + ) + rescaled_old = linalg.mul( + old_output, rescale_factor_2d, outs=[rescaled_old_init] + ) # Compute: exp_v / new_sum (broadcast new_sum to [seq_q, tile_size]) - norm_factor_2d_init = tensor.empty([seq_q_dim, tile_size_value], element_type) - norm_factor_2d = linalg.broadcast(new_sum, outs=[norm_factor_2d_init], dimensions=[0]) + norm_factor_2d_init = tensor.empty( + [seq_q_dim, tile_size_value], element_type + ) + norm_factor_2d = linalg.broadcast( + new_sum, outs=[norm_factor_2d_init], dimensions=[0] + ) # Normalize new contribution - normalized_exp_v_init = tensor.empty([seq_q_dim, head_dim], element_type) - normalized_exp_v = linalg.div(exp_v, norm_factor_2d, outs=[normalized_exp_v_init]) + normalized_exp_v_init = tensor.empty( + [seq_q_dim, head_dim], element_type + ) + normalized_exp_v = linalg.div( + exp_v, norm_factor_2d, outs=[normalized_exp_v_init] + ) # Add both contributions new_output_init = tensor.empty([seq_q_dim, head_dim], element_type) - new_output = linalg.add(rescaled_old, normalized_exp_v, outs=[new_output_init]) + new_output = linalg.add( + rescaled_old, normalized_exp_v, outs=[new_output_init] + ) scf.yield_([new_max, new_sum, new_output]) @@ -319,7 +374,13 @@ def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): # Expand the 2D output back to 3D to match the original output shape # [seq_q, head_dim] -> [1, seq_q, head_dim] - final_output_3d = tensor.expand_shape(output_type, src=final_output_2d, reassociation=[[0, 1], [2]], output_shape=[], static_output_shape=output_type.shape) + final_output_3d = tensor.expand_shape( + output_type, + src=final_output_2d, + reassociation=[[0, 1], [2]], + output_shape=[], + static_output_shape=output_type.shape, + ) # Create a dummy add operation to wrap the final output # This is needed because replace_op requires an operation, not a value @@ -331,7 +392,9 @@ def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): # Create the add operation: final_output_3d + 0 output_init_for_add = tensor.empty(zero_tensor_shape, element_type) - dummy_add = linalg.add(final_output_3d, zero_tensor, outs=[output_init_for_add]) + dummy_add = linalg.add( + final_output_3d, zero_tensor, outs=[output_init_for_add] + ) # Replace the original output operation with the dummy add rewriter.replace_op(output_op, dummy_add.owner) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 7c6beae5..d5902522 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -197,27 +197,23 @@ def bundle_xegpu_fused_attention_schedule( # Match Q, K, V slices inside the forall loop # K slice is the first operand of the transpose op transpose_op = match_and_split(forall_loop, ops={"linalg.transpose"}, nhandles=1)[0] - k_slice = transform.get_producer_of_operand( - anytype, transpose_op, operand_number=0 - ) + k_slice = transform.get_producer_of_operand(anytype, transpose_op, operand_number=0) # Q slice is the first operand of the first batch matmul - batch_matmuls = match_and_split(forall_loop, ops={"linalg.batch_matmul"}, nhandles=2) + batch_matmuls = match_and_split( + forall_loop, ops={"linalg.batch_matmul"}, nhandles=2 + ) q_slice = transform.get_producer_of_operand( anytype, batch_matmuls[0], operand_number=0 ) # V slice is the second operand of the last batch matmul (inside the forall loop) # Need to match the tiled version of the last matmul inside the loop last_matmul = batch_matmuls[1] - v_slice = transform.get_producer_of_operand( - anytype, last_matmul, operand_number=1 - ) + v_slice = transform.get_producer_of_operand(anytype, last_matmul, operand_number=1) # Match the scaling operation (linalg.mul) to get the scaling factor # The QK output is scaled before softmax: QK * scale mul_op = match_and_split(forall_loop, ops={"linalg.mul"}, nhandles=1)[0] - scale_slice = transform.get_producer_of_operand( - anytype, mul_op, operand_number=1 - ) + scale_slice = transform.get_producer_of_operand(anytype, mul_op, operand_number=1) # transform.print_(target=k_slice, name="k_slice") # transform.print_(target=q_slice, name="q_slice") # transform.print_(target=v_slice, name="v_slice") From b6863a937c19fdd8977d9ca3214733bc413a28b9 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 11 May 2026 22:44:00 +0000 Subject: [PATCH 52/63] minimum buffer version --- .../transform_ext/ops/generate_fused_attention.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index 14d6ba46..08c782ad 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -244,7 +244,7 @@ def row_max(elem: f16, acc: f16): # Compute new max across this tile # new_max: [seq_q] = max(old_max, row_max) new_max_init = tensor.empty([seq_q_dim], element_type) - new_max = linalg.max(old_max, row_max, outs=[new_max_init]) + new_max = linalg.max(old_max, row_max, outs=[old_max]) # Compute exp(qk - new_max) # First broadcast new_max to [seq_q, 1] then to [seq_q, tile_size] @@ -298,12 +298,12 @@ def correction(old_val: f16, new_val: f16, _: f16): # Update running sum with rescaling # new_sum: [seq_q] = old_sum * correction + row_sum_exp - new_sum_init = tensor.empty([seq_q_dim], element_type) + # new_sum_init = tensor.empty([seq_q_dim], element_type) @linalg.map( result=[ir.RankedTensorType.get([seq_q_dim], element_type)], inputs=[old_sum, correction, row_sum_exp], - init=new_sum_init, + init=old_sum, ) def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): rescaled = arith.mulf(old_s, corr) @@ -342,7 +342,7 @@ def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): [seq_q_dim, head_dim], element_type ) rescaled_old = linalg.mul( - old_output, rescale_factor_2d, outs=[rescaled_old_init] + old_output, rescale_factor_2d, outs=[old_output] ) # Compute: exp_v / new_sum (broadcast new_sum to [seq_q, tile_size]) @@ -364,7 +364,7 @@ def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): # Add both contributions new_output_init = tensor.empty([seq_q_dim, head_dim], element_type) new_output = linalg.add( - rescaled_old, normalized_exp_v, outs=[new_output_init] + rescaled_old, normalized_exp_v, outs=[old_output] ) scf.yield_([new_max, new_sum, new_output]) From 22faeb5ba3bd41085910508624292c0642294c0f Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 11 May 2026 22:53:19 +0000 Subject: [PATCH 53/63] minimum buffer version --- .../transform_ext/ops/generate_fused_attention.py | 8 ++++---- lighthouse/schedule/xegpu/fused_attention_schedule.py | 11 +++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index 08c782ad..ffc3e5e2 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -338,9 +338,9 @@ def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): ) # Rescale old output - rescaled_old_init = tensor.empty( - [seq_q_dim, head_dim], element_type - ) + # rescaled_old_init = tensor.empty( + # [seq_q_dim, head_dim], element_type + # ) rescaled_old = linalg.mul( old_output, rescale_factor_2d, outs=[old_output] ) @@ -362,7 +362,7 @@ def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): ) # Add both contributions - new_output_init = tensor.empty([seq_q_dim, head_dim], element_type) + # new_output_init = tensor.empty([seq_q_dim, head_dim], element_type) new_output = linalg.add( rescaled_old, normalized_exp_v, outs=[old_output] ) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index d5902522..08c71655 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -20,7 +20,7 @@ apply_registered_pass, ) from lighthouse.schedule import schedule_boilerplate -from lighthouse.dialects.transform.transform_ext import generate_fused_attention +from lighthouse.dialects.transform.transform_ext import generate_fused_attention, update_address_space def fused_attention_schedule( @@ -246,9 +246,9 @@ def bundle_xegpu_fused_attention_schedule( transform.apply_cse(func) canonicalize(func) # Try to remove any unit dimensions that may have been introduced due to tiling (e.g. batch dim of 1) - with ir.InsertionPoint(transform.apply_patterns(func).patterns): - apply_patterns_vector_cast_away_vector_leading_one_dim() - apply_patterns_vector_drop_unit_dims_with_shape_cast() + # with ir.InsertionPoint(transform.apply_patterns(func).patterns): + # apply_patterns_vector_cast_away_vector_leading_one_dim() + # apply_patterns_vector_drop_unit_dims_with_shape_cast() if stop_at_stage == "vectorized": raise PipelineInterrupt() @@ -312,6 +312,9 @@ def bundle_xegpu_fused_attention_schedule( gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) for gpu_mod in gpu_mod_ops: gpu_func = match(gpu_mod, ops={"gpu.func"}) + allocas = match_and_split(gpu_func, ops={"memref.alloca"}) + for alloca in allocas: + update_address_space(alloca, address_space=3) gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) From f9710f864082d3cbbc07ce5d34057f05ecb3cedf Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 12 May 2026 21:20:01 +0000 Subject: [PATCH 54/63] save work --- .../ops/generate_fused_attention.py | 6 +- .../xegpu/fused_attention_schedule.py | 73 +++++++++++++++++-- 2 files changed, 68 insertions(+), 11 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index ffc3e5e2..5988d6ec 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -90,9 +90,9 @@ def apply( scale_type = ir.RankedTensorType(scale_result.type) v_type = ir.RankedTensorType(v_result.type) output_type = ir.RankedTensorType(output_result.type) - print( - f"Q type: {q_type}, K type: {k_type}, Scale type: {scale_type}, V type: {v_type}, Output type: {output_type}" - ) + # print( + # f"Q type: {q_type}, K type: {k_type}, Scale type: {scale_type}, V type: {v_type}, Output type: {output_type}" + # ) element_type = q_type.element_type index_type = ir.IndexType.get() diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 08c71655..a8fbaec0 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -20,7 +20,10 @@ apply_registered_pass, ) from lighthouse.schedule import schedule_boilerplate -from lighthouse.dialects.transform.transform_ext import generate_fused_attention, update_address_space +from lighthouse.dialects.transform.transform_ext import ( + generate_fused_attention, + update_address_space, +) def fused_attention_schedule( @@ -246,9 +249,9 @@ def bundle_xegpu_fused_attention_schedule( transform.apply_cse(func) canonicalize(func) # Try to remove any unit dimensions that may have been introduced due to tiling (e.g. batch dim of 1) - # with ir.InsertionPoint(transform.apply_patterns(func).patterns): - # apply_patterns_vector_cast_away_vector_leading_one_dim() - # apply_patterns_vector_drop_unit_dims_with_shape_cast() + with ir.InsertionPoint(transform.apply_patterns(func).patterns): + apply_patterns_vector_cast_away_vector_leading_one_dim() + apply_patterns_vector_drop_unit_dims_with_shape_cast() if stop_at_stage == "vectorized": raise PipelineInterrupt() @@ -264,6 +267,21 @@ def bundle_xegpu_fused_attention_schedule( ).result transform.apply_cse(mod) canonicalize(mod) + # fold memref.subviews into vector.transfer_read/write ops + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + # promote memref.alloc to memref.alloca in payload function + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) if stop_at_stage == "bufferized": raise PipelineInterrupt() @@ -312,8 +330,9 @@ def bundle_xegpu_fused_attention_schedule( gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) for gpu_mod in gpu_mod_ops: gpu_func = match(gpu_mod, ops={"gpu.func"}) - allocas = match_and_split(gpu_func, ops={"memref.alloca"}) + allocas = match_and_split(gpu_func, ops={"memref.alloca"}, nhandles=3) for alloca in allocas: + # print("Updating address space for alloca:") update_address_space(alloca, address_space=3) gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) @@ -327,9 +346,47 @@ def bundle_xegpu_fused_attention_schedule( # Set layout attributes for xegpu.store_nd ops. store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] - sg_layout = [sg_rows, 1] - sg_data = [sg_rows, parameters["n_head"]] - xegpu.set_anchor_layout(store_nd_op, sg_layout=sg_layout, sg_data=sg_data) + out_sg_layout = [num_subgroups, 1] + out_sg_data = [sg_rows, parameters["n_head"]] + xegpu.set_anchor_layout(store_nd_op, sg_layout=out_sg_layout, sg_data=out_sg_data) + + # Set layout attributes for xegpu.store_matrix ops. same as store_nd ops. + store_matrix_ops = match_and_split(gpu_func, ops={"xegpu.store_matrix"}, nhandles=2) + for store_matrix_op in store_matrix_ops: + xegpu.set_anchor_layout( + store_matrix_op, sg_layout=out_sg_layout, sg_data=out_sg_data + ) + + # Set layout for xegpu.dpas ops + dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=2) + # layouts for the first dpas op (Q*K^T): + first_dpas_op = dpas_ops[0] + qk_a_sg_layout = out_sg_layout + qk_a_sg_data = out_sg_data + qk_b_sg_layout = [1, num_subgroups] + qk_b_sg_data = [parameters["n_head"], num_subgroups] + qk_cd_sg_layout = out_sg_layout + qk_cd_sg_data = [sg_rows, 16] + xegpu.set_anchor_layout( + first_dpas_op, sg_layout=qk_a_sg_layout, sg_data=qk_a_sg_data, index=0 + ) + xegpu.set_anchor_layout( + first_dpas_op, sg_layout=qk_b_sg_layout, sg_data=qk_b_sg_data, index=1 + ) + xegpu.set_anchor_layout( + first_dpas_op, sg_layout=qk_cd_sg_layout, sg_data=qk_cd_sg_data, index=2 + ) + # layouts for the second dpas op (attention_weights*V): + second_dpas_op = dpas_ops[1] + xegpu.set_anchor_layout( + second_dpas_op, sg_layout=qk_cd_sg_layout, sg_data=qk_cd_sg_data, index=0 + ) + xegpu.set_anchor_layout( + second_dpas_op, sg_layout=qk_b_sg_layout, sg_data=qk_b_sg_data, index=1 + ) + xegpu.set_anchor_layout( + second_dpas_op, sg_layout=out_sg_layout, sg_data=out_sg_data, index=2 + ) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From 036d59fe8c65f945e4b7e3009e0ef3a682ef37df Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 14 May 2026 23:43:20 +0000 Subject: [PATCH 55/63] save work --- .../ops/generate_fused_attention.py | 652 +++++++++--------- .../xegpu/fused_attention_schedule.py | 83 ++- 2 files changed, 377 insertions(+), 358 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index 5988d6ec..e6f9e6f1 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -1,7 +1,8 @@ """Transform extension to generate fused attention computation.""" +import numpy as np from mlir import ir -from mlir.dialects import ext, transform, arith, scf, linalg, tensor, math +from mlir.dialects import ext, transform, arith, scf, linalg, tensor, math, vector from mlir.dialects.transform import DiagnosedSilenceableFailure from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect @@ -12,27 +13,27 @@ class GenerateFusedAttention( ): """Generate tiled fused attention computation (flash attention optimization). - Takes Q, K, V slices and output tensor, and generates an inner tiled loop - that computes fused attention with online softmax using running max and sum. + Takes Q, K, V loads and scale constant from bufferized IR, and generates an inner + tiled loop that computes fused attention with online softmax using running max and sum. This implements the flash attention algorithm where: - 1. The computation is tiled along the inner K dimension (sequence length) + 1. The computation is tiled along the reduction dimension (K/V sequence length) 2. Online max and sum are maintained across tiles 3. Output is incrementally updated with rescaled contributions Args: - q_slice: Handle to Q slice operation (tensor.extract_slice) - k_slice: Handle to K slice operation (tensor.extract_slice) - scale_slice: Handle to scaling factor slice operation (tensor.extract_slice) - v_slice: Handle to V slice operation (tensor.extract_slice) - output: Handle to the output operation to replace - tile_size: Size of inner dimension tiles (default: from attributes) + q_load: Handle to Q load operation (vector.transfer_read) + k_load: Handle to K load operation (vector.transfer_read) + v_load: Handle to V load operation (vector.transfer_read) + scale: Handle to scale constant operation (arith.constant) + output: Handle to the output operation to replace (vector.contract) + tile_size: Tile size for the reduction dimension tiling (K/V sequence length) """ - q_slice: ext.Operand[transform.AnyOpType] - k_slice: ext.Operand[transform.AnyOpType] - scale_slice: ext.Operand[transform.AnyOpType] - v_slice: ext.Operand[transform.AnyOpType] + q_load: ext.Operand[transform.AnyOpType] + k_load: ext.Operand[transform.AnyOpType] + v_load: ext.Operand[transform.AnyOpType] + scale: ext.Operand[transform.AnyOpType] output: ext.Operand[transform.AnyOpType] tile_size: ir.IntegerAttr new_output: ext.Result[transform.AnyOpType[()]] = ext.infer_result() @@ -51,355 +52,334 @@ def apply( state: transform.TransformState, ) -> DiagnosedSilenceableFailure: # Get payload operations - q_slice_ops = state.get_payload_ops(op.q_slice) - k_slice_ops = state.get_payload_ops(op.k_slice) - scale_slice_ops = state.get_payload_ops(op.scale_slice) - v_slice_ops = state.get_payload_ops(op.v_slice) + q_load_ops = state.get_payload_ops(op.q_load) + k_load_ops = state.get_payload_ops(op.k_load) + v_load_ops = state.get_payload_ops(op.v_load) + scale_ops = state.get_payload_ops(op.scale) output_ops = state.get_payload_ops(op.output) if ( - len(q_slice_ops) != 1 - or len(k_slice_ops) != 1 - or len(scale_slice_ops) != 1 - or len(v_slice_ops) != 1 + len(q_load_ops) != 1 + or len(k_load_ops) != 1 + or len(v_load_ops) != 1 + or len(scale_ops) != 1 or len(output_ops) != 1 ): return DiagnosedSilenceableFailure.emit_silenceable_error( "Expected exactly one operation for each operand" ) - q_slice_op = q_slice_ops[0] - k_slice_op = k_slice_ops[0] - scale_slice_op = scale_slice_ops[0] - v_slice_op = v_slice_ops[0] + q_load_op = q_load_ops[0] + k_load_op = k_load_ops[0] + v_load_op = v_load_ops[0] + scale_op = scale_ops[0] output_op = output_ops[0] + # Extract the scale scalar value from scale_op (arith.constant) + scale_attr = scale_op.attributes["value"] + # Extract the scalar scale value from the scale_attr DenseElementsAttr + scale_dense_attr = ir.DenseElementsAttr(scale_attr) + # Get the first element as the scale value (all elements are the same in a splat) + scale_np_array = np.array(scale_dense_attr) + scale_value = float(scale_np_array.flat[0]) + + # Extract wg_rows and d_head from q_load result type + # q_load is vector.transfer_read that produces a vector + q_load_result = q_load_op.results[0] + q_vector_type = ir.VectorType(q_load_result.type) + wg_rows = q_vector_type.shape[0] + d_head = q_vector_type.shape[1] + # Get tile size tile_size_value = ir.IntegerAttr(op.tile_size).value - # Get the result types and shapes - q_result = q_slice_op.results[0] - k_result = k_slice_op.results[0] - scale_result = scale_slice_op.results[0] - v_result = v_slice_op.results[0] - output_result = output_op.results[0] - - # Extract shape information from the slice operations - q_type = ir.RankedTensorType(q_result.type) - k_type = ir.RankedTensorType(k_result.type) - scale_type = ir.RankedTensorType(scale_result.type) - v_type = ir.RankedTensorType(v_result.type) - output_type = ir.RankedTensorType(output_result.type) - # print( - # f"Q type: {q_type}, K type: {k_type}, Scale type: {scale_type}, V type: {v_type}, Output type: {output_type}" - # ) - - element_type = q_type.element_type - index_type = ir.IndexType.get() + # Get element type from q_load result + element_type = q_vector_type.element_type # Build the fused attention computation with ir.InsertionPoint(output_op): - # Collapse the unit batch dimension to get 2D tensors - # Q: [1, seq_q, head_dim] -> [seq_q, head_dim] - # K: [1, seq_k, head_dim] -> [seq_k, head_dim] - # V: [1, seq_k, head_dim] -> [seq_k, head_dim] - # Scale: [1, seq_q, 1] -> [seq_q, 1] - q_2d_ty = ir.RankedTensorType.get( - (q_type.shape[1], q_type.shape[2]), element_type - ) - k_2d_ty = ir.RankedTensorType.get( - (k_type.shape[1], k_type.shape[2]), element_type - ) - v_2d_ty = ir.RankedTensorType.get( - (v_type.shape[1], v_type.shape[2]), element_type - ) - scale_2d_ty = ir.RankedTensorType.get( - (scale_type.shape[1], scale_type.shape[2]), element_type - ) - q_2d = tensor.collapse_shape( - q_2d_ty, src=q_result, reassociation=[[0, 1], [2]] - ) - k_2d = tensor.collapse_shape( - k_2d_ty, src=k_result, reassociation=[[0, 1], [2]] - ) - v_2d = tensor.collapse_shape( - v_2d_ty, src=v_result, reassociation=[[0, 1], [2]] - ) - scale_2d = tensor.collapse_shape( - scale_2d_ty, src=scale_result, reassociation=[[0, 1], [2]] - ) - - # Get dimensions from 2D tensors - # Q: [seq_q, head_dim] - # K: [seq_k, head_dim] - # V: [seq_k, head_dim] - seq_q_dim = q_type.shape[1] - head_dim = q_type.shape[2] - seq_k_size = arith.constant(index_type, k_type.shape[1]) - - # Initialize max to -inf - # Shape: [seq_q] (1D for 2D tensors) - neg_inf = arith.constant( - element_type, - float("-inf") if element_type == ir.F32Type.get() else -1e10, - ) - max_shape = [seq_q_dim] - max_init = tensor.empty(max_shape, element_type) - running_max = linalg.fill(neg_inf, outs=[max_init]) - - # Initialize sum to 0 - # Shape: [seq_q] (1D for 2D tensors) - zero = arith.constant(element_type, 0.0) - sum_init = tensor.empty(max_shape, element_type) - running_sum = linalg.fill(zero, outs=[sum_init]) - - # Initialize output accumulator to 0 - # Shape: [seq_q, head_dim] (2D) - output_2d_shape = [seq_q_dim, head_dim] - output_2d_init = tensor.empty(output_2d_shape, element_type) - output_acc = linalg.fill(zero, outs=[output_2d_init]) - - # Create tiled loop over K dimension + # 1. Define m_i_init: vector of shape [wg_rows] with neg_inf values + m_i_vector_type = ir.VectorType.get([wg_rows], element_type) + neg_inf_value = 0xFC00 if element_type == ir.F16Type.get() else float("-inf") + m_i_values = np.full(wg_rows, neg_inf_value, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) + m_i_init_attr = ir.DenseElementsAttr.get(m_i_values, type=m_i_vector_type) + m_i_init = arith.constant(m_i_vector_type, m_i_init_attr) + + # 2. Define l_i_init: vector of shape [wg_rows] with zero values + l_i_vector_type = ir.VectorType.get([wg_rows], element_type) + l_i_values = np.zeros(wg_rows, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) + l_i_init_attr = ir.DenseElementsAttr.get(l_i_values, type=l_i_vector_type) + l_i_init = arith.constant(l_i_vector_type, l_i_init_attr) + + # 3. Define acc_init: vector of shape [wg_rows, d_head] with zero values + acc_vector_type = ir.VectorType.get([wg_rows, d_head], element_type) + acc_values = np.zeros((wg_rows, d_head), dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) + acc_init_attr = ir.DenseElementsAttr.get(acc_values, type=acc_vector_type) + acc_init = arith.constant(acc_vector_type, acc_init_attr) + + # Get n_ctx from k_load result type (first dimension size) + k_load_result = k_load_op.results[0] + k_vector_type = ir.VectorType(k_load_result.type) + n_ctx = k_vector_type.shape[0] + + + + scale_vector_type = ir.VectorType.get([wg_rows], element_type) + scale_values = np.full((wg_rows), scale_value, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) + scale_init_attr = ir.DenseElementsAttr.get(scale_values, type=scale_vector_type) + scale_vector = arith.constant(scale_vector_type, scale_init_attr) + + # Create loop bounds + index_type = ir.IndexType.get() c0 = arith.constant(index_type, 0) - tile_size_const = arith.constant(index_type, tile_size_value) - - # Build the scf.for loop - loop_result = scf.ForOp( - c0, - seq_k_size, - tile_size_const, - [running_max, running_sum, output_acc], - ) + c_n_ctx = arith.constant(index_type, n_ctx) + c_tile_size = arith.constant(index_type, tile_size_value) - with ir.InsertionPoint(loop_result.body): - # Get loop iteration variable and current state - k_idx = loop_result.induction_variable - old_max = loop_result.inner_iter_args[0] - old_sum = loop_result.inner_iter_args[1] - old_output = loop_result.inner_iter_args[2] - - # Slice K and V for this tile - # K: [seq_k, head_dim] -> K_tile: [tile_size, head_dim] - # V: [seq_k, head_dim] -> V_tile: [tile_size, head_dim] - one = arith.constant(index_type, 1) - k_tile_type = ir.RankedTensorType.get( - [tile_size_value, head_dim], element_type - ) - k_tile = tensor.extract_slice( - k_tile_type, - source=k_2d, - offsets=[k_idx], - sizes=[], - strides=[], - static_offsets=[ir.ShapedType.get_dynamic_size(), 0], - static_sizes=[tile_size_value, head_dim], - static_strides=[1, 1], - ) + # Create scf.for loop that iterates from 0 to n_ctx in steps of tile_size + loop = scf.ForOp(c0, c_n_ctx, c_tile_size, [m_i_init, l_i_init, acc_init]) - v_tile_type = ir.RankedTensorType.get( - [tile_size_value, head_dim], element_type - ) - v_tile = tensor.extract_slice( - v_tile_type, - source=v_2d, - offsets=[k_idx], - sizes=[], - strides=[], - static_offsets=[ir.ShapedType.get_dynamic_size(), 0], - static_sizes=[tile_size_value, head_dim], - static_strides=[1, 1], - ) - # Transpose K_tile: [tile_size, head_dim] -> [head_dim, tile_size] - k_tile_t_shape = [head_dim, tile_size_value] - k_tile_t_init = tensor.empty(k_tile_t_shape, element_type) - k_tile_t = linalg.transpose( - k_tile, outs=[k_tile_t_init], permutation=[1, 0] - ) + with ir.InsertionPoint(loop.body): + # Get the loop induction variable and iter_args + loop_idx = loop.induction_variable + m_i = loop.inner_iter_args[0] + l_i = loop.inner_iter_args[1] + acc = loop.inner_iter_args[2] - # Compute Q @ K^T for this tile - # Q: [seq_q, head_dim] - # K_tile_T: [head_dim, tile_size] - # Result: [seq_q, tile_size] - qk_shape = [seq_q_dim, tile_size_value] - qk_init = tensor.empty(qk_shape, element_type) - qk_filled = linalg.fill(zero, outs=[qk_init]) - qk = linalg.matmul(q_2d, k_tile_t, outs=[qk_filled]) - - # Compute row-wise max of qk - # row_max: [seq_q] = max(qk, axis=1) - row_max_init = tensor.empty([seq_q_dim], element_type) - row_max_filled = linalg.fill(neg_inf, outs=[row_max_init]) - dims_attr = ir.DenseI64ArrayAttr.get([1]) - f16 = ir.F16Type.get() - - @linalg.reduce( - result=[ir.RankedTensorType.get([seq_q_dim], element_type)], - inputs=[qk], - inits=[row_max_filled], - dimensions=dims_attr, - ) - def row_max(elem: f16, acc: f16): - return arith.maximumf(elem, acc) - - # Compute new max across this tile - # new_max: [seq_q] = max(old_max, row_max) - new_max_init = tensor.empty([seq_q_dim], element_type) - new_max = linalg.max(old_max, row_max, outs=[old_max]) - - # Compute exp(qk - new_max) - # First broadcast new_max to [seq_q, 1] then to [seq_q, tile_size] - new_max_2d_type = ir.RankedTensorType.get( - [seq_q_dim, 1], element_type - ) - new_max_2d_init = tensor.empty( - [seq_q_dim, tile_size_value], element_type - ) - new_max_2d = linalg.broadcast( - new_max, outs=[new_max_2d_init], dimensions=[0] - ) + # Load the current K tile: shape [tile_size, d_head] + # Use the same memref and indices as k_load, but replace second-to-last index with loop_idx + k_memref = k_load_op.operands[0] + k_tile_type = ir.VectorType.get([tile_size_value, d_head], element_type) - # exp_scores: [seq_q, tile_size] = exp(qk - new_max_2d) - exp_scores_init = tensor.empty(qk_shape, element_type) + # Get the indices from original k_load (all operands except the first one which is the memref) + # and the last one which is the padding value + k_load_indices = list(k_load_op.operands[1:-1]) - @linalg.map( - result=[ir.RankedTensorType.get(qk_shape, element_type)], - inputs=[qk, new_max_2d], - init=exp_scores_init, - ) - def exp_scores(qk_val: f16, max_val: f16, _: f16): - diff = arith.subf(qk_val, max_val) - return math.exp(diff) - - # Compute row-wise sum of exp_scores - # row_sum_exp: [seq_q] = sum(exp_scores, axis=1) - row_sum_exp_init = tensor.empty([seq_q_dim], element_type) - row_sum_exp_filled = linalg.fill(zero, outs=[row_sum_exp_init]) - - @linalg.reduce( - result=[ir.RankedTensorType.get([seq_q_dim], element_type)], - inputs=[exp_scores], - inits=[row_sum_exp_filled], - dimensions=dims_attr, - ) - def row_sum_exp(elem: f16, acc: f16): - return arith.addf(elem, acc) + # Replace the second-to-last index with loop_idx + k_tile_indices = k_load_indices + k_tile_indices[-2] = loop_idx # Assuming the reduction dimension is the last index before padding - # Compute correction factor for old values: exp(old_max - new_max) - correction_init = tensor.empty([seq_q_dim], element_type) + # Get the padding value (last operand of k_load) + padding = k_load_op.operands[-1] - @linalg.map( - result=[ir.RankedTensorType.get([seq_q_dim], element_type)], - inputs=[old_max, new_max], - init=correction_init, - ) - def correction(old_val: f16, new_val: f16, _: f16): - diff = arith.subf(old_val, new_val) - return math.exp(diff) - - # Update running sum with rescaling - # new_sum: [seq_q] = old_sum * correction + row_sum_exp - # new_sum_init = tensor.empty([seq_q_dim], element_type) - - @linalg.map( - result=[ir.RankedTensorType.get([seq_q_dim], element_type)], - inputs=[old_sum, correction, row_sum_exp], - init=old_sum, - ) - def new_sum(old_s: f16, corr: f16, new_s: f16, _: f16): - rescaled = arith.mulf(old_s, corr) - return arith.addf(rescaled, new_s) - - # Compute exp_scores @ V_tile - # exp_scores: [seq_q, tile_size] - # V_tile: [tile_size, head_dim] - # Result: [seq_q, head_dim] - exp_v_init = tensor.empty([seq_q_dim, head_dim], element_type) - exp_v_filled = linalg.fill(zero, outs=[exp_v_init]) - exp_v = linalg.matmul(exp_scores, v_tile, outs=[exp_v_filled]) - - # Update output with rescaling - # new_output: [seq_q, head_dim] = old_output * (correction * old_sum / new_sum) + (exp_v / new_sum) - # First compute rescale factor: correction / new_sum (broadcasted to [seq_q, 1]) - rescale_factor_div_init = tensor.empty([seq_q_dim], element_type) - rescale_factor_div = linalg.div( - correction, new_sum, outs=[rescale_factor_div_init] - ) - rescale_factor_mul_init = tensor.empty([seq_q_dim], element_type) - rescale_factor_mul = linalg.mul( - rescale_factor_div, old_sum, outs=[rescale_factor_mul_init] - ) - rescale_factor_2d_init = tensor.empty( - [seq_q_dim, tile_size_value], element_type - ) - rescale_factor_2d = linalg.broadcast( - rescale_factor_mul, - outs=[rescale_factor_2d_init], - dimensions=[0], - ) + # Get in_bounds attribute if it exists + in_bounds = k_load_op.attributes.get("in_bounds", None) - # Rescale old output - # rescaled_old_init = tensor.empty( - # [seq_q_dim, head_dim], element_type - # ) - rescaled_old = linalg.mul( - old_output, rescale_factor_2d, outs=[old_output] - ) + k_perm_map = k_load_op.attributes.get("permutation_map", None) - # Compute: exp_v / new_sum (broadcast new_sum to [seq_q, tile_size]) - norm_factor_2d_init = tensor.empty( - [seq_q_dim, tile_size_value], element_type + # Create vector.transfer_read for K tile + k_tile = vector.TransferReadOp( + k_tile_type, + k_memref, + k_load_indices, + k_perm_map, + padding, + in_bounds=in_bounds + ).result + # print(f"k_tile: {k_tile}") + + # Step 1: Transpose K tile from [tile_size, d_head] to [d_head, tile_size] + k_transpose_type = ir.VectorType.get([d_head, tile_size_value], element_type) + # vector.transpose with permutation [1, 0] swaps the two dimensions + k_transpose = vector.transpose(k_transpose_type, k_tile, [1, 0]) + # print(f"k_transpose: {k_transpose}") + + # Step 2: Compute Q * K_transpose using vector.contract + # Q shape: [wg_rows, d_head] + # K_transpose shape: [d_head, tile_size] + # Output shape: [wg_rows, tile_size] + # Contraction: Q[i, k] * K_transpose[k, j] -> QKT[i, j] + # indexing_maps: affine_map<(i, j, k) -> (i, k)>, affine_map<(i, j, k) -> (k, j)>, affine_map<(i, j, k) -> (i, j)> + # iterator_types: ["parallel", "parallel", "reduction"] + + q_value = q_load_op.results[0] + qkt_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) + + # Create zero-initialized accumulator for the contraction + qkt_acc_values = np.zeros((wg_rows, tile_size_value), dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) + qkt_acc_attr = ir.DenseElementsAttr.get(qkt_acc_values, type=qkt_type) + qkt_acc = arith.constant(qkt_type, qkt_acc_attr) + + # Create affine maps for the contraction + affine_d0 = ir.AffineExpr.get_dim(0) + affine_d1 = ir.AffineExpr.get_dim(1) + affine_d2 = ir.AffineExpr.get_dim(2) + + # Map for Q: (i, j, k) -> (i, k) + q_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) + # Map for K_transpose: (i, j, k) -> (k, j) + k_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) + # Map for output QKT: (i, j, k) -> (i, j) + out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) + + indexing_maps = ir.ArrayAttr.get([ + ir.AffineMapAttr.get(q_map), + ir.AffineMapAttr.get(k_map), + ir.AffineMapAttr.get(out_map) + ]) + + iterator_types = ir.ArrayAttr.get([ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type") + ]) + + qkt = vector.contract( + qkt_type, + q_value, + k_transpose, + qkt_acc, + indexing_maps=indexing_maps, + iterator_types=iterator_types ) - norm_factor_2d = linalg.broadcast( - new_sum, outs=[norm_factor_2d_init], dimensions=[0] + # print(f"qkt: {qkt}") + + # Step 3: Max reduction over the inner dimension of QKT + # QKT shape: [wg_rows, tile_size] + # Result shape: [wg_rows] + # We need to compute max along dimension 1 (tile_size dimension) + + qkt_max = vector.multi_reduction( + kind="maxnumf", + source=qkt, + acc=m_i_init, + reduction_dims=[1] ) - # Normalize new contribution - normalized_exp_v_init = tensor.empty( - [seq_q_dim, head_dim], element_type - ) - normalized_exp_v = linalg.div( - exp_v, norm_factor_2d, outs=[normalized_exp_v_init] + # Step 4: Scale the max: qkt_max_scaled = qkt_max * scale + # Both have shape [wg_rows] + qkt_max_scaled = arith.mulf(qkt_max, scale_vector) + + # Step 5: Compute m_ij = max(m_i, qkt_max_scaled) + # Both have shape [wg_rows] + m_ij = arith.maximumf(m_i, qkt_max_scaled) + + # Step 6: Scale QKT matrix: qkt_scaled = qkt * scale_2d + # Need to broadcast scale from [wg_rows] to [wg_rows, tile_size] + scale_2d_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) + scale_2d_values = np.full((wg_rows, tile_size_value), scale_value, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) + scale_2d_attr = ir.DenseElementsAttr.get(scale_2d_values, type=scale_2d_type) + scale_2d = arith.constant(scale_2d_type, scale_2d_attr) + qkt_scaled = arith.mulf(qkt, scale_2d) + + # Step 7: Broadcast m_ij from [wg_rows] to [wg_rows, tile_size] + m_ij_bcasted_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) + m_ij_bcasted = vector.broadcast(m_ij_bcasted_type, m_ij) + + # Step 8: Center the scores: qkt_centered = qkt_scaled - m_ij_bcasted + qkt_centered = arith.subf(qkt_scaled, m_ij_bcasted) + + # Step 9: Compute exponential: qkt_exp = exp(qkt_centered) + qkt_exp = math.exp(qkt_centered) + + # Step 10: Sum reduction along inner dimension: l_ij = sum(qkt_exp, dim=1) + # Shape [wg_rows, tile_size] -> [wg_rows] + l_ij = vector.multi_reduction( + kind="add", + source=qkt_exp, + acc=l_i_init, + reduction_dims=[1] ) - # Add both contributions - # new_output_init = tensor.empty([seq_q_dim, head_dim], element_type) - new_output = linalg.add( - rescaled_old, normalized_exp_v, outs=[old_output] + # Step 11: Compute alpha = exp(m_i - m_ij) + m_diff = arith.subf(m_i, m_ij) + alpha = math.exp(m_diff) + + # Step 12: Update l_i: l_i_updated = l_i * alpha + l_ij + l_i_scaled = arith.mulf(l_i, alpha) + l_i_updated = arith.addf(l_i_scaled, l_ij) + + # Step 13: Broadcast alpha from [wg_rows] to [wg_rows, d_head] + alpha_bcasted_type = ir.VectorType.get([wg_rows, d_head], element_type) + alpha_bcasted = vector.broadcast(alpha_bcasted_type, alpha) + + # Step 14: Update accumulator: acc_updated = acc * alpha_bcasted + acc_updated = arith.mulf(acc, alpha_bcasted) + + # Step 15: Load the current V tile: shape [tile_size, d_head] + # Use the same memref and indices as v_load, but replace second-to-last index with loop_idx + v_memref = v_load_op.operands[0] + v_tile_type = ir.VectorType.get([tile_size_value, d_head], element_type) + + # Get the indices from original v_load (all operands except the first one which is the memref) + # and the last one which is the padding value + v_load_indices = list(v_load_op.operands[1:-1]) + + # Replace the second-to-last index with loop_idx + v_tile_indices = v_load_indices + v_tile_indices[-2] = loop_idx # Assuming the reduction dimension is the second-to-last index + + # Get the padding value (last operand of v_load) + v_padding = v_load_op.operands[-1] + + # Get in_bounds attribute if it exists + v_in_bounds = v_load_op.attributes.get("in_bounds", None) + + v_perm_map = v_load_op.attributes.get("permutation_map", None) + + # Create vector.transfer_read for V tile + v_tile = vector.TransferReadOp( + v_tile_type, + v_memref, + v_load_indices, + v_perm_map, + v_padding, + in_bounds=v_in_bounds + ).result + + # Step 16: Compute attention-weighted values: pv_out = qkt_exp @ v_tile + # qkt_exp shape: [wg_rows, tile_size] + # v_tile shape: [tile_size, d_head] + # Output shape: [wg_rows, d_head] + # Contraction: qkt_exp[i, k] * v_tile[k, j] -> pv_out[i, j] + + # Create affine maps for the contraction + affine_d0 = ir.AffineExpr.get_dim(0) + affine_d1 = ir.AffineExpr.get_dim(1) + affine_d2 = ir.AffineExpr.get_dim(2) + + # Map for qkt_exp: (i, j, k) -> (i, k) + qkt_exp_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) + # Map for v_tile: (i, j, k) -> (k, j) + v_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) + # Map for output pv_out: (i, j, k) -> (i, j) + pv_out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) + + indexing_maps_pv = ir.ArrayAttr.get([ + ir.AffineMapAttr.get(qkt_exp_map), + ir.AffineMapAttr.get(v_map), + ir.AffineMapAttr.get(pv_out_map) + ]) + + iterator_types_pv = ir.ArrayAttr.get([ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type") + ]) + + pv_out = vector.contract( + acc_vector_type, + qkt_exp, + v_tile, + acc_updated, + indexing_maps=indexing_maps_pv, + iterator_types=iterator_types_pv ) - scf.yield_([new_max, new_sum, new_output]) + # Yield the updated iter args + scf.yield_([m_ij, l_i_updated, pv_out]) - # Extract final result from loop - final_output_2d = loop_result.results[2] + # Extract the final accumulator result (3rd output) from the loop + final_output = loop.results[2] - # Expand the 2D output back to 3D to match the original output shape - # [seq_q, head_dim] -> [1, seq_q, head_dim] - final_output_3d = tensor.expand_shape( - output_type, - src=final_output_2d, - reassociation=[[0, 1], [2]], - output_shape=[], - static_output_shape=output_type.shape, - ) - # Create a dummy add operation to wrap the final output - # This is needed because replace_op requires an operation, not a value - zero_const = arith.constant(element_type, 0.0) - # Create a linalg.add that adds 0 (identity operation) - zero_tensor_shape = output_type.shape - zero_tensor_init = tensor.empty(zero_tensor_shape, element_type) - zero_tensor = linalg.fill(zero_const, outs=[zero_tensor_init]) - - # Create the add operation: final_output_3d + 0 - output_init_for_add = tensor.empty(zero_tensor_shape, element_type) - dummy_add = linalg.add( - final_output_3d, zero_tensor, outs=[output_init_for_add] - ) + # Replace all uses of the original output operation with the final loop result + output_op.results[0].replace_all_uses_with(final_output) - # Replace the original output operation with the dummy add - rewriter.replace_op(output_op, dummy_add.owner) + # Erase the original output operation + rewriter.erase_op(output_op) - results.set_ops(op.new_output, [dummy_add.owner]) + # Return the final output handle + results.set_ops(op.new_output, [final_output.owner]) return DiagnosedSilenceableFailure.Success @staticmethod @@ -420,22 +400,22 @@ def get_effects(op: ir.Operation, effects): def generate_fused_attention( - q_slice: ir.Value, - k_slice: ir.Value, - scale_slice: ir.Value, - v_slice: ir.Value, + q_load: ir.Value, + k_load: ir.Value, + v_load: ir.Value, + scale: ir.Value, output: ir.Value, tile_size: int | ir.IntegerAttr, ) -> ir.Value: - """Generate fused attention computation with inner tiling. + """Generate fused attention computation with inner tiling on bufferized IR. Args: - q_slice: Handle to Q slice operation - k_slice: Handle to K slice operation - scale_slice: Handle to scaling factor slice operation - v_slice: Handle to V slice operation - output: Handle to output operation to replace - tile_size: Size of tiles along the K dimension + q_load: Handle to Q load operation (vector.transfer_read) + k_load: Handle to K load operation (vector.transfer_read) + v_load: Handle to V load operation (vector.transfer_read) + scale: Handle to scale constant operation (arith.constant) + output: Handle to output operation to replace (vector.contract) + tile_size: Tile size for the reduction dimension tiling (K/V sequence length) Returns: Handle to the new output operation @@ -444,5 +424,5 @@ def generate_fused_attention( tile_size = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), tile_size) return GenerateFusedAttention( - q_slice, k_slice, scale_slice, v_slice, output, tile_size=tile_size + q_load, k_load, v_load, scale, output, tile_size=tile_size ).new_output diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index a8fbaec0..2533f0de 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -226,17 +226,17 @@ def bundle_xegpu_fused_attention_schedule( # Generate fused attention computation with inner tiling (flash attention) # This replaces the current unfused computation with a tiled loop that # maintains online max and sum for efficient memory usage - tile_size = 64 - new_output = generate_fused_attention( - q_slice=q_slice, - k_slice=k_slice, - scale_slice=scale_slice, - v_slice=v_slice, - output=last_matmul, - tile_size=tile_size, - ) - transform.apply_cse(func) - canonicalize(func) + # tile_size = 64 + # new_output = generate_fused_attention( + # q_slice=q_slice, + # k_slice=k_slice, + # scale_slice=scale_slice, + # v_slice=v_slice, + # output=last_matmul, + # tile_size=tile_size, + # ) + # transform.apply_cse(func) + # canonicalize(func) if stop_at_stage == "inner-tiled": raise PipelineInterrupt() @@ -283,6 +283,55 @@ def bundle_xegpu_fused_attention_schedule( }, ) + # Extract q, k, v memrefs from the bufferized IR + # Match vector.contract ops to find the q, k, v loads + for_all = match(mod, ops={"scf.forall"}) + func = transform.get_parent_op(anytype, for_all, op_name="func.func") + contract_ops = match_and_split(func, ops={"vector.contract"}, nhandles=2) + + # First vector.contract is Q @ K^T + # Its first operand is the q load (vector.transfer_read) + # Its second operand is the k load (vector.transfer_read) + first_contract = contract_ops[0] + q_load = transform.get_producer_of_operand(anytype, first_contract, operand_number=0) + k_load = transform.get_producer_of_operand(anytype, first_contract, operand_number=1) + + # # Second vector.contract is attention_weights @ V + # # Its second operand is the v load (vector.transfer_read) + second_contract = contract_ops[1] + v_load = transform.get_producer_of_operand(anytype, second_contract, operand_number=1) + + # Extract memrefs from the loads (first operand of vector.transfer_read) + # q_memref = transform.get_operand(anytype, q_load, [0]) + # k_memref = transform.get_operand(anytype, k_load, [0]) + # v_memref = transform.get_operand(anytype, v_load, [0]) + + # Match arith.mulf to get the scale parameter + # The scale is the second operand of arith.mulf (the constant) + mulf_op = match_and_split(func, ops={"arith.mulf"}, nhandles=1)[0] + scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) + + # Debug prints to verify we got the right memrefs and scale + # transform.print_(target=q_load, name="q_memref") + # transform.print_(target=k_load, name="k_memref") + # transform.print_(target=v_load, name="v_memref") + # transform.print_(target=scale, name="scale") + + # Generate fused attention computation with inner tiling + # This replaces the second vector.contract (attention_weights @ V) with a tiled + # loop that implements online softmax for efficient memory usage + tile_size = 64 # Tile size for reduction dimension (K/V sequence length) + new_output = generate_fused_attention( + q_load=q_load, + k_load=k_load, + v_load=v_load, + scale=scale, + output=second_contract, + tile_size=tile_size, + ) + # transform.apply_cse(func) + # canonicalize(func) + if stop_at_stage == "bufferized": raise PipelineInterrupt() @@ -336,10 +385,7 @@ def bundle_xegpu_fused_attention_schedule( update_address_space(alloca, address_space=3) gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) - - # Cleanup. - transform.apply_cse(mod) - canonicalize(mod) + gpu_func = apply_registered_pass(gpu_func, "loop-invariant-code-motion") if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() @@ -350,13 +396,6 @@ def bundle_xegpu_fused_attention_schedule( out_sg_data = [sg_rows, parameters["n_head"]] xegpu.set_anchor_layout(store_nd_op, sg_layout=out_sg_layout, sg_data=out_sg_data) - # Set layout attributes for xegpu.store_matrix ops. same as store_nd ops. - store_matrix_ops = match_and_split(gpu_func, ops={"xegpu.store_matrix"}, nhandles=2) - for store_matrix_op in store_matrix_ops: - xegpu.set_anchor_layout( - store_matrix_op, sg_layout=out_sg_layout, sg_data=out_sg_data - ) - # Set layout for xegpu.dpas ops dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=2) # layouts for the first dpas op (Q*K^T): From 14cbb3b4628605f72208d2ce714e63e82fd5d6a5 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 14 May 2026 23:57:20 +0000 Subject: [PATCH 56/63] save valid imex version --- .../ops/generate_fused_attention.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index e6f9e6f1..ba8644cb 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -264,9 +264,11 @@ def apply( # Step 7: Broadcast m_ij from [wg_rows] to [wg_rows, tile_size] m_ij_bcasted_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) m_ij_bcasted = vector.broadcast(m_ij_bcasted_type, m_ij) + m_ij_transposed_type = ir.VectorType.get([tile_size_value, wg_rows], element_type) + m_ij_transposed = vector.transpose(m_ij_transposed_type, m_ij_bcasted, [1, 0]) - # Step 8: Center the scores: qkt_centered = qkt_scaled - m_ij_bcasted - qkt_centered = arith.subf(qkt_scaled, m_ij_bcasted) + # Step 8: Center the scores: qkt_centered = qkt_scaled - m_ij_transposed + qkt_centered = arith.subf(qkt_scaled, m_ij_transposed) # Step 9: Compute exponential: qkt_exp = exp(qkt_centered) qkt_exp = math.exp(qkt_centered) @@ -291,9 +293,11 @@ def apply( # Step 13: Broadcast alpha from [wg_rows] to [wg_rows, d_head] alpha_bcasted_type = ir.VectorType.get([wg_rows, d_head], element_type) alpha_bcasted = vector.broadcast(alpha_bcasted_type, alpha) + alpha_transposed_type = ir.VectorType.get([d_head, wg_rows], element_type) + alpha_transposed = vector.transpose(alpha_transposed_type, alpha_bcasted, [1, 0]) # Step 14: Update accumulator: acc_updated = acc * alpha_bcasted - acc_updated = arith.mulf(acc, alpha_bcasted) + acc_updated = arith.mulf(acc, alpha_transposed) # Step 15: Load the current V tile: shape [tile_size, d_head] # Use the same memref and indices as v_load, but replace second-to-last index with loop_idx @@ -369,17 +373,26 @@ def apply( scf.yield_([m_ij, l_i_updated, pv_out]) # Extract the final accumulator result (3rd output) from the loop - final_output = loop.results[2] + pv_out = loop.results[2] + l_i_out = loop.results[1] + with ir.InsertionPoint.after(loop): + # Step 17: Normalize the output: output_final = pv_out / l_i_out + # Need to broadcast l_i_out from [wg_rows] to [wg_rows, d_head] + l_i_out_bcasted_type = ir.VectorType.get([wg_rows, d_head], element_type) + l_i_out_bcasted = vector.broadcast(l_i_out_bcasted_type, l_i_out) + l_i_out_transposed_type = ir.VectorType.get([d_head, wg_rows], element_type) + l_i_out_transposed = vector.transpose(l_i_out_transposed_type, l_i_out_bcasted, [1, 0]) + output_final = arith.divf(pv_out, l_i_out_transposed) # Replace all uses of the original output operation with the final loop result - output_op.results[0].replace_all_uses_with(final_output) + output_op.results[0].replace_all_uses_with(output_final) # Erase the original output operation rewriter.erase_op(output_op) # Return the final output handle - results.set_ops(op.new_output, [final_output.owner]) + results.set_ops(op.new_output, [output_final.owner]) return DiagnosedSilenceableFailure.Success @staticmethod From e4c347c8981424106e22590019c5d6703c6e25f1 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 15 May 2026 17:23:15 +0000 Subject: [PATCH 57/63] match dims of imex --- examples/xegpu/fused_attention.py | 6 +++--- .../transform_ext/ops/generate_fused_attention.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index 101e3ff9..c375470c 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -212,7 +212,7 @@ def parse_cli(): parser.add_argument( "--n-ctx", type=int, - default=1024, + default=4096, help="Context length (sequence length)", ) parser.add_argument( @@ -224,13 +224,13 @@ def parse_cli(): parser.add_argument( "--wg-rows", type=int, - default=64, + default=128, help="Number of Q*K^T*V rows computed by each work group", ) parser.add_argument( "--sg-rows", type=int, - default=8, + default=16, help="Number of Q*K^T*V rows computed by each subgroup", ) parser.add_argument( diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index ba8644cb..facde5ec 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -262,9 +262,9 @@ def apply( qkt_scaled = arith.mulf(qkt, scale_2d) # Step 7: Broadcast m_ij from [wg_rows] to [wg_rows, tile_size] - m_ij_bcasted_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) + m_ij_bcasted_type = ir.VectorType.get([tile_size_value, wg_rows], element_type) m_ij_bcasted = vector.broadcast(m_ij_bcasted_type, m_ij) - m_ij_transposed_type = ir.VectorType.get([tile_size_value, wg_rows], element_type) + m_ij_transposed_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) m_ij_transposed = vector.transpose(m_ij_transposed_type, m_ij_bcasted, [1, 0]) # Step 8: Center the scores: qkt_centered = qkt_scaled - m_ij_transposed @@ -291,9 +291,9 @@ def apply( l_i_updated = arith.addf(l_i_scaled, l_ij) # Step 13: Broadcast alpha from [wg_rows] to [wg_rows, d_head] - alpha_bcasted_type = ir.VectorType.get([wg_rows, d_head], element_type) + alpha_bcasted_type = ir.VectorType.get([d_head, wg_rows], element_type) alpha_bcasted = vector.broadcast(alpha_bcasted_type, alpha) - alpha_transposed_type = ir.VectorType.get([d_head, wg_rows], element_type) + alpha_transposed_type = ir.VectorType.get([wg_rows, d_head], element_type) alpha_transposed = vector.transpose(alpha_transposed_type, alpha_bcasted, [1, 0]) # Step 14: Update accumulator: acc_updated = acc * alpha_bcasted @@ -378,9 +378,9 @@ def apply( with ir.InsertionPoint.after(loop): # Step 17: Normalize the output: output_final = pv_out / l_i_out # Need to broadcast l_i_out from [wg_rows] to [wg_rows, d_head] - l_i_out_bcasted_type = ir.VectorType.get([wg_rows, d_head], element_type) + l_i_out_bcasted_type = ir.VectorType.get([d_head, wg_rows], element_type) l_i_out_bcasted = vector.broadcast(l_i_out_bcasted_type, l_i_out) - l_i_out_transposed_type = ir.VectorType.get([d_head, wg_rows], element_type) + l_i_out_transposed_type = ir.VectorType.get([wg_rows, d_head], element_type) l_i_out_transposed = vector.transpose(l_i_out_transposed_type, l_i_out_bcasted, [1, 0]) output_final = arith.divf(pv_out, l_i_out_transposed) From 2c87d4a70e78246bf86cba5a55be863ec045dcb3 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 15 May 2026 20:12:22 +0000 Subject: [PATCH 58/63] save work --- .../xegpu/fused_attention_schedule.py | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 2533f0de..539952e1 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -329,8 +329,8 @@ def bundle_xegpu_fused_attention_schedule( output=second_contract, tile_size=tile_size, ) - # transform.apply_cse(func) - # canonicalize(func) + transform.apply_cse(func) + canonicalize(func) if stop_at_stage == "bufferized": raise PipelineInterrupt() @@ -390,42 +390,42 @@ def bundle_xegpu_fused_attention_schedule( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - # Set layout attributes for xegpu.store_nd ops. - store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] - out_sg_layout = [num_subgroups, 1] - out_sg_data = [sg_rows, parameters["n_head"]] - xegpu.set_anchor_layout(store_nd_op, sg_layout=out_sg_layout, sg_data=out_sg_data) - - # Set layout for xegpu.dpas ops - dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=2) - # layouts for the first dpas op (Q*K^T): - first_dpas_op = dpas_ops[0] - qk_a_sg_layout = out_sg_layout - qk_a_sg_data = out_sg_data - qk_b_sg_layout = [1, num_subgroups] - qk_b_sg_data = [parameters["n_head"], num_subgroups] - qk_cd_sg_layout = out_sg_layout - qk_cd_sg_data = [sg_rows, 16] - xegpu.set_anchor_layout( - first_dpas_op, sg_layout=qk_a_sg_layout, sg_data=qk_a_sg_data, index=0 - ) - xegpu.set_anchor_layout( - first_dpas_op, sg_layout=qk_b_sg_layout, sg_data=qk_b_sg_data, index=1 - ) - xegpu.set_anchor_layout( - first_dpas_op, sg_layout=qk_cd_sg_layout, sg_data=qk_cd_sg_data, index=2 - ) - # layouts for the second dpas op (attention_weights*V): - second_dpas_op = dpas_ops[1] - xegpu.set_anchor_layout( - second_dpas_op, sg_layout=qk_cd_sg_layout, sg_data=qk_cd_sg_data, index=0 - ) - xegpu.set_anchor_layout( - second_dpas_op, sg_layout=qk_b_sg_layout, sg_data=qk_b_sg_data, index=1 - ) - xegpu.set_anchor_layout( - second_dpas_op, sg_layout=out_sg_layout, sg_data=out_sg_data, index=2 - ) + # # Set layout attributes for xegpu.store_nd ops. + # store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] + # out_sg_layout = [num_subgroups, 1] + # out_sg_data = [sg_rows, parameters["n_head"]] + # xegpu.set_anchor_layout(store_nd_op, sg_layout=out_sg_layout, sg_data=out_sg_data) + + # # Set layout for xegpu.dpas ops + # dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=2) + # # layouts for the first dpas op (Q*K^T): + # first_dpas_op = dpas_ops[0] + # qk_a_sg_layout = out_sg_layout + # qk_a_sg_data = out_sg_data + # qk_b_sg_layout = [1, num_subgroups] + # qk_b_sg_data = [parameters["n_head"], num_subgroups] + # qk_cd_sg_layout = out_sg_layout + # qk_cd_sg_data = [sg_rows, 16] + # xegpu.set_anchor_layout( + # first_dpas_op, sg_layout=qk_a_sg_layout, sg_data=qk_a_sg_data, index=0 + # ) + # xegpu.set_anchor_layout( + # first_dpas_op, sg_layout=qk_b_sg_layout, sg_data=[64, 64], order=[0, 1], index=1 + # ) + # xegpu.set_anchor_layout( + # first_dpas_op, sg_layout=qk_cd_sg_layout, sg_data=qk_cd_sg_data, index=2 + # ) + # # layouts for the second dpas op (attention_weights*V): + # second_dpas_op = dpas_ops[1] + # xegpu.set_anchor_layout( + # second_dpas_op, sg_layout=qk_cd_sg_layout, sg_data=qk_cd_sg_data, index=0 + # ) + # xegpu.set_anchor_layout( + # second_dpas_op, sg_layout=[8, 1], sg_data=[64, 64], index=1 + # ) + # xegpu.set_anchor_layout( + # second_dpas_op, sg_layout=out_sg_layout, sg_data=out_sg_data, index=2 + # ) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From c0ec544f2880f058596171669cb017fb6db29338 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 15 May 2026 20:38:13 +0000 Subject: [PATCH 59/63] save unolled version --- .../ops/generate_fused_attention.py | 252 +++++++++--------- 1 file changed, 123 insertions(+), 129 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index facde5ec..656f3413 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -145,70 +145,35 @@ def apply( l_i = loop.inner_iter_args[1] acc = loop.inner_iter_args[2] - # Load the current K tile: shape [tile_size, d_head] - # Use the same memref and indices as k_load, but replace second-to-last index with loop_idx + # Get common values for K/V tiling k_memref = k_load_op.operands[0] - k_tile_type = ir.VectorType.get([tile_size_value, d_head], element_type) - - # Get the indices from original k_load (all operands except the first one which is the memref) - # and the last one which is the padding value k_load_indices = list(k_load_op.operands[1:-1]) - - # Replace the second-to-last index with loop_idx - k_tile_indices = k_load_indices - k_tile_indices[-2] = loop_idx # Assuming the reduction dimension is the last index before padding - - # Get the padding value (last operand of k_load) padding = k_load_op.operands[-1] - - # Get in_bounds attribute if it exists in_bounds = k_load_op.attributes.get("in_bounds", None) - k_perm_map = k_load_op.attributes.get("permutation_map", None) + q_value = q_load_op.results[0] - # Create vector.transfer_read for K tile - k_tile = vector.TransferReadOp( - k_tile_type, - k_memref, - k_load_indices, - k_perm_map, - padding, - in_bounds=in_bounds - ).result - # print(f"k_tile: {k_tile}") - - # Step 1: Transpose K tile from [tile_size, d_head] to [d_head, tile_size] - k_transpose_type = ir.VectorType.get([d_head, tile_size_value], element_type) - # vector.transpose with permutation [1, 0] swaps the two dimensions - k_transpose = vector.transpose(k_transpose_type, k_tile, [1, 0]) - # print(f"k_transpose: {k_transpose}") - - # Step 2: Compute Q * K_transpose using vector.contract - # Q shape: [wg_rows, d_head] - # K_transpose shape: [d_head, tile_size] - # Output shape: [wg_rows, tile_size] - # Contraction: Q[i, k] * K_transpose[k, j] -> QKT[i, j] - # indexing_maps: affine_map<(i, j, k) -> (i, k)>, affine_map<(i, j, k) -> (k, j)>, affine_map<(i, j, k) -> (i, j)> - # iterator_types: ["parallel", "parallel", "reduction"] + # Constants for K/V tiling (tile into chunks of 16) + k_subtile_size = 16 + num_k_tiles = tile_size_value // k_subtile_size - q_value = q_load_op.results[0] - qkt_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) + # Create offset constants for each K tile + k_tile_offsets = [] + for i in range(num_k_tiles): + offset = arith.constant(index_type, i * k_subtile_size) + k_tile_offsets.append(offset) - # Create zero-initialized accumulator for the contraction - qkt_acc_values = np.zeros((wg_rows, tile_size_value), dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) - qkt_acc_attr = ir.DenseElementsAttr.get(qkt_acc_values, type=qkt_type) - qkt_acc = arith.constant(qkt_type, qkt_acc_attr) + # Step 1: Load and process K tiles (unrolled) + # Each K tile is [16, d_head], transposed to [d_head, 16], contracted to [wg_rows, 16] + qkt_chunks = [] - # Create affine maps for the contraction + # Create affine maps for Q@K contraction (used for all tiles) affine_d0 = ir.AffineExpr.get_dim(0) affine_d1 = ir.AffineExpr.get_dim(1) affine_d2 = ir.AffineExpr.get_dim(2) - # Map for Q: (i, j, k) -> (i, k) q_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) - # Map for K_transpose: (i, j, k) -> (k, j) k_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) - # Map for output QKT: (i, j, k) -> (i, j) out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) indexing_maps = ir.ArrayAttr.get([ @@ -223,24 +188,56 @@ def apply( ir.Attribute.parse("#vector.iterator_type") ]) - qkt = vector.contract( - qkt_type, - q_value, - k_transpose, - qkt_acc, - indexing_maps=indexing_maps, - iterator_types=iterator_types - ) - # print(f"qkt: {qkt}") - - # Step 3: Max reduction over the inner dimension of QKT - # QKT shape: [wg_rows, tile_size] - # Result shape: [wg_rows] - # We need to compute max along dimension 1 (tile_size dimension) - + # Accumulator for Q@K chunks + qkt_chunk_type = ir.VectorType.get([wg_rows, k_subtile_size], element_type) + qkt_chunk_acc_values = np.zeros((wg_rows, k_subtile_size), dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) + qkt_chunk_acc_attr = ir.DenseElementsAttr.get(qkt_chunk_acc_values, type=qkt_chunk_type) + qkt_chunk_acc = arith.constant(qkt_chunk_type, qkt_chunk_acc_attr) + + for tile_idx in range(num_k_tiles): + # Compute the offset index for this tile + k_tile_offset = arith.addi(loop_idx, k_tile_offsets[tile_idx]) + + # Update indices for this K tile + k_tile_indices = k_load_indices.copy() + k_tile_indices[-2] = k_tile_offset + + # Load K tile: [16, d_head] + k_tile_type = ir.VectorType.get([k_subtile_size, d_head], element_type) + k_tile = vector.TransferReadOp( + k_tile_type, + k_memref, + k_tile_indices, + k_perm_map, + padding, + in_bounds=in_bounds + ).result + + # Transpose K tile: [16, d_head] -> [d_head, 16] + k_transpose_type = ir.VectorType.get([d_head, k_subtile_size], element_type) + k_transpose = vector.transpose(k_transpose_type, k_tile, [1, 0]) + + # Contract Q @ K_transpose: [wg_rows, d_head] @ [d_head, 16] -> [wg_rows, 16] + qkt_chunk = vector.contract( + qkt_chunk_type, + q_value, + k_transpose, + qkt_chunk_acc, + indexing_maps=indexing_maps, + iterator_types=iterator_types + ) + qkt_chunks.append(qkt_chunk) + + # Step 2: Elementwise maximum across all Q@K chunks + # Build tree of maximumf operations + qkt_max_combined = qkt_chunks[0] + for i in range(1, num_k_tiles): + qkt_max_combined = arith.maximumf(qkt_max_combined, qkt_chunks[i]) + + # Step 3: Final multi_reduction to get row-wise max: [wg_rows, 16] -> [wg_rows] qkt_max = vector.multi_reduction( kind="maxnumf", - source=qkt, + source=qkt_max_combined, acc=m_i_init, reduction_dims=[1] ) @@ -253,31 +250,41 @@ def apply( # Both have shape [wg_rows] m_ij = arith.maximumf(m_i, qkt_max_scaled) - # Step 6: Scale QKT matrix: qkt_scaled = qkt * scale_2d - # Need to broadcast scale from [wg_rows] to [wg_rows, tile_size] - scale_2d_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) - scale_2d_values = np.full((wg_rows, tile_size_value), scale_value, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) - scale_2d_attr = ir.DenseElementsAttr.get(scale_2d_values, type=scale_2d_type) - scale_2d = arith.constant(scale_2d_type, scale_2d_attr) - qkt_scaled = arith.mulf(qkt, scale_2d) + # Step 6-9: Apply softmax to each Q@K chunk + # Scale constant for chunks: [wg_rows, 16] + scale_chunk_type = ir.VectorType.get([wg_rows, k_subtile_size], element_type) + scale_chunk_values = np.full((wg_rows, k_subtile_size), scale_value, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) + scale_chunk_attr = ir.DenseElementsAttr.get(scale_chunk_values, type=scale_chunk_type) + scale_chunk = arith.constant(scale_chunk_type, scale_chunk_attr) - # Step 7: Broadcast m_ij from [wg_rows] to [wg_rows, tile_size] - m_ij_bcasted_type = ir.VectorType.get([tile_size_value, wg_rows], element_type) + # Broadcast m_ij from [wg_rows] to [wg_rows, 16] + m_ij_bcasted_type = ir.VectorType.get([k_subtile_size, wg_rows], element_type) m_ij_bcasted = vector.broadcast(m_ij_bcasted_type, m_ij) - m_ij_transposed_type = ir.VectorType.get([wg_rows, tile_size_value], element_type) + m_ij_transposed_type = ir.VectorType.get([wg_rows, k_subtile_size], element_type) m_ij_transposed = vector.transpose(m_ij_transposed_type, m_ij_bcasted, [1, 0]) - # Step 8: Center the scores: qkt_centered = qkt_scaled - m_ij_transposed - qkt_centered = arith.subf(qkt_scaled, m_ij_transposed) + # Apply softmax to each chunk + qkt_exp_chunks = [] + for qkt_chunk in qkt_chunks: + # Scale: qkt_scaled = qkt_chunk * scale + qkt_scaled = arith.mulf(qkt_chunk, scale_chunk) + + # Center: qkt_centered = qkt_scaled - m_ij_transposed + qkt_centered = arith.subf(qkt_scaled, m_ij_transposed) - # Step 9: Compute exponential: qkt_exp = exp(qkt_centered) - qkt_exp = math.exp(qkt_centered) + # Exponential: qkt_exp = exp(qkt_centered) + qkt_exp = math.exp(qkt_centered) + qkt_exp_chunks.append(qkt_exp) - # Step 10: Sum reduction along inner dimension: l_ij = sum(qkt_exp, dim=1) - # Shape [wg_rows, tile_size] -> [wg_rows] + # Step 10: Elementwise sum across all exp chunks + qkt_exp_combined = qkt_exp_chunks[0] + for i in range(1, num_k_tiles): + qkt_exp_combined = arith.addf(qkt_exp_combined, qkt_exp_chunks[i]) + + # Final multi_reduction to get row-wise sum: [wg_rows, 16] -> [wg_rows] l_ij = vector.multi_reduction( kind="add", - source=qkt_exp, + source=qkt_exp_combined, acc=l_i_init, reduction_dims=[1] ) @@ -299,53 +306,17 @@ def apply( # Step 14: Update accumulator: acc_updated = acc * alpha_bcasted acc_updated = arith.mulf(acc, alpha_transposed) - # Step 15: Load the current V tile: shape [tile_size, d_head] - # Use the same memref and indices as v_load, but replace second-to-last index with loop_idx + # Step 15-16: Load V tiles and compute attention-weighted values + # Get V load parameters v_memref = v_load_op.operands[0] - v_tile_type = ir.VectorType.get([tile_size_value, d_head], element_type) - - # Get the indices from original v_load (all operands except the first one which is the memref) - # and the last one which is the padding value v_load_indices = list(v_load_op.operands[1:-1]) - - # Replace the second-to-last index with loop_idx - v_tile_indices = v_load_indices - v_tile_indices[-2] = loop_idx # Assuming the reduction dimension is the second-to-last index - - # Get the padding value (last operand of v_load) v_padding = v_load_op.operands[-1] - - # Get in_bounds attribute if it exists v_in_bounds = v_load_op.attributes.get("in_bounds", None) - v_perm_map = v_load_op.attributes.get("permutation_map", None) - # Create vector.transfer_read for V tile - v_tile = vector.TransferReadOp( - v_tile_type, - v_memref, - v_load_indices, - v_perm_map, - v_padding, - in_bounds=v_in_bounds - ).result - - # Step 16: Compute attention-weighted values: pv_out = qkt_exp @ v_tile - # qkt_exp shape: [wg_rows, tile_size] - # v_tile shape: [tile_size, d_head] - # Output shape: [wg_rows, d_head] - # Contraction: qkt_exp[i, k] * v_tile[k, j] -> pv_out[i, j] - - # Create affine maps for the contraction - affine_d0 = ir.AffineExpr.get_dim(0) - affine_d1 = ir.AffineExpr.get_dim(1) - affine_d2 = ir.AffineExpr.get_dim(2) - - # Map for qkt_exp: (i, j, k) -> (i, k) + # Create affine maps for P@V contraction qkt_exp_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) - # Map for v_tile: (i, j, k) -> (k, j) v_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) - # Map for output pv_out: (i, j, k) -> (i, j) pv_out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) indexing_maps_pv = ir.ArrayAttr.get([ @@ -360,14 +331,37 @@ def apply( ir.Attribute.parse("#vector.iterator_type") ]) - pv_out = vector.contract( - acc_vector_type, - qkt_exp, - v_tile, - acc_updated, - indexing_maps=indexing_maps_pv, - iterator_types=iterator_types_pv - ) + # Load and process V tiles (unrolled), accumulating results + pv_out = acc_updated + for tile_idx in range(num_k_tiles): + # Compute the offset index for this V tile + v_tile_offset = arith.addi(loop_idx, k_tile_offsets[tile_idx]) + + # Update indices for this V tile + v_tile_indices = v_load_indices.copy() + v_tile_indices[-2] = v_tile_offset + + # Load V tile: [16, d_head] + v_tile_type = ir.VectorType.get([k_subtile_size, d_head], element_type) + v_tile = vector.TransferReadOp( + v_tile_type, + v_memref, + v_tile_indices, + v_perm_map, + v_padding, + in_bounds=v_in_bounds + ).result + + # Contract qkt_exp_chunk @ v_tile: [wg_rows, 16] @ [16, d_head] -> [wg_rows, d_head] + # Accumulate into pv_out + pv_out = vector.contract( + acc_vector_type, + qkt_exp_chunks[tile_idx], + v_tile, + pv_out, + indexing_maps=indexing_maps_pv, + iterator_types=iterator_types_pv + ) # Yield the updated iter args scf.yield_([m_ij, l_i_updated, pv_out]) From 60e3b2e5f9721925e4bc07cd7f62700c2856650a Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 15 May 2026 21:16:35 +0000 Subject: [PATCH 60/63] compile to binary now --- .../xegpu/fused_attention_schedule.py | 159 ++++++++++++++---- 1 file changed, 123 insertions(+), 36 deletions(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 539952e1..5d78e748 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -390,42 +390,129 @@ def bundle_xegpu_fused_attention_schedule( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - # # Set layout attributes for xegpu.store_nd ops. - # store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] - # out_sg_layout = [num_subgroups, 1] - # out_sg_data = [sg_rows, parameters["n_head"]] - # xegpu.set_anchor_layout(store_nd_op, sg_layout=out_sg_layout, sg_data=out_sg_data) - - # # Set layout for xegpu.dpas ops - # dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=2) - # # layouts for the first dpas op (Q*K^T): - # first_dpas_op = dpas_ops[0] - # qk_a_sg_layout = out_sg_layout - # qk_a_sg_data = out_sg_data - # qk_b_sg_layout = [1, num_subgroups] - # qk_b_sg_data = [parameters["n_head"], num_subgroups] - # qk_cd_sg_layout = out_sg_layout - # qk_cd_sg_data = [sg_rows, 16] - # xegpu.set_anchor_layout( - # first_dpas_op, sg_layout=qk_a_sg_layout, sg_data=qk_a_sg_data, index=0 - # ) - # xegpu.set_anchor_layout( - # first_dpas_op, sg_layout=qk_b_sg_layout, sg_data=[64, 64], order=[0, 1], index=1 - # ) - # xegpu.set_anchor_layout( - # first_dpas_op, sg_layout=qk_cd_sg_layout, sg_data=qk_cd_sg_data, index=2 - # ) - # # layouts for the second dpas op (attention_weights*V): - # second_dpas_op = dpas_ops[1] - # xegpu.set_anchor_layout( - # second_dpas_op, sg_layout=qk_cd_sg_layout, sg_data=qk_cd_sg_data, index=0 - # ) - # xegpu.set_anchor_layout( - # second_dpas_op, sg_layout=[8, 1], sg_data=[64, 64], index=1 - # ) - # xegpu.set_anchor_layout( - # second_dpas_op, sg_layout=out_sg_layout, sg_data=out_sg_data, index=2 - # ) + # Define XeGPU layout parameters + q_sg_layout = [8, 1] + q_sg_data = [16, 64] + q_inst_data = [8, 16] + + k_sg_layout = [8, 1] + k_sg_data = [16, 64] + k_inst_data = [16, 16] + + v_sg_layout = k_sg_layout + v_sg_data = k_sg_data + v_inst_data = k_inst_data + + kt_sg_layout = [1, 8] + kt_sg_data = [64, 16] + kt_inst_data = [16, 16] + kt_order = [0, 1] + + out_sg_layout = q_sg_layout + out_sg_data = q_sg_data + out_inst_data = q_inst_data + + + layout_128x16_sg_layout = [8, 1] + layout_128x16_sg_data = [16, 16] + layout_128x16_inst_data = [8, 16] + + qk_sg_layout = layout_128x16_sg_layout + qk_sg_data = layout_128x16_sg_data + qk_inst_data = layout_128x16_inst_data + + # Set layout attributes for xegpu.store_nd ops. + store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] + xegpu.set_anchor_layout(store_nd_op, sg_layout=out_sg_layout, sg_data=out_sg_data, inst_data=out_inst_data) + + # Set layout for xegpu.load_nd ops (9 total: 1 Q, 4 K, 4 V) + load_nd_ops = match_and_split(gpu_func, ops={"xegpu.load_nd"}, nhandles=9) + + # First load_nd: Q layout + xegpu.set_anchor_layout( + load_nd_ops[0], + sg_layout=q_sg_layout, + sg_data=q_sg_data, + inst_data=q_inst_data + ) + + # Next 4 load_nd ops: K layout + for i in range(1, 5): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=k_sg_layout, + sg_data=k_sg_data, + inst_data=k_inst_data + ) + + # Last 4 load_nd ops: V layout + for i in range(5, 9): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data + ) + + # Set layout for xegpu.dpas ops (8 total: 4 for Q@K, 4 for P@V) + dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=8) + + # Layouts for first 4 dpas ops (Q@K^T): + for i in range(4): + qk_dpas_op = dpas_ops[i] + # Index 0: Q layout + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=q_sg_layout, + sg_data=q_sg_data, + inst_data=q_inst_data, + index=0 + ) + # Index 1: K^T layout + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=kt_sg_layout, + sg_data=kt_sg_data, + inst_data=kt_inst_data, + order=kt_order, + index=1 + ) + # Index 2: QK output layout (128x16) + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=layout_128x16_sg_layout, + sg_data=layout_128x16_sg_data, + inst_data=layout_128x16_inst_data, + index=2 + ) + + # Layouts for second 4 dpas ops (P@V): + for i in range(4, 8): + pv_dpas_op = dpas_ops[i] + # Index 0: QK (attention weights) layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=qk_sg_layout, + sg_data=qk_sg_data, + inst_data=qk_inst_data, + index=0 + ) + # Index 1: V layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + index=1 + ) + # Index 2: Output layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + index=2 + ) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() From f07eec34286cb23eb9b2a42d6c0c1258d8a9b375 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 15 May 2026 21:29:44 +0000 Subject: [PATCH 61/63] cleanup --- .../xegpu/fused_attention_schedule.py | 101 +++++------------- 1 file changed, 28 insertions(+), 73 deletions(-) diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 5d78e748..5a64abf5 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -197,50 +197,6 @@ def bundle_xegpu_fused_attention_schedule( if stop_at_stage == "outer-tiled": raise PipelineInterrupt() - # Match Q, K, V slices inside the forall loop - # K slice is the first operand of the transpose op - transpose_op = match_and_split(forall_loop, ops={"linalg.transpose"}, nhandles=1)[0] - k_slice = transform.get_producer_of_operand(anytype, transpose_op, operand_number=0) - # Q slice is the first operand of the first batch matmul - batch_matmuls = match_and_split( - forall_loop, ops={"linalg.batch_matmul"}, nhandles=2 - ) - q_slice = transform.get_producer_of_operand( - anytype, batch_matmuls[0], operand_number=0 - ) - # V slice is the second operand of the last batch matmul (inside the forall loop) - # Need to match the tiled version of the last matmul inside the loop - last_matmul = batch_matmuls[1] - v_slice = transform.get_producer_of_operand(anytype, last_matmul, operand_number=1) - - # Match the scaling operation (linalg.mul) to get the scaling factor - # The QK output is scaled before softmax: QK * scale - mul_op = match_and_split(forall_loop, ops={"linalg.mul"}, nhandles=1)[0] - scale_slice = transform.get_producer_of_operand(anytype, mul_op, operand_number=1) - # transform.print_(target=k_slice, name="k_slice") - # transform.print_(target=q_slice, name="q_slice") - # transform.print_(target=v_slice, name="v_slice") - # transform.print_(target=scale_slice, name="scale_slice") - # transform.print_(target=last_matmul, name="tiled_attention_weights_v_matmul") - - # Generate fused attention computation with inner tiling (flash attention) - # This replaces the current unfused computation with a tiled loop that - # maintains online max and sum for efficient memory usage - # tile_size = 64 - # new_output = generate_fused_attention( - # q_slice=q_slice, - # k_slice=k_slice, - # scale_slice=scale_slice, - # v_slice=v_slice, - # output=last_matmul, - # tile_size=tile_size, - # ) - # transform.apply_cse(func) - # canonicalize(func) - - if stop_at_stage == "inner-tiled": - raise PipelineInterrupt() - # vectorize func = structured.VectorizeChildrenAndApplyPatternsOp( func, @@ -293,35 +249,33 @@ def bundle_xegpu_fused_attention_schedule( # Its first operand is the q load (vector.transfer_read) # Its second operand is the k load (vector.transfer_read) first_contract = contract_ops[0] - q_load = transform.get_producer_of_operand(anytype, first_contract, operand_number=0) - k_load = transform.get_producer_of_operand(anytype, first_contract, operand_number=1) + q_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=0 + ) + k_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=1 + ) # # Second vector.contract is attention_weights @ V # # Its second operand is the v load (vector.transfer_read) second_contract = contract_ops[1] - v_load = transform.get_producer_of_operand(anytype, second_contract, operand_number=1) - - # Extract memrefs from the loads (first operand of vector.transfer_read) - # q_memref = transform.get_operand(anytype, q_load, [0]) - # k_memref = transform.get_operand(anytype, k_load, [0]) - # v_memref = transform.get_operand(anytype, v_load, [0]) + v_load = transform.get_producer_of_operand( + anytype, second_contract, operand_number=1 + ) # Match arith.mulf to get the scale parameter # The scale is the second operand of arith.mulf (the constant) mulf_op = match_and_split(func, ops={"arith.mulf"}, nhandles=1)[0] scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) - # Debug prints to verify we got the right memrefs and scale - # transform.print_(target=q_load, name="q_memref") - # transform.print_(target=k_load, name="k_memref") - # transform.print_(target=v_load, name="v_memref") - # transform.print_(target=scale, name="scale") + if stop_at_stage == "bufferized": + raise PipelineInterrupt() # Generate fused attention computation with inner tiling # This replaces the second vector.contract (attention_weights @ V) with a tiled # loop that implements online softmax for efficient memory usage tile_size = 64 # Tile size for reduction dimension (K/V sequence length) - new_output = generate_fused_attention( + generate_fused_attention( q_load=q_load, k_load=k_load, v_load=v_load, @@ -332,7 +286,7 @@ def bundle_xegpu_fused_attention_schedule( transform.apply_cse(func) canonicalize(func) - if stop_at_stage == "bufferized": + if stop_at_stage == "inner-tiled": raise PipelineInterrupt() # convert forall to parallel @@ -412,7 +366,6 @@ def bundle_xegpu_fused_attention_schedule( out_sg_data = q_sg_data out_inst_data = q_inst_data - layout_128x16_sg_layout = [8, 1] layout_128x16_sg_data = [16, 16] layout_128x16_inst_data = [8, 16] @@ -423,17 +376,19 @@ def bundle_xegpu_fused_attention_schedule( # Set layout attributes for xegpu.store_nd ops. store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] - xegpu.set_anchor_layout(store_nd_op, sg_layout=out_sg_layout, sg_data=out_sg_data, inst_data=out_inst_data) + xegpu.set_anchor_layout( + store_nd_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + ) # Set layout for xegpu.load_nd ops (9 total: 1 Q, 4 K, 4 V) load_nd_ops = match_and_split(gpu_func, ops={"xegpu.load_nd"}, nhandles=9) # First load_nd: Q layout xegpu.set_anchor_layout( - load_nd_ops[0], - sg_layout=q_sg_layout, - sg_data=q_sg_data, - inst_data=q_inst_data + load_nd_ops[0], sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data ) # Next 4 load_nd ops: K layout @@ -442,7 +397,7 @@ def bundle_xegpu_fused_attention_schedule( load_nd_ops[i], sg_layout=k_sg_layout, sg_data=k_sg_data, - inst_data=k_inst_data + inst_data=k_inst_data, ) # Last 4 load_nd ops: V layout @@ -451,7 +406,7 @@ def bundle_xegpu_fused_attention_schedule( load_nd_ops[i], sg_layout=v_sg_layout, sg_data=v_sg_data, - inst_data=v_inst_data + inst_data=v_inst_data, ) # Set layout for xegpu.dpas ops (8 total: 4 for Q@K, 4 for P@V) @@ -466,7 +421,7 @@ def bundle_xegpu_fused_attention_schedule( sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data, - index=0 + index=0, ) # Index 1: K^T layout xegpu.set_anchor_layout( @@ -475,7 +430,7 @@ def bundle_xegpu_fused_attention_schedule( sg_data=kt_sg_data, inst_data=kt_inst_data, order=kt_order, - index=1 + index=1, ) # Index 2: QK output layout (128x16) xegpu.set_anchor_layout( @@ -483,7 +438,7 @@ def bundle_xegpu_fused_attention_schedule( sg_layout=layout_128x16_sg_layout, sg_data=layout_128x16_sg_data, inst_data=layout_128x16_inst_data, - index=2 + index=2, ) # Layouts for second 4 dpas ops (P@V): @@ -495,7 +450,7 @@ def bundle_xegpu_fused_attention_schedule( sg_layout=qk_sg_layout, sg_data=qk_sg_data, inst_data=qk_inst_data, - index=0 + index=0, ) # Index 1: V layout xegpu.set_anchor_layout( @@ -503,7 +458,7 @@ def bundle_xegpu_fused_attention_schedule( sg_layout=v_sg_layout, sg_data=v_sg_data, inst_data=v_inst_data, - index=1 + index=1, ) # Index 2: Output layout xegpu.set_anchor_layout( @@ -511,7 +466,7 @@ def bundle_xegpu_fused_attention_schedule( sg_layout=out_sg_layout, sg_data=out_sg_data, inst_data=out_inst_data, - index=2 + index=2, ) if stop_at_stage == "xegpu-wg": From 1eaa9237e59c0c5d7263f9156f4c7b9639e3e257 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 15 May 2026 21:33:49 +0000 Subject: [PATCH 62/63] cleanup --- examples/xegpu/fused_attention.py | 7 +++++++ lighthouse/schedule/xegpu/fused_attention_schedule.py | 4 +++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py index c375470c..65302764 100644 --- a/examples/xegpu/fused_attention.py +++ b/examples/xegpu/fused_attention.py @@ -239,6 +239,12 @@ def parse_cli(): default=16, help="Subgroup size", ) + parser.add_argument( + "--inner-loop-tile-size", + type=int, + default=64, + help="Tile size for the inner reduction dimension (K/V sequence length)", + ) parser.add_argument( "--nruns", type=int, @@ -300,6 +306,7 @@ def parse_cli(): "wg_rows": args.wg_rows, "sg_rows": args.sg_rows, "subgroup_size": args.subgroup_size, + "inner_loop_tile_size": args.inner_loop_tile_size, } Z = args.batch_size diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py index 5a64abf5..cb99e0e9 100644 --- a/lighthouse/schedule/xegpu/fused_attention_schedule.py +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -274,7 +274,9 @@ def bundle_xegpu_fused_attention_schedule( # Generate fused attention computation with inner tiling # This replaces the second vector.contract (attention_weights @ V) with a tiled # loop that implements online softmax for efficient memory usage - tile_size = 64 # Tile size for reduction dimension (K/V sequence length) + tile_size = parameters.get( + "inner_loop_tile_size", 64 + ) # Tile size for reduction dimension (K/V sequence length) generate_fused_attention( q_load=q_load, k_load=k_load, From 2302a342ade4470fb6e0ad30bd7a9af50b519e95 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 15 May 2026 21:43:12 +0000 Subject: [PATCH 63/63] cleanup --- .../ops/generate_fused_attention.py | 243 ++++++++++++------ 1 file changed, 163 insertions(+), 80 deletions(-) diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py index 656f3413..35f9822e 100644 --- a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -2,7 +2,7 @@ import numpy as np from mlir import ir -from mlir.dialects import ext, transform, arith, scf, linalg, tensor, math, vector +from mlir.dialects import ext, transform, arith, scf, math, vector from mlir.dialects.transform import DiagnosedSilenceableFailure from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect @@ -77,14 +77,11 @@ def apply( # Extract the scale scalar value from scale_op (arith.constant) scale_attr = scale_op.attributes["value"] - # Extract the scalar scale value from the scale_attr DenseElementsAttr scale_dense_attr = ir.DenseElementsAttr(scale_attr) - # Get the first element as the scale value (all elements are the same in a splat) scale_np_array = np.array(scale_dense_attr) scale_value = float(scale_np_array.flat[0]) # Extract wg_rows and d_head from q_load result type - # q_load is vector.transfer_read that produces a vector q_load_result = q_load_op.results[0] q_vector_type = ir.VectorType(q_load_result.type) wg_rows = q_vector_type.shape[0] @@ -98,35 +95,65 @@ def apply( # Build the fused attention computation with ir.InsertionPoint(output_op): - # 1. Define m_i_init: vector of shape [wg_rows] with neg_inf values + # Define m_i_init: vector of shape [wg_rows] with neg_inf values m_i_vector_type = ir.VectorType.get([wg_rows], element_type) - neg_inf_value = 0xFC00 if element_type == ir.F16Type.get() else float("-inf") - m_i_values = np.full(wg_rows, neg_inf_value, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) - m_i_init_attr = ir.DenseElementsAttr.get(m_i_values, type=m_i_vector_type) + neg_inf_value = ( + 0xFC00 if element_type == ir.F16Type.get() else float("-inf") + ) + m_i_values = np.full( + wg_rows, + neg_inf_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + m_i_init_attr = ir.DenseElementsAttr.get( + m_i_values, type=m_i_vector_type + ) m_i_init = arith.constant(m_i_vector_type, m_i_init_attr) - # 2. Define l_i_init: vector of shape [wg_rows] with zero values + # Define l_i_init: vector of shape [wg_rows] with zero values l_i_vector_type = ir.VectorType.get([wg_rows], element_type) - l_i_values = np.zeros(wg_rows, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) - l_i_init_attr = ir.DenseElementsAttr.get(l_i_values, type=l_i_vector_type) + l_i_values = np.zeros( + wg_rows, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + l_i_init_attr = ir.DenseElementsAttr.get( + l_i_values, type=l_i_vector_type + ) l_i_init = arith.constant(l_i_vector_type, l_i_init_attr) - # 3. Define acc_init: vector of shape [wg_rows, d_head] with zero values + # Define acc_init: vector of shape [wg_rows, d_head] with zero values acc_vector_type = ir.VectorType.get([wg_rows, d_head], element_type) - acc_values = np.zeros((wg_rows, d_head), dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) - acc_init_attr = ir.DenseElementsAttr.get(acc_values, type=acc_vector_type) + acc_values = np.zeros( + (wg_rows, d_head), + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + acc_init_attr = ir.DenseElementsAttr.get( + acc_values, type=acc_vector_type + ) acc_init = arith.constant(acc_vector_type, acc_init_attr) # Get n_ctx from k_load result type (first dimension size) k_load_result = k_load_op.results[0] k_vector_type = ir.VectorType(k_load_result.type) n_ctx = k_vector_type.shape[0] - - - + # Define scale vector: vector of shape [wg_rows] with the scale value scale_vector_type = ir.VectorType.get([wg_rows], element_type) - scale_values = np.full((wg_rows), scale_value, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) - scale_init_attr = ir.DenseElementsAttr.get(scale_values, type=scale_vector_type) + scale_values = np.full( + (wg_rows), + scale_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + scale_init_attr = ir.DenseElementsAttr.get( + scale_values, type=scale_vector_type + ) scale_vector = arith.constant(scale_vector_type, scale_init_attr) # Create loop bounds @@ -136,7 +163,9 @@ def apply( c_tile_size = arith.constant(index_type, tile_size_value) # Create scf.for loop that iterates from 0 to n_ctx in steps of tile_size - loop = scf.ForOp(c0, c_n_ctx, c_tile_size, [m_i_init, l_i_init, acc_init]) + loop = scf.ForOp( + c0, c_n_ctx, c_tile_size, [m_i_init, l_i_init, acc_init] + ) with ir.InsertionPoint(loop.body): # Get the loop induction variable and iter_args @@ -163,7 +192,7 @@ def apply( offset = arith.constant(index_type, i * k_subtile_size) k_tile_offsets.append(offset) - # Step 1: Load and process K tiles (unrolled) + # Load and process K tiles (unrolled) # Each K tile is [16, d_head], transposed to [d_head, 16], contracted to [wg_rows, 16] qkt_chunks = [] @@ -176,22 +205,35 @@ def apply( k_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) - indexing_maps = ir.ArrayAttr.get([ - ir.AffineMapAttr.get(q_map), - ir.AffineMapAttr.get(k_map), - ir.AffineMapAttr.get(out_map) - ]) + indexing_maps = ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(q_map), + ir.AffineMapAttr.get(k_map), + ir.AffineMapAttr.get(out_map), + ] + ) - iterator_types = ir.ArrayAttr.get([ - ir.Attribute.parse("#vector.iterator_type"), - ir.Attribute.parse("#vector.iterator_type"), - ir.Attribute.parse("#vector.iterator_type") - ]) + iterator_types = ir.ArrayAttr.get( + [ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ] + ) # Accumulator for Q@K chunks - qkt_chunk_type = ir.VectorType.get([wg_rows, k_subtile_size], element_type) - qkt_chunk_acc_values = np.zeros((wg_rows, k_subtile_size), dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) - qkt_chunk_acc_attr = ir.DenseElementsAttr.get(qkt_chunk_acc_values, type=qkt_chunk_type) + qkt_chunk_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + qkt_chunk_acc_values = np.zeros( + (wg_rows, k_subtile_size), + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + qkt_chunk_acc_attr = ir.DenseElementsAttr.get( + qkt_chunk_acc_values, type=qkt_chunk_type + ) qkt_chunk_acc = arith.constant(qkt_chunk_type, qkt_chunk_acc_attr) for tile_idx in range(num_k_tiles): @@ -203,18 +245,22 @@ def apply( k_tile_indices[-2] = k_tile_offset # Load K tile: [16, d_head] - k_tile_type = ir.VectorType.get([k_subtile_size, d_head], element_type) + k_tile_type = ir.VectorType.get( + [k_subtile_size, d_head], element_type + ) k_tile = vector.TransferReadOp( k_tile_type, k_memref, k_tile_indices, k_perm_map, padding, - in_bounds=in_bounds + in_bounds=in_bounds, ).result # Transpose K tile: [16, d_head] -> [d_head, 16] - k_transpose_type = ir.VectorType.get([d_head, k_subtile_size], element_type) + k_transpose_type = ir.VectorType.get( + [d_head, k_subtile_size], element_type + ) k_transpose = vector.transpose(k_transpose_type, k_tile, [1, 0]) # Contract Q @ K_transpose: [wg_rows, d_head] @ [d_head, 16] -> [wg_rows, 16] @@ -224,44 +270,62 @@ def apply( k_transpose, qkt_chunk_acc, indexing_maps=indexing_maps, - iterator_types=iterator_types + iterator_types=iterator_types, ) qkt_chunks.append(qkt_chunk) - # Step 2: Elementwise maximum across all Q@K chunks + # Elementwise maximum across all Q@K chunks # Build tree of maximumf operations qkt_max_combined = qkt_chunks[0] for i in range(1, num_k_tiles): - qkt_max_combined = arith.maximumf(qkt_max_combined, qkt_chunks[i]) + qkt_max_combined = arith.maximumf( + qkt_max_combined, qkt_chunks[i] + ) - # Step 3: Final multi_reduction to get row-wise max: [wg_rows, 16] -> [wg_rows] + # Final multi_reduction to get row-wise max: [wg_rows, 16] -> [wg_rows] qkt_max = vector.multi_reduction( kind="maxnumf", source=qkt_max_combined, acc=m_i_init, - reduction_dims=[1] + reduction_dims=[1], ) - # Step 4: Scale the max: qkt_max_scaled = qkt_max * scale + # Scale the max: qkt_max_scaled = qkt_max * scale # Both have shape [wg_rows] qkt_max_scaled = arith.mulf(qkt_max, scale_vector) - # Step 5: Compute m_ij = max(m_i, qkt_max_scaled) + # Compute m_ij = max(m_i, qkt_max_scaled) # Both have shape [wg_rows] m_ij = arith.maximumf(m_i, qkt_max_scaled) - # Step 6-9: Apply softmax to each Q@K chunk + # Apply softmax to each Q@K chunk # Scale constant for chunks: [wg_rows, 16] - scale_chunk_type = ir.VectorType.get([wg_rows, k_subtile_size], element_type) - scale_chunk_values = np.full((wg_rows, k_subtile_size), scale_value, dtype=np.float16 if element_type == ir.F16Type.get() else np.float32) - scale_chunk_attr = ir.DenseElementsAttr.get(scale_chunk_values, type=scale_chunk_type) + scale_chunk_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + scale_chunk_values = np.full( + (wg_rows, k_subtile_size), + scale_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + scale_chunk_attr = ir.DenseElementsAttr.get( + scale_chunk_values, type=scale_chunk_type + ) scale_chunk = arith.constant(scale_chunk_type, scale_chunk_attr) # Broadcast m_ij from [wg_rows] to [wg_rows, 16] - m_ij_bcasted_type = ir.VectorType.get([k_subtile_size, wg_rows], element_type) + m_ij_bcasted_type = ir.VectorType.get( + [k_subtile_size, wg_rows], element_type + ) m_ij_bcasted = vector.broadcast(m_ij_bcasted_type, m_ij) - m_ij_transposed_type = ir.VectorType.get([wg_rows, k_subtile_size], element_type) - m_ij_transposed = vector.transpose(m_ij_transposed_type, m_ij_bcasted, [1, 0]) + m_ij_transposed_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + m_ij_transposed = vector.transpose( + m_ij_transposed_type, m_ij_bcasted, [1, 0] + ) # Apply softmax to each chunk qkt_exp_chunks = [] @@ -276,37 +340,45 @@ def apply( qkt_exp = math.exp(qkt_centered) qkt_exp_chunks.append(qkt_exp) - # Step 10: Elementwise sum across all exp chunks + # Elementwise sum across all exp chunks qkt_exp_combined = qkt_exp_chunks[0] for i in range(1, num_k_tiles): - qkt_exp_combined = arith.addf(qkt_exp_combined, qkt_exp_chunks[i]) + qkt_exp_combined = arith.addf( + qkt_exp_combined, qkt_exp_chunks[i] + ) # Final multi_reduction to get row-wise sum: [wg_rows, 16] -> [wg_rows] l_ij = vector.multi_reduction( kind="add", source=qkt_exp_combined, acc=l_i_init, - reduction_dims=[1] + reduction_dims=[1], ) - # Step 11: Compute alpha = exp(m_i - m_ij) + # Compute alpha = exp(m_i - m_ij) m_diff = arith.subf(m_i, m_ij) alpha = math.exp(m_diff) - # Step 12: Update l_i: l_i_updated = l_i * alpha + l_ij + # Update l_i: l_i_updated = l_i * alpha + l_ij l_i_scaled = arith.mulf(l_i, alpha) l_i_updated = arith.addf(l_i_scaled, l_ij) - # Step 13: Broadcast alpha from [wg_rows] to [wg_rows, d_head] - alpha_bcasted_type = ir.VectorType.get([d_head, wg_rows], element_type) + # Broadcast alpha from [wg_rows] to [wg_rows, d_head] + alpha_bcasted_type = ir.VectorType.get( + [d_head, wg_rows], element_type + ) alpha_bcasted = vector.broadcast(alpha_bcasted_type, alpha) - alpha_transposed_type = ir.VectorType.get([wg_rows, d_head], element_type) - alpha_transposed = vector.transpose(alpha_transposed_type, alpha_bcasted, [1, 0]) + alpha_transposed_type = ir.VectorType.get( + [wg_rows, d_head], element_type + ) + alpha_transposed = vector.transpose( + alpha_transposed_type, alpha_bcasted, [1, 0] + ) - # Step 14: Update accumulator: acc_updated = acc * alpha_bcasted + # Update accumulator: acc_updated = acc * alpha_bcasted acc_updated = arith.mulf(acc, alpha_transposed) - # Step 15-16: Load V tiles and compute attention-weighted values + # Load V tiles and compute attention-weighted values # Get V load parameters v_memref = v_load_op.operands[0] v_load_indices = list(v_load_op.operands[1:-1]) @@ -319,17 +391,21 @@ def apply( v_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) pv_out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) - indexing_maps_pv = ir.ArrayAttr.get([ - ir.AffineMapAttr.get(qkt_exp_map), - ir.AffineMapAttr.get(v_map), - ir.AffineMapAttr.get(pv_out_map) - ]) + indexing_maps_pv = ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(qkt_exp_map), + ir.AffineMapAttr.get(v_map), + ir.AffineMapAttr.get(pv_out_map), + ] + ) - iterator_types_pv = ir.ArrayAttr.get([ - ir.Attribute.parse("#vector.iterator_type"), - ir.Attribute.parse("#vector.iterator_type"), - ir.Attribute.parse("#vector.iterator_type") - ]) + iterator_types_pv = ir.ArrayAttr.get( + [ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ] + ) # Load and process V tiles (unrolled), accumulating results pv_out = acc_updated @@ -342,14 +418,16 @@ def apply( v_tile_indices[-2] = v_tile_offset # Load V tile: [16, d_head] - v_tile_type = ir.VectorType.get([k_subtile_size, d_head], element_type) + v_tile_type = ir.VectorType.get( + [k_subtile_size, d_head], element_type + ) v_tile = vector.TransferReadOp( v_tile_type, v_memref, v_tile_indices, v_perm_map, v_padding, - in_bounds=v_in_bounds + in_bounds=v_in_bounds, ).result # Contract qkt_exp_chunk @ v_tile: [wg_rows, 16] @ [16, d_head] -> [wg_rows, d_head] @@ -360,7 +438,7 @@ def apply( v_tile, pv_out, indexing_maps=indexing_maps_pv, - iterator_types=iterator_types_pv + iterator_types=iterator_types_pv, ) # Yield the updated iter args @@ -370,15 +448,20 @@ def apply( pv_out = loop.results[2] l_i_out = loop.results[1] with ir.InsertionPoint.after(loop): - # Step 17: Normalize the output: output_final = pv_out / l_i_out + # Normalize the output: output_final = pv_out / l_i_out # Need to broadcast l_i_out from [wg_rows] to [wg_rows, d_head] - l_i_out_bcasted_type = ir.VectorType.get([d_head, wg_rows], element_type) + l_i_out_bcasted_type = ir.VectorType.get( + [d_head, wg_rows], element_type + ) l_i_out_bcasted = vector.broadcast(l_i_out_bcasted_type, l_i_out) - l_i_out_transposed_type = ir.VectorType.get([wg_rows, d_head], element_type) - l_i_out_transposed = vector.transpose(l_i_out_transposed_type, l_i_out_bcasted, [1, 0]) + l_i_out_transposed_type = ir.VectorType.get( + [wg_rows, d_head], element_type + ) + l_i_out_transposed = vector.transpose( + l_i_out_transposed_type, l_i_out_bcasted, [1, 0] + ) output_final = arith.divf(pv_out, l_i_out_transposed) - # Replace all uses of the original output operation with the final loop result output_op.results[0].replace_all_uses_with(output_final)