Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
cdf5979
add triton matmul fusion
wtr0504 Apr 11, 2026
20054d1
add cute kernel
wtr0504 Apr 13, 2026
292e5cd
[Feat] Add CUTLASS matmul-epilogue fusion path for sm_120
wtr0504 Apr 28, 2026
ce3f7b4
add cutlass install in Dockerfile & update
wtr0504 Apr 29, 2026
4474bbd
add enable_mm_epilogue_fusion & chore
wtr0504 Apr 29, 2026
bd5a2e6
chore
wtr0504 Apr 29, 2026
2239be7
update .github/codestyle/copyright.hook
wtr0504 Apr 30, 2026
68ecbee
Fix: unify Alignment and padding D Tensor
wtr0504 May 7, 2026
efd5193
add more flexible align for matrix
wtr0504 May 9, 2026
0868864
refactor & add sm90 c++ code
wtr0504 May 15, 2026
0a3d89f
add sm90 multi-extra
wtr0504 May 18, 2026
0e77026
refactor & handle type conversion in epilogue
wtr0504 May 19, 2026
75004b4
fix static param handling in swiglu
wtr0504 May 19, 2026
7a0b3b5
refactor & fix sm90 ldd & chore
wtr0504 May 22, 2026
0b8082c
Improve cleanup handling for interrupted C++ compilation
wtr0504 May 23, 2026
4ea07f6
chore & add ci test
wtr0504 May 23, 2026
6535f96
Update Dockerfile
wtr0504 May 25, 2026
0f65438
Update README.md
wtr0504 May 25, 2026
3242d8d
fix matmul epilogue fusion correctness
wtr0504 May 27, 2026
16a1679
change cutlass root path
wtr0504 May 27, 2026
0ddef80
chore
wtr0504 May 27, 2026
ded0b34
chore
wtr0504 May 27, 2026
468e936
rm some tests
wtr0504 May 27, 2026
3e65cbd
rm some tests
wtr0504 May 27, 2026
97ddf50
chore
wtr0504 May 27, 2026
8b84c6f
rm some tests
wtr0504 May 28, 2026
d546b70
rm some tests
wtr0504 May 28, 2026
3587de6
chore
wtr0504 May 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/codestyle/copyright.hook
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_comment_mark(path):
if lang_type.search(path) is not None:
return "#"

lang_type=re.compile(r"\.(h|c|hpp|cc|cpp|cu|go|cuh|proto)$")
lang_type=re.compile(r"\.(h|c|hpp|hxx|cc|cpp|cxx|cu|go|cuh|proto)$")
if lang_type.search(path) is not None:
return "//"

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
name: copyright_checker
entry: python3 ./.github/codestyle/copyright.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$
files: \.(c|cc|cxx|cpp|cu|cuh|h|hpp|hxx|proto|py|sh)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
Expand Down
79 changes: 79 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,25 @@ FROM nvcr.io/nvidia/pytorch:25.10-py3

ARG FLASH_ATTENTION_COMMIT_ID="b613d9e2c8475945baff3fd68f2030af1b890acf"

# CUTLASS β€” source is always cloned (the magi_compiler EVT-fusion path
Comment thread
wtr0504 marked this conversation as resolved.
# JIT-includes its headers and our /usr/local/cutlass tree is the readable
# reference checkout). The CMake-driven profiler/library is compiled
# only for supported targets; every other arch gets headers only.
#
# Supported NVCC arch strings (CUTLASS_NVCC_ARCHS):
# 90a β€” Hopper (H100, compute_cap 9.x, WGMMA/TMA)
# 120a β€” consumer Blackwell (RTX 50 series, compute_cap 12.x)
#
# Override behaviour with build args:
# --build-arg CUTLASS_BUILD=yes|no|auto
# yes β€” force cmake configure (requires CUTLASS_NVCC_ARCHS or a GPU)
# no β€” skip cmake even if a supported GPU is present
# auto β€” (default) compile iff nvidia-smi reports 9.x or 12.x
# --build-arg CUTLASS_NVCC_ARCHS=90a|120a
ARG CUTLASS_COMMIT_ID="f74fea9ce35868d3ae9f8d1dce1969d7250d3f90"
ARG CUTLASS_BUILD="auto"
ARG CUTLASS_NVCC_ARCHS=""

ENV PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \
PYTHONDONTWRITEBYTECODE=1
Expand All @@ -18,6 +37,7 @@ RUN --mount=type=secret,id=http_proxy,required=false \
ca-certificates \
git \
build-essential \
cmake \
ninja-build && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
Expand All @@ -42,6 +62,65 @@ RUN --mount=type=secret,id=http_proxy,required=false \
cp /tmp/flash-attention/hopper/flash_attn_interface.py ${python_path}/flash_attn_3/ && \
rm -rf /tmp/flash-attention


