Skip to content

[Cutlass SM90] Per-group aux TMA descriptor update for grouped GEMM + Gated-SwiGLU example)#3256

Open
Butterfingrz wants to merge 1 commit into
NVIDIA:mainfrom
Butterfingrz:kernel/sm90-grouped-tma-desc
Open

[Cutlass SM90] Per-group aux TMA descriptor update for grouped GEMM + Gated-SwiGLU example)#3256
Butterfingrz wants to merge 1 commit into
NVIDIA:mainfrom
Butterfingrz:kernel/sm90-grouped-tma-desc

Conversation

@Butterfingrz
Copy link
Copy Markdown

@Butterfingrz Butterfingrz commented May 21, 2026

Motivation

SM90 persistent grouped GEMM is already a critical high-performance path for MoE workloads: PyTorch exposes BF16 grouped_mm for CUDA grouped GEMM, and TorchTitan's Llama/MoE path directly depends on torch._grouped_mm. In practice this makes CUTLASS SM90 grouped GEMM a natural backend to extend rather than bypass.

However, the existing CUTLASS epilogue templates assume a 1:1 accumulator-to-output mapping, and the stock aux-store path is tied to a static, shape-preserving descriptor. This prevents fusing shape-changing activations such as SwiGLU, where the GEMM produces [M, 2I] but the activation output is [M, I] — especially when a persistent grouped kernel moves across groups with different pointers and problem shapes.

This PR add a generic per-group Aux TMA descriptor update path to the SM90 persistent grouped GEMM TMA warp-specialized epilogue, and demonstrates it with a Hopper Grouped GEMM + Gated-SwiGLU end-to-end example. The framework-level changes are activation-agnostic — SwiGLU is simply the first consumer of the mechanism and lives entirely under examples/57_hopper_grouped_gemm/.

Summary

  • Framework: A generic per-group Aux TMA descriptor update pipeline inside the SM90 grouped TMA warp-specialized epilogue, mirroring the existing C / D per-group descriptor
    lifecycle and gated on a compile-time SFINAE trait so that existing instantiations see zero behavioral and code-size impact.
  • Application: A new Sm90GatedSwiGLUStoreTma EVT root node and a complete 57_hopper_bias_swiglu_grouped_gemm_bf16 driver demonstrating Bias + SwiGLU fused into grouped GEMM
    bf16, validated on both cooperative and pingpong schedules.

Framework changes (4 include files, +203 / −5)

include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp (+25)

  • ConsumerStoreArgs gains an aux_store_tensormap field that threads the per-CTA × per-warpgroup aux gmem tensormap slot pointer through to every EVT callback.
  • Sm90TreeVisitor adds get_aux_tma_descriptor() and aux_tensormaps_replace() that forward to the root NodeOp, so the collective epilogue can drive aux descriptor updates
    without dissecting the EVT.

include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp (+128 / −5)

  • A SFINAE probe detail::has_aux_tma_store_v<FusionCallbacks> recursively scans the EVT tree; the aux pipeline activates whenever any node declares static constexpr bool HasAuxTmaStore = true.
  • TensorMapStorage adds smem_tensormap_aux[NumEpilogueWarpGroups], and the gmem workspace is extended in lockstep (aux_base = NumEpilogueWarpGroups + (is_void_v<ElementC> ? 0 : 1); get_workspace_size() and to_underlying_arguments() updated accordingly).
  • Four new helpers that strictly mirror the C / D "init / replace / publish / synchronize" choreography, reusing the existing cute::tma_descriptor_* primitives together with
    tma_desc_commit_group / wait_group:
    • aux_store_init
    • aux_tensormaps_perform_update
    • aux_tensormaps_cp_fence_release
    • aux_tensormaps_fence_acquire
  • store() takes an additional aux_store_tensormap parameter and forwards it into ConsumerStoreArgs, allowing the application visitor to inject the runtime descriptor in
    tma_store() via .with(runtime_desc).

include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp (+25) and ..._pingpong.hpp (+25)

  • In the consumer warp group prologue and at every did_batch_change boundary, the D-path init / perform_update / cp_fence_release / fence_acquire / store call sequence is
    duplicated for the aux path.
  • No new pipeline stages or barrier semantics are introduced; the existing synchronization skeleton is reused as-is.

Zero-overhead guarantee

When no EVT node declares HasAuxTmaStore = true, NumAuxTmaTensors == 0 and the entire aux pipeline is short-circuited by if constexpr. Existing grouped GEMM instantiations
observe zero behavioral difference and zero additional code/state.

Application example (examples/57_hopper_grouped_gemm/, +1265)

