diff --git a/examples/xegpu/fused_attention.py b/examples/xegpu/fused_attention.py new file mode 100644 index 00000000..65302764 --- /dev/null +++ b/examples/xegpu/fused_attention.py @@ -0,0 +1,387 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU fused attention benchmark. +""" + +import argparse +from typing import Optional +from functools import cached_property + +import numpy as np +from mlir import ir + +from lighthouse import dialects as lh_dialects +from lighthouse.execution.runner import Runner +from lighthouse.pipeline.driver import TransformDriver +from lighthouse.execution import GPUMemoryManager +from lighthouse.utils.numpy import mlir_to_numpy_dtype +from lighthouse.ingress.mlir_gen import get_mlir_elem_type +from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import ( + generate_gpu_fused_attention_payload, +) +from lighthouse.schedule.xegpu import fused_attention_schedule, xegpu_to_binary + + +def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int): + """ + Complexity of fused attention operation. + + For each batch and head: + - Q @ K^T: O(n_ctx^2 * n_head) operations + - Softmax: O(n_ctx^2) operations + - Attention @ V: O(n_ctx^2 * n_head) operations + Total: approximately 2*n_ctx^2*n_head FLOPs per batch and head + """ + # Approximation: 2 * n_ctx^2 * n_head FLOPs per batch and head + flop_count = Z * H * 2 * n_ctx * n_ctx * n_head + # Memory: read Q, K, V and write output + memory_reads = 3 * Z * H * n_ctx * n_head * nbytes + memory_writes = Z * H * n_ctx * n_head * nbytes + return flop_count, memory_reads, memory_writes + + +def check_correctness( + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, + output_arr: np.ndarray, + verbose: int = 0, +) -> bool: + """ + Check correctness of fused attention output. + + Reference implementation: + - scores = Q @ K^T / sqrt(n_head) + - attention_weights = softmax(scores, dim=-1) + - output = attention_weights @ V + """ + # Use float32 for computation + Q_f32 = Q.astype(np.float32) + K_f32 = K.astype(np.float32) + V_f32 = V.astype(np.float32) + + Z, H, n_ctx, n_head = Q.shape + scale = 1.0 / np.sqrt(n_head) + + output_ref = np.zeros_like(Q_f32) + + # Compute reference for each batch and head + for z in range(Z): + for h in range(H): + # scores = Q @ K^T / sqrt(n_head) + scores = Q_f32[z, h] @ K_f32[z, h].T * scale + + # softmax along last dimension + max_vals = np.max(scores, axis=1, keepdims=True) + exp_vals = np.exp(scores - max_vals) + sum_vals = np.sum(exp_vals, axis=1, keepdims=True) + attention_weights = exp_vals / sum_vals + + # output = attention_weights @ V + output_ref[z, h] = attention_weights @ V_f32[z, h] + + output = output_arr.astype(np.float32) + + if verbose > 1: + print("Reference solution (first batch, first head, first 5 rows):") + print(output_ref[0, 0, :5]) + print("Computed solution (first batch, first head, first 5 rows):") + print(output[0, 0, :5]) + + # Check values match reference + values_ok = np.allclose(output, output_ref, rtol=1e-3, atol=1e-4) + + success = values_ok + + if verbose: + if success: + print("PASSED") + else: + print("FAILED!") + if not values_ok: + max_diff = np.abs(output - output_ref).max() + print(f" Values mismatch. Max abs diff: {max_diff:.6e}") + return success + + +class XeGPUFusedAttention: + """ + Fused attention workload on XeGPU. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + All Q, K, V matrices have shape (Z, H, n_ctx, n_head) where: + - Z: batch size + - H: number of heads + - n_ctx: context length + - n_head: head dimension + """ + + def __init__( + self, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: str = "f16", + ): + self.Z = Z + self.H = H + self.n_ctx = n_ctx + self.n_head = n_head + self.shape = (Z, H, n_ctx, n_head) + assert dtype == "f16", "Only f16 type is supported for fused attention" + self.elem_type = get_mlir_elem_type(dtype) + self.dtype = mlir_to_numpy_dtype(self.elem_type) + self.memory_manager_class = GPUMemoryManager + self.payload_function_name = "payload" + + @cached_property + def _initial_host_arrays(self) -> tuple[np.ndarray]: + """Generate initial values on host with numpy.""" + np.random.seed(42) + # Initialize Q, K, V with small random values + Q = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + K = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + V = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype) + output_arr = np.zeros(self.shape, dtype=self.dtype) + return (output_arr, Q, K, V) + + def get_complexity(self) -> tuple[int, int, int]: + nbytes = np.dtype(self.dtype).itemsize + return fused_attention_complexity( + self.Z, self.H, self.n_ctx, self.n_head, nbytes + ) + + def payload_module(self) -> ir.Module: + """Generate MLIR module for fused attention payload.""" + return generate_gpu_fused_attention_payload( + func_name=self.payload_function_name, + Z=self.Z, + H=self.H, + n_ctx=self.n_ctx, + n_head=self.n_head, + dtype=self.elem_type, + ) + + def schedule_modules( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> list[ir.Module]: + """Generate transform schedule for fused attention.""" + schedules = [] + schedules.append(Runner.get_bench_wrapper_schedule(self.payload_function_name)) + + schedules.append( + fused_attention_schedule( + stop_at_stage=stop_at_stage, + parameters=parameters, + ) + ) + + if stop_at_stage and stop_at_stage != "final": + return schedules + + schedules.append(xegpu_to_binary()) + + return schedules + + def shared_libs(self) -> list[str]: + return ["libmlir_levelzero_runtime.so"] + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Fused Attention using MLIR XeGPU", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--batch-size", + type=int, + default=2, + help="Batch size (Z)", + ) + parser.add_argument( + "--num-heads", + type=int, + default=8, + help="Number of attention heads (H)", + ) + parser.add_argument( + "--n-ctx", + type=int, + default=4096, + help="Context length (sequence length)", + ) + parser.add_argument( + "--n-head", + type=int, + default=64, + help="Head dimension", + ) + parser.add_argument( + "--wg-rows", + type=int, + default=128, + help="Number of Q*K^T*V rows computed by each work group", + ) + parser.add_argument( + "--sg-rows", + type=int, + default=16, + help="Number of Q*K^T*V rows computed by each subgroup", + ) + parser.add_argument( + "--subgroup-size", + type=int, + default=16, + help="Subgroup size", + ) + parser.add_argument( + "--inner-loop-tile-size", + type=int, + default=64, + help="Tile size for the inner reduction dimension (K/V sequence length)", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the fused attention computation.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "outer-tiled", + "inner-tiled", + "vectorized", + "bufferized", + "gpu-outlining", + "xegpu-initial", + "xegpu-wg", + "final", + ], + help="Dump kernel IR at different stages of lowering and exit without " + "executing the kernel.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + parser.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="Increase output verbosity (e.g. print reference and computed solutions).", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "batch_size": args.batch_size, + "num_heads": args.num_heads, + "n_ctx": args.n_ctx, + "n_head": args.n_head, + "wg_rows": args.wg_rows, + "sg_rows": args.sg_rows, + "subgroup_size": args.subgroup_size, + "inner_loop_tile_size": args.inner_loop_tile_size, + } + + Z = args.batch_size + H = args.num_heads + n_ctx = args.n_ctx + n_head = args.n_head + dtype = "f16" + + with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUFusedAttention(Z=Z, H=H, n_ctx=n_ctx, n_head=n_head, dtype=dtype) + + if args.dump_kernel or args.dump_schedule: + pipeline = TransformDriver( + wload.schedule_modules( + stop_at_stage=args.dump_kernel, parameters=params + ) + ) + payload = pipeline.apply(wload.payload_module()) + if args.dump_kernel: + print(payload) + if args.dump_schedule: + for schedule_module in wload.schedule_modules(parameters=params): + print(schedule_module) + else: + pipeline = TransformDriver(wload.schedule_modules(parameters=params)) + payload = pipeline.apply(wload.payload_module()) + runner = Runner( + payload, + mem_manager_cls=wload.memory_manager_class, + shared_libs=wload.shared_libs(), + ) + if args.check_result: + # Setup callback function to copy result from device to host. + result_host_copy, argument_access_callback = ( + Runner.get_gpu_argument_access_callback(wload.shape, wload.dtype) + ) + + # Execute kernel once. + runner.execute( + host_input_buffers=wload._initial_host_arrays, + payload_function_name=wload.payload_function_name, + argument_access_callback=argument_access_callback, + ) + + # Compute reference solution on host. + Q, K, V = wload._initial_host_arrays[1:4] + success = check_correctness( + Q, + K, + V, + result_host_copy, + verbose=args.verbose, + ) + if not success: + raise ValueError("Result mismatch!") + else: + print("Result is correct. Proceeding to benchmark...") + + times = runner.benchmark( + host_input_buffers=wload._initial_host_arrays, + nruns=args.nruns, + nwarmup=args.nwarmup, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + print( + f"batch-size={Z} " + f"num-heads={H} " + f"n-ctx={n_ctx} " + f"n-head={n_head} " + f"dt={dtype} " + f"time(us): {elapsed:.2f} " + f"GFLOPS: {gflops:.2f} " + ) diff --git a/lighthouse/dialects/transform/transform_ext/__init__.py b/lighthouse/dialects/transform/transform_ext/__init__.py index 997522a2..eec36b6e 100644 --- a/lighthouse/dialects/transform/transform_ext/__init__.py +++ b/lighthouse/dialects/transform/transform_ext/__init__.py @@ -10,11 +10,13 @@ from .ops.get_tileable_consumers import get_tileable_consumers from .ops.get_tiling_sizes import get_tiling_sizes from .ops.update_address_space import update_address_space +from .ops.generate_fused_attention import generate_fused_attention __all__ = [ "TransformExtensionDialect", "convert_func_results_to_args", "extract_handle", + "generate_fused_attention", "get_named_attribute", "get_named_attribute", "get_tileable_consumers", diff --git a/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py new file mode 100644 index 00000000..35f9822e --- /dev/null +++ b/lighthouse/dialects/transform/transform_ext/ops/generate_fused_attention.py @@ -0,0 +1,518 @@ +"""Transform extension to generate fused attention computation.""" + +import numpy as np +from mlir import ir +from mlir.dialects import ext, transform, arith, scf, math, vector +from mlir.dialects.transform import DiagnosedSilenceableFailure + +from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect + + +class GenerateFusedAttention( + TransformExtensionDialect.Operation, name="generate_fused_attention" +): + """Generate tiled fused attention computation (flash attention optimization). + + Takes Q, K, V loads and scale constant from bufferized IR, and generates an inner + tiled loop that computes fused attention with online softmax using running max and sum. + + This implements the flash attention algorithm where: + 1. The computation is tiled along the reduction dimension (K/V sequence length) + 2. Online max and sum are maintained across tiles + 3. Output is incrementally updated with rescaled contributions + + Args: + q_load: Handle to Q load operation (vector.transfer_read) + k_load: Handle to K load operation (vector.transfer_read) + v_load: Handle to V load operation (vector.transfer_read) + scale: Handle to scale constant operation (arith.constant) + output: Handle to the output operation to replace (vector.contract) + tile_size: Tile size for the reduction dimension tiling (K/V sequence length) + """ + + q_load: ext.Operand[transform.AnyOpType] + k_load: ext.Operand[transform.AnyOpType] + v_load: ext.Operand[transform.AnyOpType] + scale: ext.Operand[transform.AnyOpType] + output: ext.Operand[transform.AnyOpType] + tile_size: ir.IntegerAttr + new_output: ext.Result[transform.AnyOpType[()]] = ext.infer_result() + + @classmethod + def attach_interface_impls(cls, ctx=None): + cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx) + + class TransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "GenerateFusedAttention", + rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> DiagnosedSilenceableFailure: + # Get payload operations + q_load_ops = state.get_payload_ops(op.q_load) + k_load_ops = state.get_payload_ops(op.k_load) + v_load_ops = state.get_payload_ops(op.v_load) + scale_ops = state.get_payload_ops(op.scale) + output_ops = state.get_payload_ops(op.output) + + if ( + len(q_load_ops) != 1 + or len(k_load_ops) != 1 + or len(v_load_ops) != 1 + or len(scale_ops) != 1 + or len(output_ops) != 1 + ): + return DiagnosedSilenceableFailure.emit_silenceable_error( + "Expected exactly one operation for each operand" + ) + + q_load_op = q_load_ops[0] + k_load_op = k_load_ops[0] + v_load_op = v_load_ops[0] + scale_op = scale_ops[0] + output_op = output_ops[0] + + # Extract the scale scalar value from scale_op (arith.constant) + scale_attr = scale_op.attributes["value"] + scale_dense_attr = ir.DenseElementsAttr(scale_attr) + scale_np_array = np.array(scale_dense_attr) + scale_value = float(scale_np_array.flat[0]) + + # Extract wg_rows and d_head from q_load result type + q_load_result = q_load_op.results[0] + q_vector_type = ir.VectorType(q_load_result.type) + wg_rows = q_vector_type.shape[0] + d_head = q_vector_type.shape[1] + + # Get tile size + tile_size_value = ir.IntegerAttr(op.tile_size).value + + # Get element type from q_load result + element_type = q_vector_type.element_type + + # Build the fused attention computation + with ir.InsertionPoint(output_op): + # Define m_i_init: vector of shape [wg_rows] with neg_inf values + m_i_vector_type = ir.VectorType.get([wg_rows], element_type) + neg_inf_value = ( + 0xFC00 if element_type == ir.F16Type.get() else float("-inf") + ) + m_i_values = np.full( + wg_rows, + neg_inf_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + m_i_init_attr = ir.DenseElementsAttr.get( + m_i_values, type=m_i_vector_type + ) + m_i_init = arith.constant(m_i_vector_type, m_i_init_attr) + + # Define l_i_init: vector of shape [wg_rows] with zero values + l_i_vector_type = ir.VectorType.get([wg_rows], element_type) + l_i_values = np.zeros( + wg_rows, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + l_i_init_attr = ir.DenseElementsAttr.get( + l_i_values, type=l_i_vector_type + ) + l_i_init = arith.constant(l_i_vector_type, l_i_init_attr) + + # Define acc_init: vector of shape [wg_rows, d_head] with zero values + acc_vector_type = ir.VectorType.get([wg_rows, d_head], element_type) + acc_values = np.zeros( + (wg_rows, d_head), + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + acc_init_attr = ir.DenseElementsAttr.get( + acc_values, type=acc_vector_type + ) + acc_init = arith.constant(acc_vector_type, acc_init_attr) + + # Get n_ctx from k_load result type (first dimension size) + k_load_result = k_load_op.results[0] + k_vector_type = ir.VectorType(k_load_result.type) + n_ctx = k_vector_type.shape[0] + # Define scale vector: vector of shape [wg_rows] with the scale value + scale_vector_type = ir.VectorType.get([wg_rows], element_type) + scale_values = np.full( + (wg_rows), + scale_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + scale_init_attr = ir.DenseElementsAttr.get( + scale_values, type=scale_vector_type + ) + scale_vector = arith.constant(scale_vector_type, scale_init_attr) + + # Create loop bounds + index_type = ir.IndexType.get() + c0 = arith.constant(index_type, 0) + c_n_ctx = arith.constant(index_type, n_ctx) + c_tile_size = arith.constant(index_type, tile_size_value) + + # Create scf.for loop that iterates from 0 to n_ctx in steps of tile_size + loop = scf.ForOp( + c0, c_n_ctx, c_tile_size, [m_i_init, l_i_init, acc_init] + ) + + with ir.InsertionPoint(loop.body): + # Get the loop induction variable and iter_args + loop_idx = loop.induction_variable + m_i = loop.inner_iter_args[0] + l_i = loop.inner_iter_args[1] + acc = loop.inner_iter_args[2] + + # Get common values for K/V tiling + k_memref = k_load_op.operands[0] + k_load_indices = list(k_load_op.operands[1:-1]) + padding = k_load_op.operands[-1] + in_bounds = k_load_op.attributes.get("in_bounds", None) + k_perm_map = k_load_op.attributes.get("permutation_map", None) + q_value = q_load_op.results[0] + + # Constants for K/V tiling (tile into chunks of 16) + k_subtile_size = 16 + num_k_tiles = tile_size_value // k_subtile_size + + # Create offset constants for each K tile + k_tile_offsets = [] + for i in range(num_k_tiles): + offset = arith.constant(index_type, i * k_subtile_size) + k_tile_offsets.append(offset) + + # Load and process K tiles (unrolled) + # Each K tile is [16, d_head], transposed to [d_head, 16], contracted to [wg_rows, 16] + qkt_chunks = [] + + # Create affine maps for Q@K contraction (used for all tiles) + affine_d0 = ir.AffineExpr.get_dim(0) + affine_d1 = ir.AffineExpr.get_dim(1) + affine_d2 = ir.AffineExpr.get_dim(2) + + q_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) + k_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) + out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) + + indexing_maps = ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(q_map), + ir.AffineMapAttr.get(k_map), + ir.AffineMapAttr.get(out_map), + ] + ) + + iterator_types = ir.ArrayAttr.get( + [ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ] + ) + + # Accumulator for Q@K chunks + qkt_chunk_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + qkt_chunk_acc_values = np.zeros( + (wg_rows, k_subtile_size), + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + qkt_chunk_acc_attr = ir.DenseElementsAttr.get( + qkt_chunk_acc_values, type=qkt_chunk_type + ) + qkt_chunk_acc = arith.constant(qkt_chunk_type, qkt_chunk_acc_attr) + + for tile_idx in range(num_k_tiles): + # Compute the offset index for this tile + k_tile_offset = arith.addi(loop_idx, k_tile_offsets[tile_idx]) + + # Update indices for this K tile + k_tile_indices = k_load_indices.copy() + k_tile_indices[-2] = k_tile_offset + + # Load K tile: [16, d_head] + k_tile_type = ir.VectorType.get( + [k_subtile_size, d_head], element_type + ) + k_tile = vector.TransferReadOp( + k_tile_type, + k_memref, + k_tile_indices, + k_perm_map, + padding, + in_bounds=in_bounds, + ).result + + # Transpose K tile: [16, d_head] -> [d_head, 16] + k_transpose_type = ir.VectorType.get( + [d_head, k_subtile_size], element_type + ) + k_transpose = vector.transpose(k_transpose_type, k_tile, [1, 0]) + + # Contract Q @ K_transpose: [wg_rows, d_head] @ [d_head, 16] -> [wg_rows, 16] + qkt_chunk = vector.contract( + qkt_chunk_type, + q_value, + k_transpose, + qkt_chunk_acc, + indexing_maps=indexing_maps, + iterator_types=iterator_types, + ) + qkt_chunks.append(qkt_chunk) + + # Elementwise maximum across all Q@K chunks + # Build tree of maximumf operations + qkt_max_combined = qkt_chunks[0] + for i in range(1, num_k_tiles): + qkt_max_combined = arith.maximumf( + qkt_max_combined, qkt_chunks[i] + ) + + # Final multi_reduction to get row-wise max: [wg_rows, 16] -> [wg_rows] + qkt_max = vector.multi_reduction( + kind="maxnumf", + source=qkt_max_combined, + acc=m_i_init, + reduction_dims=[1], + ) + + # Scale the max: qkt_max_scaled = qkt_max * scale + # Both have shape [wg_rows] + qkt_max_scaled = arith.mulf(qkt_max, scale_vector) + + # Compute m_ij = max(m_i, qkt_max_scaled) + # Both have shape [wg_rows] + m_ij = arith.maximumf(m_i, qkt_max_scaled) + + # Apply softmax to each Q@K chunk + # Scale constant for chunks: [wg_rows, 16] + scale_chunk_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + scale_chunk_values = np.full( + (wg_rows, k_subtile_size), + scale_value, + dtype=np.float16 + if element_type == ir.F16Type.get() + else np.float32, + ) + scale_chunk_attr = ir.DenseElementsAttr.get( + scale_chunk_values, type=scale_chunk_type + ) + scale_chunk = arith.constant(scale_chunk_type, scale_chunk_attr) + + # Broadcast m_ij from [wg_rows] to [wg_rows, 16] + m_ij_bcasted_type = ir.VectorType.get( + [k_subtile_size, wg_rows], element_type + ) + m_ij_bcasted = vector.broadcast(m_ij_bcasted_type, m_ij) + m_ij_transposed_type = ir.VectorType.get( + [wg_rows, k_subtile_size], element_type + ) + m_ij_transposed = vector.transpose( + m_ij_transposed_type, m_ij_bcasted, [1, 0] + ) + + # Apply softmax to each chunk + qkt_exp_chunks = [] + for qkt_chunk in qkt_chunks: + # Scale: qkt_scaled = qkt_chunk * scale + qkt_scaled = arith.mulf(qkt_chunk, scale_chunk) + + # Center: qkt_centered = qkt_scaled - m_ij_transposed + qkt_centered = arith.subf(qkt_scaled, m_ij_transposed) + + # Exponential: qkt_exp = exp(qkt_centered) + qkt_exp = math.exp(qkt_centered) + qkt_exp_chunks.append(qkt_exp) + + # Elementwise sum across all exp chunks + qkt_exp_combined = qkt_exp_chunks[0] + for i in range(1, num_k_tiles): + qkt_exp_combined = arith.addf( + qkt_exp_combined, qkt_exp_chunks[i] + ) + + # Final multi_reduction to get row-wise sum: [wg_rows, 16] -> [wg_rows] + l_ij = vector.multi_reduction( + kind="add", + source=qkt_exp_combined, + acc=l_i_init, + reduction_dims=[1], + ) + + # Compute alpha = exp(m_i - m_ij) + m_diff = arith.subf(m_i, m_ij) + alpha = math.exp(m_diff) + + # Update l_i: l_i_updated = l_i * alpha + l_ij + l_i_scaled = arith.mulf(l_i, alpha) + l_i_updated = arith.addf(l_i_scaled, l_ij) + + # Broadcast alpha from [wg_rows] to [wg_rows, d_head] + alpha_bcasted_type = ir.VectorType.get( + [d_head, wg_rows], element_type + ) + alpha_bcasted = vector.broadcast(alpha_bcasted_type, alpha) + alpha_transposed_type = ir.VectorType.get( + [wg_rows, d_head], element_type + ) + alpha_transposed = vector.transpose( + alpha_transposed_type, alpha_bcasted, [1, 0] + ) + + # Update accumulator: acc_updated = acc * alpha_bcasted + acc_updated = arith.mulf(acc, alpha_transposed) + + # Load V tiles and compute attention-weighted values + # Get V load parameters + v_memref = v_load_op.operands[0] + v_load_indices = list(v_load_op.operands[1:-1]) + v_padding = v_load_op.operands[-1] + v_in_bounds = v_load_op.attributes.get("in_bounds", None) + v_perm_map = v_load_op.attributes.get("permutation_map", None) + + # Create affine maps for P@V contraction + qkt_exp_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d2]) + v_map = ir.AffineMap.get(3, 0, [affine_d2, affine_d1]) + pv_out_map = ir.AffineMap.get(3, 0, [affine_d0, affine_d1]) + + indexing_maps_pv = ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get(qkt_exp_map), + ir.AffineMapAttr.get(v_map), + ir.AffineMapAttr.get(pv_out_map), + ] + ) + + iterator_types_pv = ir.ArrayAttr.get( + [ + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ir.Attribute.parse("#vector.iterator_type"), + ] + ) + + # Load and process V tiles (unrolled), accumulating results + pv_out = acc_updated + for tile_idx in range(num_k_tiles): + # Compute the offset index for this V tile + v_tile_offset = arith.addi(loop_idx, k_tile_offsets[tile_idx]) + + # Update indices for this V tile + v_tile_indices = v_load_indices.copy() + v_tile_indices[-2] = v_tile_offset + + # Load V tile: [16, d_head] + v_tile_type = ir.VectorType.get( + [k_subtile_size, d_head], element_type + ) + v_tile = vector.TransferReadOp( + v_tile_type, + v_memref, + v_tile_indices, + v_perm_map, + v_padding, + in_bounds=v_in_bounds, + ).result + + # Contract qkt_exp_chunk @ v_tile: [wg_rows, 16] @ [16, d_head] -> [wg_rows, d_head] + # Accumulate into pv_out + pv_out = vector.contract( + acc_vector_type, + qkt_exp_chunks[tile_idx], + v_tile, + pv_out, + indexing_maps=indexing_maps_pv, + iterator_types=iterator_types_pv, + ) + + # Yield the updated iter args + scf.yield_([m_ij, l_i_updated, pv_out]) + + # Extract the final accumulator result (3rd output) from the loop + pv_out = loop.results[2] + l_i_out = loop.results[1] + with ir.InsertionPoint.after(loop): + # Normalize the output: output_final = pv_out / l_i_out + # Need to broadcast l_i_out from [wg_rows] to [wg_rows, d_head] + l_i_out_bcasted_type = ir.VectorType.get( + [d_head, wg_rows], element_type + ) + l_i_out_bcasted = vector.broadcast(l_i_out_bcasted_type, l_i_out) + l_i_out_transposed_type = ir.VectorType.get( + [wg_rows, d_head], element_type + ) + l_i_out_transposed = vector.transpose( + l_i_out_transposed_type, l_i_out_bcasted, [1, 0] + ) + output_final = arith.divf(pv_out, l_i_out_transposed) + + # Replace all uses of the original output operation with the final loop result + output_op.results[0].replace_all_uses_with(output_final) + + # Erase the original output operation + rewriter.erase_op(output_op) + + # Return the final output handle + results.set_ops(op.new_output, [output_final.owner]) + return DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "GenerateFusedAttention") -> bool: + return False + + class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: ir.Operation, effects): + # Read Q, K, scale, V slices + transform.only_reads_handle(op.op_operands[:4], effects) + # Consume and replace output + transform.consumes_handle(op.op_operands[4:5], effects) + # Produce new output handle + transform.produces_handle(op.results, effects) + # Modify the payload + transform.modifies_payload(effects) + + +def generate_fused_attention( + q_load: ir.Value, + k_load: ir.Value, + v_load: ir.Value, + scale: ir.Value, + output: ir.Value, + tile_size: int | ir.IntegerAttr, +) -> ir.Value: + """Generate fused attention computation with inner tiling on bufferized IR. + + Args: + q_load: Handle to Q load operation (vector.transfer_read) + k_load: Handle to K load operation (vector.transfer_read) + v_load: Handle to V load operation (vector.transfer_read) + scale: Handle to scale constant operation (arith.constant) + output: Handle to output operation to replace (vector.contract) + tile_size: Tile size for the reduction dimension tiling (K/V sequence length) + + Returns: + Handle to the new output operation + """ + if not isinstance(tile_size, ir.IntegerAttr): + tile_size = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), tile_size) + + return GenerateFusedAttention( + q_load, k_load, v_load, scale, output, tile_size=tile_size + ).new_output diff --git a/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py new file mode 100644 index 00000000..73873dc6 --- /dev/null +++ b/lighthouse/ingress/mlir_gen/gpu_fused_attention_payload.py @@ -0,0 +1,136 @@ +"""Generate MLIR payload for GPU fused attention operation.""" + +import math + +from mlir import ir +from mlir.dialects import arith, bufferization, linalg, memref, tensor + +from lighthouse.utils.mlir import func_cif +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs +from lighthouse.ingress.mlir_gen.utils import emit_buf_to_tensor + + +def generate_gpu_fused_attention_payload( + func_name: str, + Z: int, + H: int, + n_ctx: int, + n_head: int, + dtype: ir.Type, +) -> ir.Module: + """ + Generate MLIR module for fused attention payload. + + Computes fused attention: + output = softmax(Q @ K^T / sqrt(n_head)) @ V + + Args: + func_name: Name of the payload function + Z: Batch size + H: Number of attention heads + n_ctx: Context length (sequence length) + n_head: Head dimension + dtype: MLIR element type (e.g., F32Type) + + Returns: + MLIR module containing the fused attention payload function + """ + mod = ir.Module.create() + shape = (Z, H, n_ctx, n_head) + memref_t = ir.MemRefType.get(shape, dtype) + + with ir.InsertionPoint(mod.body): + # Collapse first 2 dimensions (Z, H) into a batch dimension + # From (Z, H, n_ctx, n_head) to (Z*H, n_ctx, n_head) + batch_dim = Z * H + collapsed_shape_3d = (batch_dim, n_ctx, n_head) + memref_3d_t = ir.MemRefType.get(collapsed_shape_3d, dtype) + + # Function signature: payload(output, Q, K, V) + @func_cif(memref_t, memref_t, memref_t, memref_t, name=func_name) + def payload(output, Q_arg, K_arg, V_arg): + # Collapse memrefs from 4D to 3D + Q_3d_memref = memref.collapse_shape( + memref_3d_t, + Q_arg, + reassociation=[[0, 1], [2], [3]], + ) + K_3d_memref = memref.collapse_shape( + memref_3d_t, + K_arg, + reassociation=[[0, 1], [2], [3]], + ) + V_3d_memref = memref.collapse_shape( + memref_3d_t, + V_arg, + reassociation=[[0, 1], [2], [3]], + ) + output_3d_memref = memref.collapse_shape( + memref_3d_t, + output, + reassociation=[[0, 1], [2], [3]], + ) + + # Convert 3D memrefs to tensors + Q_3d = emit_buf_to_tensor(Q_3d_memref, restrict=True) + K_3d = emit_buf_to_tensor(K_3d_memref, restrict=True) + V_3d = emit_buf_to_tensor(V_3d_memref, restrict=True) + + # Step 1: Transpose K to get K^T + # Permute from (batch_dim, n_ctx, n_head) to (batch_dim, n_head, n_ctx) + kt_shape_3d = (batch_dim, n_head, n_ctx) + kt_init = tensor.empty(kt_shape_3d, dtype) + K_transposed = linalg.transpose(K_3d, outs=[kt_init], permutation=[0, 2, 1]) + + # Step 2: Compute Q @ K^T using batch_matmul + # Q: (batch_dim, n_ctx, n_head) @ K^T: (batch_dim, n_head, n_ctx) + # Result: (batch_dim, n_ctx, n_ctx) + qkt_shape_3d = (batch_dim, n_ctx, n_ctx) + qkt_init = tensor.empty(qkt_shape_3d, dtype) + # Initialize with zeros for matmul accumulation + zero = arith.constant(dtype, 0.0) + qkt_init_filled = linalg.fill(zero, outs=[qkt_init]) + + # Batch matmul: Q @ K^T + qkt = linalg.batch_matmul(Q_3d, K_transposed, outs=[qkt_init_filled]) + + # Step 3: Scale by 1/sqrt(n_head) + scale_factor = 1.0 / math.sqrt(n_head) + scale_const = arith.constant(dtype, scale_factor) + + # Create a tensor filled with the scale factor + scale_tensor_init = tensor.empty(qkt_shape_3d, dtype) + scale_tensor = linalg.fill(scale_const, outs=[scale_tensor_init]) + + # Elementwise multiply qkt with scale tensor + scaled_qkt_init = tensor.empty(qkt_shape_3d, dtype) + scaled_qkt = linalg.mul(qkt, scale_tensor, outs=[scaled_qkt_init]) + + # Step 4: Apply softmax along the last dimension (dim=2 in 3D) + softmax_init = tensor.empty(qkt_shape_3d, dtype) + attention_weights = linalg.softmax( + result=[ir.RankedTensorType.get(qkt_shape_3d, dtype)], + input=scaled_qkt, + output=softmax_init, + dimension=2, + ) + + # Step 5: Multiply attention weights by V using batch_matmul + # attention_weights: (batch_dim, n_ctx, n_ctx) @ V: (batch_dim, n_ctx, n_head) + # Result: (batch_dim, n_ctx, n_head) + output_3d_init = tensor.empty(collapsed_shape_3d, dtype) + output_3d_init_filled = linalg.fill(zero, outs=[output_3d_init]) + + result_3d = linalg.batch_matmul( + attention_weights, V_3d, outs=[output_3d_init_filled] + ) + + # Materialize 3D result back to 3D output memref + bufferization.materialize_in_destination( + None, result_3d, output_3d_memref, restrict=True, writable=True + ) + + # Emit utility functions for GPU memory management + emit_gpu_util_funcs(dtype, rank=4) + + return mod diff --git a/lighthouse/schedule/xegpu/__init__.py b/lighthouse/schedule/xegpu/__init__.py index 23d9ef0c..76f4101c 100644 --- a/lighthouse/schedule/xegpu/__init__.py +++ b/lighthouse/schedule/xegpu/__init__.py @@ -1,8 +1,10 @@ from .xegpu_to_binary import xegpu_to_binary from .mlp_schedule import mlp_schedule from .softmax_schedule import softmax_schedule +from .fused_attention_schedule import fused_attention_schedule __all__ = [ + "fused_attention_schedule", "mlp_schedule", "softmax_schedule", "xegpu_to_binary", diff --git a/lighthouse/schedule/xegpu/fused_attention_schedule.py b/lighthouse/schedule/xegpu/fused_attention_schedule.py new file mode 100644 index 00000000..cb99e0e9 --- /dev/null +++ b/lighthouse/schedule/xegpu/fused_attention_schedule.py @@ -0,0 +1,477 @@ +"""Generate MLIR transform schedule for XeGPU fused attention operation.""" + +from typing import Optional + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.bufferization import LayoutMapOption +from mlir.dialects.transform.vector import ( + apply_patterns_vector_cast_away_vector_leading_one_dim, + apply_patterns_vector_drop_unit_dims_with_shape_cast, +) + +from lighthouse.pipeline.helper import ( + canonicalize, + match, + match_and_split, + PipelineInterrupt, + apply_registered_pass, +) +from lighthouse.schedule import schedule_boilerplate +from lighthouse.dialects.transform.transform_ext import ( + generate_fused_attention, + update_address_space, +) + + +def fused_attention_schedule( + stop_at_stage: Optional[str] = None, + parameters: Optional[dict] = None, +) -> ir.Module: + """ + Generate transform schedule for fused attention operation. + + The schedule performs the following transformations: + 1. Tile the fused attention operation + 2. Vectorize operations + 3. Bufferize tensors + 4. Convert to GPU dialect + 5. Lower to XeGPU operations + + Args: + stop_at_stage: Optional stage name to stop early (for debugging) + parameters: Dictionary with scheduling parameters: + - batch_size: Batch size (Z) + - num_heads: Number of attention heads (H) + - n_ctx: Context length + - n_head: Head dimension + - wg_rows: Number of Q*K^T*V rows computed by each work group + - sg_rows: Number of Q*K^T*V rows computed by each subgroup + - subgroup_size: Size of subgroup + + Returns: + MLIR module containing the transform schedule + """ + assert parameters is not None, "Schedule parameters must be provided" + + with schedule_boilerplate() as (schedule, named_seq): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_seq.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + + try: + bundle_xegpu_fused_attention_schedule( + payload_mod, + parameters=parameters, + stop_at_stage=stop_at_stage or "", + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + + return schedule + + +def bundle_xegpu_fused_attention_schedule( + mod: ir.Value[transform.AnyOpType], + parameters: dict, + stop_at_stage: str = "", +) -> ir.Value[transform.AnyOpType]: + """Schedule for lowering fused attention payload to xegpu wg level.""" + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + anytype = transform.AnyOpType.get() + # Match all matmul operations - there should be 2: + # 1. Q @ K^T + # 2. attention_weights @ V + matmul_ops = match_and_split(mod, ops={"linalg.batch_matmul"}, nhandles=2) + + # Get the last matmul (attention_weights @ V) + last_matmul = matmul_ops[1] + func = transform.get_parent_op( + anytype, + last_matmul, + op_name="func.func", + deduplicate=True, + ) + + # Tile the last matmul in both batch and M dimensions. + wg_rows = parameters["wg_rows"] + + tiled_matmul, forall_loop = structured.structured_tile_using_forall( + anytype, + anytype, + last_matmul, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(1, wg_rows, 0, 0), + ) + # Fuse the zero initialization of the output of the last matmul (tensor.empty) into the forall loop. + tiled_matmul_init = transform.get_producer_of_operand( + anytype, forall_loop, operand_number=0 + ) + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=tiled_matmul_init, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # Decompose softmax into generic ops + softmax_ops = match_and_split(func, ops={"linalg.softmax"}, nhandles=1) + softmax_op = softmax_ops[0] + structured.structured_decompose_interface(anytype, softmax_op) + transform.apply_cse(func) + canonicalize(func) + + # Fuse all linalg.generic ops from softmax decomposition (4 ops: max, sub+exp, sum, div) + # Match and fuse in reverse order (from consumer to producer) + generic_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=4) + for generic_op in reversed(generic_ops): + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=generic_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + # Max and add reductions use linalg.fill to intialize the reduction output. Fuse these fill ops as well. + fill_ops = match_and_split(func, ops={"linalg.fill"}, nhandles=5) + # Max fill is the third fill op and add fill is the fourth fill op (based on the pattern of decomposition) + max_fill_op = fill_ops[2] + add_fill_op = fill_ops[3] + for fill_op in [max_fill_op, add_fill_op]: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=fill_op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + linalg_mul_op = match_and_split(func, ops={"linalg.mul"}, nhandles=1)[0] + first_matmul = transform.get_producer_of_operand( + anytype, linalg_mul_op, operand_number=0 + ) + scale_fill_op = transform.get_producer_of_operand( + anytype, linalg_mul_op, operand_number=1 + ) + transpose_op = transform.get_producer_of_operand( + anytype, first_matmul, operand_number=1 + ) + matmul_fill_op = transform.get_producer_of_operand( + anytype, first_matmul, operand_number=2 + ) + for op in [ + linalg_mul_op, + scale_fill_op, + first_matmul, + matmul_fill_op, + transpose_op, + ]: + _, forall_loop = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=op, + containing_op=forall_loop, + ) + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "outer-tiled": + raise PipelineInterrupt() + + # vectorize + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + transform.apply_cse(func) + canonicalize(func) + # Try to remove any unit dimensions that may have been introduced due to tiling (e.g. batch dim of 1) + with ir.InsertionPoint(transform.apply_patterns(func).patterns): + apply_patterns_vector_cast_away_vector_leading_one_dim() + apply_patterns_vector_drop_unit_dims_with_shape_cast() + + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + # bufferize + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = transform_bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + transform.apply_cse(mod) + canonicalize(mod) + # fold memref.subviews into vector.transfer_read/write ops + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + # promote memref.alloc to memref.alloca in payload function + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + + # Extract q, k, v memrefs from the bufferized IR + # Match vector.contract ops to find the q, k, v loads + for_all = match(mod, ops={"scf.forall"}) + func = transform.get_parent_op(anytype, for_all, op_name="func.func") + contract_ops = match_and_split(func, ops={"vector.contract"}, nhandles=2) + + # First vector.contract is Q @ K^T + # Its first operand is the q load (vector.transfer_read) + # Its second operand is the k load (vector.transfer_read) + first_contract = contract_ops[0] + q_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=0 + ) + k_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=1 + ) + + # # Second vector.contract is attention_weights @ V + # # Its second operand is the v load (vector.transfer_read) + second_contract = contract_ops[1] + v_load = transform.get_producer_of_operand( + anytype, second_contract, operand_number=1 + ) + + # Match arith.mulf to get the scale parameter + # The scale is the second operand of arith.mulf (the constant) + mulf_op = match_and_split(func, ops={"arith.mulf"}, nhandles=1)[0] + scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) + + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + # Generate fused attention computation with inner tiling + # This replaces the second vector.contract (attention_weights @ V) with a tiled + # loop that implements online softmax for efficient memory usage + tile_size = parameters.get( + "inner_loop_tile_size", 64 + ) # Tile size for reduction dimension (K/V sequence length) + generate_fused_attention( + q_load=q_load, + k_load=k_load, + v_load=v_load, + scale=scale, + output=second_contract, + tile_size=tile_size, + ) + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "inner-tiled": + raise PipelineInterrupt() + + # convert forall to parallel + wg_loops = match_and_split(mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set the number of threads for the gpu.launch operation + launch_op = match_and_split(func, ops={"gpu.launch"}) + wg_rows = parameters["wg_rows"] + sg_rows = parameters["sg_rows"] + subgroup_size = parameters["subgroup_size"] + num_subgroups = wg_rows // sg_rows + num_threads = num_subgroups * subgroup_size + xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # for each gpu function in the gpu module, change memref.alloca address + # space to 3 (SLM) and convert vector to xegpu. + gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + allocas = match_and_split(gpu_func, ops={"memref.alloca"}, nhandles=3) + for alloca in allocas: + # print("Updating address space for alloca:") + update_address_space(alloca, address_space=3) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + gpu_func = apply_registered_pass(gpu_func, "loop-invariant-code-motion") + + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + # Define XeGPU layout parameters + q_sg_layout = [8, 1] + q_sg_data = [16, 64] + q_inst_data = [8, 16] + + k_sg_layout = [8, 1] + k_sg_data = [16, 64] + k_inst_data = [16, 16] + + v_sg_layout = k_sg_layout + v_sg_data = k_sg_data + v_inst_data = k_inst_data + + kt_sg_layout = [1, 8] + kt_sg_data = [64, 16] + kt_inst_data = [16, 16] + kt_order = [0, 1] + + out_sg_layout = q_sg_layout + out_sg_data = q_sg_data + out_inst_data = q_inst_data + + layout_128x16_sg_layout = [8, 1] + layout_128x16_sg_data = [16, 16] + layout_128x16_inst_data = [8, 16] + + qk_sg_layout = layout_128x16_sg_layout + qk_sg_data = layout_128x16_sg_data + qk_inst_data = layout_128x16_inst_data + + # Set layout attributes for xegpu.store_nd ops. + store_nd_op = match_and_split(gpu_func, ops={"xegpu.store_nd"}, nhandles=1)[0] + xegpu.set_anchor_layout( + store_nd_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + ) + + # Set layout for xegpu.load_nd ops (9 total: 1 Q, 4 K, 4 V) + load_nd_ops = match_and_split(gpu_func, ops={"xegpu.load_nd"}, nhandles=9) + + # First load_nd: Q layout + xegpu.set_anchor_layout( + load_nd_ops[0], sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data + ) + + # Next 4 load_nd ops: K layout + for i in range(1, 5): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=k_sg_layout, + sg_data=k_sg_data, + inst_data=k_inst_data, + ) + + # Last 4 load_nd ops: V layout + for i in range(5, 9): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + ) + + # Set layout for xegpu.dpas ops (8 total: 4 for Q@K, 4 for P@V) + dpas_ops = match_and_split(gpu_func, ops={"xegpu.dpas"}, nhandles=8) + + # Layouts for first 4 dpas ops (Q@K^T): + for i in range(4): + qk_dpas_op = dpas_ops[i] + # Index 0: Q layout + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=q_sg_layout, + sg_data=q_sg_data, + inst_data=q_inst_data, + index=0, + ) + # Index 1: K^T layout + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=kt_sg_layout, + sg_data=kt_sg_data, + inst_data=kt_inst_data, + order=kt_order, + index=1, + ) + # Index 2: QK output layout (128x16) + xegpu.set_anchor_layout( + qk_dpas_op, + sg_layout=layout_128x16_sg_layout, + sg_data=layout_128x16_sg_data, + inst_data=layout_128x16_inst_data, + index=2, + ) + + # Layouts for second 4 dpas ops (P@V): + for i in range(4, 8): + pv_dpas_op = dpas_ops[i] + # Index 0: QK (attention weights) layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=qk_sg_layout, + sg_data=qk_sg_data, + inst_data=qk_inst_data, + index=0, + ) + # Index 1: V layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + index=1, + ) + # Index 2: Output layout + xegpu.set_anchor_layout( + pv_dpas_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + index=2, + ) + + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + + return mod