RUN --mount=type=secret,id=http_proxy,required=false \
--mount=type=secret,id=https_proxy,required=false \
export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \
export https_proxy="$(cat /run/secrets/https_proxy 2>/dev/null || true)" && \
mkdir -p /usr/local/cutlass && \
cd /usr/local/cutlass && \
git init -q && \
git remote add origin https://github.com/NVIDIA/cutlass.git && \
git fetch origin ${CUTLASS_COMMIT_ID} --depth 1 && \
git checkout ${CUTLASS_COMMIT_ID} && \
(git submodule update --init --recursive --depth 1 --jobs 8 || \
git submodule update --init --recursive --depth 1 --jobs 1)


RUN set -eu; \
_cutlass_arch_from_gpu() { \
if ! command -v nvidia-smi >/dev/null 2>&1; then return 1; fi; \
cap="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -n1 | tr -d ' ')"; \
case "${cap}" in \
9.*) echo "90a" ;; \
12.*) echo "120a" ;; \
*) return 1 ;; \
esac; \
}; \
if [ -n "${CUTLASS_NVCC_ARCHS}" ]; then \
NVCC_ARCHS="${CUTLASS_NVCC_ARCHS}"; \
echo "[CUTLASS] Using CUTLASS_NVCC_ARCHS=${NVCC_ARCHS} (build-arg override)."; \
elif arch="$(_cutlass_arch_from_gpu)"; then \
NVCC_ARCHS="${arch}"; \
echo "[CUTLASS] nvidia-smi β†’ CUTLASS_NVCC_ARCHS=${NVCC_ARCHS}."; \
else \
NVCC_ARCHS=""; \
fi; \
case "${CUTLASS_BUILD}" in \
no) echo "[CUTLASS] CUTLASS_BUILD=no β€” skipping cmake configure."; exit 0 ;; \
yes) \
if [ -z "${NVCC_ARCHS}" ]; then \
echo "[CUTLASS] CUTLASS_BUILD=yes but no arch: set CUTLASS_NVCC_ARCHS=90a|120a or build on a 9.x/12.x GPU."; \
exit 1; \
fi; \
DO_BUILD=1 ;; \
auto) \
if [ -z "${NVCC_ARCHS}" ]; then \
echo "[CUTLASS] No sm_90/sm_120 GPU and no CUTLASS_NVCC_ARCHS β€” skipping cmake (headers still available)."; \
exit 0; \
fi; \
DO_BUILD=1 ;; \
*) echo "[CUTLASS] Unknown CUTLASS_BUILD=${CUTLASS_BUILD}"; exit 1 ;; \
esac; \
case "${NVCC_ARCHS}" in \
90a|120a) ;; \
*) echo "[CUTLASS] Unsupported CUTLASS_NVCC_ARCHS=${NVCC_ARCHS} (expected 90a or 120a)."; exit 1 ;; \
esac; \
[ -n "${DO_BUILD:-}" ] && cd /usr/local/cutlass && \
export CUDACXX="${CUDA_INSTALL_PATH:-${CUDA_HOME:-/usr/local/cuda}}/bin/nvcc" && \
mkdir -p build && cd build && \
cmake .. -DCUTLASS_NVCC_ARCHS="${NVCC_ARCHS}"

RUN --mount=type=secret,id=http_proxy,required=false \
--mount=type=secret,id=https_proxy,required=false \
export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \
Expand Down
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ pip install -r requirements.txt
# Step 4 β€” Install MagiCompiler (pick one)
pip install . # End users (recommended)
# pip install -e . --no-build-isolation --config-settings editable_mode=compat # Developer / editable

# Step 5 (optional) β€” Install CUTLASS for matmul epilogue fusion
# Required for the CUTLASS-based matmul + epilogue fusion pass (sm_90 / sm_120).
# Without CUTLASS the compiler still works but skips this optimization.
git clone --depth 1 https://github.com/NVIDIA/cutlass.git /usr/local/cutlass
# Or specify a custom path:
# git clone --depth 1 https://github.com/NVIDIA/cutlass.git /your/path
# export MAGI_CUTLASS_ROOT=/your/path
export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
mkdir /usr/local/cutlass/build && cd /usr/local/cutlass/build
cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GPU architecture
# cmake .. -DCUTLASS_NVCC_ARCHS=120a # compiles for NVIDIA consumer Blackwell (RTX 50 series)
```

---
Expand Down
28 changes: 28 additions & 0 deletions magi_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ class PassConfig(BaseModel):
# TODO: Add sequence parallelism pass and async TP pass.
# TODO: Add Ulysses overlap pass.
enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.")
enable_mm_epilogue_fusion: bool = Field(
False,
description=(
"Whether to enable the matmul + elementwise epilogue fusion pass. "
"On RTX 5090 (sm_120) this lowers fused chains to a CUTLASS Sm80EVT "
"kernel via the fusion.MatmulEvtEpilogueFusionPass; on H100 "
"(sm_90) the swiglu sub-path additionally uses the native Sm90 "
"TMA + WGMMA DualGemm. The pass is a no-op on older architectures "
"regardless of this flag, but the flag still controls whether it "
"is registered at all."
),
)

@property
def hash(self) -> str:
Expand Down Expand Up @@ -141,6 +153,14 @@ class OffloadConfig(BaseModel):
bandwidth_safety_factor: float = Field(0.9, description="The safety factor for the H2D bandwidth.")


def _find_cutlass_root() -> str:
"""Return the CUTLASS source root, or empty string if not found."""
path = os.environ.get("MAGI_CUTLASS_ROOT", "/usr/local/cutlass")
if os.path.isdir(path):
return path
return ""


class CompileConfig(BaseSettings):
"""Top-level configuration consumed by ``magi_compile`` and the MagiCompiler backend.

Expand Down Expand Up @@ -172,6 +192,10 @@ class CompileConfig(BaseSettings):
default=os.path.expanduser("~/.cache/magi_compiler"),
description="Root directory for persisting compiled artifacts and debug dumps.",
)
cutlass_root: str = Field(
default_factory=_find_cutlass_root,
description="Path to the CUTLASS source tree. Default: $MAGI_CUTLASS_ROOT or /usr/local/cutlass.",
)

# ---- Compilation mode ----
aot: bool = Field(
Expand Down Expand Up @@ -234,6 +258,10 @@ class CompileConfig(BaseSettings):
),
)

@property
def has_cutlass(self) -> bool:
return bool(self.cutlass_root)

@property
def hash(self) -> str:
return compute_hash(self.model_dump(mode="json"))
Expand Down
2 changes: 2 additions & 0 deletions magi_compiler/passes/full_graph/full_graph_pass_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ...magi_depyf.timeline import observe_lifecycle
from .remove_item import RemoveItemPass
from .remove_useless_ops import EliminateIdentityViewCastPass
from .replace_sage_atten import ReplaceSageAttentionPass


Expand All @@ -30,6 +31,7 @@ def __init__(self, pass_config):
if self.pass_config.enable_sage_attn:
self.passes.append(ReplaceSageAttentionPass())
self.passes.append(RemoveItemPass())
self.passes.append(EliminateIdentityViewCastPass())

@observe_lifecycle("full_graph_manager")
def __call__(self, gm: torch.fx.GraphModule):
Expand Down
116 changes: 116 additions & 0 deletions magi_compiler/passes/full_graph/remove_useless_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch._inductor.fx_passes.pre_grad

from ...magi_depyf.timeline import emit_pass_lifecycle
from ..pass_base import MagiInductorPass


class EliminateIdentityViewCastPass(MagiInductorPass):
"""
Remove useless convert, view, reshape operations.
When their input already has the target type and shape, these operations are redundant.
"""

TARGET_METHODS = {
"view",
"reshape",
"to",
"type",
"contiguous",
"flatten",
"permute",
"transpose",
"t",
"unsqueeze",
"squeeze",
"expand",
"repeat",
"bfloat16",
"float",
"half",
"int",
"long",
"short",
"double",
"bool",
"byte",
}

@staticmethod
def _get_tensor_info(node: torch.fx.Node):
# Get tensor info from example_value
if "example_value" in node.meta:
val = node.meta["example_value"]
if isinstance(val, torch.Tensor):
return val.shape, val.dtype, val.stride()
elif isinstance(val, (list, tuple)) and len(val) > 0 and isinstance(val[0], torch.Tensor):
return val[0].shape, val[0].dtype, val[0].stride()

return None, None, None

def is_applicable(self, graph: torch.fx.Graph, shape: int | None = None) -> bool:
for node in graph.nodes:
if node.op == "call_method" and node.target in self.TARGET_METHODS:
return True
return False

@emit_pass_lifecycle
def __call__(self, graph: torch.fx.Graph):
nodes_to_remove = []

for node in graph.nodes:
is_target_method = node.op == "call_method" and node.target in self.TARGET_METHODS
if not is_target_method:
continue

# Need at least one argument (the input tensor)
if not node.args or not isinstance(node.args[0], torch.fx.Node):
continue

input_node = node.args[0]

node_shape, node_dtype, node_stride = self._get_tensor_info(node)
input_shape, input_dtype, input_stride = self._get_tensor_info(input_node)
if node_shape is None or input_shape is None:
continue
if node_dtype is None or input_dtype is None:
continue
# Some ops or metadata might not have stride properly captured,
# but if they do, we should require them to match to be totally safe against contiguous-forcing ops.
if node_stride is not None and input_stride is not None and node_stride != input_stride:
continue

# Check if shape and dtype match exactly
if node_shape == input_shape and node_dtype == input_dtype:
# For _to_copy, ensure we are not changing memory format or device or other properties implicitly,
# but typically in full graph if shape and dtype match, and it's on the same device, it's safe.
# Let's also check device just in case if it's available.
def get_device(n):
if "example_value" in n.meta and isinstance(n.meta["example_value"], torch.Tensor):
return n.meta["example_value"].device

node_device = get_device(node)
input_device = get_device(input_node)
if node_device is not None and input_device is not None and node_device != input_device:
continue

# Replace uses
node.replace_all_uses_with(input_node)
nodes_to_remove.append(node)

for node in nodes_to_remove:
graph.erase_node(node)
13 changes: 13 additions & 0 deletions magi_compiler/passes/piecewise_graph/fusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15 changes: 15 additions & 0 deletions magi_compiler/passes/piecewise_graph/fusion/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) 2026 SandAI. All Rights Reserved.
Loading