diff --git a/examples/cpu/x86/matmul.py b/examples/cpu/x86/matmul.py index 88d2e267..7248ceee 100644 --- a/examples/cpu/x86/matmul.py +++ b/examples/cpu/x86/matmul.py @@ -27,8 +27,7 @@ import ml_dtypes import numpy as np from mlir import ir -from mlir.dialects import linalg, transform -from mlir.dialects.transform import tensor +from mlir.dialects import linalg from lighthouse import dialects as lh_dialects from lighthouse.execution.runner import Runner @@ -38,8 +37,6 @@ import lighthouse.utils as lh_utils from lighthouse import schedule as lh_schedule import lighthouse.schedule.x86 as lh_schedule_x86 -from lighthouse import transform as lh_transform -import lighthouse.transform.x86 as lh_transform_x86 import lighthouse.ingress.mlir_gen.utils as lh_mlir_utils from functools import cached_property from typing import Optional @@ -133,7 +130,6 @@ def payload(A, B, C): def get_pipeline( self, stop_at_stage: Optional[str] = None, - parameters: Optional[dict] = None, ) -> PipelineDriver: scheds = PipelineDriver(self.context) @@ -154,7 +150,9 @@ def get_pipeline( rhs_transpose_inner_block=False, ) ) - scheds.add_transform(lh_schedule_x86.lower_packs_unpacks(self.tile_size)) + scheds.add_transform( + lh_schedule_x86.lower_packs_unpacks(tile_size=self.tile_size) + ) # Convert to category ops for easier op matching. scheds.add_pass( @@ -164,66 +162,55 @@ def get_pipeline( ) ) - # GEMM cache tiling. - # Create memory friendly access pattern. - gemm_op = "linalg.contract" - with lh_schedule.schedule_boilerplate() as (sched, named_seq): - ops = lh_transform.match_op(named_seq.bodyTarget, gemm_op) - with lh_transform.foreach(ops) as op: - lh_transform_x86.matmul_cache_tiling( - op, tile_size=self.tile_size, fuse_producers=True - ) - transform.yield_() - transform.yield_() - scheds.add_transform(sched) + # GEMM cache tiling, create memory friendly access pattern. + scheds.add_transform( + lh_schedule_x86.matmul_cache_tiling( + target="linalg.contract", tile_size=self.tile_size, fuse_producers=True + ) + ) # Fold extra parallel outer unit dims before further tiling to help later # vectorization rewrites to recognize ops. scheds.add_transform(lh_schedule.linalg_contract_fold_unit_dims()) - # GEMM register tiling. - # Ensure that computation can fit into vector registers. + # GEMM register tiling, ensure that computation can fit into vector registers. reg_tile_batch = 1 reg_tile_m = 8 reg_tile_n = 32 reg_tile_k = 2 - reg_peel_loops = [] - assert self.tile_size % reg_tile_k == 0, "Invalid K dim register tiling" - if self.tile_size % reg_tile_n != 0: - reg_peel_loops.append(1) - if self.tile_size % reg_tile_m != 0: - reg_peel_loops.append(0) + scheds.add_transform( - lh_schedule.tile_ops( - gemm_op, - tile_sizes=[reg_tile_batch, reg_tile_m, reg_tile_n, reg_tile_k], - tile_interchange=[1, 2, 0, 3], - peel_loops=reg_peel_loops, + lh_schedule_x86.matmul_register_tiling( + target="linalg.contract", + tile_size=self.tile_size, + reg_tile_batch=reg_tile_batch, + reg_tile_m=reg_tile_m, + reg_tile_n=reg_tile_n, + reg_tile_k=reg_tile_k, ) ) - # GEMM register unroll. - # Ensure that shapes are compatible with target hardware instructions. - reg_unroll_m = 1 - reg_unroll_n = 16 - # When VNNI can be used, tuples of 32-bit elements are needed. - reg_unroll_k = 2 if self.dtype == ml_dtypes.bfloat16 else 1 - reg_unroll_factors = [ - reg_tile_m // reg_unroll_m, - reg_tile_n // reg_unroll_n, - reg_tile_k // reg_unroll_k, - ] + # GEMM register unroll, ensure that shapes are compatible with target hardware instructions. scheds.add_transform( - lh_schedule.tile_ops( - gemm_op, - tile_sizes=[0, reg_unroll_m, reg_unroll_n, reg_unroll_k], - unroll_factors=reg_unroll_factors, + lh_schedule_x86.matmul_register_unroll( + target="linalg.contract", + reg_tile_m=reg_tile_m, + reg_tile_n=reg_tile_n, + reg_tile_k=reg_tile_k, + reg_unroll_m=1, + reg_unroll_n=16, + reg_unroll_k=2 if self.dtype == ml_dtypes.bfloat16 else 1, + batch=reg_tile_batch > 0, ) ) # Further tiling into hardware-friendly sizes for vectorization. - scheds.add_transform(lh_schedule.tile_ops("linalg.fill", tile_sizes=[1, 1, 1])) - scheds.add_transform(lh_schedule.tile_ops("linalg.generic", tile_sizes=[1, 8])) + scheds.add_transform( + lh_schedule.tile_ops(target_op="linalg.fill", tile_sizes=[1, 1, 1]) + ) + scheds.add_transform( + lh_schedule.tile_ops(target_op="linalg.generic", tile_sizes=[1, 8]) + ) if stop_at_stage == "tiled": return scheds @@ -232,14 +219,7 @@ def get_pipeline( scheds.add_transform(lh_schedule.vectorize_linalg()) scheds.add_transform(lh_schedule.hoist_loops()) - with lh_schedule.schedule_boilerplate() as (sched, named_seq): - with ir.InsertionPoint( - transform.ApplyPatternsOp(named_seq.bodyTarget).patterns - ): - tensor.apply_patterns_tensor_fold_tensor_subset_ops_into_vector_transfers() - transform.apply_patterns_canonicalization() - transform.yield_() - scheds.add_transform(sched) + scheds.add_transform(lh_schedule.simplify_vector_ops()) # Rewrite vector ops into x86-specific sequences. scheds.add_transform(lh_schedule.x86_vectorization()) @@ -254,11 +234,7 @@ def get_pipeline( scheds.add_transform(lh_schedule.vectorize_all()) # Cleanup vector ops. - with lh_schedule.schedule_boilerplate() as (sched, named_seq): - lh_transform.flatten_vector_ops(named_seq.bodyTarget) - lh_transform.cleanup(named_seq.bodyTarget) - transform.yield_() - scheds.add_transform(sched) + scheds.add_transform(lh_schedule.flatten_vector_ops()) if stop_at_stage == "vectorized": return scheds @@ -331,6 +307,11 @@ def parse_cli(): action="store_true", help="Dump transform schedule.", ) + parser.add_argument( + "--print-mlir-after-all", + action="store_true", + help="Dump MLIR after all transformations.", + ) args = parser.parse_args() return args @@ -354,7 +335,9 @@ def parse_cli(): wload = Matmul(*args.sizes, dtype=in_dtype, tile_size=args.tile_size) pipeline = wload.get_pipeline(stop_at_stage=args.dump_kernel) - payload = pipeline.apply(wload.payload_module()) + payload = pipeline.apply( + wload.payload_module(), print_after_all=args.print_mlir_after_all + ) if args.dump_kernel or args.dump_schedule: if args.dump_kernel: diff --git a/examples/end-to-end/KernelBench/cpu_matmul.yaml b/examples/end-to-end/KernelBench/cpu_matmul.yaml new file mode 100644 index 00000000..766b34f9 --- /dev/null +++ b/examples/end-to-end/KernelBench/cpu_matmul.yaml @@ -0,0 +1,33 @@ +# This is the default pipeline for kernel_bench, producing serial code. +# For optimized builds, pass the --pipeline to the kernel_bench program. +Pipeline: + ## Packing & Cache tiling (CPU generic) + - schedule: "packing.py[gen=block_pack_matmuls]{block_factors=32,32,32 rhs_transpose_outer_block=True rhs_transpose_inner_block=False}" + - schedule: "x86/pack_lowering.py[gen=lower_packs_unpacks]{tile_size=32}" + - pass: "linalg-morph-ops{named-to-category generic-to-category}" + - schedule: "x86/cache_tiling.py[gen=matmul_cache_tiling]{target=linalg.contract fuse_producers}" + - schedule: "linalg.py[gen=linalg_contract_fold_unit_dims]" + + ## CPU specific register tiling (depends on uArch & data type) + - schedule: "x86/register_tiling.py[gen=matmul_register_tiling]{target=linalg.contract}" + - schedule: "x86/register_tiling.py[gen=matmul_register_unroll]{target=linalg.contract}" + - schedule: "tiling.py[gen=tile_ops]{target_op=linalg.fill tile_sizes=[1,1,1]}" + - schedule: "tiling.py[gen=tile_ops]{target_op=linalg.generic tile_sizes=[1,8]}" + + ## Tensor vectorization (for the left-over element wise) + - schedule: "vectorization.py[gen=vectorize_linalg]" + - schedule: "vectorization.py[gen=simplify_vector_ops]" + - include: cleanup.yaml + - schedule: "vectorization.py[gen=x86_vectorization]" + + ## Bufferization + - include: bufferization.yaml + - include: bufferization_cleanup.yaml + + ## Buffer vectorization (gets rid of all the rest) + - schedule: "vectorization.py[gen=x86_vectorization]" + - schedule: "vectorization.py[gen=vectorize_all]" + - schedule: "vectorization.py[gen=flatten_vector_ops]" + + ## Lower to LLVM + - include: llvm_lowering.yaml diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index 733752a8..7daa98e5 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -4,36 +4,47 @@ # REQUIRES: kernel_bench import subprocess +import platform from pathlib import Path +script_path = Path(__file__).parent +project_root = script_path.parent.parent.parent +kb_program = project_root / "tools" / "kernel_bench" +kb_default_pipeline = kb_program.parent / "kernel_bench.yaml" +kb_path = project_root / "third_party" / "KernelBench" / "KernelBench" +arch = platform.machine() tests = [ { "kernel": "level1/1_Square_matrix_multiplication_.py", "input_shapes": "32x32xf32xrnd,32x32xf32xid", "output_shape": "32x32xf32x0", + "pipeline": f"{script_path}/cpu_matmul.yaml" + if arch == "x86_64" + else str(kb_default_pipeline), }, { "kernel": "level1/1_Square_matrix_multiplication_.py", "input_shapes": "32x32xbf16xrnd,32x32xbf16xid", "output_shape": "32x32xbf16x0", + "pipeline": str(kb_default_pipeline), }, { "kernel": "level1/2_Standard_matrix_multiplication_.py", "input_shapes": "8x16xf32xrnd,16x8xf32xrnd", "output_shape": "8x8xf32x0", + "pipeline": f"{script_path}/cpu_matmul.yaml" + if arch == "x86_64" + else str(kb_default_pipeline), }, { "kernel": "level1/2_Standard_matrix_multiplication_.py", "input_shapes": "8x16xbf16xrnd,16x8xbf16xrnd", "output_shape": "8x8xbf16x0", + "pipeline": str(kb_default_pipeline), }, ] if __name__ == "__main__": - project_root = Path(__file__).parent.parent.parent.parent - kb_program = project_root / "tools" / "kernel_bench" - kb_path = project_root / "third_party" / "KernelBench" / "KernelBench" - for test in tests: kb_kernel = kb_path / test["kernel"] command_line = [ @@ -43,6 +54,8 @@ test["input_shapes"], "--output-shape", test["output_shape"], + "--pipeline", + test["pipeline"], "--print-tensor=1", "--seed=42", ] @@ -61,8 +74,8 @@ assert result.returncode == 0, "Execution failed" # CHECK: 1_Square_matrix_multiplication_.mlir -# CHECK 0.37454012 0.9507143 0.7319939 ... 0.04645041 0.60754484 0.17052412 -# CHECK: 0.27214515 0.59023064 0.3609739 ... 0.297349 0.9243962 0.97105825 +# CHECK 0.3745{{.*}} 0.9507{{.*}} 0.7319{{.*}} ... 0.0464{{.*}} 0.6075{{.*}} 0.1705{{.*}} +# CHECK: 0.2721{{.*}} 0.5902{{.*}} 0.3609{{.*}} ... 0.2973{{.*}} 0.9243{{.*}} 0.9710{{.*}} # CHECK-NOT: Execution failed @@ -73,8 +86,8 @@ # CHECK-NOT: Execution failed # CHECK: 2_Standard_matrix_multiplication_.mlir -# CHECK: 3.120935 3.7697 4.5365195 4.397648 4.4506536 3.2665431 3.5362916 -# CHECK: 5.036752 5.312808 5.8109508 4.810084 4.7435184 4.35573 5.311559 +# CHECK: 3.1209{{.*}} 3.7697{{.*}} 4.5365{{.*}} 4.3976{{.*}} 4.4506{{.*}} 3.2665{{.*}} 3.5362{{.*}} +# CHECK: 5.0367{{.*}} 5.3128{{.*}} 5.8109{{.*}} 4.8100{{.*}} 4.7435{{.*}} 4.3557{{.*}} 5.3115{{.*}} # CHECK-NOT: Execution failed diff --git a/lighthouse/pipeline/descriptor.py b/lighthouse/pipeline/descriptor.py index 1461438e..ea1e9f46 100644 --- a/lighthouse/pipeline/descriptor.py +++ b/lighthouse/pipeline/descriptor.py @@ -133,7 +133,7 @@ def _normalize_include_path(self) -> str: ) @staticmethod - def _string_to_type(value: str) -> str | int | float | bool: + def _string_to_type(value: str) -> str | int | float | bool | list: value = str(value) if value == "True": return True @@ -145,7 +145,19 @@ def _string_to_type(value: str) -> str | int | float | bool: try: return float(value) except ValueError: - return value + # List of values, e.g. [val1,val2,...] + if value.startswith("[") and value.endswith("]"): + list_str = value[1:-1] + # List of values, e.g. val1,val2,... + elif value.find(",") != -1: + list_str = value + else: + # Something else entirely, return as string + return value + # Recursesively parse the list elements + return [ + Descriptor._string_to_type(v.strip()) for v in list_str.split(",") + ] @staticmethod def _parse_csv(line: str, separator: str = ",") -> dict: diff --git a/lighthouse/pipeline/driver.py b/lighthouse/pipeline/driver.py index 4fed7c3b..f8b7145a 100644 --- a/lighthouse/pipeline/driver.py +++ b/lighthouse/pipeline/driver.py @@ -75,7 +75,8 @@ def apply(self, module: ir.Module, print_after_all: bool = False) -> ir.Module: if module.context != self.context: raise ValueError("Module context does not match driver context.") for stage in self.stages: - module = stage.apply(module) + with self.context: + module = stage.apply(module) if print_after_all: print(f"After stage {stage}:\n{module}") return module diff --git a/lighthouse/pipeline/stage.py b/lighthouse/pipeline/stage.py index 558f9905..00a5c660 100644 --- a/lighthouse/pipeline/stage.py +++ b/lighthouse/pipeline/stage.py @@ -180,11 +180,12 @@ def __init__(self, transform: Transform | ir.Module, context: ir.Context): raise ValueError( f"Transform module '{transform.filename}' does not define a '{transform.generator}' generator function." ) + # Get the generator function in the transform module. self.generator = getattr(transform_module, transform.generator) # Run the function with the dictionary as the options that will create the named sequence. with context, ir.Location.unknown(): - self.module = self.generator(transform.options) + self.module = self.generator(**transform.options) else: raise ValueError(f"Unsupported transform type: {transform.type}") diff --git a/lighthouse/schedule/__init__.py b/lighthouse/schedule/__init__.py index e82438ee..b1204476 100644 --- a/lighthouse/schedule/__init__.py +++ b/lighthouse/schedule/__init__.py @@ -6,6 +6,8 @@ from .linalg import linalg_contract_fold_unit_dims from .packing import block_pack_matmuls from .tiling import tile_ops +from .vectorization import flatten_vector_ops +from .vectorization import simplify_vector_ops from .vectorization import vectorize_linalg from .vectorization import vectorize_all from .vectorization import x86_vectorization @@ -18,10 +20,12 @@ "convert_function_results", "create_named_sequence", "create_schedule", + "flatten_vector_ops", "hoist_loops", "linalg_contract_fold_unit_dims", "print_ir", "schedule_boilerplate", + "simplify_vector_ops", "tile_ops", "vectorize_all", "vectorize_linalg", diff --git a/lighthouse/schedule/linalg.py b/lighthouse/schedule/linalg.py index 4d7790ec..bebe3aa3 100644 --- a/lighthouse/schedule/linalg.py +++ b/lighthouse/schedule/linalg.py @@ -6,7 +6,7 @@ import lighthouse.transform as lh_transform -def linalg_contract_fold_unit_dims(options: dict = {}) -> ir.Module: +def linalg_contract_fold_unit_dims() -> ir.Module: """ Fold unit dims of linalg contract. diff --git a/lighthouse/schedule/packing.py b/lighthouse/schedule/packing.py index c487db97..59ba895d 100644 --- a/lighthouse/schedule/packing.py +++ b/lighthouse/schedule/packing.py @@ -1,7 +1,7 @@ from mlir import ir from mlir.dialects import transform -from .builders import schedule_boilerplate +from lighthouse.schedule.builders import schedule_boilerplate import lighthouse.transform as lh_transform diff --git a/lighthouse/schedule/tiling.py b/lighthouse/schedule/tiling.py index c16c22b2..f28e0475 100644 --- a/lighthouse/schedule/tiling.py +++ b/lighthouse/schedule/tiling.py @@ -2,7 +2,7 @@ from mlir.dialects import transform from mlir.dialects.transform.structured import MatchInterfaceEnum -from .builders import schedule_boilerplate +from lighthouse.schedule.builders import schedule_boilerplate import lighthouse.transform as lh_transform diff --git a/lighthouse/schedule/vectorization.py b/lighthouse/schedule/vectorization.py index df79f3ab..23988d7f 100644 --- a/lighthouse/schedule/vectorization.py +++ b/lighthouse/schedule/vectorization.py @@ -1,9 +1,8 @@ from mlir import ir from mlir.dialects import transform -from mlir.dialects.transform import structured -from mlir.dialects.transform import vector +from mlir.dialects.transform import structured, vector, tensor -from .builders import schedule_boilerplate +from lighthouse.schedule.builders import schedule_boilerplate import lighthouse.transform as lh_transform @@ -64,3 +63,37 @@ def x86_vectorization() -> ir.Module: transform.yield_() return schedule + + +def simplify_vector_ops() -> ir.Module: + """ + Apply simplification patterns to vector operations. + + Returns: + Schedule + """ + with schedule_boilerplate() as (schedule, named_seq): + with ir.InsertionPoint( + transform.ApplyPatternsOp(named_seq.bodyTarget).patterns + ): + # FIXME: This transform breaks AVX512 FMA recognition, + # but it's in the other sub-schedule of the same name. + # vector.apply_patterns_vector_cast_away_vector_leading_one_dim() + tensor.apply_patterns_tensor_fold_tensor_subset_ops_into_vector_transfers() + transform.apply_patterns_canonicalization() + transform.yield_() + return schedule + + +def flatten_vector_ops() -> ir.Module: + """ + Flatten vector ops to 1D. + + Returns: + Schedule + """ + with schedule_boilerplate() as (schedule, named_seq): + lh_transform.flatten_vector_ops(named_seq.bodyTarget) + lh_transform.cleanup(named_seq.bodyTarget) + transform.yield_() + return schedule diff --git a/lighthouse/schedule/x86/__init__.py b/lighthouse/schedule/x86/__init__.py index f38af651..4e10d28c 100644 --- a/lighthouse/schedule/x86/__init__.py +++ b/lighthouse/schedule/x86/__init__.py @@ -1,6 +1,11 @@ +from .cache_tiling import matmul_cache_tiling from .pack_lowering import lower_packs_unpacks +from .register_tiling import matmul_register_tiling, matmul_register_unroll __all__ = [ "lower_packs_unpacks", + "matmul_cache_tiling", + "matmul_register_tiling", + "matmul_register_unroll", "tile_and_vector_matmul", ] diff --git a/lighthouse/schedule/x86/cache_tiling.py b/lighthouse/schedule/x86/cache_tiling.py new file mode 100644 index 00000000..53e27aff --- /dev/null +++ b/lighthouse/schedule/x86/cache_tiling.py @@ -0,0 +1,49 @@ +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured + +from lighthouse.dialects.transform import transform_ext +from lighthouse.schedule.builders import schedule_boilerplate +import lighthouse.transform as lh_transform + + +def matmul_cache_tiling( + target: str, tile_size: int = 32, fuse_producers: bool = False +) -> ir.Module: + """ + Applies cache tiling to the target matmul operation. + Creates a forall loop on successful rewrite. + + This tiling step improves computation's memory access pattern + and exposes parallelism. + + Optionally, fusion can be performed after tiling to minimize + data transfers. + + Args: + target: Handle to target operation. + tile_size: Target size for tile dimensions. + fuse_producers: Apply extra producer ops fusion after tiling. + Returns: + Schedule module + """ + with schedule_boilerplate() as (sched, named_seq): + ops = lh_transform.match_op(named_seq.bodyTarget, target) + with lh_transform.foreach(ops) as op: + tiles = transform_ext.get_tiling_sizes(op, tile_dim=tile_size) + if fuse_producers: + # Tile the target and greedily fuse its producers. + structured.FuseOp( + op, + tile_sizes=tiles, + apply_cleanup=True, + use_forall=True, + ).results + else: + # Only tile the target. + structured.TileUsingForallOp(op, tile_sizes=tiles).results + transform.yield_() + transform.yield_() + + # TODO: Fuse elementwise consumers. + return sched diff --git a/lighthouse/schedule/x86/pack_lowering.py b/lighthouse/schedule/x86/pack_lowering.py index 4133593f..d2f16c39 100644 --- a/lighthouse/schedule/x86/pack_lowering.py +++ b/lighthouse/schedule/x86/pack_lowering.py @@ -10,10 +10,7 @@ def lower_packs_for_vectorization( - pack_ops, - pack_tile_sizes: list[int], - vector_tile_sizes: list[int] | None = None, - vector_unroll_factors: list[int] = [], + pack_ops, pack_tile_sizes, vector_tile_sizes=None, vector_unroll_factors=[] ): """ Lower packs into hardware-friendly operations. @@ -45,9 +42,7 @@ def lower_packs_for_vectorization( def lower_unpacks_for_vectorization( - unpack_ops, - unpack_tile_sizes: list[int], - vector_tile_sizes: list[int] | None = None, + unpack_ops, unpack_tile_sizes, vector_tile_sizes=None ): """ Lower unpacks into hardware-friendly operations. @@ -91,18 +86,16 @@ def lower_packs_unpacks(tile_size: int) -> ir.Module: pack_unpack_vector_n = min(64, tile_size) packs = lh_transform.match_op(named_seq.bodyTarget, "linalg.pack") lower_packs_for_vectorization( - packs, + pack_ops=packs, pack_tile_sizes=[1, 1], vector_tile_sizes=[1, 1, pack_unpack_vector_m, pack_unpack_vector_n], - vector_unroll_factors=[ - tile_size // pack_unpack_vector_n, - ], + vector_unroll_factors=[tile_size // pack_unpack_vector_n], ) lh_transform.cleanup(named_seq.bodyTarget) unpacks = lh_transform.match_op(named_seq.bodyTarget, "linalg.unpack") lower_unpacks_for_vectorization( - unpacks, + unpack_ops=unpacks, unpack_tile_sizes=[tile_size, tile_size], vector_tile_sizes=[1], ) diff --git a/lighthouse/schedule/x86/register_tiling.py b/lighthouse/schedule/x86/register_tiling.py new file mode 100644 index 00000000..70730e67 --- /dev/null +++ b/lighthouse/schedule/x86/register_tiling.py @@ -0,0 +1,95 @@ +from mlir import ir + +import lighthouse.schedule as lh_schedule + +# FIXME: These functions should receive u-Arch and data type parameters and come +# up with the appropriate tiling and unrolling factors based on the target hardware + + +def matmul_register_tiling( + target: str, + tile_size: int = 32, + reg_tile_batch: int = 0, + reg_tile_m: int = 8, + reg_tile_n: int = 32, + reg_tile_k: int = 2, +) -> ir.Module: + """ + Applies register tiling to the target matmul operation. + + This tiling step prepares the IR for the x86 vectorization passes. + + Args: + target: Target operation. + tile_size: Tile size used in the previous cache tiling step. + reg_tile_batch: Target size for batch dimension tile. + reg_tile_m: Target size for M dimension tile. + reg_tile_n: Target size for N dimension tile. + reg_tile_k: Target size for K dimension tile. + batch: True if the input has batch dimension. + Returns: + Schedule + """ + tile_sizes = [reg_tile_m, reg_tile_n, reg_tile_k] + tile_interchange = [] + if reg_tile_batch: + tile_sizes = [reg_tile_batch] + tile_sizes + tile_interchange = [1, 2, 0, 3] + + reg_peel_loops = [] + assert tile_size % reg_tile_k == 0, "Invalid K dim register tiling" + if tile_size % reg_tile_n != 0: + reg_peel_loops.append(1) + if tile_size % reg_tile_m != 0: + reg_peel_loops.append(0) + return lh_schedule.tile_ops( + target_op=target, + tile_sizes=tile_sizes, + tile_interchange=tile_interchange, + peel_loops=reg_peel_loops, + ) + + +def matmul_register_unroll( + target: str, + reg_tile_m: int = 8, + reg_tile_n: int = 32, + reg_tile_k: int = 2, + reg_unroll_m: int = 1, + reg_unroll_n: int = 16, + reg_unroll_k: int = 1, + batch: bool = False, +) -> ir.Module: + """ + Applies register unrolling to the target matmul operation. + + This unrolling step prepares the IR for the x86 vectorization passes. + Ensure that shapes are compatible with target hardware instructions. + + Args: + target: Target operation. + tile_size: Tile size used in the previous cache tiling step. + reg_tile_m: Target size for M dimension tile. + reg_tile_n: Target size for N dimension tile. + reg_tile_k: Target size for K dimension tile. + reg_unroll_m: Unroll M dimension after tiling. + reg_unroll_n: Unroll N dimension after tiling. + reg_unroll_k: Unroll K dimension after tiling. + batch: True if the input has batch dimension. + Returns: + Schedule + """ + tile_sizes = [reg_unroll_m, reg_unroll_n, reg_unroll_k] + if batch: + tile_sizes = [0] + tile_sizes + + reg_unroll_factors = [ + reg_tile_m // reg_unroll_m, + reg_tile_n // reg_unroll_n, + reg_tile_k // reg_unroll_k, + ] + return lh_schedule.tile_ops( + target_op=target, + tile_sizes=tile_sizes, + unroll_factors=reg_unroll_factors, + ) diff --git a/lighthouse/schedule/x86/tile_and_vector_matmul.py b/lighthouse/schedule/x86/tile_and_vector_matmul.py index 4a8910e1..fb75c609 100644 --- a/lighthouse/schedule/x86/tile_and_vector_matmul.py +++ b/lighthouse/schedule/x86/tile_and_vector_matmul.py @@ -8,7 +8,6 @@ def create_schedule( - options: dict = {}, tile_sizes: tuple[int, int] = [32, 32], register_tile: tuple[int, int, int] = [8, 32, 1], matmul_op: str = "linalg.matmul", diff --git a/lighthouse/transform/x86/__init__.py b/lighthouse/transform/x86/__init__.py deleted file mode 100644 index 115a0d47..00000000 --- a/lighthouse/transform/x86/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .cache_tiling import matmul_cache_tiling - -__all__ = [ - "matmul_cache_tiling", -] diff --git a/lighthouse/transform/x86/cache_tiling.py b/lighthouse/transform/x86/cache_tiling.py deleted file mode 100644 index 3063b3ec..00000000 --- a/lighthouse/transform/x86/cache_tiling.py +++ /dev/null @@ -1,45 +0,0 @@ -from mlir import ir -from mlir.dialects.transform import structured - -from lighthouse.dialects.transform import transform_ext - - -def matmul_cache_tiling( - target, - tile_size: int = 32, - fuse_producers: bool = False, -) -> tuple[ir.Value, ir.Value]: - """ - Applies cache tiling to the target matmul operation. - Creates a forall loop on successful rewrite. - - This tiling step improves computation's memory access pattern - and exposes parallelism. - - Optionally, fusion can be performed after tiling to minimize - data transfers. - - Args: - target: Handle to target operation. - tile_size: Target size for tile dimensions. - fuse_producers: Apply extra producer ops fusion after tiling. - Returns: - Handles to the tiled op and created loop - """ - tiles = transform_ext.get_tiling_sizes(target, tile_dim=tile_size) - if fuse_producers: - # Tile the target and greedily fuse its producers. - tiled_op, forall_op = structured.FuseOp( - target, - tile_sizes=tiles, - apply_cleanup=True, - use_forall=True, - ).results - else: - # Only tile the target. - tiled_op, forall_op = structured.TileUsingForallOp( - target, tile_sizes=tiles - ).results - # TODO: Fuse elementwise consumers. - - return tiled_op, forall_op diff --git a/test/opt/transforms/pipeline-check.py b/test/opt/transforms/pipeline-check.py index 961e120c..274c9860 100644 --- a/test/opt/transforms/pipeline-check.py +++ b/test/opt/transforms/pipeline-check.py @@ -4,7 +4,7 @@ from lighthouse.pipeline.stage import apply_bundle -def create_schedule(options: dict = {}) -> ir.Module: +def create_schedule(skip_llvm=False) -> ir.Module: """Creates a Transform Schedule for the test's optimization pipeline.""" schedule_module = ir.Module.create() @@ -32,7 +32,7 @@ def create_schedule(options: dict = {}) -> ir.Module: mod = apply_registered_pass(mod, "convert-linalg-to-loops") mod = apply_bundle(mod, "cleanup.yaml") - if not options.get("skip_llvm", False): + if not skip_llvm: mod = apply_bundle(mod, "llvm_lowering.yaml") transform.YieldOp()