Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
eecee07
save work
charithaintc Mar 13, 2026
9ff9dbc
Merge branch 'main' into softmax_impl
charithaintc Mar 16, 2026
f991027
save work
charithaintc Mar 16, 2026
22415bb
save work
charithaintc Mar 16, 2026
cb9ead1
save work
charithaintc Mar 17, 2026
ac39be3
save work
charithaintc Mar 18, 2026
51d494e
save work
charithaintc Mar 20, 2026
7ac8852
save work
charithaintc Mar 20, 2026
fa2993d
Merge branch 'main' into softmax_impl
charithaintc Mar 20, 2026
d65bf9f
save work
charithaintc Mar 20, 2026
0bf3eb3
save working version
charithaintc Mar 24, 2026
fabd656
save working version
charithaintc Mar 24, 2026
1e63d7d
save working version
charithaintc Mar 24, 2026
64b5d73
save working version
charithaintc Mar 24, 2026
108f2c0
save working version
charithaintc Mar 25, 2026
a7e1e6c
save working version
charithaintc Mar 25, 2026
df53caa
precommit issues
charithaintc Mar 25, 2026
9bcc653
use linalg.softmax
charithaintc Mar 27, 2026
3f5cbce
save work
charithaintc Mar 30, 2026
6204d6c
add inner dim tiling
charithaintc Mar 30, 2026
a8ca522
Merge branch 'main' into softmax_impl
charithaintc Mar 31, 2026
1feb0d4
save fused version
charithaintc Apr 1, 2026
a28cf4a
save work
charithaintc Apr 1, 2026
79e2f73
save work
charithaintc Apr 3, 2026
55c175c
save work
charithaintc Apr 3, 2026
bf3a8c6
save work
charithaintc Apr 3, 2026
81da73e
Merge branch 'softmax_impl' into softmax_doc
charithaintc Apr 3, 2026
b083887
save work
charithaintc Apr 3, 2026
c02b66b
fused version
charithaintc Apr 3, 2026
bce6260
tiled reduction doc
charithaintc Apr 3, 2026
2df0777
tiled reduction doc
charithaintc Apr 3, 2026
56687b7
tiled reduction doc
charithaintc Apr 3, 2026
b8f5616
Merge branch 'main' into softmax_impl
charithaintc Apr 15, 2026
d2d4c49
save work
charithaintc Apr 15, 2026
32a345a
save work
charithaintc Apr 16, 2026
eacc9d8
save work
charithaintc Apr 17, 2026
f02f599
Merge branch 'main' into softmax_reduction_tiling
charithaintc Apr 20, 2026
1313477
working version
charithaintc Apr 20, 2026
f1857aa
working version
charithaintc Apr 20, 2026
240cf08
add initial version
charithaintc Apr 20, 2026
3262e4a
add initial version
charithaintc Apr 20, 2026
2135de3
payload done
charithaintc Apr 21, 2026
361e069
tiled last matmul
charithaintc Apr 21, 2026
4d0827e
change to batch matmul
charithaintc Apr 21, 2026
e379b68
save work
charithaintc Apr 22, 2026
38f2d97
save work
charithaintc Apr 27, 2026
0076ee2
save work
charithaintc Apr 27, 2026
ffd2f1e
Merge branch 'main' into softmax_reduction_tiling
charithaintc Apr 27, 2026
452faf7
update llvm
charithaintc Apr 27, 2026
e3512e6
Merge branch 'softmax_reduction_tiling' into flash_attention_tiling
charithaintc Apr 27, 2026
3248477
refactor code
charithaintc Apr 27, 2026
ce07760
refactor code
charithaintc Apr 28, 2026
a0c421b
parallel dim tiling done
charithaintc Apr 28, 2026
dba42ad
address comments
charithaintc Apr 29, 2026
ba380a3
initial verion without reduction tiling
charithaintc Apr 29, 2026
5222af3
save xegpu wg version
charithaintc Apr 30, 2026
1dba7ec
save work
charithaintc May 6, 2026
32e0349
save work
charithaintc May 7, 2026
609e571
save work
charithaintc May 7, 2026
b6863a9
minimum buffer version
charithaintc May 11, 2026
22faeb5
minimum buffer version
charithaintc May 11, 2026
93e489c
Merge branch 'main' into flash_attention_tiling
charithaintc May 11, 2026
f9710f8
save work
charithaintc May 12, 2026
036d59f
save work
charithaintc May 14, 2026
14cbb3b
save valid imex version
charithaintc May 14, 2026
e4c347c
match dims of imex
charithaintc May 15, 2026
2c87d4a
save work
charithaintc May 15, 2026
c0ec544
save unolled version
charithaintc May 15, 2026
60e3b2e
compile to binary now
charithaintc May 15, 2026
f07eec3
cleanup
charithaintc May 15, 2026
1eaa923
cleanup
charithaintc May 15, 2026
2302a34
cleanup
charithaintc May 15, 2026
36a2418
Merge branch 'main' into flash_attention_tiling_imex_version
charithaintc May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 387 additions & 0 deletions examples/xegpu/fused_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
# CHECK: module attributes {gpu.container_module} {

