Skip to content

【2/N】add support wmma kernels for RDNA4(GFX1201)#250

Open
vivienfanghuagood wants to merge 3 commits intoROCm:mainfrom
vivienfanghuagood:fx1201-wmma-gemm
Open

【2/N】add support wmma kernels for RDNA4(GFX1201)#250
vivienfanghuagood wants to merge 3 commits intoROCm:mainfrom
vivienfanghuagood:fx1201-wmma-gemm

Conversation

@vivienfanghuagood
Copy link
Contributor

@vivienfanghuagood vivienfanghuagood commented Mar 20, 2026

Motivation

Add WMMA GEMM kernels for RDNA4 (gfx12xx)

Technical Details

Three FP8/BF16 GEMM kernels targeting different workload profiles on RDNA4:

  1. wmma_gemm: BF16 GEMM using LDS double-buffered A+B tiles with
    ping-pong pipeline. SCF K-loop with barrier-based synchronization.
    Achieves 91-99% of torch.mm on compute-bound shapes (M>=1024).

  2. wmma_fp8_gemm: FP8 GEMM with no LDS. A loaded directly from raw
    [M,K] layout, B preshuffled for zero-copy GMEM-to-register path.
    Software-pipelined K-loop carries A/B prefetch state for load-compute
    overlap. Per-token/per-channel rowwise scaling. Best for small-M
    decode (M<128), achieving 97-116% of torch._scaled_mm.

  3. wmma_fp8_gemm_lds: FP8 GEMM with LDS-buffered A (ping-pong) and
    preshuffled B (direct GMEM). Combines coalesced A loads via LDS with
    zero-copy B loads. SCF K-loop carries only accumulators, eliminating
    register rename overhead. Best for large-M compute-bound shapes
    (M>=128), achieving 88-105% of torch._scaled_mm.

Also includes wave32 adaptations for layernorm/rmsnorm/softmax kernels,
benchmark integration in benchmark_common.py, and correctness tests.

Test Plan

Test Result

Kernel Implementation Scenarios vs torch
wmma_gemm.py bf16 GEMM (LDS A+B) compute-bound 91-99%
wmma_fp8_gemm.py FP8 GEMM (no LDS) small M decode (M<128) 97-116%
wmma_fp8_gemm_lds.py FP8 GEMM (LDS A) big M compute-bound (M>=128) 88-105%

Submission Checklist

New WMMA kernels (ported to @flyc.kernel API):
- wmma_gemm.py: f16/bf16 GEMM using WMMA 16x16x16, double-buffered LDS
- wmma_fp8_gemm.py: FP8 preshuffle GEMM using WMMA fp8 instructions

Kernels call ODS WMMA classes directly (e.g. rocdl.wmma_f32_16x16x16_bf16),
no wrapper layer needed.

Wave32 adaptations for existing kernels:
- layernorm_kernel.py: use get_warp_size() for wave32/wave64
- rmsnorm_kernel.py: use get_warp_size() for wave32/wave64
- softmax_kernel.py: use get_warp_size() for wave32/wave64

Tests (following repo conventions):
- test_wmma_gemm.py: bf16/f16 correctness, f32 output, benchmark
- test_wmma_fp8_gemm.py: fp8 correctness, preshuffle, quantize tests
wmma_fp8_gemm_lds.py: FP8 WMMA GEMM with A through LDS, B preshuffled.

Architecture:
  - A: raw [M,K] fp8 -> coalesced GMEM load -> LDS ping-pong -> WMMA
  - B: preshuffled fp8 -> GMEM direct -> WMMA (no LDS needed)
  - Rowwise scaling: per-token scale_a[M], per-channel scale_b[N]

Key design decisions:
  - Only A uses LDS (12KB), B bypasses LDS via preshuffle (saves 50% LDS)
  - SCF loop carries only accumulators (zero v_dual_mov register rename)
  - Ping-pong double buffer hides A load latency behind WMMA compute
  - K-padding (a_k_pad=16) avoids LDS bank conflicts

Performance (gfx1201):
  4096x4096x4096:  205 TFLOPS (90% of torch._scaled_mm)
  4096x4096x8192:  195 TFLOPS (105% of torch, exceeds reference)

Complementary to wmma_fp8_gemm.py (no-LDS version):
  - wmma_fp8_gemm.py: best for small-M decode (M<128), 97-116% of torch
  - wmma_fp8_gemm_lds.py: best for large-M compute-bound (M>=128), 88-105%
FlyDSL's AST rewriter intercepts 'while' and tries to lower it to
scf.while, but the condition (Python bool from 'sh >= 1') has no
MLIR .owner attribute. Use range_constexpr(log2(WARP_SIZE)) instead,
which unrolls at compile time and produces the same shuffle sequence.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant