Skip to content

[gfx1201] Enable RMSNorm support for gfx1201#2352

Open
vllmellm wants to merge 2 commits intoROCm:mainfrom
EmbeddedLLM:rdna4_rmsnorm-support
Open

[gfx1201] Enable RMSNorm support for gfx1201#2352
vllmellm wants to merge 2 commits intoROCm:mainfrom
EmbeddedLLM:rdna4_rmsnorm-support

Conversation

@vllmellm
Copy link
Contributor

Motivation

The RMSNorm kernels in csrc/kernels/rmsnorm_quant_kernels.cu use CDNA-specific inline assembly instructions that are not supported on RDNA4 (gfx1201) architecture. This prevents the RMSNorm operation from working on gfx1201 GPUs. This PR aims to enable RMSNorm support on gfx1201 by replacing unsupported assembly instructions with portable HIP/C++ alternatives.

Technical Details

Changes Overview

Modified csrc/kernels/rmsnorm_quant_kernels.cu to replace CDNA-specific inline assembly with portable implementations for gfx11/gfx12:

  1. Replaced v_pk_mul_f32 inline assembly (lines 146-151, 196-201)

    • Changed from: asm volatile("v_pk_mul_f32 %0, %1, %2" ...)
    • Changed to: Standard float multiplication with __gfx11__ || __gfx12__ guard
    • This instruction is not supported on RDNA4 (gfx12xx) architecture
  2. Replaced bf16 unpacking inline assembly (lines 162-176)

    • Changed from: v_lshlrev_b32_e32 and v_and_b32_e32 instructions
    • Changed to: ck_tile::bit_cast with shift operations for unpacking bf16 values
    • Provides equivalent functionality using portable HIP/C++ code
  3. Replaced fp16 unpacking inline assembly (lines 180-194)

    • Changed from: v_cvt_f32_f16_e32 and v_cvt_f32_f16_sdwa instructions
    • Changed to: ck_tile::bit_cast with shift operations for unpacking fp16 values
    • SDWA (Sub-Dword Addressing) instructions are CDNA-specific

Compatibility

  • CDNA (gfx90a, gfx942): No functional change - continues to use optimized inline assembly
  • RDNA4 (gfx1201): Now uses portable HIP/C++ implementation
  • All changes are guarded by preprocessor conditions (#if defined(__gfx11__) || defined(__gfx12__)) to ensure optimal performance on both architectures

Test Plan

Run the RMSNorm test suites to validate the changes:

# Test fused RMSNorm with add and quantization (FP8)
python op_tests/test_rmsnorm2dFusedAddQuant.py --mode 7 -q fp8

# Test standard RMSNorm operations
python op_tests/test_rmsnorm2d.py

Tests cover:

  • RMSNorm with residual addition
  • FP8 quantization paths
  • Various hidden dimension sizes (up to 8192)
  • Both bf16 and fp16 input data types
  • Per-token and per-channel quantization modes

Test Result

m n quant_type add_residual dtype quant_dtype smoothquant torch us hip us hip err hip bw(GB/s)
8 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 31.8614 2.18096 0.0289307 25.4712
256 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 78.5019 5.69223 0.0311623 301.907
2048 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 156.285 19.365 0.030992 709.261
2560 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 138.565 23.4243 0.0310837 732.917
32768 1024 4 True torch.bfloat16 torch.float8_e4m3fn False 2394.75 390.314 0.0309762 562.953
8 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 27.368 2.29151 0.0303955 48.4846
256 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 79.2884 13.0944 0.0310841 262.482
2048 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 195.881 36.4202 0.0309489 754.243
2560 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 236.859 45.099 0.0309561 761.349
32768 2048 4 True torch.bfloat16 torch.float8_e4m3fn False 4646.04 776.304 0.0310346 566.089
8 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 37.3548 2.4614 0.0296936 90.2764
256 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 86.6443 12.8225 0.0306196 536.097
2048 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 402.692 101.938 0.0309554 538.948
2560 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 587.155 126.984 0.0310359 540.795
32768 4096 4 True torch.bfloat16 torch.float8_e4m3fn False 9209.33 1579.47 0.0309953 556.461
8 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 44.9283 2.94468 0.0310822 150.92
256 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 114.853 31.5449 0.0310688 435.829
2048 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 1130.84 201.211 0.0310543 546.085
2560 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 1470.88 251.325 0.0310502 546.482
32768 8192 4 True torch.bfloat16 torch.float8_e4m3fn False 18529.6 3153.58 0.0310005 557.408

Submission Checklist

@vllmellm vllmellm requested a review from a team March 19, 2026 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants