diff --git a/docs/IREEAttentionLowering.md b/docs/IREEAttentionLowering.md new file mode 100644 index 00000000..2b863bb1 --- /dev/null +++ b/docs/IREEAttentionLowering.md @@ -0,0 +1,335 @@ +# IREE Attention Lowering Pipeline + +## 1. ConvertAttentionToOnlineAttentionPass + +### Overview + +The `ConvertAttentionToOnlineAttentionPass` transforms a standard (offline) attention operation (`iree_linalg_ext.attention`) into an **online attention** operation (`iree_linalg_ext.online_attention`). Online attention computes attention in a **tiled/streaming** fashion, maintaining running max and running sum accumulators to perform numerically stable softmax incrementally — this is the core idea behind **FlashAttention**. + + +### Before the Pass + +A standard `iree_linalg_ext.attention` op with Q, K, V, scale, and output: + +```mlir +%result = iree_linalg_ext.attention { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1)>, // Q: (m, k1) + affine_map<(d0, d1, d2, d3) -> (d2, d1)>, // K: (k2, k1) + affine_map<(d0, d1, d2, d3) -> (d2, d3)>, // V: (k2, n) + affine_map<(d0, d1, d2, d3) -> ()>, // scale + affine_map<(d0, d1, d2, d3) -> (d0, d3)> // output: (m, n) + ]} + ins(%Q, %K, %V, %scale : tensor<16x64xf32>, tensor<4096x64xf32>, + tensor<4096x64xf32>, f32) + outs(%output : tensor<16x64xf32>) + -> tensor<16x64xf32> +``` + +### After the Pass + +The op is converted to `iree_linalg_ext.online_attention` with two additional accumulator outputs — **running max** and **running sum** — initialized to `-inf` and `0` respectively: + +```mlir +// Initialize accumulators +%empty_output = tensor.empty() : tensor<16x64xf32> +%empty_max = tensor.empty() : tensor<16xf32> +%cst_0 = arith.constant 0.000000e+00 : f32 // 0 for output +%cst_neg_inf = arith.constant -3.40282347E+38 : f32 // -inf for max +%cst_zero = arith.constant 0.000000e+00 : f32 // 0 for sum + +%output_acc = linalg.fill ins(%cst_0 : f32) outs(%empty_output) -> tensor<16x64xf32> +%max_init = linalg.fill ins(%cst_neg_inf : f32) outs(%empty_max) -> tensor<16xf32> +%sum_init = linalg.fill ins(%cst_zero : f32) outs(%empty_max) -> tensor<16xf32> + +// Online attention with streaming accumulators +%result:3 = iree_linalg_ext.online_attention { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1)>, // Q: (m, k1) + affine_map<(d0, d1, d2, d3) -> (d2, d1)>, // K: (k2, k1) + affine_map<(d0, d1, d2, d3) -> (d2, d3)>, // V: (k2, n) + affine_map<(d0, d1, d2, d3) -> ()>, // scale + affine_map<(d0, d1, d2, d3) -> (d0, d3)>, // output: (m, n) + affine_map<(d0, d1, d2, d3) -> (d0)>, // running max: (m) + affine_map<(d0, d1, d2, d3) -> (d0)> // running sum: (m) + ]} + ins(%Q, %K, %V, %scale : tensor<16x64xf32>, tensor<4096x64xf32>, + tensor<4096x64xf32>, f32) + outs(%output_acc, %max_init, %sum_init : tensor<16x64xf32>, + tensor<16xf32>, + tensor<16xf32>) { + ^bb0(%arg: f32): + iree_linalg_ext.yield %arg : f32 +} -> tensor<16x64xf32>, tensor<16xf32>, tensor<16xf32> + +// Final normalization: divide accumulated output by the final sum +%final = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0)>, // final sum: (m) + affine_map<(d0, d1) -> (d0, d1)>, // accumulated output: (m, n) + affine_map<(d0, d1) -> (d0, d1)> // normalized output: (m, n) + ], + iterator_types = ["parallel", "parallel"]} + ins(%result#2, %result#0 : tensor<16xf32>, tensor<16x64xf32>) + outs(%empty_output : tensor<16x64xf32>) { + ^bb0(%sum: f32, %acc: f32, %out: f32): + %cst_1 = arith.constant 1.000000e+00 : f32 + %inv_sum = arith.divf %cst_1, %sum : f32 + %normalized = arith.mulf %inv_sum, %acc : f32 + linalg.yield %normalized : f32 +} -> tensor<16x64xf32> +``` + +**Question**: not sure why they do it this way. why not embed the normalization also inside the op? + +### Key Transformations + +| Aspect | Before | After | +|--------|--------|-------| +| **Op** | `iree_linalg_ext.attention` | `iree_linalg_ext.online_attention` | +| **Q shape** | `tensor<16x64xf32>` | `tensor<16x64xf32>` (unchanged) | +| **K shape** | `tensor<4096x64xf32>` | `tensor<4096x64xf32>` (unchanged) | +| **V shape** | `tensor<4096x64xf32>` | `tensor<4096x64xf32>` (unchanged) | +| **Outputs** | 1 (result: `16x64xf32`) | 3 (result: `16x64xf32`, max: `16xf32`, sum: `16xf32`) | +| **Max accumulator** | N/A | Initialized to `-inf` (`-3.40282347E+38`) | +| **Sum accumulator** | N/A | Initialized to `0.0` | +| **Post-processing** | None | Division by final sum (normalization) | +| **Memory** | Materializes full attention matrix | Streams over K/V tiles | + +### Significance in the Pipeline + +This pass is a critical step in IREE's attention lowering pipeline. After this conversion, subsequent passes can **tile the online_attention op along the K2 (key sequence) dimension**, processing chunks of keys/values at a time while maintaining numerically stable softmax via the running max/sum — exactly the FlashAttention algorithm. + +--- + +## 2. DecomposeAttentionPass + +### Overview + +The `DecomposeAttentionPass` (`iree-linalg-ext-decompose-attention`) runs **after** tiling has been applied to the online attention op. It decomposes each tiled `iree_linalg_ext.online_attention` op into a sequence of primitive `linalg.generic` operations that implement the online softmax + attention algorithm explicitly. + +This is the pass that eliminates all custom attention ops and produces standard linalg operations that the rest of the compiler knows how to handle (vectorize, bufferize, map to hardware intrinsics, etc.). + +### Pipeline Context + +By the time `DecomposeAttentionPass` runs, the IR has been through: + +1. `ConvertAttentionToOnlineAttention` — introduced online_attention + max/sum accumulators +2. `TileAndDistributeToWorkgroups` — tiled across batch and query-sequence dims +3. `GPUApplyTilingLevel` (multiple times) — tiled the K2 (key-sequence) reduction dimension into chunks (e.g., tiles of 64 or 128) + +So the input to this pass is a **tiled** online_attention operating on a slice of K/V. + +### Before the Pass (Tiled Online Attention) + +After tiling, the online attention operates on a K2-tile (e.g., 16 keys at a time). This example shows a 16x64 Q-tile processing a 16x64 K-tile and V-tile: + +```mlir +// Inside an scf.for loop over K2 tiles: +// Step size is 16. +%results:3 = iree_linalg_ext.online_attention { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1)>, // Q tile: (m, k1) + affine_map<(d0, d1, d2, d3) -> (d2, d1)>, // K tile: (k2, k1) + affine_map<(d0, d1, d2, d3) -> (d2, d3)>, // V tile: (k2, n) + affine_map<(d0, d1, d2, d3) -> ()>, // scale + affine_map<(d0, d1, d2, d3) -> (d0, d3)>, // acc output: (m, n) + affine_map<(d0, d1, d2, d3) -> (d0)>, // running max: (m) + affine_map<(d0, d1, d2, d3) -> (d0)> // running sum: (m) + ]} + ins(%q_tile, %k_tile, %v_tile, %scale : tensor<16x64xf32>, + tensor<16x64xf32>, + tensor<16x64xf32>, f32) + outs(%acc, %old_max, %old_sum : tensor<16x64xf32>, + tensor<16xf32>, + tensor<16xf32>) + -> tensor<16x64xf32>, tensor<16xf32>, tensor<16xf32> +``` + +### After the Pass (Decomposed to linalg.generic) + +The pass decomposes the single online_attention op into **5 steps**: + +#### Step 1: Compute S = Q @ K^T * scale (matmul + scale) + +```mlir +// S[m, k2] = sum_k1(Q[m, k1] * K[k2, k1]) * scale +%empty_S = tensor.empty() : tensor<16x16xf32> +%zero_S = linalg.fill ins(%cst_0) outs(%empty_S) +%S = linalg.generic { + indexing_maps = [ + affine_map<(m, k2, k1) -> (m, k1)>, // Q + affine_map<(m, k2, k1) -> (k2, k1)>, // K + affine_map<(m, k2, k1) -> ()>, // scale + affine_map<(m, k2, k1) -> (m, k2)> // S (output) + ], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%q_tile, %k_tile, %scale : ...) + outs(%zero_S : tensor<16x16xf32>) { + ^bb0(%q: f32, %k: f32, %s: f32, %out: f32): + %mul = arith.mulf %q, %k : f32 + %scaled = arith.mulf %mul, %s : f32 + %add = arith.addf %scaled, %out : f32 + linalg.yield %add : f32 +} -> tensor<16x16xf32> +``` + +#### Step 2: Compute new_max = max(old_max, rowmax(S)) + +```mlir +// Row-wise max of S, then element-wise max with old_max +%new_max = linalg.generic { + indexing_maps = [ + affine_map<(m, k2) -> (m, k2)>, // S + affine_map<(m, k2) -> (m)> // max accumulator + ], + iterator_types = ["parallel", "reduction"]} + ins(%S : tensor<16x16xf32>) + outs(%old_max : tensor<16xf32>) { + ^bb0(%s_val: f32, %cur_max: f32): + %m = arith.maximumf %s_val, %cur_max : f32 + linalg.yield %m : f32 +} -> tensor<16xf32> +``` + +#### Step 3: Compute P = exp(S - new_max) and correction factor alpha = exp(old_max - new_max) + +```mlir +// Subtract new_max from S and exponentiate: P[m, k2] = exp(S[m, k2] - new_max[m]) +%P = linalg.generic { + indexing_maps = [ + affine_map<(m, k2) -> (m, k2)>, // S + affine_map<(m, k2) -> (m)>, // new_max + affine_map<(m, k2) -> (m, k2)> // P (output) + ], + iterator_types = ["parallel", "parallel"]} + ins(%S, %new_max : ...) + outs(%empty_S : tensor<16x16xf32>) { + ^bb0(%s_val: f32, %max_val: f32, %out: f32): + %sub = arith.subf %s_val, %max_val : f32 + %exp = math.exp %sub : f32 + linalg.yield %exp : f32 +} -> tensor<16x16xf32> + +// Correction factor: alpha[m] = exp(old_max[m] - new_max[m]) +%alpha = linalg.generic { + indexing_maps = [ + affine_map<(m) -> (m)>, // old_max + affine_map<(m) -> (m)>, // new_max + affine_map<(m) -> (m)> // alpha + ], + iterator_types = ["parallel"]} + ins(%old_max, %new_max : ...) + outs(%empty_alpha : tensor<16xf32>) { + ^bb0(%old_m: f32, %new_m: f32, %out: f32): + %sub = arith.subf %old_m, %new_m : f32 + %exp = math.exp %sub : f32 + linalg.yield %exp : f32 +} -> tensor<16xf32> +``` + +#### Step 4: Update sum = alpha * old_sum + rowsum(P) + +```mlir +// Scale old sum by correction factor, then add row sums of P +// new_sum[m] = alpha[m] * old_sum[m] + sum_k2(P[m, k2]) +%scaled_sum = linalg.generic { + indexing_maps = [ + affine_map<(m) -> (m)>, // old_sum + affine_map<(m) -> (m)>, // alpha + affine_map<(m) -> (m)> // output + ], + iterator_types = ["parallel"]} + ins(%old_sum, %alpha : ...) + outs(%empty_sum : tensor<16xf32>) { + ^bb0(%s: f32, %a: f32, %out: f32): + %mul = arith.mulf %s, %a : f32 + linalg.yield %mul : f32 +} -> tensor<16xf32> + +%new_sum = linalg.generic { + indexing_maps = [ + affine_map<(m, k2) -> (m, k2)>, // P + affine_map<(m, k2) -> (m)> // sum accumulator + ], + iterator_types = ["parallel", "reduction"]} + ins(%P : tensor<16x16xf32>) + outs(%scaled_sum : tensor<16xf32>) { + ^bb0(%p_val: f32, %cur_sum: f32): + %add = arith.addf %p_val, %cur_sum : f32 + linalg.yield %add : f32 +} -> tensor<16xf32> +``` + +#### Step 5: Update output = alpha * old_acc + P @ V + +```mlir +// Scale old accumulator by alpha: corrected_acc[m, n] = alpha[m] * old_acc[m, n] +%corrected_acc = linalg.generic { + indexing_maps = [ + affine_map<(m, n) -> (m)>, // alpha + affine_map<(m, n) -> (m, n)>, // old_acc + affine_map<(m, n) -> (m, n)> // output + ], + iterator_types = ["parallel", "parallel"]} + ins(%alpha, %old_acc : ...) + outs(%empty_acc : tensor<16x64xf32>) { + ^bb0(%a: f32, %acc: f32, %out: f32): + %mul = arith.mulf %a, %acc : f32 + linalg.yield %mul : f32 +} -> tensor<16x64xf32> + +// new_acc[m, n] = corrected_acc[m, n] + sum_k2(P[m, k2] * V[k2, n]) +%new_acc = linalg.generic { + indexing_maps = [ + affine_map<(m, n, k2) -> (m, k2)>, // P + affine_map<(m, n, k2) -> (k2, n)>, // V + affine_map<(m, n, k2) -> (m, n)> // acc (output) + ], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%P, %v_tile : ...) + outs(%corrected_acc : tensor<16x64xf32>) { + ^bb0(%p_val: f32, %v_val: f32, %acc: f32): + %mul = arith.mulf %p_val, %v_val : f32 + %add = arith.addf %mul, %acc : f32 + linalg.yield %add : f32 +} -> tensor<16x64xf32> +``` + +### Summary of Decomposition + +The online attention op is decomposed into these primitive operations: + +``` +┌─────────────────────────────────────────────────┐ +│ iree_linalg_ext.online_attention (1 tiled op) │ +└────────────────────┬────────────────────────────┘ + │ DecomposeAttentionPass + ▼ +┌─────────────────────────────────────────────────┐ +│ 1. S = Q @ K^T * scale (linalg.generic) │ +│ 2. new_max = max(old_max, rowmax(S)) (generic) │ +│ 3. P = exp(S - new_max) (generic) │ +│ alpha = exp(old_max - new_max) (generic) │ +│ 4. new_sum = alpha*old_sum + rowsum(P) (generic) │ +│ 5. new_acc = alpha*old_acc + P @ V (generic) │ +└─────────────────────────────────────────────────┘ +``` + +| Step | Operation | Type | Dims | +|------|-----------|------|------| +| 1 | `S = Q @ K^T * scale` | Matmul + scale | `[16, 16]` ← `[16, 64] × [16, 64]` | +| 2 | `new_max = max(old_max, rowmax(S))` | Row reduction | `[16]` ← `[16, 16]` | +| 3a | `P = exp(S - new_max)` | Elementwise | `[16, 16]` | +| 3b | `alpha = exp(old_max - new_max)` | Elementwise | `[16]` | +| 4 | `new_sum = alpha * old_sum + Σ P` | Scale + row reduction | `[16]` | +| 5 | `new_acc = alpha * old_acc + P @ V` | Scale + matmul | `[16, 64]` ← `[16, 16] × [16, 64]` | + +### Why This Matters + +After decomposition, all ops are standard `linalg.generic` operations. This enables: + +- **Vectorization** via IREE's vector distribution pipeline +- **Mapping to MMA intrinsics** (e.g., MFMA on MI300X) for the two matmuls (Steps 1 and 5) +- **Register-level tiling** and shared memory promotion for GPU targets +- The `scf.for` loop around these ops implements the streaming/online iteration over K/V chunks diff --git a/docs/softmax_lowering.md b/docs/softmax_lowering.md new file mode 100644 index 00000000..d15d2e85 --- /dev/null +++ b/docs/softmax_lowering.md @@ -0,0 +1,683 @@ +# Linalg softmax lowering to XeGPU (Currently supported in lighthouse) + +## Overview + +**Assumptions:** +Softmax dimension size is small (64 in this example). No tiling in reduction dim. + +The lowering process consists of seven stages: +1. **initial** - High-level tensor operations +2. **tiled-softmax** - Tiled softmax operations +3. **decomposed** - Decomposition into generic operations +4. **vectorized** - Vector operations +5. **bufferized** - Memory-based representation +6. **xegpu-initial** - GPU kernel with XeGPU operations +7. **xegpu-wg** - Work-group optimized XeGPU + +--- + +## Stage 1: Initial + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + %2 = tensor.empty() : tensor<1024x64xf32> + %3 = linalg.softmax dimension(1) ins(%1 : tensor<1024x64xf32>) + outs(%2 : tensor<1024x64xf32>) -> tensor<1024x64xf32> + // ... + return +} +``` +--- + +## Stage 2: Tiled Softmax + +**Notes** +- Work distribution via `scf.forall` (16 parallel iterations) +- Each tile processes 64x64 elements + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { + %4 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2) + // Extract 64x64 input slice + %extracted_slice = tensor.extract_slice ... + // Extract 64x64 output slice + %extracted_slice_0 = tensor.extract_slice ... + // Apply softmax to the tile + %5 = linalg.softmax dimension(1) ins(%extracted_slice : tensor<64x64xf32>) + outs(%extracted_slice_0 : tensor<64x64xf32>) -> tensor<64x64xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg3[%4, %c0] [64, 64] [1, 1] : + tensor<64x64xf32> into tensor<1024x64xf32> + } + } + // ... + return +} +``` + +--- + +## Stage 3: Decomposed + +**Notes** +- Softmax decomposed into 4 `linalg.generic` ops : max, sub+exp, sum, divide +- Uses `structured.structured_decompose_interface` implemented by `linalg.softmax` + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + + %2 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %1) -> (tensor<1024x64xf32>) { + %3 = affine.apply #map(%arg2) // %3 = %arg2 * 64 + %extracted_slice = tensor.extract_slice ... + + // Step 1: Find max along dimension 1 + %4 = tensor.empty() : tensor<64xf32> + %5 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<64xf32>) -> tensor<64xf32> + %6 = linalg.generic // ... + %11 = arith.maxnumf %in, %out : f32 + // ... + } -> tensor<64xf32> + + // Step 2: Subtract max and exponentiate + %7 = linalg.generic // ... + %11 = arith.subf %in, %in_2 : f32 + %12 = math.exp %11 : f32 + // ... + } -> tensor<64x64xf32> + + // Step 3: Sum exponentials + %8 = linalg.fill ins(%cst : f32) outs(%4 : tensor<64xf32>) -> tensor<64xf32> + %9 = linalg.generic // ... + %11 = arith.addf %in, %out : f32 + // ... + } -> tensor<64xf32> + + // Step 4: Normalize by sum + %10 = linalg.generic // ... + %11 = arith.divf %in, %in_2 : f32 + // ... + } -> tensor<64x64xf32> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %10 into %arg3[%3, 0] [64, 64] [1, 1] : + tensor<64x64xf32> into tensor<1024x64xf32> + } + } + return +} +``` + +--- + +## Stage 4: Vectorized + +**Notes** +- `linalg.generic` operations replaced with vector operations +- Vector transfers for reading/writing data + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { + %4 = affine.apply #map(%arg2) // %4 = %arg2 * 64 + %extracted_slice = tensor.extract_slice .. + + // Vector read: Load 64x64 tile + %5 = vector.transfer_read %1[%4, %c0], %0 {in_bounds = [true, true]} : + tensor<1024x64xf32>, vector<64x64xf32> + + // Max reduction: Reduce dimension 1 -> vector<64xf32> + %6 = vector.multi_reduction , %5, %cst_0 [1] : + vector<64x64xf32> to vector<64xf32> + + // Broadcast max values back to 64x64 and transpose + %7 = vector.broadcast %6 : vector<64xf32> to vector<64x64xf32> + %8 = vector.transpose %7, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + + // Subtract max and exponentiate + %9 = arith.subf %5, %8 : vector<64x64xf32> + %10 = math.exp %9 : vector<64x64xf32> + + // Sum reduction: Reduce dimension 1 -> vector<64xf32> + %11 = vector.multi_reduction , %10, %cst [1] : + vector<64x64xf32> to vector<64xf32> + + // Broadcast sums back to 64x64 and transpose + %12 = vector.broadcast %11 : vector<64xf32> to vector<64x64xf32> + %13 = vector.transpose %12, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + + // Normalize + %14 = arith.divf %10, %13 : vector<64x64xf32> + + // Vector write + %15 = vector.transfer_write %14, %extracted_slice[%c0, %c0] {in_bounds = [true, true]} : + vector<64x64xf32>, tensor<64x64xf32> + + scf.forall.in_parallel { + tensor.parallel_insert_slice %15 into %arg3[%4, 0] [64, 64] [1, 1] + } + } + return +} +``` +--- + +## Stage 5: Bufferized + +**Notes** +- Tensors eliminated, working directly with memrefs + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + + scf.forall (%arg2) in (16) { + %1 = affine.apply #map(%arg2) // %1 = %arg2 * 64 + + // Direct memref read + %2 = vector.transfer_read %arg1[%1, %c0], %0 {in_bounds = [true, true]} : + memref<1024x64xf32>, vector<64x64xf32> + + // Max reduction + %3 = vector.multi_reduction , %2, %cst_0 [1] : + vector<64x64xf32> to vector<64xf32> + %4 = vector.broadcast %3 : vector<64xf32> to vector<64x64xf32> + %5 = vector.transpose %4, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + + // Subtract and exp + %6 = arith.subf %2, %5 : vector<64x64xf32> + %7 = math.exp %6 : vector<64x64xf32> + + // Sum reduction + %8 = vector.multi_reduction , %7, %cst [1] : + vector<64x64xf32> to vector<64xf32> + %9 = vector.broadcast %8 : vector<64xf32> to vector<64x64xf32> + %10 = vector.transpose %9, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + + // Normalize + %11 = arith.divf %7, %10 : vector<64x64xf32> + + // Direct memref write + vector.transfer_write %11, %arg0[%1, %c0] {in_bounds = [true, true]} : + vector<64x64xf32>, memref<1024x64xf32> + } + return +} +``` + +--- + +## Stage 6: XeGPU-Initial + +**Notes** +- GPU kernel separated from host code (Gpu Outlining) +- `gpu.launch_func` invocation with grid/block dimensions +- Use `vector-to-xegpu` + +**Code:** + +**Host Side:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + gpu.launch_func @payload_kernel::@payload_kernel + blocks in (%c16, %c1, %c1) + threads in (%c128, %c1, %c1) + args(%arg1 : memref<1024x64xf32>, %arg0 : memref<1024x64xf32>) + return +} +``` + +**GPU Kernel:** +```mlir +gpu.module @payload_kernel [#xevm.target] { + gpu.func @payload_kernel(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) kernel + attributes {known_block_size = array, + known_grid_size = array} { + // ... + %block_id_x = gpu.block_id x + %0 = arith.muli %block_id_x, %c64 overflow : index + + // Create XeGPU tensor descriptor for load + %1 = xegpu.create_nd_tdesc %arg0 : memref<1024x64xf32> -> + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> + + // XeGPU block load + %2 = xegpu.load_nd %1[%0, 0] : + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> -> + vector<64x64xf32> + + // Same compute operations as before + %3 = vector.multi_reduction , %2, %cst_0 [1] : + vector<64x64xf32> to vector<64xf32> + %4 = vector.broadcast %3 : vector<64xf32> to vector<64x64xf32> + %5 = vector.transpose %4, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + %6 = arith.subf %2, %5 : vector<64x64xf32> + %7 = math.exp %6 : vector<64x64xf32> + %8 = vector.multi_reduction , %7, %cst [1] : + vector<64x64xf32> to vector<64xf32> + %9 = vector.broadcast %8 : vector<64xf32> to vector<64x64xf32> + %10 = vector.transpose %9, [1, 0] : vector<64x64xf32> to vector<64x64xf32> + %11 = arith.divf %7, %10 : vector<64x64xf32> + + // Create XeGPU tensor descriptor for store + %12 = xegpu.create_nd_tdesc %arg1 : memref<1024x64xf32> -> + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> + + // XeGPU block store + xegpu.store_nd %11, %12[%0, 0] : + vector<64x64xf32>, + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> + + gpu.return + } +} +``` + +--- + +## Stage 7: XeGPU-WG (Work-Group Optimized) + +**Notes** +- Sets the layout for anchor xegpu ops. Each Wg consistes of [8, 1] subgroups + doing 8x64 softmax slice. +- Only sets the layout for `store_nd`. Layout propagation does the rest. + +**Code (differences from xegpu-initial):** +```mlir +// Store operation now includes layout hints +xegpu.store_nd %11, %12[%0, 0] + <{layout = #xegpu.layout}> : + vector<64x64xf32>, + !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr> +``` + +--- +## What is reused form the lighthouse project? + +The softmax implementation leverages several reusable lighthouse infrastructure components: + +The workload-specific code is minimal (payload generation + tiling/decomposition strategy). Most infrastructure (execution, benchmarking, XeGPU lowering) is shared across different operations like matmul, softmax, etc. + +### Pipeline Infrastructure (`lighthouse.pipeline`) +- **`TransformDriver`**: Orchestrates application of transform schedules to payload modules +- **`apply_registered_pass`**: Applies named MLIR passes (e.g., `eliminate-empty-tensors`, `gpu-kernel-outlining`) +- **`canonicalize`, `match`, `match_and_split`, `PipelineInterrupt`**: Helper utilities for constructing transform sequences + +### Execution Infrastructure (`lighthouse.execution`) +- **`execute`**: Runs compiled MLIR modules on GPU with memory management +- **`benchmark`**: Benchmarks kernel execution with warmup and timing utilities +- **`GPUMemoryManager`**: Manages host-device memory transfers for GPU execution +- **`get_bench_wrapper_schedule`**: Wraps payload functions with benchmarking infrastructure + +### XeGPU Lowering (`lighthouse.schedule.xegpu.helper`) +- **`bundle_xegpu_to_binary`**: Common lowering path from XeGPU to executable binary (shared with matmul and other XeGPU workloads) + - Handles XeGPU peephole optimizations, layout propagation, and GPU-to-SPIRV/binary compilation + +### Payload Generation (`lighthouse.ingress.mlir_gen`) +- **`get_mlir_elem_type`**: Type conversion utilities for constructing MLIR types + + +--- + +# Supporting larger Softmax dimension sizes + +Previsouly we tiled all ops in the parallel dimension only (i.e. non softmax dim). Handling a larger softmax reduction dimension require tiling the softmax contituent ops in the reduction dimension. + +**Approach:** Tile reductions along dimension 1 (using appropriate step size) and fuse producers into consumers to avoid extra memory buffers. + +--- + +## Stage A - Tile div op + +**Notes:** +- Tile the division operation with step size 16 along dimension 1 +- Creates `scf.for` loop iterating over 64 elements in chunks of 16 + +**Key Changes:** +```mlir +// Before: Single division linalg.generic over 64x64 +scf.forall ... { + // Max, Center+Exp, Sum ops ... + %11 = linalg.generic {...} ins(%8, %10 : tensor<64x64xf32>, tensor<64xf32>) outs(%extracted_slice_0 : tensor<64x64xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %12 = arith.divf %in, %in_2 : f32 + linalg.yield %12 : f32 + } -> tensor<64x64xf32> +} + +// After: Division tiled into 64x16 chunks +scf.forall ... { + // Max, Center+Exp, Sum ops ... + %11 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %extracted_slice_0) -> (tensor<64x64xf32>) { + %12 = linalg.generic {...} ins(%extracted_slice_3, %extracted_slice_4 : tensor<64x16xf32>, tensor<64xf32>) outs(%extracted_slice_5 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_6: f32, %out: f32): + %13 = arith.divf %in, %in_6 : f32 + linalg.yield %13 : f32 + } -> tensor<64x16xf32> + %inserted_slice = tensor.insert_slice %12 into %arg5[0, %arg4] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x64xf32> + scf.yield %inserted_slice : tensor<64x64xf32> + } +} +``` + +--- + +## Stage B - Fuse sub+exp into div loop + +**Notes:** +- Fuse the `sub+exp` producer (max_center_and_exp_op) into the div loop +- Recomputes exp values on-the-fly instead of materializing full 64x64 tensor + +**Key Changes:** +```mlir +// Original tiled div loop. +%11 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %extracted_slice_0) -> (tensor<64x64xf32>) { + %extracted_slice_3 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] + + // Fused recomputation: sub+exp computed per 16-element chunk + %12 = linalg.generic {...} ins(%extracted_slice_3, %extracted_slice_4 : tensor<64x16xf32>, tensor<64xf32>) + outs(%extracted_slice_5 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_8: f32, %out: f32): + %14 = arith.subf %in, %in_8 : f32 + %15 = math.exp %14 : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + + // Division operation + %13 = linalg.generic {...} ins(%12, %extracted_slice_6 : tensor<64x16xf32>, tensor<64xf32>) + outs(%extracted_slice_7 : tensor<64x16xf32>) { ... } -> tensor<64x16xf32> + // ... +} +``` + +--- + +## Stage C - Tile sum reduction + +**Notes:** +- Tile the sum reduction using `structured_tile_reduction_using_for` +- Creates intermediate accumulator tensor (64x16) +- Final reduction via `linalg.reduce` over dimension 1 + +**Key Changes:** +```mlir +// Tiled sum reduction with intermediate accumulator +%10 = tensor.empty() : tensor<64x16xf32> +%11 = linalg.fill ins(%cst_2 : f32) outs(%10 : tensor<64x16xf32>) -> tensor<64x16xf32> + +%12 = scf.for %arg4 = %c0_3 to %c64 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { + %extracted_slice_7 = tensor.extract_slice %8[0, %arg4] [64, 16] [1, 1] + %14 = linalg.generic {...} ins(%extracted_slice_7 : tensor<64x16xf32>) + outs(%extracted_slice_8 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %15 = arith.addf %in, %out : f32 + linalg.yield %15 : f32 + } -> tensor<64x16xf32> + // ... +} + +// Final reduction to 64xf32 +%reduced = linalg.reduce ins(%12 : tensor<64x16xf32>) outs(%9 : tensor<64xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %14 = arith.addf %in, %init : f32 + linalg.yield %14 : f32 + } +``` + +--- + +## Stage D - Fuse sub+exp into sum loop + +**Notes:** +- Fuse `sub+exp` into the sum reduction loop +- Stream computation: compute exp and accumulate in same loop + +**Key Changes:** +```mlir +// Original tiled mun loop. +%12 = scf.for %arg4 = %c0_3 to %c64 step %c16 iter_args(%arg5 = %11) -> (tensor<64x16xf32>) { + %extracted_slice_7 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] + + // Fused recomputation: sub+exp + %14 = linalg.generic {...} ins(%extracted_slice_7, %extracted_slice_8 : tensor<64x16xf32>, tensor<64xf32>) + outs(%extracted_slice_9 : tensor<64x16xf32>) { + ^bb0(%in: f32, %in_11: f32, %out: f32): + %16 = arith.subf %in, %in_11 : f32 + %17 = math.exp %16 : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + + // Accumulate sum + %15 = linalg.generic {...} ins(%14 : tensor<64x16xf32>) + outs(%extracted_slice_10 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %16 = arith.addf %in, %out : f32 + linalg.yield %16 : f32 + } -> tensor<64x16xf32> + // ... +} +``` + +--- + +## Stage E - Tile max reduction + +**Notes:** +- Tile max reduction similar to sum reduction +- Creates 64x16 intermediate accumulator +- Final reduction via `linalg.reduce` with maxnumf + +**Key Changes:** +```mlir +// Tiled max reduction +%7 = tensor.empty() : tensor<64x16xf32> +%8 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<64x16xf32>) -> tensor<64x16xf32> + +%9 = scf.for %arg4 = %c0_2 to %c64 step %c16 iter_args(%arg5 = %8) -> (tensor<64x16xf32>) { + %extracted_slice_12 = tensor.extract_slice %extracted_slice[0, %arg4] [64, 16] [1, 1] + %16 = linalg.generic {...} ins(%extracted_slice_12 : tensor<64x16xf32>) + outs(%extracted_slice_13 : tensor<64x16xf32>) { + ^bb0(%in: f32, %out: f32): + %17 = arith.maxnumf %in, %out : f32 + linalg.yield %17 : f32 + } -> tensor<64x16xf32> + // ... +} + +// Final max reduction +%reduced = linalg.reduce ins(%9 : tensor<64x16xf32>) outs(%6 : tensor<64xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %16 = arith.maxnumf %in, %init : f32 + linalg.yield %16 : f32 + } +``` + +**Result:** Now all three major computations (max, sum, div) are tiled and operate on 64x16 chunks, with exp computation fused into both sum and div loops. + +--- + +## Stage F - Vectorization + +**Notes:** +- Convert tiled linalg operations to vector operations +- `scf.for` loops remain but operate on vectors +- Vector size: 64x16 for tiled operations +- Same buffer is used for max and sum streaming reductions. + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + %cst = arith.constant dense<0.000000e+00> : vector<64x16xf32> // 0.0 + %cst_1 = arith.constant dense<0xFFC00000> : vector<64x16xf32> // -inf + %3 = scf.forall (%arg2) in (16) shared_outs(%arg3 = %2) -> (tensor<1024x64xf32>) { + // ... + + // Vectorized max reduction loop + %5 = tensor.empty() : tensor<64x16xf32> + %6 = vector.transfer_write %cst_1, %5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> + %7 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %6) -> (tensor<64x16xf32>) { + %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> + %16 = vector.transfer_read %arg5[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> + %17 = arith.maxnumf %15, %16 : vector<64x16xf32> + %18 = vector.transfer_write %17, %arg5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> + scf.yield %18 : tensor<64x16xf32> + } + %8 = vector.transfer_read %7[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> + %9 = vector.multi_reduction , %8, %cst_2 [1] : vector<64x16xf32> to vector<64xf32> + %10 = vector.transfer_write %cst, %5[%c0, %c0] {in_bounds = [true, true]} : vector<64x16xf32>, tensor<64x16xf32> + // Vectorized sum reduction loop with fused sub+exp + %11 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %10) -> (tensor<64x16xf32>) { + %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> + %16 = vector.broadcast %9 : vector<64xf32> to vector<16x64xf32> + %17 = vector.transpose %16, [1, 0] : vector<16x64xf32> to vector<64x16xf32> + %18 = arith.subf %15, %17 : vector<64x16xf32> + %19 = math.exp %18 : vector<64x16xf32> + %20 = vector.transfer_read %arg5[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> + %21 = arith.addf %19, %20 : vector<64x16xf32> + %22 = vector.transfer_write %21, %arg5[%c0, %c0] : vector<64x16xf32>, tensor<64x16xf32> + scf.yield %22 : tensor<64x16xf32> + } + %12 = vector.transfer_read %11[%c0, %c0], %0 : tensor<64x16xf32>, vector<64x16xf32> + %13 = vector.multi_reduction , %12, %cst_0 [1] : vector<64x16xf32> to vector<64xf32> + + // Vectorized div loop with fused sub+exp + %14 = scf.for %arg4 = %c0 to %c64 step %c16 iter_args(%arg5 = %extracted_slice) -> (tensor<64x64xf32>) { + %15 = vector.transfer_read %1[%4, %arg4], %0 : tensor<1024x64xf32>, vector<64x16xf32> + %16 = vector.broadcast %9 : vector<64xf32> to vector<16x64xf32> + %17 = vector.transpose %16, [1, 0] : vector<16x64xf32> to vector<64x16xf32> + %18 = arith.subf %15, %17 : vector<64x16xf32> + %19 = math.exp %18 : vector<64x16xf32> + %20 = vector.broadcast %13 : vector<64xf32> to vector<16x64xf32> + %21 = vector.transpose %20, [1, 0] : vector<16x64xf32> to vector<64x16xf32> + %22 = arith.divf %19, %21 : vector<64x16xf32> + %23 = vector.transfer_write %22, %arg5[%c0, %arg4] : vector<64x16xf32>, tensor<64x64xf32> + scf.yield %23 : tensor<64x64xf32> + } + } + // ... +} +``` + +--- + +## Stage G - Bufferization + +**Notes:** +- Convert tensors to memrefs +- Allocate stack buffer for 64x16 accumulator: `memref.alloc()` + +**Code:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + // ... + scf.forall (%arg2) in (16) { + %1 = affine.apply #map(%arg2) + %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] + + // Allocate accumulator buffer + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64x16xf32> + + // Max reduction loop + vector.transfer_write %cst_1, %alloc[%c0, %c0] : vector<64x16xf32>, memref<64x16xf32> + scf.for %arg3 = %c0 to %c64 step %c16 { + %6 = vector.transfer_read %arg1[%1, %arg3], %0 : memref<1024x64xf32>, vector<64x16xf32> + %7 = vector.transfer_read %alloc[%c0, %c0], %0 : memref<64x16xf32>, vector<64x16xf32> + %8 = arith.maxnumf %6, %7 : vector<64x16xf32> + vector.transfer_write %8, %alloc[%c0, %c0] : vector<64x16xf32>, memref<64x16xf32> + } + %2 = vector.transfer_read %alloc[%c0, %c0], %0 : memref<64x16xf32>, vector<64x16xf32> + %3 = vector.multi_reduction , %2, %cst_2 [1] : vector<64x16xf32> to vector<64xf32> + + // Sum reduction loop (reuses %alloc) + // ... + + // Div loop (writes to %subview) + // ... + } +} +``` + +--- + +## Stage H - Promote buffers to stack + +**Notes:** +- Convert `memref.alloc()` to `memref.alloca()` for stack allocation +- **Why?** We can only allocate SLM inside GPU kernel. + +**Code:** +```mlir +scf.forall (%arg2) in (16) { + %1 = affine.apply #map(%arg2) + %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] + + // Stack allocation instead of heap + %alloca = memref.alloca() {alignment = 64 : i64} : memref<64x16xf32> + + // ... same operations using %alloca ... +} +``` + +--- + +## Stage I - GPU outlining + +**Notes:** +- Convert `scf.forall` to `scf.parallel`, then to `gpu.launch` +- Extract GPU kernel into separate `gpu.module` +- Set thread count: 128 threads = (64 rows / 8 sg_rows) × 16 subgroup_size + +**Host Side:** +```mlir +func.func @payload(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) { + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + gpu.launch_func @payload_kernel::@payload_kernel + blocks in (%c16, %c1, %c1) + threads in (%c128, %c1, %c1) + args(%arg0 : memref<1024x64xf32>, %arg1 : memref<1024x64xf32>) + return +} +``` + +**GPU Kernel:** +```mlir +gpu.module @payload_kernel { + gpu.func @payload_kernel(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) kernel + attributes {known_block_size = array, + known_grid_size = array} { + %block_id_x = gpu.block_id x + %1 = arith.muli %block_id_x, %c64 overflow : index + %subview = memref.subview %arg0[%1, 0] [64, 64] [1, 1] + %alloca = memref.alloca() {alignment = 64 : i64} : memref<64x16xf32> + + // Three reduction loops (max, sum, div) with same structure + scf.for %arg2 = %c0 to %c64 step %c16 { + // Max: accumulate max values + // Sum: compute & accumulate exp(x - max) + // Div: compute exp(x - max) / sum + } + + gpu.return + } +} +``` + +**Summary:** At this stage, the kernel processes 64x16 chunks in streaming fashion through three sequential loops, minimizing memory footprint. + +## XeGPU improvements needed for e2e support (WIP) + +* Change `memeref.alloca()` address space to SLM. +* Support for lowering load/store using SLM to `xegpu.load/store_matrix` instead of `xegpu.load/store_nd`