Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
eecee07
save work
charithaintc Mar 13, 2026
9ff9dbc
Merge branch 'main' into softmax_impl
charithaintc Mar 16, 2026
f991027
save work
charithaintc Mar 16, 2026
22415bb
save work
charithaintc Mar 16, 2026
cb9ead1
save work
charithaintc Mar 17, 2026
ac39be3
save work
charithaintc Mar 18, 2026
51d494e
save work
charithaintc Mar 20, 2026
7ac8852
save work
charithaintc Mar 20, 2026
fa2993d
Merge branch 'main' into softmax_impl
charithaintc Mar 20, 2026
d65bf9f
save work
charithaintc Mar 20, 2026
0bf3eb3
save working version
charithaintc Mar 24, 2026
fabd656
save working version
charithaintc Mar 24, 2026
1e63d7d
save working version
charithaintc Mar 24, 2026
64b5d73
save working version
charithaintc Mar 24, 2026
108f2c0
save working version
charithaintc Mar 25, 2026
a7e1e6c
save working version
charithaintc Mar 25, 2026
df53caa
precommit issues
charithaintc Mar 25, 2026
9bcc653
use linalg.softmax
charithaintc Mar 27, 2026
3f5cbce
save work
charithaintc Mar 30, 2026
6204d6c
add inner dim tiling
charithaintc Mar 30, 2026
a8ca522
Merge branch 'main' into softmax_impl
charithaintc Mar 31, 2026
1feb0d4
save fused version
charithaintc Apr 1, 2026
a28cf4a
save work
charithaintc Apr 1, 2026
79e2f73
save work
charithaintc Apr 3, 2026
55c175c
save work
charithaintc Apr 3, 2026
bf3a8c6
save work
charithaintc Apr 3, 2026
81da73e
Merge branch 'softmax_impl' into softmax_doc
charithaintc Apr 3, 2026
b083887
save work
charithaintc Apr 3, 2026
c02b66b
fused version
charithaintc Apr 3, 2026
bce6260
tiled reduction doc
charithaintc Apr 3, 2026
2df0777
tiled reduction doc
charithaintc Apr 3, 2026
56687b7
tiled reduction doc
charithaintc Apr 3, 2026
b8f5616
Merge branch 'main' into softmax_impl
charithaintc Apr 15, 2026
d2d4c49
save work
charithaintc Apr 15, 2026
32a345a
save work
charithaintc Apr 16, 2026
eacc9d8
save work
charithaintc Apr 17, 2026
f02f599
Merge branch 'main' into softmax_reduction_tiling
charithaintc Apr 20, 2026
1313477
working version
charithaintc Apr 20, 2026
f1857aa
working version
charithaintc Apr 20, 2026
240cf08
add initial version
charithaintc Apr 20, 2026
3262e4a
add initial version
charithaintc Apr 20, 2026
2135de3
payload done
charithaintc Apr 21, 2026
361e069
tiled last matmul
charithaintc Apr 21, 2026
4d0827e
change to batch matmul
charithaintc Apr 21, 2026
e379b68
save work
charithaintc Apr 22, 2026
06e7ede
save initial softmax doc
charithaintc Apr 22, 2026
05e3e07
save work
charithaintc Apr 22, 2026
028b27d
save work
charithaintc Apr 22, 2026
cd93df6
add optimization note
charithaintc Apr 23, 2026
9bbb99c
save work
charithaintc Apr 23, 2026
8637269
add attention doc
charithaintc Apr 23, 2026
037193b
add optimization note
charithaintc Apr 23, 2026
32133cf
add optimization note
charithaintc Apr 23, 2026
3583e81
add attention doc
charithaintc Apr 23, 2026
06cf4f2
save work
charithaintc Apr 23, 2026
1caf9ce
save work
charithaintc Apr 23, 2026
d1e3c3f
add pdf slices
charithaintc Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 359 additions & 0 deletions examples/xegpu/fused_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
# CHECK: module attributes {gpu.container_module} {

"""
XeGPU fused attention benchmark.
"""

import argparse
from typing import Optional
from functools import cached_property

import numpy as np
from mlir import ir

from lighthouse import dialects as lh_dialects
from lighthouse.execution.runner import Runner
from lighthouse.pipeline.driver import TransformDriver
from lighthouse.execution import GPUMemoryManager
from lighthouse.utils.numpy import mlir_to_numpy_dtype
from lighthouse.ingress.mlir_gen import get_mlir_elem_type
from lighthouse.ingress.mlir_gen.gpu_fused_attention_payload import (
generate_gpu_fused_attention_payload,
)
from lighthouse.schedule.xegpu.fused_attention_schedule import (
get_fused_attention_schedule_module,
)


def fused_attention_complexity(Z: int, H: int, n_ctx: int, n_head: int, nbytes: int):
"""
Complexity of fused attention operation.

For each batch and head:
- Q @ K^T: O(n_ctx^2 * n_head) operations
- Softmax: O(n_ctx^2) operations
- Attention @ V: O(n_ctx^2 * n_head) operations
Total: approximately 2*n_ctx^2*n_head FLOPs per batch and head
"""
# Approximation: 2 * n_ctx^2 * n_head FLOPs per batch and head
flop_count = Z * H * 2 * n_ctx * n_ctx * n_head
# Memory: read Q, K, V and write output
memory_reads = 3 * Z * H * n_ctx * n_head * nbytes
memory_writes = Z * H * n_ctx * n_head * nbytes
return flop_count, memory_reads, memory_writes


def check_correctness(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
output_arr: np.ndarray,
verbose: int = 0,
) -> bool:
"""
Check correctness of fused attention output.

Reference implementation:
- scores = Q @ K^T / sqrt(n_head)
- attention_weights = softmax(scores, dim=-1)
- output = attention_weights @ V
"""
# Use float32 for computation
Q_f32 = Q.astype(np.float32)
K_f32 = K.astype(np.float32)
V_f32 = V.astype(np.float32)

Z, H, n_ctx, n_head = Q.shape
scale = 1.0 / np.sqrt(n_head)

output_ref = np.zeros_like(Q_f32)

# Compute reference for each batch and head
for z in range(Z):
for h in range(H):
# scores = Q @ K^T / sqrt(n_head)
scores = Q_f32[z, h] @ K_f32[z, h].T * scale

# softmax along last dimension
max_vals = np.max(scores, axis=1, keepdims=True)
exp_vals = np.exp(scores - max_vals)
sum_vals = np.sum(exp_vals, axis=1, keepdims=True)
attention_weights = exp_vals / sum_vals

# output = attention_weights @ V
output_ref[z, h] = attention_weights @ V_f32[z, h]

output = output_arr.astype(np.float32)

if verbose > 1:
print("Reference solution (first batch, first head, first 5 rows):")
print(output_ref[0, 0, :5])
print("Computed solution (first batch, first head, first 5 rows):")
print(output[0, 0, :5])

# Check values match reference
values_ok = np.allclose(output, output_ref, rtol=1e-3, atol=1e-4)

success = values_ok

if verbose:
if success:
print("PASSED")
else:
print("FAILED!")
if not values_ok:
max_diff = np.abs(output - output_ref).max()
print(f" Values mismatch. Max abs diff: {max_diff:.6e}")
return success


class XeGPUFusedAttention:
"""
Fused attention workload on XeGPU.

Computes fused attention:
output = softmax(Q @ K^T / sqrt(n_head)) @ V

All Q, K, V matrices have shape (Z, H, n_ctx, n_head) where:
- Z: batch size
- H: number of heads
- n_ctx: context length
- n_head: head dimension
"""

def __init__(
self,
Z: int,
H: int,
n_ctx: int,
n_head: int,
dtype: str = "f32",
):
self.Z = Z
self.H = H
self.n_ctx = n_ctx
self.n_head = n_head
self.shape = (Z, H, n_ctx, n_head)
assert dtype == "f32", "Only f32 type is supported for fused attention"
self.elem_type = get_mlir_elem_type(dtype)
self.dtype = mlir_to_numpy_dtype(self.elem_type)
self.memory_manager_class = GPUMemoryManager
self.payload_function_name = "payload"

@cached_property
def _initial_host_arrays(self) -> tuple[np.ndarray]:
"""Generate initial values on host with numpy."""
np.random.seed(42)
# Initialize Q, K, V with small random values
Q = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype)
K = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype)
V = np.random.uniform(-0.5, 0.5, self.shape).astype(self.dtype)
output_arr = np.zeros(self.shape, dtype=self.dtype)
return (output_arr, Q, K, V)

def get_complexity(self) -> tuple[int, int, int]:
nbytes = np.dtype(self.dtype).itemsize
return fused_attention_complexity(
self.Z, self.H, self.n_ctx, self.n_head, nbytes
)

def payload_module(self) -> ir.Module:
"""Generate MLIR module for fused attention payload."""
return generate_gpu_fused_attention_payload(
func_name=self.payload_function_name,
Z=self.Z,
H=self.H,
n_ctx=self.n_ctx,
n_head=self.n_head,
dtype=self.elem_type,
)

