diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py new file mode 100644 index 00000000..2bf59882 --- /dev/null +++ b/examples/xegpu/fused_attention.py @@ -0,0 +1,359 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU fused attention benchmark. +""" + +import argparse +from typing import Optional +from functools import cached_property + +import numpy as np +from mlir import ir + +from lighthouse import dialects as lh_dialects +from lighthouse.execution.runner import Runner +from lighthouse.pipeline.driver import TransformDriver +from lighthouse.execution import GPUMemoryManager +from lighthouse.utils.numpy import mlir_to_numpy_dtype +from lighthouse.ingress.mlir_gen import get_mlir_elem_type +from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import ( + generate_gpu_fused_attention_payload, +) +from lighthouse.schedule.xegpu.fused_attention_schedule import ( + get_fused_attention_schedule_module, +) + + +def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int): + """ + Complexity of fused attention operation. + + For each batch and head: + - Q @ K^T: O(n_ctx^2 * n_head) operations + - Softmax: O(n_ctx^2) operations + - Attention @ V: O(n_ctx^2 * n_head) operations + Total: approximately 2*n_ctx^2*n_head FLOPs per batch and head + """ + # Approximation: 2 * n_ctx^2 * n_head FLOPs per batch and head + flop_count = Z * H * 2 * n_ctx * n_ctx * n_head + # Memory: read Q, K, V and write output + memory_reads = 3 * Z * H * n_ctx * n_head * nbytes + memory_writes = Z * H * n_ctx * n_head * nbytes + return flop_count, memory_reads, memory_writes + + +def check_correctness( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + output_arr: np.ndarray, + verbose: int = 0, +) -> bool: + """ + Check correctness of fused attention output. + + Reference implementation: + - scores = Q @ K^T / sqrt(n_head) + - attention_weights = softmax(scores, dim=-1) + - output = attention_weights @ V + """ + # Use float32 for computation + Q_f32 = Q.astype(np.float32) + K_f32 = K.astype(np.float32) + V_f32 = V.astype(np.float32) + + Z, H, n_ctx, n_head = Q.shape + scale = 1.0 / np.sqrt(n_head) + + output_ref = np.zeros_like(Q_f32) + + # Compute reference for each batch and head + for z in range(Z): + for h in range(H): + # scores = Q @ K^T / sqrt(n_head) + scores = Q_f32[z, h] @ K_f32[z, h].T * scale + + # softmax along last dimension + max_vals = np.max(scores, axis=1, keepdims=True) + exp_vals = np.exp(scores - max_vals) + sum_vals = np.sum(exp_vals, axis=1, keepdims=True) + attention_weights = exp_vals / sum_vals + + # output = attention_weights @ V + output_ref[z, h] = attention_weights @ V_f32[z, h] + + output = output_arr.astype(np.float32) + + if verbose > 1: + print("Reference solution (first batch, first head, first 5 rows):") + print(output_ref[0, 0, :5]) + print("Computed solution (first batch, first head, first 5 rows):") + print(output[0, 0, :5]) + + # Check values match reference + values_ok = np.allclose(output, output_ref, rtol=1e-3, atol=1e-4) + + success = values_ok + + if verbose: + if success: + print("PASSED") + else: + print("FAILED!") + if not values_ok: + max_diff = np.abs(output - output_ref).max() + print(f" Values mismatch. Max abs diff: {max_diff:.6e}") + return success + + +class XeGPUFusedAttention: + """ + Fused attention workload on XeGPU. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + All Q, K, V matrices have shape (Z, H, n_ctx, n_head) where: + - Z: batch size + - H: number of heads + - n_ctx: context length + - n_head: head dimension + """ + + def __init__( + self, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: str = "f32", + ): + self.Z = Z + self.H = H + self.n_ctx = n_ctx + self.n_head = n_head + self.shape = (Z, H, n_ctx, n_head) + assert dtype == "f32", "Only f32 type is supported for fused attention" + self.elem_type = get_mlir_elem_type(dtype) + self.dtype = mlir_to_numpy_dtype(self.elem_type) + self.memory_manager_class = GPUMemoryManager + self.payload_function_name = "payload" + + @cached_property + def _initial_host_arrays(self) -> tuple[np.ndarray]: + """Generate initial values on host with numpy.""" + np.random.seed(42) + # Initialize Q, K, V with small random values + Q = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + K = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + V = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + output_arr = np.zeros(self.shape, dtype=self.dtype) + return (output_arr, Q, K, V) + + def get_complexity(self) -> tuple[int, int, int]: + nbytes = np.dtype(self.dtype).itemsize + return fused_attention_complexity( + self.Z, self.H, self.n_ctx, self.n_head, nbytes + ) + + def payload_module(self) -> ir.Module: + """Generate MLIR module for fused attention payload.""" + return generate_gpu_fused_attention_payload( + func_name=self.payload_function_name, + Z=self.Z, + H=self.H, + n_ctx=self.n_ctx, + n_head=self.n_head, + dtype=self.elem_type, + ) + + def schedule_modules( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> list[ir.Module]: + """Generate transform schedule for fused attention.""" + return [ + Runner.get_bench_wrapper_schedule(self.payload_function_name), + get_fused_attention_schedule_module( + stop_at_stage=stop_at_stage, + parameters=parameters, + ), + ] + + def shared_libs(self) -> list[str]: + return ["libmlir_levelzero_runtime.so"] + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Fused Attention using MLIR XeGPU", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size (Z)", + ) + parser.add_argument( + "--num-heads", + type=int, + default=1, + help="Number of attention heads (H)", + ) + parser.add_argument( + "--n-ctx", + type=int, + default=512, + help="Context length (sequence length)", + ) + parser.add_argument( + "--n-head", + type=int, + default=64, + help="Head dimension", + ) + parser.add_argument( + "--wg-tile-size", + type=int, + default=64, + help="Workgroup tile size for the collapsed batch dimension (Z*H*n_ctx)", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the fused attention computation.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "outer-tiled", + "inner-tiled", + "vectorized", + "bufferized", + "gpu-outlining", + "xegpu-initial", + "xegpu-wg", + "final", + ], + help="Dump kernel IR at different stages of lowering and exit without " + "executing the kernel.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + parser.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="Increase output verbosity (e.g. print reference and computed solutions).", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "batch_size": args.batch_size, + "num_heads": args.num_heads, + "n_ctx": args.n_ctx, + "n_head": args.n_head, + "wg_tile_size": args.wg_tile_size, + } + + Z = args.batch_size + H = args.num_heads + n_ctx = args.n_ctx + n_head = args.n_head + dtype = "f32" + + with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUFusedAttention(Z=Z, H=H, n_ctx=n_ctx, n_head=n_head, dtype=dtype) + + if args.dump_kernel or args.dump_schedule: + pipeline = TransformDriver( + wload.schedule_modules( + stop_at_stage=args.dump_kernel, parameters=params + ) + ) + payload = pipeline.apply(wload.payload_module()) + if args.dump_kernel: + print(payload) + if args.dump_schedule: + for schedule_module in wload.schedule_modules(parameters=params): + print(schedule_module) + else: + pipeline = TransformDriver(wload.schedule_modules(parameters=params)) + payload = pipeline.apply(wload.payload_module()) + runner = Runner( + payload, + mem_manager_cls=wload.memory_manager_class, + shared_libs=wload.shared_libs(), + ) + if args.check_result: + # Setup callback function to copy result from device to host. + result_host_copy, argument_access_callback = ( + Runner.get_gpu_argument_access_callback(wload.shape, wload.dtype) + ) + + # Execute kernel once. + runner.execute( + host_input_buffers=wload._initial_host_arrays, + payload_function_name=wload.payload_function_name, + argument_access_callback=argument_access_callback, + ) + + # Compute reference solution on host. + Q, K, V = wload._initial_host_arrays[1:4] + success = check_correctness( + Q, + K, + V, + result_host_copy, + verbose=args.verbose, + ) + if not success: + raise ValueError("Result mismatch!") + else: + print("Result is correct. Proceeding to benchmark...") + + times = runner.benchmark( + host_input_buffers=wload._initial_host_arrays, + nruns=args.nruns, + nwarmup=args.nwarmup, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + print( + f"batch-size={Z} " + f"num-heads={H} " + f"n-ctx={n_ctx} " + f"n-head={n_head} " + f"dt={dtype} " + f"time(us): {elapsed:.2f} " + f"GFLOPS: {gflops:.2f} " + ) diff --git a/examples/xegpu/softmax.py b/examples/xegpu/softmax.py index df9e9624..f75613d0 100644 --- a/examples/xegpu/softmax.py +++ b/examples/xegpu/softmax.py @@ -155,7 +155,7 @@ def parse_cli(): "--sizes", type=int, nargs=2, - default=[1024, 64], + default=[1024, 512], help="M,N matrix sizes (MxN)", ) parser.add_argument( @@ -176,6 +176,12 @@ def parse_cli(): default=16, help="Subgroup size.", ) + parser.add_argument( + "--reduction-step-size", + type=int, + default=16, + help="Step size for reduction loop tiling (optional).", + ) parser.add_argument( "--nruns", type=int, @@ -201,6 +207,7 @@ def parse_cli(): "tiled", "vectorized", "bufferized", + "gpu-outlining", "xegpu-initial", "xegpu-wg", "final", @@ -232,6 +239,7 @@ def parse_cli(): "wg_rows": args.wg_rows, "sg_rows": args.sg_rows, "subgroup_size": args.subgroup_size, + "reduction_step_size": args.reduction_step_size, } M, N = args.sizes @@ -282,6 +290,8 @@ def parse_cli(): ) if not success: raise ValueError("Result mismatch!") + else: + print("Result is correct. Proceeding to benchmark...") times = runner.benchmark( host_input_buffers=wload._initial_host_arrays, diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py index aba08bf0..997522a2 100644 --- a/lighthouse/dialects/transform/transform_ext/__init__.py +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -9,6 +9,7 @@ from .ops.extract_handle import extract_handle from .ops.get_tileable_consumers import get_tileable_consumers from .ops.get_tiling_sizes import get_tiling_sizes +from .ops.update_address_space import update_address_space __all__ = [ "TransformExtensionDialect", @@ -21,5 +22,6 @@ "param_cmp_eq", "register_and_load", "replace", + "update_address_space", "wrap_in_benching_func", ] diff --git a/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py new file mode 100644 index 00000000..8d0b6041 --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/update_address_space.py @@ -0,0 +1,100 @@ +from mlir import ir +from mlir.dialects import ext, transform, memref +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect + + +class UpdateAddressSpace( + TransformExtensionDialect.Operation, name="update_address_space" +): + """Update the address space of a memref allocation operation. + + Takes a target memref allocation operation and updates its address space + to the provided value. + """ + + target: ext.Operand[transform.AnyOpType] + address_space: ir.IntegerAttr + updated_op: ext.Result[transform.AnyOpType[()]] = ext.infer_result() + + @classmethod + def attach_interface_impls(cls, ctx=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "UpdateAddressSpace", + rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + # Get the target operations to transform + target_op = state.get_payload_ops(op.target)[0] + # Get the address space value from the attribute + address_space_value = ir.IntegerAttr(op.address_space).value + new_ops = [] + + # Verify this is a memref.alloca operation + if target_op.OPERATION_NAME != "memref.alloca": + return DiagnosedSilenceableFailure.emit_silenceable_error( + f"Expected memref.alloca operation, got {target_op.OPERATION_NAME}" + ) + + # Get the current result type (should be a MemRefType) + old_result_type = target_op.results[0].type + memref_type = ir.MemRefType(old_result_type) + # Create a new memref type with the specified address space + new_memref_type = ir.MemRefType.get( + memref_type.shape, + memref_type.element_type, + layout=memref_type.layout, + memory_space=ir.Attribute.parse(f"{address_space_value}"), + ) + + # Replace the operation with a new one that has the updated type + with ir.InsertionPoint(target_op): + # Get the operands from the original alloca (dynamic sizes and symbols) + dynamic_sizes = list( + target_op.operands[: target_op.attributes["operandSegmentSizes"][0]] + ) + symbol_operands = list( + target_op.operands[target_op.attributes["operandSegmentSizes"][0] :] + ) + # Create a new alloca with the updated type + new_alloca = memref.alloca( + new_memref_type, dynamic_sizes, symbol_operands + ) + # Replace all uses of the old operation with the new one + # FIXME: This won't handle operations that consume the memref type and + # return a new memref (such as subview). + rewriter.replace_op(target_op, [new_alloca]) + new_ops.append(new_alloca.owner) + + # Set the results to the new operations + results.set_ops(op.updated_op, new_ops) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "UpdateAddressSpace") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + transform.consumes_handle(op.op_operands[:1], effects) + transform.produces_handle(op.results, effects) + transform.modifies_payload(effects) + + +def update_address_space( + target: ir.Value, + address_space: int | ir.IntegerAttr, +) -> ir.Value: + if not isinstance(address_space, ir.IntegerAttr): + address_space = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), address_space + ) + return UpdateAddressSpace(target, address_space=address_space).updated_op diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py new file mode 100644 index 00000000..d4a80856 --- /dev/null +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -0,0 +1,137 @@ +"""Generate MLIR payload for GPU fused attention operation.""" + +from mlir import ir +from mlir.dialects import arith, bufferization, linalg, tensor + +from lighthouse.utils.mlir import func_cif +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs +from lighthouse.ingress.mlir_gen.utils import emit_buf_to_tensor + + +def generate_gpu_fused_attention_payload( + func_name: str, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: ir.Type, +) -> ir.Module: + """ + Generate MLIR module for fused attention payload. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + Args: + func_name: Name of the payload function + Z: Batch size + H: Number of attention heads + n_ctx: Context length (sequence length) + n_head: Head dimension + dtype: MLIR element type (e.g., F32Type) + + Returns: + MLIR module containing the fused attention payload function + """ + mod = ir.Module.create() + shape = (Z, H, n_ctx, n_head) + memref_t = ir.MemRefType.get(shape, dtype) + + with ir.InsertionPoint(mod.body): + # Function signature: payload(output, Q, K, V) + @func_cif(memref_t, memref_t, memref_t, memref_t, name=func_name) + def payload(output, Q_arg, K_arg, V_arg): + # Convert memrefs to tensors + emit_buf_to_tensor(output, restrict=True, writable=True) + Q_tensor = emit_buf_to_tensor(Q_arg, restrict=True) + K_tensor = emit_buf_to_tensor(K_arg, restrict=True) + V_tensor = emit_buf_to_tensor(V_arg, restrict=True) + + # Collapse first 3 dimensions (Z, H, n_ctx) into a single batch dimension + # From (Z, H, n_ctx, n_head) to (Z*H*n_ctx, n_head) + batch_dim = Z * H * n_ctx + collapsed_shape_2d = (batch_dim, n_head) + + Q_2d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_2d, dtype), + Q_tensor, + reassociation=[[0, 1, 2], [3]], + ) + K_2d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_2d, dtype), + K_tensor, + reassociation=[[0, 1, 2], [3]], + ) + V_2d = tensor.collapse_shape( + ir.RankedTensorType.get(collapsed_shape_2d, dtype), + V_tensor, + reassociation=[[0, 1, 2], [3]], + ) + + # Step 1: Transpose K to get K^T + # Transpose from (batch_dim, n_head) to (n_head, batch_dim) + kt_shape_2d = (n_head, batch_dim) + kt_init = tensor.empty(kt_shape_2d, dtype) + K_transposed = linalg.transpose(K_2d, outs=[kt_init], permutation=[1, 0]) + + # Step 2: Compute Q @ K^T using matmul + # Q: (batch_dim, n_head) @ K^T: (n_head, batch_dim) + # Result: (batch_dim, batch_dim) + qkt_shape_2d = (batch_dim, batch_dim) + qkt_init = tensor.empty(qkt_shape_2d, dtype) + # Initialize with zeros for matmul accumulation + zero = arith.constant(dtype, 0.0) + qkt_init_filled = linalg.fill(zero, outs=[qkt_init]) + + # Matmul: Q @ K^T + qkt = linalg.matmul(Q_2d, K_transposed, outs=[qkt_init_filled]) + + # # Step 3: Scale by 1/sqrt(n_head) + # scale_factor = 1.0 / math.sqrt(n_head) + # scale_const = arith.constant(dtype, scale_factor) + + # # Create a tensor filled with the scale factor + # scale_tensor_init = tensor.empty(qkt_shape_2d, dtype) + # scale_tensor = linalg.fill(scale_const, outs=[scale_tensor_init]) + + # # Elementwise multiply qkt with scale tensor + # scaled_qkt_init = tensor.empty(qkt_shape_2d, dtype) + # scaled_qkt = linalg.mul(qkt, scale_tensor, outs=[scaled_qkt_init]) + + # Step 4: Apply softmax along the last dimension (dim=1 in 2D) + softmax_init = tensor.empty(qkt_shape_2d, dtype) + attention_weights = linalg.softmax( + result=[ir.RankedTensorType.get(qkt_shape_2d, dtype)], + input=qkt, + output=softmax_init, + dimension=1, + ) + + # Step 5: Multiply attention weights by V using matmul + # attention_weights: (batch_dim, batch_dim) @ V: (batch_dim, n_head) + # Result: (batch_dim, n_head) + output_2d_init = tensor.empty(collapsed_shape_2d, dtype) + output_2d_init_filled = linalg.fill(zero, outs=[output_2d_init]) + + result_2d = linalg.matmul( + attention_weights, V_2d, outs=[output_2d_init_filled] + ) + + # Expand back to 4D: (Z*H*n_ctx, n_head) -> (Z, H, n_ctx, n_head) + result = tensor.expand_shape( + ir.RankedTensorType.get(shape, dtype), + result_2d, + reassociation=[[0, 1, 2], [3]], + output_shape=[], + static_output_shape=shape, + ) + + # Materialize result back to output memref + bufferization.materialize_in_destination( + None, result, output, restrict=True, writable=True + ) + + # Emit utility functions for GPU memory management + emit_gpu_util_funcs(dtype, rank=4) + + return mod diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py new file mode 100644 index 00000000..6217d835 --- /dev/null +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -0,0 +1,282 @@ +"""Generate MLIR transform schedule for XeGPU fused attention operation.""" + +from typing import Optional + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured + +from lighthouse.pipeline.helper import ( + match, + match_and_split, + PipelineInterrupt, +) +from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary + + +def get_fused_attention_schedule_module( + stop_at_stage: Optional[str] = None, + parameters: Optional[dict] = None, +) -> ir.Module: + """ + Generate transform schedule for fused attention operation. + + The schedule performs the following transformations: + 1. Tile the fused attention operation + 2. Vectorize operations + 3. Bufferize tensors + 4. Convert to GPU dialect + 5. Lower to XeGPU operations + + Args: + stop_at_stage: Optional stage name to stop early (for debugging) + parameters: Dictionary with scheduling parameters: + - batch_size: Batch size (Z) + - num_heads: Number of attention heads (H) + - n_ctx: Context length + - n_head: Head dimension + - wg_tile_size: Workgroup tile size for the collapsed batch dimension (Z*H*n_ctx) + + Returns: + MLIR module containing the transform schedule + """ + assert parameters is not None, "Schedule parameters must be provided" + + mod = ir.Module.create() + mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + + with ir.InsertionPoint(mod.body): + # Create a transform sequence with proper signature + named_sequence = transform.named_sequence( + "__transform_main", + [transform.AnyOpType.get()], # input: module + [], # no outputs + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + + with ir.InsertionPoint(named_sequence.body): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + + xegpu_fused_attention_transform_schedule( + payload_mod, + parameters=parameters, + stop_at_stage=stop_at_stage or "", + ) + + return mod + + +def xegpu_fused_attention_transform_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +): + """Transform schedule for fused attention payload.""" + try: + mod = bundle_xegpu_fused_attention_schedule( + mod, + parameters=parameters, + stop_at_stage=stop_at_stage, + ) + + mod = bundle_xegpu_to_binary( + mod, + stop_at_stage=stop_at_stage, + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + + +def bundle_xegpu_fused_attention_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +) -> ir.Value[transform.AnyOpType]: + """Schedule for lowering fused attention payload to xegpu wg level.""" + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + anytype = transform.AnyOpType.get() + anyvalue = transform.AnyValueType.get() + # # Match all matmul operations - there should be 2: + # # 1. Q @ K^T + # # 2. attention_weights @ V + # matmul_ops = match_and_split(mod, ops={"linalg.batch_matmul"}, nhandles=2) + + # # Get the last matmul (attention_weights @ V) + # last_matmul = matmul_ops[1] + # func = transform.get_parent_op( + # anytype, + # last_matmul, + # op_name="func.func", + # deduplicate=True, + # ) + + # # Tile the last matmul in the batch dimension using tile_using_forall + # # Batch dimension is the first dimension (collapsed_dim = Z * H * n_ctx) + # # Extract workgroup tile size from parameters + # wg_tile_size = parameters["wg_tile_size"] + + # tiled_matmul, forall_loop = structured.structured_tile_using_forall( + # anytype, + # anytype, + # last_matmul, + # num_threads=[], + # tile_sizes=[], + # static_tile_sizes=(1, wg_tile_size, 0), + # ) + + # # Fuse the softmax producer into forall + # softmax_ops = match_and_split(func, ops={"linalg.softmax"}, nhandles=1) + # softmax_op = softmax_ops[0] + # fused_softmax_op, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=softmax_op, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) + + # # Fuse linalg.mul (scaling) into forall + # mul_ops = match_and_split(func, ops={"linalg.mul"}, nhandles=1) + # mul_op = mul_ops[0] + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=mul_op, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) + + # # Fuse the first matmul (Q @ K^T) into forall + # matmul_ops = match_and_split( + # func, ops={"linalg.batch_matmul"}, nhandles=2 + # ) # Two matmuls are present. + # first_matmul = matmul_ops[0] + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=first_matmul, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) + + # # Fuse linalg.transpose (K transpose) into forall + # transpose_ops = match_and_split(func, ops={"linalg.transpose"}, nhandles=1) + # transpose_op = transpose_ops[0] + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=transpose_op, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) + + # # At this point all of the key operations are fused into the forall loop. + # # Remaining linalg.fill ops can be fused trivially. + # fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=3) + # for fill_op in fill_ops: + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=fill_op, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) + + # # tensor.empty() holding the result of transpose can be fused. + # transpose_op = match_and_split(func, ops={"linalg.transpose"}, nhandles=1)[0] + # transpose_init = transform.get_producer_of_operand( + # anytype, transpose_op, operand_number=1 + # ) + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=transpose_init, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) + + # # tensor.empty() ops holding the result of the softmax can also be fused. + # softmax_op = match_and_split(func, ops={"linalg.softmax"}, nhandles=1)[0] + # softmax_init = transform.get_producer_of_operand( + # anytype, softmax_op, operand_number=1 + # ) + # _, forall_loop = structured.structured_fuse_into_containing_op( + # anytype, + # anytype, + # producer_op=softmax_init, + # containing_op=forall_loop, + # ) + # transform.apply_cse(func) + # canonicalize(func) + + if stop_at_stage == "outer-tiled": + raise PipelineInterrupt() + + # Match the last matmul (attention_weights @ V) + # There should be 2 matmuls: Q @ K^T and attention_weights @ V + matmul_ops = match_and_split(mod, ops={"linalg.matmul"}, nhandles=2) + last_matmul = matmul_ops[1] + + # Tile the last matmul in the K dimension only (reduction dimension) + # Matmul shape: (512, 512) @ (512, 64) + # Tile sizes: [M, N, K] = [0, 0, 64] - only tile the K dimension + _, _, _, sum_loop = structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=last_matmul, + tile_sizes=[0, 0, 32], + ) + + # transform.apply_cse(func) + # canonicalize(func) + + if stop_at_stage == "inner-tiled": + raise PipelineInterrupt() + + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + # bufferize (placeholder) + # mod = apply_registered_pass(mod, "eliminate-empty-tensors") + # identity_layout = LayoutMapOption.IdentityLayoutMap + # mod = transform_bufferization.OneShotBufferizeOp( + # mod, + # allow_return_allocs_from_loops=True, + # bufferize_function_boundaries=True, + # function_boundary_type_conversion=identity_layout, + # ).result + + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + + return mod diff --git a/lighthouse/schedule/xegpu/softmax_schedule.py b/lighthouse/schedule/xegpu/softmax_schedule.py index 0907bc5a..b406b5fc 100644 --- a/lighthouse/schedule/xegpu/softmax_schedule.py +++ b/lighthouse/schedule/xegpu/softmax_schedule.py @@ -16,6 +16,7 @@ PipelineInterrupt, ) from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary +from lighthouse.dialects.transform import transform_ext def get_softmax_schedule_module( @@ -39,6 +40,7 @@ def get_softmax_schedule_module( - sg_rows: Number of rows per subgroup - subgroup_size: Size of subgroup - sizes: Tuple with the sizes of the input tensors (e.g. (M, N)) + - reduction_step_size: Optional step size for tiling reduction loops Returns: MLIR module containing the transform schedule @@ -100,6 +102,22 @@ def xegpu_softmax_transform_schedule( transform.yield_() +def match_and_print_parent_function(op, msg): + """Get the parent function of an operation and print it. + + Args: + op: The operation whose parent function to find + func_name: Name label to use when printing the function + """ + anytype = transform.AnyOpType.get() + func = transform.get_parent_op( + anytype, + op, + op_name="func.func", + deduplicate=True, + ) + transform.print_(target=func, name=msg) + def bundle_xegpu_softmax_schedule( mod: ir.Value[transform.AnyOpType], parameters: dict, @@ -118,6 +136,8 @@ def bundle_xegpu_softmax_schedule( transform.AnyOpType.get(), mod, ops=["linalg.softmax"] ) + match_and_print_parent_function(softmax_op, "initial") + # Tile the softmax operation using tile_using_forall tiled_op, for_op = structured.structured_tile_using_forall( anytype, @@ -127,6 +147,7 @@ def bundle_xegpu_softmax_schedule( tile_sizes=[], static_tile_sizes=(parameters["wg_rows"],), ) + match_and_print_parent_function(for_op, "after tiling parallel dim") func = transform.get_parent_op( anytype, @@ -140,6 +161,90 @@ def bundle_xegpu_softmax_schedule( ) structured.structured_decompose_interface(anytype, softmax_ops) + + + linalg_ops = match_and_split( + func, ops={"linalg.generic", "linalg.fill"}, nhandles=6 + ) + match_and_print_parent_function(linalg_ops[0], "after decomposing softmax") + max_reduction = linalg_ops[1] + max_center_and_exp_op = linalg_ops[2] + sum_reduction = linalg_ops[4] + div_op = linalg_ops[5] + + reduction_step_size = parameters["reduction_step_size"] + + # Tile the division op and fuse the sub+exp producer into it + _, div_loop = structured.TileUsingForOp( + div_op, sizes=[0, reduction_step_size] + ).results + # Cleanup after tiling and fusion + transform.apply_cse(func) + canonicalize(func) + match_and_print_parent_function(div_loop, "after tiling div") + + # Fuse max_center_and_exp_op into the div loop + _, fused_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=max_center_and_exp_op, + containing_op=div_loop, + ) + # Cleanup after tiling and fusion + transform.apply_cse(func) + canonicalize(func) + match_and_print_parent_function(fused_loop, "after fusing max_center_and_exp into div loop") + + # Tile the sum reduction and fuse the sub+exp producer into it + _, _, _, sum_loop = structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=sum_reduction, + tile_sizes=[0, reduction_step_size], + ) + # Cleanup after tiling and fusion + transform.apply_cse(func) + canonicalize(func) + match_and_print_parent_function(sum_loop, "after tiling sum reduction") + + func = transform.get_parent_op( + anytype, + fused_loop, + op_name="func.func", + deduplicate=True, + ) + + # Re-match and split linalg generic ops, there are 5 at this point + linalg_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=5) + max_center_and_exp_op = linalg_ops[1] + + # Fuse max_center_and_exp_op into the sum reduction loop + _, fused_sum_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=max_center_and_exp_op, + containing_op=sum_loop, + ) + # Cleanup after tiling and fusion + transform.apply_cse(func) + canonicalize(func) + match_and_print_parent_function(fused_sum_loop, "after fusing max_center_and_exp into sum reduction loop") + + # Tile the max reduction. + max_reduction = linalg_ops[0] + _, _, _, max_loop = structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=max_reduction, + tile_sizes=[0, reduction_step_size], + ) + match_and_print_parent_function(max_loop, "after tiling max reduction") + + # Cleanup after tiling and fusion transform.apply_cse(func) canonicalize(func) @@ -171,6 +276,17 @@ def bundle_xegpu_softmax_schedule( transform.apply_cse(mod) canonicalize(mod) + # promote memref.alloc to memref.alloca in payload function + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + if stop_at_stage == "bufferized": raise PipelineInterrupt() @@ -200,6 +316,9 @@ def bundle_xegpu_softmax_schedule( mod = apply_registered_pass(mod, "gpu-kernel-outlining") transform.apply_cse(mod) + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + # set xevm target mod = apply_registered_pass( mod, @@ -207,22 +326,33 @@ def bundle_xegpu_softmax_schedule( options={"O": "3", "chip": "bmg"}, ) - # convert vector to xegpu + # for each gpu function in the gpu module, change memref.alloca address + # space to 3 (SLM) and convert vector to xegpu. gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) for gpu_mod in gpu_mod_ops: gpu_func = match(gpu_mod, ops={"gpu.func"}) + allocas = match_and_split(gpu_func, ops={"memref.alloca"}) + for alloca in allocas: + transform_ext.update_address_space(alloca, address_space=3) gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") transform.apply_cse(gpu_func) + # Cleanup. + transform.apply_cse(mod) + canonicalize(mod) + if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - # Set layout attributes for xegpu.store_nd operations. - # FIXME: currently ecah subgroup is handling the entire row. - store_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) + # Set layout attributes for xegpu.store_nd and xegpu.store_matrix ops. + store_nd_ops = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1) + store_matrix_ops = match_and_split(gpu_func, ops={"xegpu.store_matrix"}, nhandles=4) sg_layout = [parameters["sg_rows"], 1] - sg_data = [parameters["sg_rows"], parameters["sizes"][1]] - xegpu.set_anchor_layout(store_ops[0], sg_layout=sg_layout, sg_data=sg_data) + sg_data = [parameters["sg_rows"], parameters["reduction_step_size"]] + for store_op in store_nd_ops: + xegpu.set_anchor_layout(store_op, sg_layout=sg_layout, sg_data=sg_data) + for store_op in store_matrix_ops: + xegpu.set_anchor_layout(store_op, sg_layout=sg_layout, sg_data=sg_data) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() diff --git a/reduction_tiling_docs/fused_attention_tiling.md b/reduction_tiling_docs/fused_attention_tiling.md new file mode 100644 index 00000000..24375a16 --- /dev/null +++ b/reduction_tiling_docs/fused_attention_tiling.md @@ -0,0 +1,240 @@ +# Attention Tiling + +## Linalg level implementation + +Input sizes: +- Q: 64x64 +- K: 4096x64 +- V: 4096x64 + +```mlir +func.func @attention(%Q : memref<64x64xf32>, %K: memref<4096x64xf32>, +%V: memref<4096x64xf32>, %out: memref<64x64xf32>) { + ... + // Transpose K + %k_transpose = linalg.transpose ... -> tensor<64x4096xf32> + + // QK^T + %QKT = linalg.matmul ins(%q, %k_transpose : tensor<64x64xf32>, tensor<64x4096xf32>) + outs(%empty_NxN : tensor<64x4096xf32>) -> tensor<64x4096xf32> + + // Fill with -inf + %t_minf = linalg.fill ins(%cst_minus_inf : f32) outs(%empty_N : tensor<64xf32>) -> tensor<64xf32> + + // Max reduce along rows + %max = linalg.reduce ins(%QKT : tensor<64x4096xf32>) ... %m = arith.maximumf %in, %init : f32 -> tensor<64xf32> + + // Broadcast max + %maxb = linalg.broadcast ins(%max: tensor<64xf32>) outs(%empty_NxN : tensor<64x4096xf32>) dimensions = [1] -> tensor<64x4096xf32> + + // Subtract + %sub = linalg.elemwise_binary {fun = #linalg.binary_fn} ... -> tensor<64x4096xf32> + + // Exp + %exp = linalg.elemwise_unary {fun = #linalg.unary_fn} ... -> tensor<64x4096xf32> + + // Fill with zeros + %t_zeros = linalg.fill ins(%c0f : f32) outs(%empty_N : tensor<64xf32>) -> tensor<64xf32> + + // Sum reduce along rows + %sum = linalg.reduce ... %s = arith.addf %in, %init : f32 ... -> tensor<64xf32> + + // Broadcast sum and div + %sums = linalg.broadcast ... -> tensor<64x4096xf32> + %p = linalg.elemwise_binary {fun = #linalg.binary_fn
} + ins(%exp, %sums : tensor<64x4096xf32>, tensor<64x4096xf32>) ... -> tensor<64x4096xf32> + + // Final matmul + %o = linalg.matmul ins(%p, %v : tensor<64x4096xf32>, tensor<4096x64xf32>) ... -> tensor<64x64xf32> + ... +} +``` + +--- + +## Stage 1: Tile the last matmul in K dim (tile size = 16) + +After tiling the final matmul `%o = linalg.matmul ins(%p, %v)` along the K dimension with tile size 16: + +```mlir +func.func @attention(%Q : memref<64x64xf32>, %K: memref<4096x64xf32>, +%V: memref<4096x64xf32>, %out: memref<64x64xf32>) { + ... + // Compute p = Softmax(Q @ K^T) + // ... + + // Final matmul TILED in K dimension (4096 / 16 = 256 tiles) + // Loop over K dimension: k = 0 to 4096 step 16 + %c0 = arith.constant 0 : index + %c4096 = arith.constant 4096 : index + %c16 = arith.constant 16 : index + + // Initialize output with zeros: 64x64 + %o_init = linalg.fill ins(%c0f : f32) outs(%empty_out : tensor<64x64xf32>) -> tensor<64x64xf32> + + %o = scf.for %k = %c0 to %c4096 step %c16 iter_args(%o_acc = %o_init) -> (tensor<64x64xf32>) { + // Extract slice from %p: 64x16 (from columns [k:k+16]) + %p_slice = tensor.extract_slice %p[0, %k][64, 16][1, 1] -> tensor<64x16xf32> + + // Extract slice from %v: 16x64 (from rows [k:k+16]) + %v_slice = tensor.extract_slice %v[%k, 0][16, 64][1, 1] -> tensor<16x64xf32> + + // Partial matmul: (64x16) @ (16x64) -> 64x64 + %partial = linalg.matmul ins(%p_slice, %v_slice : tensor<64x16xf32>, tensor<16x64xf32>) + outs(%empty_partial : tensor<64x64xf32>) -> tensor<64x64xf32> + + // Accumulate: 64x64 + 64x64 -> 64x64 + %o_new = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%o_acc, %partial : tensor<64x64xf32>, tensor<64x64xf32>) ... -> tensor<64x64xf32> + + scf.yield %o_new : tensor<64x64xf32> + } + ... +} +``` + +## Stage 2: Tile and fuse the softmax computation (tile size = 16) + +After tiling the softmax computation in the reduction dimension with tile size 16 and fusing operations, following the pattern from the softmax lowering flow: + +```mlir +func.func @attention(%Q : memref<64x64xf32>, %K: memref<4096x64xf32>, +%V: memref<4096x64xf32>, %out: memref<64x64xf32>) { + ... + // === First matmul: Q @ K^T === + // Transpose K: 4096x64 -> 64x4096 + %k_transpose = linalg.transpose ... -> tensor<64x4096xf32> + + // QK^T: (64x64) @ (64x4096) -> 64x4096 + %QKT = linalg.matmul ins(%q, %k_transpose : tensor<64x64xf32>, tensor<64x4096xf32>) + outs(%empty_NxN : tensor<64x4096xf32>) -> tensor<64x4096xf32> + + + // === Tiled and fused softmax computation === + // Tile size = 16, number of tiles = 4096 / 16 = 256 + %c0 = arith.constant 0 : index + %c4096 = arith.constant 4096 : index + %c16 = arith.constant 16 : index + + // Initialize max buffer with -inf: 64x16 + %max_buffer_init = linalg.fill ins(%cst_minus_inf : f32) outs(%empty_max_buf : tensor<64x16xf32>) -> tensor<64x16xf32> + + // Loop 1: Max reduction (4096 / 16 = 256 iterations) + %max_buffer = scf.for %k = %c0 to %c4096 step %c16 iter_args(%max_acc = %max_buffer_init) -> (tensor<64x16xf32>) { + // Extract slice from QKT: 64x16 + %QKT_slice = tensor.extract_slice %QKT[0, %k][64, 16][1, 1] -> tensor<64x16xf32> + + // Max accumulation: 64x16 + %max_new = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%QKT_slice : tensor<64x16xf32>) outs(%max_acc : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %max_val = arith.maxnumf %in, %out : f32 + linalg.yield %max_val : f32 + } -> tensor<64x16xf32> + + scf.yield %max_new : tensor<64x16xf32> + } + + // Final max reduction: 64x16 -> 64 + %max = linalg.reduce ins(%max_buffer : tensor<64x16xf32>) outs(%empty_N : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %m = arith.maxnumf %in, %init : f32 + linalg.yield %m : f32 + } + } -> tensor<64xf32> + + + // Initialize sum buffer with zeros: 64x16 + %sum_buffer_init = linalg.fill ins(%c0f : f32) outs(%empty_sum_buf : tensor<64x16xf32>) -> tensor<64x16xf32> + + // Loop 2: Sum reduction with fused center+exp (256 iterations) + %sum_buffer = scf.for %k = %c0 to %c4096 step %c16 iter_args(%sum_acc = %sum_buffer_init) -> (tensor<64x16xf32>) { + // Extract slice from QKT: 64x16 + %QKT_slice = tensor.extract_slice %QKT[0, %k][64, 16][1, 1] -> tensor<64x16xf32> + + // Fused center+exp: 64x16 + %exp_slice = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%QKT_slice, %max : tensor<64x16xf32>, tensor<64xf32>) outs(%empty_slice : tensor<64x16xf32>) { + ^bb0(%in: f32, %max_val: f32, %out: f32): + %centered = arith.subf %in, %max_val : f32 + %exp_val = math.exp %centered : f32 + linalg.yield %exp_val : f32 + } -> tensor<64x16xf32> + + // Sum accumulation: 64x16 + %sum_new = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%exp_slice : tensor<64x16xf32>) outs(%sum_acc : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %sum_val = arith.addf %in, %out : f32 + linalg.yield %sum_val : f32 + } -> tensor<64x16xf32> + + scf.yield %sum_new : tensor<64x16xf32> + } + + // Final sum reduction: 64x16 -> 64 + %sum = linalg.reduce ins(%sum_buffer : tensor<64x16xf32>) outs(%empty_N : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %s = arith.addf %in, %init : f32 + linalg.yield %s : f32 + } + } -> tensor<64xf32> + + + // Initialize output buffer for softmax: 64x4096 + %p_init = linalg.fill ins(%c0f : f32) outs(%empty_NxN : tensor<64x4096xf32>) -> tensor<64x4096xf32> + + // Loop 3: Division with fused center+exp+div (256 iterations) + %p = scf.for %k = %c0 to %c4096 step %c16 iter_args(%p_acc = %p_init) -> (tensor<64x4096xf32>) { + // Extract slice from QKT: 64x16 + %QKT_slice = tensor.extract_slice %QKT[0, %k][64, 16][1, 1] -> tensor<64x16xf32> + + // Fused center+exp: 64x16 + %exp_slice = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%QKT_slice, %max : tensor<64x16xf32>, tensor<64xf32>) outs(%empty_slice : tensor<64x16xf32>) { + ^bb0(%in: f32, %max_val: f32, %out: f32): + %centered = arith.subf %in, %max_val : f32 + %exp_val = math.exp %centered : f32 + linalg.yield %exp_val : f32 + } -> tensor<64x16xf32> + + // Division: 64x16 + %p_slice = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%exp_slice, %sum : tensor<64x16xf32>, tensor<64xf32>) outs(%empty_slice : tensor<64x16xf32>) { + ^bb0(%exp_val: f32, %sum_val: f32, %out: f32): + %result = arith.divf %exp_val, %sum_val : f32 + linalg.yield %result : f32 + } -> tensor<64x16xf32> + + // Insert slice back: 64x16 -> 64x4096 + %p_new = tensor.insert_slice %p_slice into %p_acc[0, %k][64, 16][1, 1] -> tensor<64x4096xf32> + + scf.yield %p_new : tensor<64x4096xf32> + } + // Result: %p contains softmax(Q @ K^T) with shape 64x4096 + + + // === Final matmul TILED in K dimension (256 iterations) === + // Initialize output with zeros: 64x64 + %o_init = linalg.fill ins(%c0f : f32) outs(%empty_out : tensor<64x64xf32>) -> tensor<64x64xf32> + + %o = scf.for %k = %c0 to %c4096 step %c16 iter_args(%o_acc = %o_init) -> (tensor<64x64xf32>) { + // Extract slice from %p: 64x16 + %p_slice = tensor.extract_slice %p[0, %k][64, 16][1, 1] -> tensor<64x16xf32> + + // Extract slice from %v: 16x64 + %v_slice = tensor.extract_slice %v[%k, 0][16, 64][1, 1] -> tensor<16x64xf32> + + // Partial matmul: (64x16) @ (16x64) -> 64x64 + %partial = linalg.matmul ins(%p_slice, %v_slice : tensor<64x16xf32>, tensor<16x64xf32>) + outs(%empty_partial : tensor<64x64xf32>) -> tensor<64x64xf32> + + // Accumulate: 64x64 + 64x64 -> 64x64 + %o_new = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%o_acc, %partial : tensor<64x64xf32>, tensor<64x64xf32>) ... -> tensor<64x64xf32> + + scf.yield %o_new : tensor<64x64xf32> + } + ... +} +``` diff --git a/reduction_tiling_docs/softmax_lowering_flow.md b/reduction_tiling_docs/softmax_lowering_flow.md new file mode 100644 index 00000000..45dd0249 --- /dev/null +++ b/reduction_tiling_docs/softmax_lowering_flow.md @@ -0,0 +1,724 @@ +# Softmax Lowering Flow: IR Transformation Stages + +**Input Shape**: `1024x512xf32` (1024 rows, 512 columns) +**Softmax Dimension**: dim=1 (along the 512-element rows) + +--- + +## Stage 1: Initial IR + +Single high-level `linalg.softmax` operation on the full tensor. + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %1 = bufferization.to_tensor %arg1 : tensor<1024x512xf32> + + // Single softmax op over entire tensor + %3 = linalg.softmax dimension(1) ins(%1 : tensor<1024x512xf32>) + outs(%2 : tensor<1024x512xf32>) -> tensor<1024x512xf32> + + bufferization.materialize_in_destination %3 in %arg0 +} +``` + +--- + +## Stage 2: After Tiling Parallel Dim + +Parallel dimension (rows) tiled into 16 chunks of 64 rows each. Introduces `scf.forall` for parallel execution. + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + // Parallel loop over 16 tiles (1024 / 64 = 16) + %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x512xf32>) { + %4 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2) + %slice = tensor.extract_slice %1[%4, 0] [64, 512] [1, 1] + + // Softmax on 64x512 slice + %5 = linalg.softmax dimension(1) ins(%slice : tensor<64x512xf32>) + outs(%slice_0 : tensor<64x512xf32>) -> tensor<64x512xf32> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg3[%4, 0] [64, 512] [1, 1] + } + } +} +``` + +--- + +## Stage 3: After Decomposing Softmax + +Softmax decomposed into 4 operations: max reduction → center+exp → sum reduction → division. + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %1[%4, 0] [64, 512] [1, 1] + + // 1. Max reduction: (64,512) -> (64,) + %7 = linalg.generic {indexing_maps = [map<(d0,d1) -> (d0,d1)>, map<(d0,d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%slice : tensor<64x512xf32>) outs(%6 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %12 = arith.maxnumf %in, %out : f32 + linalg.yield %12 : f32 + } -> tensor<64xf32> + + // 2. Center and exp: (64,512) -> (64,512) + %8 = linalg.generic {indexing_maps = [map<(d0,d1) -> (d0,d1)>, map<(d0,d1) -> (d0)>, map<(d0,d1) -> (d0,d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%slice, %7 : tensor<64x512xf32>, tensor<64xf32>) outs(%slice_0 : tensor<64x512xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %12 = arith.subf %in, %in_2 : f32 + %13 = math.exp %12 : f32 + linalg.yield %13 : f32 + } -> tensor<64x512xf32> + + // 3. Sum reduction: (64,512) -> (64,) + %10 = linalg.generic {indexing_maps = [map<(d0,d1) -> (d0,d1)>, map<(d0,d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%8 : tensor<64x512xf32>) outs(%9 : tensor<64xf32>) { + ^bb0(%in: f32, %out: f32): + %12 = arith.addf %in, %out : f32 + linalg.yield %12 : f32 + } -> tensor<64xf32> + + // 4. Division: (64,512) -> (64,512) + %11 = linalg.generic {indexing_maps = [map<(d0,d1) -> (d0,d1)>, map<(d0,d1) -> (d0)>, map<(d0,d1) -> (d0,d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%8, %10 : tensor<64x512xf32>, tensor<64xf32>) outs(%slice_0 : tensor<64x512xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %12 = arith.divf %in, %in_2 : f32 + linalg.yield %12 : f32 + } -> tensor<64x512xf32> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %11 into %arg3[%4, 0] [64, 512] [1, 1] + } + } +} +``` + +--- + +## Stage 4: After Tiling Division + +Division operation tiled along dimension 1 into chunks of 16 columns. + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %0[%3, 0] [64, 512] [1, 1] + + // Max reduction (64,512) -> (64,) + %6 = linalg.generic {iterator_types = ["parallel", "reduction"]} + ins(%slice) outs(%5) { maxnumf } -> tensor<64xf32> + + // Center and exp (64,512) -> (64,512) + %7 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice, %6) outs(%slice_1) { subf, exp } -> tensor<64x512xf32> + + // Sum reduction (64,512) -> (64,) + %9 = linalg.generic {iterator_types = ["parallel", "reduction"]} + ins(%7) outs(%8) { addf } -> tensor<64xf32> + + // Division tiled over columns: loop from 0 to 512 step 16 + %10 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %slice_1) -> (tensor<64x512xf32>) { + %slice_2 = tensor.extract_slice %7[0, %arg4] [64, 16] [1, 1] + %slice_3 = tensor.extract_slice %arg5[0, %arg4] [64, 16] [1, 1] + + // Division on 64x16 tile + %11 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_2, %9 : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_3 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_4: f32, %out: f32): + %12 = arith.divf %in, %in_4 : f32 + linalg.yield %12 : f32 + } -> tensor<64x16xf32> + + %inserted = tensor.insert_slice %11 into %arg5[0, %arg4] [64, 16] [1, 1] + scf.yield %inserted : tensor<64x512xf32> + } + + scf.forall.in_parallel { + tensor.parallel_insert_slice %10 into %arg3[%3, 0] [64, 512] [1, 1] + } + } +} +``` + +--- + +## Stage 5: After Fusing Max+Center+Exp into Division Loop + +The center-and-exp computation is fused into the division loop to recompute values on-the-fly. + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %0[%3, 0] [64, 512] [1, 1] + + // Max reduction (64,512) -> (64,) + %6 = linalg.generic {iterator_types = ["parallel", "reduction"]} + ins(%slice) outs(%5) { maxnumf } -> tensor<64xf32> + + // Center and exp (still materialized for sum reduction) + %7 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice, %6) outs(%slice_1) { subf, exp } -> tensor<64x512xf32> + + // Sum reduction (64,512) -> (64,) + %9 = linalg.generic {iterator_types = ["parallel", "reduction"]} + ins(%7) outs(%8) { addf } -> tensor<64xf32> + + // Division loop with fused center+exp+div + %10 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %slice_1) -> (tensor<64x512xf32>) { + %slice_2 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] // from original input + %slice_3 = tensor.extract_slice %arg5[0, %arg4] [64, 16] [1, 1] + + // Fused center+exp on 64x16 tile + %11 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_2, %6 : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_3 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_4: f32, %out: f32): + %13 = arith.subf %in, %in_4 : f32 + %14 = math.exp %13 : f32 + linalg.yield %14 : f32 + } -> tensor<64x16xf32> + + // Division on 64x16 tile + %12 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%11, %9 : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_3 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_4: f32, %out: f32): + %13 = arith.divf %in, %in_4 : f32 + linalg.yield %13 : f32 + } -> tensor<64x16xf32> + + %inserted = tensor.insert_slice %12 into %arg5[0, %arg4] [64, 16] [1, 1] + scf.yield %inserted : tensor<64x512xf32> + } + + scf.forall.in_parallel { + tensor.parallel_insert_slice %10 into %arg3[%3, 0] [64, 512] [1, 1] + } + } +} +``` + + +--- + +## Stage 6: After Tiling Sum Reduction + +Sum reduction tiled into chunks of 16 columns, introducing partial sums followed by a final reduction. + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %0[%3, 0] [64, 512] [1, 1] + + // Max reduction (64,512) -> (64,) + %6 = linalg.generic {iterator_types = ["parallel", "reduction"]} + ins(%slice) outs(%5) { maxnumf } -> tensor<64xf32> + + // Center and exp (64,512) -> (64,512) + %7 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice, %6) outs(%slice_1) { subf, exp } -> tensor<64x512xf32> + + // Tiled sum reduction: accumulate into 64x16 buffer + %11 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %10) -> (tensor<64x16xf32>) { + %slice_2 = tensor.extract_slice %7[0, %arg4] [64, 16] [1, 1] + + // Accumulate sums + %13 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_2 : tensor<64x16xf32>) outs(%arg5 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %14 = arith.addf %in, %out : f32 + linalg.yield %14 : f32 + } -> tensor<64x16xf32> + + scf.yield %13 : tensor<64x16xf32> + } + + // Final reduction: (64,16) -> (64,) + %reduced = linalg.reduce ins(%11 : tensor<64x16xf32>) outs(%8 : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %13 = arith.addf %in, %init : f32 + linalg.yield %13 : f32 + } + } + + // Division loop (same as before) + %12 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %slice_1) -> (tensor<64x512xf32>) { + // ... fused center+exp+div ... + } + + scf.forall.in_parallel { + tensor.parallel_insert_slice %12 into %arg3[%3, 0] [64, 512] [1, 1] + } + } +} +``` + +--- + +## Stage 7: After Fusing Max+Center+Exp into Sum Reduction Loop + +The center-and-exp computation is now fused into the sum reduction loop as well. + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %0[%3, 0] [64, 512] [1, 1] + + // Max reduction (64,512) -> (64,) + %6 = linalg.generic {iterator_types = ["parallel", "reduction"]} + ins(%slice) outs(%5) { maxnumf } -> tensor<64xf32> + + // Sum reduction loop with fused center+exp + %10 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %9) -> (tensor<64x16xf32>) { + %slice_2 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Fused center+exp + %12 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_2, %6 : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_3 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_4: f32, %out: f32): + %14 = arith.subf %in, %in_4 : f32 + %15 = math.exp %14 : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + + // Accumulate into sum buffer + %13 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%12 : tensor<64x16xf32>) outs(%arg5 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %14 = arith.addf %in, %out : f32 + linalg.yield %14 : f32 + } -> tensor<64x16xf32> + + scf.yield %13 : tensor<64x16xf32> + } + + // Final reduction: (64,16) -> (64,) + %reduced = linalg.reduce ins(%10 : tensor<64x16xf32>) outs(%7 : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %12 = arith.addf %in, %init : f32 + linalg.yield %12 : f32 + } + } + + // Division loop with fused center+exp+div + %11 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %slice_1) -> (tensor<64x512xf32>) { + %slice_2 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Fused center+exp + %12 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_2, %6 : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_3 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_4: f32, %out: f32): + %14 = arith.subf %in, %in_4 : f32 + %15 = math.exp %14 : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + + // Division + %13 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%12, %reduced : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_3 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_4: f32, %out: f32): + %14 = arith.divf %in, %in_4 : f32 + linalg.yield %14 : f32 + } -> tensor<64x16xf32> + + %inserted = tensor.insert_slice %13 into %arg5[0, %arg4] [64, 16] [1, 1] + scf.yield %inserted : tensor<64x512xf32> + } + + scf.forall.in_parallel { + tensor.parallel_insert_slice %11 into %arg3[%3, 0] [64, 512] [1, 1] + } + } +} +``` + +--- + +## Stage 8: After Tiling Max Reduction + +Max reduction also tiled into 16-column chunks with partial max followed by final reduction. + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %0[%3, 0] [64, 512] [1, 1] + + // Tiled max reduction: accumulate into 64x16 buffer + %8 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %7) -> (tensor<64x16xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Max accumulation + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7 : tensor<64x16xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %15 = arith.maxnumf %in, %out : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + + %inserted = tensor.insert_slice %14 into %arg5[0, 0] [64, 16] [1, 1] + scf.yield %inserted : tensor<64x16xf32> + } + + // Final max reduction: (64,16) -> (64,) + %reduced = linalg.reduce ins(%8 : tensor<64x16xf32>) outs(%5 : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %14 = arith.maxnumf %in, %init : f32 + linalg.yield %14 : f32 + } + } + + // Sum reduction loop with fused center+exp + %12 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Fused center+exp using reduced max + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7, %reduced : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %16 = arith.subf %in, %in_9 : f32 + %17 = math.exp %16 : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + + // Sum accumulation + %15 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%14 : tensor<64x16xf32>) outs(%arg5 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %16 = arith.addf %in, %out : f32 + linalg.yield %16 : f32 + } -> tensor<64x16xf32> + + scf.yield %15 : tensor<64x16xf32> + } + + // Final sum reduction: (64,16) -> (64,) + %reduced_6 = linalg.reduce ins(%12 : tensor<64x16xf32>) outs(%9 : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %14 = arith.addf %in, %init : f32 + linalg.yield %14 : f32 + } + } + + // Division loop with fused center+exp+div + %13 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %slice_1) -> (tensor<64x512xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Fused center+exp + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7, %reduced : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %16 = arith.subf %in, %in_9 : f32 + %17 = math.exp %16 : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + + // Division + %15 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%14, %reduced_6 : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %16 = arith.divf %in, %in_9 : f32 + linalg.yield %16 : f32 + } -> tensor<64x16xf32> + + %inserted = tensor.insert_slice %15 into %arg5[0, %arg4] [64, 16] [1, 1] + scf.yield %inserted : tensor<64x512xf32> + } + + scf.forall.in_parallel { + tensor.parallel_insert_slice %13 into %arg3[%3, 0] [64, 512] [1, 1] + } + } +} +``` + +--- + +## Stage 9: Final Vectorized XeGPU Version + +After vectorization, bufferization, and conversion to XeGPU operations. Uses shared local memory (SLM) for partial reductions. + +```mlir +gpu.module @payload_kernel { + gpu.func @payload_kernel(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) kernel { + %block_id_x = gpu.block_id x + %0 = arith.muli %block_id_x, %c64 : index + %subview = memref.subview %arg0[%0, 0] [64, 512] [1, 1] + + // Allocate SLM buffer for partial reductions + %alloca = memref.alloca() : memref<64x16xf32, 3> + %1 = xegpu.create_mem_desc %alloca : !xegpu.mem_desc<64x16xf32> + + // Max reduction loop + xegpu.store_matrix %cst_2, %1[0, 0] // init with -inf + scf.for %arg2 = %c0 to %c512 step %c16 { + // Load 64x16 tile from global memory + %6 = xegpu.create_nd_tdesc %arg1 : !xegpu.tensor_desc<64x16xf32> + %7 = xegpu.load_nd %6[%0, %arg2] : vector<64x16xf32> + + // Load partial max from SLM, compute max, store back + %8 = xegpu.load_matrix %1[0, 0] : vector<64x16xf32> + %9 = arith.maxnumf %7, %8 : vector<64x16xf32> + xegpu.store_matrix %9, %1[0, 0] + } + + // Final max reduction across 16 columns + %2 = xegpu.load_matrix %1[0, 0] : vector<64x16xf32> + %3 = vector.multi_reduction , %2, %cst_1 [1] : vector<64x16xf32> to vector<64xf32> + + // Sum reduction loop + xegpu.store_matrix %cst_0, %1[0, 0] // init with 0.0 + scf.for %arg2 = %c0 to %c512 step %c16 { + // Load 64x16 tile + %6 = xegpu.create_nd_tdesc %arg1 : !xegpu.tensor_desc<64x16xf32> + %7 = xegpu.load_nd %6[%0, %arg2] : vector<64x16xf32> + + // Fused center+exp + %8 = vector.broadcast %3 : vector<64xf32> to vector<16x64xf32> + %9 = vector.transpose %8, [1, 0] : vector<64x16xf32> + %10 = arith.subf %7, %9 : vector<64x16xf32> + %11 = math.exp %10 : vector<64x16xf32> + + // Accumulate sum in SLM + %12 = xegpu.load_matrix %1[0, 0] : vector<64x16xf32> + %13 = arith.addf %11, %12 : vector<64x16xf32> + xegpu.store_matrix %13, %1[0, 0] + } + + // Final sum reduction across 16 columns + %4 = xegpu.load_matrix %1[0, 0] : vector<64x16xf32> + %5 = vector.multi_reduction , %4, %cst [1] : vector<64x16xf32> to vector<64xf32> + + // Division loop + scf.for %arg2 = %c0 to %c512 step %c16 { + // Load 64x16 tile + %6 = xegpu.create_nd_tdesc %arg1 : !xegpu.tensor_desc<64x16xf32> + %7 = xegpu.load_nd %6[%0, %arg2] : vector<64x16xf32> + + // Fused center+exp + %8 = vector.broadcast %3 : vector<64xf32> to vector<16x64xf32> + %9 = vector.transpose %8, [1, 0] : vector<64x16xf32> + %10 = arith.subf %7, %9 : vector<64x16xf32> + %11 = math.exp %10 : vector<64x16xf32> + + // Division + %12 = vector.broadcast %5 : vector<64xf32> to vector<16x64xf32> + %13 = vector.transpose %12, [1, 0] : vector<64x16xf32> + %14 = arith.divf %11, %13 : vector<64x16xf32> + + // Store result to global memory + %18 = xegpu.create_nd_tdesc %intptr : !xegpu.tensor_desc<64x16xf32> + xegpu.store_nd %14, %18[0, %arg2] + } + + gpu.return + } +} +``` + +--- + +## Summary of Transformations + +| Stage | Key Transformation | Loop Structure | +|-------|-------------------|----------------| +| 1 | Initial high-level softmax | No loops | +| 2 | Tile parallel dimension | `scf.forall(16)` | +| 3 | Decompose softmax | `scf.forall(16)` + 4 sequential ops | +| 4 | Tile division | `scf.forall(16)` → `scf.for(32)` | +| 5 | Fuse into division loop | Recompute center+exp in div loop | +| 6 | Tile sum reduction | Add sum loop + final reduction | +| 7 | Fuse into sum loop | Recompute center+exp in sum loop | +| 8 | Tile max reduction | Add max loop + final reduction | +| 9 | Vectorize + XeGPU | GPU kernel with SLM and vector ops | + +**Final computation pattern per GPU block:** +1. **Max reduction**: 32-iteration loop with SLM accumulation → final reduction +2. **Sum reduction**: 32-iteration loop (fused center+exp) with SLM accumulation → final reduction +3. **Division**: 32-iteration loop (fused center+exp+div) writing to global memory + +--- + +## Optimization: Fusing Max and Sum Reduction Loops + +After Stage 8, we can apply an additional optimization to fuse the max reduction loop and sum reduction loop into a single loop. This reduces the number of loops from 3 to 2. + +### Key Insight + +The optimization leverages the **online softmax algorithm**, which allows us to incrementally update both the global maximum and the global sum as we process each tile of the reduction dimension. For each 16-column tile: + +1. Compute the **local max** for the tile +2. Update the **global max** using the local max +3. Compute the **local centered sum** using exp(x - local_max) +4. **Rescale** the global sum by exp(global_max_old - global_max_new) +5. **Add** the rescaled local sum to the global sum + +This maintains numerical stability while processing tiles incrementally, since we adjust previous sums by the correction factor when we discover a new maximum. + +### Before: Separate Max and Sum Loops (3 loops total) + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %0[%3, 0] [64, 512] [1, 1] + + // Loop 1: Max reduction + %8 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %7) -> (tensor<64x16xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7 : tensor<64x16xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %15 = arith.maxnumf %in, %out : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + + scf.yield %14 : tensor<64x16xf32> + } + %reduced = linalg.reduce ins(%8) outs(%5) dimensions = [1] { maxnumf } + + // Loop 2: Sum reduction with center+exp + %12 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Center+exp using global max + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7, %reduced : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8) { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %16 = arith.subf %in, %in_9 : f32 + %17 = math.exp %16 : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + + // Accumulate sum + %15 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%14) outs(%arg5) { + ^bb0(%in: f32, %out: f32): + %16 = arith.addf %in, %out : f32 + linalg.yield %16 : f32 + } -> tensor<64x16xf32> + + scf.yield %15 : tensor<64x16xf32> + } + %reduced_6 = linalg.reduce ins(%12) outs(%9) dimensions = [1] { addf } + + // Loop 3: Division with center+exp+div + %13 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %slice_1) -> (tensor<64x512xf32>) { + // ... fused center+exp+div ... + } + } +} +``` + +### After: Fused Max+Sum Loop (2 loops total) + +```mlir +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %0[%3, 0] [64, 512] [1, 1] + + // Loop 1: Fused max+sum reduction (online softmax) + %fused = scf.for %arg4 = %c0 to %c512 step %c16 + iter_args(%global_max = %init_max, %global_sum_buffer = %7) + -> (tensor<64xf32>, tensor<64x16xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Step 1: Compute local max for this tile + %local_max = linalg.reduce ins(%slice_7 : tensor<64x16xf32>) outs(%5 : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %max = arith.maxnumf %in, %init : f32 + linalg.yield %max : f32 + } + } + + // Step 2: Update global max + %new_global_max = linalg.generic {iterator_types = ["parallel"]} + ins(%global_max, %local_max : tensor<64xf32>, tensor<64xf32>) outs(%out_max : tensor<64xf32>) { + ^bb0(%old_max: f32, %curr_max: f32, %out: f32): + %updated = arith.maxnumf %old_max, %curr_max : f32 + linalg.yield %updated : f32 + } -> tensor<64xf32> + + // Step 3: Compute local centered sum: exp(x - local_max) + %local_exp = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7, %local_max : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %max: f32, %out: f32): + %centered = arith.subf %in, %max : f32 + %exp_val = math.exp %centered : f32 + linalg.yield %exp_val : f32 + } -> tensor<64x16xf32> + + // Reduce to get tile sum + %local_sum_buffer = linalg.reduce ins(%local_exp : tensor<64x16xf32>) outs(%9 : tensor<64x16xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %sum = arith.addf %in, %init : f32 + linalg.yield %sum : f32 + } + } + + // Step 4: Rescale global sum by exp(old_max - new_max) and add local sum + %updated_global_sum = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%global_sum_buffer, %global_max, %new_global_max, %local_sum_buffer : + tensor<64x16xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64x16xf32>) + outs(%out_sum : tensor<64x16xf32>) { + ^bb0(%old_sum: f32, %old_max: f32, %new_max: f32, %local: f32, %out: f32): + // Correction factor: exp(old_max - new_max) + %max_diff = arith.subf %old_max, %new_max : f32 + %scale = math.exp %max_diff : f32 + %rescaled_sum = arith.mulf %old_sum, %scale : f32 + + // Add local sum (already centered on local_max, need to rescale) + %local_scale_diff = arith.subf %local_max, %new_max : f32 + %local_scale = math.exp %local_scale_diff : f32 + %rescaled_local = arith.mulf %local, %local_scale : f32 + + %updated = arith.addf %rescaled_sum, %rescaled_local : f32 + linalg.yield %updated : f32 + } -> tensor<64x16xf32> + + scf.yield %new_global_max, %updated_global_sum : tensor<64xf32>, tensor<64x16xf32> + } + + // Extract final results + %final_max = %fused#0 : tensor<64xf32> + %final_sum_buffer = %fused#1 : tensor<64x16xf32> + %final_sum = linalg.reduce ins(%final_sum_buffer) outs(%9) dimensions = [1] { addf } + + // Loop 2: Division with center+exp+div + %13 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %slice_1) -> (tensor<64x512xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Center+exp + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7, %final_max : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8) { + ^bb0(%in: f32, %max: f32, %out: f32): + %centered = arith.subf %in, %max : f32 + %exp_val = math.exp %centered : f32 + linalg.yield %exp_val : f32 + } -> tensor<64x16xf32> + + // Division + %15 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%14, %final_sum : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8) { + ^bb0(%exp_val: f32, %sum: f32, %out: f32): + %result = arith.divf %exp_val, %sum : f32 + linalg.yield %result : f32 + } -> tensor<64x16xf32> + + %inserted = tensor.insert_slice %15 into %arg5[0, %arg4] [64, 16] [1, 1] + scf.yield %inserted : tensor<64x512xf32> + } + } +} +``` + +### Benefits + +1. **Reduced loop count**: 3 loops → 2 loops (fused max+sum, division) +2. **Better memory locality**: Input data is read only twice instead of three times + +This optimization is particularly valuable for GPU implementations where memory bandwidth is the bottleneck, as reducing the number of passes over the input data can significantly improve performance despite the increased computational complexity per iteration. diff --git a/reduction_tiling_docs/softmax_lowering_flow.pdf b/reduction_tiling_docs/softmax_lowering_flow.pdf new file mode 100644 index 00000000..ed96b0bb Binary files /dev/null and b/reduction_tiling_docs/softmax_lowering_flow.pdf differ diff --git a/reduction_tiling_docs/softmax_lowering_flow.tex b/reduction_tiling_docs/softmax_lowering_flow.tex new file mode 100644 index 00000000..bc378c29 --- /dev/null +++ b/reduction_tiling_docs/softmax_lowering_flow.tex @@ -0,0 +1,567 @@ +\documentclass[aspectratio=169]{beamer} +\usepackage[utf8]{inputenc} +\usepackage{listings} +\usepackage{xcolor} +\usepackage{amsmath} +\usepackage{booktabs} + +\usetheme{default} +\usecolortheme{default} + +% MLIR syntax highlighting - simplified for slides +\lstdefinelanguage{MLIR}{ + morekeywords={func, memref, tensor, linalg, scf, forall, for, arith, math, gpu, xegpu, vector}, + sensitive=true, + morecomment=[l]{//}, + morestring=[b]", +} + +\lstset{ + language=MLIR, + basicstyle=\ttfamily\small, + keywordstyle=\color{blue}\bfseries, + commentstyle=\color{gray}\itshape, + stringstyle=\color{red}, + numbers=none, + backgroundcolor=\color{white}, + showspaces=false, + showstringspaces=false, + showtabs=false, + frame=single, + tabsize=2, + breaklines=true, + breakatwhitespace=false, + escapeinside={\%*}{*)} +} + +\title{Softmax Lowering Flow: IR Transformation Stages} +\author{} +\date{} + +\begin{document} + +\begin{frame} +\titlepage +\end{frame} + +\begin{frame}{Problem Setup} +\textbf{Input Shape}: $1024 \times 512 \times f32$ (1024 rows, 512 columns) + +\textbf{Softmax Dimension}: dim=1 (along the 512-element rows) + +\vspace{0.5cm} + +\textbf{Softmax Formula}: +$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}$$ +\end{frame} + +\begin{frame}[fragile]{Stage 1: Initial IR} +Single high-level \texttt{linalg.softmax} operation on the full tensor. + +\begin{lstlisting} +// Pseudo code: +output = softmax(input[1024x512], dim=1) +\end{lstlisting} + +\vspace{0.3cm} + +\textbf{Key Points}: +\begin{itemize} +\item Single operation on entire tensor +\item No parallelism or tiling yet +\end{itemize} +\end{frame} + +\begin{frame}[fragile]{Stage 2: After Tiling Parallel Dimension} +Parallel dimension (rows) tiled into 16 chunks of 64 rows each. + +\begin{lstlisting} +// Pseudo code: +parallel for tile_id in [0..16): + row_offset = tile_id * 64 + slice = input[row_offset:row_offset+64, 0:512] + output_slice = softmax(slice, dim=1) + output[row_offset:row_offset+64, :] = output_slice +\end{lstlisting} + +\vspace{0.3cm} + +\textbf{Key Points}: +\begin{itemize} +\item 16 parallel tiles: $1024 / 64 = 16$ +\item Each tile: $64 \times 512$ +\item Introduces \texttt{scf.forall} for parallel execution +\end{itemize} +\end{frame} + +\begin{frame}[fragile]{Stage 3: After Decomposing Softmax} +Softmax decomposed into 4 operations: +\textbf{max} $\rightarrow$ \textbf{center+exp} $\rightarrow$ \textbf{sum} $\rightarrow$ \textbf{division} + +\begin{lstlisting} +// Pseudo code: +parallel for tile_id in [0..16): + slice = input[tile_id*64:(tile_id+1)*64, :] + + // Step 1: Max reduction (64,512) -> (64,) + max_vals = reduce_max(slice, dim=1) + + // Step 2: Center and exp (64,512) -> (64,512) + exp_vals = exp(slice - max_vals) + + // Step 3: Sum reduction (64,512) -> (64,) + sum_vals = reduce_sum(exp_vals, dim=1) + + // Step 4: Division (64,512) -> (64,512) + output_slice = exp_vals / sum_vals +\end{lstlisting} +\end{frame} + +\begin{frame}[fragile]{Stage 4: After Tiling Division} +Division operation tiled along dimension 1 into chunks of 16 columns. + +\begin{lstlisting} +// Pseudo code: +parallel for tile_id in [0..16): + slice = input[tile_id*64:(tile_id+1)*64, :] + + max_vals = reduce_max(slice, dim=1) // (64,512)->(64,) + exp_vals = exp(slice - max_vals) // (64,512)->(64,512) + sum_vals = reduce_sum(exp_vals, dim=1) // (64,512)->(64,) + + // Division tiled: 32 iterations (512/16 = 32) + for col_offset in [0:512:16]: + exp_tile = exp_vals[:, col_offset:col_offset+16] + output[:, col_offset:col_offset+16] = exp_tile / sum_vals +\end{lstlisting} + +\vspace{0.2cm} +\textbf{Key}: Division loop operates on $64 \times 16$ tiles +\end{frame} + +\begin{frame}[fragile]{Stage 5: Fusing Center+Exp into Division Loop} +Recompute center-and-exp on-the-fly in division loop. + +\begin{lstlisting} +// Pseudo code: +parallel for tile_id in [0..16): + slice = input[tile_id*64:(tile_id+1)*64, :] + + max_vals = reduce_max(slice, dim=1) // (64,) + + // Still materialized for sum reduction + exp_vals = exp(slice - max_vals) // (64,512) + sum_vals = reduce_sum(exp_vals, dim=1) // (64,) + + // Division loop: recompute exp on-the-fly + for col_offset in [0:512:16]: + input_tile = slice[:, col_offset:col_offset+16] + exp_tile = exp(input_tile - max_vals) // Recomputed + output[:, col_offset:col_offset+16] = exp_tile / sum_vals +\end{lstlisting} + +\vspace{0.2cm} +\textbf{Benefit}: Reduces memory footprint (partial recomputation) +\end{frame} + +\begin{frame}[fragile]{Stage 6: After Tiling Sum Reduction} +Sum reduction tiled into 16-column chunks with partial sums. + +\begin{lstlisting} +// Pseudo code: +parallel for tile_id in [0..16): + slice = input[tile_id*64:(tile_id+1)*64, :] + + max_vals = reduce_max(slice, dim=1) // (64,) + exp_vals = exp(slice - max_vals) // (64,512) + + // Tiled sum reduction: accumulate into buffer (64,16) + sum_buffer = zeros(64, 16) + for col_offset in [0:512:16]: + exp_tile = exp_vals[:, col_offset:col_offset+16] + sum_buffer += exp_tile // Accumulate + + sum_vals = reduce_sum(sum_buffer, dim=1) // (64,16)->(64,) + + // Division loop (same as before) + for col_offset in [0:512:16]: ... +\end{lstlisting} + +\vspace{0.2cm} +\textbf{Key}: Sum uses partial accumulation buffer +\end{frame} + +\begin{frame}[fragile]{Stage 7: Fusing into Sum Reduction Loop} +Fuse center-and-exp into sum reduction loop as well. + +\begin{lstlisting} +// Pseudo code: +parallel for tile_id in [0..16): + slice = input[tile_id*64:(tile_id+1)*64, :] + + max_vals = reduce_max(slice, dim=1) // (64,) + + // Sum reduction with fused center+exp + sum_buffer = zeros(64, 16) + for col_offset in [0:512:16]: + input_tile = slice[:, col_offset:col_offset+16] + exp_tile = exp(input_tile - max_vals) // Fused + sum_buffer += exp_tile + + sum_vals = reduce_sum(sum_buffer, dim=1) // (64,) + + // Division loop with fused center+exp+div + for col_offset in [0:512:16]: + input_tile = slice[:, col_offset:col_offset+16] + exp_tile = exp(input_tile - max_vals) // Recomputed + output[:, col_offset:col_offset+16] = exp_tile / sum_vals +\end{lstlisting} +\end{frame} + +\section{Stage 8: After Tiling Max Reduction} + +Max reduction also tiled into 16-column chunks with partial max followed by final reduction. + +\begin{lstlisting} +func.func @payload(%arg0: memref<1024x512xf32>, %arg1: memref<1024x512xf32>) { + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x512xf32>) { + %slice = tensor.extract_slice %0[%3, 0] [64, 512] [1, 1] + + // Tiled max reduction: accumulate into 64x16 buffer + %8 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %7) -> (tensor<64x16xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Max accumulation + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7 : tensor<64x16xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %15 = arith.maxnumf %in, %out : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + + %inserted = tensor.insert_slice %14 into %arg5[0, 0] [64, 16] [1, 1] + scf.yield %inserted : tensor<64x16xf32> + } + + // Final max reduction: (64,16) -> (64,) + %reduced = linalg.reduce ins(%8 : tensor<64x16xf32>) outs(%5 : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %14 = arith.maxnumf %in, %init : f32 + linalg.yield %14 : f32 + } + } + + // Sum reduction loop with fused center+exp + %12 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Fused center+exp using reduced max + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7, %reduced : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %16 = arith.subf %in, %in_9 : f32 + %17 = math.exp %16 : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + + // Sum accumulation + %15 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%14 : tensor<64x16xf32>) outs(%arg5 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %16 = arith.addf %in, %out : f32 + linalg.yield %16 : f32 + } -> tensor<64x16xf32> + + scf.yield %15 : tensor<64x16xf32> + } + + // Final sum reduction: (64,16) -> (64,) + %reduced_6 = linalg.reduce ins(%12 : tensor<64x16xf32>) outs(%9 : tensor<64xf32>) dimensions = [1] { + (%in: f32, %init: f32) { + %14 = arith.addf %in, %init : f32 + linalg.yield %14 : f32 + } + } + + // Division loop with fused center+exp+div + %13 = scf.for %arg4 = %c0 to %c512 step %c16 iter_args(%arg5 = %slice_1) -> (tensor<64x512xf32>) { + %slice_7 = tensor.extract_slice %slice[0, %arg4] [64, 16] [1, 1] + + // Fused center+exp + %14 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%slice_7, %reduced : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %16 = arith.subf %in, %in_9 : f32 + %17 = math.exp %16 : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + + // Division + %15 = linalg.generic {iterator_types = ["parallel", "parallel"]} + ins(%14, %reduced_6 : tensor<64x16xf32>, tensor<64xf32>) outs(%slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %16 = arith.divf %in, %in_9 : f32 + linalg.yield %16 : f32 + } -> tensor<64x16xf32> + + %inserted = tensor.insert_slice %15 into %arg5[0, %arg4] [64, 16] [1, 1] + scf.yield %inserted : tensor<64x512xf32> + } + + scf.forall.in_parallel { + tensor.parallel_insert_slice %13 into %arg3[%3, 0] [64, 512] [1, 1] + } + } +} +\end{lstlisting} + +\begin{frame}[fragile]{Stage 8: After Tiling Max Reduction (Part 1)} +Max reduction tiled with partial buffers, creating 3 loops total. + +\begin{lstlisting} +// Pseudo code (continued from Stage 7): +parallel for tile_id in [0..16): + slice = input[tile_id*64:(tile_id+1)*64, :] + + // Loop 1: Tiled max reduction + max_buffer = fill(-inf, 64, 16) + for col_offset in [0:512:16]: + input_tile = slice[:, col_offset:col_offset+16] + max_buffer = max(max_buffer, input_tile) + max_vals = reduce_max(max_buffer, dim=1) // (64,) + + // Loop 2: Sum reduction with fused center+exp + sum_buffer = zeros(64, 16) + for col_offset in [0:512:16]: + input_tile = slice[:, col_offset:col_offset+16] + exp_tile = exp(input_tile - max_vals) + sum_buffer += exp_tile + sum_vals = reduce_sum(sum_buffer, dim=1) // (64,) +\end{lstlisting} +\end{frame} + +\begin{frame}[fragile]{Stage 8: After Tiling Max Reduction (Part 2)} +\begin{lstlisting} + // Loop 3: Division with fused center+exp+div + for col_offset in [0:512:16]: + input_tile = slice[:, col_offset:col_offset+16] + exp_tile = exp(input_tile - max_vals) + output[:, col_offset:col_offset+16] = exp_tile / sum_vals +\end{lstlisting} + +\vspace{0.3cm} + +\textbf{Key Points}: +\begin{itemize} +\item All 3 operations (max, sum, div) are now tiled +\item 3 separate loops over columns (32 iterations each) +\item Each loop processes $64 \times 16$ tiles +\end{itemize} +\end{frame} + +\begin{frame}[fragile]{Stage 9: Final Vectorized XeGPU Version} +After vectorization, bufferization, and XeGPU lowering. + +\begin{lstlisting} +// Pseudo code: +gpu.kernel: + block_id = get_block_id() + slice = input[block_id*64:(block_id+1)*64, :] + + // Allocate SLM buffer (64x16) for partial reductions + slm_buffer = alloc_shared_memory(64, 16) + + // Loop 1: Max reduction (32 iterations) + slm_buffer = fill(-inf) + for col_offset in [0:512:16]: + tile = load_vector(slice[:, col_offset:col_offset+16]) + slm_buffer = max(slm_buffer, tile) // Update in SLM + max_vals = reduce_across_cols(slm_buffer) + + // Loop 2: Sum reduction (32 iterations) + slm_buffer = zeros() + for col_offset in [0:512:16]: + tile = load_vector(slice[:, col_offset:col_offset+16]) + exp_tile = exp(tile - max_vals) + slm_buffer += exp_tile // Accumulate in SLM + sum_vals = reduce_across_cols(slm_buffer) +\end{lstlisting} +\end{frame} + +\begin{frame}[fragile]{Stage 9: Final Vectorized XeGPU Version (cont'd)} +\begin{lstlisting} + // Loop 3: Division (32 iterations) + for col_offset in [0:512:16]: + tile = load_vector(slice[:, col_offset:col_offset+16]) + exp_tile = exp(tile - max_vals) + result = exp_tile / sum_vals + store_vector(output[:, col_offset:col_offset+16], result) +\end{lstlisting} + +\vspace{0.3cm} + +\textbf{Key Features}: +\begin{itemize} +\item Uses vector operations ($64 \times 16$ SIMD) +\item Shared Local Memory (SLM) for partial reductions +\item XeGPU dialect for Intel GPU operations +\end{itemize} +\end{frame} + +\begin{frame}{Summary of Transformations} +\begin{table} +\centering +\small +\begin{tabular}{@{}rll@{}} +\toprule +\textbf{Stage} & \textbf{Key Transformation} & \textbf{Loop Structure} \\ +\midrule +1 & Initial high-level softmax & No loops \\ +2 & Tile parallel dimension & \texttt{forall(16)} \\ +3 & Decompose softmax & \texttt{forall(16)} + 4 ops \\ +4 & Tile division & \texttt{forall(16)} $\rightarrow$ \texttt{for(32)} \\ +5 & Fuse into division loop & Recompute center+exp \\ +6 & Tile sum reduction & Add sum loop \\ +7 & Fuse into sum loop & Recompute center+exp \\ +8 & Tile max reduction & 3 loops total \\ +9 & Vectorize + XeGPU & GPU with SLM \\ +\bottomrule +\end{tabular} +\end{table} + +\vspace{0.3cm} + +\textbf{Final pattern per GPU block}: 3 loops of 32 iterations each +\begin{enumerate} +\item Max reduction $\rightarrow$ SLM $\rightarrow$ final reduction +\item Sum reduction (fused center+exp) $\rightarrow$ SLM $\rightarrow$ final reduction +\item Division (fused center+exp+div) $\rightarrow$ global memory +\end{enumerate} +\end{frame} + +\begin{frame}{Optimization: Fusing Max and Sum Loops} +After Stage 8, we can fuse max and sum loops into one. + +\textbf{Result}: 3 loops $\rightarrow$ 2 loops + +\vspace{0.5cm} + +\textbf{Key Insight}: \emph{Online Softmax Algorithm} + +Incrementally update both global max and sum as we process each tile: + +\begin{enumerate} +\item Compute \textbf{local max} for the tile +\item Update \textbf{global max} = $\max(\text{old\_max}, \text{local\_max})$ +\item Compute \textbf{local sum} = $\sum \exp(x - \text{local\_max})$ +\item \textbf{Rescale} global sum by $\exp(\text{old\_max} - \text{new\_max})$ +\item \textbf{Add} rescaled local sum to global sum +\end{enumerate} + +\vspace{0.3cm} + +Maintains numerical stability while reducing memory bandwidth! +\end{frame} + +\begin{frame}[fragile]{Before Fusion: 3 Separate Loops} +\begin{lstlisting} +// Pseudo code: +parallel for tile_id in [0..16): + slice = input[tile_id*64:(tile_id+1)*64, :] + + // Loop 1: Max reduction + max_buffer = fill(-inf, 64, 16) + for col_offset in [0:512:16]: + tile = slice[:, col_offset:col_offset+16] + max_buffer = max(max_buffer, tile) + max_vals = reduce_max(max_buffer, dim=1) + + // Loop 2: Sum reduction + sum_buffer = zeros(64, 16) + for col_offset in [0:512:16]: + tile = slice[:, col_offset:col_offset+16] + exp_tile = exp(tile - max_vals) // Uses final max + sum_buffer += exp_tile + sum_vals = reduce_sum(sum_buffer, dim=1) + + // Loop 3: Division + for col_offset in [0:512:16]: ... +\end{lstlisting} +\end{frame} + +\begin{frame}[fragile]{After Fusion: 2 Loops (Online Softmax)} +\begin{lstlisting} +// Pseudo code: +parallel for tile_id in [0..16): + slice = input[tile_id*64:(tile_id+1)*64, :] + + // Loop 1: Fused max+sum (online softmax algorithm) + global_max = fill(-inf, 64) + global_sum = zeros(64) + + for col_offset in [0:512:16]: + tile = slice[:, col_offset:col_offset+16] + + // Update max incrementally + local_max = reduce_max(tile, dim=1) + new_max = max(global_max, local_max) + + // Compute local sum centered on local_max + local_sum = sum(exp(tile - local_max), dim=1) + + // Rescale and accumulate sum + correction = exp(global_max - new_max) + local_correction = exp(local_max - new_max) + global_sum = global_sum * correction + local_sum * local_correction + + global_max = new_max +\end{lstlisting} +\end{frame} + +\begin{frame}[fragile]{After Fusion: 2 Loops (cont'd)} +\begin{lstlisting} + // Loop 2: Division (same as before) + for col_offset in [0:512:16]: + tile = slice[:, col_offset:col_offset+16] + exp_tile = exp(tile - global_max) + output[:, col_offset:col_offset+16] = exp_tile / global_sum +\end{lstlisting} + +\vspace{0.5cm} + +\textbf{Benefits}: +\begin{itemize} +\item Reduced loop count: 3 $\rightarrow$ 2 loops +\item Better memory locality: Input read twice instead of three times +\item Same numerical stability (still uses max-centering) +\end{itemize} + +\vspace{0.3cm} + +\textbf{Trade-off}: More computation per iteration (exponentials for rescaling) +\end{frame} + +\begin{frame}{Summary} +\textbf{Softmax Lowering Journey}: 9 transformation stages + +\begin{enumerate} +\item Start: High-level \texttt{linalg.softmax} operation +\item Decompose into: max $\rightarrow$ center+exp $\rightarrow$ sum $\rightarrow$ div +\item Tile parallel dimension (16 workgroups) +\item Tile reduction dimension (32 iterations per loop) +\item Fuse operations to reduce memory footprint +\item Lower to GPU with vectorization and SLM +\item \textbf{Optional}: Fuse max+sum loops (online softmax) +\end{enumerate} + +\vspace{0.5cm} + +\textbf{Key Techniques}: +\begin{itemize} +\item Tiling for parallelism and memory hierarchy +\item Fusion for memory efficiency (recomputation) +\item Online softmax for bandwidth optimization +\end{itemize} +\end{frame} + +\end{document}