Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# This is an optimizing pipeline for kernel_bench matmuls on bf16 types.
# This is basically a copy of the fp32 pipeline, with ONE CHANGE:
# - register_tiling.py -> reg_unroll_k=2 (instead of 1)
# Tested on x86_64 with AVX512 reaching good performance for simple KB kernels.
# It may not apply to other workloads / extensions / architectures, so use with caution.
Pipeline:
- include: pack_and_tile.yaml
- include: ../pack_and_tile.yaml

## CPU specific register tiling (depends on uArch & data type)
- schedule: "x86/register_tiling.py[gen=matmul_register_tiling]{target=linalg.contract reg_tile_batch=1 reg_tile_m=8 reg_tile_n=32 reg_tile_k=2}"
- schedule: "x86/register_tiling.py[gen=matmul_register_unroll]{target=linalg.contract batch=1 reg_tile_m=8 reg_tile_n=32 reg_tile_k=2 reg_unroll_m=1 reg_unroll_n=16 reg_unroll_k=2}"
- schedule: "tiling.py[gen=tile_ops]{target_op=linalg.fill tile_sizes=[1,1,1]}"
- schedule: "tiling.py[gen=tile_ops]{target_op=linalg.generic tile_sizes=[1,8]}"

- include: vectorize.yaml
- include: ../vectorize.yaml

- include: lower.yaml
- include: ../lower.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# Tested on x86_64 with AVX512 reaching good performance for simple KB kernels.
# It may not apply to other workloads / extensions / architectures, so use with caution.
Pipeline:
- include: pack_and_tile.yaml
- include: ../pack_and_tile.yaml

## CPU specific register tiling (depends on uArch & data type)
- schedule: "x86/register_tiling.py[gen=matmul_register_tiling]{target=linalg.contract reg_tile_batch=1 reg_tile_m=8 reg_tile_n=32 reg_tile_k=2}"
- schedule: "x86/register_tiling.py[gen=matmul_register_unroll]{target=linalg.contract batch=1 reg_tile_m=8 reg_tile_n=32 reg_tile_k=2 reg_unroll_m=1 reg_unroll_n=16 reg_unroll_k=1}"
- schedule: "tiling.py[gen=tile_ops]{target_op=linalg.fill tile_sizes=[1,1,1]}"
- schedule: "tiling.py[gen=tile_ops]{target_op=linalg.generic tile_sizes=[1,8]}"

- include: vectorize.yaml
- include: ../vectorize.yaml

- include: lower.yaml
- include: ../lower.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Tensor level vectorization for matmul like kernels on any type.
Pipeline:
## Tensor vectorization (for the left-over element wise)
- schedule: "vectorization.py[gen=vectorize_linalg]"
- schedule: "hoisting.py[gen=hoist_loops]"
- schedule: "vectorization.py[gen=simplify_vector_ops]"
- schedule: "vectorization.py[gen=x86_vectorization]"
129 changes: 87 additions & 42 deletions examples/end-to-end/KernelBench/test_kernel_bench.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# RUN: python %s | FileCheck %s
# RUN: python %s --ci | FileCheck %s

# REQUIRES: torch
# REQUIRES: kernel_bench
Expand All @@ -9,60 +9,55 @@
import platform
from pathlib import Path

import yaml

script_path = Path(__file__).parent
project_root = script_path.parent.parent.parent
kb_program = project_root / "tools" / "kernel_bench"
kb_default_pipeline = kb_program.parent / "kernel_bench.yaml"
kb_path = project_root / "third_party" / "KernelBench" / "KernelBench"
yaml_path = script_path / "tests.yaml"


def get_pipeline_file(kernel_name: str, dtype: str) -> Path:
def get_pipeline_file(name: str, dtype: str) -> Path:
"""
Returns the appropriate pipeline file for a given kernel.
"""
arch = platform.machine()
if arch != "x86_64":
return kb_default_pipeline

# Level 1 matmuls should use the same pipelines
if kernel_name.startswith("level1") and "matrix_multiplication" in kernel_name:
pipeline = script_path / f"schedules/{arch}/matmul/{dtype}.yaml"
# If the pipeline file exists for the given name and dtype
if name:
pipeline = script_path / f"schedules/{arch}/{name}/{dtype}.yaml"
if pipeline.exists():
return pipeline

# Otherwise, just return the safe option
return kb_default_pipeline


tests = [
{
"kernel": "level1/1_Square_matrix_multiplication_.py",
"input_shapes": ["1024x1024", "1024x1024"],
"initializations": ["rnd", "id"],
"output_shape": "1024x1024",
"dtypes": ["f32", "bf16"],
"gflops": (1024 * 1024 * 1024 * 2) / 1e9,
},
{
"kernel": "level1/2_Standard_matrix_multiplication_.py",
"input_shapes": ["512x1024", "1024x512"],
"initializations": ["rnd", "rnd"],
"output_shape": "512x512",
"dtypes": ["f32", "bf16"],
"gflops": (512 * 1024 * 512 * 2) / 1e9,
},
]


def get_tests(args: argparse.Namespace) -> list[dict]:
"""
Returns the list of tests to be executed.
"""
if args.ci:
print(
"Running in CI mode: fewer tests, no bf16, no benchmarking for faster feedback"
)
args.bf16 = False # Disable bf16 tests in CI for faster feedback
args.benchmark = False # Disable benchmarking in CI for faster feedback

tests = []
with open(yaml_path) as f:
tests = yaml.safe_load(f)

test_list = []
for test in tests:
for dtype in test["dtypes"]:
if not args.bf16 and dtype == "bf16":
continue
# If a specific test is specified, only include that test
if args.test and not test["kernel"].startswith(args.test):
continue
test_list.append(
{
"kernel": test["kernel"],
Expand All @@ -76,9 +71,13 @@ def get_tests(args: argparse.Namespace) -> list[dict]:
"gflops": test["gflops"]
if "gflops" in test and args.benchmark
else None,
"pipeline": str(get_pipeline_file(test["kernel"], dtype)),
"pipeline": str(get_pipeline_file(test.get("pipeline", ""), dtype)),
"warning": test.get("warning", None),
}
)
# CI mode runs fewer tests for faster feedback
if args.ci and len(test_list) >= 5:
return test_list
return test_list


Expand All @@ -105,9 +104,33 @@ def get_flops_per_second(stdout: str, gflops: float) -> float:
action=argparse.BooleanOptionalAction,
help="Enable bf16 precision kernels.",
)
Parser.add_argument(
"--ci",
action=argparse.BooleanOptionalAction,
help="Enable CI mode (faster run, fewer kernels).",
)
Parser.add_argument(
"--test",
type=str,
help="Specify a particular test to run.",
)
Parser.add_argument(
"--print-mlir-after-all",
action=argparse.BooleanOptionalAction,
help="Whether to print the MLIR module after all stages. Default is False.",
)
args = Parser.parse_args()
tests = get_tests(args)
if len(tests) == 0:
if args.test:
print(
f"No tests found matching '{args.test}'. Please check your arguments."
)
else:
print("No tests to run. Please check your arguments.")
exit(0)

for test in get_tests(args):
for test in tests:
kb_kernel = kb_path / test["kernel"]
command_line = [
str(kb_program),
Expand All @@ -121,36 +144,58 @@ def get_flops_per_second(stdout: str, gflops: float) -> float:
"--print-tensor=1",
"--seed=42",
]
benchmark = test.get("gflops") is not None
benchmark = args.benchmark and test.get("gflops") is not None
if benchmark:
command_line += ["--benchmark"]
if args.print_mlir_after_all:
command_line += ["--print-mlir-after-all"]
if test.get("warning"):
print(f"WARNING: {test['warning']}")
print(f"Running command: {' '.join(command_line)}")

# While debugging kernels, it's useful to see the output as it comes.
# Note: GFLOPS can't be shown if the output is not captured.
capture_output = True
if args.print_mlir_after_all and not args.ci:
capture_output = False

result = subprocess.run(
command_line,
capture_output=True,
capture_output=capture_output,
text=True,
)

print("STDOUT:")
print(result.stdout)
if benchmark:
flops_per_second = get_flops_per_second(result.stdout, test["gflops"])
if flops_per_second > 0:
print(f"Performance: {flops_per_second:.2f} GFLOPS")
# If output is captured, print it out, including benchmark results if applicable.
if capture_output:
print("STDOUT:")
print(result.stdout)
if benchmark:
flops_per_second = get_flops_per_second(result.stdout, test["gflops"])
if flops_per_second > 0:
print(f"Performance: {flops_per_second:.2f} GFLOPS")

print("STDERR:")
print(result.stderr)

print("STDERR:")
print(result.stderr)
print(f"Return code: {result.returncode}")
assert result.returncode == 0, "Execution failed"

# CHECK: 1_Square_matrix_multiplication_.mlir
# CHECK: 0.3745{{.*}} 0.9507{{.*}} 0.7319{{.*}} ... 0.2973{{.*}} 0.9243{{.*}} 0.9710{{.*}}
# CHECK: 0.7201{{.*}} 0.9926{{.*}} 0.1208{{.*}} ... 0.1742{{.*}} 0.3485{{.*}} 0.6436{{.*}}

# CHECK-NOT: Execution failed

# CHECK: 2_Standard_matrix_multiplication_.mlir
# CHECK: 249.78{{.*}} 260.13{{.*}} 249.36{{.*}} ... 261.10{{.*}} 260.49{{.*}} 257.09{{.*}}
# CHECK: 243.56{{.*}} 250.91{{.*}} 252.38{{.*}} ... 260.40{{.*}} 261.56{{.*}} 256.24{{.*}}

# CHECK-NOT: Execution failed
# CHECK: 3_Batched_matrix_multiplication.mlir
# CHECK: 5.2403{{.*}} 7.7905{{.*}} 6.0769{{.*}} ... 7.8579{{.*}} 6.8890{{.*}} 6.6193{{.*}}
# CHECK: 9.0407{{.*}} 6.3299{{.*}} 5.2003{{.*}} ... 6.2594{{.*}} 6.2980{{.*}} 5.9807{{.*}}

# CHECK: 4_Matrix_vector_multiplication_.mlir
# CHECK: 264.86{{.*}}
# CHECK: 265.12{{.*}}

# CHECK: 5_Matrix_scalar_multiplication.mlir
# CHECK: 0.1750{{.*}} 0.4442{{.*}} 0.3420{{.*}} ... 0.1389{{.*}} 0.4319{{.*}} 0.4538{{.*}}
# CHECK: 0.3365{{.*}} 0.4638{{.*}} 0.0564{{.*}} ... 0.0814{{.*}} 0.1628{{.*}} 0.3007{{.*}}
Loading
Loading