"""
XeGPU fused attention benchmark.
"""

import argparse
from typing import Optional
from functools import cached_property

import numpy as np
from mlir import ir

from lighthouse import dialects as lh_dialects
from lighthouse.execution.runner import Runner
from lighthouse.pipeline.driver import TransformDriver
from lighthouse.execution import GPUMemoryManager
from lighthouse.utils.numpy import mlir_to_numpy_dtype
from lighthouse.ingress.mlir_gen import get_mlir_elem_type
from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import (
generate_gpu_fused_attention_payload,
)
from lighthouse.schedule.xegpu import fused_attention_schedule, xegpu_to_binary


def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int):
"""
Complexity of fused attention operation.

For each batch and head:
- Q @ K^T: O(n_ctx^2 * n_head) operations
- Softmax: O(n_ctx^2) operations
- Attention @ V: O(n_ctx^2 * n_head) operations
Total: approximately 2*n_ctx^2*n_head FLOPs per batch and head
"""
# Approximation: 2 * n_ctx^2 * n_head FLOPs per batch and head
flop_count = Z * H * 2 * n_ctx * n_ctx * n_head
# Memory: read Q, K, V and write output
memory_reads = 3 * Z * H * n_ctx * n_head * nbytes
memory_writes = Z * H * n_ctx * n_head * nbytes
return flop_count, memory_reads, memory_writes


def check_correctness(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
output_arr: np.ndarray,
verbose: int = 0,
) -> bool:
"""
Check correctness of fused attention output.

Reference implementation:
- scores = Q @ K^T / sqrt(n_head)
- attention_weights = softmax(scores, dim=-1)
- output = attention_weights @ V
"""
# Use float32 for computation
Q_f32 = Q.astype(np.float32)
K_f32 = K.astype(np.float32)
V_f32 = V.astype(np.float32)

Z, H, n_ctx, n_head = Q.shape
scale = 1.0 / np.sqrt(n_head)

output_ref = np.zeros_like(Q_f32)

# Compute reference for each batch and head
for z in range(Z):
for h in range(H):
# scores = Q @ K^T / sqrt(n_head)
scores = Q_f32[z, h] @ K_f32[z, h].T * scale

# softmax along last dimension
max_vals = np.max(scores, axis=1, keepdims=True)
exp_vals = np.exp(scores - max_vals)
sum_vals = np.sum(exp_vals, axis=1, keepdims=True)
attention_weights = exp_vals / sum_vals

# output = attention_weights @ V
output_ref[z, h] = attention_weights @ V_f32[z, h]

output = output_arr.astype(np.float32)

if verbose > 1:
print("Reference solution (first batch, first head, first 5 rows):")
print(output_ref[0, 0, :5])
print("Computed solution (first batch, first head, first 5 rows):")
print(output[0, 0, :5])

# Check values match reference
values_ok = np.allclose(output, output_ref, rtol=1e-3, atol=1e-4)

success = values_ok

if verbose:
if success:
print("PASSED")
else:
print("FAILED!")
if not values_ok:
max_diff = np.abs(output - output_ref).max()
print(f" Values mismatch. Max abs diff: {max_diff:.6e}")
return success


class XeGPUFusedAttention:
"""
Fused attention workload on XeGPU.

Computes fused attention:
output = softmax(Q @ K^T / sqrt(n_head)) @ V

All Q, K, V matrices have shape (Z, H, n_ctx, n_head) where:
- Z: batch size
- H: number of heads
- n_ctx: context length
- n_head: head dimension
"""

def __init__(
self,
Z: int,
H: int,
n_ctx: int,
n_head: int,
dtype: str = "f16",
):
self.Z = Z
self.H = H
self.n_ctx = n_ctx
self.n_head = n_head
self.shape = (Z, H, n_ctx, n_head)
assert dtype == "f16", "Only f16 type is supported for fused attention"
self.elem_type = get_mlir_elem_type(dtype)
self.dtype = mlir_to_numpy_dtype(self.elem_type)
self.memory_manager_class = GPUMemoryManager
self.payload_function_name = "payload"

@cached_property
def _initial_host_arrays(self) -> tuple[np.ndarray]:
"""Generate initial values on host with numpy."""
np.random.seed(42)
# Initialize Q, K, V with small random values
Q = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype)
K = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype)
V = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype)
output_arr = np.zeros(self.shape, dtype=self.dtype)
return (output_arr, Q, K, V)

def get_complexity(self) -> tuple[int, int, int]:
nbytes = np.dtype(self.dtype).itemsize
return fused_attention_complexity(
self.Z, self.H, self.n_ctx, self.n_head, nbytes
)

def payload_module(self) -> ir.Module:
"""Generate MLIR module for fused attention payload."""
return generate_gpu_fused_attention_payload(
func_name=self.payload_function_name,
Z=self.Z,
H=self.H,
n_ctx=self.n_ctx,
n_head=self.n_head,
dtype=self.elem_type,
)

