diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c8366ecce543..35492ca05455 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -50,25 +50,84 @@ def is_aiter_found_and_supported() -> bool: VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery. """ if current_platform.is_rocm() and IS_AITER_FOUND: - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx12x - return on_gfx9() + return on_gfx9() or on_gfx12x() return False -def if_aiter_supported(func: Callable) -> Callable: - """Decorator that only executes the function if - ROCm AITER package is supported and enabled on gfx9 archs. +def aiter_ops_supported_arch( + arch_checks: list[str], +): """ + Decorator factory that only executes the function if ROCm AITER package + is supported and any of the specified arch checks pass. + + Args: + arch_checks: List of arch check names (strings). + Must explicitly specify which archs are supported. + Available options and their mappings: + - "gfx1x": gfx11, gfx12 + - "gfx12x": gfx12 + - "mi3xx": gfx942, gfx950 + - "gfx9": gfx90a, gfx942, gfx950 + - "gfx942": gfx942 + - "gfx950": gfx950 + Examples: + - ["mi3xx"]: MI300 series only (gfx942, gfx950) + - ["mi3xx", "gfx12x"]: MI300 OR gfx12x + - ["gfx9"]: All gfx9 archs (gfx90a, gfx942, gfx950) + + Usage: + @aiter_ops_supported_arch(["mi3xx"]) # MI300 series only + def is_enabled(): ... + + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) # MI300 OR gfx12x + def is_linear_enabled(): ... + """ + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + # First check: platform and library availability + if not (current_platform.is_rocm() and IS_AITER_FOUND): + return None + + # Import arch check functions lazily only when on ROCm + from vllm.platforms.rocm import ( + on_gfx1x, + on_gfx9, + on_gfx12x, + on_gfx942, + on_gfx950, + on_mi3xx, + ) - @functools.wraps(func) - def wrapper(*args, **kwargs): - if is_aiter_found_and_supported(): - return func(*args, **kwargs) + func_map = { + "gfx1x": on_gfx1x, + "gfx12x": on_gfx12x, + "mi3xx": on_mi3xx, + "gfx9": on_gfx9, + "gfx942": on_gfx942, + "gfx950": on_gfx950, + } + + # Execute if any arch check passes + for check_name in arch_checks: + check_func = func_map.get(check_name) + if check_func is None: + raise ValueError( + f"Unknown arch check: {check_name}. " + f"Available: {list(func_map.keys())}" + ) + if check_func(): + return func(*args, **kwargs) - return None + return None + + return wrapper - return wrapper + return decorator def _rocm_aiter_fused_moe_impl( @@ -953,16 +1012,20 @@ class rocm_aiter_ops: after monkey patching the env variables in the unit test. Check Functions: - All check functions (is_*_enabled) are decorated with @if_aiter_supported, - which verifies: (1) platform is ROCm, (2) device arch is gfx9, and - (3) aiter library is installed. The check function then also verifies + All check functions (is_*_enabled) are decorated with @aiter_ops_supported_arch([...]), + which verifies: (1) platform is ROCm, (2) device arch matches the specified list, + and (3) aiter library is installed. The check function then also verifies the corresponding environment variable is enabled. i.e. ___ is_enabled() == current_platform.is_rocm() and | checked by - current_platform.is_on_gfx9() and | @if_aiter_supported + (arch_check passes) and | @aiter_ops_supported_arch([...]) IS_AITER_FOUND and _______________| cls._AITER_ENABLED -----> Check by the logic in `is_enabled()` + Note: Arch checks must be explicitly specified: + @aiter_ops_supported_arch(["mi3xx"]) # MI300 series only + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) # MI300 OR gfx12x + Example: from vllm._aiter_ops import rocm_aiter_ops @@ -1089,86 +1152,86 @@ def get_aiter_quant_type(quant_type_str: str): return mapping.get(name) @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_enabled(cls) -> bool: return cls._AITER_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_linear_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._LINEAR_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_linear_fp8_enabled(cls) -> bool: return cls.is_linear_enabled() @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_rmsnorm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._RMSNORM_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_fused_moe_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FMOE_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_fusion_moe_shared_experts_enabled(cls) -> bool: return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_mla_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MLA_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_mha_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MHA_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_shuffle_kv_cache_enabled(cls) -> bool: return cls._SHUFFLE_KV_CACHE_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_triton_unified_attn_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_fp4bmm_enabled(cls) -> bool: from vllm.platforms.rocm import on_gfx950 return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950() @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: from vllm.platforms.rocm import on_gfx950 return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM and on_gfx950() @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_triton_rotary_embed_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED @classmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx"]) def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM @staticmethod - @if_aiter_supported + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def register_ops_once() -> None: global _OPS_REGISTERED if not _OPS_REGISTERED: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f1fd3331802b..b6e86f8c9664 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -145,6 +145,7 @@ def _get_gcn_arch() -> str: _GCN_ARCH = _get_gcn_arch() _ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"]) +_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"]) _ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"]) _ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) _ON_GFX942 = "gfx942" in _GCN_ARCH @@ -226,6 +227,10 @@ def on_gfx1x() -> bool: return _ON_GFX1X +def on_gfx12x() -> bool: + return _ON_GFX12X + + def on_mi3xx() -> bool: return _ON_MI3XX