From d69978d46c2f48e6cb909e4141391be794695c74 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Fri, 8 May 2026 16:41:50 +0100 Subject: [PATCH 1/8] Change matmul / kernel_bench tests to reuse CPU pipeline This PR moves all remaining schedules from CPU test `matmul.py` into their own schedules in Lighthouse, so they can be reused by the kernel_bench test for similar tests (L1, K1 & K2). The performance of `matmul.py` remains the same, while the perf for the Kernel Bench kernels has improved dramatically. This is only working for FP32 for now, so the BF16 tests in KB are still using the old (lower-to-loops) strategy. A follow up PR will fix the perf for BF16 and update the kernel bench testing to start tracking performance, just like `matmul.py`. assisted-by: GitHub Copilot --- examples/cpu/x86/matmul.py | 109 +++++++----------- .../end-to-end/KernelBench/cpu_matmul.yaml | 33 ++++++ .../KernelBench/test_kernel_bench.py | 19 ++- lighthouse/pipeline/descriptor.py | 18 ++- lighthouse/pipeline/driver.py | 3 +- lighthouse/schedule/__init__.py | 4 + lighthouse/schedule/hoisting.py | 10 +- lighthouse/schedule/packing.py | 26 ++--- lighthouse/schedule/tiling.py | 18 +-- lighthouse/schedule/vectorization.py | 42 ++++++- lighthouse/schedule/x86/__init__.py | 5 + lighthouse/schedule/x86/cache_tiling.py | 48 ++++++++ lighthouse/schedule/x86/pack_lowering.py | 46 ++++---- lighthouse/schedule/x86/register_tiling.py | 93 +++++++++++++++ lighthouse/transform/x86/__init__.py | 5 - lighthouse/transform/x86/cache_tiling.py | 45 -------- 16 files changed, 345 insertions(+), 179 deletions(-) create mode 100644 examples/end-to-end/KernelBench/cpu_matmul.yaml create mode 100644 lighthouse/schedule/x86/cache_tiling.py create mode 100644 lighthouse/schedule/x86/register_tiling.py delete mode 100644 lighthouse/transform/x86/__init__.py delete mode 100644 lighthouse/transform/x86/cache_tiling.py diff --git a/examples/cpu/x86/matmul.py b/examples/cpu/x86/matmul.py index 88d2e267..36dc3171 100644 --- a/examples/cpu/x86/matmul.py +++ b/examples/cpu/x86/matmul.py @@ -27,8 +27,7 @@ import ml_dtypes import numpy as np from mlir import ir -from mlir.dialects import linalg, transform -from mlir.dialects.transform import tensor +from mlir.dialects import linalg from lighthouse import dialects as lh_dialects from lighthouse.execution.runner import Runner @@ -38,8 +37,6 @@ import lighthouse.utils as lh_utils from lighthouse import schedule as lh_schedule import lighthouse.schedule.x86 as lh_schedule_x86 -from lighthouse import transform as lh_transform -import lighthouse.transform.x86 as lh_transform_x86 import lighthouse.ingress.mlir_gen.utils as lh_mlir_utils from functools import cached_property from typing import Optional @@ -149,12 +146,16 @@ def get_pipeline( # Create cache-friendly access pattern across matmul tiles. scheds.add_transform( lh_schedule.block_pack_matmuls( - block_factors=[self.tile_size, self.tile_size, self.tile_size], - rhs_transpose_outer_block=True, - rhs_transpose_inner_block=False, + { + "block_factors": [self.tile_size, self.tile_size, self.tile_size], + "rhs_transpose_outer_block": True, + "rhs_transpose_inner_block": False, + } ) ) - scheds.add_transform(lh_schedule_x86.lower_packs_unpacks(self.tile_size)) + scheds.add_transform( + lh_schedule_x86.lower_packs_unpacks({"tile_size": self.tile_size}) + ) # Convert to category ops for easier op matching. scheds.add_pass( @@ -166,64 +167,51 @@ def get_pipeline( # GEMM cache tiling. # Create memory friendly access pattern. - gemm_op = "linalg.contract" - with lh_schedule.schedule_boilerplate() as (sched, named_seq): - ops = lh_transform.match_op(named_seq.bodyTarget, gemm_op) - with lh_transform.foreach(ops) as op: - lh_transform_x86.matmul_cache_tiling( - op, tile_size=self.tile_size, fuse_producers=True - ) - transform.yield_() - transform.yield_() - scheds.add_transform(sched) + scheds.add_transform( + lh_schedule_x86.matmul_cache_tiling( + {"tile_size": self.tile_size, "fuse_producers": True} + ) + ) # Fold extra parallel outer unit dims before further tiling to help later # vectorization rewrites to recognize ops. scheds.add_transform(lh_schedule.linalg_contract_fold_unit_dims()) - # GEMM register tiling. - # Ensure that computation can fit into vector registers. - reg_tile_batch = 1 - reg_tile_m = 8 - reg_tile_n = 32 - reg_tile_k = 2 - reg_peel_loops = [] - assert self.tile_size % reg_tile_k == 0, "Invalid K dim register tiling" - if self.tile_size % reg_tile_n != 0: - reg_peel_loops.append(1) - if self.tile_size % reg_tile_m != 0: - reg_peel_loops.append(0) + # GEMM register tiling, ensure that computation can fit into vector registers. scheds.add_transform( - lh_schedule.tile_ops( - gemm_op, - tile_sizes=[reg_tile_batch, reg_tile_m, reg_tile_n, reg_tile_k], - tile_interchange=[1, 2, 0, 3], - peel_loops=reg_peel_loops, + lh_schedule_x86.matmul_register_tiling( + { + "tile_size": self.tile_size, + "reg_tile_m": 8, + "reg_tile_n": 32, + "reg_tile_k": 2, + "batch": True, + } ) ) - # GEMM register unroll. - # Ensure that shapes are compatible with target hardware instructions. - reg_unroll_m = 1 - reg_unroll_n = 16 - # When VNNI can be used, tuples of 32-bit elements are needed. - reg_unroll_k = 2 if self.dtype == ml_dtypes.bfloat16 else 1 - reg_unroll_factors = [ - reg_tile_m // reg_unroll_m, - reg_tile_n // reg_unroll_n, - reg_tile_k // reg_unroll_k, - ] + # GEMM register unroll, ensure that shapes are compatible with target hardware instructions. scheds.add_transform( - lh_schedule.tile_ops( - gemm_op, - tile_sizes=[0, reg_unroll_m, reg_unroll_n, reg_unroll_k], - unroll_factors=reg_unroll_factors, + lh_schedule_x86.matmul_register_unroll( + { + "reg_tile_m": 8, + "reg_tile_n": 32, + "reg_tile_k": 2, + "reg_unroll_m": 1, + "reg_unroll_n": 16, + "reg_unroll_k": 2 if self.dtype == ml_dtypes.bfloat16 else 1, + "batch": True, + } ) ) # Further tiling into hardware-friendly sizes for vectorization. - scheds.add_transform(lh_schedule.tile_ops("linalg.fill", tile_sizes=[1, 1, 1])) - scheds.add_transform(lh_schedule.tile_ops("linalg.generic", tile_sizes=[1, 8])) + scheds.add_transform( + lh_schedule.tile_ops({"target_op": "linalg.fill", "tile_sizes": [1, 1, 1]}) + ) + scheds.add_transform( + lh_schedule.tile_ops({"target_op": "linalg.generic", "tile_sizes": [1, 8]}) + ) if stop_at_stage == "tiled": return scheds @@ -232,14 +220,7 @@ def get_pipeline( scheds.add_transform(lh_schedule.vectorize_linalg()) scheds.add_transform(lh_schedule.hoist_loops()) - with lh_schedule.schedule_boilerplate() as (sched, named_seq): - with ir.InsertionPoint( - transform.ApplyPatternsOp(named_seq.bodyTarget).patterns - ): - tensor.apply_patterns_tensor_fold_tensor_subset_ops_into_vector_transfers() - transform.apply_patterns_canonicalization() - transform.yield_() - scheds.add_transform(sched) + scheds.add_transform(lh_schedule.fold_into_vector_transfer()) # Rewrite vector ops into x86-specific sequences. scheds.add_transform(lh_schedule.x86_vectorization()) @@ -254,11 +235,7 @@ def get_pipeline( scheds.add_transform(lh_schedule.vectorize_all()) # Cleanup vector ops. - with lh_schedule.schedule_boilerplate() as (sched, named_seq): - lh_transform.flatten_vector_ops(named_seq.bodyTarget) - lh_transform.cleanup(named_seq.bodyTarget) - transform.yield_() - scheds.add_transform(sched) + scheds.add_transform(lh_schedule.flatten_vector_ops()) if stop_at_stage == "vectorized": return scheds @@ -354,7 +331,7 @@ def parse_cli(): wload = Matmul(*args.sizes, dtype=in_dtype, tile_size=args.tile_size) pipeline = wload.get_pipeline(stop_at_stage=args.dump_kernel) - payload = pipeline.apply(wload.payload_module()) + payload = pipeline.apply(wload.payload_module(), print_after_all=True) if args.dump_kernel or args.dump_schedule: if args.dump_kernel: diff --git a/examples/end-to-end/KernelBench/cpu_matmul.yaml b/examples/end-to-end/KernelBench/cpu_matmul.yaml new file mode 100644 index 00000000..a3ca913a --- /dev/null +++ b/examples/end-to-end/KernelBench/cpu_matmul.yaml @@ -0,0 +1,33 @@ +# This is the default pipeline for kernel_bench, producing serial code. +# For optimized builds, pass the --pipeline to the kernel_bench program. +Pipeline: + ## Packing & Cache tiling (CPU generic) + - schedule: "packing.py[gen=block_pack_matmuls]{block_factors=32,32,32 rhs_transpose_outer_block=True rhs_transpose_inner_block=False}" + - schedule: "x86/pack_lowering.py[gen=lower_packs_unpacks]{tile_size=32}" + - pass: "linalg-morph-ops{named-to-category generic-to-category}" + - schedule: "x86/cache_tiling.py[gen=matmul_cache_tiling]{fuse_producers}" + - schedule: "linalg.py[gen=linalg_contract_fold_unit_dims]" + + ## CPU specific register tiling (depends on uArch & data type) + - schedule: "x86/register_tiling.py[gen=matmul_register_tiling]" + - schedule: "x86/register_tiling.py[gen=matmul_register_unroll]" + - schedule: "tiling.py[gen=tile_ops]{target_op=linalg.fill tile_sizes=[1,1,1]}" + - schedule: "tiling.py[gen=tile_ops]{target_op=linalg.generic tile_sizes=[1,8]}" + + ## Tensor vectorization (for the left-over element wise) + - schedule: "vectorization.py[gen=vectorize_linalg]" + - schedule: "vectorization.py[gen=fold_into_vector_transfer]" + - include: cleanup.yaml + - schedule: "vectorization.py[gen=x86_vectorization]" + + ## Bufferization + - include: bufferization.yaml + - include: bufferization_cleanup.yaml + + ## Buffer vectorization (gets rid of all the rest) + - schedule: "vectorization.py[gen=x86_vectorization]" + - schedule: "vectorization.py[gen=vectorize_all]" + - schedule: "vectorization.py[gen=flatten_vector_ops]" + + ## Lower to LLVM + - include: llvm_lowering.yaml diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index 733752a8..9f23b3b7 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -6,34 +6,39 @@ import subprocess from pathlib import Path +script_path = Path(__file__).parent +project_root = script_path.parent.parent.parent +kb_program = project_root / "tools" / "kernel_bench" +kb_default_pipeline = kb_program.parent / "kernel_bench.yaml" +kb_path = project_root / "third_party" / "KernelBench" / "KernelBench" tests = [ { "kernel": "level1/1_Square_matrix_multiplication_.py", "input_shapes": "32x32xf32xrnd,32x32xf32xid", "output_shape": "32x32xf32x0", + "pipeline": f"{script_path}/cpu_matmul.yaml", }, { "kernel": "level1/1_Square_matrix_multiplication_.py", "input_shapes": "32x32xbf16xrnd,32x32xbf16xid", "output_shape": "32x32xbf16x0", + "pipeline": str(kb_default_pipeline), }, { "kernel": "level1/2_Standard_matrix_multiplication_.py", "input_shapes": "8x16xf32xrnd,16x8xf32xrnd", "output_shape": "8x8xf32x0", + "pipeline": f"{script_path}/cpu_matmul.yaml", }, { "kernel": "level1/2_Standard_matrix_multiplication_.py", "input_shapes": "8x16xbf16xrnd,16x8xbf16xrnd", "output_shape": "8x8xbf16x0", + "pipeline": str(kb_default_pipeline), }, ] if __name__ == "__main__": - project_root = Path(__file__).parent.parent.parent.parent - kb_program = project_root / "tools" / "kernel_bench" - kb_path = project_root / "third_party" / "KernelBench" / "KernelBench" - for test in tests: kb_kernel = kb_path / test["kernel"] command_line = [ @@ -43,6 +48,8 @@ test["input_shapes"], "--output-shape", test["output_shape"], + "--pipeline", + test["pipeline"], "--print-tensor=1", "--seed=42", ] @@ -73,8 +80,8 @@ # CHECK-NOT: Execution failed # CHECK: 2_Standard_matrix_multiplication_.mlir -# CHECK: 3.120935 3.7697 4.5365195 4.397648 4.4506536 3.2665431 3.5362916 -# CHECK: 5.036752 5.312808 5.8109508 4.810084 4.7435184 4.35573 5.311559 +# CHECK: 3.1209347 3.7697 4.536519 4.397648 4.4506536 3.2665434 3.5362918 +# CHECK: 5.036752 5.3128085 5.8109508 4.810084 4.7435184 4.35573 5.311559 # CHECK-NOT: Execution failed diff --git a/lighthouse/pipeline/descriptor.py b/lighthouse/pipeline/descriptor.py index 1461438e..f34f258b 100644 --- a/lighthouse/pipeline/descriptor.py +++ b/lighthouse/pipeline/descriptor.py @@ -133,19 +133,33 @@ def _normalize_include_path(self) -> str: ) @staticmethod - def _string_to_type(value: str) -> str | int | float | bool: + def _string_to_type(value: str) -> str | int | float | bool | list: + # Boolean value = str(value) if value == "True": return True elif value == "False": return False + # Integer try: return int(value) except ValueError: + # Floating point try: return float(value) except ValueError: - return value + # List of values, e.g. [val1,val2,...] + if value.startswith("[") and value.endswith("]"): + list_str = value[1:-1] + # List of values, e.g. val1,val2,... + elif value.find(",") != -1: + list_str = value + else: + # Something else entirely, return as string + return value + return [ + Descriptor._string_to_type(v.strip()) for v in list_str.split(",") + ] @staticmethod def _parse_csv(line: str, separator: str = ",") -> dict: diff --git a/lighthouse/pipeline/driver.py b/lighthouse/pipeline/driver.py index 4fed7c3b..f8b7145a 100644 --- a/lighthouse/pipeline/driver.py +++ b/lighthouse/pipeline/driver.py @@ -75,7 +75,8 @@ def apply(self, module: ir.Module, print_after_all: bool = False) -> ir.Module: if module.context != self.context: raise ValueError("Module context does not match driver context.") for stage in self.stages: - module = stage.apply(module) + with self.context: + module = stage.apply(module) if print_after_all: print(f"After stage {stage}:\n{module}") return module diff --git a/lighthouse/schedule/__init__.py b/lighthouse/schedule/__init__.py index e82438ee..5dc04217 100644 --- a/lighthouse/schedule/__init__.py +++ b/lighthouse/schedule/__init__.py @@ -6,6 +6,8 @@ from .linalg import linalg_contract_fold_unit_dims from .packing import block_pack_matmuls from .tiling import tile_ops +from .vectorization import flatten_vector_ops +from .vectorization import fold_into_vector_transfer from .vectorization import vectorize_linalg from .vectorization import vectorize_all from .vectorization import x86_vectorization @@ -18,6 +20,8 @@ "convert_function_results", "create_named_sequence", "create_schedule", + "flatten_vector_ops", + "fold_into_vector_transfer", "hoist_loops", "linalg_contract_fold_unit_dims", "print_ir", diff --git a/lighthouse/schedule/hoisting.py b/lighthouse/schedule/hoisting.py index 43e820d7..33bc774f 100644 --- a/lighthouse/schedule/hoisting.py +++ b/lighthouse/schedule/hoisting.py @@ -6,11 +6,7 @@ import lighthouse.transform as lh_transform -def hoist_loops( - target_op: str - | list[str] - | MatchInterfaceEnum = MatchInterfaceEnum.LoopLikeInterface, -) -> ir.Module: +def hoist_loops(options: dict = {}) -> ir.Module: """ Apply loop hoisting to all matching ops. @@ -19,6 +15,10 @@ def hoist_loops( Returns: Schedule """ + target_op: str | list[str] | MatchInterfaceEnum = options.get( + "target_op", MatchInterfaceEnum.LoopLikeInterface + ) + with schedule_boilerplate() as (schedule, named_seq): ops = lh_transform.match_op(named_seq.bodyTarget, target_op) lh_transform.loop_hoisting(ops) diff --git a/lighthouse/schedule/packing.py b/lighthouse/schedule/packing.py index c487db97..7a109e80 100644 --- a/lighthouse/schedule/packing.py +++ b/lighthouse/schedule/packing.py @@ -1,17 +1,11 @@ from mlir import ir from mlir.dialects import transform -from .builders import schedule_boilerplate +from lighthouse.schedule.builders import schedule_boilerplate import lighthouse.transform as lh_transform -def block_pack_matmuls( - block_factors: tuple[int, int, int], - lhs_transpose_outer_block: bool = False, - lhs_transpose_inner_block: bool = False, - rhs_transpose_outer_block: bool = True, - rhs_transpose_inner_block: bool = True, -) -> ir.Module: +def block_pack_matmuls(options: dict) -> ir.Module: """ Block pack all matmuls. @@ -25,15 +19,21 @@ def block_pack_matmuls( and the (mb, nb, kb) are the minor blocks of their respective original 2D dimensions (M, N, K). - Args: + Options: block_factors: Block sizes (mb, nb, kb) - lhs_transpose_outer_block: A matrix MB x KB => KB x MB - lhs_transpose_inner_block: A matrix mb x kb => kb x mb - rhs_transpose_outer_block: B matrix KB x NB => NB x KB - rhs_transpose_inner_block: B matrix kb x nb => nb x kb + lhs_transpose_outer_block: True if A matrix MB x KB => KB x MB + lhs_transpose_inner_block: True if A matrix mb x kb => kb x mb + rhs_transpose_outer_block: True if B matrix KB x NB => NB x KB + rhs_transpose_inner_block: True if B matrix kb x nb => nb x kb Returns: Schedule """ + block_factors = options.get("block_factors") + lhs_transpose_outer_block = options.get("lhs_transpose_outer_block", False) + lhs_transpose_inner_block = options.get("lhs_transpose_inner_block", False) + rhs_transpose_outer_block = options.get("rhs_transpose_outer_block", True) + rhs_transpose_inner_block = options.get("rhs_transpose_inner_block", True) + if len(block_factors) != 3: raise ValueError(f"Expected 3 block factors but got {len(block_factors)}") diff --git a/lighthouse/schedule/tiling.py b/lighthouse/schedule/tiling.py index c16c22b2..49ab3b18 100644 --- a/lighthouse/schedule/tiling.py +++ b/lighthouse/schedule/tiling.py @@ -2,18 +2,11 @@ from mlir.dialects import transform from mlir.dialects.transform.structured import MatchInterfaceEnum -from .builders import schedule_boilerplate +from lighthouse.schedule.builders import schedule_boilerplate import lighthouse.transform as lh_transform -def tile_ops( - target_op: str | list[str] | MatchInterfaceEnum, - tile_sizes: list[int], - fuse_producers: bool = False, - tile_interchange: list[int] | None = None, - peel_loops: list[int] = [], - unroll_factors: list[int] = [], -) -> ir.Module: +def tile_ops(options: dict) -> ir.Module: """ Tile all matching op. @@ -39,6 +32,13 @@ def tile_ops( Returns: Schedule """ + target_op: str | list[str] | MatchInterfaceEnum = options["target_op"] + tile_sizes: list[int] = options["tile_sizes"] + fuse_producers: bool = options.get("fuse_producers", False) + tile_interchange: list[int] | None = options.get("tile_interchange", None) + peel_loops: list[int] = options.get("peel_loops", []) + unroll_factors: list[int] = options.get("unroll_factors", []) + with schedule_boilerplate() as (schedule, named_seq): ops = lh_transform.match_op(named_seq.bodyTarget, target_op) with lh_transform.foreach(ops) as op: diff --git a/lighthouse/schedule/vectorization.py b/lighthouse/schedule/vectorization.py index df79f3ab..f10db178 100644 --- a/lighthouse/schedule/vectorization.py +++ b/lighthouse/schedule/vectorization.py @@ -1,13 +1,12 @@ from mlir import ir from mlir.dialects import transform -from mlir.dialects.transform import structured -from mlir.dialects.transform import vector +from mlir.dialects.transform import structured, vector, tensor -from .builders import schedule_boilerplate +from lighthouse.schedule.builders import schedule_boilerplate import lighthouse.transform as lh_transform -def vectorize_linalg() -> ir.Module: +def vectorize_linalg(options: dict = {}) -> ir.Module: """ Vectorize all linalg ops. @@ -33,7 +32,7 @@ def vectorize_linalg() -> ir.Module: return schedule -def vectorize_all() -> ir.Module: +def vectorize_all(options: dict = {}) -> ir.Module: """ Vectorize all ops. @@ -51,7 +50,7 @@ def vectorize_all() -> ir.Module: return schedule -def x86_vectorization() -> ir.Module: +def x86_vectorization(options: dict = {}) -> ir.Module: """ Apply x86-specific vector rewrites. @@ -64,3 +63,34 @@ def x86_vectorization() -> ir.Module: transform.yield_() return schedule + + +def fold_into_vector_transfer(options: dict = {}) -> ir.Module: + """ + Fold vector.contract into vector.transfer_read and vector.transfer_write. + + Returns: + Schedule + """ + with schedule_boilerplate() as (schedule, named_seq): + with ir.InsertionPoint( + transform.ApplyPatternsOp(named_seq.bodyTarget).patterns + ): + tensor.apply_patterns_tensor_fold_tensor_subset_ops_into_vector_transfers() + transform.apply_patterns_canonicalization() + transform.yield_() + return schedule + + +def flatten_vector_ops(options: dict = {}) -> ir.Module: + """ + Flatten vector ops to 1D. + + Returns: + Schedule + """ + with schedule_boilerplate() as (schedule, named_seq): + lh_transform.flatten_vector_ops(named_seq.bodyTarget) + lh_transform.cleanup(named_seq.bodyTarget) + transform.yield_() + return schedule diff --git a/lighthouse/schedule/x86/__init__.py b/lighthouse/schedule/x86/__init__.py index f38af651..4e10d28c 100644 --- a/lighthouse/schedule/x86/__init__.py +++ b/lighthouse/schedule/x86/__init__.py @@ -1,6 +1,11 @@ +from .cache_tiling import matmul_cache_tiling from .pack_lowering import lower_packs_unpacks +from .register_tiling import matmul_register_tiling, matmul_register_unroll __all__ = [ "lower_packs_unpacks", + "matmul_cache_tiling", + "matmul_register_tiling", + "matmul_register_unroll", "tile_and_vector_matmul", ] diff --git a/lighthouse/schedule/x86/cache_tiling.py b/lighthouse/schedule/x86/cache_tiling.py new file mode 100644 index 00000000..fe69e6ad --- /dev/null +++ b/lighthouse/schedule/x86/cache_tiling.py @@ -0,0 +1,48 @@ +from mlir.dialects import transform +from mlir.dialects.transform import structured + +from lighthouse.dialects.transform import transform_ext +from lighthouse.schedule.builders import schedule_boilerplate +import lighthouse.transform as lh_transform + + +def matmul_cache_tiling(options: dict) -> transform.TransformOpInterface: + """ + Applies cache tiling to the target matmul operation. + Creates a forall loop on successful rewrite. + + This tiling step improves computation's memory access pattern + and exposes parallelism. + + Optionally, fusion can be performed after tiling to minimize + data transfers. + + Args: + target: Handle to target operation. + tile_size: Target size for tile dimensions. + fuse_producers: Apply extra producer ops fusion after tiling. + """ + target = options.get("target", "linalg.contract") + tile_size: int = options.get("tile_size", 32) + fuse_producers: bool = options.get("fuse_producers", False) + + with schedule_boilerplate() as (sched, named_seq): + ops = lh_transform.match_op(named_seq.bodyTarget, target) + with lh_transform.foreach(ops) as op: + tiles = transform_ext.get_tiling_sizes(op, tile_dim=tile_size) + if fuse_producers: + # Tile the target and greedily fuse its producers. + structured.FuseOp( + op, + tile_sizes=tiles, + apply_cleanup=True, + use_forall=True, + ).results + else: + # Only tile the target. + structured.TileUsingForallOp(op, tile_sizes=tiles).results + transform.yield_() + transform.yield_() + + # TODO: Fuse elementwise consumers. + return sched diff --git a/lighthouse/schedule/x86/pack_lowering.py b/lighthouse/schedule/x86/pack_lowering.py index 4133593f..8c86c4a8 100644 --- a/lighthouse/schedule/x86/pack_lowering.py +++ b/lighthouse/schedule/x86/pack_lowering.py @@ -9,12 +9,7 @@ from lighthouse import transform as lh_transform -def lower_packs_for_vectorization( - pack_ops, - pack_tile_sizes: list[int], - vector_tile_sizes: list[int] | None = None, - vector_unroll_factors: list[int] = [], -): +def lower_packs_for_vectorization(options: dict): """ Lower packs into hardware-friendly operations. @@ -24,6 +19,11 @@ def lower_packs_for_vectorization( vector_tile_sizes: Target vector shapes vector_unroll_factors: Unroll factors for each vector loop. """ + pack_ops = options["pack_ops"] + pack_tile_sizes = options["pack_tile_sizes"] + vector_tile_sizes = options.get("vector_tile_sizes", None) + vector_unroll_factors = options.get("vector_unroll_factors", []) + with lh_transform.foreach(pack_ops) as pack_op: tiled_pack = structured.TileUsingForOp( pack_op, sizes=pack_tile_sizes @@ -44,11 +44,7 @@ def lower_packs_for_vectorization( transform.yield_() -def lower_unpacks_for_vectorization( - unpack_ops, - unpack_tile_sizes: list[int], - vector_tile_sizes: list[int] | None = None, -): +def lower_unpacks_for_vectorization(options: dict): """ Lower unpacks into hardware-friendly operations. @@ -57,6 +53,10 @@ def lower_unpacks_for_vectorization( unpack_tile_sizes: Unpack sub-tiling sizes vector_tile_sizes: Target vector shapes """ + unpack_ops = options["unpack_ops"] + unpack_tile_sizes = options["unpack_tile_sizes"] + vector_tile_sizes = options.get("vector_tile_sizes", None) + with lh_transform.foreach(unpack_ops) as unpack_op: tiled_unpack = structured.TileUsingForOp( unpack_op, sizes=unpack_tile_sizes @@ -77,7 +77,7 @@ def lower_unpacks_for_vectorization( transform.yield_() -def lower_packs_unpacks(tile_size: int) -> ir.Module: +def lower_packs_unpacks(options: dict) -> ir.Module: """ Lower pack and unpack ops into hardware-friendly shapes. @@ -86,25 +86,29 @@ def lower_packs_unpacks(tile_size: int) -> ir.Module: Returns: Schedule """ + tile_size = options["tile_size"] + with schedule_boilerplate() as (schedule, named_seq): pack_unpack_vector_m = max(8, tile_size) pack_unpack_vector_n = min(64, tile_size) packs = lh_transform.match_op(named_seq.bodyTarget, "linalg.pack") lower_packs_for_vectorization( - packs, - pack_tile_sizes=[1, 1], - vector_tile_sizes=[1, 1, pack_unpack_vector_m, pack_unpack_vector_n], - vector_unroll_factors=[ - tile_size // pack_unpack_vector_n, - ], + { + "pack_ops": packs, + "pack_tile_sizes": [1, 1], + "vector_tile_sizes": [1, 1, pack_unpack_vector_m, pack_unpack_vector_n], + "vector_unroll_factors": [tile_size // pack_unpack_vector_n], + } ) lh_transform.cleanup(named_seq.bodyTarget) unpacks = lh_transform.match_op(named_seq.bodyTarget, "linalg.unpack") lower_unpacks_for_vectorization( - unpacks, - unpack_tile_sizes=[tile_size, tile_size], - vector_tile_sizes=[1], + { + "unpack_ops": unpacks, + "unpack_tile_sizes": [tile_size, tile_size], + "vector_tile_sizes": [1], + } ) transposes = lh_transform.match_op(named_seq.bodyTarget, "linalg.transpose") with lh_transform.foreach(transposes) as tranpose: diff --git a/lighthouse/schedule/x86/register_tiling.py b/lighthouse/schedule/x86/register_tiling.py new file mode 100644 index 00000000..35a590f1 --- /dev/null +++ b/lighthouse/schedule/x86/register_tiling.py @@ -0,0 +1,93 @@ +from mlir.dialects import transform + +import lighthouse.schedule as lh_schedule + +# FIXME: These functions should receive u-Arch and data type parameters and come +# up with the appropriate tiling and unrolling factors based on the target hardware + + +def matmul_register_tiling(options: dict) -> transform.TransformOpInterface: + """ + Applies register tiling to the target matmul operation. + + This tiling step prepares the IR for the x86 vectorization passes. + + Args: + target: Target operation. + tile_size: Tile size used in the previous cache tiling step. + reg_tile_batch: Target size for batch dimension tile. + reg_tile_m: Target size for M dimension tile. + reg_tile_n: Target size for N dimension tile. + reg_tile_k: Target size for K dimension tile. + batch: True is the input has batch dimension. + """ + target = options.get("target", "linalg.contract") + tile_size: int = options.get("tile_size", 32) + reg_tile_batch: int = options.get("reg_tile_batch", 1) + reg_tile_m: int = options.get("reg_tile_m", 8) + reg_tile_n: int = options.get("reg_tile_n", 32) + reg_tile_k: int = options.get("reg_tile_k", 2) + batch: bool = options.get("batch", False) + + tile_sizes = [reg_tile_m, reg_tile_n, reg_tile_k] + if batch: + tile_sizes = [reg_tile_batch] + tile_sizes + + reg_peel_loops = [] + assert tile_size % reg_tile_k == 0, "Invalid K dim register tiling" + if tile_size % reg_tile_n != 0: + reg_peel_loops.append(1) + if tile_size % reg_tile_m != 0: + reg_peel_loops.append(0) + return lh_schedule.tile_ops( + { + "target_op": target, + "tile_sizes": tile_sizes, + "tile_interchange": [1, 2, 0, 3], + "peel_loops": reg_peel_loops, + } + ) + + +def matmul_register_unroll(options: dict) -> transform.TransformOpInterface: + """ + Applies register unrolling to the target matmul operation. + + This unrolling step prepares the IR for the x86 vectorization passes. + Ensure that shapes are compatible with target hardware instructions. + + Args: + target: Target operation. + tile_size: Tile size used in the previous cache tiling step. + reg_tile_m: Target size for M dimension tile. + reg_tile_n: Target size for N dimension tile. + reg_tile_k: Target size for K dimension tile. + reg_unroll_m: Unroll M dimension after tiling. + reg_unroll_n: Unroll N dimension after tiling. + batch: True is the input has batch dimension. + """ + target = options.get("target", "linalg.contract") + reg_tile_m: int = options.get("reg_tile_m", 8) + reg_tile_n: int = options.get("reg_tile_n", 32) + reg_tile_k: int = options.get("reg_tile_k", 2) + reg_unroll_m: int = options.get("reg_unroll_m", 1) + reg_unroll_n: int = options.get("reg_unroll_n", 16) + reg_unroll_k: int = options.get("reg_unroll_k", 1) + batch: bool = options.get("batch", False) + + tile_sizes = [reg_unroll_m, reg_unroll_n, reg_unroll_k] + if batch: + tile_sizes = [0] + tile_sizes + + reg_unroll_factors = [ + reg_tile_m // reg_unroll_m, + reg_tile_n // reg_unroll_n, + reg_tile_k // reg_unroll_k, + ] + return lh_schedule.tile_ops( + { + "target_op": target, + "tile_sizes": tile_sizes, + "unroll_factors": reg_unroll_factors, + } + ) diff --git a/lighthouse/transform/x86/__init__.py b/lighthouse/transform/x86/__init__.py deleted file mode 100644 index 115a0d47..00000000 --- a/lighthouse/transform/x86/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .cache_tiling import matmul_cache_tiling - -__all__ = [ - "matmul_cache_tiling", -] diff --git a/lighthouse/transform/x86/cache_tiling.py b/lighthouse/transform/x86/cache_tiling.py deleted file mode 100644 index 3063b3ec..00000000 --- a/lighthouse/transform/x86/cache_tiling.py +++ /dev/null @@ -1,45 +0,0 @@ -from mlir import ir -from mlir.dialects.transform import structured - -from lighthouse.dialects.transform import transform_ext - - -def matmul_cache_tiling( - target, - tile_size: int = 32, - fuse_producers: bool = False, -) -> tuple[ir.Value, ir.Value]: - """ - Applies cache tiling to the target matmul operation. - Creates a forall loop on successful rewrite. - - This tiling step improves computation's memory access pattern - and exposes parallelism. - - Optionally, fusion can be performed after tiling to minimize - data transfers. - - Args: - target: Handle to target operation. - tile_size: Target size for tile dimensions. - fuse_producers: Apply extra producer ops fusion after tiling. - Returns: - Handles to the tiled op and created loop - """ - tiles = transform_ext.get_tiling_sizes(target, tile_dim=tile_size) - if fuse_producers: - # Tile the target and greedily fuse its producers. - tiled_op, forall_op = structured.FuseOp( - target, - tile_sizes=tiles, - apply_cleanup=True, - use_forall=True, - ).results - else: - # Only tile the target. - tiled_op, forall_op = structured.TileUsingForallOp( - target, tile_sizes=tiles - ).results - # TODO: Fuse elementwise consumers. - - return tiled_op, forall_op From 03627b0405ad04d948b5aab51475ad2c44c274a4 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Mon, 11 May 2026 15:08:33 +0100 Subject: [PATCH 2/8] only use optimized pipeline for x86_64 --- examples/end-to-end/KernelBench/test_kernel_bench.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index 9f23b3b7..9e5708e2 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -4,6 +4,7 @@ # REQUIRES: kernel_bench import subprocess +import platform from pathlib import Path script_path = Path(__file__).parent @@ -11,12 +12,15 @@ kb_program = project_root / "tools" / "kernel_bench" kb_default_pipeline = kb_program.parent / "kernel_bench.yaml" kb_path = project_root / "third_party" / "KernelBench" / "KernelBench" +arch = platform.machine() tests = [ { "kernel": "level1/1_Square_matrix_multiplication_.py", "input_shapes": "32x32xf32xrnd,32x32xf32xid", "output_shape": "32x32xf32x0", - "pipeline": f"{script_path}/cpu_matmul.yaml", + "pipeline": f"{script_path}/cpu_matmul.yaml" + if arch == "x86_64" + else str(kb_default_pipeline), }, { "kernel": "level1/1_Square_matrix_multiplication_.py", @@ -28,7 +32,9 @@ "kernel": "level1/2_Standard_matrix_multiplication_.py", "input_shapes": "8x16xf32xrnd,16x8xf32xrnd", "output_shape": "8x8xf32x0", - "pipeline": f"{script_path}/cpu_matmul.yaml", + "pipeline": f"{script_path}/cpu_matmul.yaml" + if arch == "x86_64" + else str(kb_default_pipeline), }, { "kernel": "level1/2_Standard_matrix_multiplication_.py", From c96d2cb9a8ea89067451f24c786c9d7c6c2f76a3 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Mon, 11 May 2026 15:11:57 +0100 Subject: [PATCH 3/8] allow for imprecision between unoptimized/optimized pipelines --- examples/end-to-end/KernelBench/test_kernel_bench.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index 9e5708e2..1598634d 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -86,8 +86,8 @@ # CHECK-NOT: Execution failed # CHECK: 2_Standard_matrix_multiplication_.mlir -# CHECK: 3.1209347 3.7697 4.536519 4.397648 4.4506536 3.2665434 3.5362918 -# CHECK: 5.036752 5.3128085 5.8109508 4.810084 4.7435184 4.35573 5.311559 +# CHECK: 3.12093{{.*}} 3.7697 4.53651{{.*}} 4.397648 4.4506536 3.26654{{.*}} 3.53629{{.*}} +# CHECK: 5.036752 5.31280{{.*}} 5.8109508 4.810084 4.7435184 4.35573 5.311559 # CHECK-NOT: Execution failed From e78644752646f9de8b8eb92a549be66a4b8c88a0 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Tue, 12 May 2026 11:13:46 +0100 Subject: [PATCH 4/8] Refactor arguments back to original, no more dictionary nonsense --- examples/cpu/x86/matmul.py | 59 ++++++++++--------- .../end-to-end/KernelBench/cpu_matmul.yaml | 6 +- .../KernelBench/test_kernel_bench.py | 8 +-- lighthouse/pipeline/stage.py | 3 +- lighthouse/schedule/hoisting.py | 10 ++-- lighthouse/schedule/linalg.py | 2 +- lighthouse/schedule/packing.py | 14 ++--- lighthouse/schedule/tiling.py | 16 ++--- lighthouse/schedule/vectorization.py | 10 ++-- lighthouse/schedule/x86/cache_tiling.py | 8 +-- lighthouse/schedule/x86/pack_lowering.py | 39 +++++------- lighthouse/schedule/x86/register_tiling.py | 57 +++++++++--------- .../schedule/x86/tile_and_vector_matmul.py | 1 - test/opt/transforms/pipeline-check.py | 4 +- 14 files changed, 112 insertions(+), 125 deletions(-) diff --git a/examples/cpu/x86/matmul.py b/examples/cpu/x86/matmul.py index 36dc3171..5d8b4b51 100644 --- a/examples/cpu/x86/matmul.py +++ b/examples/cpu/x86/matmul.py @@ -146,15 +146,13 @@ def get_pipeline( # Create cache-friendly access pattern across matmul tiles. scheds.add_transform( lh_schedule.block_pack_matmuls( - { - "block_factors": [self.tile_size, self.tile_size, self.tile_size], - "rhs_transpose_outer_block": True, - "rhs_transpose_inner_block": False, - } + block_factors=[self.tile_size, self.tile_size, self.tile_size], + rhs_transpose_outer_block=True, + rhs_transpose_inner_block=False, ) ) scheds.add_transform( - lh_schedule_x86.lower_packs_unpacks({"tile_size": self.tile_size}) + lh_schedule_x86.lower_packs_unpacks(tile_size=self.tile_size) ) # Convert to category ops for easier op matching. @@ -165,11 +163,10 @@ def get_pipeline( ) ) - # GEMM cache tiling. - # Create memory friendly access pattern. + # GEMM cache tiling, create memory friendly access pattern. scheds.add_transform( lh_schedule_x86.matmul_cache_tiling( - {"tile_size": self.tile_size, "fuse_producers": True} + target="linalg.contract", tile_size=self.tile_size, fuse_producers=True ) ) @@ -180,37 +177,36 @@ def get_pipeline( # GEMM register tiling, ensure that computation can fit into vector registers. scheds.add_transform( lh_schedule_x86.matmul_register_tiling( - { - "tile_size": self.tile_size, - "reg_tile_m": 8, - "reg_tile_n": 32, - "reg_tile_k": 2, - "batch": True, - } + target="linalg.contract", + tile_size=self.tile_size, + reg_tile_m=8, + reg_tile_n=32, + reg_tile_k=2, + batch=True, ) ) # GEMM register unroll, ensure that shapes are compatible with target hardware instructions. scheds.add_transform( lh_schedule_x86.matmul_register_unroll( - { - "reg_tile_m": 8, - "reg_tile_n": 32, - "reg_tile_k": 2, - "reg_unroll_m": 1, - "reg_unroll_n": 16, - "reg_unroll_k": 2 if self.dtype == ml_dtypes.bfloat16 else 1, - "batch": True, - } + target="linalg.contract", + tile_size=self.tile_size, + reg_tile_m=8, + reg_tile_n=32, + reg_tile_k=2, + reg_unroll_m=1, + reg_unroll_n=16, + reg_unroll_k=2 if self.dtype == ml_dtypes.bfloat16 else 1, + batch=True, ) ) # Further tiling into hardware-friendly sizes for vectorization. scheds.add_transform( - lh_schedule.tile_ops({"target_op": "linalg.fill", "tile_sizes": [1, 1, 1]}) + lh_schedule.tile_ops(target_op="linalg.fill", tile_sizes=[1, 1, 1]) ) scheds.add_transform( - lh_schedule.tile_ops({"target_op": "linalg.generic", "tile_sizes": [1, 8]}) + lh_schedule.tile_ops(target_op="linalg.generic", tile_sizes=[1, 8]) ) if stop_at_stage == "tiled": @@ -308,6 +304,11 @@ def parse_cli(): action="store_true", help="Dump transform schedule.", ) + parser.add_argument( + "--print-mlir-after-all", + action="store_true", + help="Dump MLIR after all transformations.", + ) args = parser.parse_args() return args @@ -331,7 +332,9 @@ def parse_cli(): wload = Matmul(*args.sizes, dtype=in_dtype, tile_size=args.tile_size) pipeline = wload.get_pipeline(stop_at_stage=args.dump_kernel) - payload = pipeline.apply(wload.payload_module(), print_after_all=True) + payload = pipeline.apply( + wload.payload_module(), print_after_all=args.print_mlir_after_all + ) if args.dump_kernel or args.dump_schedule: if args.dump_kernel: diff --git a/examples/end-to-end/KernelBench/cpu_matmul.yaml b/examples/end-to-end/KernelBench/cpu_matmul.yaml index a3ca913a..ee40ffef 100644 --- a/examples/end-to-end/KernelBench/cpu_matmul.yaml +++ b/examples/end-to-end/KernelBench/cpu_matmul.yaml @@ -5,12 +5,12 @@ Pipeline: - schedule: "packing.py[gen=block_pack_matmuls]{block_factors=32,32,32 rhs_transpose_outer_block=True rhs_transpose_inner_block=False}" - schedule: "x86/pack_lowering.py[gen=lower_packs_unpacks]{tile_size=32}" - pass: "linalg-morph-ops{named-to-category generic-to-category}" - - schedule: "x86/cache_tiling.py[gen=matmul_cache_tiling]{fuse_producers}" + - schedule: "x86/cache_tiling.py[gen=matmul_cache_tiling]{target=linalg.contract fuse_producers}" - schedule: "linalg.py[gen=linalg_contract_fold_unit_dims]" ## CPU specific register tiling (depends on uArch & data type) - - schedule: "x86/register_tiling.py[gen=matmul_register_tiling]" - - schedule: "x86/register_tiling.py[gen=matmul_register_unroll]" + - schedule: "x86/register_tiling.py[gen=matmul_register_tiling]{target=linalg.contract}" + - schedule: "x86/register_tiling.py[gen=matmul_register_unroll]{target=linalg.contract}" - schedule: "tiling.py[gen=tile_ops]{target_op=linalg.fill tile_sizes=[1,1,1]}" - schedule: "tiling.py[gen=tile_ops]{target_op=linalg.generic tile_sizes=[1,8]}" diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index 1598634d..7daa98e5 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -74,8 +74,8 @@ assert result.returncode == 0, "Execution failed" # CHECK: 1_Square_matrix_multiplication_.mlir -# CHECK 0.37454012 0.9507143 0.7319939 ... 0.04645041 0.60754484 0.17052412 -# CHECK: 0.27214515 0.59023064 0.3609739 ... 0.297349 0.9243962 0.97105825 +# CHECK 0.3745{{.*}} 0.9507{{.*}} 0.7319{{.*}} ... 0.0464{{.*}} 0.6075{{.*}} 0.1705{{.*}} +# CHECK: 0.2721{{.*}} 0.5902{{.*}} 0.3609{{.*}} ... 0.2973{{.*}} 0.9243{{.*}} 0.9710{{.*}} # CHECK-NOT: Execution failed @@ -86,8 +86,8 @@ # CHECK-NOT: Execution failed # CHECK: 2_Standard_matrix_multiplication_.mlir -# CHECK: 3.12093{{.*}} 3.7697 4.53651{{.*}} 4.397648 4.4506536 3.26654{{.*}} 3.53629{{.*}} -# CHECK: 5.036752 5.31280{{.*}} 5.8109508 4.810084 4.7435184 4.35573 5.311559 +# CHECK: 3.1209{{.*}} 3.7697{{.*}} 4.5365{{.*}} 4.3976{{.*}} 4.4506{{.*}} 3.2665{{.*}} 3.5362{{.*}} +# CHECK: 5.0367{{.*}} 5.3128{{.*}} 5.8109{{.*}} 4.8100{{.*}} 4.7435{{.*}} 4.3557{{.*}} 5.3115{{.*}} # CHECK-NOT: Execution failed diff --git a/lighthouse/pipeline/stage.py b/lighthouse/pipeline/stage.py index 558f9905..00a5c660 100644 --- a/lighthouse/pipeline/stage.py +++ b/lighthouse/pipeline/stage.py @@ -180,11 +180,12 @@ def __init__(self, transform: Transform | ir.Module, context: ir.Context): raise ValueError( f"Transform module '{transform.filename}' does not define a '{transform.generator}' generator function." ) + # Get the generator function in the transform module. self.generator = getattr(transform_module, transform.generator) # Run the function with the dictionary as the options that will create the named sequence. with context, ir.Location.unknown(): - self.module = self.generator(transform.options) + self.module = self.generator(**transform.options) else: raise ValueError(f"Unsupported transform type: {transform.type}") diff --git a/lighthouse/schedule/hoisting.py b/lighthouse/schedule/hoisting.py index 33bc774f..43e820d7 100644 --- a/lighthouse/schedule/hoisting.py +++ b/lighthouse/schedule/hoisting.py @@ -6,7 +6,11 @@ import lighthouse.transform as lh_transform -def hoist_loops(options: dict = {}) -> ir.Module: +def hoist_loops( + target_op: str + | list[str] + | MatchInterfaceEnum = MatchInterfaceEnum.LoopLikeInterface, +) -> ir.Module: """ Apply loop hoisting to all matching ops. @@ -15,10 +19,6 @@ def hoist_loops(options: dict = {}) -> ir.Module: Returns: Schedule """ - target_op: str | list[str] | MatchInterfaceEnum = options.get( - "target_op", MatchInterfaceEnum.LoopLikeInterface - ) - with schedule_boilerplate() as (schedule, named_seq): ops = lh_transform.match_op(named_seq.bodyTarget, target_op) lh_transform.loop_hoisting(ops) diff --git a/lighthouse/schedule/linalg.py b/lighthouse/schedule/linalg.py index 4d7790ec..bebe3aa3 100644 --- a/lighthouse/schedule/linalg.py +++ b/lighthouse/schedule/linalg.py @@ -6,7 +6,7 @@ import lighthouse.transform as lh_transform -def linalg_contract_fold_unit_dims(options: dict = {}) -> ir.Module: +def linalg_contract_fold_unit_dims() -> ir.Module: """ Fold unit dims of linalg contract. diff --git a/lighthouse/schedule/packing.py b/lighthouse/schedule/packing.py index 7a109e80..ac8c5158 100644 --- a/lighthouse/schedule/packing.py +++ b/lighthouse/schedule/packing.py @@ -5,7 +5,13 @@ import lighthouse.transform as lh_transform -def block_pack_matmuls(options: dict) -> ir.Module: +def block_pack_matmuls( + block_factors, + lhs_transpose_outer_block=False, + lhs_transpose_inner_block=False, + rhs_transpose_outer_block=True, + rhs_transpose_inner_block=True, +) -> ir.Module: """ Block pack all matmuls. @@ -28,12 +34,6 @@ def block_pack_matmuls(options: dict) -> ir.Module: Returns: Schedule """ - block_factors = options.get("block_factors") - lhs_transpose_outer_block = options.get("lhs_transpose_outer_block", False) - lhs_transpose_inner_block = options.get("lhs_transpose_inner_block", False) - rhs_transpose_outer_block = options.get("rhs_transpose_outer_block", True) - rhs_transpose_inner_block = options.get("rhs_transpose_inner_block", True) - if len(block_factors) != 3: raise ValueError(f"Expected 3 block factors but got {len(block_factors)}") diff --git a/lighthouse/schedule/tiling.py b/lighthouse/schedule/tiling.py index 49ab3b18..f28e0475 100644 --- a/lighthouse/schedule/tiling.py +++ b/lighthouse/schedule/tiling.py @@ -6,7 +6,14 @@ import lighthouse.transform as lh_transform -def tile_ops(options: dict) -> ir.Module: +def tile_ops( + target_op: str | list[str] | MatchInterfaceEnum, + tile_sizes: list[int], + fuse_producers: bool = False, + tile_interchange: list[int] | None = None, + peel_loops: list[int] = [], + unroll_factors: list[int] = [], +) -> ir.Module: """ Tile all matching op. @@ -32,13 +39,6 @@ def tile_ops(options: dict) -> ir.Module: Returns: Schedule """ - target_op: str | list[str] | MatchInterfaceEnum = options["target_op"] - tile_sizes: list[int] = options["tile_sizes"] - fuse_producers: bool = options.get("fuse_producers", False) - tile_interchange: list[int] | None = options.get("tile_interchange", None) - peel_loops: list[int] = options.get("peel_loops", []) - unroll_factors: list[int] = options.get("unroll_factors", []) - with schedule_boilerplate() as (schedule, named_seq): ops = lh_transform.match_op(named_seq.bodyTarget, target_op) with lh_transform.foreach(ops) as op: diff --git a/lighthouse/schedule/vectorization.py b/lighthouse/schedule/vectorization.py index f10db178..125366a2 100644 --- a/lighthouse/schedule/vectorization.py +++ b/lighthouse/schedule/vectorization.py @@ -6,7 +6,7 @@ import lighthouse.transform as lh_transform -def vectorize_linalg(options: dict = {}) -> ir.Module: +def vectorize_linalg() -> ir.Module: """ Vectorize all linalg ops. @@ -32,7 +32,7 @@ def vectorize_linalg(options: dict = {}) -> ir.Module: return schedule -def vectorize_all(options: dict = {}) -> ir.Module: +def vectorize_all() -> ir.Module: """ Vectorize all ops. @@ -50,7 +50,7 @@ def vectorize_all(options: dict = {}) -> ir.Module: return schedule -def x86_vectorization(options: dict = {}) -> ir.Module: +def x86_vectorization() -> ir.Module: """ Apply x86-specific vector rewrites. @@ -65,7 +65,7 @@ def x86_vectorization(options: dict = {}) -> ir.Module: return schedule -def fold_into_vector_transfer(options: dict = {}) -> ir.Module: +def fold_into_vector_transfer() -> ir.Module: """ Fold vector.contract into vector.transfer_read and vector.transfer_write. @@ -82,7 +82,7 @@ def fold_into_vector_transfer(options: dict = {}) -> ir.Module: return schedule -def flatten_vector_ops(options: dict = {}) -> ir.Module: +def flatten_vector_ops() -> ir.Module: """ Flatten vector ops to 1D. diff --git a/lighthouse/schedule/x86/cache_tiling.py b/lighthouse/schedule/x86/cache_tiling.py index fe69e6ad..ad3ddd36 100644 --- a/lighthouse/schedule/x86/cache_tiling.py +++ b/lighthouse/schedule/x86/cache_tiling.py @@ -6,7 +6,9 @@ import lighthouse.transform as lh_transform -def matmul_cache_tiling(options: dict) -> transform.TransformOpInterface: +def matmul_cache_tiling( + target: str, tile_size: int = 32, fuse_producers: bool = False +) -> transform.TransformOpInterface: """ Applies cache tiling to the target matmul operation. Creates a forall loop on successful rewrite. @@ -22,10 +24,6 @@ def matmul_cache_tiling(options: dict) -> transform.TransformOpInterface: tile_size: Target size for tile dimensions. fuse_producers: Apply extra producer ops fusion after tiling. """ - target = options.get("target", "linalg.contract") - tile_size: int = options.get("tile_size", 32) - fuse_producers: bool = options.get("fuse_producers", False) - with schedule_boilerplate() as (sched, named_seq): ops = lh_transform.match_op(named_seq.bodyTarget, target) with lh_transform.foreach(ops) as op: diff --git a/lighthouse/schedule/x86/pack_lowering.py b/lighthouse/schedule/x86/pack_lowering.py index 8c86c4a8..d2f16c39 100644 --- a/lighthouse/schedule/x86/pack_lowering.py +++ b/lighthouse/schedule/x86/pack_lowering.py @@ -9,7 +9,9 @@ from lighthouse import transform as lh_transform -def lower_packs_for_vectorization(options: dict): +def lower_packs_for_vectorization( + pack_ops, pack_tile_sizes, vector_tile_sizes=None, vector_unroll_factors=[] +): """ Lower packs into hardware-friendly operations. @@ -19,11 +21,6 @@ def lower_packs_for_vectorization(options: dict): vector_tile_sizes: Target vector shapes vector_unroll_factors: Unroll factors for each vector loop. """ - pack_ops = options["pack_ops"] - pack_tile_sizes = options["pack_tile_sizes"] - vector_tile_sizes = options.get("vector_tile_sizes", None) - vector_unroll_factors = options.get("vector_unroll_factors", []) - with lh_transform.foreach(pack_ops) as pack_op: tiled_pack = structured.TileUsingForOp( pack_op, sizes=pack_tile_sizes @@ -44,7 +41,9 @@ def lower_packs_for_vectorization(options: dict): transform.yield_() -def lower_unpacks_for_vectorization(options: dict): +def lower_unpacks_for_vectorization( + unpack_ops, unpack_tile_sizes, vector_tile_sizes=None +): """ Lower unpacks into hardware-friendly operations. @@ -53,10 +52,6 @@ def lower_unpacks_for_vectorization(options: dict): unpack_tile_sizes: Unpack sub-tiling sizes vector_tile_sizes: Target vector shapes """ - unpack_ops = options["unpack_ops"] - unpack_tile_sizes = options["unpack_tile_sizes"] - vector_tile_sizes = options.get("vector_tile_sizes", None) - with lh_transform.foreach(unpack_ops) as unpack_op: tiled_unpack = structured.TileUsingForOp( unpack_op, sizes=unpack_tile_sizes @@ -77,7 +72,7 @@ def lower_unpacks_for_vectorization(options: dict): transform.yield_() -def lower_packs_unpacks(options: dict) -> ir.Module: +def lower_packs_unpacks(tile_size: int) -> ir.Module: """ Lower pack and unpack ops into hardware-friendly shapes. @@ -86,29 +81,23 @@ def lower_packs_unpacks(options: dict) -> ir.Module: Returns: Schedule """ - tile_size = options["tile_size"] - with schedule_boilerplate() as (schedule, named_seq): pack_unpack_vector_m = max(8, tile_size) pack_unpack_vector_n = min(64, tile_size) packs = lh_transform.match_op(named_seq.bodyTarget, "linalg.pack") lower_packs_for_vectorization( - { - "pack_ops": packs, - "pack_tile_sizes": [1, 1], - "vector_tile_sizes": [1, 1, pack_unpack_vector_m, pack_unpack_vector_n], - "vector_unroll_factors": [tile_size // pack_unpack_vector_n], - } + pack_ops=packs, + pack_tile_sizes=[1, 1], + vector_tile_sizes=[1, 1, pack_unpack_vector_m, pack_unpack_vector_n], + vector_unroll_factors=[tile_size // pack_unpack_vector_n], ) lh_transform.cleanup(named_seq.bodyTarget) unpacks = lh_transform.match_op(named_seq.bodyTarget, "linalg.unpack") lower_unpacks_for_vectorization( - { - "unpack_ops": unpacks, - "unpack_tile_sizes": [tile_size, tile_size], - "vector_tile_sizes": [1], - } + unpack_ops=unpacks, + unpack_tile_sizes=[tile_size, tile_size], + vector_tile_sizes=[1], ) transposes = lh_transform.match_op(named_seq.bodyTarget, "linalg.transpose") with lh_transform.foreach(transposes) as tranpose: diff --git a/lighthouse/schedule/x86/register_tiling.py b/lighthouse/schedule/x86/register_tiling.py index 35a590f1..2310274a 100644 --- a/lighthouse/schedule/x86/register_tiling.py +++ b/lighthouse/schedule/x86/register_tiling.py @@ -6,7 +6,15 @@ # up with the appropriate tiling and unrolling factors based on the target hardware -def matmul_register_tiling(options: dict) -> transform.TransformOpInterface: +def matmul_register_tiling( + target: str, + tile_size: int = 32, + reg_tile_batch: int = 1, + reg_tile_m: int = 8, + reg_tile_n: int = 32, + reg_tile_k: int = 2, + batch: bool = False, +) -> transform.TransformOpInterface: """ Applies register tiling to the target matmul operation. @@ -21,14 +29,6 @@ def matmul_register_tiling(options: dict) -> transform.TransformOpInterface: reg_tile_k: Target size for K dimension tile. batch: True is the input has batch dimension. """ - target = options.get("target", "linalg.contract") - tile_size: int = options.get("tile_size", 32) - reg_tile_batch: int = options.get("reg_tile_batch", 1) - reg_tile_m: int = options.get("reg_tile_m", 8) - reg_tile_n: int = options.get("reg_tile_n", 32) - reg_tile_k: int = options.get("reg_tile_k", 2) - batch: bool = options.get("batch", False) - tile_sizes = [reg_tile_m, reg_tile_n, reg_tile_k] if batch: tile_sizes = [reg_tile_batch] + tile_sizes @@ -40,16 +40,24 @@ def matmul_register_tiling(options: dict) -> transform.TransformOpInterface: if tile_size % reg_tile_m != 0: reg_peel_loops.append(0) return lh_schedule.tile_ops( - { - "target_op": target, - "tile_sizes": tile_sizes, - "tile_interchange": [1, 2, 0, 3], - "peel_loops": reg_peel_loops, - } + target_op=target, + tile_sizes=tile_sizes, + tile_interchange=[1, 2, 0, 3], + peel_loops=reg_peel_loops, ) -def matmul_register_unroll(options: dict) -> transform.TransformOpInterface: +def matmul_register_unroll( + target: str, + tile_size: int = 32, + reg_tile_m: int = 8, + reg_tile_n: int = 32, + reg_tile_k: int = 2, + reg_unroll_m: int = 1, + reg_unroll_n: int = 16, + reg_unroll_k: int = 1, + batch: bool = False, +) -> transform.TransformOpInterface: """ Applies register unrolling to the target matmul operation. @@ -66,15 +74,6 @@ def matmul_register_unroll(options: dict) -> transform.TransformOpInterface: reg_unroll_n: Unroll N dimension after tiling. batch: True is the input has batch dimension. """ - target = options.get("target", "linalg.contract") - reg_tile_m: int = options.get("reg_tile_m", 8) - reg_tile_n: int = options.get("reg_tile_n", 32) - reg_tile_k: int = options.get("reg_tile_k", 2) - reg_unroll_m: int = options.get("reg_unroll_m", 1) - reg_unroll_n: int = options.get("reg_unroll_n", 16) - reg_unroll_k: int = options.get("reg_unroll_k", 1) - batch: bool = options.get("batch", False) - tile_sizes = [reg_unroll_m, reg_unroll_n, reg_unroll_k] if batch: tile_sizes = [0] + tile_sizes @@ -85,9 +84,7 @@ def matmul_register_unroll(options: dict) -> transform.TransformOpInterface: reg_tile_k // reg_unroll_k, ] return lh_schedule.tile_ops( - { - "target_op": target, - "tile_sizes": tile_sizes, - "unroll_factors": reg_unroll_factors, - } + target_op=target, + tile_sizes=tile_sizes, + unroll_factors=reg_unroll_factors, ) diff --git a/lighthouse/schedule/x86/tile_and_vector_matmul.py b/lighthouse/schedule/x86/tile_and_vector_matmul.py index 4a8910e1..fb75c609 100644 --- a/lighthouse/schedule/x86/tile_and_vector_matmul.py +++ b/lighthouse/schedule/x86/tile_and_vector_matmul.py @@ -8,7 +8,6 @@ def create_schedule( - options: dict = {}, tile_sizes: tuple[int, int] = [32, 32], register_tile: tuple[int, int, int] = [8, 32, 1], matmul_op: str = "linalg.matmul", diff --git a/test/opt/transforms/pipeline-check.py b/test/opt/transforms/pipeline-check.py index 961e120c..274c9860 100644 --- a/test/opt/transforms/pipeline-check.py +++ b/test/opt/transforms/pipeline-check.py @@ -4,7 +4,7 @@ from lighthouse.pipeline.stage import apply_bundle -def create_schedule(options: dict = {}) -> ir.Module: +def create_schedule(skip_llvm=False) -> ir.Module: """Creates a Transform Schedule for the test's optimization pipeline.""" schedule_module = ir.Module.create() @@ -32,7 +32,7 @@ def create_schedule(options: dict = {}) -> ir.Module: mod = apply_registered_pass(mod, "convert-linalg-to-loops") mod = apply_bundle(mod, "cleanup.yaml") - if not options.get("skip_llvm", False): + if not skip_llvm: mod = apply_bundle(mod, "llvm_lowering.yaml") transform.YieldOp() From 57a75fb69de20588fe635cee0804fcf6589c4976 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Tue, 12 May 2026 11:16:39 +0100 Subject: [PATCH 5/8] comments --- lighthouse/pipeline/descriptor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lighthouse/pipeline/descriptor.py b/lighthouse/pipeline/descriptor.py index f34f258b..ea1e9f46 100644 --- a/lighthouse/pipeline/descriptor.py +++ b/lighthouse/pipeline/descriptor.py @@ -134,17 +134,14 @@ def _normalize_include_path(self) -> str: @staticmethod def _string_to_type(value: str) -> str | int | float | bool | list: - # Boolean value = str(value) if value == "True": return True elif value == "False": return False - # Integer try: return int(value) except ValueError: - # Floating point try: return float(value) except ValueError: @@ -157,6 +154,7 @@ def _string_to_type(value: str) -> str | int | float | bool | list: else: # Something else entirely, return as string return value + # Recursesively parse the list elements return [ Descriptor._string_to_type(v.strip()) for v in list_str.split(",") ] From 08ee28431d4b6cd99b8d4f65a361332551c835d7 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Tue, 12 May 2026 12:42:30 +0100 Subject: [PATCH 6/8] address all comments --- examples/cpu/x86/matmul.py | 24 +++++++++++-------- .../end-to-end/KernelBench/cpu_matmul.yaml | 2 +- lighthouse/schedule/__init__.py | 4 ++-- lighthouse/schedule/packing.py | 18 +++++++------- lighthouse/schedule/vectorization.py | 7 ++++-- lighthouse/schedule/x86/cache_tiling.py | 5 +++- lighthouse/schedule/x86/register_tiling.py | 16 ++++++------- 7 files changed, 43 insertions(+), 33 deletions(-) diff --git a/examples/cpu/x86/matmul.py b/examples/cpu/x86/matmul.py index 5d8b4b51..1992d561 100644 --- a/examples/cpu/x86/matmul.py +++ b/examples/cpu/x86/matmul.py @@ -175,14 +175,19 @@ def get_pipeline( scheds.add_transform(lh_schedule.linalg_contract_fold_unit_dims()) # GEMM register tiling, ensure that computation can fit into vector registers. + reg_tile_batch = 1 + reg_tile_m = 8 + reg_tile_n = 32 + reg_tile_k = 2 + scheds.add_transform( lh_schedule_x86.matmul_register_tiling( target="linalg.contract", tile_size=self.tile_size, - reg_tile_m=8, - reg_tile_n=32, - reg_tile_k=2, - batch=True, + reg_tile_batch=reg_tile_batch, + reg_tile_m=reg_tile_m, + reg_tile_n=reg_tile_n, + reg_tile_k=reg_tile_k, ) ) @@ -190,14 +195,13 @@ def get_pipeline( scheds.add_transform( lh_schedule_x86.matmul_register_unroll( target="linalg.contract", - tile_size=self.tile_size, - reg_tile_m=8, - reg_tile_n=32, - reg_tile_k=2, + reg_tile_m=reg_tile_m, + reg_tile_n=reg_tile_n, + reg_tile_k=reg_tile_k, reg_unroll_m=1, reg_unroll_n=16, reg_unroll_k=2 if self.dtype == ml_dtypes.bfloat16 else 1, - batch=True, + batch=reg_tile_batch > 0, ) ) @@ -216,7 +220,7 @@ def get_pipeline( scheds.add_transform(lh_schedule.vectorize_linalg()) scheds.add_transform(lh_schedule.hoist_loops()) - scheds.add_transform(lh_schedule.fold_into_vector_transfer()) + scheds.add_transform(lh_schedule.simplify_vector_ops()) # Rewrite vector ops into x86-specific sequences. scheds.add_transform(lh_schedule.x86_vectorization()) diff --git a/examples/end-to-end/KernelBench/cpu_matmul.yaml b/examples/end-to-end/KernelBench/cpu_matmul.yaml index ee40ffef..766b34f9 100644 --- a/examples/end-to-end/KernelBench/cpu_matmul.yaml +++ b/examples/end-to-end/KernelBench/cpu_matmul.yaml @@ -16,7 +16,7 @@ Pipeline: ## Tensor vectorization (for the left-over element wise) - schedule: "vectorization.py[gen=vectorize_linalg]" - - schedule: "vectorization.py[gen=fold_into_vector_transfer]" + - schedule: "vectorization.py[gen=simplify_vector_ops]" - include: cleanup.yaml - schedule: "vectorization.py[gen=x86_vectorization]" diff --git a/lighthouse/schedule/__init__.py b/lighthouse/schedule/__init__.py index 5dc04217..b1204476 100644 --- a/lighthouse/schedule/__init__.py +++ b/lighthouse/schedule/__init__.py @@ -7,7 +7,7 @@ from .packing import block_pack_matmuls from .tiling import tile_ops from .vectorization import flatten_vector_ops -from .vectorization import fold_into_vector_transfer +from .vectorization import simplify_vector_ops from .vectorization import vectorize_linalg from .vectorization import vectorize_all from .vectorization import x86_vectorization @@ -21,11 +21,11 @@ "create_named_sequence", "create_schedule", "flatten_vector_ops", - "fold_into_vector_transfer", "hoist_loops", "linalg_contract_fold_unit_dims", "print_ir", "schedule_boilerplate", + "simplify_vector_ops", "tile_ops", "vectorize_all", "vectorize_linalg", diff --git a/lighthouse/schedule/packing.py b/lighthouse/schedule/packing.py index ac8c5158..932a3fd5 100644 --- a/lighthouse/schedule/packing.py +++ b/lighthouse/schedule/packing.py @@ -6,11 +6,11 @@ def block_pack_matmuls( - block_factors, - lhs_transpose_outer_block=False, - lhs_transpose_inner_block=False, - rhs_transpose_outer_block=True, - rhs_transpose_inner_block=True, + block_factors: tuple[int, int, int], + lhs_transpose_outer_block: bool = False, + lhs_transpose_inner_block: bool = False, + rhs_transpose_outer_block: bool = True, + rhs_transpose_inner_block: bool = True, ) -> ir.Module: """ Block pack all matmuls. @@ -27,10 +27,10 @@ def block_pack_matmuls( Options: block_factors: Block sizes (mb, nb, kb) - lhs_transpose_outer_block: True if A matrix MB x KB => KB x MB - lhs_transpose_inner_block: True if A matrix mb x kb => kb x mb - rhs_transpose_outer_block: True if B matrix KB x NB => NB x KB - rhs_transpose_inner_block: True if B matrix kb x nb => nb x kb + lhs_transpose_outer_block: A matrix MB x KB => KB x MB + lhs_transpose_inner_block: A matrix mb x kb => kb x mb + rhs_transpose_outer_block: B matrix KB x NB => NB x KB + rhs_transpose_inner_block: B matrix kb x nb => nb x kb Returns: Schedule """ diff --git a/lighthouse/schedule/vectorization.py b/lighthouse/schedule/vectorization.py index 125366a2..23988d7f 100644 --- a/lighthouse/schedule/vectorization.py +++ b/lighthouse/schedule/vectorization.py @@ -65,9 +65,9 @@ def x86_vectorization() -> ir.Module: return schedule -def fold_into_vector_transfer() -> ir.Module: +def simplify_vector_ops() -> ir.Module: """ - Fold vector.contract into vector.transfer_read and vector.transfer_write. + Apply simplification patterns to vector operations. Returns: Schedule @@ -76,6 +76,9 @@ def fold_into_vector_transfer() -> ir.Module: with ir.InsertionPoint( transform.ApplyPatternsOp(named_seq.bodyTarget).patterns ): + # FIXME: This transform breaks AVX512 FMA recognition, + # but it's in the other sub-schedule of the same name. + # vector.apply_patterns_vector_cast_away_vector_leading_one_dim() tensor.apply_patterns_tensor_fold_tensor_subset_ops_into_vector_transfers() transform.apply_patterns_canonicalization() transform.yield_() diff --git a/lighthouse/schedule/x86/cache_tiling.py b/lighthouse/schedule/x86/cache_tiling.py index ad3ddd36..53e27aff 100644 --- a/lighthouse/schedule/x86/cache_tiling.py +++ b/lighthouse/schedule/x86/cache_tiling.py @@ -1,3 +1,4 @@ +from mlir import ir from mlir.dialects import transform from mlir.dialects.transform import structured @@ -8,7 +9,7 @@ def matmul_cache_tiling( target: str, tile_size: int = 32, fuse_producers: bool = False -) -> transform.TransformOpInterface: +) -> ir.Module: """ Applies cache tiling to the target matmul operation. Creates a forall loop on successful rewrite. @@ -23,6 +24,8 @@ def matmul_cache_tiling( target: Handle to target operation. tile_size: Target size for tile dimensions. fuse_producers: Apply extra producer ops fusion after tiling. + Returns: + Schedule module """ with schedule_boilerplate() as (sched, named_seq): ops = lh_transform.match_op(named_seq.bodyTarget, target) diff --git a/lighthouse/schedule/x86/register_tiling.py b/lighthouse/schedule/x86/register_tiling.py index 2310274a..793e4d1f 100644 --- a/lighthouse/schedule/x86/register_tiling.py +++ b/lighthouse/schedule/x86/register_tiling.py @@ -1,4 +1,4 @@ -from mlir.dialects import transform +from mlir import ir import lighthouse.schedule as lh_schedule @@ -9,12 +9,11 @@ def matmul_register_tiling( target: str, tile_size: int = 32, - reg_tile_batch: int = 1, + reg_tile_batch: int = 0, reg_tile_m: int = 8, reg_tile_n: int = 32, reg_tile_k: int = 2, - batch: bool = False, -) -> transform.TransformOpInterface: +) -> ir.Module: """ Applies register tiling to the target matmul operation. @@ -30,8 +29,10 @@ def matmul_register_tiling( batch: True is the input has batch dimension. """ tile_sizes = [reg_tile_m, reg_tile_n, reg_tile_k] - if batch: + tile_interchange = [] + if reg_tile_batch: tile_sizes = [reg_tile_batch] + tile_sizes + tile_interchange = [1, 2, 0, 3] reg_peel_loops = [] assert tile_size % reg_tile_k == 0, "Invalid K dim register tiling" @@ -42,14 +43,13 @@ def matmul_register_tiling( return lh_schedule.tile_ops( target_op=target, tile_sizes=tile_sizes, - tile_interchange=[1, 2, 0, 3], + tile_interchange=tile_interchange, peel_loops=reg_peel_loops, ) def matmul_register_unroll( target: str, - tile_size: int = 32, reg_tile_m: int = 8, reg_tile_n: int = 32, reg_tile_k: int = 2, @@ -57,7 +57,7 @@ def matmul_register_unroll( reg_unroll_n: int = 16, reg_unroll_k: int = 1, batch: bool = False, -) -> transform.TransformOpInterface: +) -> ir.Module: """ Applies register unrolling to the target matmul operation. From eadcc2b5c3621d6575a8b91a1740b03de171abac Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Tue, 12 May 2026 14:42:21 +0100 Subject: [PATCH 7/8] unnecessary parameter in matmyl.py --- examples/cpu/x86/matmul.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/cpu/x86/matmul.py b/examples/cpu/x86/matmul.py index 1992d561..7248ceee 100644 --- a/examples/cpu/x86/matmul.py +++ b/examples/cpu/x86/matmul.py @@ -130,7 +130,6 @@ def payload(A, B, C): def get_pipeline( self, stop_at_stage: Optional[str] = None, - parameters: Optional[dict] = None, ) -> PipelineDriver: scheds = PipelineDriver(self.context) From 24e64e854be144d9ef89af5e4a6aa95cf38435f2 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Tue, 12 May 2026 14:51:12 +0100 Subject: [PATCH 8/8] docstrings --- lighthouse/schedule/packing.py | 2 +- lighthouse/schedule/x86/register_tiling.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/lighthouse/schedule/packing.py b/lighthouse/schedule/packing.py index 932a3fd5..59ba895d 100644 --- a/lighthouse/schedule/packing.py +++ b/lighthouse/schedule/packing.py @@ -25,7 +25,7 @@ def block_pack_matmuls( and the (mb, nb, kb) are the minor blocks of their respective original 2D dimensions (M, N, K). - Options: + Args: block_factors: Block sizes (mb, nb, kb) lhs_transpose_outer_block: A matrix MB x KB => KB x MB lhs_transpose_inner_block: A matrix mb x kb => kb x mb diff --git a/lighthouse/schedule/x86/register_tiling.py b/lighthouse/schedule/x86/register_tiling.py index 793e4d1f..70730e67 100644 --- a/lighthouse/schedule/x86/register_tiling.py +++ b/lighthouse/schedule/x86/register_tiling.py @@ -26,7 +26,9 @@ def matmul_register_tiling( reg_tile_m: Target size for M dimension tile. reg_tile_n: Target size for N dimension tile. reg_tile_k: Target size for K dimension tile. - batch: True is the input has batch dimension. + batch: True if the input has batch dimension. + Returns: + Schedule """ tile_sizes = [reg_tile_m, reg_tile_n, reg_tile_k] tile_interchange = [] @@ -72,7 +74,10 @@ def matmul_register_unroll( reg_tile_k: Target size for K dimension tile. reg_unroll_m: Unroll M dimension after tiling. reg_unroll_n: Unroll N dimension after tiling. - batch: True is the input has batch dimension. + reg_unroll_k: Unroll K dimension after tiling. + batch: True if the input has batch dimension. + Returns: + Schedule """ tile_sizes = [reg_unroll_m, reg_unroll_n, reg_unroll_k] if batch: