Skip to content
125 changes: 94 additions & 31 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,84 @@ def is_aiter_found_and_supported() -> bool:
VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
"""
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
from vllm.platforms.rocm import on_gfx9, on_gfx12x

return on_gfx9()
return on_gfx9() or on_gfx12x()
return False


def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
ROCm AITER package is supported and enabled on gfx9 archs.
def aiter_ops_supported_arch(
arch_checks: list[str],
):
"""
Decorator factory that only executes the function if ROCm AITER package
is supported and any of the specified arch checks pass.

Args:
arch_checks: List of arch check names (strings).
Must explicitly specify which archs are supported.
Available options and their mappings:
- "gfx1x": gfx11, gfx12
- "gfx12x": gfx12
- "mi3xx": gfx942, gfx950
- "gfx9": gfx90a, gfx942, gfx950
- "gfx942": gfx942
- "gfx950": gfx950
Examples:
- ["mi3xx"]: MI300 series only (gfx942, gfx950)
- ["mi3xx", "gfx12x"]: MI300 OR gfx12x
- ["gfx9"]: All gfx9 archs (gfx90a, gfx942, gfx950)

Usage:
@aiter_ops_supported_arch(["mi3xx"]) # MI300 series only
def is_enabled(): ...

@aiter_ops_supported_arch(["mi3xx", "gfx12x"]) # MI300 OR gfx12x
def is_linear_enabled(): ...
"""

def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
# First check: platform and library availability
if not (current_platform.is_rocm() and IS_AITER_FOUND):
return None

# Import arch check functions lazily only when on ROCm
from vllm.platforms.rocm import (
on_gfx1x,
on_gfx9,
on_gfx12x,
on_gfx942,
on_gfx950,
on_mi3xx,
)

@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_aiter_found_and_supported():
return func(*args, **kwargs)
func_map = {
"gfx1x": on_gfx1x,
"gfx12x": on_gfx12x,
"mi3xx": on_mi3xx,
"gfx9": on_gfx9,
"gfx942": on_gfx942,
"gfx950": on_gfx950,
}

# Execute if any arch check passes
for check_name in arch_checks:
check_func = func_map.get(check_name)
if check_func is None:
raise ValueError(
f"Unknown arch check: {check_name}. "
f"Available: {list(func_map.keys())}"
)
if check_func():
return func(*args, **kwargs)

return None
return None

return wrapper

return wrapper
return decorator


def _rocm_aiter_fused_moe_impl(
Expand Down Expand Up @@ -953,16 +1012,20 @@ class rocm_aiter_ops:
after monkey patching the env variables in the unit test.

Check Functions:
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
(3) aiter library is installed. The check function then also verifies
All check functions (is_*_enabled) are decorated with @aiter_ops_supported_arch([...]),
which verifies: (1) platform is ROCm, (2) device arch matches the specified list,
and (3) aiter library is installed. The check function then also verifies
the corresponding environment variable is enabled.
i.e. ___
is_enabled() == current_platform.is_rocm() and | checked by
current_platform.is_on_gfx9() and | @if_aiter_supported
(arch_check passes) and | @aiter_ops_supported_arch([...])
IS_AITER_FOUND and _______________|
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`

Note: Arch checks must be explicitly specified:
@aiter_ops_supported_arch(["mi3xx"]) # MI300 series only
@aiter_ops_supported_arch(["mi3xx", "gfx12x"]) # MI300 OR gfx12x

Example:
from vllm._aiter_ops import rocm_aiter_ops

Expand Down Expand Up @@ -1089,86 +1152,86 @@ def get_aiter_quant_type(quant_type_str: str):
return mapping.get(name)

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_enabled(cls) -> bool:
return cls._AITER_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_linear_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._LINEAR_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_linear_fp8_enabled(cls) -> bool:
return cls.is_linear_enabled()

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_rmsnorm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_fused_moe_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FMOE_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_mla_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._MLA_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_mha_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._MHA_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_shuffle_kv_cache_enabled(cls) -> bool:
return cls._SHUFFLE_KV_CACHE_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_triton_unified_attn_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_fp8bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_fp4bmm_enabled(cls) -> bool:
from vllm.platforms.rocm import on_gfx950

return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
from vllm.platforms.rocm import on_gfx950

return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM and on_gfx950()

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_triton_rotary_embed_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_triton_gemm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM

@staticmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/input_quant_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def forward_hip(
scale_ub: torch.Tensor | None = None,
use_triton: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant and use_triton:
from vllm.platforms.rocm import on_gfx12x

if self.is_group_quant and use_triton and on_gfx12x():
assert scale is None, "Dynamic group quantization does not use scale"

return torch.ops.vllm.triton_per_token_group_quant_fp8(x, self.group_size)
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _get_gcn_arch() -> str:
_GCN_ARCH = _get_gcn_arch()

_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"])
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
Expand Down Expand Up @@ -226,6 +227,10 @@ def on_gfx1x() -> bool:
return _ON_GFX1X


def on_gfx12x() -> bool:
return _ON_GFX12X


def on_mi3xx() -> bool:
return _ON_MI3XX

Expand Down
Loading