def schedule_modules(
self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None
) -> list[ir.Module]:
"""Generate transform schedule for fused attention."""
return [
Runner.get_bench_wrapper_schedule(self.payload_function_name),
get_fused_attention_schedule_module(
stop_at_stage=stop_at_stage,
parameters=parameters,
),
]

def shared_libs(self) -> list[str]:
return ["libmlir_levelzero_runtime.so"]


def parse_cli():
parser = argparse.ArgumentParser(
description="Fused Attention using MLIR XeGPU",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size (Z)",
)
parser.add_argument(
"--num-heads",
type=int,
default=1,
help="Number of attention heads (H)",
)
parser.add_argument(
"--n-ctx",
type=int,
default=512,
help="Context length (sequence length)",
)
parser.add_argument(
"--n-head",
type=int,
default=64,
help="Head dimension",
)
parser.add_argument(
"--wg-tile-size",
type=int,
default=64,
help="Workgroup tile size for the collapsed batch dimension (Z*H*n_ctx)",
)
parser.add_argument(
"--nruns",
type=int,
default=1000,
help="Number of runs to average the execution time.",
)
parser.add_argument(
"--nwarmup",
type=int,
default=20,
help="Number of warm-up iterations before benchmarking.",
)
parser.add_argument(
"--check-result",
action="store_true",
help="Check the result of the fused attention computation.",
)
parser.add_argument(
"--dump-kernel",
type=str,
choices=[
"initial",
"outer-tiled",
"inner-tiled",
"vectorized",
"bufferized",
"gpu-outlining",
"xegpu-initial",
"xegpu-wg",
"final",
],
help="Dump kernel IR at different stages of lowering and exit without "
"executing the kernel.",
)
parser.add_argument(
"--dump-schedule",
action="store_true",
help="Dump transform schedule.",
)
parser.add_argument(
"--verbose",
"-v",
action="count",
default=0,
help="Increase output verbosity (e.g. print reference and computed solutions).",
)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_cli()

params = {
"batch_size": args.batch_size,
"num_heads": args.num_heads,
"n_ctx": args.n_ctx,
"n_head": args.n_head,
"wg_tile_size": args.wg_tile_size,
}

Z = args.batch_size
H = args.num_heads
n_ctx = args.n_ctx
n_head = args.n_head
dtype = "f32"

with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()
wload = XeGPUFusedAttention(Z=Z, H=H, n_ctx=n_ctx, n_head=n_head, dtype=dtype)

if args.dump_kernel or args.dump_schedule:
pipeline = TransformDriver(
wload.schedule_modules(
stop_at_stage=args.dump_kernel, parameters=params
)
)
payload = pipeline.apply(wload.payload_module())
if args.dump_kernel:
print(payload)
if args.dump_schedule:
for schedule_module in wload.schedule_modules(parameters=params):
print(schedule_module)
else:
pipeline = TransformDriver(wload.schedule_modules(parameters=params))
payload = pipeline.apply(wload.payload_module())
runner = Runner(
payload,
mem_manager_cls=wload.memory_manager_class,
shared_libs=wload.shared_libs(),
)
if args.check_result:
# Setup callback function to copy result from device to host.
result_host_copy, argument_access_callback = (
Runner.get_gpu_argument_access_callback(wload.shape, wload.dtype)
)

# Execute kernel once.
runner.execute(
host_input_buffers=wload._initial_host_arrays,
payload_function_name=wload.payload_function_name,
argument_access_callback=argument_access_callback,
)

# Compute reference solution on host.
Q, K, V = wload._initial_host_arrays[1:4]
success = check_correctness(
Q,
K,
V,
result_host_copy,
verbose=args.verbose,
)
if not success:
raise ValueError("Result mismatch!")
else:
print("Result is correct. Proceeding to benchmark...")

times = runner.benchmark(
host_input_buffers=wload._initial_host_arrays,
nruns=args.nruns,
nwarmup=args.nwarmup,
)
times *= 1e6 # convert to microseconds
elapsed = np.mean(times)
flop_count = wload.get_complexity()[0]
gflops = flop_count / (elapsed * 1e-6) / 1e9

print(
f"batch-size={Z} "
f"num-heads={H} "
f"n-ctx={n_ctx} "
f"n-head={n_head} "
f"dt={dtype} "
f"time(us): {elapsed:.2f} "
f"GFLOPS: {gflops:.2f} "
)
Loading
Loading