[Feat] Add CUTLASS matmul-epilogue fusion path#30
Open
wtr0504 wants to merge 28 commits into
Open
Conversation
23 tasks
jiahy0825
reviewed
May 8, 2026
23 tasks
23 tasks
jiahy0825
reviewed
May 8, 2026
548ae66 to
7a3bed3
Compare
b92c6e7 to
d546b70
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
🗂️ PR Category
📝 Description
Add a CUTLASS-based matmul + epilogue fusion pass that fuses aten.mm followed by elementwise chains (activations, scalar ops, bias-add, residual loads) into a single GPU kernel, eliminating intermediate global memory round-trips.
Two fusion backends
Generic EVT (Epilogue Visitor Tree) — builds an IR tree from the FX epilogue chain and JIT-compiles a CUTLASS kernel templated on that tree. Supports unary activations (SiLU, Sigmoid, GeLU, ReLU, etc.), scalar arithmetic, 1-D bias (RowBroadcast), column scaling (ColBroadcast), and full (M, N) auxiliary loads. Codegen renders to CUTLASS 3.x
Sm90EVTon H100 (sm_90) and CUTLASS 2.xSm80EVTon RTX 5090 / Blackwell consumer (sm_120).SwiGLU DualGemm — pattern-matches the canonical SwiGLU recipe (
slice-stride-2 → dual clamp → scaled SiLU → multiply) and dispatches to a vendored CUTLASSDualGemmkernel that runs both GEMMs (A @ W_gate.TandA @ W_linear.T) in the same threadblock, sharing A's SMEM stages. Writes(M, N/2)directly. Routes to SM80cp.asyncmultistage on sm_120 and to SM90 TMA + WGMMA warp-specialized path on sm_90. Static constants (alpha,limit,one) are captured dynamically from the FX graph.Key design points
compute_dtype: the IR walker tracks_to_copy/convert_element_typenodes and stamps eachComputenode with the precision active at that point (float32,bfloat16, orfloat16). Codegen emits per-node element types inVisitorCompute/Sm90Compute._track_build/_untrack_build+atexit+ SIGTERM/SIGINT/SIGHUP handlers ensure interruptedcpp_extension.loadcalls don't leave stale lock files. Warm-cache fast path (_try_dlopen_prebuilt) skipscpp_extension.loadentirely when the.soalready exists, preventing multi-rank FileBaton hangs.(M, N)auxiliary tensors are loaded via inlineld.global(Sm90AuxLoad<0>) — no SMEM staging needed — enabling patterns likemm + R1 + R2 + R3.Files added / changed
evt_ir.pymatmul_epilogue_fusion.pyevt_runtime.pytorch.libraryop, JIT compile cache, DualGemm loader, dispatch fast-cache, build cleanupsm80/evt_codegen.py.curenderersm90/evt_codegen.py.curenderer with Sm90Compute + Sm90AuxLoadsm80/cutlass_kernels/swiglu_one_stage.cusm90/cutlass_kernels/swiglu_one_stage.cusm90/cutlass_kernels/hopper_dual_gemm/common/cutlass_kernels/swiglu_combine.h,common/codegen_shared.pyconfig.pyenable_mm_epilogue_fusionflag in PassConfigDockerfile,.pre-commit-config.yamltest_matmul_epilogue_fusion.pytest_build_cleanup.py_track_build/_untrack_build+ signal-handler cleanup mechanism