Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 94 additions & 31 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,84 @@ def is_aiter_found_and_supported() -> bool:
VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
"""
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
from vllm.platforms.rocm import on_gfx9, on_gfx12x

return on_gfx9()
return on_gfx9() or on_gfx12x()
return False


def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
ROCm AITER package is supported and enabled on gfx9 archs.
def aiter_ops_supported_arch(
arch_checks: list[str],
):
"""
Decorator factory that only executes the function if ROCm AITER package
is supported and any of the specified arch checks pass.

Args:
arch_checks: List of arch check names (strings).
Must explicitly specify which archs are supported.
Available options and their mappings:
- "gfx1x": gfx11, gfx12
- "gfx12x": gfx12
- "mi3xx": gfx942, gfx950
- "gfx9": gfx90a, gfx942, gfx950
- "gfx942": gfx942
- "gfx950": gfx950
Examples:
- ["mi3xx"]: MI300 series only (gfx942, gfx950)
- ["mi3xx", "gfx12x"]: MI300 OR gfx12x
- ["gfx9"]: All gfx9 archs (gfx90a, gfx942, gfx950)

Usage:
@aiter_ops_supported_arch(["mi3xx"]) # MI300 series only
def is_enabled(): ...

@aiter_ops_supported_arch(["mi3xx", "gfx12x"]) # MI300 OR gfx12x
def is_linear_enabled(): ...
"""

def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
# First check: platform and library availability
if not (current_platform.is_rocm() and IS_AITER_FOUND):
return None

# Import arch check functions lazily only when on ROCm
from vllm.platforms.rocm import (
on_gfx1x,
on_gfx9,
on_gfx12x,
on_gfx942,
on_gfx950,
on_mi3xx,
)

@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_aiter_found_and_supported():
return func(*args, **kwargs)
func_map = {
"gfx1x": on_gfx1x,
"gfx12x": on_gfx12x,
"mi3xx": on_mi3xx,
"gfx9": on_gfx9,
"gfx942": on_gfx942,
"gfx950": on_gfx950,
}

# Execute if any arch check passes
for check_name in arch_checks:
check_func = func_map.get(check_name)
if check_func is None:
raise ValueError(
f"Unknown arch check: {check_name}. "
f"Available: {list(func_map.keys())}"
)
if check_func():
return func(*args, **kwargs)

return None
return None

return wrapper

return wrapper
return decorator


def _rocm_aiter_fused_moe_impl(
Expand Down Expand Up @@ -953,16 +1012,20 @@ class rocm_aiter_ops:
after monkey patching the env variables in the unit test.

Check Functions:
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
(3) aiter library is installed. The check function then also verifies
All check functions (is_*_enabled) are decorated with @aiter_ops_supported_arch([...]),
which verifies: (1) platform is ROCm, (2) device arch matches the specified list,
and (3) aiter library is installed. The check function then also verifies
the corresponding environment variable is enabled.
i.e. ___
is_enabled() == current_platform.is_rocm() and | checked by
current_platform.is_on_gfx9() and | @if_aiter_supported
(arch_check passes) and | @aiter_ops_supported_arch([...])
IS_AITER_FOUND and _______________|
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`

Note: Arch checks must be explicitly specified:
@aiter_ops_supported_arch(["mi3xx"]) # MI300 series only
@aiter_ops_supported_arch(["mi3xx", "gfx12x"]) # MI300 OR gfx12x

Example:
from vllm._aiter_ops import rocm_aiter_ops

Expand Down Expand Up @@ -1089,86 +1152,86 @@ def get_aiter_quant_type(quant_type_str: str):
return mapping.get(name)

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_enabled(cls) -> bool:
return cls._AITER_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_linear_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._LINEAR_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_linear_fp8_enabled(cls) -> bool:
return cls.is_linear_enabled()

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_rmsnorm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_fused_moe_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FMOE_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_mla_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._MLA_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_mha_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._MHA_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_shuffle_kv_cache_enabled(cls) -> bool:
return cls._SHUFFLE_KV_CACHE_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def is_triton_unified_attn_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_fp8bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_fp4bmm_enabled(cls) -> bool:
from vllm.platforms.rocm import on_gfx950

return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
from vllm.platforms.rocm import on_gfx950

return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM and on_gfx950()

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_triton_rotary_embed_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED

@classmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx"])
def is_triton_gemm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM

@staticmethod
@if_aiter_supported
@aiter_ops_supported_arch(["mi3xx", "gfx12x"])
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _get_gcn_arch() -> str:
_GCN_ARCH = _get_gcn_arch()

_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"])
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
Expand Down Expand Up @@ -226,6 +227,10 @@ def on_gfx1x() -> bool:
return _ON_GFX1X


def on_gfx12x() -> bool:
return _ON_GFX12X


def on_mi3xx() -> bool:
return _ON_MI3XX

Expand Down
Loading