[Cutlass SM90] Per-group aux TMA descriptor update for grouped GEMM + Gated-SwiGLU example)#3256
Open
Butterfingrz wants to merge 1 commit into
Open
[Cutlass SM90] Per-group aux TMA descriptor update for grouped GEMM + Gated-SwiGLU example)#3256Butterfingrz wants to merge 1 commit into
Butterfingrz wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
SM90 persistent grouped GEMM is already a critical high-performance path for MoE workloads: PyTorch exposes BF16
grouped_mmfor CUDA grouped GEMM, and TorchTitan's Llama/MoE path directly depends ontorch._grouped_mm. In practice this makes CUTLASS SM90 grouped GEMM a natural backend to extend rather than bypass.However, the existing CUTLASS epilogue templates assume a 1:1 accumulator-to-output mapping, and the stock aux-store path is tied to a static, shape-preserving descriptor. This prevents fusing shape-changing activations such as SwiGLU, where the GEMM produces
[M, 2I]but the activation output is[M, I]— especially when a persistent grouped kernel moves across groups with different pointers and problem shapes.This PR add a generic per-group Aux TMA descriptor update path to the SM90 persistent grouped GEMM TMA warp-specialized epilogue, and demonstrates it with a Hopper Grouped GEMM + Gated-SwiGLU end-to-end example. The framework-level changes are activation-agnostic — SwiGLU is simply the first consumer of the mechanism and lives entirely under
examples/57_hopper_grouped_gemm/.Summary
C / Dper-group descriptorlifecycle and gated on a compile-time SFINAE trait so that existing instantiations see zero behavioral and code-size impact.
Sm90GatedSwiGLUStoreTmaEVT root node and a complete57_hopper_bias_swiglu_grouped_gemm_bf16driver demonstrating Bias + SwiGLU fused into grouped GEMMbf16, validated on both cooperative and pingpong schedules.
Framework changes (4 include files, +203 / −5)
include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp(+25)ConsumerStoreArgsgains anaux_store_tensormapfield that threads the per-CTA × per-warpgroup aux gmem tensormap slot pointer through to every EVT callback.Sm90TreeVisitoraddsget_aux_tma_descriptor()andaux_tensormaps_replace()that forward to the rootNodeOp, so the collective epilogue can drive aux descriptor updateswithout dissecting the EVT.
include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp(+128 / −5)detail::has_aux_tma_store_v<FusionCallbacks>recursively scans the EVT tree; the aux pipeline activates whenever any node declaresstatic constexpr bool HasAuxTmaStore = true.TensorMapStorageaddssmem_tensormap_aux[NumEpilogueWarpGroups], and the gmem workspace is extended in lockstep (aux_base = NumEpilogueWarpGroups + (is_void_v<ElementC> ? 0 : 1);get_workspace_size()andto_underlying_arguments()updated accordingly).C / D"init / replace / publish / synchronize" choreography, reusing the existingcute::tma_descriptor_*primitives together withtma_desc_commit_group/wait_group:aux_store_initaux_tensormaps_perform_updateaux_tensormaps_cp_fence_releaseaux_tensormaps_fence_acquirestore()takes an additionalaux_store_tensormapparameter and forwards it intoConsumerStoreArgs, allowing the application visitor to inject the runtime descriptor intma_store()via.with(runtime_desc).include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp(+25) and..._pingpong.hpp(+25)did_batch_changeboundary, theD-pathinit / perform_update / cp_fence_release / fence_acquire / storecall sequence isduplicated for the aux path.
Zero-overhead guarantee
When no EVT node declares
HasAuxTmaStore = true,NumAuxTmaTensors == 0and the entire aux pipeline is short-circuited byif constexpr. Existing grouped GEMM instantiationsobserve zero behavioral difference and zero additional code/state.
Application example (
examples/57_hopper_grouped_gemm/, +1265)sm90_gated_swiglu_store_tma.hpp(+428) — custom EVT rootSm90GatedSwiGLUStoreTmaHasAuxTmaStore = trueand implementsget_aux_tma_descriptor()andaux_tensormaps_replace(). This node takes over the stockSm90AuxStorerole,additionally handling the
N → N/2shape change.visit()— full-width register accumulation,postreduce()— SwiGLU computed withtanh.approx.bf16x2PTX, scattered into half-width SMEM,tma_store()— emitted via.with(kernel_swiglu_tensormap), matching theD-pathtma_store_d.with(store_tensormap)idiom.α = 1.702andlimit = 7.0,g_clip = min(gate, limit)u_clip = clamp(up, ±limit)s = 0.5 * (1 + tanh(α · g_clip / 2))out = g_clip · s · u_clip + g_clip · s57_hopper_bias_swiglu_grouped_gemm_bf16.cu(+821) — single-file driverEpiTile / SmemLayoutAtomD / CopyOpR2S / StagesD; phase 2 assembles the real EVT treeSm90EVT<SwiGLUStore, ...>.Dpath carries the bias-added pre-activation[M_i, 2I](usable as a backward-pass input); the aux TMA path carries the SwiGLU output[M_i, I].The two stores do not block each other.
{Cooperative, Pingpong} × host_problem_shapes_available ∈ {true, false}— 4 configurations.CMakeLists.txt(+16) — registers the new target and ctest case.Design notes
HasAuxTmaStore,get_aux_tma_descriptor(),aux_tensormaps_replace(). All SwiGLU-specific semantics (N → N/2,tanh.approx.bf16x2, theptr_swiglu_out_arrayindexing scheme) live in the example. Other shape-changingoutputs (softmax + top-k mask, per-output quantization scales, etc.) can reuse the same hooks unchanged.
Sm90AuxStoreis left untouched. ItsArgumentsuse a single pointer + single stride and assume that the input shape and the aux output shape match — SwiGLU violatesthat precondition. Introducing
Sm90GatedSwiGLUStoreTmaas a new node keeps the existing class clean.Testing
./build/release/examples/57_hopper_grouped_gemm/57_hopper_bias_swiglu_grouped_gemm \ --m=1024 --n=4096 --k=2048 --groups=8 \ 2>&1 | grep -E 'schedule|Disposition|Avg runtime|TFLOPS' Result: all four schedule variants pass correctness verification. TFLOPS (GEMM) counts only the grouped GEMM (2 · Σ M·N·K); the fused bias-add and SwiGLU activation are included in the timing but intentionally not counted in the flop budget — matching the convention used by the upstream 57_hopper_grouped_gemm example. ┌─────────────┬─────────────────────┬─────────────┬──────────────────┬───────────────┐ │ Schedule │ Host problem shapes │ Disposition │ Avg runtime (ms) │ TFLOPS (GEMM) │ ├─────────────┼─────────────────────┼─────────────┼──────────────────┼───────────────┤ │ Cooperative │ available │ Passed │ 0.204138 │ 673.266 │ ├─────────────┼─────────────────────┼─────────────┼──────────────────┼───────────────┤ │ Cooperative │ unavailable │ Passed │ 0.206483 │ 665.618 │ ├─────────────┼─────────────────────┼─────────────┼──────────────────┼───────────────┤ │ Pingpong │ available │ Passed │ 0.197277 │ 696.681 │ ├─────────────┼─────────────────────┼─────────────┼──────────────────┼───────────────┤ │ Pingpong │ unavailable │ Passed │ 0.196608 │ 699.051 │ └─────────────┴─────────────────────┴─────────────┴──────────────────┴───────────────┘ *** Cooperative schedule (bias + SwiGLU) *** Disposition : Passed Avg runtime : 0.204138 ms TFLOPS (GEMM) : 673.266 *** Cooperative schedule (host problem shapes unavailable) *** Disposition : Passed Avg runtime : 0.206483 ms TFLOPS (GEMM) : 665.618 *** Pingpong schedule (bias + SwiGLU) *** Disposition : Passed Avg runtime : 0.197277 ms TFLOPS (GEMM) : 696.681 *** Pingpong schedule (host problem shapes unavailable) *** Disposition : Passed Avg runtime : 0.196608 ms TFLOPS (GEMM) : 699.051 ---