sm90_gated_swiglu_store_tma.hpp (+428) — custom EVT root Sm90GatedSwiGLUStoreTma

  • Framework hooks: declares HasAuxTmaStore = true and implements get_aux_tma_descriptor() and aux_tensormaps_replace(). This node takes over the stock Sm90AuxStore role,
    additionally handling the N → N/2 shape change.
  • Three-stage pipeline:
    • visit() — full-width register accumulation,
    • postreduce() — SwiGLU computed with tanh.approx.bf16x2 PTX, scattered into half-width SMEM,
    • tma_store() — emitted via .with(kernel_swiglu_tensormap), matching the D-path tma_store_d.with(store_tensormap) idiom.
  • SwiGLU formula (OpenAI / GPT-OSS mxfp4 variant): with α = 1.702 and limit = 7.0,
    • g_clip = min(gate, limit)
    • u_clip = clamp(up, ±limit)
    • s = 0.5 * (1 + tanh(α · g_clip / 2))
    • out = g_clip · s · u_clip + g_clip · s

57_hopper_bias_swiglu_grouped_gemm_bf16.cu (+821) — single-file driver

  • Two-phase CollectiveBuilder: phase 1 uses the stock fusion to probe EpiTile / SmemLayoutAtomD / CopyOpR2S / StagesD; phase 2 assembles the real EVT tree Sm90EVT<SwiGLUStore, ...>.
  • Parallel dual outputs: the D path carries the bias-added pre-activation [M_i, 2I] (usable as a backward-pass input); the aux TMA path carries the SwiGLU output [M_i, I].
    The two stores do not block each other.
  • Test matrix: {Cooperative, Pingpong} × host_problem_shapes_available ∈ {true, false} — 4 configurations.

CMakeLists.txt (+16) — registers the new target and ctest case.

Design notes

  • Generic hook in the framework, semantics in the application. The library only exposes three contract points — HasAuxTmaStore, get_aux_tma_descriptor(),
    aux_tensormaps_replace(). All SwiGLU-specific semantics (N → N/2, tanh.approx.bf16x2, the ptr_swiglu_out_array indexing scheme) live in the example. Other shape-changing
    outputs (softmax + top-k mask, per-output quantization scales, etc.) can reuse the same hooks unchanged.
  • The stock Sm90AuxStore is left untouched. Its Arguments use a single pointer + single stride and assume that the input shape and the aux output shape match — SwiGLU violates
    that precondition. Introducing Sm90GatedSwiGLUStoreTma as a new node keeps the existing class clean.

Testing

./build/release/examples/57_hopper_grouped_gemm/57_hopper_bias_swiglu_grouped_gemm \
    --m=1024 --n=4096 --k=2048 --groups=8 \
    2>&1 | grep -E 'schedule|Disposition|Avg runtime|TFLOPS'

Result: all four schedule variants pass correctness verification.
TFLOPS (GEMM) counts only the grouped GEMM (2 · Σ M·N·K); the fused bias-add and SwiGLU activation are included in the timing but intentionally not counted in the flop budget —
matching the convention used by the upstream 57_hopper_grouped_gemm example.

┌─────────────┬─────────────────────┬─────────────┬──────────────────┬───────────────┐
│  Schedule   │ Host problem shapes │ Disposition │ Avg runtime (ms) │ TFLOPS (GEMM) │
├─────────────┼─────────────────────┼─────────────┼──────────────────┼───────────────┤
│ Cooperative │      available      │   Passed    │         0.204138 │       673.266 │
├─────────────┼─────────────────────┼─────────────┼──────────────────┼───────────────┤
│ Cooperative │     unavailable     │   Passed    │         0.206483 │       665.618 │
├─────────────┼─────────────────────┼─────────────┼──────────────────┼───────────────┤
│ Pingpong    │      available      │   Passed    │         0.197277 │       696.681 │
├─────────────┼─────────────────────┼─────────────┼──────────────────┼───────────────┤
│ Pingpong    │     unavailable     │   Passed    │         0.196608 │       699.051 │
└─────────────┴─────────────────────┴─────────────┴──────────────────┴───────────────┘

*** Cooperative schedule (bias + SwiGLU) ***
  Disposition   : Passed
  Avg runtime   : 0.204138 ms
  TFLOPS (GEMM) : 673.266
*** Cooperative schedule (host problem shapes unavailable) ***
  Disposition   : Passed
  Avg runtime   : 0.206483 ms
  TFLOPS (GEMM) : 665.618
*** Pingpong schedule (bias + SwiGLU) ***
  Disposition   : Passed
  Avg runtime   : 0.197277 ms
  TFLOPS (GEMM) : 696.681
*** Pingpong schedule (host problem shapes unavailable) ***
  Disposition   : Passed
  Avg runtime   : 0.196608 ms
  TFLOPS (GEMM) : 699.051

---

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant