Skip to content

[ROCm] Enable Aiter ck_gemm_a8w8_blockscale for RDNA4 gpus. Qwen3.5-27B-FP8 tp=2, Qwen3-0.6B-FP8 tp=1 #77

Open
big-yellow-duck wants to merge 10 commits intomainfrom
rdna4-aiter-fp8-blockscale
Open

[ROCm] Enable Aiter ck_gemm_a8w8_blockscale for RDNA4 gpus. Qwen3.5-27B-FP8 tp=2, Qwen3-0.6B-FP8 tp=1 #77
big-yellow-duck wants to merge 10 commits intomainfrom
rdna4-aiter-fp8-blockscale

Conversation

@big-yellow-duck
Copy link

@big-yellow-duck big-yellow-duck commented Mar 12, 2026

Purpose

We enabled ck gemm_a8w8_blockscale in Aiter for gfx1201 but vllm has not enabled support yet. this pr enables aiter for gfx12xx for FP8 inference. the ck gemm_a8w8_blockscale from aiter provides better performance than the default untuned triton kernel in vllm.

enabled the aiter FP8 path for gfx12x card to use tuned ck gemm configs from aiter.

Test Plan

benchmark Qwen3.5-27B-FP8 with default vllm and vllm with Aiter enabled on 2x Radeon PRO 9700

default

VLLM_ROCM_USE_AITER=0 vllm serve Qwen/Qwen3.5-27B-FP8 -tp 2 --gpu-memory-utilization 0.98 --max-model-len 65536

using aiter ck gemm_a8w8_blockscale

VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER_RMSNORM=0   vllm serve Qwen/Qwen3.5-27B-FP8 -tp 2 --gpu-memory-utilization 0.98 --max-model-len 65536

Test Results

Benchmark using Qwen/Qwen3.5-27B-FP8

TTFT (ms)

ISL-OSL Default Aiter CK GEMM Speedup
512-512 7540.37 7723.20 -2.4%
1024-1024 3728.69 3984.64 -6.9%
2048-2048 5666.84 7198.22 -27.0%
4096-4096 10358.16 12424.68 -20.0%
8192-1024 21272.21 25200.08 -18.5%
16384-2048 49513.83 56871.82 -14.9%
Average 16346.68 18900.44 -15.6%

TPOT (ms)

ISL-OSL Default Aiter CK GEMM Speedup
512-512 62.51 42.66 +31.8%
1024-1024 65.42 44.28 +32.3%
2048-2048 65.82 46.89 +28.8%
4096-4096 69.03 49.62 +28.1%
8192-1024 108.26 96.39 +11.0%
16384-2048 125.64 113.84 +9.4%
Average 82.78 65.61 +20.7%

E2E Latency (ms)

ISL-OSL Default Aiter CK GEMM Speedup
512-512 39483.28 29524.24 +25.2%
1024-1024 70657.76 49286.97 +30.2%
2048-2048 140390.78 103172.85 +26.5%
4096-4096 293054.02 215632.43 +26.4%
8192-1024 132024.15 123804.59 +6.2%
16384-2048 306707.17 289907.97 +5.5%
Average 163719.53 135221.51 +17.4%

Accuracy checks

GSM8K Accuracy

Metric Default Aiter CK GEMM Diff
exact_match,strict-match 86.13% 85.22% -0.91%
exact_match,flexible-extract 87.79% 86.66% -1.13%

All accuracy differences are not statistically significant


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@big-yellow-duck big-yellow-duck changed the title [ROCm] Enable Aiter ck_gemm_a8w8_blockscale for RDNA4 gpus. [ROCm] Enable Aiter ck_gemm_a8w8_blockscale for RDNA4 gpus. Qwen3.5-27B-FP8 tp=2, Qwen3-0.6B-FP8 tp=1 Mar 12, 2026
@big-yellow-duck big-yellow-duck marked this pull request as ready for review March 16, 2026 04:50
Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
@tjtanaa
Copy link
Member

tjtanaa commented Mar 16, 2026

@big-yellow-duck @BadrBasowid

We need to guard all of the other ops with additional condition on_mi3xx() so that users don't need to know which flag to switch off. On Radeon user can just do VLLM_ROCM_USE_AITER=1.

https://github.com/vllm-project/vllm/blob/8d3f8f485efc0b812f91ecf19a3a12232587550c/vllm/_aiter_ops.py#L1129-L1201

Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>

import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx12x, on_mi3xx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not import here. It has to be lazy import in functions

Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants