Skip to content

[Feat] Add CUTLASS matmul-epilogue fusion path#30

Open
wtr0504 wants to merge 28 commits into
SandAI-org:mainfrom
wtr0504:feat/matmul_epilogue
Open

[Feat] Add CUTLASS matmul-epilogue fusion path#30
wtr0504 wants to merge 28 commits into
SandAI-org:mainfrom
wtr0504:feat/matmul_epilogue

Conversation

@wtr0504
Copy link
Copy Markdown
Collaborator

@wtr0504 wtr0504 commented Apr 28, 2026

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

Add a CUTLASS-based matmul + epilogue fusion pass that fuses aten.mm followed by elementwise chains (activations, scalar ops, bias-add, residual loads) into a single GPU kernel, eliminating intermediate global memory round-trips.

Two fusion backends

  • Generic EVT (Epilogue Visitor Tree) — builds an IR tree from the FX epilogue chain and JIT-compiles a CUTLASS kernel templated on that tree. Supports unary activations (SiLU, Sigmoid, GeLU, ReLU, etc.), scalar arithmetic, 1-D bias (RowBroadcast), column scaling (ColBroadcast), and full (M, N) auxiliary loads. Codegen renders to CUTLASS 3.x Sm90EVT on H100 (sm_90) and CUTLASS 2.x Sm80EVT on RTX 5090 / Blackwell consumer (sm_120).

  • SwiGLU DualGemm — pattern-matches the canonical SwiGLU recipe (slice-stride-2 → dual clamp → scaled SiLU → multiply) and dispatches to a vendored CUTLASS DualGemm kernel that runs both GEMMs (A @ W_gate.T and A @ W_linear.T) in the same threadblock, sharing A's SMEM stages. Writes (M, N/2) directly. Routes to SM80 cp.async multistage on sm_120 and to SM90 TMA + WGMMA warp-specialized path on sm_90. Static constants (alpha, limit, one) are captured dynamically from the FX graph.

Key design points

  • Per-node compute_dtype: the IR walker tracks _to_copy / convert_element_type nodes and stamps each Compute node with the precision active at that point (float32, bfloat16, or float16). Codegen emits per-node element types in VisitorCompute / Sm90Compute.
  • Greedy alignment: tries 128-bit loads first, falls back to 64-bit when K or N only meets 8-byte alignment. The runtime pads D's row stride to 16-byte boundaries for TMA compatibility.
  • Autotune: both DualGemm paths register multiple (TileShape, Stages) candidates and time them at first call per shape bucket; the winner is cached for all subsequent calls.
  • Build robustness: _track_build / _untrack_build + atexit + SIGTERM/SIGINT/SIGHUP handlers ensure interrupted cpp_extension.load calls don't leave stale lock files. Warm-cache fast path (_try_dlopen_prebuilt) skips cpp_extension.load entirely when the .so already exists, preventing multi-rank FileBaton hangs.
  • SM90 multi-AuxLoad: multiple (M, N) auxiliary tensors are loaded via inline ld.global (Sm90AuxLoad<0>) — no SMEM staging needed — enabling patterns like mm + R1 + R2 + R3.

Files added / changed

Area Files What
IR evt_ir.py Dataclass IR (Accum, Compute, Store, RowBroadcast, ColBroadcast, AuxLoad) + canonical JSON serialization
FX pass matmul_epilogue_fusion.py Graph walker, epilogue chain absorption, SwiGLU pattern matching, B-layout classification
Runtime evt_runtime.py torch.library op, JIT compile cache, DualGemm loader, dispatch fast-cache, build cleanup
SM80 codegen sm80/evt_codegen.py CUTLASS 2.x Sm80EVT .cu renderer
SM90 codegen sm90/evt_codegen.py CUTLASS 3.x Sm90EVT .cu renderer with Sm90Compute + Sm90AuxLoad
SM80 DualGemm sm80/cutlass_kernels/swiglu_one_stage.cu Vendored DualGemm + SwigluCombine epilogue + autotune runner
SM90 DualGemm sm90/cutlass_kernels/swiglu_one_stage.cu TMA + WGMMA DualGemm + SwigluCombine + autotune runner
SM90 device wrapper sm90/cutlass_kernels/hopper_dual_gemm/ Vendored Sm90DualGemm with LayoutTraits for 2.x→3.x layout translation
Shared common/cutlass_kernels/swiglu_combine.h, common/codegen_shared.py SwigluCombine functor, shared codegen utilities
Config config.py enable_mm_epilogue_fusion flag in PassConfig
Infra Dockerfile, .pre-commit-config.yaml CUTLASS install, copyright hook
Tests test_matmul_epilogue_fusion.py 30+ tests: positive (activations, scalar ops, bias, AuxLoad, SwiGLU), negative (escape, misalign, bare mm), out_dtype matrix, compute_dtype, D-stride padding, SM90 parity
Tests test_build_cleanup.py 5 tests for the _track_build/_untrack_build + signal-handler cleanup mechanism

Comment thread magi_compiler/passes/full_graph/remove_useless_ops.py Outdated
Comment thread magi_compiler/utils/device.py
Comment thread Dockerfile
Comment thread magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py Outdated
@jiahy0825 jiahy0825 linked an issue May 8, 2026 that may be closed by this pull request
23 tasks
Comment thread magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py Outdated
Comment thread magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py Outdated
@wtr0504 wtr0504 force-pushed the feat/matmul_epilogue branch from 548ae66 to 7a3bed3 Compare May 19, 2026 06:57
@wtr0504 wtr0504 force-pushed the feat/matmul_epilogue branch from b92c6e7 to d546b70 Compare May 28, 2026 07:31
@wtr0504 wtr0504 changed the title [Feat] Add CUTLASS matmul-epilogue fusion path for sm_120 [Feat] Add CUTLASS matmul-epilogue fusion path May 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants