Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 45 additions & 62 deletions examples/cpu/x86/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
import ml_dtypes
import numpy as np
from mlir import ir
from mlir.dialects import linalg, transform
from mlir.dialects.transform import tensor
from mlir.dialects import linalg

from lighthouse import dialects as lh_dialects
from lighthouse.execution.runner import Runner
Expand All @@ -38,8 +37,6 @@
import lighthouse.utils as lh_utils
from lighthouse import schedule as lh_schedule
import lighthouse.schedule.x86 as lh_schedule_x86
from lighthouse import transform as lh_transform
import lighthouse.transform.x86 as lh_transform_x86
import lighthouse.ingress.mlir_gen.utils as lh_mlir_utils
from functools import cached_property
from typing import Optional
Expand Down Expand Up @@ -133,7 +130,6 @@ def payload(A, B, C):
def get_pipeline(
self,
stop_at_stage: Optional[str] = None,
parameters: Optional[dict] = None,
) -> PipelineDriver:
scheds = PipelineDriver(self.context)

Expand All @@ -154,7 +150,9 @@ def get_pipeline(
rhs_transpose_inner_block=False,
)
)
scheds.add_transform(lh_schedule_x86.lower_packs_unpacks(self.tile_size))
scheds.add_transform(
lh_schedule_x86.lower_packs_unpacks(tile_size=self.tile_size)
)

# Convert to category ops for easier op matching.
scheds.add_pass(
Expand All @@ -164,66 +162,55 @@ def get_pipeline(
)
)

# GEMM cache tiling.
# Create memory friendly access pattern.
gemm_op = "linalg.contract"
Comment thread
rengolin marked this conversation as resolved.
with lh_schedule.schedule_boilerplate() as (sched, named_seq):
ops = lh_transform.match_op(named_seq.bodyTarget, gemm_op)
with lh_transform.foreach(ops) as op:
lh_transform_x86.matmul_cache_tiling(
op, tile_size=self.tile_size, fuse_producers=True
)
transform.yield_()
transform.yield_()
scheds.add_transform(sched)
# GEMM cache tiling, create memory friendly access pattern.
scheds.add_transform(
lh_schedule_x86.matmul_cache_tiling(
target="linalg.contract", tile_size=self.tile_size, fuse_producers=True
)
)

# Fold extra parallel outer unit dims before further tiling to help later
# vectorization rewrites to recognize ops.
scheds.add_transform(lh_schedule.linalg_contract_fold_unit_dims())

# GEMM register tiling.
# Ensure that computation can fit into vector registers.
# GEMM register tiling, ensure that computation can fit into vector registers.
Comment thread
rengolin marked this conversation as resolved.
reg_tile_batch = 1
reg_tile_m = 8
reg_tile_n = 32
reg_tile_k = 2
reg_peel_loops = []
assert self.tile_size % reg_tile_k == 0, "Invalid K dim register tiling"
if self.tile_size % reg_tile_n != 0:
reg_peel_loops.append(1)
if self.tile_size % reg_tile_m != 0:
reg_peel_loops.append(0)