def schedule_modules(
self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None
) -> list[ir.Module]:
"""Generate transform schedule for fused attention."""
schedules = []
schedules.append(Runner.get_bench_wrapper_schedule(self.payload_function_name))

schedules.append(
fused_attention_schedule(
stop_at_stage=stop_at_stage,
parameters=parameters,
)
)

if stop_at_stage and stop_at_stage != "final":
return schedules

schedules.append(xegpu_to_binary())

return schedules

def shared_libs(self) -> list[str]:
return ["libmlir_levelzero_runtime.so"]


def parse_cli():
parser = argparse.ArgumentParser(
description="Fused Attention using MLIR XeGPU",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--batch-size",
type=int,
default=2,
help="Batch size (Z)",
)
parser.add_argument(
"--num-heads",
type=int,
default=8,
help="Number of attention heads (H)",
)
parser.add_argument(
"--n-ctx",
type=int,
default=4096,
help="Context length (sequence length)",
)
parser.add_argument(
"--n-head",
type=int,
default=64,
help="Head dimension",
)
parser.add_argument(
"--wg-rows",
type=int,
default=128,
help="Number of Q*K^T*V rows computed by each work group",
)
parser.add_argument(
"--sg-rows",
type=int,
default=16,
help="Number of Q*K^T*V rows computed by each subgroup",
)
parser.add_argument(
"--subgroup-size",
type=int,
default=16,
help="Subgroup size",
)
parser.add_argument(
"--inner-loop-tile-size",
type=int,
default=64,
help="Tile size for the inner reduction dimension (K/V sequence length)",
)
parser.add_argument(
"--nruns",
type=int,
default=1000,
help="Number of runs to average the execution time.",
)
parser.add_argument(
"--nwarmup",
type=int,
default=20,
help="Number of warm-up iterations before benchmarking.",
)
parser.add_argument(
"--check-result",
action="store_true",
help="Check the result of the fused attention computation.",
)
parser.add_argument(
"--dump-kernel",
type=str,
choices=[
"initial",
"outer-tiled",
"inner-tiled",
"vectorized",
"bufferized",
"gpu-outlining",
"xegpu-initial",
"xegpu-wg",
"final",
],
help="Dump kernel IR at different stages of lowering and exit without "
"executing the kernel.",
)
parser.add_argument(
"--dump-schedule",
action="store_true",
help="Dump transform schedule.",
)
parser.add_argument(
"--verbose",
"-v",
action="count",
default=0,
help="Increase output verbosity (e.g. print reference and computed solutions).",
)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_cli()

params = {
"batch_size": args.batch_size,
"num_heads": args.num_heads,
"n_ctx": args.n_ctx,
"n_head": args.n_head,
"wg_rows": args.wg_rows,
"sg_rows": args.sg_rows,
"subgroup_size": args.subgroup_size,
"inner_loop_tile_size": args.inner_loop_tile_size,
}

Z = args.batch_size
H = args.num_heads
n_ctx = args.n_ctx
n_head = args.n_head
dtype = "f16"

with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()
wload = XeGPUFusedAttention(Z=Z, H=H, n_ctx=n_ctx, n_head=n_head, dtype=dtype)

if args.dump_kernel or args.dump_schedule:
pipeline = TransformDriver(
wload.schedule_modules(
stop_at_stage=args.dump_kernel, parameters=params
)
)
payload = pipeline.apply(wload.payload_module())
if args.dump_kernel:
print(payload)
if args.dump_schedule:
for schedule_module in wload.schedule_modules(parameters=params):
print(schedule_module)
else:
pipeline = TransformDriver(wload.schedule_modules(parameters=params))
payload = pipeline.apply(wload.payload_module())
runner = Runner(
payload,
mem_manager_cls=wload.memory_manager_class,
shared_libs=wload.shared_libs(),
)
if args.check_result:
# Setup callback function to copy result from device to host.
result_host_copy, argument_access_callback = (
Runner.get_gpu_argument_access_callback(wload.shape, wload.dtype)
)

# Execute kernel once.
runner.execute(
host_input_buffers=wload._initial_host_arrays,
payload_function_name=wload.payload_function_name,
argument_access_callback=argument_access_callback,
)

# Compute reference solution on host.
Q, K, V = wload._initial_host_arrays[1:4]
success = check_correctness(
Q,
K,
V,
result_host_copy,
verbose=args.verbose,
)
if not success:
raise ValueError("Result mismatch!")
else:
print("Result is correct. Proceeding to benchmark...")

times = runner.benchmark(
host_input_buffers=wload._initial_host_arrays,
nruns=args.nruns,
nwarmup=args.nwarmup,
)
times *= 1e6 # convert to microseconds
elapsed = np.mean(times)
flop_count = wload.get_complexity()[0]
gflops = flop_count / (elapsed * 1e-6) / 1e9

print(
f"batch-size={Z} "
f"num-heads={H} "
f"n-ctx={n_ctx} "
f"n-head={n_head} "
f"dt={dtype} "
f"time(us): {elapsed:.2f} "
f"GFLOPS: {gflops:.2f} "
)
Loading
Loading