diff --git a/.github/codestyle/copyright.hook b/.github/codestyle/copyright.hook index 484ada0..3479940 100644 --- a/.github/codestyle/copyright.hook +++ b/.github/codestyle/copyright.hook @@ -43,7 +43,7 @@ def _get_comment_mark(path): if lang_type.search(path) is not None: return "#" - lang_type=re.compile(r"\.(h|c|hpp|cc|cpp|cu|go|cuh|proto)$") + lang_type=re.compile(r"\.(h|c|hpp|hxx|cc|cpp|cxx|cu|go|cuh|proto)$") if lang_type.search(path) is not None: return "//" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c16f79..a460928 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: name: copyright_checker entry: python3 ./.github/codestyle/copyright.hook language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$ + files: \.(c|cc|cxx|cpp|cu|cuh|h|hpp|hxx|proto|py|sh)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: diff --git a/Dockerfile b/Dockerfile index 476ad3f..df1fde2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,25 @@ FROM nvcr.io/nvidia/pytorch:25.10-py3 ARG FLASH_ATTENTION_COMMIT_ID="b613d9e2c8475945baff3fd68f2030af1b890acf" +# CUTLASS — source is always cloned (the magi_compiler EVT-fusion path +# JIT-includes its headers and our /usr/local/cutlass tree is the readable +# reference checkout). The CMake-driven profiler/library is compiled +# only for supported targets; every other arch gets headers only. +# +# Supported NVCC arch strings (CUTLASS_NVCC_ARCHS): +# 90a — Hopper (H100, compute_cap 9.x, WGMMA/TMA) +# 120a — consumer Blackwell (RTX 50 series, compute_cap 12.x) +# +# Override behaviour with build args: +# --build-arg CUTLASS_BUILD=yes|no|auto +# yes — force cmake configure (requires CUTLASS_NVCC_ARCHS or a GPU) +# no — skip cmake even if a supported GPU is present +# auto — (default) compile iff nvidia-smi reports 9.x or 12.x +# --build-arg CUTLASS_NVCC_ARCHS=90a|120a +ARG CUTLASS_COMMIT_ID="f74fea9ce35868d3ae9f8d1dce1969d7250d3f90" +ARG CUTLASS_BUILD="auto" +ARG CUTLASS_NVCC_ARCHS="" + ENV PIP_NO_CACHE_DIR=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \ PYTHONDONTWRITEBYTECODE=1 @@ -18,6 +37,7 @@ RUN --mount=type=secret,id=http_proxy,required=false \ ca-certificates \ git \ build-essential \ + cmake \ ninja-build && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -42,6 +62,65 @@ RUN --mount=type=secret,id=http_proxy,required=false \ cp /tmp/flash-attention/hopper/flash_attn_interface.py ${python_path}/flash_attn_3/ && \ rm -rf /tmp/flash-attention + +RUN --mount=type=secret,id=http_proxy,required=false \ + --mount=type=secret,id=https_proxy,required=false \ + export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \ + export https_proxy="$(cat /run/secrets/https_proxy 2>/dev/null || true)" && \ + mkdir -p /usr/local/cutlass && \ + cd /usr/local/cutlass && \ + git init -q && \ + git remote add origin https://github.com/NVIDIA/cutlass.git && \ + git fetch origin ${CUTLASS_COMMIT_ID} --depth 1 && \ + git checkout ${CUTLASS_COMMIT_ID} && \ + (git submodule update --init --recursive --depth 1 --jobs 8 || \ + git submodule update --init --recursive --depth 1 --jobs 1) + + +RUN set -eu; \ + _cutlass_arch_from_gpu() { \ + if ! command -v nvidia-smi >/dev/null 2>&1; then return 1; fi; \ + cap="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -n1 | tr -d ' ')"; \ + case "${cap}" in \ + 9.*) echo "90a" ;; \ + 12.*) echo "120a" ;; \ + *) return 1 ;; \ + esac; \ + }; \ + if [ -n "${CUTLASS_NVCC_ARCHS}" ]; then \ + NVCC_ARCHS="${CUTLASS_NVCC_ARCHS}"; \ + echo "[CUTLASS] Using CUTLASS_NVCC_ARCHS=${NVCC_ARCHS} (build-arg override)."; \ + elif arch="$(_cutlass_arch_from_gpu)"; then \ + NVCC_ARCHS="${arch}"; \ + echo "[CUTLASS] nvidia-smi → CUTLASS_NVCC_ARCHS=${NVCC_ARCHS}."; \ + else \ + NVCC_ARCHS=""; \ + fi; \ + case "${CUTLASS_BUILD}" in \ + no) echo "[CUTLASS] CUTLASS_BUILD=no — skipping cmake configure."; exit 0 ;; \ + yes) \ + if [ -z "${NVCC_ARCHS}" ]; then \ + echo "[CUTLASS] CUTLASS_BUILD=yes but no arch: set CUTLASS_NVCC_ARCHS=90a|120a or build on a 9.x/12.x GPU."; \ + exit 1; \ + fi; \ + DO_BUILD=1 ;; \ + auto) \ + if [ -z "${NVCC_ARCHS}" ]; then \ + echo "[CUTLASS] No sm_90/sm_120 GPU and no CUTLASS_NVCC_ARCHS — skipping cmake (headers still available)."; \ + exit 0; \ + fi; \ + DO_BUILD=1 ;; \ + *) echo "[CUTLASS] Unknown CUTLASS_BUILD=${CUTLASS_BUILD}"; exit 1 ;; \ + esac; \ + case "${NVCC_ARCHS}" in \ + 90a|120a) ;; \ + *) echo "[CUTLASS] Unsupported CUTLASS_NVCC_ARCHS=${NVCC_ARCHS} (expected 90a or 120a)."; exit 1 ;; \ + esac; \ + [ -n "${DO_BUILD:-}" ] && cd /usr/local/cutlass && \ + export CUDACXX="${CUDA_INSTALL_PATH:-${CUDA_HOME:-/usr/local/cuda}}/bin/nvcc" && \ + mkdir -p build && cd build && \ + cmake .. -DCUTLASS_NVCC_ARCHS="${NVCC_ARCHS}" + RUN --mount=type=secret,id=http_proxy,required=false \ --mount=type=secret,id=https_proxy,required=false \ export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \ diff --git a/README.md b/README.md index dd191a3..69c6e05 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,18 @@ pip install -r requirements.txt # Step 4 — Install MagiCompiler (pick one) pip install . # End users (recommended) # pip install -e . --no-build-isolation --config-settings editable_mode=compat # Developer / editable + +# Step 5 (optional) — Install CUTLASS for matmul epilogue fusion +# Required for the CUTLASS-based matmul + epilogue fusion pass (sm_90 / sm_120). +# Without CUTLASS the compiler still works but skips this optimization. +git clone --depth 1 https://github.com/NVIDIA/cutlass.git /usr/local/cutlass +# Or specify a custom path: +# git clone --depth 1 https://github.com/NVIDIA/cutlass.git /your/path +# export MAGI_CUTLASS_ROOT=/your/path +export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc +mkdir /usr/local/cutlass/build && cd /usr/local/cutlass/build +cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GPU architecture +# cmake .. -DCUTLASS_NVCC_ARCHS=120a # compiles for NVIDIA consumer Blackwell (RTX 50 series) ``` --- diff --git a/magi_compiler/config.py b/magi_compiler/config.py index c5edf38..a303093 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -64,6 +64,18 @@ class PassConfig(BaseModel): # TODO: Add sequence parallelism pass and async TP pass. # TODO: Add Ulysses overlap pass. enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.") + enable_mm_epilogue_fusion: bool = Field( + False, + description=( + "Whether to enable the matmul + elementwise epilogue fusion pass. " + "On RTX 5090 (sm_120) this lowers fused chains to a CUTLASS Sm80EVT " + "kernel via the fusion.MatmulEvtEpilogueFusionPass; on H100 " + "(sm_90) the swiglu sub-path additionally uses the native Sm90 " + "TMA + WGMMA DualGemm. The pass is a no-op on older architectures " + "regardless of this flag, but the flag still controls whether it " + "is registered at all." + ), + ) @property def hash(self) -> str: @@ -141,6 +153,14 @@ class OffloadConfig(BaseModel): bandwidth_safety_factor: float = Field(0.9, description="The safety factor for the H2D bandwidth.") +def _find_cutlass_root() -> str: + """Return the CUTLASS source root, or empty string if not found.""" + path = os.environ.get("MAGI_CUTLASS_ROOT", "/usr/local/cutlass") + if os.path.isdir(path): + return path + return "" + + class CompileConfig(BaseSettings): """Top-level configuration consumed by ``magi_compile`` and the MagiCompiler backend. @@ -172,6 +192,10 @@ class CompileConfig(BaseSettings): default=os.path.expanduser("~/.cache/magi_compiler"), description="Root directory for persisting compiled artifacts and debug dumps.", ) + cutlass_root: str = Field( + default_factory=_find_cutlass_root, + description="Path to the CUTLASS source tree. Default: $MAGI_CUTLASS_ROOT or /usr/local/cutlass.", + ) # ---- Compilation mode ---- aot: bool = Field( @@ -234,6 +258,10 @@ class CompileConfig(BaseSettings): ), ) + @property + def has_cutlass(self) -> bool: + return bool(self.cutlass_root) + @property def hash(self) -> str: return compute_hash(self.model_dump(mode="json")) diff --git a/magi_compiler/passes/full_graph/full_graph_pass_mgr.py b/magi_compiler/passes/full_graph/full_graph_pass_mgr.py index 502d190..7f04183 100644 --- a/magi_compiler/passes/full_graph/full_graph_pass_mgr.py +++ b/magi_compiler/passes/full_graph/full_graph_pass_mgr.py @@ -16,6 +16,7 @@ from ...magi_depyf.timeline import observe_lifecycle from .remove_item import RemoveItemPass +from .remove_useless_ops import EliminateIdentityViewCastPass from .replace_sage_atten import ReplaceSageAttentionPass @@ -30,6 +31,7 @@ def __init__(self, pass_config): if self.pass_config.enable_sage_attn: self.passes.append(ReplaceSageAttentionPass()) self.passes.append(RemoveItemPass()) + self.passes.append(EliminateIdentityViewCastPass()) @observe_lifecycle("full_graph_manager") def __call__(self, gm: torch.fx.GraphModule): diff --git a/magi_compiler/passes/full_graph/remove_useless_ops.py b/magi_compiler/passes/full_graph/remove_useless_ops.py new file mode 100644 index 0000000..3863e7b --- /dev/null +++ b/magi_compiler/passes/full_graph/remove_useless_ops.py @@ -0,0 +1,116 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch._inductor.fx_passes.pre_grad + +from ...magi_depyf.timeline import emit_pass_lifecycle +from ..pass_base import MagiInductorPass + + +class EliminateIdentityViewCastPass(MagiInductorPass): + """ + Remove useless convert, view, reshape operations. + When their input already has the target type and shape, these operations are redundant. + """ + + TARGET_METHODS = { + "view", + "reshape", + "to", + "type", + "contiguous", + "flatten", + "permute", + "transpose", + "t", + "unsqueeze", + "squeeze", + "expand", + "repeat", + "bfloat16", + "float", + "half", + "int", + "long", + "short", + "double", + "bool", + "byte", + } + + @staticmethod + def _get_tensor_info(node: torch.fx.Node): + # Get tensor info from example_value + if "example_value" in node.meta: + val = node.meta["example_value"] + if isinstance(val, torch.Tensor): + return val.shape, val.dtype, val.stride() + elif isinstance(val, (list, tuple)) and len(val) > 0 and isinstance(val[0], torch.Tensor): + return val[0].shape, val[0].dtype, val[0].stride() + + return None, None, None + + def is_applicable(self, graph: torch.fx.Graph, shape: int | None = None) -> bool: + for node in graph.nodes: + if node.op == "call_method" and node.target in self.TARGET_METHODS: + return True + return False + + @emit_pass_lifecycle + def __call__(self, graph: torch.fx.Graph): + nodes_to_remove = [] + + for node in graph.nodes: + is_target_method = node.op == "call_method" and node.target in self.TARGET_METHODS + if not is_target_method: + continue + + # Need at least one argument (the input tensor) + if not node.args or not isinstance(node.args[0], torch.fx.Node): + continue + + input_node = node.args[0] + + node_shape, node_dtype, node_stride = self._get_tensor_info(node) + input_shape, input_dtype, input_stride = self._get_tensor_info(input_node) + if node_shape is None or input_shape is None: + continue + if node_dtype is None or input_dtype is None: + continue + # Some ops or metadata might not have stride properly captured, + # but if they do, we should require them to match to be totally safe against contiguous-forcing ops. + if node_stride is not None and input_stride is not None and node_stride != input_stride: + continue + + # Check if shape and dtype match exactly + if node_shape == input_shape and node_dtype == input_dtype: + # For _to_copy, ensure we are not changing memory format or device or other properties implicitly, + # but typically in full graph if shape and dtype match, and it's on the same device, it's safe. + # Let's also check device just in case if it's available. + def get_device(n): + if "example_value" in n.meta and isinstance(n.meta["example_value"], torch.Tensor): + return n.meta["example_value"].device + + node_device = get_device(node) + input_device = get_device(input_node) + if node_device is not None and input_device is not None and node_device != input_device: + continue + + # Replace uses + node.replace_all_uses_with(input_node) + nodes_to_remove.append(node) + + for node in nodes_to_remove: + graph.erase_node(node) diff --git a/magi_compiler/passes/piecewise_graph/fusion/__init__.py b/magi_compiler/passes/piecewise_graph/fusion/__init__.py new file mode 100644 index 0000000..3eaa44a --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/magi_compiler/passes/piecewise_graph/fusion/common/__init__.py b/magi_compiler/passes/piecewise_graph/fusion/common/__init__.py new file mode 100644 index 0000000..6fa72d4 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/common/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2026 SandAI. All Rights Reserved. diff --git a/magi_compiler/passes/piecewise_graph/fusion/common/codegen_shared.py b/magi_compiler/passes/piecewise_graph/fusion/common/codegen_shared.py new file mode 100644 index 0000000..b572fb6 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/common/codegen_shared.py @@ -0,0 +1,120 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Arch-agnostic codegen helpers shared by the SM80 and SM90 EVT codegens.""" + +from __future__ import annotations + +import textwrap + +_DTYPE_TO_CUTLASS = {"bfloat16": "cutlass::bfloat16_t", "float16": "cutlass::half_t", "float32": "float"} + +_DTYPE_TO_AT = {"bfloat16": "at::kBFloat16", "float16": "at::kHalf", "float32": "at::kFloat"} + +_DTYPE_TO_AT_CPP = {"bfloat16": "at::BFloat16", "float16": "at::Half", "float32": "float"} + + +# IR op name → CUTLASS template name (arch-agnostic, works on both Sm80EVT and Sm90EVT). +_BUILTIN_FN_TEMPLATE = { + "add": "cutlass::plus", + "sub": "cutlass::minus", + "mul": "cutlass::multiplies", + "div": "cutlass::divides", + "max": "cutlass::maximum", + "min": "cutlass::minimum", + "neg": "cutlass::negate", + "sigmoid": "cutlass::epilogue::thread::Sigmoid", + "silu": "cutlass::epilogue::thread::SiLu", + "tanh": "cutlass::epilogue::thread::Tanh", + "relu": "cutlass::epilogue::thread::ReLu", + "abs": "cutlass::absolute_value_op", +} + +# Custom functor bodies: ``T`` = element type, ``x`` = input value. +_CUSTOM_UNARY_BODY = { + "square": "return x * x;", + "exp": "return cutlass::fast_exp(x);", + "log": "return cutlass::fast_log(x);", + "sqrt": "return cutlass::fast_sqrt(x);", + "rsqrt": "return cutlass::fast_rsqrt(x);", + "erf": "return T(erff(float(x)));", + "gelu_erf": "return T(0.5f) * x * (T(1.0f) + T(erff(float(x) * 0.70710678118654752f)));", + "gelu_tanh": ( + "float v = float(x);" " return T(0.5f * v * (1.0f + tanhf(" "0.7978845608028654f * (v + 0.044715f * v * v * v))));" + ), +} + +# Scalar-baked: body uses ``x`` and ``c`` (compile-time constant). +_CUSTOM_SCALAR_BODY = { + "add_scalar": "return x + c;", + "sub_scalar": "return x - c;", + "mul_scalar": "return x * c;", + "div_scalar": "return x / c;", + "rsub_scalar": "return c - x;", + "clamp_min_c": "return x < c ? c : x;", + "clamp_max_c": "return x < c ? x : c;", + # scaled_silu_alpha(x, alpha) = x * sigmoid(alpha * x). Used by GELU7. + "scaled_silu_alpha": ( + "T t = c * x;" " T one = T(1.0f);" " T sig = one / (one + cutlass::fast_exp(-t));" " return x * sig;" + ), + # pow_scalar(x, c) – emit as repeated multiplies for small int c. + # Otherwise fall back to powf. + "pow_scalar": "return T(powf(float(x), float(c)));", +} + + +_VALID_ALIGN_BITS = (128, 64) + + +def _scalar_literal_T(value: float) -> str: + return f"T({float(value)!r}f)" + + +def _emit_custom_functor(name: str, op: str, scalar=None) -> str: + """Emit a unary CUTLASS-compatible functor with scalar + Array specialisation.""" + if op in _CUSTOM_UNARY_BODY: + body = _CUSTOM_UNARY_BODY[op] + scalar_decl = "" + elif op in _CUSTOM_SCALAR_BODY: + if scalar is None: + raise ValueError(f"Scalar op {op!r} needs a baked constant") + body = _CUSTOM_SCALAR_BODY[op] + scalar_decl = f" const T c = {_scalar_literal_T(scalar)};\n" + else: + raise ValueError(f"No custom functor body for op {op!r}") + return textwrap.dedent( + f"""\ + template + struct {name} {{ + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE + T operator()(T const& x) const {{ + {scalar_decl} {body} + }} + }}; + + template + struct {name}> {{ + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE + cutlass::Array operator()(cutlass::Array const& v) const {{ + {name} op; + cutlass::Array out; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) out[i] = op(v[i]); + return out; + }} + }}; + """ + ) diff --git a/magi_compiler/passes/piecewise_graph/fusion/common/cutlass_kernels/swiglu_combine.h b/magi_compiler/passes/piecewise_graph/fusion/common/cutlass_kernels/swiglu_combine.h new file mode 100644 index 0000000..53b2ab9 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/common/cutlass_kernels/swiglu_combine.h @@ -0,0 +1,153 @@ +// Copyright (c) 2026 SandAI. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary epilogue combine functor for the swiglu DualGemm fusion. +// +// D = silu_alpha( clamp(lhs, max=limit) ) * ( clamp(rhs, -limit, limit) + 1 ) +// +// silu_alpha(x) = x * sigmoid(alpha * x) default: alpha = 1.702, limit = 7.0 +// +// `lhs` is the gate-path output fragment (Op0 applied to A @ W_gate.T), +// `rhs` is the linear-path output fragment (Op1 applied to A @ W_linear.T). +// Both arrive as ElementOutput (bf16) fragments — this is dictated by the +// dual-epilogue call site (examples/45_dual_gemm/threadblock/dual_epilogue.h:413 +// passes `output_frag_ptr[0][i]` and `[1][i]`, which are post-conversion +// output-type fragments, not raw accumulator fragments). The combine upcasts +// to ElementCompute (fp32) internally, evaluates the swiglu expression, and +// converts back to bf16. +// +// Note on precision: the gate/linear matmuls accumulate in fp32 inside the +// MMAs. Op0/Op1 (LinearCombination, ScaleType::Nothing) downcast those fp32 +// accumulators to bf16 before this combine runs. The swiglu math itself +// stays in fp32 here, so the only extra precision loss vs the two-stage EVT +// version is the single fp32→bf16 round-trip on each accumulator at the +// epilogue boundary. Empirically this is well within the bf16 noise floor. +// +// Modelled on cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h — +// same interface contract: ElementOutput / ElementAccumulator / ElementCompute +// typedefs, kCount fragment width, empty Params, two operator() overloads +// (fragment + scalar), is_source_needed() returning true. + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +template < + typename ElementOutput_, + int Count, + typename ElementAccumulator_ = ElementOutput_, + typename ElementCompute_ = ElementOutput_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class SwigluCombine { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params { + ElementCompute alpha; + ElementCompute limit; + ElementCompute one; + + CUTLASS_HOST_DEVICE + Params() : alpha(ElementCompute(1.702f)), + limit(ElementCompute(7.0f)), + one(ElementCompute(1.0f)) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha_, ElementCompute limit_, ElementCompute one_) + : alpha(alpha_), limit(limit_), one(one_) {} + }; + +public: + + CUTLASS_HOST_DEVICE + SwigluCombine(Params const& p) : alpha_(p.alpha), limit_(p.limit), one_(p.one) {} + + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + CUTLASS_HOST_DEVICE + void set_k_partition(int /*k_partition*/, int /*k_partition_count*/) { + // swiglu cannot be split-K-reduced (non-linear epilogue). + assert(false); + } + + // Fragment-level. lhs = gate output fragment (bf16, post Op0), + // rhs = linear output fragment (bf16, post Op1). + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentOutput const& lhs, + FragmentOutput const& rhs) const { + NumericArrayConverter in2c; + NumericArrayConverter c2o; + + ComputeFragment gate = in2c(lhs); + ComputeFragment lin = in2c(rhs); + ComputeFragment out; + + Sigmoid sig; + ElementCompute const nlimit = -limit_; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + ElementCompute g = gate[i] < limit_ ? gate[i] : limit_; + ElementCompute r = lin[i] < nlimit ? nlimit + : (lin[i] > limit_ ? limit_ : lin[i]); + ElementCompute silu_g = g * sig(alpha_ * g); + out[i] = silu_g * (r + one_); + } + return c2o(out); + } + + // Scalar overload — required by the DualGemm epilogue boilerplate. + CUTLASS_HOST_DEVICE + ElementOutput operator()(ElementOutput const& lhs, + ElementOutput const& rhs) const { + ElementCompute g(lhs), r(rhs); + ElementCompute const nlimit = -limit_; + + Sigmoid sig; + + g = g < limit_ ? g : limit_; + r = r < nlimit ? nlimit : (r > limit_ ? limit_ : r); + ElementCompute silu_g = g * sig(alpha_ * g); + return ElementOutput(silu_g * (r + one_)); + } + +private: + ElementCompute alpha_; + ElementCompute limit_; + ElementCompute one_; +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass diff --git a/magi_compiler/passes/piecewise_graph/fusion/evt_ir.py b/magi_compiler/passes/piecewise_graph/fusion/evt_ir.py new file mode 100644 index 0000000..11c1935 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/evt_ir.py @@ -0,0 +1,210 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EVT (Epilogue Visitor Tree) intermediate representation. + +Dataclass IR built by the FX pass from ``aten.mm`` consumers, consumed by +``evt_codegen.py`` to render a CUTLASS .cu. Canonicalised to deterministic +JSON for the JIT module cache key. Adding a new op requires updating both +this file and the codegen. +""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass +from typing import List, Optional, Union + +UNARY_OPS = frozenset( + {"neg", "sigmoid", "silu", "gelu_erf", "gelu_tanh", "tanh", "relu", "square", "erf", "exp", "log", "sqrt", "rsqrt", "abs"} +) + +BINARY_OPS = frozenset({"add", "sub", "mul", "div", "max", "min"}) + +SCALAR_UNARY_OPS = frozenset( + { + "add_scalar", # x + c + "sub_scalar", # x - c + "mul_scalar", # x * c + "div_scalar", # x / c + "rsub_scalar", # c - x + "clamp_min_c", # max(x, c) + "clamp_max_c", # min(x, c) + "scaled_silu_alpha", # x * sigmoid(alpha * x), used by gelu7 + "pow_scalar", # x ** c (only sensible for small integer c) + } +) + +ALL_OPS = UNARY_OPS | BINARY_OPS | SCALAR_UNARY_OPS + +# Strings (not torch.dtype) so the IR is JSON-serialisable. +DTYPES = frozenset({"bfloat16", "float16", "float32"}) + +# Hardware-native ALU compute types supported by the EVT epilogue. +# +# Floating-point: FP32, FP16, BF16 are full-speed on both H100 (sm_90) and +# RTX 5090 (sm_120). FP64 is full-speed on H100 but extremely slow on 5090, +# so we exclude it from the EVT path. +# +# Integer: INT64, INT32, INT16, INT8 are ALU-supported on both architectures, +# but CUTLASS VisitorCompute / Sm90Compute templates are only instantiated +# for floating-point types, so integer compute_dtype is not valid here. +COMPUTE_DTYPES = frozenset({"bfloat16", "float16", "float32"}) + + +@dataclass(frozen=True) +class Accum: + """The fp32 GEMM accumulator. Always the unique starting leaf of the IR.""" + + kind: str = "accum" + + +@dataclass(frozen=True) +class RowBroadcast: + """1-D (N,) tensor broadcast along M. ``input_idx`` indexes the runtime extras list.""" + + input_idx: int + dtype: str + kind: str = "row_bcast" + + +@dataclass(frozen=True) +class ColBroadcast: + """1-D (M,) tensor broadcast along N.""" + + input_idx: int + dtype: str + kind: str = "col_bcast" + + +@dataclass(frozen=True) +class AuxLoad: + """2-D (M, N) row-major aux tensor. stride[1] must be 1, stride[0] 16-byte aligned.""" + + input_idx: int + dtype: str + kind: str = "aux_load" + + +@dataclass(frozen=True) +class Compute: + """An interior elementwise op over EVT subtrees. + + ``compute_dtype`` controls the precision of this node's VisitorCompute / + Sm90Compute template instantiation. Defaults to ``"float32"`` (the GEMM + accumulator's native precision). A preceding ``to(bf16)`` in the FX + chain sets it to ``"bfloat16"`` so the kernel runs that op in bf16. + """ + + op: str + children: tuple + scalar: Optional[float] = None + compute_dtype: str = "float32" + kind: str = "compute" + + def __post_init__(self): + if self.op not in ALL_OPS: + raise ValueError(f"Unknown EVT op: {self.op!r}") + if self.compute_dtype not in COMPUTE_DTYPES: + raise ValueError(f"Unsupported compute_dtype {self.compute_dtype!r} for EVT. " f"Valid: {sorted(COMPUTE_DTYPES)}") + if self.op in UNARY_OPS: + if len(self.children) != 1 or self.scalar is not None: + raise ValueError(f"UNARY op {self.op!r} requires 1 child, no scalar") + elif self.op in BINARY_OPS: + if len(self.children) != 2 or self.scalar is not None: + raise ValueError(f"BINARY op {self.op!r} requires 2 children, no scalar") + elif self.op in SCALAR_UNARY_OPS: + if len(self.children) != 1 or self.scalar is None: + raise ValueError(f"SCALAR_UNARY op {self.op!r} requires 1 child + scalar") + + +@dataclass(frozen=True) +class Store: + """Root of the IR. Casts the fp32 result to ``out_dtype`` and writes D.""" + + child: object # any IR node + out_dtype: str + kind: str = "store" + + def __post_init__(self): + if self.out_dtype not in DTYPES: + raise ValueError(f"Unknown out_dtype {self.out_dtype!r}") + + +IRNode = Union[Accum, RowBroadcast, ColBroadcast, AuxLoad, Compute, Store] + + +def to_dict(node) -> dict: + """Recursively convert an IR tree into a JSON-friendly dict for stable hashing.""" + if isinstance(node, Accum): + return {"kind": "accum"} + if isinstance(node, RowBroadcast): + return {"kind": "row_bcast", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, ColBroadcast): + return {"kind": "col_bcast", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, AuxLoad): + return {"kind": "aux_load", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, Compute): + d = {"kind": "compute", "op": node.op, "children": [to_dict(c) for c in node.children]} + if node.scalar is not None: + d["scalar"] = repr(float(node.scalar)) + if node.compute_dtype != "float32": + d["compute_dtype"] = node.compute_dtype + return d + if isinstance(node, Store): + return {"kind": "store", "out_dtype": node.out_dtype, "child": to_dict(node.child)} + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +def to_canonical_json(node) -> str: + """Deterministic JSON string for an IR tree. Same IR ⇒ same string.""" + return json.dumps(to_dict(node), sort_keys=True, separators=(",", ":")) + + +def cache_key(node, a_dtype: str, b_dtype: str) -> str: + """SHA-256 hash of (IR JSON, A dtype, B dtype). Used as the JIT module key.""" + payload = {"ir": to_dict(node), "a": a_dtype, "b": b_dtype, "version": 1} + blob = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(blob).hexdigest() + + +def walk_leaves(node) -> List: + """Return all leaf nodes in left-to-right pre-order.""" + out: list = [] + + def _go(n): + if isinstance(n, (Accum, RowBroadcast, ColBroadcast, AuxLoad)): + out.append(n) + elif isinstance(n, Compute): + for c in n.children: + _go(c) + elif isinstance(n, Store): + _go(n.child) + else: + raise TypeError(f"Unknown IR node type: {type(n).__name__}") + + _go(node) + return out + + +def is_trivial(node) -> bool: + """Store(Accum) — no compute; FX pass should refuse to emit these.""" + return isinstance(node, Store) and isinstance(node.child, Accum) + + +def num_extras(node) -> int: + """Maximum input_idx + 1 across all non-Accum leaves, or 0 if none.""" + indices: list = [leaf.input_idx for leaf in walk_leaves(node) if not isinstance(leaf, Accum)] + return max(indices) + 1 if indices else 0 diff --git a/magi_compiler/passes/piecewise_graph/fusion/evt_runtime.py b/magi_compiler/passes/piecewise_graph/fusion/evt_runtime.py new file mode 100644 index 0000000..7ecf6b1 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/evt_runtime.py @@ -0,0 +1,681 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime side of the EVT fusion: torch.library op + JIT loader + dispatch. + +This file owns: + * The ``magi_epilogue::matmul_fused_epilogue`` torch.library op + fake impl. + * A process-level cache mapping IR JSON → compiled cpp_extension module. + * Dispatch to one of two backends: + - ``kind == "evt"`` → JIT-compiled CUTLASS Sm80EVT kernel. + - ``kind == "swiglu_dual"`` → vendored DualGemm one-stage kernel. + Routes to the SM80 cp.async multistage path on sm_120 (RTX 5090) and + to the SM90 TMA + WGMMA path on sm_90 (H100). Both expose the same + ``swiglu_dual_matmul_out(A, B, D)`` PYBIND callable, so the + dispatcher is arch-agnostic. + +The kernel build directory uses the IR cache key + arch tag as its name so +re-runs and multi-process Inductor compile workers all hit the same on-disk +cache, and so a binary built for one arch never gets reused on another. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import shutil +import threading +from typing import Optional + +import torch + +from magi_compiler.config import get_compile_config + +from .evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store +from .sm80.evt_codegen import render_evt_cu as _render_evt_cu_sm80 +from .sm90.evt_codegen import render_evt_cu as _render_evt_cu_sm90 + +# ── torch.library op definition ─────────────────────────────────────────────── +# Reuse the existing ``magi_epilogue`` library so all our custom matmul ops +# live under one namespace. Defining a fresh op here is harmless even if +# ``matmul_epilogue_fusion.py`` has already initialised the library. +_LIB = torch.library.Library("magi_epilogue", "FRAGMENT") +_LIB.define( + "matmul_fused_epilogue(Tensor A, Tensor B, Tensor[] extras, str ir_json," + " str kind, int n_out, int out_dtype_id) -> Tensor" +) + + +# ── Output-dtype encoding (must round-trip through torch.library int args) ──── +_OUT_DTYPE_ID = {torch.bfloat16: 0, torch.float16: 1, torch.float32: 2} +_ID_TO_DTYPE = {v: k for k, v in _OUT_DTYPE_ID.items()} +_DTYPE_TO_STR = {torch.bfloat16: "bfloat16", torch.float16: "float16", torch.float32: "float32"} + + +def out_dtype_id(dt: torch.dtype) -> int: + """Encode a torch.dtype as a small int for inclusion in op args.""" + if dt not in _OUT_DTYPE_ID: + raise ValueError(f"Unsupported EVT output dtype {dt}") + return _OUT_DTYPE_ID[dt] + + +def out_dtype_from_id(i: int) -> torch.dtype: + return _ID_TO_DTYPE[i] + + +# Greedy alignment: 128 bits when divisible, 64-bit fallback. +GREEDY_ALIGN_BITS = (128, 64) + + +def _runtime_align_bits(dim: int, dtype: torch.dtype) -> int: + n_int = int(dim) + for bits in GREEDY_ALIGN_BITS: + align_elems = max(1, bits // (dtype.itemsize * 8)) + if n_int % align_elems == 0: + return bits + raise ValueError(f"dim={n_int} not even {GREEDY_ALIGN_BITS[-1]}-bit-aligned for dtype={dtype}") + + +def _aligned_n_stride(n_out: int, dtype: torch.dtype) -> int: + """Round n_out up to a 16-byte element count. + + 16 bytes is the minimum stride alignment required by both SM80 + (``AlignmentC = 128 / sizeof_bits`` = 8 bf16 elements) + and SM90 TMA (``cudaTensorMapEncodeTiled`` requires globalStrides + to be multiples of 16 bytes). + + Bytes-based formula keeps this dtype-agnostic: + bf16 / fp16 → 8 element pad boundary + fp32 → 4 element pad boundary + fp8 → 16 element pad boundary + """ + align_bytes = 16 + align = max(1, align_bytes // dtype.itemsize) + n = int(n_out) + return ((n + align - 1) // align) * align + + +# ── Compile cache + per-key build lock ──────────────────────────────────────── +_MODULE_CACHE: dict = {} +# Fast cache keyed by hashable tuple — skips json.dumps + sha256 on hot path. +_MODULE_FAST_CACHE: dict = {} +_MODULE_LOCKS: dict = {} +_MODULE_LOCKS_GLOBAL = threading.Lock() +_SWIGLU_LOCK = threading.Lock() + + +def _device_gencode_flags() -> list[str]: + """Return nvcc -gencode flags for the live device. + + sm_90 needs the ``a`` variant for WGMMA/TMA support. + Override with MAGI_EVT_GENCODE (semicolon-separated). + """ + override = os.environ.get("MAGI_EVT_GENCODE") + if override: + return [a for a in override.split(";") if a] + cap = torch.cuda.get_device_capability() + arch = f"{cap[0]}{cap[1]}" # "90" for H100, "120" for RTX 5090, "80" for A100 + # Use the wgmma-enabled "a" variant on Hopper; all other arches stay plain. + arch_for_code = f"{arch}a" if arch == "90" else arch + return [ + f"-gencode=arch=compute_{arch_for_code},code=sm_{arch_for_code}", + # Embed PTX of the same arch so a slightly newer driver / different + # minor revision JITs cleanly without rebuilding. + f"-gencode=arch=compute_{arch_for_code},code=compute_{arch_for_code}", + ] + + +def _device_arch_tag() -> str: + """Short tag for the live device (e.g. ``sm90``), folded into build_dir.""" + cap = torch.cuda.get_device_capability() + return f"sm{cap[0]}{cap[1]}" + + +def _evt_build_dir(key: str) -> str: + cache_root = get_compile_config().cache_root_dir + return os.path.join(cache_root, "evt_kernels", _device_arch_tag(), key) + + +def _per_key_lock(key: str) -> threading.Lock: + """Return the per-key build lock; coalesces concurrent compile requests.""" + with _MODULE_LOCKS_GLOBAL: + lock = _MODULE_LOCKS.get(key) + if lock is None: + lock = threading.Lock() + _MODULE_LOCKS[key] = lock + return lock + + +# Two-pronged hardening on top of ``cpp_extension.load``: +# +# (1) Warm-cache fast path. If the .so for this build_dir is already on +# disk, dlopen it directly — skip cpp_extension.load (and therefore +# FileBaton) entirely. After the first successful build, no run ever +# touches the lock file again, so multi-rank warm starts cannot hang. +# +# (2) Interruption cleanup. We only care about the on-disk lock during the +# call to cpp_extension.load. ``_track_build`` registers the build_dir +# before the call, ``_untrack_build`` un-registers it right after. +# atexit + SIGTERM/SIGINT/SIGHUP handlers fire only if we are still +# inside that window — they wipe the entire build_dir, eliminating the +# lock and any half-written ninja/nvcc artifacts so the next run +# starts from a clean slate. +# +# SIGKILL/OOM/power-loss leak the build_dir: signal handlers physically +# cannot run for those. Recovery there is "rm -rf the build_dir" by hand. +# Deliberately does NOT use fcntl.flock — multi-rank workloads on certain +# filesystems reject blocking flock with EAGAIN. + + +# Build_dirs whose cpp_extension.load is currently in flight. Touched only +# by _track_build / _untrack_build and the atexit / signal callbacks. +_PENDING_BUILD_DIRS: "set[str]" = set() +_PENDING_LOCK = threading.Lock() +_SIGNAL_HANDLERS_INSTALLED = False + + +def _cleanup_pending_build_dirs() -> None: + """Wipe every build_dir registered by an in-flight cpp_extension.load. + + Called from ``atexit`` and from SIGTERM/SIGINT/SIGHUP handlers. Removes + the whole directory — lock, ninja files, half-baked .cuda.o, partial + .so — so the next run rebuilds from scratch instead of inheriting + inconsistent state. Idempotent; never raises. + """ + with _PENDING_LOCK: + dirs = list(_PENDING_BUILD_DIRS) + _PENDING_BUILD_DIRS.clear() + for d in dirs: + shutil.rmtree(d, ignore_errors=True) + + +def _install_exit_cleanup_once() -> None: + """Install ``atexit`` and forwarding signal handlers exactly once per + process. Signal handlers chain to whatever was previously registered + so we don't interfere with torchrun / app-level signal handling.""" + global _SIGNAL_HANDLERS_INSTALLED + if _SIGNAL_HANDLERS_INSTALLED: + return + _SIGNAL_HANDLERS_INSTALLED = True + + import atexit + import signal + + atexit.register(_cleanup_pending_build_dirs) + + def _make_handler(signum: int): + prev = signal.getsignal(signum) + + def _handler(sn, frame, _prev=prev, _sig=signum): + try: + _cleanup_pending_build_dirs() + finally: + # Chain to whatever was installed before us; otherwise fall + # back to the signal's default action (terminate). + if callable(_prev) and _prev not in (signal.SIG_DFL, signal.SIG_IGN): + _prev(sn, frame) + elif _prev == signal.SIG_IGN: + return + else: + signal.signal(_sig, signal.SIG_DFL) + os.kill(os.getpid(), _sig) + + return _handler + + for sig_name in ("SIGTERM", "SIGINT", "SIGHUP"): + sig = getattr(signal, sig_name, None) + if sig is None: + continue + try: + signal.signal(sig, _make_handler(sig)) + except (ValueError, OSError): + # ValueError: not in main thread; OSError: invalid in this env. + pass + + +def _track_build(build_dir: str) -> None: + """Register ``build_dir`` for cleanup-on-exit. Pair with ``_untrack_build`` + on the success path so completed builds aren't wiped.""" + _install_exit_cleanup_once() + with _PENDING_LOCK: + _PENDING_BUILD_DIRS.add(build_dir) + + +def _untrack_build(build_dir: str) -> None: + """Unregister a build_dir after cpp_extension.load returns. The module is + already dlopen'd at this point so even if a signal beats us to the + discard, the in-memory module keeps working.""" + with _PENDING_LOCK: + _PENDING_BUILD_DIRS.discard(build_dir) + + +def _try_dlopen_prebuilt(build_dir: str, mod_name: str): + """Fast path: if the .so for this build_dir already exists, import it + directly without going through cpp_extension.load (which would try to + acquire FileBaton). Returns None on any miss / failure so the caller + falls back to the full compile path.""" + so_path = os.path.join(build_dir, f"{mod_name}.so") + if not os.path.isfile(so_path): + return None + try: + import importlib.util + + spec = importlib.util.spec_from_file_location(mod_name, so_path) + if spec is None or spec.loader is None: + return None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + except Exception: + return None + + +def _compile_evt_module( + ir_json: str, + a_dtype: torch.dtype, + b_dtype: torch.dtype, + b_layout: str = "row", + m_bucket: str = "medium", + N: int = 0, + K: int = 0, + alignment_a_bits: int = 128, + alignment_b_bits: int = 128, + alignment_c_bits: int = 128, +): + """Render + JIT-compile the EVT kernel for ``ir_json``. Process-level cached. + + Each distinct (N, K) gets its own module so autotune state is isolated. + """ + arch = _device_arch_tag() + fast_key = ( + ir_json, + a_dtype, + b_dtype, + b_layout, + m_bucket, + N, + K, + alignment_a_bits, + alignment_b_bits, + alignment_c_bits, + arch, + ) + cached = _MODULE_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + if b_layout not in ("row", "col"): + raise ValueError(f"b_layout must be 'row' or 'col', got {b_layout!r}") + a_str = _DTYPE_TO_STR[a_dtype] + b_str = _DTYPE_TO_STR[b_dtype] + extended = json.dumps( + { + "ir": ir_json, + "a": a_str, + "b": b_str, + "b_layout": b_layout, + "m_bucket": m_bucket, + "N": int(N), + "K": int(K), + "alignA_bits": int(alignment_a_bits), + "alignB_bits": int(alignment_b_bits), + "alignC_bits": int(alignment_c_bits), + "arch": arch, + "version": 10, + }, + sort_keys=True, + ).encode("utf-8") + key = hashlib.sha256(extended).hexdigest() + + cached = _MODULE_CACHE.get(key) + if cached is not None: + _MODULE_FAST_CACHE[fast_key] = cached + return cached + + lock = _per_key_lock(key) + with lock: + cached = _MODULE_CACHE.get(key) + if cached is not None: + _MODULE_FAST_CACHE[fast_key] = cached + return cached + + # sm_90 → Sm90EVT (TMA+WGMMA); else → Sm80EVT (cp.async). + ir = _ir_from_json(ir_json) + render_fn = _render_evt_cu_sm90 if arch == "sm90" else _render_evt_cu_sm80 + src = render_fn( + ir, + a_str, + b_str, + cache_key_str=key, + b_layout=b_layout, + m_bucket=m_bucket, + alignment_a_bits=alignment_a_bits, + alignment_b_bits=alignment_b_bits, + alignment_c_bits=alignment_c_bits, + arch=arch, + ) + + build_dir = _evt_build_dir(key) + os.makedirs(build_dir, exist_ok=True) + mod_name = f"magi_evt_{key[:12]}" + + # Warm-cache fast path: if a previous run already produced the .so + # for this exact key, dlopen it directly and skip cpp_extension.load + # (and its FileBaton) entirely. Makes repeated runs / multi-rank + # warm starts immune to any lock-file hang. + prebuilt = _try_dlopen_prebuilt(build_dir, mod_name) + if prebuilt is not None: + _MODULE_CACHE[key] = prebuilt + _MODULE_FAST_CACHE[fast_key] = prebuilt + return prebuilt + + src_path = os.path.join(build_dir, "evt.cu") + # Atomic write: tmp + rename to avoid half-written files across ranks. + tmp_path = f"{src_path}.{os.getpid()}.tmp" + with open(tmp_path, "w") as f: + f.write(src) + os.replace(tmp_path, src_path) + + cutlass_root = get_compile_config().cutlass_root + from torch.utils.cpp_extension import load + + # SM90 needs extra cflags for warp-specialized collectives + extended MMA. + sm90_specific_cflags = ( + ["--expt-extended-lambda", "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED=1"] if arch == "sm90" else [] + ) + + # -fvisibility=hidden gives each .so its own copy of CUTLASS template + # static members like GemmUniversalBase::device_ordinal_. + # Without this, two .so files that instantiate the same GemmKernel + # type (e.g. medium and large m-bucket modules sharing the same EVT + # chain + tile shape) collide on the static symbol — the first .so + # to call init_device_props() poisons the cache for all later .so + # files: their kernels never get cudaFuncSetAttribute called, so any + # launch above the default 48 KB dynamic SMEM fails with cudaError- + # InvalidValue ("invalid argument"). + _track_build(build_dir) + try: + module = load( + name=mod_name, + sources=[src_path], + extra_include_paths=[ + os.path.join(cutlass_root, "include"), + os.path.join(cutlass_root, "tools", "util", "include"), + ], + extra_cflags=["-O3", "-std=c++17", "-fvisibility=hidden", "-fvisibility-inlines-hidden"], + extra_cuda_cflags=( + [ + "-std=c++17", + "-O3", + "--expt-relaxed-constexpr", + "-Xcompiler=-fvisibility=hidden", + "-Xcompiler=-fvisibility-inlines-hidden", + ] + + sm90_specific_cflags + + _device_gencode_flags() + ), + build_directory=build_dir, + verbose=False, + ) + finally: + _untrack_build(build_dir) + _MODULE_CACHE[key] = module + _MODULE_FAST_CACHE[fast_key] = module + return module + + +# ── IR (de)serialisation ───────────────────────────────────────────────────── + + +def to_ir_json(node) -> str: + from .evt_ir import to_canonical_json + + return to_canonical_json(node) + + +def _ir_from_json(s: str): + """Inverse of ``to_canonical_json``. Used only at codegen time.""" + d = json.loads(s) + return _node_from_dict(d) + + +def _node_from_dict(d): + kind = d["kind"] + if kind == "accum": + return Accum() + if kind == "row_bcast": + return RowBroadcast(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "col_bcast": + return ColBroadcast(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "aux_load": + return AuxLoad(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "compute": + scalar = d.get("scalar") + scalar_val: Optional[float] = float(scalar) if scalar is not None else None + compute_dtype = d.get("compute_dtype", "float32") + return Compute( + op=d["op"], + children=tuple(_node_from_dict(c) for c in d["children"]), + scalar=scalar_val, + compute_dtype=compute_dtype, + ) + if kind == "store": + return Store(child=_node_from_dict(d["child"]), out_dtype=d["out_dtype"]) + raise ValueError(f"Unknown IR kind {kind!r}") + + +# Per-(m_bucket, N, K, align) cache — separate modules so each runner has its +# own autotune state (best_idx_). +_SWIGLU_FAST_CACHE: dict = {} +_SWIGLU_BUILD_LOCKS: dict = {} + + +def _compile_swiglu_dual( + m_bucket: str, N: int, K: int, alignment_a_bits: int = 128, alignment_b_bits: int = 128, alignment_c_bits: int = 128 +): + """Lazy-load a per-(bucket, N, K, align) DualGemm kernel module.""" + fast_key = (m_bucket, int(N), int(K), int(alignment_a_bits), int(alignment_b_bits), int(alignment_c_bits)) + cached = _SWIGLU_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + with _SWIGLU_LOCK: + lock = _SWIGLU_BUILD_LOCKS.get(fast_key) + if lock is None: + lock = threading.Lock() + _SWIGLU_BUILD_LOCKS[fast_key] = lock + with lock: + cached = _SWIGLU_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + cutlass_root = get_compile_config().cutlass_root + here = os.path.dirname(os.path.abspath(__file__)) + # sm_90 → TMA+WGMMA DualGemm; else → SM80 multistage path. + arch_tag = _device_arch_tag() + arch_subdir = "sm90" if arch_tag == "sm90" else "sm80" + src = os.path.join(here, arch_subdir, "cutlass_kernels", "swiglu_one_stage.cu") + if not os.path.exists(src): + raise FileNotFoundError(f"vendored swiglu source not found: {src}") + cache_root = get_compile_config().cache_root_dir + # Build dir embeds (arch, bucket, N, K, align) — stale cross-arch + # binaries cause cudaErrorInvalidDeviceFunction. + build_tag = f"{m_bucket}_N{N}_K{K}" f"_aA{alignment_a_bits}_aB{alignment_b_bits}_aC{alignment_c_bits}" + build_dir = os.path.join(cache_root, "evt_kernels", arch_tag, f"swiglu_dual_{build_tag}") + os.makedirs(build_dir, exist_ok=True) + mod_name = f"magi_swiglu_dual_{build_tag}" + + # Warm-cache fast path — see _compile_evt_module for rationale. + prebuilt = _try_dlopen_prebuilt(build_dir, mod_name) + if prebuilt is not None: + _SWIGLU_FAST_CACHE[fast_key] = prebuilt + return prebuilt + + from torch.utils.cpp_extension import load + + # SM90 needs extra cflags for WGMMA + warp-specialized collective. + sm90_specific_cflags = ( + ["--expt-extended-lambda", "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED=1"] if arch_tag == "sm90" else [] + ) + + sm90_include_paths = [os.path.join(here, "sm90", "cutlass_kernels")] if arch_tag == "sm90" else [] + + # -fvisibility=hidden — see _compile_evt_module above for rationale. + _track_build(build_dir) + try: + module = load( + name=mod_name, + sources=[src], + extra_include_paths=[ + os.path.join(cutlass_root, "include"), + os.path.join(cutlass_root, "tools", "util", "include"), + os.path.join(cutlass_root, "examples"), + os.path.join(here, "common", "cutlass_kernels"), + *sm90_include_paths, + ], + extra_cflags=["-O3", "-std=c++17", "-fvisibility=hidden", "-fvisibility-inlines-hidden"], + extra_cuda_cflags=[ + "-std=c++17", + "-O3", + "--expt-relaxed-constexpr", + "-Xcompiler=-fvisibility=hidden", + "-Xcompiler=-fvisibility-inlines-hidden", + *sm90_specific_cflags, + *_device_gencode_flags(), + f"-DMAGI_SWIGLU_ALIGN_A_BITS={int(alignment_a_bits)}", + f"-DMAGI_SWIGLU_ALIGN_B_BITS={int(alignment_b_bits)}", + f"-DMAGI_SWIGLU_ALIGN_C_BITS={int(alignment_c_bits)}", + ], + build_directory=build_dir, + verbose=False, + ) + finally: + _untrack_build(build_dir) + _SWIGLU_FAST_CACHE[fast_key] = module + return module + + +# ── Dispatch fast-cache ────────────────────────────────────────────────────── +# Collapses out_dtype_from_id → _m_bucket → _compile_* → mod.attr-lookup +# into a single dict.get(). Keyed by (kind, ir_json, dtypes, N, K, m_bucket, +# out_dtype); reaches steady state after the first call per (site, bucket). +class _DispatchEntry: + __slots__ = ("kernel_call", "is_evt", "out_dtype") + + def __init__(self, kernel_call, is_evt, out_dtype): + self.kernel_call = kernel_call + self.is_evt = is_evt + self.out_dtype = out_dtype + + +_DISPATCH_CACHE: dict = {} + + +def _resolve_dispatch(kind, ir_json, a_dtype, b_dtype, N_w, K_w, m_bucket, out_dtype): + """Slow-path resolver — compiles the .cu module and binds the kernel callable.""" + n_out_for_c = (N_w // 2) if kind == "swiglu_dual" else N_w + ldd = _aligned_n_stride(n_out_for_c, out_dtype) + alignment_c_bits = _runtime_align_bits(ldd, out_dtype) + + if kind == "swiglu_dual": + # K alignment also covers ldB=2K. + align_bits = _runtime_align_bits(K_w, a_dtype) + mod = _compile_swiglu_dual( + m_bucket, N_w, K_w, alignment_a_bits=align_bits, alignment_b_bits=align_bits, alignment_c_bits=alignment_c_bits + ) + sw7 = json.loads(ir_json) if ir_json else {} + sw7_alpha = float(sw7.get("alpha", 1.702)) + sw7_limit = float(sw7.get("limit", 7.0)) + sw7_one = float(sw7.get("one", 1.0)) + kernel_fn = mod.swiglu_dual_matmul_out + + def _sw7_call(A, B, D, _fn=kernel_fn, _a=sw7_alpha, _l=sw7_limit, _o=sw7_one): + return _fn(A, B, D, _a, _l, _o) + + return _DispatchEntry(_sw7_call, False, out_dtype) + if kind == "evt_row" or kind == "evt": + b_layout = "row" + elif kind == "evt_col": + b_layout = "col" + else: + raise ValueError(f"Unknown EVT kind {kind!r}") + alignment_a_bits = _runtime_align_bits(K_w, a_dtype) + b_lead_dim = N_w if b_layout == "row" else K_w + alignment_b_bits = _runtime_align_bits(b_lead_dim, b_dtype) + mod = _compile_evt_module( + ir_json, + a_dtype, + b_dtype, + b_layout=b_layout, + m_bucket=m_bucket, + N=N_w, + K=K_w, + alignment_a_bits=alignment_a_bits, + alignment_b_bits=alignment_b_bits, + alignment_c_bits=alignment_c_bits, + ) + return _DispatchEntry(mod.evt_matmul_out, True, out_dtype) + + +@torch.library.impl(_LIB, "matmul_fused_epilogue", "CUDA") +def _matmul_fused_epilogue_cuda(A, B, extras, ir_json, kind, n_out, out_dtype_id_): + """Runtime entry point. Do NOT call .contiguous() on B — the FX pass + controls the layout (evt_row=RowMajor, evt_col/swiglu=ColumnMajor).""" + # B.size(0)/size(1) avoids the Python tuple construction of .shape. + B_size0 = B.size(0) + B_size1 = B.size(1) + M = A.size(0) + if M <= 256: + m_bucket = "small" + elif M <= 2048: + m_bucket = "medium" + else: + m_bucket = "large" + out_dtype = _ID_TO_DTYPE[out_dtype_id_] + a_dtype = A.dtype + b_dtype_ = B.dtype + fast_key = (kind, ir_json, a_dtype, b_dtype_, B_size0, B_size1, m_bucket, out_dtype) + entry = _DISPATCH_CACHE.get(fast_key) + if entry is None: + # Map B sizes to (N_w, K_w) in the layout the compile path expects. + if kind == "evt_row": + K_w, N_w = B_size0, B_size1 + else: + # evt_col / swiglu_dual: B is (N, K) underlying weight. + N_w, K_w = B_size0, B_size1 + entry = _resolve_dispatch(kind, ir_json, a_dtype, b_dtype_, N_w, K_w, m_bucket, out_dtype) + _DISPATCH_CACHE[fast_key] = entry + + n_pad = _aligned_n_stride(n_out, out_dtype) + D_pad = torch.empty((M, n_pad), device=A.device, dtype=out_dtype) + D = D_pad[:, :n_out] if n_pad != n_out else D_pad + + kernel_call = entry.kernel_call + if entry.is_evt: + kernel_call(A, B, extras, D) + else: + # swiglu_dual: extras is always [] here (FX pass guarantees). + kernel_call(A, B, D) + return D + + +@torch.library.register_fake("magi_epilogue::matmul_fused_epilogue") +def _matmul_fused_epilogue_fake(A, B, extras, ir_json, kind, n_out, out_dtype_id_): + out_dtype = out_dtype_from_id(out_dtype_id_) + # Strided (M, n_out) view of an (M, n_pad) buffer — must match the + # stride layout the CUDA impl actually returns, otherwise Inductor's + # downstream view metadata desyncs from the real tensor. + n_pad = _aligned_n_stride(n_out, out_dtype) + return A.new_empty_strided((A.shape[0], n_out), (n_pad, 1), dtype=out_dtype) diff --git a/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py b/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py new file mode 100644 index 0000000..ee4d2aa --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py @@ -0,0 +1,911 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FX pass that fuses aten.mm + elementwise epilogue into a CUTLASS EVT call. + +Two backends: + * Generic EVT — for the 6 non-swiglu activations and 1-D bias/scale variants. + Builds an IR tree (see ``evt_ir.py``), serialises to JSON, replaces the + matched chain with a single ``torch.ops.magi_epilogue.matmul_fused_epilogue`` + call. The runtime renders + JIT-compiles a CUTLASS Sm80EVT kernel keyed by + the IR hash (see ``evt_runtime.py``). + * swiglu — pattern-matches the canonical recipe (slice-stride-2 + dual + clamps + scaled SiLU) and dispatches to a vendored DualGemm one-stage + kernel that writes (M, N/2) directly. + +Eligibility gates (alignment, B layout, dtype) are checked up-front. Anything +not eligible stays as ``aten.mm`` for cuBLAS to handle. We do NOT fall back to +the Triton fusion path on sm120; per user decision, EVT replaces it entirely. +""" + +from __future__ import annotations + +import json +import operator +from typing import List, Optional, Tuple + +import torch +import torch.fx as fx + +from magi_compiler.passes.pass_base import MagiInductorPass +from magi_compiler.utils.device import device_capability_major + +from . import evt_runtime # ensures torch.library op + fake impl are registered +from .evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store, is_trivial, num_extras, to_canonical_json +from .evt_runtime import GREEDY_ALIGN_BITS + +# ── Op tables ──────────────────────────────────────────────────────────────── +# Pure passthrough — no value or dtype change; alias the same IR node. +_PASSTHROUGH_OPS = frozenset({torch.ops.aten.clone.default, torch.ops.aten.contiguous.default, torch.ops.aten.alias.default}) + +# Dtype-conversion ops update current_compute_dtype so downstream Compute nodes +# use the target precision (e.g. to(bf16) → subsequent ops run in bf16). +_TYPE_CONV_OPS = frozenset({torch.ops.prims.convert_element_type.default, torch.ops.aten._to_copy.default}) + +# Unary ops with a direct EVT IR equivalent. +_UNARY_OPS = { + torch.ops.aten.neg.default: "neg", + torch.ops.aten.sigmoid.default: "sigmoid", + torch.ops.aten.tanh.default: "tanh", + torch.ops.aten.silu.default: "silu", + torch.ops.aten.relu.default: "relu", + torch.ops.aten.square.default: "square", + torch.ops.aten.erf.default: "erf", + torch.ops.aten.exp.default: "exp", + torch.ops.aten.log.default: "log", + torch.ops.aten.sqrt.default: "sqrt", + torch.ops.aten.rsqrt.default: "rsqrt", + torch.ops.aten.abs.default: "abs", +} + +_BINARY_OPS = { + torch.ops.aten.add.Tensor: "add", + torch.ops.aten.sub.Tensor: "sub", + torch.ops.aten.mul.Tensor: "mul", + torch.ops.aten.div.Tensor: "div", + torch.ops.aten.maximum.default: "max", + torch.ops.aten.minimum.default: "min", + operator.add: "add", + operator.sub: "sub", + operator.mul: "mul", + operator.truediv: "div", +} + +# Scalar binary ops → SCALAR_UNARY_OPS in IR. +_SCALAR_BINARY_TO_SCALAR_UNARY = { + torch.ops.aten.add.Scalar: "add_scalar", + torch.ops.aten.sub.Scalar: "sub_scalar", + torch.ops.aten.mul.Scalar: "mul_scalar", + torch.ops.aten.div.Scalar: "div_scalar", +} + + +# Output-dtype encode helper (mirrors evt_runtime). +_DTYPE_TO_STR = {torch.bfloat16: "bfloat16", torch.float16: "float16", torch.float32: "float32"} + + +def _val_dtype(node) -> Optional[torch.dtype]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + return val.dtype if val is not None else None + + +def _val_shape(node) -> Optional[Tuple]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + return tuple(val.shape) if val is not None else None + + +def _val_stride(node) -> Optional[Tuple]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + try: + return tuple(val.stride()) if val is not None else None + except Exception: + return None + + +def _is_static_int(x) -> bool: + return type(x) is int + + +# Greedy alignment: try 128-bit first, fall back to 64-bit. CUTLASS only needs +# the leading dim divisible by AlignmentX, so picking the largest power-of-2 +# that fits gets us vectorised loads when shapes allow but doesn't lock out +# 64-bit-only shapes (e.g. K=12 for bf16 → 4-elem-aligned). +def _largest_pow2_align_bits(n, dtype: torch.dtype) -> Optional[int]: + """Return the largest bit-width in (128, 64) that divides ``n * itemsize_bits``. + + For dynamic ``n`` (SymInt) we conservatively return the smallest candidate + (64) — runtime is the authoritative gate; we just need to admit the fusion + here. Returns None when even the smallest candidate doesn't fit, in which + case the caller must abort fusion. + """ + if not _is_static_int(n): + return GREEDY_ALIGN_BITS[-1] + n_int = int(n) + for bits in GREEDY_ALIGN_BITS: + align_elems = max(1, bits // (dtype.itemsize * 8)) + if n_int % align_elems == 0: + return bits + return None + + +def _is_transpose_node(n) -> bool: + """True iff ``n`` is a 2-D transpose (aten.t / transpose(0,1) / permute([1,0])).""" + if not isinstance(n, fx.Node) or n.op != "call_function": + return False + if n.target is torch.ops.aten.t.default: + return True + if n.target is torch.ops.aten.transpose.int: + # transpose(x, dim0, dim1) — accept (0, 1) on a 2D tensor. + if len(n.args) >= 3: + d0, d1 = n.args[1], n.args[2] + return {d0, d1} == {0, 1} + return False + if n.target is torch.ops.aten.permute.default: + # permute(x, [1, 0]) on a 2D tensor. + if len(n.args) >= 2: + perm = n.args[1] + return list(perm) == [1, 0] + return False + return False + + +def _b_layout_kind(B_node): + """Classify B for the EVT generic path. + + Returns (b_layout, underlying_b_node, n_dim) where: + * b_layout = "row" : B is (K, N) row-major contiguous; pass B as-is. + * b_layout = "col" : B is a stride-transpose of a contiguous (N, K) + tensor; pass the underlying tensor; kernel uses + LayoutB=ColumnMajor. + * (None, None, None) : B is not in a supported layout. + """ + shape = _val_shape(B_node) + stride = _val_stride(B_node) + if shape is None or stride is None or len(shape) != 2: + return None, None, None + K_or_N0, N_or_K1 = shape[0], shape[1] + # Contiguous (K, N): row layout. N = shape[1]. + if stride == (N_or_K1, 1): + return "row", B_node, N_or_K1 + # Stride-transposed (K, N) view of a contig (N, K) weight: stride == (1, K). + # Only accept an explicit t/transpose/permute([1,0]) so we can pass the + # underlying (N, K) row-major weight to the runtime. A bare stride-only + # view would keep the (K, N) logical shape, causing the runtime to swap + # N_w and K_w (it assumes B.size(0)=N for evt_col). + if _is_transpose_node(B_node): + weight = B_node.args[0] + w_shape = _val_shape(weight) if isinstance(weight, fx.Node) else None + w_stride = _val_stride(weight) if isinstance(weight, fx.Node) else None + if w_shape is not None and len(w_shape) == 2 and w_stride == (w_shape[1], 1): + # weight is (N, K) row-major contig; N = w_shape[0]. + return "col", weight, w_shape[0] + return None, None, None + + +# ── swiglu structural validation ─────────────────────────────────────────── +def _validate_swiglu_structure(chain_nodes: List[fx.Node], mm_node: fx.Node) -> Optional[Tuple[float, float, float]]: + """Strictly validate the decomposed swiglu pattern and extract constants. + + The canonical decomposition is:: + + mm → _to_copy(fp32) + → slice(dim=1, start=0, step=2) [gate] + → slice(dim=1, start=1, step=2) [linear] + → clamp(gate, None, limit) + → clamp(linear, -limit, limit) + → mul(gate_clamp, alpha) → sigmoid → mul(gate_clamp, sigmoid) + → add(linear_clamp, one) → mul(gate_silu, linear_offset) + → _to_copy(out_dtype) + + Returns ``(alpha, limit, one)`` on match, ``None`` on structural mismatch. + """ + + # ── Phase 1: classify nodes into roles ────────────────────────────────── + gate_slice: Optional[fx.Node] = None + linear_slice: Optional[fx.Node] = None + gate_clamp: Optional[fx.Node] = None + linear_clamp: Optional[fx.Node] = None + alpha_mul: Optional[fx.Node] = None + sigmoid_node: Optional[fx.Node] = None + gate_silu: Optional[fx.Node] = None + linear_add: Optional[fx.Node] = None + final_mul: Optional[fx.Node] = None + + limit: Optional[float] = None + alpha: Optional[float] = None + one: Optional[float] = None + + _clamp_targets = {torch.ops.aten.clamp.default, torch.ops.aten.clamp_max.default, torch.ops.aten.clamp_min.default} + _mul_targets = {torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar} + _add_targets = {torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar} + + linear_clamp_min: Optional[fx.Node] = None + linear_clamp_min_val: Optional[float] = None + + for n in chain_nodes: + t = n.target + + # ── stride-2 slices ───────────────────────────────────────────── + if t == torch.ops.aten.slice.Tensor: + if len(n.args) >= 4 and n.args[1] == 1 and (len(n.args) < 5 or n.args[4] == 2): + step = n.args[4] if len(n.args) >= 5 else 1 + if step != 2: + continue + start = n.args[2] + if start == 0 and gate_slice is None: + gate_slice = n + elif start == 1 and linear_slice is None: + linear_slice = n + + # ── clamp ─────────────────────────────────────────────────────── + elif t in _clamp_targets: + if t == torch.ops.aten.clamp_min.default: + # clamp_min(linear_slice, -limit) — first half of decomposed + # linear clamp: clamp(x, -limit, limit) → clamp_min → clamp_max + if ( + len(n.args) >= 2 + and isinstance(n.args[0], fx.Node) + and isinstance(n.args[1], (int, float)) + and n.args[0] is linear_slice + and linear_clamp_min is None + ): + linear_clamp_min = n + linear_clamp_min_val = float(n.args[1]) + + elif t == torch.ops.aten.clamp_max.default: + if len(n.args) >= 2 and isinstance(n.args[0], fx.Node) and isinstance(n.args[1], (int, float)): + if n.args[0] is gate_slice and gate_clamp is None: + gate_clamp = n + limit = float(n.args[1]) + elif n.args[0] is linear_clamp_min and linear_clamp is None: + linear_clamp = n + else: + # clamp.default(x, min_val, max_val) + if len(n.args) >= 3 and isinstance(n.args[0], fx.Node): + min_val = n.args[1] + max_val = n.args[2] + if ( + isinstance(max_val, (int, float)) + and n.args[0] is gate_slice + and (min_val is None) + and gate_clamp is None + ): + gate_clamp = n + limit = float(max_val) + elif ( + isinstance(min_val, (int, float)) + and isinstance(max_val, (int, float)) + and n.args[0] is linear_slice + and linear_clamp is None + ): + linear_clamp = n + + # ── sigmoid ───────────────────────────────────────────────────── + elif t == torch.ops.aten.sigmoid.default: + if sigmoid_node is None: + sigmoid_node = n + + # ── mul / add ─────────────────────────────────────────────────── + elif t in _mul_targets: + if ( + len(n.args) >= 2 + and isinstance(n.args[1], (int, float)) + and any(u.target == torch.ops.aten.sigmoid.default for u in n.users) + ): + alpha_mul = n + alpha = float(n.args[1]) + # Other muls are classified in Phase 2 (need sigmoid_node first). + + elif t in _add_targets: + if len(n.args) >= 2 and isinstance(n.args[0], fx.Node) and isinstance(n.args[1], (int, float)): + if n.args[0] is linear_clamp and linear_add is None: + linear_add = n + one = float(n.args[1]) + + # ── Phase 2: classify mul nodes that depend on sigmoid ────────────────── + for n in chain_nodes: + t = n.target + if t not in _mul_targets: + continue + if n is alpha_mul: + continue + if len(n.args) < 2: + continue + a0, a1 = n.args[0], n.args[1] + if not (isinstance(a0, fx.Node) and isinstance(a1, fx.Node)): + continue + # gate_silu = mul(gate_clamp, sigmoid) + if ( + gate_silu is None + and {a0, a1} == {gate_clamp, sigmoid_node} + and gate_clamp is not None + and sigmoid_node is not None + ): + gate_silu = n + # final_mul = mul(gate_silu, linear_add) + elif final_mul is None and gate_silu is not None and linear_add is not None and {a0, a1} == {gate_silu, linear_add}: + final_mul = n + + # ── Phase 3: validate all required roles are present ──────────────────── + if any( + x is None + for x in ( + gate_slice, + linear_slice, + gate_clamp, + linear_clamp, + alpha_mul, + sigmoid_node, + gate_silu, + linear_add, + final_mul, + ) + ): + return None + + if alpha is None or limit is None or one is None: + return None + + # ── Phase 4: cross-validate data-flow edges ───────────────────────────── + + # Both slices must share the same source (the _to_copy of mm). + if gate_slice.args[0] is not linear_slice.args[0]: + return None + + # Linear clamp limits must be consistent: min == -limit, max == limit. + # Two forms: + # (a) clamp.default(x, -limit, limit) — single op + # (b) clamp_min(x, -limit) → clamp_max(_, limit) — decomposed pair + if linear_clamp.target == torch.ops.aten.clamp.default: + lin_min = linear_clamp.args[1] + lin_max = linear_clamp.args[2] + if not (isinstance(lin_min, (int, float)) and float(lin_min) == -limit): + return None + if not (isinstance(lin_max, (int, float)) and float(lin_max) == limit): + return None + elif linear_clamp.target == torch.ops.aten.clamp_max.default and linear_clamp_min is not None: + if not (isinstance(linear_clamp_min_val, (int, float)) and float(linear_clamp_min_val) == -limit): + return None + lin_max_val = linear_clamp.args[1] + if not (isinstance(lin_max_val, (int, float)) and float(lin_max_val) == limit): + return None + else: + return None + + # sigmoid input must be alpha_mul. + if sigmoid_node.args[0] is not alpha_mul: + return None + + # alpha_mul input must be gate_clamp. + if alpha_mul.args[0] is not gate_clamp: + return None + + return (alpha, limit, one) + + +# ── swiglu weight / chain validation ────────────────────────────────────── + + +_SWIGLU7_CHAIN_OPS = frozenset( + { + torch.ops.aten.slice.Tensor, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp_min.default, + torch.ops.aten.clamp_max.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor, + torch.ops.aten.add.Scalar, + torch.ops.aten.mul.Scalar, + torch.ops.prims.convert_element_type.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.clone.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.alias.default, + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, + torch.ops.aten._unsafe_view.default, + } +) + + +def _validate_swiglu_weight(mm_node: fx.Node) -> Optional[Tuple[fx.Node, fx.Node, int, int]]: + """Check B's underlying data is contiguous (N, K) bf16 with N even. + + K alignment and A/B dtype-compatibility are guaranteed by the caller + (``_try_fuse_evt``). This validates swiglu-specific constraints only. + + Requires an explicit transpose node (``t(weight)``) so we can extract the + underlying ``weight`` with shape (N, K). The runtime reads ``B.size(0)`` + as N, so the tensor passed to the kernel must be (N, K)-shaped. + + Returns ``(B_node, weight_node, N, K)`` on success, ``None`` on failure. + """ + B_node = mm_node.args[1] + if not isinstance(B_node, fx.Node) or not _is_transpose_node(B_node): + return None + weight_node = B_node.args[0] + if not isinstance(weight_node, fx.Node): + return None + w_shape = _val_shape(weight_node) + w_stride = _val_stride(weight_node) + if w_shape is None or len(w_shape) != 2 or w_stride is None: + return None + N, K = w_shape + if w_stride != (K, 1): + return None + if not (_is_static_int(N) and N % 2 == 0): + return None + if _val_dtype(mm_node.args[0]) != torch.bfloat16 or _val_dtype(weight_node) != torch.bfloat16: + return None + # SM90 TMA requires K * sizeof(elem) % 16 == 0. + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 9: + elem_bytes = torch.bfloat16.itemsize + if _is_static_int(K) and (int(K) * elem_bytes) % 16 != 0: + return None + return B_node, weight_node, N, K + + +def _validate_swiglu_chain(mm_node: fx.Node, N: int) -> Optional[Tuple[List[fx.Node], fx.Node, torch.dtype, str]]: + """Collect the epilogue chain, validate shape/escape/structure, extract constants. + + Returns ``(chain_nodes, last_chain_node, out_dt, sw7_json)`` on success, + ``None`` on failure. + """ + chain_nodes: List[fx.Node] = [] + chain_set: set = {mm_node} + last_chain_node: Optional[fx.Node] = None + curr = mm_node.next + while curr is not None and curr.op != "output": + uses_chain = any(isinstance(a, fx.Node) and a in chain_set for a in curr.args) + if not uses_chain: + curr = curr.next + continue + if curr.target not in _SWIGLU7_CHAIN_OPS: + break + chain_nodes.append(curr) + chain_set.add(curr) + last_chain_node = curr + curr = curr.next + + if last_chain_node is None: + return None + out_dt = _val_dtype(last_chain_node) or torch.bfloat16 + out_shape = _val_shape(last_chain_node) + if out_shape is None or len(out_shape) != 2: + return None + if not _is_static_int(out_shape[1]) or out_shape[1] != N // 2: + return None + # Refuse if any intermediate escapes the fused region. + for n in chain_nodes[:-1]: + for u in n.users: + if u not in chain_set: + return None + constants = _validate_swiglu_structure(chain_nodes, mm_node) + if constants is None: + return None + sw7_alpha, sw7_limit, sw7_one = constants + sw7_json = json.dumps({"alpha": sw7_alpha, "limit": sw7_limit, "one": sw7_one}, sort_keys=True) + return chain_nodes, last_chain_node, out_dt, sw7_json + + +# ── Shared graph-rewrite helper ──────────────────────────────────────────── + + +def _emit_and_replace( + graph: fx.Graph, + last_node: fx.Node, + op_args: tuple, + nodes_to_erase: List[fx.Node], + extra_dead: Optional[List[fx.Node]] = None, +) -> fx.Node: + """Insert ``matmul_fused_epilogue``, propagate meta, replace uses, erase dead nodes.""" + with graph.inserting_after(last_node): + new_node = graph.call_function(torch.ops.magi_epilogue.matmul_fused_epilogue.default, args=op_args) + val_last = last_node.meta.get("val") + if val_last is not None: + try: + n_pad = evt_runtime._aligned_n_stride(int(val_last.shape[-1]), val_last.dtype) + except (TypeError, ValueError): + n_pad = None + if n_pad is not None: + new_node.meta["val"] = val_last.new_empty_strided(val_last.shape, (n_pad, 1)) + + last_node.replace_all_uses_with(new_node) + for n in reversed(nodes_to_erase): + if len(n.users) == 0 and n is not new_node: + graph.erase_node(n) + if extra_dead: + for n in extra_dead: + if isinstance(n, fx.Node) and len(n.users) == 0: + graph.erase_node(n) + return new_node + + +# ── Pass ───────────────────────────────────────────────────────────────────── + + +class MatmulEvtEpilogueFusionPass(MagiInductorPass): + """Fuse aten.mm + elementwise chain into a CUTLASS EVT call. + + Active on: + * sm_90 (Hopper / H100) — lowers via CUTLASS 3.x Sm90EVT codegen. + * sm_120+ (Blackwell consumer) — lowers via CUTLASS 2.x Sm80EVT codegen. + + The codegen renderer is picked inside ``evt_runtime._compile_evt_module`` + based on the live device's arch tag. + """ + + def __init__(self, allow_extras: bool = True) -> None: + major = device_capability_major() + self._enabled = major == 9 or major >= 12 + self.allow_extras = allow_extras + + def __call__(self, graph: fx.Graph) -> bool: + if not self._enabled: + return False + fused = 0 + for node in list(graph.nodes): + if node.op != "call_function": + continue + if node.target not in (torch.ops.aten.mm.default, torch.ops.aten.mm): + continue + if self._try_fuse_evt(graph, node): + fused += 1 + if fused: + graph.eliminate_dead_code() + return fused > 0 + + # ── Generic EVT chain walker ────────────────────────────────────────────── + + def _try_fuse_evt(self, graph: fx.Graph, mm_node: fx.Node) -> bool: + A, B = mm_node.args[0], mm_node.args[1] + if not isinstance(A, fx.Node) or not isinstance(B, fx.Node): + return False + a_dtype = _val_dtype(A) + b_dtype = _val_dtype(B) + if a_dtype not in (torch.bfloat16, torch.float16) or a_dtype != b_dtype: + return False + a_shape = _val_shape(A) + b_shape = _val_shape(B) + if a_shape is None or b_shape is None or len(a_shape) != 2 or len(b_shape) != 2: + return False + K = a_shape[1] + if _largest_pow2_align_bits(K, a_dtype) is None: + return False + # SM90 TMA requires globalStride to be 16-byte aligned. A is + # RowMajor (M, K) so stride_A[0] = K; need K * elem_bytes % 16 == 0. + # (For bf16 this reduces to K % 8 == 0.) + if ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() == (9, 0) + and _is_static_int(K) + and (int(K) * a_dtype.itemsize) % 16 != 0 + ): + return False + a_stride = _val_stride(A) + if a_stride is None or a_stride != (a_shape[1], 1): + return False + + node_to_ir: dict = {mm_node: Accum()} + fused_nodes: List[fx.Node] = [mm_node] + walk_seen: List[fx.Node] = [mm_node] + extras_nodes: List[fx.Node] = [] + saw_slice = False + current_compute_dtype = "float32" + last_node = mm_node + last_ir = node_to_ir[mm_node] + + # ── Walker-local helpers ── + curr = mm_node.next + + def _absorb(ir): + nonlocal last_node, last_ir, curr + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + + def _alias(existing_ir): + nonlocal last_node, last_ir, curr + node_to_ir[curr] = existing_ir + walk_seen.append(curr) + last_node = curr + last_ir = existing_ir + curr = curr.next + + # Walk consumers in source order, greedily absorbing supported ops. + while curr is not None and curr.op != "output": + if not any(isinstance(a, fx.Node) and a in node_to_ir for a in curr.args): + curr = curr.next + continue + + target = curr.target + + if target in _PASSTHROUGH_OPS: + _alias(node_to_ir[curr.args[0]]) + continue + + if target in _TYPE_CONV_OPS: + target_dtype = _val_dtype(curr) + if target_dtype is not None and target_dtype in _DTYPE_TO_STR: + current_compute_dtype = _DTYPE_TO_STR[target_dtype] + _alias(node_to_ir[curr.args[0]]) + continue + + if target in (torch.ops.aten.view.default, torch.ops.aten.reshape.default, torch.ops.aten._unsafe_view.default): + if _val_shape(curr.args[0]) == _val_shape(curr): + _alias(node_to_ir[curr.args[0]]) + continue + break + + if target is torch.ops.aten.slice.Tensor: + step = curr.args[4] if len(curr.args) > 4 else curr.kwargs.get("step", 1) + if step == 2: + saw_slice = True + break + + if target in _UNARY_OPS: + _absorb(Compute(_UNARY_OPS[target], (node_to_ir[curr.args[0]],), compute_dtype=current_compute_dtype)) + continue + + if target is torch.ops.aten.gelu.default: + op_name = "gelu_tanh" if curr.kwargs.get("approximate", "none") == "tanh" else "gelu_erf" + _absorb(Compute(op_name, (node_to_ir[curr.args[0]],), compute_dtype=current_compute_dtype)) + continue + + if target in _SCALAR_BINARY_TO_SCALAR_UNARY: + if not isinstance(curr.args[1], (int, float)): + break + scalar_val = float(curr.args[1]) + if target in (torch.ops.aten.add.Scalar, torch.ops.aten.sub.Scalar): + alpha = curr.kwargs.get("alpha", 1) + if not isinstance(alpha, (int, float)): + break + scalar_val = float(alpha) * scalar_val + _absorb( + Compute( + _SCALAR_BINARY_TO_SCALAR_UNARY[target], + (node_to_ir[curr.args[0]],), + scalar=scalar_val, + compute_dtype=current_compute_dtype, + ) + ) + continue + + if target in (torch.ops.aten.clamp.default, torch.ops.aten.clamp_min.default, torch.ops.aten.clamp_max.default): + child_ir = node_to_ir[curr.args[0]] + if target is torch.ops.aten.clamp_min.default: + lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min") + hi = None + elif target is torch.ops.aten.clamp_max.default: + lo = None + hi = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("max") + else: + lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min") + hi = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("max") + if (lo is not None and not isinstance(lo, (int, float))) or ( + hi is not None and not isinstance(hi, (int, float)) + ): + break + ir_now = child_ir + if lo is not None: + ir_now = Compute("clamp_min_c", (ir_now,), scalar=float(lo), compute_dtype=current_compute_dtype) + if hi is not None: + ir_now = Compute("clamp_max_c", (ir_now,), scalar=float(hi), compute_dtype=current_compute_dtype) + _absorb(ir_now) + continue + + if target is torch.ops.aten.pow.Tensor_Scalar: + exp = curr.args[1] if len(curr.args) > 1 else None + child_ir = node_to_ir[curr.args[0]] + if exp == 2 or exp == 2.0: + _absorb(Compute("square", (child_ir,), compute_dtype=current_compute_dtype)) + elif isinstance(exp, (int, float)): + _absorb(Compute("pow_scalar", (child_ir,), scalar=float(exp), compute_dtype=current_compute_dtype)) + else: + break + continue + + if target in _BINARY_OPS: + ir = self._try_lower_binary(curr, target, node_to_ir, extras_nodes, A, B, current_compute_dtype) + if ir is None: + break + _absorb(ir) + continue + + break + + if saw_slice: + return self._try_fuse_swiglu(graph, mm_node) + + result = self._validate_evt_epilogue( + B, b_dtype, mm_node, node_to_ir, fused_nodes, walk_seen, last_node, last_ir, extras_nodes + ) + if result is None: + return False + ir_json, b_underlying, n_out, out_dt_id, kind = result + + _emit_and_replace(graph, last_node, (A, b_underlying, extras_nodes, ir_json, kind, n_out, out_dt_id), walk_seen) + return True + + # ── Post-walk EVT validation ────────────────────────────────────────────── + + def _validate_evt_epilogue( + self, B, b_dtype, mm_node, node_to_ir, fused_nodes, walk_seen, last_node, last_ir, extras_nodes + ): + """Post-walk eligibility gates for the generic EVT path. + + Returns ``(ir_json, b_underlying, n_out, out_dt_id, kind)`` on success, + ``None`` on any gate failure. + """ + if last_ir is node_to_ir[mm_node]: + return None + + fused_set = set(fused_nodes) | set(walk_seen) + for n in walk_seen[:-1]: + for u in n.users: + if u not in fused_set: + return None + + b_layout, b_underlying, n_dim = _b_layout_kind(B) + if b_layout is None: + return None + if b_layout == "row" and _largest_pow2_align_bits(n_dim, b_dtype) is None: + return None + + out_dt = _val_dtype(last_node) or torch.bfloat16 + if out_dt not in _DTYPE_TO_STR: + return None + + if torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) and _is_static_int(n_dim): + n_int = int(n_dim) + if (n_int * out_dt.itemsize) % 16 != 0: + return None + + ir_root = Store(child=last_ir, out_dtype=_DTYPE_TO_STR[out_dt]) + if is_trivial(ir_root): + return None + if not self.allow_extras and num_extras(ir_root) > 0: + return None + ir_json = to_canonical_json(ir_root) + kind = "evt_row" if b_layout == "row" else "evt_col" + return ir_json, b_underlying, n_dim, evt_runtime.out_dtype_id(out_dt), kind + + # ── Binary op lowering ──────────────────────────────────────────────────── + + def _try_lower_binary(self, curr, target, node_to_ir, extras_nodes, A, B, compute_dtype): + """Try to lower a binary op to IR. Returns an IR node or None (caller breaks).""" + op_name = _BINARY_OPS[target] + lhs_raw, rhs_raw = curr.args[0], curr.args[1] + + # aten.add.Tensor / aten.sub.Tensor carry an ``alpha`` kwarg: + # add(self, other, alpha=a) → self + a * other + # sub(self, other, alpha=a) → self - a * other + # operator.add/sub and mul/div/max/min have no alpha. + has_alpha = target in (torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor) + alpha = 1 + if has_alpha: + alpha = curr.kwargs.get("alpha", 1) + if not isinstance(alpha, (int, float)): + return None + + if isinstance(rhs_raw, (int, float)) and isinstance(lhs_raw, fx.Node) and lhs_raw in node_to_ir: + scalar_op = {"add": "add_scalar", "sub": "sub_scalar", "mul": "mul_scalar", "div": "div_scalar"}.get(op_name) + if scalar_op is None: + return None + scalar_val = float(alpha) * float(rhs_raw) if has_alpha else float(rhs_raw) + return Compute(scalar_op, (node_to_ir[lhs_raw],), scalar=scalar_val, compute_dtype=compute_dtype) + + if isinstance(lhs_raw, (int, float)) and isinstance(rhs_raw, fx.Node) and rhs_raw in node_to_ir: + rhs_ir = node_to_ir[rhs_raw] + if has_alpha and alpha != 1: + rhs_ir = Compute("mul_scalar", (rhs_ir,), scalar=float(alpha), compute_dtype=compute_dtype) + if op_name in ("add", "mul"): + scalar_op = "add_scalar" if op_name == "add" else "mul_scalar" + return Compute(scalar_op, (rhs_ir,), scalar=float(lhs_raw), compute_dtype=compute_dtype) + if op_name == "sub": + return Compute("rsub_scalar", (rhs_ir,), scalar=float(lhs_raw), compute_dtype=compute_dtype) + return None + + lhs_ir = self._ir_for_arg(lhs_raw, node_to_ir, extras_nodes, A, B) + rhs_ir = self._ir_for_arg(rhs_raw, node_to_ir, extras_nodes, A, B) + if lhs_ir is None or rhs_ir is None: + return None + if has_alpha and alpha != 1: + rhs_ir = Compute("mul_scalar", (rhs_ir,), scalar=float(alpha), compute_dtype=compute_dtype) + return Compute(op_name, (lhs_ir, rhs_ir), compute_dtype=compute_dtype) + + # ── External operand classification ─────────────────────────────────────── + + def _ir_for_arg(self, arg, node_to_ir, extras_nodes, A_node, B_node): + """Classify operand: internal → existing IR; external → leaf node; None ⇒ abort.""" + if not isinstance(arg, fx.Node): + return None + if arg in node_to_ir: + return node_to_ir[arg] + if not self.allow_extras: + return None + a_shape = _val_shape(A_node) + b_shape = _val_shape(B_node) + if a_shape is None or b_shape is None: + return None + M = a_shape[0] + N = b_shape[1] + shape = _val_shape(arg) + stride = _val_stride(arg) + dt = _val_dtype(arg) + if shape is None or dt is None: + return None + dt_str = _DTYPE_TO_STR.get(dt) + if dt_str is None: + return None + if len(shape) == 1: + n0 = shape[0] + m_is_static = _is_static_int(M) + n_is_static = _is_static_int(N) + if n_is_static and n0 == N: + if m_is_static and n0 == M: + return None + idx = self._add_extra(extras_nodes, arg) + return RowBroadcast(input_idx=idx, dtype=dt_str) + if m_is_static and n0 == M: + idx = self._add_extra(extras_nodes, arg) + return ColBroadcast(input_idx=idx, dtype=dt_str) + return None + if len(shape) == 2: + if shape[0] == 1 and shape[1] == N: + idx = self._add_extra(extras_nodes, arg) + return RowBroadcast(input_idx=idx, dtype=dt_str) + if shape[1] == 1 and shape[0] == M: + idx = self._add_extra(extras_nodes, arg) + return ColBroadcast(input_idx=idx, dtype=dt_str) + if shape[0] == M and shape[1] == N and stride is not None and stride[1] == 1: + idx = self._add_extra(extras_nodes, arg) + return AuxLoad(input_idx=idx, dtype=dt_str) + return None + + def _add_extra(self, extras_nodes, arg) -> int: + for i, e in enumerate(extras_nodes): + if e is arg: + return i + extras_nodes.append(arg) + return len(extras_nodes) - 1 + + # ── swiglu special-case ────────────────────────────────────────────────── + + def _try_fuse_swiglu(self, graph: fx.Graph, mm_node: fx.Node) -> bool: + """Match the canonical swiglu epilogue and dispatch to DualGemm.""" + wt = _validate_swiglu_weight(mm_node) + if wt is None: + return False + B_node, weight_node, N, K = wt + + ch = _validate_swiglu_chain(mm_node, N) + if ch is None: + return False + chain_nodes, last_chain_node, out_dt, sw7_json = ch + + out_dt_id = evt_runtime.out_dtype_id(out_dt) + n_out = N // 2 + _emit_and_replace( + graph, + last_chain_node, + (mm_node.args[0], weight_node, [], sw7_json, "swiglu_dual", n_out, out_dt_id), + chain_nodes, + extra_dead=[mm_node, B_node], + ) + return True diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm80/__init__.py b/magi_compiler/passes/piecewise_graph/fusion/sm80/__init__.py new file mode 100644 index 0000000..6fa72d4 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm80/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2026 SandAI. All Rights Reserved. diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm80/cutlass_kernels/swiglu_one_stage.cu b/magi_compiler/passes/piecewise_graph/fusion/sm80/cutlass_kernels/swiglu_one_stage.cu new file mode 100644 index 0000000..a4f01d4 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm80/cutlass_kernels/swiglu_one_stage.cu @@ -0,0 +1,427 @@ +// Copyright (c) 2026 SandAI. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-kernel fully-fused swiglu — SM80 multistage path. +// +// Routes from sm_80 / sm_86 / sm_89 / sm_120 (Blackwell GeForce). The +// Hopper (sm_90) native TMA + WGMMA implementation lives at +// ../../sm90/cutlass_kernels/swiglu_one_stage.cu and is selected by +// _compile_swiglu_dual in evt_runtime.py per device compute capability. +// +// D = swiglu(A @ B.T) +// +// A : (M, K) bf16 row-major +// B : (N, K) bf16 row-major (torch.nn.Linear weight convention; N even) +// D : (M, N/2) bf16 row-major (strided view of (M, ldd) host-padded buffer) +// +// Implementation uses cutlass::gemm::device::DualGemm — the two GEMMs +// A @ W_gate.T and A @ W_linear.T run in the same threadblock sharing A's +// smem stages; their accumulators stay in registers and a custom +// SwigluCombine epilogue functor combines them and writes only D. +// +// AUTOTUNE: at first call per (M, N, K) tuple the runner times every +// registered (TileShape, WarpShape, Stages) candidate and caches the +// fastest one. Candidate set is sized to the sm_120 / Ada SMEM budget +// (~96 KB per CTA); see SwAutoTuneRunner for SMEM math. + +#include +#include + +#include +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/util/host_tensor.h" + +#include "45_dual_gemm/device/dual_gemm.h" +#include "swiglu_combine.h" + +//////////////////////////////////////////////////////////////////////////////// +// Data types +//////////////////////////////////////////////////////////////////////////////// + +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::bfloat16_t; +using ElementC = cutlass::bfloat16_t; +using ElementAcc = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB0 = cutlass::layout::ColumnMajor; // strided ldB = 2K view +using LayoutB1 = cutlass::layout::ColumnMajor; // strided ldB = 2K view +using LayoutC = cutlass::layout::RowMajor; + +// AlignmentA / AlignmentB / AlignmentC are picked greedily on the Python side +// (128 → 64 bits) and passed in via -D at JIT time, so weights/activations +// whose K only divides 64 bits (e.g. K = 12 for bf16) still fuse onto this +// kernel instead of falling back to cuBLAS. AlignmentC normally stays at 128 +// because the host pads D's row stride to a full cache line, but exposing it +// keeps the parity with A/B and lets a smaller-pad mode drop to 64 without +// editing this file. Defaults preserve the prior 128-bit behaviour for +// callers that don't override. +#ifndef MAGI_SWIGLU_ALIGN_A_BITS +#define MAGI_SWIGLU_ALIGN_A_BITS 128 +#endif +#ifndef MAGI_SWIGLU_ALIGN_B_BITS +#define MAGI_SWIGLU_ALIGN_B_BITS 128 +#endif +#ifndef MAGI_SWIGLU_ALIGN_C_BITS +#define MAGI_SWIGLU_ALIGN_C_BITS 128 +#endif +constexpr int AlignmentA = MAGI_SWIGLU_ALIGN_A_BITS / cutlass::sizeof_bits::value; +constexpr int AlignmentB = MAGI_SWIGLU_ALIGN_B_BITS / cutlass::sizeof_bits::value; +// Output vector store width = ldd's alignment expressed in elements. Host-side +// padding (see _aligned_n_stride in evt_runtime.py) normally guarantees 128 +// bits / 8 elements for bf16 — kept tunable here for parity with A/B. +constexpr int EpilogueVecCount = MAGI_SWIGLU_ALIGN_C_BITS / cutlass::sizeof_bits::value; + +using ArchTag = cutlass::arch::Sm80; +using OperatorClass = cutlass::arch::OpClassTensorOp; +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +constexpr auto kScaleType = cutlass::epilogue::thread::ScaleType::Nothing; +constexpr bool kSplitKSerial = false; +constexpr bool kStoreD0 = false; +constexpr bool kStoreD1 = false; + +//////////////////////////////////////////////////////////////////////////////// +// Per-tile DualGemm wrapper. The DualGemm device type is templated on +// (TileShape, WarpShape, Stages) — every autotune candidate instantiates the +// full kernel for its tuple. Compile time grows linearly with candidate count +// but DualGemm Sm80 is much cheaper to compile than the EVT path (no visitor +// tree), so we can afford 8–10 candidates. +//////////////////////////////////////////////////////////////////////////////// + +template +struct DualGemmConfig { + using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute, kScaleType>; + using EpilogueOp1 = cutlass::epilogue::thread::LinearCombination< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute, kScaleType>; + using EpilogueOp2 = cutlass::epilogue::thread::SwigluCombine< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute>; + + using Gemm = cutlass::gemm::device::DualGemm< + ElementA, LayoutA, + ElementB, LayoutB0, LayoutB1, + ElementC, LayoutC, + ElementAcc, + OperatorClass, ArchTag, + TbShape, WaShape, InstructionShape, + EpilogueOp0, EpilogueOp1, EpilogueOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + Stages, + kStoreD0, kStoreD1, kSplitKSerial, + AlignmentA, AlignmentB>; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Type-erased runner concept; one instance per autotune candidate. +//////////////////////////////////////////////////////////////////////////////// + +struct SwArgs { + int M; // activations rows + int N_out; // = N/2 (output cols) + int K; + void* ptr_A; + void* ptr_B; // (N, K) row-major weight; gate/linear interleaved + void* ptr_D; // (M, N_out) — strided view of an (M, ldd) padded buffer + int64_t ldd; // D's row stride in elements; >= N_out, multiple of EpilogueVecCount + float alpha; // silu_alpha scaling: x * sigmoid(alpha * x) + float limit; // clamp bound: clamp(gate, max=limit), clamp(linear, -limit, limit) + float one; // additive offset: (x_linear + one) +}; + +class SwConcept { + public: + virtual ~SwConcept() = default; + virtual size_t get_workspace_size(const SwArgs&) = 0; + virtual cutlass::Status initialize(const SwArgs&, void* ws, cudaStream_t) = 0; + virtual cutlass::Status run(cudaStream_t stream) = 0; + virtual const char* name() const = 0; +}; + +template +class SwImpl : public SwConcept { + public: + using GemmType = typename Cfg::Gemm; + using EpilogueOp0 = typename Cfg::EpilogueOp0; + using EpilogueOp1 = typename Cfg::EpilogueOp1; + using EpilogueOp2 = typename Cfg::EpilogueOp2; + + explicit SwImpl(const char* name) : name_(name) {} + + typename GemmType::Arguments make_args(const SwArgs& a) { + auto ptrA = reinterpret_cast(a.ptr_A); + auto ptrB = reinterpret_cast(a.ptr_B); + auto ptrD = reinterpret_cast(a.ptr_D); + int const M = a.M, N_out = a.N_out, K = a.K; + + int64_t const ldB_strided = static_cast(2) * K; + LayoutB0 layoutB_gate(ldB_strided); + LayoutB1 layoutB_linear(ldB_strided); + LayoutC layoutC(a.ldd); + + using TensorRefA = cutlass::TensorRef; + using TensorRefB0 = cutlass::TensorRef; + using TensorRefB1 = cutlass::TensorRef; + using TensorRefCi = cutlass::TensorRef; + using TensorRefDo = cutlass::TensorRef; + + TensorRefA ref_A0(ptrA, LayoutA(static_cast(K))); + TensorRefB0 ref_B0(ptrB, layoutB_gate); + TensorRefCi ref_C0(nullptr, LayoutC(0)); + TensorRefDo ref_D0(nullptr, LayoutC(0)); + TensorRefB1 ref_B1(ptrB + K, layoutB_linear); + TensorRefCi ref_C1(nullptr, LayoutC(0)); + TensorRefDo ref_D1(nullptr, LayoutC(0)); + TensorRefDo ref_D2(ptrD, layoutC); + + typename EpilogueOp0::Params epi0{ElementCompute(1.0f), ElementCompute(0.0f)}; + typename EpilogueOp1::Params epi1{ElementCompute(1.0f), ElementCompute(0.0f)}; + typename EpilogueOp2::Params epi2{ + ElementCompute(a.alpha), ElementCompute(a.limit), ElementCompute(a.one)}; + + cutlass::gemm::GemmCoord problem{M, N_out, K}; + typename GemmType::Arguments args( + cutlass::gemm::DualGemmMode::kGemm, + problem, + ref_A0, + ref_B0, ref_C0, ref_D0, + ref_B1, ref_C1, ref_D1, + ref_D2, + epi0, epi1, epi2, + /*split_k_slices=*/1, + /*batch_count=*/1, + /*batch_stride_A=*/0, + /*batch_stride_B0=*/0, + /*batch_stride_B1=*/0, + /*batch_stride_C=*/0, + /*batch_stride_D=*/0); + return args; + } + + size_t get_workspace_size(const SwArgs& a) override { + return GemmType::get_workspace_size(make_args(a)); + } + cutlass::Status initialize(const SwArgs& a, void* ws, cudaStream_t s) override { + return gemm_.initialize(make_args(a), ws, s); + } + cutlass::Status run(cudaStream_t stream) override { + return gemm_.run(stream); + } + const char* name() const override { return name_; } + + private: + GemmType gemm_; + const char* name_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// AutoTune runner — first call per (M, N_out, K) shape times all candidates. +//////////////////////////////////////////////////////////////////////////////// + +#define SW_TILE(tb_m, tb_n, tb_k, wa_m, wa_n, wa_k, stages, label) \ + configs_.push_back(std::make_unique< \ + SwImpl, \ + cutlass::gemm::GemmShape, \ + stages>>>(label)) + +class SwAutoTuneRunner { + public: + SwAutoTuneRunner() { + // SMEM cost for DualGemm = (BM + 2*BN) * BK * 2B * stages because both + // B operands live in smem simultaneously. Budget cap ~96 KB matches + // sm_120's per-SM SMEM (also fits sm_80 / sm_86 / sm_89). + // + // Bucket of M doesn't drive a separate .cu here — DualGemm compiles + // fast enough that one runner with all candidates handles every M, and + // the per-shape cache picks the best for whatever M it sees. + + // Small / decode-friendly tiles + SW_TILE(64, 64, 32, 32, 32, 32, 4, "T<64,64,32>_S4"); // 36 KB + SW_TILE(64, 64, 64, 32, 32, 64, 3, "T<64,64,64>_S3"); // 72 KB + SW_TILE(64, 128, 32, 32, 64, 32, 3, "T<64,128,32>_S3"); // 60 KB + SW_TILE(64, 128, 32, 32, 64, 32, 4, "T<64,128,32>_S4"); // 80 KB + + // Medium tiles (CUTLASS bf16 reference defaults) + SW_TILE(128, 64, 32, 64, 32, 32, 3, "T<128,64,32>_S3"); // 48 KB + SW_TILE(128, 64, 32, 64, 32, 32, 4, "T<128,64,32>_S4"); // 64 KB + SW_TILE(128, 64, 64, 64, 32, 64, 3, "T<128,64,64>_S3"); // 96 KB + SW_TILE(128, 128, 32, 64, 64, 32, 3, "T<128,128,32>_S3"); // 72 KB + SW_TILE(128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"); // 96 KB + + // Large prefill tiles + SW_TILE(256, 64, 32, 64, 32, 32, 3, "T<256,64,32>_S3"); // 72 KB + // (256, 128, 32)*3 = 96 KB exact-budget, prone to SMEM alloc fail; omitted. + // (128, 256, 32)*3 = 120 KB > 96 — omitted. + // (64, 256, 32)*3 = 108 KB > 96 — omitted. + } + + void operator()(at::Tensor A, at::Tensor B, at::Tensor D, + float alpha, float limit, float one) { + TORCH_CHECK(A.is_cuda() && B.is_cuda() && D.is_cuda(), + "all inputs must be CUDA tensors"); + TORCH_CHECK(A.scalar_type() == at::kBFloat16 && B.scalar_type() == at::kBFloat16 + && D.scalar_type() == at::kBFloat16, + "all inputs must be bf16"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2 && D.dim() == 2, "A, B, D must be 2D"); + TORCH_CHECK(A.size(1) == B.size(1), "K mismatch (A.size(1) vs B.size(1))"); + // Stride-based contiguity instead of A.is_contiguous() / B.is_contiguous(): + // Inductor's reinterpret_tensor often hands us a tensor with the right + // strides but tripped is_contiguous() (e.g. bigger storage than sizes + // would imply). The kernel only cares that A's innermost is K-stride 1 + // and B's innermost is K-stride 1 (B is the (N, K) row-major weight, + // CUTLASS reads it via ColumnMajor + ldB=2K). + TORCH_CHECK(A.stride(1) == 1, "A innermost stride must be 1; got ", A.stride(1)); + TORCH_CHECK(A.stride(0) >= A.size(1), + "A row stride must be >= K; got stride(0)=", A.stride(0), ", K=", A.size(1)); + TORCH_CHECK(B.stride(1) == 1, "B innermost stride must be 1; got ", B.stride(1)); + TORCH_CHECK(B.stride(0) >= B.size(1), + "B row stride must be >= K; got stride(0)=", B.stride(0), ", K=", B.size(1)); + + int const M = static_cast(A.size(0)); + int const K = static_cast(A.size(1)); + int const N = static_cast(B.size(0)); + TORCH_CHECK((N % 2) == 0, "N must be even, got ", N); + int const N_out = N / 2; + TORCH_CHECK(D.size(0) == M && D.size(1) == N_out, + "D must be (M, N/2) = (", M, ",", N_out, ")"); + // D may be a strided view of a host-padded (M, ldd) buffer. + TORCH_CHECK(D.stride(1) == 1, "D innermost stride must be 1; got ", D.stride(1)); + TORCH_CHECK(D.stride(0) >= N_out, + "D row stride must be >= N_out; got stride(0)=", D.stride(0), ", N_out=", N_out); + + SwArgs ea; + ea.M = M; ea.N_out = N_out; ea.K = K; + ea.ptr_A = A.data_ptr(); + ea.ptr_B = B.data_ptr(); + ea.ptr_D = D.data_ptr(); + ea.ldd = static_cast(D.stride(0)); + ea.alpha = alpha; ea.limit = limit; ea.one = one; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()).stream(); + + // Single autotune per module. The .cu is compiled per (M-bucket, N, K) + // on the Python side — every distinct weight (N, K) gets its own .cu, + // so this runner instance hosts exactly one (N, K) and one bucket. The + // first call autotunes; all subsequent calls (any M in the bucket) + // reuse `best_idx_`. + if (best_idx_ < 0) { + best_idx_ = autotune(ea, stream); + } + int idx = best_idx_; + + auto& gemm = configs_[idx]; + size_t ws_sz = gemm->get_workspace_size(ea); + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) { + ws_ = at::empty({(int64_t)ws_sz + 1}, + at::TensorOptions().dtype(at::kByte).device(A.device())); + } + auto st = gemm->initialize(ea, ws_sz > 0 ? ws_.data_ptr() : nullptr, stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "DualGemm init failed (", gemm->name(), "): ", + cutlassGetStatusString(st)); + st = gemm->run(stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "DualGemm run failed (", gemm->name(), "): ", + cutlassGetStatusString(st)); + } + + int num_configs() const { return (int)configs_.size(); } + + private: + int autotune(const SwArgs& ea, cudaStream_t stream) { + int best_idx = -1; + float best_time = 1e30f; + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + + for (size_t i = 0; i < configs_.size(); ++i) { + auto& g = configs_[i]; + size_t ws_sz = 0; + try { ws_sz = g->get_workspace_size(ea); } + catch (...) { continue; } + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) { + ws_ = at::empty({(int64_t)ws_sz + 1}, + at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + } + void* ws_ptr = ws_sz > 0 ? ws_.data_ptr() : nullptr; + if (g->initialize(ea, ws_ptr, stream) != cutlass::Status::kSuccess) { + continue; + } + + // Warmup — 10 iters so the L2 / instruction cache settle. With only + // 3 warmups (the original count) the first timed iter sees a cold L2 + // and inflates the average, sometimes flipping the best-config choice. + for (int w = 0; w < 10; ++w) g->run(stream); + cudaStreamSynchronize(stream); + + // Time — 50 iters keeps timing noise to <1% so 2–3 % perf gaps + // between candidates are distinguishable. + cudaEventRecord(s, stream); + int iters = 50; + for (int p = 0; p < iters; ++p) g->run(stream); + cudaEventRecord(e, stream); + cudaEventSynchronize(e); + float ms = 0; + cudaEventElapsedTime(&ms, s, e); + float avg = ms / iters; + if (avg < best_time) { best_time = avg; best_idx = (int)i; } + } + cudaEventDestroy(s); cudaEventDestroy(e); + TORCH_CHECK(best_idx >= 0, + "swiglu AutoTune: no candidate succeeded for (M,N_out,K)=(", + ea.M, ",", ea.N_out, ",", ea.K, ")"); + return best_idx; + } + + std::vector> configs_; + int best_idx_ = -1; // -1 = not yet autotuned; sticky after first call. + at::Tensor ws_; +}; + +static SwAutoTuneRunner& runner() { + static SwAutoTuneRunner R; + return R; +} + +void swiglu_dual_matmul_out(at::Tensor A, at::Tensor B, at::Tensor D, + float alpha, float limit, float one) { + runner()(std::move(A), std::move(B), std::move(D), alpha, limit, one); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "CUTLASS DualGemm fully-fused swiglu (bf16) on sm_120 — autotune"; + m.def("swiglu_dual_matmul_out", + &swiglu_dual_matmul_out, + "D = swiglu(A @ B.T) in a single fused kernel; " + "A:(M,K) bf16, B:(N,K) bf16 (N even), D:(M,N/2) bf16", + pybind11::arg("A"), + pybind11::arg("B"), + pybind11::arg("D"), + pybind11::arg("alpha") = 1.702f, + pybind11::arg("limit") = 7.0f, + pybind11::arg("one") = 1.0f); + m.def("num_configs", []() { return runner().num_configs(); }); +} diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm80/evt_codegen.py b/magi_compiler/passes/piecewise_graph/fusion/sm80/evt_codegen.py new file mode 100644 index 0000000..a4e8471 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm80/evt_codegen.py @@ -0,0 +1,764 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Render a CUTLASS 2.x Sm80EVT .cu source from an EVT IR tree. + +Used on sm_120 (RTX 5090) and all non-sm_90 arches. The H100 path is +``../sm90/evt_codegen.py``, selected by ``evt_runtime`` on sm_90 devices. +""" + +from __future__ import annotations + +from typing import Dict, List, Tuple + +from ..common.codegen_shared import ( + _BUILTIN_FN_TEMPLATE, + _DTYPE_TO_AT, + _DTYPE_TO_AT_CPP, + _DTYPE_TO_CUTLASS, + _VALID_ALIGN_BITS, + _emit_custom_functor, +) +from ..evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store, walk_leaves + +# (BM, BN, BK, WM, WN, WK, NumStages, label). +# RTX 5090: 170 SMs, 100 KB SMEM / SM; tile×stages kept inside that envelope. +_TILE_CANDIDATES_SM120: dict = { + "small": [ + (64, 64, 32, 32, 32, 32, 4, "T<64,64,32>_S4"), + (64, 64, 64, 32, 32, 64, 3, "T<64,64,64>_S3"), + (64, 128, 32, 32, 64, 32, 3, "T<64,128,32>_S3"), + (64, 128, 32, 32, 64, 32, 4, "T<64,128,32>_S4"), + (64, 128, 64, 32, 64, 64, 3, "T<64,128,64>_S3"), + (64, 256, 32, 32, 64, 32, 3, "T<64,256,32>_S3"), + (128, 64, 32, 64, 32, 32, 3, "T<128,64,32>_S3"), + (128, 64, 32, 64, 32, 32, 4, "T<128,64,32>_S4"), + ], + "medium": [ + (128, 128, 32, 64, 64, 32, 3, "T<128,128,32>_S3"), + (128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"), + (128, 128, 64, 64, 64, 64, 3, "T<128,128,64>_S3"), + (128, 256, 32, 64, 64, 32, 3, "T<128,256,32>_S3"), + (256, 128, 32, 64, 64, 32, 3, "T<256,128,32>_S3"), + (128, 64, 64, 64, 32, 64, 4, "T<128,64,64>_S4"), + (64, 128, 64, 32, 64, 64, 4, "T<64,128,64>_S4"), + ], + "large": [ + (128, 256, 32, 64, 64, 32, 3, "T<128,256,32>_S3"), + (256, 128, 32, 64, 64, 32, 3, "T<256,128,32>_S3"), + (128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"), + (128, 128, 64, 64, 64, 64, 3, "T<128,128,64>_S3"), + ], +} + +# Backward-compat alias: some external callers still reference this name. +_TILE_CANDIDATES_5090 = _TILE_CANDIDATES_SM120 + + +def _emit_tile_candidates(m_bucket: str) -> str: + """Emit C++ EVT_TILE_CANDIDATE(...) statements for the given M bucket.""" + candidates = _TILE_CANDIDATES_SM120.get(m_bucket, _TILE_CANDIDATES_SM120["medium"]) + lines = [] + for bm, bn, bk, wm, wn, wk, stages, label in candidates: + lines.append(f' EVT_TILE_CANDIDATE({bm}, {bn}, {bk}, {wm}, {wn}, {wk}, ' f'{stages}, "{label}");') + return "\n".join(lines) + + +class _EvtEmitter: + """Bottom-up walker that emits typedef chains + leaf placeholders.""" + + def __init__(self, root: Store): + self.root = root + self.typedef_lines: List[str] = [] + self.functor_decls: List[str] = [] + self._emitted_functors: Dict[Tuple[str, str], str] = {} + self._tmp_counter = 0 + self.leaf_typedefs: List[Tuple[str, str, "int | None", str]] = [] + self.scalar_functor_counter = 0 + + def _new_name(self, prefix: str) -> str: + self._tmp_counter += 1 + return f"{prefix}_{self._tmp_counter}" + + def _functor_name_for(self, op: str, scalar) -> str: + key = (op, repr(scalar) if scalar is not None else "") + if key in self._emitted_functors: + return self._emitted_functors[key] + scalar_tag = "" + if scalar is not None: + self.scalar_functor_counter += 1 + scalar_tag = f"_v{self.scalar_functor_counter}" + name = f"Magi_{op}{scalar_tag}" + self._emitted_functors[key] = name + self.functor_decls.append(_emit_custom_functor(name, op, scalar)) + return name + + def _compute_op_template(self, node: Compute) -> str: + if node.op in _BUILTIN_FN_TEMPLATE and node.scalar is None: + return _BUILTIN_FN_TEMPLATE[node.op] + # Custom functor — either scalar-baked or unary-no-builtin (e.g. erf). + return self._functor_name_for(node.op, node.scalar) + + def emit(self) -> str: + """Walk the IR; return the typedef name of the root EVT type.""" + body_root = self._emit_node(self.root.child) + # The store leaf itself is the StoreD typedef wrapping body_root. + store_name = self._new_name("StoreD") + self.typedef_lines.append( + "using {name} = cutlass::epilogue::threadblock::VisitorAuxStore<\n" + " OutputTileThreadMap, ElementC,\n" + " cutlass::FloatRoundStyle::round_to_nearest,\n" + " cute::Stride>;".format(name=store_name) + ) + evt_d = self._new_name("EVT_D") + self.typedef_lines.append( + f"using {evt_d} = cutlass::epilogue::threadblock::Sm80EVT<\n" f" {store_name}, {body_root}>;" + ) + # Track the StoreD leaf metadata so the launcher knows where to bind D. + self.leaf_typedefs.append((store_name, "store", None, self.root.out_dtype)) + return evt_d + + def _emit_node(self, node) -> str: + if isinstance(node, Accum): + name = self._new_name("Accum") + self.typedef_lines.append(f"using {name} = cutlass::epilogue::threadblock::VisitorAccFetch;") + return name + if isinstance(node, RowBroadcast): + name = self._new_name("RowBcast") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorRowBroadcast<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride<_0, _1, int32_t>>;" + ) + self.leaf_typedefs.append((name, "row_bcast", node.input_idx, node.dtype)) + return name + if isinstance(node, ColBroadcast): + name = self._new_name("ColBcast") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorColBroadcast<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride<_1, _0, int32_t>>;" + ) + self.leaf_typedefs.append((name, "col_bcast", node.input_idx, node.dtype)) + return name + if isinstance(node, AuxLoad): + name = self._new_name("Aux") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorAuxLoad<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride>;" + ) + self.leaf_typedefs.append((name, "aux_load", node.input_idx, node.dtype)) + return name + if isinstance(node, Compute): + child_names = [self._emit_node(c) for c in node.children] + compute_name = self._new_name(f"Cmp_{node.op}") + fn_template = self._compute_op_template(node) + elem_compute = _DTYPE_TO_CUTLASS[node.compute_dtype] + self.typedef_lines.append( + f"using {compute_name} = cutlass::epilogue::threadblock::VisitorCompute<\n" + f" {fn_template}, {elem_compute}, {elem_compute},\n" + f" cutlass::FloatRoundStyle::round_to_nearest>;" + ) + evt_name = self._new_name(f"EVT_{node.op}") + child_typedef_list = ", ".join(child_names) + self.typedef_lines.append( + f"using {evt_name} = cutlass::epilogue::threadblock::Sm80EVT<\n" f" {compute_name}, {child_typedef_list}>;" + ) + return evt_name + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +def _emit_args_tree(node, leaf_args: Dict[int, str], indent: int = 4) -> str: + """Emit the nested-brace runtime args literal matching the EVT typedef tree.""" + pad = " " * indent + if isinstance(node, Accum): + return f"{pad}{{}}" + if isinstance(node, (RowBroadcast, ColBroadcast, AuxLoad)): + return f"{pad}{leaf_args[node.input_idx]}" + if isinstance(node, Compute): + children_str = ",\n".join(_emit_args_tree(c, leaf_args, indent + 2) for c in node.children) + return f"{pad}{{\n" f"{children_str},\n" f"{pad} {{}}\n" f"{pad}}}" + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +_KERNEL_PREAMBLE = """\ +// AUTO-GENERATED by magi_compiler/passes/piecewise_graph/fusion/sm80/evt_codegen.py +// Do not edit by hand. Regenerate by re-running the FX pass. +// +// IR cache key: {cache_key} + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +using cute::_0; +using cute::_1; + +//////////////////////////////////////////////////////////////////////////////// +// Custom functors (one per unique scalar-baked op or non-builtin unary). +//////////////////////////////////////////////////////////////////////////////// +{functor_decls} + +//////////////////////////////////////////////////////////////////////////////// +// Data types and layouts +//////////////////////////////////////////////////////////////////////////////// + +using ElementA = {a_elem}; +using ElementB = {b_elem}; +using ElementC = {c_elem}; +using ElementAcc = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::{b_layout}; +using LayoutC = cutlass::layout::RowMajor; + +// AlignmentA / AlignmentB / AlignmentC are baked from the (greedy) bit-width +// chosen at runtime to match the actual K, N, and ldd divisibility — 128 +// bits when shapes allow vector loads, 64 bits as a fallback for shapes that +// only meet 8-byte alignment (e.g. K = 12 for bf16). For C the host already +// over-pads D's row stride to a full cache line (see ``_aligned_n_stride`` +// in evt_runtime.py), so AlignmentC = 128 is almost always achievable — +// keeping it tunable lets a smaller-padding mode drop to 64 without a +// CUTLASS template rebuild from scratch. +constexpr int AlignmentA = {alignment_a_bits} / cutlass::sizeof_bits::value; +constexpr int AlignmentB = {alignment_b_bits} / cutlass::sizeof_bits::value; +constexpr int AlignmentC = {alignment_c_bits} / cutlass::sizeof_bits::value; + +using ArchTag = cutlass::arch::Sm80; +using OperatorClass = cutlass::arch::OpClassTensorOp; +using InstructionShape = cutlass::gemm::GemmShape< 16, 8, 16>; +constexpr int EVTEpilogueStages = 1; + +//////////////////////////////////////////////////////////////////////////////// +// Per-tile-config GEMM type. The OutputTileThreadMap depends on +// ThreadblockShape/WarpShape, which forces every EVT typedef to be re-built +// per tile. We package the whole tree inside a template struct keyed on the +// tile/warp/stages parameters so each autotune candidate is a distinct type. +//////////////////////////////////////////////////////////////////////////////// + +template +struct EvtConfig {{ + using TheTbShape = TbShape; + using TheWarpShape = WarpShape; + + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + TbShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>; + + //////////////////////////////////////////////////////////////////////////// + // EVT (Epilogue Visitor Tree) typedefs — generated from the IR tree. + //////////////////////////////////////////////////////////////////////////// +{typedef_block} + + //////////////////////////////////////////////////////////////////////////// + // GemmKernel / DeviceGemm + //////////////////////////////////////////////////////////////////////////// + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, LayoutC, AlignmentC, + ElementAcc, + ElementCompute, + OperatorClass, + ArchTag, + TbShape, + WarpShape, + InstructionShape, + {evt_root_name}, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + NumStages, + cutlass::arch::OpMultiplyAdd, + EVTEpilogueStages>::GemmKernel; + + using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter; +}}; + +//////////////////////////////////////////////////////////////////////////////// +// Autotune runner — one candidate per tile/warp/stages combination; first call +// at a new (M, N, K) tuple times every candidate and caches the winner. +//////////////////////////////////////////////////////////////////////////////// + +struct EvtArgs {{ + int M; + int N; + int K; + void* ptr_A; + void* ptr_B; + void* ptr_D; + int64_t lda; + int64_t ldb; + int64_t ldd; + // Extras pointers, in IR-leaf order. + std::vector ptr_extras; + // Row strides for AuxLoad extras (stride(0) in elements). Indexed in + // the same order as ptr_extras; RowBroadcast/ColBroadcast entries are + // unused but still present so indices stay aligned. + std::vector stride_extras; +}}; + +class EvtConcept {{ + public: + virtual ~EvtConcept() = default; + virtual size_t get_workspace_size(const EvtArgs&) = 0; + virtual cutlass::Status initialize(const EvtArgs&, void* ws, cudaStream_t s) = 0; + virtual cutlass::Status run(cudaStream_t stream) = 0; + virtual const char* name() const = 0; +}}; + +template +class EvtImpl : public EvtConcept {{ + public: + using GemmType = typename Cfg::DeviceGemm; + using EvtRoot = typename Cfg::{evt_root_name}; + + explicit EvtImpl(const char* name) : name_(name) {{}} + + typename GemmType::Arguments make_args(const EvtArgs& a) {{ + auto ptrA = reinterpret_cast(a.ptr_A); + auto ptrB = reinterpret_cast(a.ptr_B); + auto ptrD = reinterpret_cast(a.ptr_D); + int const M = a.M; + int const N = a.N; + int const K = a.K; + int64_t const MN = static_cast(M) * static_cast(N); + // ldd = D's row stride in elements; padded by host to satisfy AlignmentC. + int64_t const ldd = a.ldd; + int64_t const stride_d_total = static_cast(M) * ldd; + + typename EvtRoot::Arguments callback_args{{ +{args_tree} + , + {{ptrD, {{ldd, _1{{}}, stride_d_total}}}} + }}; + + cutlass::gemm::GemmCoord problem{{M, N, K}}; + int64_t const lda = a.lda; + int64_t const ldb = a.ldb; + typename GemmType::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + /*batch_count=*/1, + callback_args, + ptrA, ptrB, + /*ptr_C=*/nullptr, /*ptr_D=*/nullptr, + /*batch_stride_A=*/static_cast(M) * lda, + /*batch_stride_B=*/static_cast(N) * ldb, + /*batch_stride_C=*/0, /*batch_stride_D=*/0, + /*stride_a=*/lda, + /*stride_b=*/ldb, + /*stride_c=*/0, /*stride_d=*/0); + return args; + }} + + size_t get_workspace_size(const EvtArgs& a) override {{ + auto args = make_args(a); + return GemmType::get_workspace_size(args); + }} + cutlass::Status initialize(const EvtArgs& a, void* ws, cudaStream_t s) override {{ + auto args = make_args(a); + return gemm_.initialize(args, ws, s); + }} + cutlass::Status run(cudaStream_t stream) override {{ + return gemm_.run(stream); + }} + const char* name() const override {{ return name_; }} + + private: + GemmType gemm_; + const char* name_; +}}; + +//////////////////////////////////////////////////////////////////////////////// +// Python-facing launcher +//////////////////////////////////////////////////////////////////////////////// +""" + + +_LAUNCHER_TEMPLATE = """\ +//////////////////////////////////////////////////////////////////////////////// +// Tile candidate registration. Each AutoConfigBuilder invocation instantiates +// the full EVT typedef tree + GemmKernel for that (TileShape, WarpShape, +// NumStages) tuple. Compile time grows linearly with the candidate count, so +// keep the list small and shape-relevant. +//////////////////////////////////////////////////////////////////////////////// + +#define EVT_TILE_CANDIDATE(tb_m, tb_n, tb_k, wa_m, wa_n, wa_k, stages, label) \\ + configs_.push_back(std::make_unique, \\ + cutlass::gemm::GemmShape, \\ + stages>>>(label)) + +class EvtAutoTuneRunner {{ + public: + EvtAutoTuneRunner() {{ +{tile_candidate_block} + }} + + void operator()(at::Tensor A, at::Tensor B, + std::vector extras, at::Tensor D) {{ + TORCH_CHECK(A.is_cuda() && B.is_cuda() && D.is_cuda(), + "evt_matmul_out: A/B/D must be CUDA tensors"); + TORCH_CHECK(A.scalar_type() == {a_at_dtype}, "A must be {a_dtype}"); + TORCH_CHECK(B.scalar_type() == {b_at_dtype}, "B must be {b_dtype}"); + TORCH_CHECK(D.scalar_type() == {c_at_dtype}, "D must be {c_dtype}"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2 && D.dim() == 2, "A, B, D must be 2D"); + // A is always row-major (M, K), so its innermost (K) stride must be 1. + // We don't require A.is_contiguous() because Inductor often hands us a + // reinterpret_tensor that has the right strides but trips that check. + TORCH_CHECK(A.stride(1) == 1, "A innermost stride must be 1; got ", A.stride(1)); + TORCH_CHECK(A.stride(0) >= A.size(1), + "A row stride must be >= K; got stride(0)=", A.stride(0), ", K=", A.size(1)); + // B's stride contract depends on b_layout (substituted at codegen time): + // row: B is (K, N) row-major → B.stride(1) == 1, B.stride(0) >= N + // col: B is the underlying (N, K) → B.stride(1) == 1, B.stride(0) >= K + // row-major weight read as + // ColumnMajor (K, N) by CUTLASS + {b_stride_check} + + int const M = static_cast(A.size(0)); + int const K = static_cast(A.size(1)); + int const N = static_cast({n_dim_expr}); + + TORCH_CHECK(D.size(0) == M && D.size(1) == N, + "D must be (M, N); got ", D.sizes()); + // D may be a strided view of a host-padded (M, n_padded) buffer: inner + // stride must be 1, row stride (ldd) must be >= N. + TORCH_CHECK(D.stride(1) == 1, "D innermost stride must be 1; got ", D.stride(1)); + TORCH_CHECK(D.stride(0) >= N, + "D row stride must be >= N; got stride(0)=", D.stride(0), ", N=", N); + TORCH_CHECK(extras.size() == {n_extras}, "expected {n_extras} extra tensors, got ", extras.size()); + +{extras_validation} + + EvtArgs ea; + ea.M = M; ea.N = N; ea.K = K; + ea.ptr_A = A.data_ptr<{a_at_cpp}>(); + ea.ptr_B = B.data_ptr<{b_at_cpp}>(); + ea.ptr_D = D.data_ptr<{c_at_cpp}>(); + // Real strides from the at::Tensor — handles Inductor reinterpret_tensor + // cases where lda > K or ldb > size(1). Both stride(0) values are in + // elements since stride(1) == 1 was just validated above. + ea.lda = static_cast(A.stride(0)); + ea.ldb = static_cast(B.stride(0)); + ea.ldd = static_cast(D.stride(0)); + ea.ptr_extras.reserve({n_extras}); +{extras_ptrs} + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()).stream(); + + // Single autotune per module. The .cu is compiled per (IR, M-bucket, + // b_layout, N, K) on the Python side — every distinct weight (N, K) + // gets its own .cu, so this runner instance hosts exactly one (N, K) + // and one bucket of M values. Autotune once on the first call; all + // subsequent calls (any M inside the bucket) reuse `best_idx_`. + if (best_idx_ < 0) {{ + best_idx_ = autotune(ea, stream); + }} + int idx = best_idx_; + + auto& gemm = configs_[idx]; + size_t ws_sz = gemm->get_workspace_size(ea); + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) {{ + ws_ = at::empty({{(int64_t)ws_sz + 1}}, + at::TensorOptions().dtype(at::kByte).device(A.device())); + }} + auto st = gemm->initialize(ea, ws_sz > 0 ? ws_.data_ptr() : nullptr, stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "CUTLASS init failed (", gemm->name(), "): ", cutlassGetStatusString(st)); + st = gemm->run(stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "CUTLASS run failed (", gemm->name(), "): ", cutlassGetStatusString(st)); + }} + + int num_configs() const {{ return (int)configs_.size(); }} + + private: + int autotune(const EvtArgs& ea, cudaStream_t stream) {{ + int best_idx = -1; + float best_time = 1e30f; + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + + // Drain any pre-existing CUDA error so we don't blame our first candidate + // for an upstream failure. + (void)cudaGetLastError(); + + for (size_t i = 0; i < configs_.size(); ++i) {{ + auto& g = configs_[i]; + size_t ws_sz = 0; + try {{ ws_sz = g->get_workspace_size(ea); }} + catch (...) {{ (void)cudaGetLastError(); continue; }} + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) {{ + ws_ = at::empty({{(int64_t)ws_sz + 1}}, + at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + }} + void* ws_ptr = ws_sz > 0 ? ws_.data_ptr() : nullptr; + // initialize() can fail synchronously (e.g. cudaFuncSetAttribute returns + // cudaErrorInvalidValue for tiles whose SharedStorage exceeds the + // device opt-in cap). Clear the sticky CUDA error before moving on — + // otherwise the next launch (or post-autotune user run) inherits it + // and surfaces a misleading "Error Internal" against an unrelated tile. + if (g->initialize(ea, ws_ptr, stream) != cutlass::Status::kSuccess) {{ + (void)cudaGetLastError(); + continue; + }} + + // Warmup — 10 iters so L2 / inst caches settle (3 was too few — first + // timed iter saw a cold L2 and biased the choice towards smaller tiles). + // Capture run() status and sync return codes so an async launch failure + // (e.g. invalid grid, latent SMEM issue) disqualifies the tile cleanly. + bool tile_ok = true; + for (int w = 0; w < 10; ++w) {{ + if (g->run(stream) != cutlass::Status::kSuccess) {{ tile_ok = false; break; }} + }} + if (tile_ok && cudaStreamSynchronize(stream) != cudaSuccess) {{ + tile_ok = false; + }} + if (!tile_ok) {{ + (void)cudaGetLastError(); + continue; + }} + + // Time — 20 iters for ~1% timing noise, matching torch.compile defaults. + cudaEventRecord(s, stream); + int iters = 20; + for (int p = 0; p < iters; ++p) {{ + if (g->run(stream) != cutlass::Status::kSuccess) {{ tile_ok = false; break; }} + }} + cudaEventRecord(e, stream); + if (cudaEventSynchronize(e) != cudaSuccess) tile_ok = false; + if (!tile_ok) {{ + (void)cudaGetLastError(); + continue; + }} + float ms = 0; + cudaEventElapsedTime(&ms, s, e); + float avg = ms / iters; + if (avg < best_time) {{ best_time = avg; best_idx = (int)i; }} + }} + cudaEventDestroy(s); cudaEventDestroy(e); + TORCH_CHECK(best_idx >= 0, + "EVT AutoTune: no candidate succeeded for (M,N,K)=(", + ea.M, ",", ea.N, ",", ea.K, ")"); + return best_idx; + }} + + std::vector> configs_; + int best_idx_ = -1; // -1 = not yet autotuned; sticky after first call. + at::Tensor ws_; +}}; + +static EvtAutoTuneRunner& runner() {{ + static EvtAutoTuneRunner R; + return R; +}} + +void evt_matmul_out(at::Tensor A, at::Tensor B, + std::vector extras, + at::Tensor D) {{ + runner()(std::move(A), std::move(B), std::move(extras), std::move(D)); +}} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ + m.doc() = "Magi compiler EVT-fused matmul (auto-generated, autotune)"; + m.def("evt_matmul_out", &evt_matmul_out, + "Fused EVT matmul: D = epilogue(A @ B, extras...)", + pybind11::arg("A"), pybind11::arg("B"), + pybind11::arg("extras"), pybind11::arg("D")); + m.def("num_configs", []() {{ return runner().num_configs(); }}); +}} +""" + + +def render_evt_cu( + ir: Store, + a_dtype: str, + b_dtype: str, + cache_key_str: str = "", + b_layout: str = "row", + m_bucket: str = "medium", + alignment_a_bits: int = 128, + alignment_b_bits: int = 128, + alignment_c_bits: int = 128, + arch: str = "sm120", +) -> str: + """Render a complete .cu source for the given EVT IR. + + ``b_layout``: "row" = B is (K, N) RowMajor; "col" = underlying (N, K) weight + read as ColumnMajor. ``m_bucket`` selects the tile-candidate set for autotune. + ``alignment_*_bits``: greedy-picked 128 or 64 to match actual shape divisibility. + ``arch`` accepted for signature parity with sm90 renderer; ignored here. + """ + if b_layout not in ("row", "col"): + raise ValueError(f"b_layout must be 'row' or 'col', got {b_layout!r}") + if m_bucket not in _TILE_CANDIDATES_SM120: + raise ValueError(f"unknown m_bucket {m_bucket!r}; " f"expected one of {list(_TILE_CANDIDATES_SM120)}") + if ( + alignment_a_bits not in _VALID_ALIGN_BITS + or alignment_b_bits not in _VALID_ALIGN_BITS + or alignment_c_bits not in _VALID_ALIGN_BITS + ): + raise ValueError( + f"alignment_*_bits must be one of {_VALID_ALIGN_BITS}; " + f"got A={alignment_a_bits}, B={alignment_b_bits}, C={alignment_c_bits}" + ) + if not isinstance(ir, Store): + raise TypeError("render_evt_cu expects a Store node as root") + del arch + tile_candidate_block = _emit_tile_candidates(m_bucket) + + a_elem = _DTYPE_TO_CUTLASS[a_dtype] + b_elem = _DTYPE_TO_CUTLASS[b_dtype] + c_elem = _DTYPE_TO_CUTLASS[ir.out_dtype] + + emitter = _EvtEmitter(ir) + evt_root = emitter.emit() + + leaves = walk_leaves(ir) + leaf_args: Dict[int, str] = {} + for leaf in leaves: + if not isinstance(leaf, (RowBroadcast, ColBroadcast, AuxLoad)): + continue + elem = _DTYPE_TO_CUTLASS[leaf.dtype] + ptr_expr = f"reinterpret_cast<{elem}*>(a.ptr_extras[{leaf.input_idx}])" + if isinstance(leaf, RowBroadcast): + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{_0{{}}, _1{{}}, int32_t(N)}}}}" + elif isinstance(leaf, ColBroadcast): + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{_1{{}}, _0{{}}, int32_t(M)}}}}" + else: # AuxLoad + stride_expr = f"a.stride_extras[{leaf.input_idx}]" + mn_expr = f"(static_cast(M) * {stride_expr})" + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{{stride_expr}, _1{{}}, {mn_expr}}}}}" + + args_tree = _emit_args_tree(ir.child, leaf_args, indent=8) + + # Dedup by input_idx — same tensor may appear at multiple IR leaves. + extras_validation_lines = [] + extras_ptr_lines = [] + seen_extras: set = set() + extra_leaves = [n for n in leaves if not isinstance(n, Accum)] + n_extras = max((leaf.input_idx for leaf in extra_leaves), default=-1) + 1 + for leaf in extra_leaves: + i = leaf.input_idx + if i in seen_extras: + continue + seen_extras.add(i) + at_dtype = _DTYPE_TO_AT[leaf.dtype] + at_cpp = _DTYPE_TO_AT_CPP[leaf.dtype] + if isinstance(leaf, RowBroadcast): + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].numel() == N, "extras[{i}] must have N elements");') + elif isinstance(leaf, ColBroadcast): + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].numel() == M, "extras[{i}] must have M elements");') + elif isinstance(leaf, AuxLoad): + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].size(0) == M && extras[{i}].size(1) == N,' f' "extras[{i}] must be (M,N)");' + ) + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].stride(1) == 1 && extras[{i}].stride(0) >= N,' + f' "extras[{i}] must be row-major with stride(1)==1 and stride(0)>=N");' + ) + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].scalar_type() == {at_dtype},' f' "extras[{i}] must be {leaf.dtype}");' + ) + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].is_cuda(), "extras[{i}] must be CUDA");') + extras_ptr_lines.append(f" ea.ptr_extras.push_back(static_cast(" f"extras[{i}].data_ptr<{at_cpp}>()));") + extras_ptr_lines.append(f" ea.stride_extras.push_back(static_cast(extras[{i}].stride(0)));") + + extras_validation = "\n".join(extras_validation_lines) if extras_validation_lines else " // no extras" + extras_ptrs = "\n".join(extras_ptr_lines) if extras_ptr_lines else "" + + functor_decls = "\n".join(emitter.functor_decls) if emitter.functor_decls else "// (no custom functors)" + typedef_block = "\n".join(" " + l if l.strip() else l for l in "\n".join(emitter.typedef_lines).split("\n")) + + cutlass_b_layout = "RowMajor" if b_layout == "row" else "ColumnMajor" + if b_layout == "row": + n_dim_expr = "B.size(1)" + stride_b_expr = "N" + b_stride_check = ( + 'TORCH_CHECK(B.stride(1) == 1, "B innermost stride must be 1; got ", B.stride(1));\n' + ' TORCH_CHECK(B.stride(0) >= B.size(1),\n' + ' "B row stride must be >= N; got stride(0)=", B.stride(0), ", N=", B.size(1));' + ) + else: + n_dim_expr = "B.size(0)" + stride_b_expr = "K" + b_stride_check = ( + 'TORCH_CHECK(B.stride(1) == 1, "B innermost stride must be 1; got ", B.stride(1));\n' + ' TORCH_CHECK(B.stride(0) >= B.size(1),\n' + ' "B row stride must be >= K; got stride(0)=", B.stride(0), ", K=", B.size(1));' + ) + + preamble = _KERNEL_PREAMBLE.format( + cache_key=cache_key_str, + functor_decls=functor_decls, + a_elem=a_elem, + b_elem=b_elem, + c_elem=c_elem, + typedef_block=typedef_block, + evt_root_name=evt_root, + b_layout=cutlass_b_layout, + alignment_a_bits=alignment_a_bits, + alignment_b_bits=alignment_b_bits, + alignment_c_bits=alignment_c_bits, + args_tree=args_tree, + stride_b_expr=stride_b_expr, + ) + launcher = _LAUNCHER_TEMPLATE.format( + evt_root_name=evt_root, + args_tree=args_tree, + a_dtype=a_dtype, + b_dtype=b_dtype, + c_dtype=ir.out_dtype, + a_at_dtype=_DTYPE_TO_AT[a_dtype], + b_at_dtype=_DTYPE_TO_AT[b_dtype], + c_at_dtype=_DTYPE_TO_AT[ir.out_dtype], + a_at_cpp=_DTYPE_TO_AT_CPP[a_dtype], + b_at_cpp=_DTYPE_TO_AT_CPP[b_dtype], + c_at_cpp=_DTYPE_TO_AT_CPP[ir.out_dtype], + n_extras=n_extras, + extras_validation=extras_validation, + extras_ptrs=extras_ptrs, + n_dim_expr=n_dim_expr, + stride_b_expr=stride_b_expr, + b_stride_check=b_stride_check, + tile_candidate_block=tile_candidate_block, + ) + return preamble + launcher diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm90/__init__.py b/magi_compiler/passes/piecewise_graph/fusion/sm90/__init__.py new file mode 100644 index 0000000..6fa72d4 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm90/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2026 SandAI. All Rights Reserved. diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/device/sm90_dual_gemm.h b/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/device/sm90_dual_gemm.h new file mode 100644 index 0000000..cca6fac --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/device/sm90_dual_gemm.h @@ -0,0 +1,507 @@ +// Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// VENDORED from upstream CUTLASS examples on 2026-05-09: +// examples/49_hopper_dual_gemm/device/sm90_dual_gemm.h +// To resync, copy the upstream file verbatim over this one. Don't edit +// in-tree — the swiglu path on top of it is in +// magi_compiler/passes/piecewise_graph/fusion/cutlass_fusion/sm90/ +// cutlass_kernels/swiglu_one_stage.cu and works around any contract quirks +// at the host side, leaving this file as a drop-in upstream copy. +// +// Sm90 DualGemm — device-level wrapper. +// +// Public API mirrors examples/45_dual_gemm/device/dual_gemm.h as closely as +// the SM90 idiom permits, so existing call sites that build on +// `cutlass::gemm::device::DualGemm<...>` migrate to +// `cutlass::gemm::device::Sm90DualGemm<...>` with only the template-parameter +// list changing (TileShape/ClusterShape replace ThreadblockShape/WarpShape/ +// InstructionShape; ArchTag is implicit). +// +// Functional contract: +// +// D2 = epilogue2( A @ B0, A @ B1 ) +// +// Both matmuls accumulate in fp32 (or whatever ElementAccumulator the user +// picks), the binary `epilogue2` (e.g. cutlass::epilogue::thread::SwigluCombine) +// fuses them into a single ElementC output. D0 / D1 are not stored — the +// only currently supported mode is StoreD0 = StoreD1 = false (the same mode +// used by the Sm80 swiglu one-stage example). +// +// Hardware: requires sm_90a (Hopper WGMMA + TMA). The kernel uses a single +// 128-thread warpgroup per CTA, no cluster, non-persistent grid. + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/cluster_launch.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/arch/mma_sm90_gmma.hpp" +#include "cute/atom/mma_traits_sm90_gmma.hpp" + +#include "../kernel/sm90_dual_gemm_kernel.hpp" +// VENDORED CHANGE: upstream points at "../../45_dual_gemm/dual_gemm_common.h" +// (examples-relative). We co-located the file under our 49_hopper_dual_gemm/ +// to make the vendored tree self-contained. Resync: leave this `#include` as +// `"../dual_gemm_common.h"` even if upstream changes its path. +#include "../dual_gemm_common.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +namespace sm90_dual_gemm_detail { + +using namespace cute; + +// --------------------------------------------------------------------------- +// CUTLASS 2.x layout tag → cute Major / stride for SM90 GMMA. +// +// CUTLASS 2.x convention (operand-aware): +// A (M×K): RowMajor → K contig (TN-A) ColMajor → M contig (NT-A) +// B (K×N): RowMajor → N contig (NT-B) ColMajor → K contig (TN-B) +// C/D (M×N): RowMajor → N contig ColMajor → M contig +// +// The SM90 kernel views B as cute shape (N, K) (CUTLASS 3.x convention), +// so for operand B the relationship between the layout tag and which mode +// is contiguous is *flipped* relative to A and C. +// +// The Tag below selects between operand semantics; for each operand we +// derive a uniform cute Stride pair (int64_t, _1) (or (_1, int64_t)) plus +// the corresponding GMMA::Major. +// --------------------------------------------------------------------------- + +enum class Operand { A, B, C }; + +// Which mode of the (mode0, mode1) cute tensor is contiguous? +// K_contig=true → cute stride = (int64_t, _1) (K contiguous, GMMA::Major::K) +// K_contig=false → cute stride = (_1, int64_t) (MN contiguous, GMMA::Major::MN) +// +// For A (M, K): RowMajor=K_contig=true, ColMajor=K_contig=false +// For B (N, K): RowMajor=K_contig=false, ColMajor=K_contig=true (flipped — see above) +// For C (M, N): treat the K-contig flag as N-contig → RowMajor=true, ColMajor=false +template +struct LayoutTraits; + +// ---- A operand +template <> +struct LayoutTraits { + using Stride = cute::Stride; + static constexpr cute::GMMA::Major Major = cute::GMMA::Major::K; + template CUTE_HOST_DEVICE static + Stride make(T ld) { return cute::make_stride(int64_t(ld), cute::_1{}); } +}; +template <> +struct LayoutTraits { + using Stride = cute::Stride; + static constexpr cute::GMMA::Major Major = cute::GMMA::Major::MN; + template CUTE_HOST_DEVICE static + Stride make(T ld) { return cute::make_stride(cute::_1{}, int64_t(ld)); } +}; + +// ---- B operand (note: layout-tag sense is flipped vs A because +// cute view is (N, K) but CUTLASS-2.x tag is "B as K×N") +template <> +struct LayoutTraits { + // CUTLASS 2.x "RowMajor B" = N contig in (K, N) = MN-contig in our (N, K) + using Stride = cute::Stride; + static constexpr cute::GMMA::Major Major = cute::GMMA::Major::MN; + template CUTE_HOST_DEVICE static + Stride make(T ld) { return cute::make_stride(cute::_1{}, int64_t(ld)); } +}; +template <> +struct LayoutTraits { + // CUTLASS 2.x "ColumnMajor B" = K contig in (K, N) = K-contig in our (N, K) + using Stride = cute::Stride; + static constexpr cute::GMMA::Major Major = cute::GMMA::Major::K; + template CUTE_HOST_DEVICE static + Stride make(T ld) { return cute::make_stride(int64_t(ld), cute::_1{}); } +}; + +// ---- C/D operand (M, N): same mapping as A but interpreting "K-contig" as N-contig. +template <> +struct LayoutTraits { + using Stride = cute::Stride; + static constexpr cute::GMMA::Major Major = cute::GMMA::Major::K; // unused for C + template CUTE_HOST_DEVICE static + Stride make(T ld) { return cute::make_stride(int64_t(ld), cute::_1{}); } +}; +template <> +struct LayoutTraits { + using Stride = cute::Stride; + static constexpr cute::GMMA::Major Major = cute::GMMA::Major::MN; // unused for C + template CUTE_HOST_DEVICE static + Stride make(T ld) { return cute::make_stride(cute::_1{}, int64_t(ld)); } +}; + +} // namespace sm90_dual_gemm_detail + +//////////////////////////////////////////////////////////////////////////////// +// Sm90DualGemm — public template +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB0_, + typename LayoutB1_, + typename ElementC_, + typename LayoutC_, + typename ElementAccumulator_, + /// CTA tile shape: cute::Shape<_M, _N, _K> (e.g. <_128,_128,_64>) + typename TileShape_, + /// Per-GEMM linear-combination ops (only used when StoreD0/D1 are true). + typename EpilogueOutputOp0_, + typename EpilogueOutputOp1_, + /// Binary combine functor (e.g. cutlass::epilogue::thread::SwigluCombine). + typename EpilogueOutputOp2_, + /// Pipeline stages. Defaults to 3 — bumping higher needs more dyn-smem. + int Stages = 3, + /// Reserved for parity with the Sm80 DualGemm — must be false today. + bool StoreD0 = false, + bool StoreD1 = false, + /// Reserved for parity with the Sm80 DualGemm — must be false today. + bool SplitKSerial = false, + int AlignmentA = 8, + int AlignmentB = 8> +class Sm90DualGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB0 = LayoutB0_; + using LayoutB1 = LayoutB1_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementAccumulator = ElementAccumulator_; + using TileShape = TileShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using EpilogueOutputOp2 = EpilogueOutputOp2_; + + static constexpr int kStages = Stages; + static constexpr bool kStoreD0 = StoreD0; + static constexpr bool kStoreD1 = StoreD1; + static constexpr bool kSplitKSerial = SplitKSerial; + static constexpr int kAlignmentA = AlignmentA; + static constexpr int kAlignmentB = AlignmentB; + + static_assert(!StoreD0, "Sm90DualGemm: StoreD0=true is not yet implemented (D0 is consumed in registers)."); + static_assert(!StoreD1, "Sm90DualGemm: StoreD1=true is not yet implemented (D1 is consumed in registers)."); + static_assert(!SplitKSerial, "Sm90DualGemm: split-K is not yet implemented."); + + // Same TensorRef typedefs as the Sm80 DualGemm wrapper for API parity. + using TensorRefA = TensorRef; + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + static_assert(cute::is_static::value, "TileShape must be a static cute::Shape."); + static constexpr int kBlockM = cute::size<0>(TileShape{}); + static constexpr int kBlockN = cute::size<1>(TileShape{}); + static constexpr int kBlockK = cute::size<2>(TileShape{}); + + static_assert(kBlockM % 64 == 0, "BLK_M must be a multiple of 64 (WGMMA constraint)."); + + // ---------------------- cute-side type setup ---------------------- + private: + + using TraitsA = sm90_dual_gemm_detail::LayoutTraits; + using TraitsB0 = sm90_dual_gemm_detail::LayoutTraits; + using TraitsB1 = sm90_dual_gemm_detail::LayoutTraits; + using TraitsC = sm90_dual_gemm_detail::LayoutTraits; + + static constexpr cute::GMMA::Major kMajorA = TraitsA::Major; + static constexpr cute::GMMA::Major kMajorB0 = TraitsB0::Major; + static constexpr cute::GMMA::Major kMajorB1 = TraitsB1::Major; + static_assert(kMajorB0 == kMajorB1, + "B0 and B1 must share the same Major (= same K-major / MN-major orientation)."); + + using StrideA = typename TraitsA::Stride; + using StrideB = typename TraitsB0::Stride; + using StrideD = typename TraitsC::Stride; + + // Cooperative warpgroup count. Splits the BLK_M dim of each CTA tile across + // this many consumer warpgroups (each runs 128 threads), so a 128x128 tile + // with 2 wgs has each wg owning 64x128 of the accumulator. This caps the + // dual-acc per-thread register pressure regardless of BLK_M. + static constexpr int kNumConsumerWgs = + (kBlockM >= 128) ? 2 : 1; // M ≥ 128 ⇒ cooperative (64 M per wg) + + // The cute SS atom selector picks the WGMMA atom for the *single-wg view* + // of the tile: it expects size<0>(TileShape) == kBlockM / kNumConsumerWgs + // (the per-wg M sub-tile). We construct a synthetic per-wg tile shape for + // the selector, then re-tile across wgs via the TiledMma layout below. + using PerWgTileShape = cute::Shape< + cute::Int, cute::Int, cute::Int>; + using GmmaAtom = decltype(cute::SM90::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, PerWgTileShape, kMajorA, kMajorB0>()); + // Cooperative TiledMma: replicate the atom kNumConsumerWgs× along M. + using TiledMma = decltype(cute::make_tiled_mma( + GmmaAtom{}, + cute::Layout, cute::_1, cute::_1>>{})); + + // Smem layout atoms — per-Major canonical SW128 atoms. + using SmemLayoutAtomA = cute::conditional_t< + kMajorA == cute::GMMA::Major::K, + cute::GMMA::Layout_K_SW128_Atom, + cute::GMMA::Layout_MN_SW128_Atom>; + using SmemLayoutAtomB = cute::conditional_t< + kMajorB0 == cute::GMMA::Major::K, + cute::GMMA::Layout_K_SW128_Atom, + cute::GMMA::Layout_MN_SW128_Atom>; + + using PipeStages_ = cute::Int; + using SmemLayoutA = decltype(cute::tile_to_shape( + SmemLayoutAtomA{}, + cute::make_shape(cute::Int{}, cute::Int{}, PipeStages_{}))); + using SmemLayoutB = decltype(cute::tile_to_shape( + SmemLayoutAtomB{}, + cute::make_shape(cute::Int{}, cute::Int{}, PipeStages_{}))); + + // TMA atom decltypes — the actual TMA atoms have to be constructed on host + // (they bake the gmem tensor's runtime shape into a copy descriptor), so + // we only use these for the `decltype(...) const` kernel-template parameter. + using TmaA = decltype(cute::make_tma_atom( + cute::SM90_TMA_LOAD{}, + cute::make_tensor(static_cast(nullptr), + cute::make_shape(int(0), int(0)), + StrideA{}), + SmemLayoutA{}(cute::_, cute::_, 0), + cute::make_shape(cute::Int{}, cute::Int{}))); + using TmaB = decltype(cute::make_tma_atom( + cute::SM90_TMA_LOAD{}, + cute::make_tensor(static_cast(nullptr), + cute::make_shape(int(0), int(0)), + StrideB{}), + SmemLayoutB{}(cute::_, cute::_, 0), + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SharedStorage = kernel::sm90_dual_gemm_detail::DualGemmSharedStorage< + ElementA, ElementB, SmemLayoutA, SmemLayoutB>; + + static constexpr int kSmemBytes = static_cast(sizeof(SharedStorage)); + + public: + + // -------------------------- Arguments -------------------------- + struct Arguments { + DualGemmMode mode; + GemmCoord problem_size; + + TensorRefA ref_A0; + TensorRefB0 ref_B0; + TensorRefC ref_C0; + TensorRefD ref_D0; + TensorRefB1 ref_B1; + TensorRefC ref_C1; + TensorRefD ref_D1; + TensorRefD ref_D2; + + typename EpilogueOutputOp0::Params epilogue0; + typename EpilogueOutputOp1::Params epilogue1; + typename EpilogueOutputOp2::Params epilogue2; + + int split_k_slices = 1; + int batch_count = 1; + int64_t batch_stride_A = 0; + int64_t batch_stride_B0 = 0; + int64_t batch_stride_B1 = 0; + int64_t batch_stride_C = 0; + int64_t batch_stride_D = 0; + + CUTLASS_HOST_DEVICE Arguments() : problem_size(0, 0, 0) {} + + CUTLASS_HOST_DEVICE Arguments( + DualGemmMode mode_, + GemmCoord problem_size_, + TensorRefA ref_A0_, + TensorRefB0 ref_B0_, + TensorRefC ref_C0_, + TensorRefD ref_D0_, + TensorRefB1 ref_B1_, + TensorRefC ref_C1_, + TensorRefD ref_D1_, + TensorRefD ref_D2_, + typename EpilogueOutputOp0::Params epilogue0_ = typename EpilogueOutputOp0::Params(), + typename EpilogueOutputOp1::Params epilogue1_ = typename EpilogueOutputOp1::Params(), + typename EpilogueOutputOp2::Params epilogue2_ = typename EpilogueOutputOp2::Params(), + int split_k_slices_ = 1, + int batch_count_ = 1, + int64_t batch_stride_A_ = 0, + int64_t batch_stride_B0_ = 0, + int64_t batch_stride_B1_ = 0, + int64_t batch_stride_C_ = 0, + int64_t batch_stride_D_ = 0) + : mode(mode_), problem_size(problem_size_), + ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_), ref_D0(ref_D0_), + ref_B1(ref_B1_), ref_C1(ref_C1_), ref_D1(ref_D1_), ref_D2(ref_D2_), + epilogue0(epilogue0_), epilogue1(epilogue1_), epilogue2(epilogue2_), + split_k_slices(split_k_slices_), + batch_count(batch_count_), + batch_stride_A(batch_stride_A_), + batch_stride_B0(batch_stride_B0_), + batch_stride_B1(batch_stride_B1_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) {} + }; + + private: + // Captured inside `initialize` for `run` to use later. + Arguments args_{}; + bool initialized_ = false; + + public: + + Sm90DualGemm() = default; + + static Status can_implement(Arguments const& args) { + if (args.mode != DualGemmMode::kGemm) { + return Status::kErrorInvalidProblem; + } + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + if (args.batch_count != 1) { + return Status::kErrorInvalidProblem; + } + if (args.problem_size.m() <= 0 || args.problem_size.n() <= 0 || args.problem_size.k() <= 0) { + return Status::kErrorInvalidProblem; + } + if (args.ref_D2.data() == nullptr) { + return Status::kErrorInvalidProblem; + } + // D0/D1 must be null when StoreD0/D1 is false (matches Sm80 DualGemm contract). + if ((kStoreD0 != (args.ref_D0.data() != nullptr)) || + (kStoreD1 != (args.ref_D1.data() != nullptr))) { + return Status::kErrorInvalidProblem; + } + // K alignment: must be a multiple of TMA's 128-bit minimum (= 8 bf16 elts). + constexpr int min_k_align = 128 / cutlass::sizeof_bits::value; + if (args.problem_size.k() % min_k_align != 0) { + return Status::kErrorInvalidProblem; + } + return Status::kSuccess; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { + return 0; + } + + Status initialize(Arguments const& args, void* /*workspace*/ = nullptr, + cudaStream_t /*stream*/ = nullptr) { + Status s = can_implement(args); + if (s != Status::kSuccess) return s; + args_ = args; + initialized_ = true; + return Status::kSuccess; + } + + Status update(Arguments const& args, void* /*workspace*/ = nullptr) { + Status s = can_implement(args); + if (s != Status::kSuccess) return s; + args_ = args; + return Status::kSuccess; + } + + Status run(cudaStream_t stream = nullptr) { + if (!initialized_) return Status::kErrorInternal; + + int const M = args_.problem_size.m(); + int const N = args_.problem_size.n(); + int const K = args_.problem_size.k(); + + // Stride conversion: TensorRef<...,LayoutX>::layout().stride() carries the + // leading dim, which is what cute needs. + auto dA = TraitsA ::make(args_.ref_A0.stride(0)); + auto dB0 = TraitsB0::make(args_.ref_B0.stride(0)); + auto dB1 = TraitsB1::make(args_.ref_B1.stride(0)); + auto dD2 = TraitsC ::make(args_.ref_D2.stride(0)); + + auto* ptrA = args_.ref_A0.data(); + auto* ptrB0 = args_.ref_B0.data(); + auto* ptrB1 = args_.ref_B1.data(); + auto* ptrD2 = args_.ref_D2.data(); + + // Build TMA atoms host-side (they capture the full gmem-shape descriptor). + auto mA = cute::make_tensor(ptrA, cute::make_shape(M, K), dA ); + auto mB0 = cute::make_tensor(ptrB0, cute::make_shape(N, K), dB0); + auto mB1 = cute::make_tensor(ptrB1, cute::make_shape(N, K), dB1); + + auto tmaA = cute::make_tma_atom(cute::SM90_TMA_LOAD{}, mA, + SmemLayoutA{}(cute::_, cute::_, 0), + cute::make_shape(cute::Int{}, cute::Int{})); + auto tmaB0 = cute::make_tma_atom(cute::SM90_TMA_LOAD{}, mB0, + SmemLayoutB{}(cute::_, cute::_, 0), + cute::make_shape(cute::Int{}, cute::Int{})); + auto tmaB1 = cute::make_tma_atom(cute::SM90_TMA_LOAD{}, mB1, + SmemLayoutB{}(cute::_, cute::_, 0), + cute::make_shape(cute::Int{}, cute::Int{})); + + typename EpilogueOutputOp2::Params op2_params = args_.epilogue2; + EpilogueOutputOp2 combine_op(op2_params); + + auto cta_tiler = TileShape{}; + auto prob_shape = cute::make_shape(M, N, K); + + auto* kernel_ptr = &kernel::sm90_dual_gemm_detail::sm90_dual_gemm_device< + decltype(prob_shape), TileShape, + ElementA, SmemLayoutA, decltype(tmaA), + ElementB, SmemLayoutB, decltype(tmaB0), + ElementC, decltype(dD2), + TiledMma, EpilogueOutputOp2>; + + cudaError_t err = cudaFuncSetAttribute( + reinterpret_cast(kernel_ptr), + cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemBytes); + if (err != cudaSuccess) return Status::kErrorInternal; + + dim3 grid(static_cast((M + kBlockM - 1) / kBlockM), + static_cast((N + kBlockN - 1) / kBlockN), + 1); + // 1 producer warpgroup (128 threads, only lane 0 of warp 0 is live) + // + kNumConsumerWgs consumer warpgroups (128 threads each). + dim3 block(static_cast(128 * (kNumConsumerWgs + 1)), 1, 1); + dim3 cluster(1, 1, 1); + + cutlass::ClusterLaunchParams launch_params{grid, block, cluster, kSmemBytes, stream}; + cutlass::Status st = cutlass::launch_kernel_on_cluster( + launch_params, + reinterpret_cast(kernel_ptr), + prob_shape, cta_tiler, + ptrA, tmaA, + ptrB0, tmaB0, + ptrB1, tmaB1, + ptrD2, dD2, + TiledMma{}, + combine_op); + return st; + } + + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + Status operator()(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr) { + Status s = initialize(args, workspace, stream); + if (s == Status::kSuccess) s = run(stream); + return s; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/dual_gemm_common.h b/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/dual_gemm_common.h new file mode 100644 index 0000000..25a083a --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/dual_gemm_common.h @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines common types used for all DualGemm operators. + + VENDORED from upstream CUTLASS examples on 2026-05-09: + examples/45_dual_gemm/dual_gemm_common.h + Co-located with the Sm90DualGemm headers in this directory because the + upstream sm90_dual_gemm.h transitively includes it. To resync, copy the + upstream file verbatim over this one. +*/ +#pragma once + +namespace cutlass { +namespace gemm { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class DualGemmMode { + kGemm, + kBatched, + kInvalid +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/kernel/sm90_dual_gemm_kernel.hpp b/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/kernel/sm90_dual_gemm_kernel.hpp new file mode 100644 index 0000000..8100588 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/hopper_dual_gemm/kernel/sm90_dual_gemm_kernel.hpp @@ -0,0 +1,389 @@ +// Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// VENDORED from upstream CUTLASS examples on 2026-05-09: +// examples/49_hopper_dual_gemm/kernel/sm90_dual_gemm_kernel.hpp +// To resync, copy the upstream file verbatim over this one. +// +// Sm90 DualGemm kernel — fused dual-WGMMA producer/consumer pipeline, +// warp-specialized. +// +// Computes (in a single kernel launch): +// +// Acc0 = A @ B0 +// Acc1 = A @ B1 +// D2 = combine(Acc0, Acc1) +// +// A is loaded once and consumed by both WGMMA chains in the same K-stage, +// so the gate / linear matmuls share A's smem traffic — the whole point +// of DualGemm. Neither D0 nor D1 ever spills to HBM. +// +// Architecture +// ------------ +// Three warpgroups per CTA (1 producer + N consumer), no clusters, +// non-persistent grid: +// +// * Producer warpgroup (warps 0-3, threads 0-127): only lane 0 of warp 0 +// is "live"; the rest call setmaxnreg.dec<40> and exit. The live thread +// issues TMA loads for A + B0 + B1 of the next K-stage and arrives on +// a per-stage producer barrier. Reg-deallocated to <=40 to free SM +// registers for the consumers. +// +// * Consumer warpgroups (warps 4..N+3, threads 128..128*(N+1)-1): each +// wg does setmaxnreg.inc<240> and runs two WGMMA chains that share +// the same A smem buffer (the TiledMma's _N_-warpgroup M-tiling splits +// A's M dim between them). Each wg owns its own accumulator pair +// (acc0, acc1) and emits its M-sub-tile of D2 via predicated STG. +// +// The number of consumer warpgroups is determined by the TiledMma's +// thread-count: `NumConsumerWgs = size(TiledMma{}) / 128`. The user +// configures this on the host side via the cooperative make_tiled_mma +// (e.g. `Layout<_2,_1,_1>` doubles M-side compute per CTA). +// +// K-pipeline +// ---------- +// Two barriers per stage: +// +// producer_mbar[s] : ClusterTransactionBarrier +// Producer arrives once after `cp.async.bulk` issue +// (3 TMAs share one barrier, transaction-bytes count +// all three). Consumer waits before issuing WGMMA. +// +// consumer_mbar[s] : ClusterBarrier +// Consumer arrives 128× after `warpgroup_wait` releases +// the stage. Producer waits before issuing the next +// TMA into the same stage. +// +// Pipelining is across K-tiles: the consumer issues a new WGMMA batch +// then immediately calls `warpgroup_wait()` which keeps +// K_PIPE_MMAS batches in flight. With K_PIPE_MMAS=1 the loop-carried +// chain is kept full and the next-stage barrier wait + next WGMMA can +// overlap with the trailing WGMMA's tensor-core latency. +// +// Bounds +// ------ +// M and N can be arbitrary. TMA naturally zero-fills out-of-bound loads +// (so accumulators stay correct), and stores are predicated per (m, n) +// coordinate. + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/device_kernel.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/arch/mma_sm90.h" + +#include "cute/tensor.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/algorithm/functional.hpp" + +namespace cutlass { +namespace gemm { +namespace kernel { + +namespace sm90_dual_gemm_detail { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////// +// SharedStorage for one Sm90 dual-GEMM CTA. +// +// Three pipelined smem buffers (A, B0, B1), one producer barrier per stage +// (TMA-arrival), one consumer barrier per stage (MMA-completion-release). +//////////////////////////////////////////////////////////////////////////////// + +template +struct DualGemmSharedStorage { + static constexpr int K_PIPE_MAX = size<2>(SmemLayoutA{}); + + alignas(128) cute::ArrayEngine> sA; + alignas(128) cute::ArrayEngine> sB0; + alignas(128) cute::ArrayEngine> sB1; + + alignas(16) uint64_t producer_mbar[K_PIPE_MAX]; + alignas(16) uint64_t consumer_mbar[K_PIPE_MAX]; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Kernel +// +// Threading: 256 threads / CTA = 2 warpgroups +// - wg 0 (threads 0-127): producer (only lane 0 of warp 0 is live) +// - wg 1 (threads 128-255): consumer (full WGMMA + epilogue) +//////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape, + class CtaTiler, + class ElementA, class SmemLayoutA, class TmaA, + class ElementB, class SmemLayoutB, class TmaB, + class ElementC, class CStride, class TiledMma, + class CombineOp> +__global__ +__launch_bounds__(/*MaxThreads=*/(decltype(size(TiledMma{}))::value + 128), 1) +void +sm90_dual_gemm_device( + ProblemShape shape_MNK, + CtaTiler cta_tiler, + ElementA const* /*ptr_A — only here so TMA atom can be constructed host-side*/, + CUTLASS_GRID_CONSTANT TmaA const tma_a, + ElementB const* /*ptr_B0*/, + CUTLASS_GRID_CONSTANT TmaB const tma_b0, + ElementB const* /*ptr_B1*/, + CUTLASS_GRID_CONSTANT TmaB const tma_b1, + ElementC* ptr_D2, CStride dD2, + TiledMma mma, + CombineOp combine_op) +{ +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + using namespace cute; + + // ---------- preconditions ---------- + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(decltype(size(TiledMma{}))::value % 128 == 0, + "Sm90 dual gemm: TiledMma thread-count must be a multiple of " + "128 (one consumer warpgroup per 128 threads)."); + + constexpr int kNumConsumerWgs = decltype(size(TiledMma{}))::value / 128; + constexpr int kConsumerThreads = 128 * kNumConsumerWgs; + constexpr int kProducerThreads = 128; + constexpr int kBarrierArvCount = kConsumerThreads; + + // ---------- gmem tensors ---------- + auto [M, N, K] = shape_MNK; + Tensor mA = tma_a .get_tma_tensor(make_shape(M, K)); + Tensor mB0 = tma_b0.get_tma_tensor(make_shape(N, K)); + Tensor mB1 = tma_b1.get_tma_tensor(make_shape(N, K)); + Tensor mD2 = make_tensor(make_gmem_ptr(ptr_D2), make_shape(M, N), dD2); + + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); + Tensor gB0 = local_tile(mB0, cta_tiler, cta_coord, Step< X,_1,_1>{}); + Tensor gB1 = local_tile(mB1, cta_tiler, cta_coord, Step< X,_1,_1>{}); + Tensor gD2 = local_tile(mD2, cta_tiler, cta_coord, Step<_1,_1, X>{}); + + // ---------- smem tensors ---------- + extern __shared__ char smem_buf[]; + using Storage = DualGemmSharedStorage; + Storage& storage = *reinterpret_cast(smem_buf); + + Tensor sA = make_tensor(make_smem_ptr(storage.sA .begin()), SmemLayoutA{}); + Tensor sB0 = make_tensor(make_smem_ptr(storage.sB0.begin()), SmemLayoutB{}); + Tensor sB1 = make_tensor(make_smem_ptr(storage.sB1.begin()), SmemLayoutB{}); + + // ---------- TMA partitioning ---------- + auto [tAgA, tAsA ] = tma_partition(tma_a , Int<0>{}, Layout<_1>{}, + group_modes<0,2>(sA ), group_modes<0,2>(gA )); + auto [tBgB0, tBsB0] = tma_partition(tma_b0, Int<0>{}, Layout<_1>{}, + group_modes<0,2>(sB0), group_modes<0,2>(gB0)); + auto [tBgB1, tBsB1] = tma_partition(tma_b1, Int<0>{}, Layout<_1>{}, + group_modes<0,2>(sB1), group_modes<0,2>(gB1)); + + constexpr uint32_t tma_transaction_bytes = + static_cast(sizeof(make_tensor_like(tensor<0>(tAsA))) + + sizeof(make_tensor_like(tensor<0>(tBsB0))) + + sizeof(make_tensor_like(tensor<0>(tBsB1)))); + + constexpr int K_PIPE_MAX = Storage::K_PIPE_MAX; + constexpr int K_PIPE_MMAS = 1; + + int k_tile_count = size<1>(tAgA); + + // ---------- warpgroup role ---------- + int thr_idx = threadIdx.x; + int warp_idx = cutlass::canonical_warp_idx_sync(); + // wg_idx == 0 → producer warpgroup + // wg_idx == 1..N → consumer warpgroup #(wg_idx-1) of the cooperative pair/triple/... + int wg_idx = thr_idx / 128; + int cons_thr_idx = thr_idx - 128; // [0, kConsumerThreads) for consumer wgs + + using ProducerBar = cutlass::arch::ClusterTransactionBarrier; + using ConsumerBar = cutlass::arch::ClusterBarrier; + + // ---------- barrier init (one thread total) ---------- + if (warp_idx == 0 && cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < K_PIPE_MAX; ++p) { + ProducerBar::init(&storage.producer_mbar[p], 1); + ConsumerBar::init(&storage.consumer_mbar[p], kBarrierArvCount); + } + } + // Make barrier inits visible to all threads in the CTA before they start + // consuming them. + __syncthreads(); + + // ============================================================================ + // Producer warpgroup + // ============================================================================ + if (wg_idx == 0) { + cutlass::arch::warpgroup_reg_dealloc<40>(); + + // Inactive lanes / warps in the producer wg exit early after reg-dealloc. + // Only lane 0 of warp 0 issues TMAs. + if (warp_idx != 0) return; + if (!cute::elect_one_sync()) return; + + // Prefetch up to K_PIPE_MAX stages without waiting — those are the + // initial fills that the consumer hasn't yet reached. State advance is + // done implicitly by issuing into stages 0..prefetch_count-1. + int const prefetch_count = + (k_tile_count < K_PIPE_MAX) ? k_tile_count : K_PIPE_MAX; + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < prefetch_count; ++p) { + ProducerBar::arrive_and_expect_tx(&storage.producer_mbar[p], + tma_transaction_bytes); + copy(tma_a .with(storage.producer_mbar[p]), + tAgA (_, p), tAsA (_, p)); + copy(tma_b0.with(storage.producer_mbar[p]), + tBgB0(_, p), tBsB0(_, p)); + copy(tma_b1.with(storage.producer_mbar[p]), + tBgB1(_, p), tBsB1(_, p)); + } + + // Steady-state main loop. Each iteration: wait for consumer to release + // the next stage, then re-arm the producer barrier and issue a fresh + // TMA into it. write_phase starts at 0 (matching the initial parity + // of consumer_mbar) and flips on every wrap of write_pipe. + int write_pipe = 0; + uint32_t write_phase = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (int k = K_PIPE_MAX; k < k_tile_count; ++k) { + ConsumerBar::wait(&storage.consumer_mbar[write_pipe], write_phase); + + ProducerBar::arrive_and_expect_tx(&storage.producer_mbar[write_pipe], + tma_transaction_bytes); + copy(tma_a .with(storage.producer_mbar[write_pipe]), + tAgA (_, k), tAsA (_, write_pipe)); + copy(tma_b0.with(storage.producer_mbar[write_pipe]), + tBgB0(_, k), tBsB0(_, write_pipe)); + copy(tma_b1.with(storage.producer_mbar[write_pipe]), + tBgB1(_, k), tBsB1(_, write_pipe)); + + ++write_pipe; + if (write_pipe == K_PIPE_MAX) { + write_pipe = 0; + write_phase ^= 1; + } + } + return; + } + + // ============================================================================ + // Consumer warpgroup(s) — cooperative when kNumConsumerWgs > 1 + // ============================================================================ + // Register budget: SM has 64K regs total. 1 producer wg × 40 + N consumer + // wgs × R must satisfy 40 + N·R ≤ 65536 / 128 = 512. + // N=1 ⇒ R ≤ 472, pick 240 (matches CUTLASS pingpong) + // N=2 ⇒ R ≤ 236, pick 232 (cooperative; matches CUTLASS cooperative) + if constexpr (kNumConsumerWgs == 1) { + cutlass::arch::warpgroup_reg_alloc<240>(); + } else { + cutlass::arch::warpgroup_reg_alloc<232>(); + } + + // For a cooperative TiledMma whose layout spans multiple warpgroups, the + // thread slice must be queried with the *flattened* index across the math + // wgs (0 .. kConsumerThreads-1). Each math wg's threads naturally cover + // its sub-tile of the (BLK_M, BLK_N) accumulator. + ThrMMA thr_mma = mma.get_thread_slice(cons_thr_idx); + Tensor tCsA = thr_mma.partition_A(sA ); + Tensor tCsB0 = thr_mma.partition_B(sB0); + Tensor tCsB1 = thr_mma.partition_B(sB1); + + Tensor tCgC = thr_mma.partition_C(gD2); + Tensor tCrC0 = thr_mma.make_fragment_C(tCgC); + Tensor tCrC1 = thr_mma.make_fragment_C(tCgC); + clear(tCrC0); + clear(tCrC1); + + Tensor tCrA = thr_mma.make_fragment_A(tCsA); + Tensor tCrB0 = thr_mma.make_fragment_B(tCsB0); + Tensor tCrB1 = thr_mma.make_fragment_B(tCsB1); + + int read_pipe = 0; + uint32_t read_phase = 0; + int release_pipe = 0; + uint32_t release_phase = 0; + + // ---------- Prologue: queue K_PIPE_MMAS WGMMA batches without releasing ---- + int prologue_count = (k_tile_count < K_PIPE_MMAS) ? k_tile_count : K_PIPE_MMAS; + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < prologue_count; ++p) { + ProducerBar::wait(&storage.producer_mbar[read_pipe], read_phase); + + cute::warpgroup_arrive(); + cute::gemm(mma, tCrA(_,_,_,read_pipe), tCrB0(_,_,_,read_pipe), tCrC0); + cute::gemm(mma, tCrA(_,_,_,read_pipe), tCrB1(_,_,_,read_pipe), tCrC1); + cute::warpgroup_commit_batch(); + + ++read_pipe; + if (read_pipe == K_PIPE_MAX) { read_pipe = 0; read_phase ^= 1; } + } + + // ---------- Mainloop: issue, wait for K_PIPE_MMAS-old batch, release -------- + int mainloop_count = k_tile_count - prologue_count; + CUTLASS_PRAGMA_NO_UNROLL + for (int k = 0; k < mainloop_count; ++k) { + ProducerBar::wait(&storage.producer_mbar[read_pipe], read_phase); + + cute::warpgroup_arrive(); + cute::gemm(mma, tCrA(_,_,_,read_pipe), tCrB0(_,_,_,read_pipe), tCrC0); + cute::gemm(mma, tCrA(_,_,_,read_pipe), tCrB1(_,_,_,read_pipe), tCrC1); + cute::warpgroup_commit_batch(); + + cute::warpgroup_wait(); + + ConsumerBar::arrive(&storage.consumer_mbar[release_pipe]); + + ++read_pipe; + if (read_pipe == K_PIPE_MAX) { read_pipe = 0; read_phase ^= 1; } + ++release_pipe; + if (release_pipe == K_PIPE_MAX) { release_pipe = 0; release_phase ^= 1; } + } + + // ---------- Drain remaining in-flight WGMMAs and release their stages ------ + cute::warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < prologue_count; ++p) { + ConsumerBar::arrive(&storage.consumer_mbar[release_pipe]); + ++release_pipe; + if (release_pipe == K_PIPE_MAX) { release_pipe = 0; release_phase ^= 1; } + } + + // ---------- Epilogue: combine (acc0, acc1) and predicate-store -------------- + Tensor cD2 = make_identity_tensor(make_shape(size<0>(gD2), size<1>(gD2))); + Tensor tCcD = thr_mma.partition_C(cD2); + + int const m_offset = blockIdx.x * size<0>(gD2); + int const n_offset = blockIdx.y * size<1>(gD2); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrC0); ++i) { + auto coord = tCcD(i); + int m_g = m_offset + get<0>(coord); + int n_g = n_offset + get<1>(coord); + if (m_g < M && n_g < N) { + ElementC c0 = static_cast(tCrC0(i)); + ElementC c1 = static_cast(tCrC1(i)); + tCgC(i) = combine_op(c0, c1); + } + } +#endif // CUTLASS_ARCH_MMA_SM90_SUPPORTED +} + +} // namespace sm90_dual_gemm_detail + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/swiglu_one_stage.cu b/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/swiglu_one_stage.cu new file mode 100644 index 0000000..cb2b907 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm90/cutlass_kernels/swiglu_one_stage.cu @@ -0,0 +1,417 @@ +// Copyright (c) 2026 SandAI. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-kernel fully-fused swiglu on Hopper (sm_90a) using the vendored +// Sm90 DualGemm (TMA + WGMMA, warp-specialized cooperative consumer +// warpgroups). User contract is byte-for-byte identical to the SM80 +// sibling at ../../sm80/cutlass_kernels/swiglu_one_stage.cu — same Python +// signature, same B gate/linear interleaved layout (ldB = 2K col-major +// view), same SwArgs shape, same stride-based input checks. +// +// D = swiglu(A @ B.T) +// +// A : (M, K) bf16 row-major +// B : (N, K) bf16 row-major (torch.nn.Linear weight convention; N even) +// D : (M, N/2) bf16 row-major (strided view of (M, ldd) host-padded buffer) +// +// AUTOTUNE: at first call per (M, N, K) tuple the runner times every +// registered (TileShape, Stages) candidate and caches the fastest one. The +// candidate set targets H100's ~228 KiB dynamic-smem budget; per-stage smem +// for Sm90DualGemm = (BM + 2*BN) * BK * 2 (bf16) * stages. +// +// Built by magi_compiler/passes/piecewise_graph/fusion/cutlass_fusion/ +// evt_runtime.py::_compile_swiglu_dual when the live device's compute +// capability is sm_90; everything else routes to the SM80 sibling. + +#include +#include +#include + +#include +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" + +// Vendored at cutlass_kernels/hopper_dual_gemm/. Resolved by adding +// cutlass_kernels/ itself to nvcc's extra_include_paths in evt_runtime.py. +#include "hopper_dual_gemm/device/sm90_dual_gemm.h" +#include "swiglu_combine.h" + +//////////////////////////////////////////////////////////////////////////////// +// Data types +//////////////////////////////////////////////////////////////////////////////// + +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::bfloat16_t; +using ElementC = cutlass::bfloat16_t; +using ElementAcc = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB0 = cutlass::layout::ColumnMajor; // strided ldB = 2K view +using LayoutB1 = cutlass::layout::ColumnMajor; // strided ldB = 2K view +using LayoutC = cutlass::layout::RowMajor; + +// Greedy-picked on the host side via -DMAGI_SWIGLU_ALIGN_*_BITS — same macro +// plumbing as the sm_80 path. Defaults give 128-bit (8 elem for bf16) loads / +// stores; the host can drop to 64-bit when a shape only meets 8B alignment. +#ifndef MAGI_SWIGLU_ALIGN_A_BITS +#define MAGI_SWIGLU_ALIGN_A_BITS 128 +#endif +#ifndef MAGI_SWIGLU_ALIGN_B_BITS +#define MAGI_SWIGLU_ALIGN_B_BITS 128 +#endif +#ifndef MAGI_SWIGLU_ALIGN_C_BITS +#define MAGI_SWIGLU_ALIGN_C_BITS 128 +#endif +constexpr int AlignmentA = MAGI_SWIGLU_ALIGN_A_BITS / cutlass::sizeof_bits::value; +constexpr int AlignmentB = MAGI_SWIGLU_ALIGN_B_BITS / cutlass::sizeof_bits::value; +constexpr int EpilogueVecCount = MAGI_SWIGLU_ALIGN_C_BITS / cutlass::sizeof_bits::value; + +constexpr auto kScaleType = cutlass::epilogue::thread::ScaleType::Nothing; +constexpr bool kSplitKSerial = false; +constexpr bool kStoreD0 = false; +constexpr bool kStoreD1 = false; + +//////////////////////////////////////////////////////////////////////////////// +// Per-tile Sm90DualGemm wrapper. Each autotune candidate instantiates the +// full kernel for its (TileShape, Stages) tuple. Compile time grows linearly +// with the candidate count — keep the set small and shape-relevant. +//////////////////////////////////////////////////////////////////////////////// + +template +struct DualGemmConfigSm90 { + using TileShape = TileShape_; + static constexpr int kStages = Stages_; + + using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute, kScaleType>; + using EpilogueOp1 = cutlass::epilogue::thread::LinearCombination< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute, kScaleType>; + using EpilogueOp2 = cutlass::epilogue::thread::SwigluCombine< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute>; + + using Gemm = cutlass::gemm::device::Sm90DualGemm< + ElementA, LayoutA, + ElementB, LayoutB0, LayoutB1, + ElementC, LayoutC, + ElementAcc, + TileShape, + EpilogueOp0, EpilogueOp1, EpilogueOp2, + kStages, + kStoreD0, kStoreD1, kSplitKSerial, + AlignmentA, AlignmentB>; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Type-erased runner concept; one instance per autotune candidate. +// Same SwArgs layout as the sm_80 path — keeps the host wrapper identical. +//////////////////////////////////////////////////////////////////////////////// + +struct SwArgs { + int M; // activations rows + int N_out; // = N/2 (output cols) + int K; + void* ptr_A; + void* ptr_B; // (N, K) row-major weight; gate/linear interleaved + void* ptr_D; // (M, N_out) — strided view of an (M, ldd) padded buffer + int64_t ldd; // D's row stride in elements; >= N_out, multiple of EpilogueVecCount + float alpha; // silu_alpha scaling: x * sigmoid(alpha * x) + float limit; // clamp bound: clamp(gate, max=limit), clamp(linear, -limit, limit) + float one; // additive offset: (x_linear + one) +}; + +class SwSm90Concept { + public: + virtual ~SwSm90Concept() = default; + virtual size_t get_workspace_size(const SwArgs&) = 0; + virtual cutlass::Status initialize(const SwArgs&, void* ws, cudaStream_t) = 0; + virtual cutlass::Status run(cudaStream_t stream) = 0; + virtual const char* name() const = 0; +}; + +template +class SwSm90Impl : public SwSm90Concept { + public: + using GemmType = typename Cfg::Gemm; + using EpilogueOp0 = typename Cfg::EpilogueOp0; + using EpilogueOp1 = typename Cfg::EpilogueOp1; + using EpilogueOp2 = typename Cfg::EpilogueOp2; + + explicit SwSm90Impl(const char* name) : name_(name) {} + + typename GemmType::Arguments make_args(const SwArgs& a) { + auto ptrA = reinterpret_cast(a.ptr_A); + auto ptrB = reinterpret_cast(a.ptr_B); + auto ptrD = reinterpret_cast(a.ptr_D); + int const M = a.M, N_out = a.N_out, K = a.K; + + int64_t const ldB_strided = static_cast(2) * K; + LayoutB0 layoutB_gate(ldB_strided); + LayoutB1 layoutB_linear(ldB_strided); + // ldd carries the host-padded row stride; Sm90DualGemm reads it via + // ref_D2.stride(0) at run() time, so a strided D view works without + // touching the vendored device/kernel headers. + LayoutC layoutC(a.ldd); + + using TensorRefA = cutlass::TensorRef; + using TensorRefB0 = cutlass::TensorRef; + using TensorRefB1 = cutlass::TensorRef; + using TensorRefCi = cutlass::TensorRef; + using TensorRefDo = cutlass::TensorRef; + + TensorRefA ref_A0(ptrA, LayoutA(static_cast(K))); + TensorRefB0 ref_B0(ptrB, layoutB_gate); // W_gate (even rows) + TensorRefCi ref_C0(nullptr, LayoutC(0)); + TensorRefDo ref_D0(nullptr, LayoutC(0)); + TensorRefB1 ref_B1(ptrB + K, layoutB_linear); // W_linear (odd rows) + TensorRefCi ref_C1(nullptr, LayoutC(0)); + TensorRefDo ref_D1(nullptr, LayoutC(0)); + TensorRefDo ref_D2(ptrD, layoutC); // output + + typename EpilogueOp0::Params epi0{ElementCompute(1.0f), ElementCompute(0.0f)}; + typename EpilogueOp1::Params epi1{ElementCompute(1.0f), ElementCompute(0.0f)}; + typename EpilogueOp2::Params epi2{ + ElementCompute(a.alpha), ElementCompute(a.limit), ElementCompute(a.one)}; + + cutlass::gemm::GemmCoord problem{M, N_out, K}; + + typename GemmType::Arguments args( + cutlass::gemm::DualGemmMode::kGemm, + problem, + ref_A0, + ref_B0, ref_C0, ref_D0, + ref_B1, ref_C1, ref_D1, + ref_D2, + epi0, epi1, epi2, + /*split_k_slices=*/1, + /*batch_count=*/1); + return args; + } + + size_t get_workspace_size(const SwArgs& a) override { + return GemmType::get_workspace_size(make_args(a)); + } + cutlass::Status initialize(const SwArgs& a, void* ws, cudaStream_t s) override { + return gemm_.initialize(make_args(a), ws, s); + } + cutlass::Status run(cudaStream_t stream) override { + return gemm_.run(stream); + } + const char* name() const override { return name_; } + + private: + GemmType gemm_; + const char* name_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// AutoTune runner — first call per (M, N_out, K) shape times all candidates. +//////////////////////////////////////////////////////////////////////////////// + +#define SW_SM90_TILE(bm, bn, bk, stages, label) \ + configs_.push_back(std::make_unique< \ + SwSm90Impl, cute::Int, cute::Int>, \ + stages>>>(label)) + +class SwSm90AutoTuneRunner { + public: + SwSm90AutoTuneRunner() { + // Tile candidates for H100 (sm_90a, ~228 KiB dynamic SMEM/SM, 132 SMs). + // + // SMEM cost = (BM + 2*BN) * BK * 2 (bf16) * stages. Stay under ~200 KiB + // to leave room for barriers and TMA descriptors. Sm90DualGemm requires + // BM >= 128 to enable cooperative dual consumer warpgroups (the perf + // sweet spot); smaller BM falls back to a single-wg path. + // + // Candidates intentionally span small/medium/large M; the runner picks + // the best one per (M, N_out, K) tuple at first call. + + // ── Reference / prefill sweet spot ─────────────────────────────────────── + SW_SM90_TILE(128, 128, 64, 4, "Sm90<128,128,64>_S4"); // 192 KiB + SW_SM90_TILE(128, 128, 64, 3, "Sm90<128,128,64>_S3"); // 144 KiB + + // ── Decode-style small M ───────────────────────────────────────────────── + SW_SM90_TILE(64, 128, 64, 4, "Sm90<64,128,64>_S4"); // 160 KiB + SW_SM90_TILE(64, 64, 64, 4, "Sm90<64,64,64>_S4"); // 96 KiB + + // ── Alternate small-N ──────────────────────────────────────────────────── + SW_SM90_TILE(128, 64, 64, 4, "Sm90<128,64,64>_S4"); // 128 KiB + + // ── Large prefill ──────────────────────────────────────────────────────── + SW_SM90_TILE(256, 128, 64, 2, "Sm90<256,128,64>_S2"); // 128 KiB + } + + void operator()(at::Tensor A, at::Tensor B, at::Tensor D, + float alpha, float limit, float one) { + TORCH_CHECK(A.is_cuda() && B.is_cuda() && D.is_cuda(), + "all inputs must be CUDA tensors"); + TORCH_CHECK(A.scalar_type() == at::kBFloat16 && B.scalar_type() == at::kBFloat16 + && D.scalar_type() == at::kBFloat16, + "all inputs must be bf16"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2 && D.dim() == 2, "A, B, D must be 2D"); + TORCH_CHECK(A.size(1) == B.size(1), "K mismatch (A.size(1) vs B.size(1))"); + // Stride-based contiguity check (mirrors sm_80 path) — Inductor's + // reinterpret_tensor often hands us a tensor with the right strides but + // tripped is_contiguous() (e.g. larger storage than sizes would imply). + TORCH_CHECK(A.stride(1) == 1, "A innermost stride must be 1; got ", A.stride(1)); + TORCH_CHECK(A.stride(0) >= A.size(1), + "A row stride must be >= K; got stride(0)=", A.stride(0), ", K=", A.size(1)); + TORCH_CHECK(B.stride(1) == 1, "B innermost stride must be 1; got ", B.stride(1)); + TORCH_CHECK(B.stride(0) >= B.size(1), + "B row stride must be >= K; got stride(0)=", B.stride(0), ", K=", B.size(1)); + + int const M = static_cast(A.size(0)); + int const K = static_cast(A.size(1)); + int const N = static_cast(B.size(0)); + TORCH_CHECK((N % 2) == 0, "N must be even, got ", N); + // Sm90DualGemm uses TMA for A/B loads; TMA requires the innermost stride + // **in bytes** to be a multiple of 16 (cudaTensorMapEncodeTiled's hard + // constraint, also enforced by sm90_dual_gemm.h's can_implement via + // constexpr int min_k_align = 128 / sizeof_bits; + // if (problem_size.k() % min_k_align != 0) return kErrorInvalidProblem; + // ). Express in bytes so a future fp8 / fp32 swiglu path inherits the + // gate without a one-line dtype change. For bf16 (sizeof = 2) this + // reduces to K % 8 == 0; for fp32 (sizeof = 4) → K % 4; for fp8 → K % 16. + constexpr int kMinKAlignBytes = 16; + constexpr int kElemBytes = sizeof(ElementA); + constexpr int kMinKAlignElems = kMinKAlignBytes / kElemBytes; + TORCH_CHECK((K % kMinKAlignElems) == 0, + "Sm90 swiglu requires K * sizeof(elem) % 16 == 0 (TMA's 128-bit " + "alignment in bytes); got K=", K, ", elem_bytes=", kElemBytes, + ", required K % ", kMinKAlignElems, + " == 0. This shape is fusion-eligible only on the sm_80/sm_120 path."); + int const N_out = N / 2; + TORCH_CHECK(D.size(0) == M && D.size(1) == N_out, + "D must be (M, N/2) = (", M, ",", N_out, ")"); + TORCH_CHECK(D.stride(1) == 1, "D innermost stride must be 1; got ", D.stride(1)); + TORCH_CHECK(D.stride(0) >= N_out, + "D row stride must be >= N_out; got stride(0)=", D.stride(0), ", N_out=", N_out); + + SwArgs ea; + ea.M = M; ea.N_out = N_out; ea.K = K; + ea.ptr_A = A.data_ptr(); + ea.ptr_B = B.data_ptr(); + ea.ptr_D = D.data_ptr(); + ea.ldd = static_cast(D.stride(0)); + ea.alpha = alpha; ea.limit = limit; ea.one = one; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()).stream(); + + // Single autotune per module. The .cu is compiled per (m_bucket, N, K, + // alignA, alignB, alignC) on the Python side — every distinct shape + // bucket gets its own runner instance with isolated `best_idx_`. + if (best_idx_ < 0) { + best_idx_ = autotune(ea, stream); + } + int idx = best_idx_; + + auto& gemm = configs_[idx]; + size_t ws_sz = gemm->get_workspace_size(ea); + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) { + ws_ = at::empty({(int64_t)ws_sz + 1}, + at::TensorOptions().dtype(at::kByte).device(A.device())); + } + auto st = gemm->initialize(ea, ws_sz > 0 ? ws_.data_ptr() : nullptr, stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "Sm90DualGemm init failed (", gemm->name(), "): ", + cutlassGetStatusString(st)); + st = gemm->run(stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "Sm90DualGemm run failed (", gemm->name(), "): ", + cutlassGetStatusString(st)); + } + + int num_configs() const { return (int)configs_.size(); } + + private: + int autotune(const SwArgs& ea, cudaStream_t stream) { + int best_idx = -1; + float best_time = 1e30f; + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + + for (size_t i = 0; i < configs_.size(); ++i) { + auto& g = configs_[i]; + size_t ws_sz = 0; + try { ws_sz = g->get_workspace_size(ea); } + catch (...) { continue; } + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) { + ws_ = at::empty({(int64_t)ws_sz + 1}, + at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + } + void* ws_ptr = ws_sz > 0 ? ws_.data_ptr() : nullptr; + if (g->initialize(ea, ws_ptr, stream) != cutlass::Status::kSuccess) { + continue; + } + + // Warmup — 10 iters so the L2 / instruction cache settle. + for (int w = 0; w < 10; ++w) g->run(stream); + cudaStreamSynchronize(stream); + + // Time — 50 iters keeps timing noise to <1%. + cudaEventRecord(s, stream); + int iters = 50; + for (int p = 0; p < iters; ++p) g->run(stream); + cudaEventRecord(e, stream); + cudaEventSynchronize(e); + float ms = 0; + cudaEventElapsedTime(&ms, s, e); + float avg = ms / iters; + if (avg < best_time) { best_time = avg; best_idx = (int)i; } + } + cudaEventDestroy(s); cudaEventDestroy(e); + TORCH_CHECK(best_idx >= 0, + "Sm90DualGemm AutoTune: no candidate succeeded for (M,N_out,K)=(", + ea.M, ",", ea.N_out, ",", ea.K, ")"); + return best_idx; + } + + std::vector> configs_; + int best_idx_ = -1; // -1 = not yet autotuned; sticky after first call. + at::Tensor ws_; +}; + +static SwSm90AutoTuneRunner& runner() { + static SwSm90AutoTuneRunner R; + return R; +} + +void swiglu_dual_matmul_out(at::Tensor A, at::Tensor B, at::Tensor D, + float alpha, float limit, float one) { + runner()(std::move(A), std::move(B), std::move(D), alpha, limit, one); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "CUTLASS Sm90 DualGemm fully-fused swiglu (bf16) on sm_90a — autotune"; + m.def("swiglu_dual_matmul_out", + &swiglu_dual_matmul_out, + "D = swiglu(A @ B.T) in a single fused Sm90 (TMA+WGMMA) kernel; " + "A:(M,K) bf16, B:(N,K) bf16 (N even), D:(M,N/2) bf16 (strided ok)", + pybind11::arg("A"), + pybind11::arg("B"), + pybind11::arg("D"), + pybind11::arg("alpha") = 1.702f, + pybind11::arg("limit") = 7.0f, + pybind11::arg("one") = 1.0f); + m.def("num_configs", []() { return runner().num_configs(); }); +} diff --git a/magi_compiler/passes/piecewise_graph/fusion/sm90/evt_codegen.py b/magi_compiler/passes/piecewise_graph/fusion/sm90/evt_codegen.py new file mode 100644 index 0000000..f14a7c8 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/sm90/evt_codegen.py @@ -0,0 +1,803 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Render a CUTLASS 3.x Sm90EVT .cu source from an EVT IR tree — H100 path. + +Uses TMA + WGMMA via warp-specialized collective builders; ~1.6-2x faster +than the SM80 path on H100. Selected by ``evt_runtime`` when arch == sm_90. + +All AuxLoad nodes use ``Sm90AuxLoad<0>`` (inline ld.global, no SMEM +staging). The C-operand TMA channel is left unused (ptr_C = nullptr). +The same ``AuxLoad.input_idx`` may appear at multiple positions in the +EVT tree (matching SM80 behaviour); the leaf-args dict produces +identical expressions for the same index so the overwrite is harmless. +""" + +from __future__ import annotations + +from typing import Dict, List, Tuple + +from ..common.codegen_shared import ( + _BUILTIN_FN_TEMPLATE, + _DTYPE_TO_AT, + _DTYPE_TO_AT_CPP, + _DTYPE_TO_CUTLASS, + _VALID_ALIGN_BITS, + _emit_custom_functor, +) +from ..evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store, walk_leaves + +# (TM, TN, TK, CM, CN, CK, schedule, label). +# Cluster_M=1 → Pingpong; Cluster_M>=2 → Cooperative. Mismatched combos +# fail at can_implement and are skipped by autotune. +# H100: 132 SMs, 228 KB SMEM / SM. + +_TILE_CANDIDATES_SM90: dict = { + "small": [ + (64, 128, 64, 1, 1, 1, "pingpong", "T<64,128,64>_Cl<1,1,1>_PP"), + (64, 256, 64, 1, 1, 1, "pingpong", "T<64,256,64>_Cl<1,1,1>_PP"), + (128, 128, 64, 1, 1, 1, "pingpong", "T<128,128,64>_Cl<1,1,1>_PP"), + (128, 256, 64, 1, 1, 1, "pingpong", "T<128,256,64>_Cl<1,1,1>_PP"), + (64, 128, 128, 1, 1, 1, "pingpong", "T<64,128,128>_Cl<1,1,1>_PP"), + (64, 256, 128, 1, 1, 1, "pingpong", "T<64,256,128>_Cl<1,1,1>_PP"), + ], + "medium": [ + (128, 128, 64, 1, 1, 1, "pingpong", "T<128,128,64>_Cl<1,1,1>_PP"), + (128, 256, 64, 1, 1, 1, "pingpong", "T<128,256,64>_Cl<1,1,1>_PP"), + (128, 128, 64, 2, 1, 1, "cooperative", "T<128,128,64>_Cl<2,1,1>_CO"), + (128, 256, 64, 2, 1, 1, "cooperative", "T<128,256,64>_Cl<2,1,1>_CO"), + (256, 128, 64, 2, 1, 1, "cooperative", "T<256,128,64>_Cl<2,1,1>_CO"), + (256, 256, 64, 2, 1, 1, "cooperative", "T<256,256,64>_Cl<2,1,1>_CO"), + ], + "large": [ + (128, 256, 64, 2, 1, 1, "cooperative", "T<128,256,64>_Cl<2,1,1>_CO"), + (256, 128, 64, 2, 1, 1, "cooperative", "T<256,128,64>_Cl<2,1,1>_CO"), + (256, 256, 64, 2, 1, 1, "cooperative", "T<256,256,64>_Cl<2,1,1>_CO"), + (128, 256, 64, 2, 2, 1, "cooperative", "T<128,256,64>_Cl<2,2,1>_CO"), + (256, 128, 64, 2, 2, 1, "cooperative", "T<256,128,64>_Cl<2,2,1>_CO"), + (256, 256, 64, 2, 2, 1, "cooperative", "T<256,256,64>_Cl<2,2,1>_CO"), + ], +} + + +_SCHEDULE_TYPES = { + "pingpong": ("cutlass::gemm::KernelTmaWarpSpecializedPingpong", "cutlass::epilogue::TmaWarpSpecialized"), + "cooperative": ("cutlass::gemm::KernelTmaWarpSpecializedCooperative", "cutlass::epilogue::TmaWarpSpecializedCooperative"), +} + + +def _emit_tile_candidates(m_bucket: str) -> str: + """Emit C++ EVT_TILE_CANDIDATE(...) statements for the given M bucket.""" + candidates = _TILE_CANDIDATES_SM90.get(m_bucket, _TILE_CANDIDATES_SM90["medium"]) + lines = [] + for tm, tn, tk, cm, cn, ck, schedule, label in candidates: + kernel_sched, epi_sched = _SCHEDULE_TYPES[schedule] + lines.append( + f" EVT_TILE_CANDIDATE(" f"{tm}, {tn}, {tk}, {cm}, {cn}, {ck}, " f"{kernel_sched}, {epi_sched}, " f'"{label}");' + ) + return "\n".join(lines) + + +class _Sm90EvtEmitter: + """Bottom-up walker emitting Sm90EVT typedef chains. + + Unlike SM80, there is no Store wrapper — the CollectiveEpilogue owns + the store; the EVT root is the topmost compute node. + """ + + def __init__(self, root: Store): + self.root = root + self.typedef_lines: List[str] = [] + self.functor_decls: List[str] = [] + self._emitted_functors: Dict[Tuple[str, str], str] = {} + self._tmp_counter = 0 + self.leaf_typedefs: List[Tuple[str, str, "int | None", str]] = [] + self.scalar_functor_counter = 0 + + def _new_name(self, prefix: str) -> str: + self._tmp_counter += 1 + return f"{prefix}_{self._tmp_counter}" + + def _functor_name_for(self, op: str, scalar) -> str: + key = (op, repr(scalar) if scalar is not None else "") + if key in self._emitted_functors: + return self._emitted_functors[key] + scalar_tag = "" + if scalar is not None: + self.scalar_functor_counter += 1 + scalar_tag = f"_v{self.scalar_functor_counter}" + name = f"Magi_{op}{scalar_tag}" + self._emitted_functors[key] = name + self.functor_decls.append(_emit_custom_functor(name, op, scalar)) + return name + + def _compute_op_template(self, node: Compute) -> str: + if node.op in _BUILTIN_FN_TEMPLATE and node.scalar is None: + return _BUILTIN_FN_TEMPLATE[node.op] + return self._functor_name_for(node.op, node.scalar) + + def emit(self) -> str: + """Walk the IR and return the typedef name of the EVT root.""" + return self._emit_node(self.root.child) + + def _emit_node(self, node) -> str: + if isinstance(node, Accum): + name = self._new_name("AccFetch") + self.typedef_lines.append(f"using {name} = cutlass::epilogue::fusion::Sm90AccFetch;") + return name + if isinstance(node, RowBroadcast): + name = self._new_name("RowBcast") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::fusion::Sm90RowBroadcast<\n" + f" /*Stages=*/0, TileShape, {elem}, ElementCompute>;" + ) + self.leaf_typedefs.append((name, "row_bcast", node.input_idx, node.dtype)) + return name + if isinstance(node, ColBroadcast): + name = self._new_name("ColBcast") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::fusion::Sm90ColBroadcast<\n" + f" /*Stages=*/0, TileShape, {elem}, ElementCompute>;" + ) + self.leaf_typedefs.append((name, "col_bcast", node.input_idx, node.dtype)) + return name + if isinstance(node, AuxLoad): + elem = _DTYPE_TO_CUTLASS[node.dtype] + name = self._new_name("AuxLoad") + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::fusion::Sm90AuxLoad<\n" + f" /*Stages=*/0, /*EpilogueTile=*/void, {elem},\n" + f" cutlass::layout::RowMajor, /*SmemLayoutAtom=*/void, /*CopyOpS2R=*/void>;" + ) + self.leaf_typedefs.append((name, "aux_load_inline", node.input_idx, node.dtype)) + return name + if isinstance(node, Compute): + child_names = [self._emit_node(c) for c in node.children] + compute_name = self._new_name(f"Cmp_{node.op}") + fn_template = self._compute_op_template(node) + elem_compute = _DTYPE_TO_CUTLASS[node.compute_dtype] + self.typedef_lines.append( + f"using {compute_name} = cutlass::epilogue::fusion::Sm90Compute<\n" + f" {fn_template}, {elem_compute}, {elem_compute},\n" + f" cutlass::FloatRoundStyle::round_to_nearest>;" + ) + evt_name = self._new_name(f"EVT_{node.op}") + child_typedef_list = ", ".join(child_names) + self.typedef_lines.append( + f"using {evt_name} = cutlass::epilogue::fusion::Sm90EVT<\n" f" {compute_name}, {child_typedef_list}>;" + ) + return evt_name + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +def _emit_args_tree(node, leaf_args: Dict[int, str], indent: int = 8) -> str: + """Emit the nested-brace runtime args literal mirroring the Sm90EVT tree.""" + pad = " " * indent + if isinstance(node, Accum): + return f"{pad}{{}}" + if isinstance(node, (AuxLoad, RowBroadcast, ColBroadcast)): + return f"{pad}{leaf_args[node.input_idx]}" + if isinstance(node, Compute): + children_str = ",\n".join(_emit_args_tree(c, leaf_args, indent + 2) for c in node.children) + return f"{pad}{{\n" f"{children_str},\n" f"{pad} {{}}\n" f"{pad}}}" # this Sm90Compute's op args (always empty) + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +_KERNEL_PREAMBLE_SM90 = """\ +// AUTO-GENERATED by magi_compiler/passes/piecewise_graph/fusion/sm90/evt_codegen.py +// Do not edit by hand. Regenerate by re-running the FX pass. +// +// IR cache key: {cache_key} + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/fast_math.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/util/packed_stride.hpp" + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////// +// Custom functors (one per unique scalar-baked op or non-builtin unary). +//////////////////////////////////////////////////////////////////////////////// +{functor_decls} + +//////////////////////////////////////////////////////////////////////////////// +// Data types and layouts +//////////////////////////////////////////////////////////////////////////////// + +using ElementA = {a_elem}; +using ElementB = {b_elem}; +// C-operand TMA channel is unused (all AuxLoad nodes use Sm90AuxLoad<0> +// which loads via ld.global). ElementC = ElementD; ptr_C = nullptr. +using ElementC = {c_elem}; +using ElementD = {d_elem}; +using ElementAccumulator = float; +using ElementCompute = float; + +using LayoutATag = cutlass::layout::RowMajor; +using LayoutBTag = cutlass::layout::{b_layout}; +using LayoutCTag = cutlass::layout::RowMajor; +using LayoutDTag = cutlass::layout::RowMajor; + +constexpr int AlignmentA = {alignment_a_bits} / cutlass::sizeof_bits::value; +constexpr int AlignmentB = {alignment_b_bits} / cutlass::sizeof_bits::value; +constexpr int AlignmentC = {alignment_c_bits} / cutlass::sizeof_bits::value; +constexpr int AlignmentD = {alignment_c_bits} / cutlass::sizeof_bits::value; + +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + +//////////////////////////////////////////////////////////////////////////////// +// Per-tile-config GEMM type. The Sm90 EVT typedefs reference TileShape (each +// Sm90RowBroadcast / Sm90ColBroadcast bakes the tile dims into its on-the-fly +// loader), and CollectiveBuilder consumes (TileShape, ClusterShape, Schedule) +// — so every autotune candidate must re-instantiate the entire EVT chain + +// CollectiveEpilogue + CollectiveMainloop + GemmKernel. We package the whole +// tree inside a template struct keyed on the four tile/cluster/schedule +// parameters so each candidate is a distinct C++ type that can live side-by- +// side in ``configs_``. +//////////////////////////////////////////////////////////////////////////////// + +template +struct EvtConfig {{ + using TileShape = TileShape_; + using ClusterShape = ClusterShape_; + using KernelSchedule = KernelSchedule_; + using EpilogueSchedule = EpilogueSchedule_; + + //////////////////////////////////////////////////////////////////////////// + // EVT (Sm90 Epilogue Visitor Tree) typedefs — generated from the IR. + // No outermost StoreD wrapper — the CollectiveEpilogue owns the store; the + // EVT root is the topmost compute / leaf node. + //////////////////////////////////////////////////////////////////////////// +{typedef_block} + + using FusionCallbacks = {evt_root_name}; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + // AutoCarveout picks the max stages that fit in the actual epilogue's + // SharedStorage footprint for the target arch. On H100 this lands on ~6-7 + // stages for typical TileShape<128,128,64>; bigger tiles automatically get + // fewer stages. Aggressive choice is safe because this codegen is sm_90- + // only (the runtime dispatcher routes other arches to sm80/evt_codegen.py). + using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + StageCountType, + KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; +}}; + +//////////////////////////////////////////////////////////////////////////////// +// Autotune runner — one candidate per (TileShape, ClusterShape, Schedule) +// tuple; first call at a new (M, N, K) tuple times every candidate that +// can_implement accepts and caches the winner. +//////////////////////////////////////////////////////////////////////////////// + +struct EvtArgs {{ + int M; + int N; + int K; + void* ptr_A; + void* ptr_B; + void* ptr_D; + int64_t lda; + int64_t ldb; + int64_t ldd; + // Extras pointers, in IR-leaf order. Each AuxLoad / RowBroadcast / + // ColBroadcast looks up its pointer from this vector by its IR + // input_idx baked into the launcher. + std::vector ptr_extras; + // Row strides for AuxLoad extras (stride(0) in elements). Indexed in + // the same order as ptr_extras; RowBroadcast/ColBroadcast entries are + // unused but still present so indices stay aligned. + std::vector stride_extras; +}}; + +class EvtConcept {{ + public: + virtual ~EvtConcept() = default; + virtual size_t get_workspace_size(const EvtArgs&) = 0; + virtual cutlass::Status can_implement(const EvtArgs&) = 0; + virtual cutlass::Status initialize(const EvtArgs&, void* ws, cudaStream_t s) = 0; + virtual cutlass::Status run(cudaStream_t stream) = 0; + virtual const char* name() const = 0; +}}; + +template +class EvtImpl : public EvtConcept {{ + public: + using GemmType = typename Cfg::Gemm; + using StrideA = typename Cfg::StrideA; + using StrideB = typename Cfg::StrideB; + using StrideC = typename Cfg::StrideC; + using StrideD = typename Cfg::StrideD; + + explicit EvtImpl(const char* name) : name_(name) {{}} + + typename GemmType::Arguments make_args(const EvtArgs& a) {{ + auto ptrA = reinterpret_cast(a.ptr_A); + auto ptrB = reinterpret_cast(a.ptr_B); + auto ptrD = reinterpret_cast (a.ptr_D); + int const M = a.M; + int const N = a.N; + int const K = a.K; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{{}}, cute::make_shape(M, K, 1)); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{{}}, cute::make_shape(N, K, 1)); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{{}}, cute::make_shape(M, N, 1)); + // D's row stride comes from the actual tensor (ea.ldd = D.stride(0)), + // which may be larger than N when the runtime pads the output buffer to + // a 16-byte boundary. Using N here would give TMA a wrong + // globalStride, corrupting every row after the first. + auto stride_D = cutlass::make_cute_packed_stride(StrideD{{}}, cute::make_shape(M, static_cast(a.ldd), 1)); + // Per-AuxLoad strides — each extra may have a different row stride + // (e.g. padded buffers where stride(0) > N). Emitted unconditionally; + // nvcc -O3 drops unused variables. +{aux_stride_decls} + + // C-operand TMA channel unused — all AuxLoad nodes use Sm90AuxLoad<0> + // (inline ld.global). ptr_C is nullptr; no node reports + // is_C_load_needed()=true so CollectiveEpilogue skips the C TMA load. + auto ptrC = {ptr_C_expr_in_make_args}; + + typename GemmType::Arguments args{{ + cutlass::gemm::GemmUniversalMode::kGemm, + {{M, N, K, 1}}, + {{ ptrA, stride_A, ptrB, stride_B }}, + {{ // epilogue args = ( FusionCallbacks_args, ptr_C, stride_C, ptr_D, stride_D ) +{args_tree}, + ptrC, stride_C, + ptrD, stride_D + }} + }}; + return args; + }} + + size_t get_workspace_size(const EvtArgs& a) override {{ + auto args = make_args(a); + return GemmType::get_workspace_size(args); + }} + cutlass::Status can_implement(const EvtArgs& a) override {{ + auto args = make_args(a); + return gemm_.can_implement(args); + }} + cutlass::Status initialize(const EvtArgs& a, void* ws, cudaStream_t s) override {{ + auto args = make_args(a); + return gemm_.initialize(args, ws, s); + }} + cutlass::Status run(cudaStream_t stream) override {{ + return gemm_.run(stream); + }} + const char* name() const override {{ return name_; }} + + private: + GemmType gemm_; + const char* name_; +}}; + +//////////////////////////////////////////////////////////////////////////////// +// Python-facing launcher — same evt_matmul_out signature as the SM80 path +// so the dispatcher in evt_runtime.py picks up the same attribute name. +//////////////////////////////////////////////////////////////////////////////// +""" + + +_LAUNCHER_TEMPLATE_SM90 = """\ +//////////////////////////////////////////////////////////////////////////////// +// Tile candidate registration. Each EVT_TILE_CANDIDATE invocation instantiates +// the full EvtConfig — EVT typedef tree + CollectiveEpilogue + CollectiveMain- +// loop + GemmKernel — for that (TileShape, ClusterShape, Schedule) tuple. +// Compile time grows linearly with the candidate count; bucket lists are kept +// at ~6 candidates each. Mismatched (schedule, cluster) combos compile fine +// but die at can_implement and are skipped silently by autotune(). +//////////////////////////////////////////////////////////////////////////////// + +#define EVT_TILE_CANDIDATE(tm, tn, tk, cm, cn, ck, kernel_sched, epi_sched, label) \\ + configs_.push_back(std::make_unique, Int, Int>, \\ + Shape, Int, Int>, \\ + kernel_sched, epi_sched>>>(label)) + +class EvtAutoTuneRunner {{ + public: + EvtAutoTuneRunner() {{ +{tile_candidate_block} + }} + + void operator()(at::Tensor A, at::Tensor B, + std::vector extras, at::Tensor D) {{ + TORCH_CHECK(A.is_cuda() && B.is_cuda() && D.is_cuda(), + "evt_matmul_out: A/B/D must be CUDA tensors"); + TORCH_CHECK(A.scalar_type() == {a_at_dtype}, "A must be {a_dtype}"); + TORCH_CHECK(B.scalar_type() == {b_at_dtype}, "B must be {b_dtype}"); + TORCH_CHECK(D.scalar_type() == {d_at_dtype}, "D must be {d_dtype}"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2 && D.dim() == 2, "A, B, D must be 2D"); + // Stride-based contiguity (Inductor's reinterpret_tensor often trips + // .is_contiguous() with the "right" strides). + TORCH_CHECK(A.stride(1) == 1, "A innermost stride must be 1; got ", A.stride(1)); + TORCH_CHECK(A.stride(0) >= A.size(1), + "A row stride must be >= K; got stride(0)=", A.stride(0), ", K=", A.size(1)); + {b_stride_check} + + int const M = static_cast(A.size(0)); + int const K = static_cast(A.size(1)); + int const N = static_cast({n_dim_expr}); + + TORCH_CHECK(D.size(0) == M && D.size(1) == N, + "D must be (M, N); got ", D.sizes()); + TORCH_CHECK(D.stride(1) == 1, "D innermost stride must be 1; got ", D.stride(1)); + TORCH_CHECK(D.stride(0) >= N, + "D row stride must be >= N; got stride(0)=", D.stride(0), ", N=", N); + TORCH_CHECK(extras.size() == {n_extras}, "expected {n_extras} extra tensors, got ", extras.size()); + +{extras_validation} + + const c10::cuda::CUDAGuard guard(A.device()); + auto stream = at::cuda::getCurrentCUDAStream(A.device().index()).stream(); + + EvtArgs ea; + ea.M = M; ea.N = N; ea.K = K; + ea.ptr_A = A.data_ptr<{a_at_cpp}>(); + ea.ptr_B = B.data_ptr<{b_at_cpp}>(); + ea.ptr_D = D.data_ptr<{d_at_cpp}>(); + ea.lda = static_cast(A.stride(0)); + ea.ldb = static_cast(B.stride(0)); + ea.ldd = static_cast(D.stride(0)); + ea.ptr_extras.reserve({n_extras}); +{extras_ptrs} + + // Single autotune per module. The .cu is compiled per (IR, M-bucket, + // b_layout, N, K) on the Python side — every distinct weight (N, K) + // gets its own .cu, so this runner instance hosts exactly one (N, K) + // and one bucket of M values. Autotune once on the first call; all + // subsequent calls (any M inside the bucket) reuse `best_idx_`. + if (best_idx_ < 0) {{ + best_idx_ = autotune(ea, stream); + }} + int idx = best_idx_; + + auto& gemm = configs_[idx]; + size_t ws_sz = gemm->get_workspace_size(ea); + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) {{ + ws_ = at::empty({{(int64_t)ws_sz + 1}}, + at::TensorOptions().dtype(at::kByte).device(A.device())); + }} + auto st = gemm->initialize(ea, ws_sz > 0 ? ws_.data_ptr() : nullptr, stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "Sm90 EVT init failed (", gemm->name(), "): ", cutlassGetStatusString(st)); + st = gemm->run(stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "Sm90 EVT run failed (", gemm->name(), "): ", cutlassGetStatusString(st)); + }} + + int num_configs() const {{ return (int)configs_.size(); }} + + private: + int autotune(const EvtArgs& ea, cudaStream_t stream) {{ + int best_idx = -1; + float best_time = 1e30f; + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + + // Drain any pre-existing CUDA error so we don't blame our first candidate + // for an upstream failure. + (void)cudaGetLastError(); + + for (size_t i = 0; i < configs_.size(); ++i) {{ + auto& g = configs_[i]; + // can_implement gates illegal (schedule, cluster) combos and shapes + // that don't satisfy the kernel's M/N/K divisibility — these would + // crash at initialize() otherwise. + if (g->can_implement(ea) != cutlass::Status::kSuccess) continue; + size_t ws_sz = 0; + try {{ ws_sz = g->get_workspace_size(ea); }} + catch (...) {{ (void)cudaGetLastError(); continue; }} + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) {{ + ws_ = at::empty({{(int64_t)ws_sz + 1}}, + at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + }} + void* ws_ptr = ws_sz > 0 ? ws_.data_ptr() : nullptr; + // initialize() can fail synchronously (e.g. cudaFuncSetAttribute returns + // cudaErrorInvalidValue for tiles whose SharedStorage exceeds the + // device opt-in cap). Clear the sticky CUDA error before moving on — + // otherwise the next launch (or post-autotune user run) inherits it + // and surfaces a misleading "Error Internal" against an unrelated tile. + if (g->initialize(ea, ws_ptr, stream) != cutlass::Status::kSuccess) {{ + (void)cudaGetLastError(); + continue; + }} + + // Warmup — 10 iters so L2 / inst caches settle (3 was too few — first + // timed iter saw a cold L2 and biased the choice towards smaller tiles). + // Capture run() status and sync return codes so an async launch failure + // (e.g. invalid grid, latent SMEM issue) disqualifies the tile cleanly. + bool tile_ok = true; + for (int w = 0; w < 10; ++w) {{ + if (g->run(stream) != cutlass::Status::kSuccess) {{ tile_ok = false; break; }} + }} + if (tile_ok && cudaStreamSynchronize(stream) != cudaSuccess) {{ + tile_ok = false; + }} + if (!tile_ok) {{ + (void)cudaGetLastError(); + continue; + }} + + // Time — 20 iters for ~1% timing noise, matching torch.compile defaults. + cudaEventRecord(s, stream); + int iters = 20; + for (int p = 0; p < iters; ++p) {{ + if (g->run(stream) != cutlass::Status::kSuccess) {{ tile_ok = false; break; }} + }} + cudaEventRecord(e, stream); + if (cudaEventSynchronize(e) != cudaSuccess) tile_ok = false; + if (!tile_ok) {{ + (void)cudaGetLastError(); + continue; + }} + float ms = 0; + cudaEventElapsedTime(&ms, s, e); + float avg = ms / iters; + if (avg < best_time) {{ best_time = avg; best_idx = (int)i; }} + }} + cudaEventDestroy(s); cudaEventDestroy(e); + TORCH_CHECK(best_idx >= 0, + "Sm90 EVT AutoTune: no candidate succeeded for (M,N,K)=(", + ea.M, ",", ea.N, ",", ea.K, ")"); + return best_idx; + }} + + std::vector> configs_; + int best_idx_ = -1; // -1 = not yet autotuned; sticky after first call. + at::Tensor ws_; +}}; + +static EvtAutoTuneRunner& runner() {{ + static EvtAutoTuneRunner R; + return R; +}} + +void evt_matmul_out(at::Tensor A, at::Tensor B, + std::vector extras, at::Tensor D) {{ + runner()(std::move(A), std::move(B), std::move(extras), std::move(D)); +}} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ + m.doc() = "Magi compiler EVT-fused matmul (Sm90 TMA + WGMMA, autotune)"; + m.def("evt_matmul_out", &evt_matmul_out, + "Fused EVT matmul: D = epilogue(A @ B, extras...)", + pybind11::arg("A"), pybind11::arg("B"), + pybind11::arg("extras"), pybind11::arg("D")); + m.def("num_configs", []() {{ return runner().num_configs(); }}); +}} +""" + + +def render_evt_cu( + ir: Store, + a_dtype: str, + b_dtype: str, + cache_key_str: str = "", + b_layout: str = "row", + m_bucket: str = "medium", + alignment_a_bits: int = 128, + alignment_b_bits: int = 128, + alignment_c_bits: int = 128, + arch: str = "sm90", +) -> str: + """Render the SM90 .cu source for ``ir``.""" + if b_layout not in ("row", "col"): + raise ValueError(f"b_layout must be 'row' or 'col', got {b_layout!r}") + if m_bucket not in _TILE_CANDIDATES_SM90: + raise ValueError(f"unknown m_bucket {m_bucket!r}; " f"expected one of {list(_TILE_CANDIDATES_SM90)}") + if ( + alignment_a_bits not in _VALID_ALIGN_BITS + or alignment_b_bits not in _VALID_ALIGN_BITS + or alignment_c_bits not in _VALID_ALIGN_BITS + ): + raise ValueError( + f"alignment_*_bits must be one of {_VALID_ALIGN_BITS}; " + f"got A={alignment_a_bits}, B={alignment_b_bits}, C={alignment_c_bits}" + ) + if not isinstance(ir, Store): + raise TypeError("render_evt_cu (sm90) expects a Store node as root") + del arch + + a_elem = _DTYPE_TO_CUTLASS[a_dtype] + b_elem = _DTYPE_TO_CUTLASS[b_dtype] + d_elem = _DTYPE_TO_CUTLASS[ir.out_dtype] + + emitter = _Sm90EvtEmitter(ir) + evt_root = emitter.emit() + + # No Sm90SrcFetch — the C-operand TMA channel is unused (ptr_C = nullptr). + # ElementC must still be a concrete type for the CollectiveBuilder template. + c_elem = d_elem + + leaves = walk_leaves(ir) + leaf_args: Dict[int, str] = {} + aux_stride_decl_lines: List[str] = [] + extras_validation_lines: List[str] = [] + extras_ptr_lines: List[str] = [] + seen_extras: set = set() + extra_leaves = [n for n in leaves if not isinstance(n, Accum)] + n_extras = max((leaf.input_idx for leaf in extra_leaves), default=-1) + 1 + for leaf in extra_leaves: + i = leaf.input_idx + elem = _DTYPE_TO_CUTLASS[leaf.dtype] + if isinstance(leaf, RowBroadcast): + ptr_expr = f"reinterpret_cast<{elem} const*>(a.ptr_extras[{i}])" + leaf_args[i] = f"{{ {ptr_expr} }}" + elif isinstance(leaf, ColBroadcast): + ptr_expr = f"reinterpret_cast<{elem} const*>(a.ptr_extras[{i}])" + leaf_args[i] = f"{{ {ptr_expr} }}" + elif isinstance(leaf, AuxLoad): + ptr_expr = f"reinterpret_cast<{elem} const*>(a.ptr_extras[{i}])" + stride_var = f"stride_aux_{i}" + leaf_args[i] = f"{{ {ptr_expr}, {elem}(0), {stride_var} }}" + if i not in seen_extras: + aux_stride_decl_lines.append( + f" auto {stride_var} = cutlass::make_cute_packed_stride(\n" + f" cute::Stride{{}},\n" + f" cute::make_shape(M, static_cast(a.stride_extras[{i}]), 1));" + ) + + if i in seen_extras: + continue + seen_extras.add(i) + at_dtype = _DTYPE_TO_AT[leaf.dtype] + at_cpp = _DTYPE_TO_AT_CPP[leaf.dtype] + if isinstance(leaf, RowBroadcast): + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].numel() == N, "extras[{i}] must have N elements");') + elif isinstance(leaf, ColBroadcast): + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].numel() == M, "extras[{i}] must have M elements");') + elif isinstance(leaf, AuxLoad): + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].size(0) == M && extras[{i}].size(1) == N,' f' "extras[{i}] must be (M,N)");' + ) + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].stride(1) == 1 && extras[{i}].stride(0) >= N,' + f' "extras[{i}] must be row-major with stride(1)==1 and stride(0)>=N");' + ) + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].scalar_type() == {at_dtype},' f' "extras[{i}] must be {leaf.dtype}");' + ) + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].is_cuda(), "extras[{i}] must be CUDA");') + extras_ptr_lines.append(f" ea.ptr_extras.push_back(static_cast(" f"extras[{i}].data_ptr<{at_cpp}>()));") + extras_ptr_lines.append(f" ea.stride_extras.push_back(static_cast(extras[{i}].stride(0)));") + + args_tree = _emit_args_tree(ir.child, leaf_args, indent=8) + + ptr_C_expr_in_make_args = "static_cast(nullptr)" + + extras_validation = "\n".join(extras_validation_lines) if extras_validation_lines else " // no extras" + extras_ptrs = "\n".join(extras_ptr_lines) if extras_ptr_lines else "" + aux_stride_decls = "\n".join(aux_stride_decl_lines) if aux_stride_decl_lines else " // (no AuxLoad strides)" + + functor_decls = "\n".join(emitter.functor_decls) if emitter.functor_decls else "// (no custom functors)" + typedef_block = "\n".join(" " + l if l.strip() else l for l in "\n".join(emitter.typedef_lines).split("\n")) + + cutlass_b_layout = "RowMajor" if b_layout == "row" else "ColumnMajor" + if b_layout == "row": + n_dim_expr = "B.size(1)" + b_stride_check = ( + 'TORCH_CHECK(B.stride(1) == 1, "B innermost stride must be 1; got ", B.stride(1));\n' + ' TORCH_CHECK(B.stride(0) >= B.size(1),\n' + ' "B row stride must be >= N; got stride(0)=", B.stride(0), ", N=", B.size(1));' + ) + else: + n_dim_expr = "B.size(0)" + b_stride_check = ( + 'TORCH_CHECK(B.stride(1) == 1, "B innermost stride must be 1; got ", B.stride(1));\n' + ' TORCH_CHECK(B.stride(0) >= B.size(1),\n' + ' "B row stride must be >= K; got stride(0)=", B.stride(0), ", K=", B.size(1));' + ) + + tile_candidate_block = _emit_tile_candidates(m_bucket) + + preamble = _KERNEL_PREAMBLE_SM90.format( + cache_key=cache_key_str, + functor_decls=functor_decls, + a_elem=a_elem, + b_elem=b_elem, + c_elem=c_elem, + d_elem=d_elem, + b_layout=cutlass_b_layout, + alignment_a_bits=alignment_a_bits, + alignment_b_bits=alignment_b_bits, + alignment_c_bits=alignment_c_bits, + typedef_block=typedef_block, + evt_root_name=evt_root, + ptr_C_expr_in_make_args=ptr_C_expr_in_make_args, + args_tree=args_tree, + aux_stride_decls=aux_stride_decls, + ) + launcher = _LAUNCHER_TEMPLATE_SM90.format( + a_dtype=a_dtype, + b_dtype=b_dtype, + d_dtype=ir.out_dtype, + a_at_dtype=_DTYPE_TO_AT[a_dtype], + b_at_dtype=_DTYPE_TO_AT[b_dtype], + d_at_dtype=_DTYPE_TO_AT[ir.out_dtype], + a_at_cpp=_DTYPE_TO_AT_CPP[a_dtype], + b_at_cpp=_DTYPE_TO_AT_CPP[b_dtype], + d_at_cpp=_DTYPE_TO_AT_CPP[ir.out_dtype], + n_dim_expr=n_dim_expr, + b_stride_check=b_stride_check, + n_extras=n_extras, + extras_validation=extras_validation, + extras_ptrs=extras_ptrs, + tile_candidate_block=tile_candidate_block, + ) + return preamble + launcher diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index f6441e0..3323ea3 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -17,7 +17,7 @@ from torch import fx as fx from torch._inductor.custom_graph_pass import CustomGraphPass -from ...config import PassConfig +from ...config import PassConfig, get_compile_config from ...utils import magi_logger, set_env_var from ...utils.envs import MAGI_PATTERN_MATCH_DEBUG from ..pass_base import InductorPass, get_pass_context @@ -80,7 +80,17 @@ def __call__(self, graph: fx.Graph): def configure(self, pass_config: PassConfig): self.pass_config = pass_config - # TODO: Register custom passes here (fusion, noop elimination, sequence parallelism, async TP, Ulysses overlap). + if pass_config.enable_mm_epilogue_fusion: + compile_config = get_compile_config() + if compile_config.has_cutlass: + from .fusion.matmul_epilogue_fusion import MatmulEvtEpilogueFusionPass + + self.add(MatmulEvtEpilogueFusionPass()) + else: + magi_logger.warning( + "Skipping matmul epilogue fusion because CUTLASS is unavailable. " + "Set MAGI_CUTLASS_ROOT or compile_config.cutlass_root to a valid CUTLASS source tree." + ) # needs a functional graph self.post_cleanup = PostCleanupPass() diff --git a/magi_compiler/utils/__init__.py b/magi_compiler/utils/__init__.py index e944e58..c93a742 100644 --- a/magi_compiler/utils/__init__.py +++ b/magi_compiler/utils/__init__.py @@ -15,6 +15,7 @@ from ._utils import * from .compile_counter import compilation_counter +from .device import device_capability, device_capability_major from .envs import set_env_var from .hash import compute_code_hash, compute_code_hash_with_content, compute_hash from .logger import logger, magi_logger @@ -34,4 +35,6 @@ "SingletonMeta", "instrument_nvtx", "add_nvtx_event", + "device_capability", + "device_capability_major", ] diff --git a/magi_compiler/utils/device.py b/magi_compiler/utils/device.py new file mode 100644 index 0000000..ebcd246 --- /dev/null +++ b/magi_compiler/utils/device.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU device introspection helpers. + +Centralised so that pass-manager / FX passes / runtime modules don't all +re-implement the same try/except dance around ``torch.cuda``. +""" + +from typing import Tuple + + +def device_capability(device: int = 0) -> Tuple[int, int]: + """Return ``(major, minor)`` for the given CUDA device. + + Falls back to ``(0, 0)`` when CUDA is unavailable / not initialised / + raises any error during introspection — callers compare against a + minimum cap so a zero pair always means "feature unsupported", which + is the safe behaviour on CPU-only hosts and during static analysis. + """ + try: + import torch as _torch + + if _torch.cuda.is_available(): + return _torch.cuda.get_device_capability(device) + except Exception: + pass + return (0, 0) + + +def device_capability_major(device: int = 0) -> int: + """Convenience wrapper: just the major-capability int (0 if no CUDA).""" + return device_capability(device)[0] diff --git a/tests/feature_tests/test_build_cleanup.py b/tests/feature_tests/test_build_cleanup.py new file mode 100644 index 0000000..11998ef --- /dev/null +++ b/tests/feature_tests/test_build_cleanup.py @@ -0,0 +1,152 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the build-directory cleanup mechanism in evt_runtime.py. + +The _track_build / _untrack_build + signal-handler machinery ensures that +interrupted cpp_extension.load calls leave no stale lock files or partial +build artifacts on disk. These tests exercise that mechanism directly +(no GPU needed). +""" + +import os +import signal +import subprocess +import sys +import textwrap + +import pytest + + +@pytest.fixture(autouse=True) +def _isolate_pending_set(): + """Reset _PENDING_BUILD_DIRS before and after each test.""" + from magi_compiler.passes.piecewise_graph.fusion import evt_runtime as rt + + saved = rt._PENDING_BUILD_DIRS.copy() + rt._PENDING_BUILD_DIRS.clear() + yield + rt._PENDING_BUILD_DIRS.clear() + rt._PENDING_BUILD_DIRS.update(saved) + + +def test_track_untrack_basic(tmp_path): + """Normal success path: track → present → untrack → absent.""" + from magi_compiler.passes.piecewise_graph.fusion import evt_runtime as rt + + build_dir = str(tmp_path / "build_ok") + os.makedirs(build_dir) + + rt._track_build(build_dir) + assert build_dir in rt._PENDING_BUILD_DIRS + + rt._untrack_build(build_dir) + assert build_dir not in rt._PENDING_BUILD_DIRS + + +def test_cleanup_pending_removes_tracked_dirs(tmp_path): + """_cleanup_pending_build_dirs wipes every tracked directory.""" + from magi_compiler.passes.piecewise_graph.fusion import evt_runtime as rt + + build_dir = str(tmp_path / "build_interrupted") + os.makedirs(build_dir) + # Simulate partial build artifacts + (tmp_path / "build_interrupted" / "lock").touch() + (tmp_path / "build_interrupted" / "kernel.cuda.o").touch() + (tmp_path / "build_interrupted" / "build.ninja").touch() + + rt._track_build(build_dir) + assert os.path.isdir(build_dir) + + rt._cleanup_pending_build_dirs() + + assert not os.path.exists(build_dir) + assert len(rt._PENDING_BUILD_DIRS) == 0 + + +def test_untracked_build_not_cleaned(tmp_path): + """A directory that was tracked then untracked must survive cleanup.""" + from magi_compiler.passes.piecewise_graph.fusion import evt_runtime as rt + + build_dir = str(tmp_path / "build_completed") + os.makedirs(build_dir) + (tmp_path / "build_completed" / "module.so").touch() + + rt._track_build(build_dir) + rt._untrack_build(build_dir) + + rt._cleanup_pending_build_dirs() + + assert os.path.isdir(build_dir) + assert (tmp_path / "build_completed" / "module.so").exists() + + +def test_cleanup_on_signal_in_subprocess(tmp_path): + """A subprocess that tracks a build_dir and receives SIGTERM must clean it up.""" + build_dir = str(tmp_path / "build_signal") + + script = textwrap.dedent( + f"""\ + import os, sys, time + sys.path.insert(0, {str((tmp_path / '..').resolve().parent)!r}) + + build_dir = {build_dir!r} + os.makedirs(build_dir, exist_ok=True) + with open(os.path.join(build_dir, "lock"), "w") as f: + f.write("locked") + with open(os.path.join(build_dir, "partial.o"), "w") as f: + f.write("junk") + + from magi_compiler.passes.piecewise_graph.fusion import evt_runtime as rt + rt._track_build(build_dir) + + # Signal parent that we're ready + sys.stdout.write("READY\\n") + sys.stdout.flush() + # Sleep long enough for parent to send signal + time.sleep(60) + """ + ) + + proc = subprocess.Popen( + [sys.executable, "-c", script], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=str(tmp_path) + ) + + try: + line = proc.stdout.readline() + assert line.strip() == "READY", f"Subprocess didn't become ready, got: {line!r}" + + proc.send_signal(signal.SIGTERM) + proc.wait(timeout=10) + except Exception: + proc.kill() + proc.wait() + raise + + assert not os.path.exists(build_dir), f"build_dir {build_dir} should have been cleaned up by SIGTERM handler" + + +def test_cleanup_idempotent(tmp_path): + """Calling _cleanup_pending_build_dirs twice is harmless.""" + from magi_compiler.passes.piecewise_graph.fusion import evt_runtime as rt + + build_dir = str(tmp_path / "build_double") + os.makedirs(build_dir) + rt._track_build(build_dir) + + rt._cleanup_pending_build_dirs() + assert not os.path.exists(build_dir) + + # Second call: no-op, no exception. + rt._cleanup_pending_build_dirs() diff --git a/tests/feature_tests/test_matmul_epilogue_fusion.py b/tests/feature_tests/test_matmul_epilogue_fusion.py new file mode 100644 index 0000000..5426715 --- /dev/null +++ b/tests/feature_tests/test_matmul_epilogue_fusion.py @@ -0,0 +1,1069 @@ +# Copyright (c) 2025 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CUTLASS EVT matmul–epilogue fusion (``MatmulEvtEpilogueFusionPass``). + +Architecture routing (see ``matmul_epilogue_fusion.py`` / ``evt_runtime.py``): + + * sm_90 (Hopper / H100) — CUTLASS 3.x ``Sm90EVT``; TMA+WGMMA. + * sm_120+ (Blackwell consumer, e.g. RTX 5090) — CUTLASS 2.x ``Sm80EVT``; + cp.async multistage. + +Most tests use ``@_EVT_CAPABLE`` (runs on whichever GPU is present). +``@_SM120_ONLY`` is reserved for SM80-path-specific edge cases (e.g. 64-bit +alignment that SM90 TMA cannot handle). + +Three families of checks: + + 1. Positive numerical equivalence: every supported epilogue must match + eager within dtype-appropriate tolerance. + 2. Fusion-actually-fired: the emitted graph must contain a + ``magi_epilogue.matmul_fused_epilogue`` node. + 3. Negative fallback: shapes / dtypes / chains the EVT pass does NOT + support must keep the original ``aten.mm`` and run through cuBLAS. +""" + +from typing import Optional + +import pytest +import torch +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F + +from magi_compiler.api import magi_compile +from magi_compiler.config import get_compile_config + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + +_SM120_ONLY = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 12, + reason="CUTLASS EVT path targets sm_120 (Blackwell consumer)", +) + +_SM90_ONLY = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability() != (9, 0), reason="SM90 EVT path targets Hopper (H100)" +) + +_EVT_CAPABLE = pytest.mark.skipif( + not torch.cuda.is_available() + or (torch.cuda.get_device_capability() != (9, 0) and torch.cuda.get_device_capability()[0] < 12), + reason="EVT path targets sm_90 (Hopper) or sm_120+ (Blackwell)", +) + + +_TEST_RNG_SEED = 123 + + +@pytest.fixture(autouse=True) +def _enable_mm_epilogue_fusion(): + config = get_compile_config() + old_value = config.pass_config.enable_mm_epilogue_fusion + config.pass_config.enable_mm_epilogue_fusion = True + yield + config.pass_config.enable_mm_epilogue_fusion = old_value + + +@pytest.fixture(autouse=True) +def _fixed_rng_seed(): + """Make low-precision random numerical tests reproducible.""" + cpu_state = torch.random.get_rng_state() + cuda_states = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None + torch.manual_seed(_TEST_RNG_SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(_TEST_RNG_SEED) + yield + torch.random.set_rng_state(cpu_state) + if cuda_states is not None: + torch.cuda.set_rng_state_all(cuda_states) + + +# ── Activations from athena/performer_v16/activation.py (verbatim) ──────────── + + +def high_precision_silu(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + return F.silu(x.to(torch.float32)).to(out_dtype) + + +def swiglu(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + x_glu, x_linear = x[..., ::2], x[..., 1::2] + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return (out_glu * (x_linear + 1)).to(out_dtype) + + +def gelu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + x_glu = x.clamp(min=None, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu.to(out_dtype) + + +# ── Compile + fusion-side instrumentation ──────────────────────────────────── + + +class _FusionStats: + """Records what the EVT pass did to the graph during one ``magi_compile``.""" + + def __init__(self) -> None: + self.mm_before = 0 + self.mm_after = 0 + self.fused_count = 0 + self.kinds: list = [] + self.out_dtype_ids: list = [] + self.ir_jsons: list = [] + self.call_function_targets_after: list = [] + + +def _install_pass_instrument(): + """Returns (stats, restore_fn). Wraps the FX pass to record per-call deltas.""" + from magi_compiler.passes.piecewise_graph.fusion import matmul_epilogue_fusion as P + + stats = _FusionStats() + original = P.MatmulEvtEpilogueFusionPass.__call__ + evt_op = torch.ops.magi_epilogue.matmul_fused_epilogue.default + mm_targets = (torch.ops.aten.mm.default, torch.ops.aten.mm) + + def _instrumented(self, graph: fx.Graph): + before = sum(1 for n in graph.nodes if n.op == "call_function" and n.target in mm_targets) + result = original(self, graph) + after = sum(1 for n in graph.nodes if n.op == "call_function" and n.target in mm_targets) + emitted_kinds = [] + emitted_out_dtype_ids = [] + emitted_ir_jsons = [] + call_function_targets_after = [] + for n in graph.nodes: + if n.op == "call_function": + call_function_targets_after.append(n.target) + if n.op == "call_function" and n.target is evt_op: + if len(n.args) >= 4: + emitted_ir_jsons.append(n.args[3]) + if len(n.args) >= 5: + emitted_kinds.append(n.args[4]) + if len(n.args) >= 7: + emitted_out_dtype_ids.append(n.args[6]) + stats.mm_before += before + stats.mm_after += after + stats.fused_count += len(emitted_kinds) + stats.kinds.extend(emitted_kinds) + stats.out_dtype_ids.extend(emitted_out_dtype_ids) + stats.ir_jsons.extend(emitted_ir_jsons) + stats.call_function_targets_after.extend(call_function_targets_after) + return result + + P.MatmulEvtEpilogueFusionPass.__call__ = _instrumented + + def restore(): + P.MatmulEvtEpilogueFusionPass.__call__ = original + + return stats, restore + + +def _compile_and_check( + model: nn.Module, + inputs, + *, + atol: float = 0.5, + rtol: float = 0.0, + expect_fused: int = -1, + expect_kinds: Optional[list] = None, + expect_out_dtype: Optional[torch.dtype] = None, + expect_actual_dtype: Optional[torch.dtype] = None, + dynamic_arg_dims=None, + cast_model_to_bf16: bool = True, +): + """Compile ``model``, run it on ``inputs``, compare against eager.""" + if dynamic_arg_dims is None: + import inspect + + params = list(inspect.signature(model.forward).parameters) + if not params: + dynamic_arg_dims = {} + else: + dynamic_arg_dims = {params[0]: 0} + + model = model.cuda() + if cast_model_to_bf16 and any(p.dtype.is_floating_point for p in model.parameters()): + model = model.bfloat16() + for p in model.parameters(): + p.requires_grad_(False) + + with torch.no_grad(): + expected = model(*inputs) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled_model = magi_compile(model, dynamic_arg_dims=dynamic_arg_dims) + with torch.no_grad(): + actual = compiled_model(*inputs) + finally: + restore() + + if expect_fused >= 0: + assert stats.fused_count == expect_fused, ( + f"Expected {expect_fused} fused mm sites, got {stats.fused_count}. " + f"mm_before={stats.mm_before} mm_after={stats.mm_after} " + f"emitted kinds={stats.kinds}" + ) + if expect_fused > 0: + evt_op = torch.ops.magi_epilogue.matmul_fused_epilogue.default + assert stats.call_function_targets_after == [evt_op] * expect_fused, ( + "Expected the final fused subgraph to contain only matmul_fused_epilogue " + f"call_function nodes, got {stats.call_function_targets_after}" + ) + + # Skip the numerical accuracy check when fusion was explicitly expected NOT + # to fire. The unfused path goes through vanilla torch.compile → Inductor, + # which has a known upstream bf16 mm bug: when the output dimension N is not + # 16-byte aligned (N % 8 != 0 for bf16), the compiled mm produces + # systematically wrong results (max |diff| ≈ 1.0). We still check fusion + # correctness above; the accuracy assertion is only meaningful when the EVT + # path is active. + if expect_fused == 0: + return + + abs_diff = (actual - expected).abs() + tol = atol + rtol * expected.abs() + max_violation = (abs_diff - tol).max().item() + assert max_violation <= 0, ( + f"Fused result outside tolerance: " + f"max(|diff| - tol) = {max_violation:.4f}, " + f"max |diff| = {abs_diff.max().item():.4f}, " + f"fusion stats: fused={stats.fused_count} kinds={stats.kinds}" + ) + if expect_kinds is not None: + assert sorted(stats.kinds) == sorted(expect_kinds), ( + f"Expected emitted kinds {sorted(expect_kinds)}, " f"got {sorted(stats.kinds)}" + ) + if expect_out_dtype is not None: + from magi_compiler.passes.piecewise_graph.fusion.evt_runtime import out_dtype_from_id + + assert stats.out_dtype_ids, ( + f"expect_out_dtype={expect_out_dtype} but no fusion fired " f"(out_dtype_ids list is empty)" + ) + decoded = [out_dtype_from_id(i) for i in stats.out_dtype_ids] + for got in decoded: + assert got == expect_out_dtype, ( + f"Emitted out_dtype mismatch: expected {expect_out_dtype}, " f"got {got} (full list: {decoded})" + ) + if expect_actual_dtype is not None: + assert actual.dtype == expect_actual_dtype, ( + f"Runtime result dtype mismatch: expected {expect_actual_dtype}, " f"got {actual.dtype}" + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Common helpers +# ───────────────────────────────────────────────────────────────────────────── + + +class _Bf16MmModel(nn.Module): + """bf16 mm followed by an epilogue fn that returns bf16.""" + + def __init__(self, k: int, n: int, epilogue): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + self._epi = epilogue + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return self._epi(y, out_dtype=torch.bfloat16) + + +_M, _K, _N = 1024, 1024, 1024 + + +def _input_a(): + return torch.randn(_M, _K, device="cuda", dtype=torch.bfloat16) + + +def _parse_ir_compute_dtypes(ir_json_str: str) -> list: + """Extract all compute_dtype values from Compute nodes in an IR JSON string.""" + import json + + dtypes = [] + + def _walk(d): + if not isinstance(d, dict): + return + if d.get("kind") == "compute": + dtypes.append(d.get("compute_dtype", "float32")) + for c in d.get("children", []): + _walk(c) + elif d.get("kind") == "store": + _walk(d.get("child")) + + _walk(json.loads(ir_json_str)) + return dtypes + + +# ───────────────────────────────────────────────────────────────────────────── +# Positive tests — unary activations, SwiGLU, scalar ops, bias, AuxLoad +# ───────────────────────────────────────────────────────────────────────────── + + +@_EVT_CAPABLE +@pytest.mark.parametrize("epi_name,epi_fn,atol,rtol", [("silu", high_precision_silu, 0.5, 0.0), ("gelu7", gelu7, 0.5, 0.0)]) +def test_evt_unary_activations_fuse(epi_name, epi_fn, atol, rtol): + """Representative unary activations must fuse to a single ``evt_col`` op.""" + model = _Bf16MmModel(_K, _N, epi_fn) + _compile_and_check(model, (_input_a(),), atol=atol, rtol=rtol, expect_fused=1, expect_kinds=["evt_col"]) + + +@_EVT_CAPABLE +def test_evt_relu_native(): + """Plain ``aten.relu`` variants must fuse and preserve emitted output dtype.""" + + class Fp32Relu(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return torch.relu(torch.mm(a, self.weight.permute(1, 0)).float()) + + _compile_and_check( + Fp32Relu(), + (_input_a(),), + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.float32, + expect_actual_dtype=torch.float32, + ) + + +@_EVT_CAPABLE +def test_evt_swiglu_constants_roundtrip_in_ir_json(): + """Verify that swiglu constant values are captured in ir_json.""" + import json as _json + + def swiglu_custom(x, out_dtype=None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + x_glu, x_linear = x[..., ::2], x[..., 1::2] + x_glu = x_glu.clamp(max=3.0) + x_linear = x_linear.clamp(min=-3.0, max=3.0) + out_glu = x_glu * torch.sigmoid(1.5 * x_glu) + return (out_glu * (x_linear + 1)).to(out_dtype) + + model = _Bf16MmModel(_K, _N, swiglu_custom).cuda().bfloat16() + for p in model.parameters(): + p.requires_grad_(False) + + a = _input_a() + with torch.no_grad(): + expected = model(a) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled(a) + finally: + restore() + + diff = (actual.float() - expected.float()).abs().max().item() + assert diff <= 0.5, f"swiglu custom constants max|diff|={diff}" + + assert stats.fused_count == 1 + assert stats.kinds == ["swiglu_dual"] + assert len(stats.ir_jsons) == 1 + sw7 = _json.loads(stats.ir_jsons[0]) + assert sw7["alpha"] == 1.5, f"Expected alpha=1.5, got {sw7['alpha']}" + assert sw7["limit"] == 3.0, f"Expected limit=3.0, got {sw7['limit']}" + assert sw7["one"] == 1.0, f"Expected one=1.0, got {sw7['one']}" + + +# ── alpha parameter tests for aten.add/sub ──────────────────────────────────── + + +@_EVT_CAPABLE +@pytest.mark.parametrize( + "case_name,op,other_kind,alpha", + [("add_scalar_alpha2", torch.add, "scalar", 2.0), ("sub_tensor_alpha2", torch.sub, "tensor", 2.0)], +) +def test_evt_mm_add_sub_with_alpha(case_name, op, other_kind, alpha): + """aten.add/sub with alpha must fuse and produce numerically correct results. + + Tensor-operand cases use ``silu(mm(...))`` as the base so that PyTorch's + FX decomposition does not merge ``mm + alpha*bias`` into ``aten.addmm`` + (which would hide the mm node from our EVT pass). + """ + + class ScalarModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return op(y, 0.5, alpha=alpha).to(torch.bfloat16) + + class TensorModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + self.bias = nn.Parameter(torch.randn(_N)) + + def forward(self, a): + y = F.silu(torch.mm(a, self.weight.permute(1, 0)).to(torch.float32)) + return op(y, self.bias, alpha=alpha).to(torch.bfloat16) + + model = ScalarModel() if other_kind == "scalar" else TensorModel() + _compile_and_check( + model, + (_input_a(),), + atol=1.5, + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.bfloat16, + expect_actual_dtype=torch.bfloat16, + ) + + +@_EVT_CAPABLE +def test_evt_mm_plus_1d_bias(): + """``silu(mm + bias_N)`` — 1-D bias as RowBroadcast extras.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + self.bias = nn.Parameter(torch.randn(_N)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + self.bias + return high_precision_silu(y, out_dtype=torch.bfloat16) + + _compile_and_check( + M(), + (_input_a(),), + atol=1.5, + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.bfloat16, + expect_actual_dtype=torch.bfloat16, + ) + + +@_EVT_CAPABLE +def test_evt_aux_load_padded_stride(): + """AuxLoad with padded row stride (stride(0) > N) must fuse and read correctly.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a, gate): + y = torch.mm(a, self.weight.permute(1, 0)) * gate + return y.to(torch.bfloat16) + + a = _input_a() + N_padded = _N + 64 + gate_buf = torch.randn(_M, N_padded, device="cuda", dtype=torch.bfloat16) + gate = gate_buf[:, :_N] # shape (_M, _N), stride (N_padded, 1) + assert gate.stride() == (N_padded, 1), f"Expected padded stride, got {gate.stride()}" + _compile_and_check( + M(), (a, gate), atol=0.0, rtol=0.1, expect_fused=1, expect_kinds=["evt_col"], dynamic_arg_dims={"a": 0, "gate": 0} + ) + + +@_EVT_CAPABLE +def test_evt_multiple_and_repeated_aux_loads_fuse(): + """Multiple AuxLoad extras, with one tensor reused at multiple EVT positions.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a, gate, r1, r2): + y = torch.mm(a, self.weight.permute(1, 0)) + return (y * gate + gate + r1 + r2).to(torch.bfloat16) + + a = _input_a() + gate = torch.randn(_M, _N, device="cuda", dtype=torch.bfloat16) + r1 = torch.randn(_M, _N, device="cuda", dtype=torch.bfloat16) + r2 = torch.randn(_M, _N, device="cuda", dtype=torch.bfloat16) + _compile_and_check( + M(), + (a, gate, r1, r2), + atol=4.0, + rtol=0.1, + expect_fused=1, + expect_kinds=["evt_col"], + dynamic_arg_dims={"a": 0, "gate": 0, "r1": 0, "r2": 0}, + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# RowMajor B layout — weight stored as (K, N), used directly without permute +# ───────────────────────────────────────────────────────────────────────────── + + +@_EVT_CAPABLE +def test_evt_row_b_layout_fuses(): + """B is (K, N) row-major (no permute). LayoutB=RowMajor, kind=evt_row. + + CuTe stride for RowMajor B: (_1, N, N*K) — N is contiguous. + TMA globalStride = N * sizeof(elem); N=1024 is 16B-aligned for bf16. + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(k, n)) + + def forward(self, a): + y = torch.mm(a, self.weight) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + _compile_and_check(M(_K, _N), (_input_a(),), expect_fused=1, expect_kinds=["evt_row"]) + + +# ───────────────────────────────────────────────────────────────────────────── +# Negative tests — fusion must NOT fire, cuBLAS fallback +# ───────────────────────────────────────────────────────────────────────────── + + +@_EVT_CAPABLE +def test_evt_no_fuse_intermediate_escapes(): + """Attention → residual → RMSNorm: intermediate value escapes the fused + chain. The pass MUST refuse.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(5120, _K)) + self.gamma = nn.Parameter(torch.randn(5120)) + + def forward(self, a, residual): + y = torch.mm(a, self.weight.permute(1, 0)).float() + x = residual + y + var = x.pow(2).mean(-1, keepdim=True) + rsqrt = torch.rsqrt(var + 1e-6) + return (x * rsqrt * (self.gamma + 1)).to(torch.bfloat16) + + a = _input_a() + residual = torch.randn(_M, 5120, device="cuda", dtype=torch.float32) + _compile_and_check(M(), (a, residual), atol=2.0, rtol=0.1, expect_fused=0, dynamic_arg_dims={"a": 0, "residual": 0}) + + +@_EVT_CAPABLE +def test_evt_no_fuse_bare_mm(): + """Bare ``mm`` — Store(Accum) is trivial, pass must skip.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return torch.mm(a, self.weight.permute(1, 0)) + + _compile_and_check(M(), (_input_a(),), atol=0.5, expect_fused=0) + + +@_EVT_CAPABLE +def test_evt_no_fuse_k_misaligned(): + """K below 64-bit alignment (bf16: K % 4 != 0) — pass aborts. + + K=1022: 1022 % 4 = 2 → no valid AlignmentA on either arch. + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1022 + N = 1024 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) + + +@_SM90_ONLY +def test_evt_sm90_no_fuse_k_not_16byte_aligned(): + """K=1020: K % 4 == 0 (64-bit aligned) but K * 2 % 16 != 0. + + SM90 TMA requires globalStride to be 16-byte aligned. A is RowMajor + (M, K) so stride_A = K, giving K * sizeof(bf16) = 2040 bytes, which + is not 16-byte aligned (2040 % 16 = 8). The pass must refuse. + On SM120 this fuses fine (64-bit alignment is sufficient). + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1020 + N = 1024 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) + + +@_SM90_ONLY +def test_evt_sm90_no_fuse_n_not_16byte_aligned(): + """N=1026: N * sizeof(bf16) = 2052 bytes, not 16-byte aligned. + + SM90 CollectiveEpilogue (TMA store) requires problem N % AlignmentD + == 0, where AlignmentD = 16 / sizeof(bf16) = 8. 1026 % 8 = 2 ≠ 0 + so all tile candidates fail can_implement. The pass must refuse. + On SM120 this fuses fine (runtime pads ldd). + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1024 + N = 1026 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) + + +@_SM90_ONLY +def test_evt_sm90_no_fuse_row_b_n_not_16byte_aligned(): + """RowMajor B with N=1020: N * sizeof(bf16) = 2040, not 16B-aligned. + + CuTe stride for RowMajor B is (_1, N, ...) so TMA globalStride = + N * sizeof(elem) = 2040 bytes, 2040 % 16 = 8 ≠ 0. + N=1020 passes the 64-bit check (1020 % 4 == 0) but fails the SM90 + 16B TMA constraint. The pass must refuse on SM90. + On SM120 this fuses fine (64-bit alignment is sufficient). + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(k, n)) + + def forward(self, a): + y = torch.mm(a, self.weight) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1024 + N = 1020 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) + + +@_EVT_CAPABLE +def test_evt_no_fuse_fp32_mm(): + """fp32 mm — pass requires bf16 or fp16; fp32 must skip.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return F.silu(y) + + a = torch.randn(_M, _K, device="cuda", dtype=torch.float32) + + model = M().cuda() + with torch.no_grad(): + expected = model(a) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled_model = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled_model(a) + finally: + restore() + + diff = (actual - expected).abs().max().item() + assert diff <= 1.0, f"fp32 mm result diverged: {diff}" + assert stats.fused_count == 0, ( + f"fp32 mm should NOT fuse, but pass emitted {stats.fused_count} ops " f"(kinds={stats.kinds})" + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Alignment edge cases and D stride padding +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_col_n_misaligned_still_fuses(): + """N=1026: not 128-bit aligned for bf16, runtime pads D stride. Still fuses. + + SM120-only: SM80 (CUTLASS 2.x) threadblock epilogue only requires ldd to + be aligned, so _aligned_n_stride(1026)=1032 suffices. SM90 (CUTLASS 3.x) + TMA CollectiveBuilder requires problem N % AlignmentD == 0, and 1026 % 8 + != 0 — all tile candidates fail can_implement. + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1024 + N = 1026 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=1) + + +@_SM120_ONLY +def test_evt_swiglu_small_n_still_fuses(): + """N=12: n_out=6, not 128-bit aligned. Runtime pads, fusion fires. + + SM120-only: same reason as col_n_misaligned — SM90 TMA requires + N % AlignmentD == 0 and 12 % 8 != 0. + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return swiglu(y, out_dtype=torch.bfloat16) + + K = 1024 + N = 12 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=1) + + +@_SM120_ONLY +def test_evt_row_b_n_64bit_aligned_fuses_on_sm120(): + """RowMajor B, N=1020: N % 4 == 0 (64-bit) but N*2 % 16 != 0. + + SM120-only: SM80 codegen accepts 64-bit alignment for B. + SM90 TMA rejects because globalStride = 1020 * 2 = 2040, 2040 % 16 ≠ 0. + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(k, n)) + + def forward(self, a): + y = torch.mm(a, self.weight) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1024 + N = 1020 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=1) + + +@_EVT_CAPABLE +def test_evt_d_stride_padding_silu(): + """D stride padding regression: N=1032, not 128-byte aligned for bf16. + Runtime pads D to n_pad=1088.""" + K = 1024 + N = 1032 + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(N, K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(), (a,), atol=0.5, expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_k_64bit_aligned_fuses_on_sm120(): + """K=1020: K % 4 == 0 (64-bit aligned) but K % 8 != 0 (not 128-bit). + + On SM120 (RTX 5090), the SM80 codegen accepts AlignmentA=4 (64-bit) + and fusion proceeds normally. This exercises the 64-bit fallback path + in ``_largest_pow2_align_bits`` / ``_runtime_align_bits``. + """ + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1020 + N = 1024 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=1, expect_kinds=["evt_col"]) + + +# ───────────────────────────────────────────────────────────────────────────── +# IR / cache key invariants +# ───────────────────────────────────────────────────────────────────────────── + + +@_EVT_CAPABLE +def test_evt_ir_canonical_determinism(): + """Same IR built twice → identical canonical JSON.""" + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, Compute, Store, cache_key, to_canonical_json + + a = Store(Compute("silu", (Compute("add", (Accum(), Accum())),)), "bfloat16") + b = Store(Compute("silu", (Compute("add", (Accum(), Accum())),)), "bfloat16") + assert to_canonical_json(a) == to_canonical_json(b) + assert cache_key(a, "bfloat16", "bfloat16") == cache_key(b, "bfloat16", "bfloat16") + + +# ───────────────────────────────────────────────────────────────────────────── +# Per-node compute_dtype +# ───────────────────────────────────────────────────────────────────────────── + + +@_EVT_CAPABLE +def test_evt_mixed_compute_dtype_chain(): + """mm → to(fp32) → silu → to(bf16) → add_scalar(0.5). + + silu must have compute_dtype=float32, add_scalar must have bfloat16. + """ + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + y = y.float() + y = F.silu(y) + y = y.bfloat16() + y = y + 0.5 + return y + + model = M().cuda().bfloat16() + for p in model.parameters(): + p.requires_grad_(False) + a = _input_a() + + with torch.no_grad(): + expected = model(a) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled(a) + finally: + restore() + + diff = (actual.float() - expected.float()).abs().max().item() + assert diff <= 1.5, f"Mixed compute_dtype chain max|diff|={diff}" + assert stats.fused_count == 1, f"Expected 1 fusion, got {stats.fused_count}" + + assert len(stats.ir_jsons) == 1, f"Expected 1 ir_json, got {len(stats.ir_jsons)}" + compute_dtypes = _parse_ir_compute_dtypes(stats.ir_jsons[0]) + assert "bfloat16" in compute_dtypes, f"Expected at least one bfloat16 compute_dtype in IR, " f"got {compute_dtypes}" + assert "float32" in compute_dtypes, f"Expected at least one float32 compute_dtype in IR, " f"got {compute_dtypes}" + + +# ───────────────────────────────────────────────────────────────────────────── +# No-GPU tests: codegen, IR invariants +# ───────────────────────────────────────────────────────────────────────────── + + +def test_sm90_codegen_repeated_aux_idx(): + """SM90 codegen produces valid C++ with repeated AuxLoad input_idx.""" + import re + + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, AuxLoad, Compute, Store + from magi_compiler.passes.piecewise_graph.fusion.sm90.evt_codegen import render_evt_cu + + ir = Store( + child=Compute( + op="add", + children=( + Compute(op="mul", children=(Accum(), AuxLoad(input_idx=0, dtype="bfloat16"))), + AuxLoad(input_idx=0, dtype="bfloat16"), + ), + ), + out_dtype="bfloat16", + ) + src = render_evt_cu(ir, "bfloat16", "bfloat16") + + aux_load_defs = re.findall(r"using\s+\w+\s*=\s*cutlass::epilogue::fusion::Sm90AuxLoad<", src) + assert len(aux_load_defs) == 2, f"Expected 2 Sm90AuxLoad typedefs, found {len(aux_load_defs)}" + assert len(re.findall(r"ptr_extras\[0\]", src)) >= 1 + assert "expected 1 extra tensors" in src + + +def test_sm90_codegen_repeated_aux_idx_mixed_with_distinct(): + """SM90 codegen: repeated input_idx=0 + distinct input_idx=1.""" + import re + + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, AuxLoad, Compute, Store + from magi_compiler.passes.piecewise_graph.fusion.sm90.evt_codegen import render_evt_cu + + ir = Store( + child=Compute( + op="add", + children=( + Compute( + op="add", + children=( + Compute(op="mul", children=(Accum(), AuxLoad(input_idx=0, dtype="bfloat16"))), + AuxLoad(input_idx=0, dtype="bfloat16"), + ), + ), + AuxLoad(input_idx=1, dtype="bfloat16"), + ), + ), + out_dtype="bfloat16", + ) + src = render_evt_cu(ir, "bfloat16", "bfloat16") + + aux_load_defs = re.findall(r"using\s+\w+\s*=\s*cutlass::epilogue::fusion::Sm90AuxLoad<", src) + assert len(aux_load_defs) == 3, f"Expected 3 Sm90AuxLoad typedefs, found {len(aux_load_defs)}" + assert "expected 2 extra tensors" in src + + +def test_evt_ir_compute_dtype_roundtrip(): + """Compute with non-default compute_dtype serialises and round-trips.""" + import json + + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, Compute, Store, to_canonical_json + from magi_compiler.passes.piecewise_graph.fusion.evt_runtime import _ir_from_json + + ir_bf16 = Store(Compute("silu", (Accum(),), compute_dtype="bfloat16"), "bfloat16") + j_bf16 = to_canonical_json(ir_bf16) + parsed = json.loads(j_bf16) + assert parsed["child"]["compute_dtype"] == "bfloat16" + + ir_default = Store(Compute("silu", (Accum(),)), "bfloat16") + j_default = to_canonical_json(ir_default) + assert "compute_dtype" not in j_default + + restored = _ir_from_json(j_bf16) + assert restored.child.compute_dtype == "bfloat16" + restored_default = _ir_from_json(j_default) + assert restored_default.child.compute_dtype == "float32" + + ir_mixed = Store( + Compute( + "add", + (Compute("silu", (Accum(),), compute_dtype="float32"), Compute("neg", (Accum(),), compute_dtype="bfloat16")), + compute_dtype="bfloat16", + ), + "bfloat16", + ) + j_mixed = to_canonical_json(ir_mixed) + p = json.loads(j_mixed) + assert p["child"]["compute_dtype"] == "bfloat16" + assert "compute_dtype" not in p["child"]["children"][0] + assert p["child"]["children"][1]["compute_dtype"] == "bfloat16" + + +def test_evt_ir_compute_dtype_cache_key_differs(): + """Different compute_dtype MUST produce different cache keys.""" + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, Compute, Store, to_canonical_json + + ir_fp32 = Store(Compute("silu", (Accum(),), compute_dtype="float32"), "bfloat16") + ir_bf16 = Store(Compute("silu", (Accum(),), compute_dtype="bfloat16"), "bfloat16") + assert to_canonical_json(ir_fp32) != to_canonical_json(ir_bf16) + + +def test_evt_ir_compute_dtype_valid_types(): + """All floating-point ALU types are accepted as compute_dtype.""" + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, Compute + + for dt in ("float32", "float16", "bfloat16"): + node = Compute("silu", (Accum(),), compute_dtype=dt) + assert node.compute_dtype == dt + + +def test_evt_ir_compute_dtype_rejects_unsupported(): + """Unsupported compute_dtype values must raise.""" + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, Compute + + for bad_dt in ("float64", "int8", "int16", "int32", "int64"): + with pytest.raises(ValueError, match="Unsupported compute_dtype"): + Compute("silu", (Accum(),), compute_dtype=bad_dt) + + +def test_evt_codegen_sm80_per_node_compute_dtype(): + """SM80 codegen emits per-node element types in VisitorCompute.""" + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, Compute, Store + from magi_compiler.passes.piecewise_graph.fusion.sm80.evt_codegen import render_evt_cu + + ir = Store( + Compute( + "add", + (Compute("silu", (Accum(),), compute_dtype="float32"), Compute("neg", (Accum(),), compute_dtype="bfloat16")), + compute_dtype="bfloat16", + ), + "bfloat16", + ) + src = render_evt_cu(ir, "bfloat16", "bfloat16") + assert "VisitorCompute<" in src + assert "cutlass::bfloat16_t, cutlass::bfloat16_t" in src + assert "float, float" in src + + +def test_evt_codegen_sm90_per_node_compute_dtype(): + """SM90 codegen emits per-node element types in Sm90Compute.""" + from magi_compiler.passes.piecewise_graph.fusion.evt_ir import Accum, Compute, Store + from magi_compiler.passes.piecewise_graph.fusion.sm90.evt_codegen import render_evt_cu + + ir = Store( + Compute( + "add", + (Compute("silu", (Accum(),), compute_dtype="float32"), Compute("neg", (Accum(),), compute_dtype="bfloat16")), + compute_dtype="bfloat16", + ), + "bfloat16", + ) + src = render_evt_cu(ir, "bfloat16", "bfloat16") + assert "Sm90Compute<" in src + assert "cutlass::bfloat16_t, cutlass::bfloat16_t" in src + assert "float, float" in src + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])