scheds.add_transform(
lh_schedule.tile_ops(
gemm_op,
tile_sizes=[reg_tile_batch, reg_tile_m, reg_tile_n, reg_tile_k],
tile_interchange=[1, 2, 0, 3],
peel_loops=reg_peel_loops,
lh_schedule_x86.matmul_register_tiling(
target="linalg.contract",
tile_size=self.tile_size,
reg_tile_batch=reg_tile_batch,
reg_tile_m=reg_tile_m,
reg_tile_n=reg_tile_n,
reg_tile_k=reg_tile_k,
)
)

# GEMM register unroll.
# Ensure that shapes are compatible with target hardware instructions.
reg_unroll_m = 1
reg_unroll_n = 16
# When VNNI can be used, tuples of 32-bit elements are needed.
reg_unroll_k = 2 if self.dtype == ml_dtypes.bfloat16 else 1
reg_unroll_factors = [
reg_tile_m // reg_unroll_m,
reg_tile_n // reg_unroll_n,
reg_tile_k // reg_unroll_k,
]
# GEMM register unroll, ensure that shapes are compatible with target hardware instructions.
scheds.add_transform(
lh_schedule.tile_ops(
gemm_op,
tile_sizes=[0, reg_unroll_m, reg_unroll_n, reg_unroll_k],
unroll_factors=reg_unroll_factors,
lh_schedule_x86.matmul_register_unroll(
target="linalg.contract",
reg_tile_m=reg_tile_m,
reg_tile_n=reg_tile_n,
reg_tile_k=reg_tile_k,
reg_unroll_m=1,
reg_unroll_n=16,
reg_unroll_k=2 if self.dtype == ml_dtypes.bfloat16 else 1,
batch=reg_tile_batch > 0,
)
)

# Further tiling into hardware-friendly sizes for vectorization.
scheds.add_transform(lh_schedule.tile_ops("linalg.fill", tile_sizes=[1, 1, 1]))
scheds.add_transform(lh_schedule.tile_ops("linalg.generic", tile_sizes=[1, 8]))
scheds.add_transform(
lh_schedule.tile_ops(target_op="linalg.fill", tile_sizes=[1, 1, 1])
)
scheds.add_transform(
lh_schedule.tile_ops(target_op="linalg.generic", tile_sizes=[1, 8])
)

if stop_at_stage == "tiled":
return scheds
Expand All @@ -232,14 +219,7 @@ def get_pipeline(
scheds.add_transform(lh_schedule.vectorize_linalg())
scheds.add_transform(lh_schedule.hoist_loops())

with lh_schedule.schedule_boilerplate() as (sched, named_seq):
with ir.InsertionPoint(
transform.ApplyPatternsOp(named_seq.bodyTarget).patterns
):
tensor.apply_patterns_tensor_fold_tensor_subset_ops_into_vector_transfers()
transform.apply_patterns_canonicalization()
transform.yield_()
scheds.add_transform(sched)
scheds.add_transform(lh_schedule.simplify_vector_ops())

# Rewrite vector ops into x86-specific sequences.
scheds.add_transform(lh_schedule.x86_vectorization())
Expand All @@ -254,11 +234,7 @@ def get_pipeline(
scheds.add_transform(lh_schedule.vectorize_all())

# Cleanup vector ops.
with lh_schedule.schedule_boilerplate() as (sched, named_seq):
lh_transform.flatten_vector_ops(named_seq.bodyTarget)
lh_transform.cleanup(named_seq.bodyTarget)
transform.yield_()
scheds.add_transform(sched)
scheds.add_transform(lh_schedule.flatten_vector_ops())

if stop_at_stage == "vectorized":
return scheds
Expand Down Expand Up @@ -331,6 +307,11 @@ def parse_cli():
action="store_true",
help="Dump transform schedule.",
)
parser.add_argument(
"--print-mlir-after-all",
action="store_true",
help="Dump MLIR after all transformations.",
)
args = parser.parse_args()

return args
Expand All @@ -354,7 +335,9 @@ def parse_cli():

wload = Matmul(*args.sizes, dtype=in_dtype, tile_size=args.tile_size)
pipeline = wload.get_pipeline(stop_at_stage=args.dump_kernel)
payload = pipeline.apply(wload.payload_module())
payload = pipeline.apply(
wload.payload_module(), print_after_all=args.print_mlir_after_all
)

if args.dump_kernel or args.dump_schedule:
if args.dump_kernel:
Expand Down
33 changes: 33 additions & 0 deletions examples/end-to-end/KernelBench/cpu_matmul.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This is the default pipeline for kernel_bench, producing serial code.
# For optimized builds, pass the --pipeline to the kernel_bench program.
Pipeline:
## Packing & Cache tiling (CPU generic)
- schedule: "packing.py[gen=block_pack_matmuls]{block_factors=32,32,32 rhs_transpose_outer_block=True rhs_transpose_inner_block=False}"
- schedule: "x86/pack_lowering.py[gen=lower_packs_unpacks]{tile_size=32}"
- pass: "linalg-morph-ops{named-to-category generic-to-category}"
- schedule: "x86/cache_tiling.py[gen=matmul_cache_tiling]{target=linalg.contract fuse_producers}"
- schedule: "linalg.py[gen=linalg_contract_fold_unit_dims]"

## CPU specific register tiling (depends on uArch & data type)
- schedule: "x86/register_tiling.py[gen=matmul_register_tiling]{target=linalg.contract}"
- schedule: "x86/register_tiling.py[gen=matmul_register_unroll]{target=linalg.contract}"
- schedule: "tiling.py[gen=tile_ops]{target_op=linalg.fill tile_sizes=[1,1,1]}"
Comment thread
rengolin marked this conversation as resolved.
- schedule: "tiling.py[gen=tile_ops]{target_op=linalg.generic tile_sizes=[1,8]}"

## Tensor vectorization (for the left-over element wise)
- schedule: "vectorization.py[gen=vectorize_linalg]"
- schedule: "vectorization.py[gen=simplify_vector_ops]"
- include: cleanup.yaml
- schedule: "vectorization.py[gen=x86_vectorization]"

## Bufferization
- include: bufferization.yaml
- include: bufferization_cleanup.yaml

## Buffer vectorization (gets rid of all the rest)
- schedule: "vectorization.py[gen=x86_vectorization]"
- schedule: "vectorization.py[gen=vectorize_all]"
- schedule: "vectorization.py[gen=flatten_vector_ops]"

## Lower to LLVM
- include: llvm_lowering.yaml
29 changes: 21 additions & 8 deletions examples/end-to-end/KernelBench/test_kernel_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,47 @@
# REQUIRES: kernel_bench

import subprocess
import platform
from pathlib import Path

script_path = Path(__file__).parent
project_root = script_path.parent.parent.parent
kb_program = project_root / "tools" / "kernel_bench"
kb_default_pipeline = kb_program.parent / "kernel_bench.yaml"
kb_path = project_root / "third_party" / "KernelBench" / "KernelBench"
arch = platform.machine()
Comment thread
adam-smnk marked this conversation as resolved.
tests = [
{
"kernel": "level1/1_Square_matrix_multiplication_.py",
"input_shapes": "32x32xf32xrnd,32x32xf32xid",
"output_shape": "32x32xf32x0",
"pipeline": f"{script_path}/cpu_matmul.yaml"
if arch == "x86_64"
else str(kb_default_pipeline),
},
{
"kernel": "level1/1_Square_matrix_multiplication_.py",
"input_shapes": "32x32xbf16xrnd,32x32xbf16xid",
"output_shape": "32x32xbf16x0",
"pipeline": str(kb_default_pipeline),
},
{
"kernel": "level1/2_Standard_matrix_multiplication_.py",
"input_shapes": "8x16xf32xrnd,16x8xf32xrnd",
"output_shape": "8x8xf32x0",
"pipeline": f"{script_path}/cpu_matmul.yaml"
if arch == "x86_64"
else str(kb_default_pipeline),
},
{
"kernel": "level1/2_Standard_matrix_multiplication_.py",
"input_shapes": "8x16xbf16xrnd,16x8xbf16xrnd",
"output_shape": "8x8xbf16x0",
"pipeline": str(kb_default_pipeline),
},
]

if __name__ == "__main__":
project_root = Path(__file__).parent.parent.parent.parent
kb_program = project_root / "tools" / "kernel_bench"
kb_path = project_root / "third_party" / "KernelBench" / "KernelBench"

for test in tests:
kb_kernel = kb_path / test["kernel"]
command_line = [
Expand All @@ -43,6 +54,8 @@
test["input_shapes"],
"--output-shape",
test["output_shape"],
"--pipeline",
test["pipeline"],
"--print-tensor=1",
"--seed=42",
]
Expand All @@ -61,8 +74,8 @@
assert result.returncode == 0, "Execution failed"

# CHECK: 1_Square_matrix_multiplication_.mlir
# CHECK 0.37454012 0.9507143 0.7319939 ... 0.04645041 0.60754484 0.17052412
# CHECK: 0.27214515 0.59023064 0.3609739 ... 0.297349 0.9243962 0.97105825
# CHECK 0.3745{{.*}} 0.9507{{.*}} 0.7319{{.*}} ... 0.0464{{.*}} 0.6075{{.*}} 0.1705{{.*}}
# CHECK: 0.2721{{.*}} 0.5902{{.*}} 0.3609{{.*}} ... 0.2973{{.*}} 0.9243{{.*}} 0.9710{{.*}}

# CHECK-NOT: Execution failed

Expand All @@ -73,8 +86,8 @@
# CHECK-NOT: Execution failed

# CHECK: 2_Standard_matrix_multiplication_.mlir
# CHECK: 3.120935 3.7697 4.5365195 4.397648 4.4506536 3.2665431 3.5362916
# CHECK: 5.036752 5.312808 5.8109508 4.810084 4.7435184 4.35573 5.311559
# CHECK: 3.1209{{.*}} 3.7697{{.*}} 4.5365{{.*}} 4.3976{{.*}} 4.4506{{.*}} 3.2665{{.*}} 3.5362{{.*}}
# CHECK: 5.0367{{.*}} 5.3128{{.*}} 5.8109{{.*}} 4.8100{{.*}} 4.7435{{.*}} 4.3557{{.*}} 5.3115{{.*}}

# CHECK-NOT: Execution failed

Expand Down
16 changes: 14 additions & 2 deletions lighthouse/pipeline/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _normalize_include_path(self) -> str:
)

@staticmethod
def _string_to_type(value: str) -> str | int | float | bool:
def _string_to_type(value: str) -> str | int | float | bool | list:
value = str(value)
if value == "True":
return True
Expand All @@ -145,7 +145,19 @@ def _string_to_type(value: str) -> str | int | float | bool:
try:
return float(value)
except ValueError:
return value
# List of values, e.g. [val1,val2,...]
if value.startswith("[") and value.endswith("]"):
list_str = value[1:-1]
# List of values, e.g. val1,val2,...
elif value.find(",") != -1:
list_str = value
else:
# Something else entirely, return as string
return value
# Recursesively parse the list elements
return [
Descriptor._string_to_type(v.strip()) for v in list_str.split(",")
]

@staticmethod
def _parse_csv(line: str, separator: str = ",") -> dict:
Expand Down
3 changes: 2 additions & 1 deletion lighthouse/pipeline/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def apply(self, module: ir.Module, print_after_all: bool = False) -> ir.Module:
if module.context != self.context:
raise ValueError("Module context does not match driver context.")
for stage in self.stages:
module = stage.apply(module)
with self.context:
module = stage.apply(module)
if print_after_all:
print(f"After stage {stage}:\n{module}")
return module
Expand Down
3 changes: 2 additions & 1 deletion lighthouse/pipeline/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,12 @@ def __init__(self, transform: Transform | ir.Module, context: ir.Context):
raise ValueError(
f"Transform module '{transform.filename}' does not define a '{transform.generator}' generator function."
)
# Get the generator function in the transform module.
self.generator = getattr(transform_module, transform.generator)

# Run the function with the dictionary as the options that will create the named sequence.
with context, ir.Location.unknown():
self.module = self.generator(transform.options)
self.module = self.generator(**transform.options)
else:
raise ValueError(f"Unsupported transform type: {transform.type}")

Expand Down
4 changes: 4 additions & 0 deletions lighthouse/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .linalg import linalg_contract_fold_unit_dims
from .packing import block_pack_matmuls
from .tiling import tile_ops
from .vectorization import flatten_vector_ops
from .vectorization import simplify_vector_ops
from .vectorization import vectorize_linalg
from .vectorization import vectorize_all
from .vectorization import x86_vectorization
Expand All @@ -18,10 +20,12 @@
"convert_function_results",
"create_named_sequence",
"create_schedule",
"flatten_vector_ops",
"hoist_loops",
"linalg_contract_fold_unit_dims",
"print_ir",
"schedule_boilerplate",
"simplify_vector_ops",
"tile_ops",
"vectorize_all",
"vectorize_linalg",
Expand Down
2 changes: 1 addition & 1 deletion lighthouse/schedule/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import lighthouse.transform as lh_transform


def linalg_contract_fold_unit_dims(options: dict = {}) -> ir.Module:
def linalg_contract_fold_unit_dims() -> ir.Module:
"""
Fold unit dims of linalg contract.

Expand Down
2 changes: 1 addition & 1 deletion lighthouse/schedule/packing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from mlir import ir
from mlir.dialects import transform

from .builders import schedule_boilerplate
from lighthouse.schedule.builders import schedule_boilerplate
import lighthouse.transform as lh_transform


Expand Down
Loading
Loading