【2/N】add support wmma kernels for RDNA4(GFX1201)#250
Open
vivienfanghuagood wants to merge 3 commits intoROCm:mainfrom
Open
【2/N】add support wmma kernels for RDNA4(GFX1201)#250vivienfanghuagood wants to merge 3 commits intoROCm:mainfrom
vivienfanghuagood wants to merge 3 commits intoROCm:mainfrom
Conversation
New WMMA kernels (ported to @flyc.kernel API): - wmma_gemm.py: f16/bf16 GEMM using WMMA 16x16x16, double-buffered LDS - wmma_fp8_gemm.py: FP8 preshuffle GEMM using WMMA fp8 instructions Kernels call ODS WMMA classes directly (e.g. rocdl.wmma_f32_16x16x16_bf16), no wrapper layer needed. Wave32 adaptations for existing kernels: - layernorm_kernel.py: use get_warp_size() for wave32/wave64 - rmsnorm_kernel.py: use get_warp_size() for wave32/wave64 - softmax_kernel.py: use get_warp_size() for wave32/wave64 Tests (following repo conventions): - test_wmma_gemm.py: bf16/f16 correctness, f32 output, benchmark - test_wmma_fp8_gemm.py: fp8 correctness, preshuffle, quantize tests
wmma_fp8_gemm_lds.py: FP8 WMMA GEMM with A through LDS, B preshuffled. Architecture: - A: raw [M,K] fp8 -> coalesced GMEM load -> LDS ping-pong -> WMMA - B: preshuffled fp8 -> GMEM direct -> WMMA (no LDS needed) - Rowwise scaling: per-token scale_a[M], per-channel scale_b[N] Key design decisions: - Only A uses LDS (12KB), B bypasses LDS via preshuffle (saves 50% LDS) - SCF loop carries only accumulators (zero v_dual_mov register rename) - Ping-pong double buffer hides A load latency behind WMMA compute - K-padding (a_k_pad=16) avoids LDS bank conflicts Performance (gfx1201): 4096x4096x4096: 205 TFLOPS (90% of torch._scaled_mm) 4096x4096x8192: 195 TFLOPS (105% of torch, exceeds reference) Complementary to wmma_fp8_gemm.py (no-LDS version): - wmma_fp8_gemm.py: best for small-M decode (M<128), 97-116% of torch - wmma_fp8_gemm_lds.py: best for large-M compute-bound (M>=128), 88-105%
FlyDSL's AST rewriter intercepts 'while' and tries to lower it to scf.while, but the condition (Python bool from 'sh >= 1') has no MLIR .owner attribute. Use range_constexpr(log2(WARP_SIZE)) instead, which unrolls at compile time and produces the same shuffle sequence.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add WMMA GEMM kernels for RDNA4 (gfx12xx)
Technical Details
Three FP8/BF16 GEMM kernels targeting different workload profiles on RDNA4:
wmma_gemm: BF16 GEMM using LDS double-buffered A+B tiles with
ping-pong pipeline. SCF K-loop with barrier-based synchronization.
Achieves 91-99% of torch.mm on compute-bound shapes (M>=1024).
wmma_fp8_gemm: FP8 GEMM with no LDS. A loaded directly from raw
[M,K] layout, B preshuffled for zero-copy GMEM-to-register path.
Software-pipelined K-loop carries A/B prefetch state for load-compute
overlap. Per-token/per-channel rowwise scaling. Best for small-M
decode (M<128), achieving 97-116% of torch._scaled_mm.
wmma_fp8_gemm_lds: FP8 GEMM with LDS-buffered A (ping-pong) and
preshuffled B (direct GMEM). Combines coalesced A loads via LDS with
zero-copy B loads. SCF K-loop carries only accumulators, eliminating
register rename overhead. Best for large-M compute-bound shapes
(M>=128), achieving 88-105% of torch._scaled_mm.
Also includes wave32 adaptations for layernorm/rmsnorm/softmax kernels,
benchmark integration in benchmark_common.py, and correctness tests.
Test Plan
Test Result
Submission Checklist