Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/end-to-end/KernelBench/cpu_matmul.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ Pipeline:
- schedule: "linalg.py[gen=linalg_contract_fold_unit_dims]"

## CPU specific register tiling (depends on uArch & data type)
- schedule: "x86/register_tiling.py[gen=matmul_register_tiling]{target=linalg.contract}"
- schedule: "x86/register_tiling.py[gen=matmul_register_unroll]{target=linalg.contract}"
- schedule: "x86/register_tiling.py[gen=matmul_register_tiling]{target=linalg.contract reg_tile_batch=1 reg_tile_m=8 reg_tile_n=32 reg_tile_k=2}"
- schedule: "x86/register_tiling.py[gen=matmul_register_unroll]{target=linalg.contract batch=1 reg_tile_m=8 reg_tile_n=32 reg_tile_k=2 reg_unroll_m=1 reg_unroll_n=16 reg_unroll_k=1}"
- schedule: "tiling.py[gen=tile_ops]{target_op=linalg.fill tile_sizes=[1,1,1]}"
- schedule: "tiling.py[gen=tile_ops]{target_op=linalg.generic tile_sizes=[1,8]}"

## Tensor vectorization (for the left-over element wise)
- schedule: "vectorization.py[gen=vectorize_linalg]"
- schedule: "hoisting.py[gen=hoist_loops]"
- schedule: "vectorization.py[gen=simplify_vector_ops]"
- include: cleanup.yaml
- schedule: "vectorization.py[gen=x86_vectorization]"

## Bufferization
Expand Down
40 changes: 31 additions & 9 deletions examples/end-to-end/KernelBench/test_kernel_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# REQUIRES: torch
# REQUIRES: kernel_bench

import re
import subprocess
import platform
from pathlib import Path
Expand All @@ -16,8 +17,9 @@
tests = [
{
"kernel": "level1/1_Square_matrix_multiplication_.py",
"input_shapes": "32x32xf32xrnd,32x32xf32xid",
"output_shape": "32x32xf32x0",
"input_shapes": "1024x1024xf32xrnd,1024x1024xf32xid",
"output_shape": "1024x1024xf32x0",
"gflops": (1024 * 1024 * 1024 * 2) / 1e9,
"pipeline": f"{script_path}/cpu_matmul.yaml"
if arch == "x86_64"
else str(kb_default_pipeline),
Expand All @@ -30,8 +32,9 @@
},
{
"kernel": "level1/2_Standard_matrix_multiplication_.py",
"input_shapes": "8x16xf32xrnd,16x8xf32xrnd",
"output_shape": "8x8xf32x0",
"input_shapes": "512x1024xf32xrnd,1024x512xf32xrnd",
"output_shape": "512x512xf32x0",
"gflops": (512 * 1024 * 512 * 2) / 1e9,
"pipeline": f"{script_path}/cpu_matmul.yaml"
if arch == "x86_64"
else str(kb_default_pipeline),
Expand All @@ -44,6 +47,16 @@
},
]


def get_flops_per_second(stdout: str, gflops: float) -> float:
for line in stdout.splitlines():
match = re.search(r"([0-9.e-]+) seconds", line)
if match:
seconds = float(match.group(1))
return gflops / seconds
return 0.0


if __name__ == "__main__":
for test in tests:
kb_kernel = kb_path / test["kernel"]
Expand All @@ -59,6 +72,8 @@
"--print-tensor=1",
"--seed=42",
]
if "gflops" in test:
command_line += ["--benchmark"]
print(f"Running command: {' '.join(command_line)}")
result = subprocess.run(
command_line,
Expand All @@ -68,26 +83,33 @@

print("STDOUT:")
print(result.stdout)
if "gflops" in test:
flops_per_second = get_flops_per_second(result.stdout, test["gflops"])
if flops_per_second > 0:
print(f"Performance: {flops_per_second:.2f} GFLOPS")

print("STDERR:")
print(result.stderr)
print(f"Return code: {result.returncode}")
assert result.returncode == 0, "Execution failed"

# CHECK: 1_Square_matrix_multiplication_.mlir
# CHECK 0.3745{{.*}} 0.9507{{.*}} 0.7319{{.*}} ... 0.0464{{.*}} 0.6075{{.*}} 0.1705{{.*}}
# CHECK: 0.2721{{.*}} 0.5902{{.*}} 0.3609{{.*}} ... 0.2973{{.*}} 0.9243{{.*}} 0.9710{{.*}}
# CHECK: 0.3745{{.*}} 0.9507{{.*}} 0.7319{{.*}} ... 0.2973{{.*}} 0.9243{{.*}} 0.9710{{.*}}
# CHECK: 0.7201{{.*}} 0.9926{{.*}} 0.1208{{.*}} ... 0.1742{{.*}} 0.3485{{.*}} 0.6436{{.*}}
# CHECK: Performance: {{.*}} GFLOPS

# CHECK-NOT: Execution failed

# CHECK: 1_Square_matrix_multiplication_.mlir
# CHECK 0.375 0.949219 0.730469 ... 0.0463867 0.609375 0.170898
# CHECK: 0.375 0.949219 0.730469 ... 0.0463867 0.609375 0.170898
# CHECK: 0.271484 0.589844 0.361328 ... 0.296875 0.925781 0.972656

# CHECK-NOT: Execution failed

# CHECK: 2_Standard_matrix_multiplication_.mlir
# CHECK: 3.1209{{.*}} 3.7697{{.*}} 4.5365{{.*}} 4.3976{{.*}} 4.4506{{.*}} 3.2665{{.*}} 3.5362{{.*}}
# CHECK: 5.0367{{.*}} 5.3128{{.*}} 5.8109{{.*}} 4.8100{{.*}} 4.7435{{.*}} 4.3557{{.*}} 5.3115{{.*}}
# CHECK: 249.78{{.*}} 260.13{{.*}} 249.36{{.*}} ... 261.10{{.*}} 260.49{{.*}} 257.09{{.*}}
# CHECK: 243.56{{.*}} 250.91{{.*}} 252.38{{.*}} ... 260.40{{.*}} 261.56{{.*}} 256.24{{.*}}
# CHECK: Performance: {{.*}} GFLOPS

# CHECK-NOT: Execution failed

Expand Down
2 changes: 1 addition & 1 deletion lighthouse/schedule/hoisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mlir.dialects import transform
from mlir.dialects.transform.structured import MatchInterfaceEnum

from .builders import schedule_boilerplate
from lighthouse.schedule.builders import schedule_boilerplate
import lighthouse.transform as lh_transform


Expand Down
Loading