From 3e9d168228ad56f5e1198b9cccfcf922b70b7cb1 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Wed, 11 Mar 2026 03:28:50 +0000 Subject: [PATCH 1/8] add aiter gemm_a8w8_blockscale support for gfx1201 --- vllm/_aiter_ops.py | 4 ++-- vllm/platforms/rocm.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c8366ecce543..df90e59f0582 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -50,9 +50,9 @@ def is_aiter_found_and_supported() -> bool: VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery. """ if current_platform.is_rocm() and IS_AITER_FOUND: - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx12x - return on_gfx9() + return on_gfx9() or on_gfx12x() return False diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f1fd3331802b..5606a85e6092 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -145,6 +145,7 @@ def _get_gcn_arch() -> str: _GCN_ARCH = _get_gcn_arch() _ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"]) +_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"]) _ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"]) _ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) _ON_GFX942 = "gfx942" in _GCN_ARCH @@ -225,6 +226,8 @@ def _capability_from_gcn_arch(gcn_arch: str) -> tuple[int, int] | None: def on_gfx1x() -> bool: return _ON_GFX1X +def on_gfx12x()-> bool: + return _ON_GFX12X def on_mi3xx() -> bool: return _ON_MI3XX From 0c8b931dbb617524461c70ac2936dcfd40bc1f4b Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Wed, 11 Mar 2026 09:31:42 +0000 Subject: [PATCH 2/8] use triton quant fp8 --- vllm/model_executor/layers/quantization/input_quant_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 6fa85436dfc2..6974f874253f 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -137,7 +137,8 @@ def forward_hip( scale_ub: torch.Tensor | None = None, use_triton: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - if self.is_group_quant and use_triton: + from vllm.platforms.rocm import on_gfx12x + if self.is_group_quant and use_triton and on_gfx12x() : assert scale is None, "Dynamic group quantization does not use scale" return torch.ops.vllm.triton_per_token_group_quant_fp8(x, self.group_size) From c4b46fd3127ea53e07427fcbdbd3cae8ba767ca0 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Wed, 11 Mar 2026 14:35:54 +0000 Subject: [PATCH 3/8] enable aiter quant fp8 for gfx1201 --- vllm/model_executor/layers/quantization/input_quant_fp8.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 6974f874253f..6fa85436dfc2 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -137,8 +137,7 @@ def forward_hip( scale_ub: torch.Tensor | None = None, use_triton: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - from vllm.platforms.rocm import on_gfx12x - if self.is_group_quant and use_triton and on_gfx12x() : + if self.is_group_quant and use_triton: assert scale is None, "Dynamic group quantization does not use scale" return torch.ops.vllm.triton_per_token_group_quant_fp8(x, self.group_size) From cd54e644908307cbb644614dcba0d86dfeedd203 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Mon, 16 Mar 2026 06:35:13 +0000 Subject: [PATCH 4/8] fix formattingg Signed-off-by: big-yellow-duck --- vllm/platforms/rocm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5606a85e6092..b6e86f8c9664 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -226,9 +226,11 @@ def _capability_from_gcn_arch(gcn_arch: str) -> tuple[int, int] | None: def on_gfx1x() -> bool: return _ON_GFX1X -def on_gfx12x()-> bool: + +def on_gfx12x() -> bool: return _ON_GFX12X + def on_mi3xx() -> bool: return _ON_MI3XX From d95a214f226db95c8b5981821d477251facc0dd4 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Mon, 16 Mar 2026 09:21:18 +0000 Subject: [PATCH 5/8] add conditional aiter_ops for gfx12x Signed-off-by: big-yellow-duck --- vllm/_aiter_ops.py | 97 +++++++++++++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 28 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index df90e59f0582..f43b27a0e4b3 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm.platforms import current_platform +from vllm.platforms.rocm import on_gfx12x, on_mi3xx from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( rocm_aiter_sparse_attn_indexer, @@ -56,19 +57,55 @@ def is_aiter_found_and_supported() -> bool: return False -def if_aiter_supported(func: Callable) -> Callable: - """Decorator that only executes the function if - ROCm AITER package is supported and enabled on gfx9 archs. +def if_aiter_supported( + arch_checks: list[Callable[[], bool]] | None = None, +): """ + Decorator factory that only executes the function if ROCm AITER package + is supported and any of the arch checks pass. + + The default check is always on_mi3xx (MI300 series: gfx942, gfx950). + When custom arch_checks are provided, they are combined with on_mi3xx + using OR logic (on_mi3xx OR any of the custom checks). + + Args: + arch_checks: List of callable predicates that return bool. + If None, defaults to [on_mi3xx] only. + If specified, combines with on_mi3xx: [on_mi3xx] + arch_checks. + + Usage: + @if_aiter_supported() # Default: MI300 series only (gfx942, gfx950) + def is_enabled(): ... + + @if_aiter_supported([on_gfx12x]) # MI300 series OR gfx12x archs + def is_linear_enabled(): ... + + @if_aiter_supported([on_gfx9]) # MI300 series OR gfx9 (all CDNA) + def is_all_cdna(): ... + """ + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + # First check: platform and library availability + if not (current_platform.is_rocm() and IS_AITER_FOUND): + return None + + # Always include on_mi3xx as the base check + checks = [on_mi3xx] + if arch_checks is not None: + # Combine user-provided checks with default on_mi3xx + checks = checks + arch_checks + + # If any arch check passes, execute the function + if not any(check() for check in checks): + return None - @functools.wraps(func) - def wrapper(*args, **kwargs): - if is_aiter_found_and_supported(): return func(*args, **kwargs) - return None + return wrapper - return wrapper + return decorator def _rocm_aiter_fused_moe_impl( @@ -953,16 +990,20 @@ class rocm_aiter_ops: after monkey patching the env variables in the unit test. Check Functions: - All check functions (is_*_enabled) are decorated with @if_aiter_supported, - which verifies: (1) platform is ROCm, (2) device arch is gfx9, and - (3) aiter library is installed. The check function then also verifies + All check functions (is_*_enabled) are decorated with @if_aiter_supported(), + which verifies: (1) platform is ROCm, (2) device arch is MI300 series (gfx942, gfx950), + and (3) aiter library is installed. The check function then also verifies the corresponding environment variable is enabled. i.e. ___ is_enabled() == current_platform.is_rocm() and | checked by - current_platform.is_on_gfx9() and | @if_aiter_supported + on_mi3xx() and | @if_aiter_supported() IS_AITER_FOUND and _______________| cls._AITER_ENABLED -----> Check by the logic in `is_enabled()` + Note: To enable a function for gfx12x or other archs, use: + @if_aiter_supported([on_gfx12x]) # MI300 series OR gfx12x + @if_aiter_supported([on_gfx9]) # MI300 series OR gfx9 (all CDNA) + Example: from vllm._aiter_ops import rocm_aiter_ops @@ -1089,86 +1130,86 @@ def get_aiter_quant_type(quant_type_str: str): return mapping.get(name) @classmethod - @if_aiter_supported + @if_aiter_supported([on_gfx12x]) def is_enabled(cls) -> bool: return cls._AITER_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported([on_gfx12x]) def is_linear_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._LINEAR_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported([on_gfx12x]) def is_linear_fp8_enabled(cls) -> bool: return cls.is_linear_enabled() @classmethod - @if_aiter_supported + @if_aiter_supported([on_gfx12x]) def is_rmsnorm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._RMSNORM_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported() def is_fused_moe_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FMOE_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported() def is_fusion_moe_shared_experts_enabled(cls) -> bool: return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported() def is_mla_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MLA_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported() def is_mha_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MHA_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported() def is_shuffle_kv_cache_enabled(cls) -> bool: return cls._SHUFFLE_KV_CACHE_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported([on_gfx12x]) def is_triton_unified_attn_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported() def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED @classmethod - @if_aiter_supported + @if_aiter_supported() def is_fp4bmm_enabled(cls) -> bool: from vllm.platforms.rocm import on_gfx950 return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950() @classmethod - @if_aiter_supported + @if_aiter_supported() def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: from vllm.platforms.rocm import on_gfx950 return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM and on_gfx950() @classmethod - @if_aiter_supported + @if_aiter_supported() def is_triton_rotary_embed_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED @classmethod - @if_aiter_supported + @if_aiter_supported() def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM @staticmethod - @if_aiter_supported + @if_aiter_supported([on_gfx12x]) def register_ops_once() -> None: global _OPS_REGISTERED if not _OPS_REGISTERED: From 998b366a481b90aba9cc00778e4512b0e5f98341 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Mon, 16 Mar 2026 15:05:57 +0000 Subject: [PATCH 6/8] change to explicit aiter support archs Signed-off-by: big-yellow-duck --- vllm/_aiter_ops.py | 77 ++++++++++++++++++++-------------------------- 1 file changed, 33 insertions(+), 44 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index f43b27a0e4b3..f04ccd0e44b4 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -57,31 +57,27 @@ def is_aiter_found_and_supported() -> bool: return False -def if_aiter_supported( - arch_checks: list[Callable[[], bool]] | None = None, +def aiter_ops_supported_arch( + arch_checks: list[Callable[[], bool]], ): """ Decorator factory that only executes the function if ROCm AITER package - is supported and any of the arch checks pass. - - The default check is always on_mi3xx (MI300 series: gfx942, gfx950). - When custom arch_checks are provided, they are combined with on_mi3xx - using OR logic (on_mi3xx OR any of the custom checks). + is supported and any of the specified arch checks pass. Args: arch_checks: List of callable predicates that return bool. - If None, defaults to [on_mi3xx] only. - If specified, combines with on_mi3xx: [on_mi3xx] + arch_checks. + Must explicitly specify which archs are supported. + Examples: + - [on_mi3xx]: MI300 series only (gfx942, gfx950) + - [on_mi3xx, on_gfx12x]: MI300 OR gfx12x + - [on_gfx9]: All gfx9 archs (gfx90a, gfx942, gfx950) Usage: - @if_aiter_supported() # Default: MI300 series only (gfx942, gfx950) + @aiter_ops_supported_arch([on_mi3xx]) # MI300 series only def is_enabled(): ... - @if_aiter_supported([on_gfx12x]) # MI300 series OR gfx12x archs + @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) # MI300 OR gfx12x def is_linear_enabled(): ... - - @if_aiter_supported([on_gfx9]) # MI300 series OR gfx9 (all CDNA) - def is_all_cdna(): ... """ def decorator(func: Callable) -> Callable: @@ -91,14 +87,8 @@ def wrapper(*args, **kwargs): if not (current_platform.is_rocm() and IS_AITER_FOUND): return None - # Always include on_mi3xx as the base check - checks = [on_mi3xx] - if arch_checks is not None: - # Combine user-provided checks with default on_mi3xx - checks = checks + arch_checks - - # If any arch check passes, execute the function - if not any(check() for check in checks): + # Execute if any arch check passes + if not any(check() for check in arch_checks): return None return func(*args, **kwargs) @@ -990,19 +980,19 @@ class rocm_aiter_ops: after monkey patching the env variables in the unit test. Check Functions: - All check functions (is_*_enabled) are decorated with @if_aiter_supported(), - which verifies: (1) platform is ROCm, (2) device arch is MI300 series (gfx942, gfx950), + All check functions (is_*_enabled) are decorated with @aiter_ops_supported_arch([...]), + which verifies: (1) platform is ROCm, (2) device arch matches the specified list, and (3) aiter library is installed. The check function then also verifies the corresponding environment variable is enabled. i.e. ___ is_enabled() == current_platform.is_rocm() and | checked by - on_mi3xx() and | @if_aiter_supported() + (on_mi3xx() or ...) and | @aiter_ops_supported_arch([...]) IS_AITER_FOUND and _______________| cls._AITER_ENABLED -----> Check by the logic in `is_enabled()` - Note: To enable a function for gfx12x or other archs, use: - @if_aiter_supported([on_gfx12x]) # MI300 series OR gfx12x - @if_aiter_supported([on_gfx9]) # MI300 series OR gfx9 (all CDNA) + Note: Arch checks must be explicitly specified: + @aiter_ops_supported_arch([on_mi3xx]) # MI300 series only + @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) # MI300 OR gfx12x Example: from vllm._aiter_ops import rocm_aiter_ops @@ -1130,86 +1120,85 @@ def get_aiter_quant_type(quant_type_str: str): return mapping.get(name) @classmethod - @if_aiter_supported([on_gfx12x]) + @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) def is_enabled(cls) -> bool: return cls._AITER_ENABLED @classmethod - @if_aiter_supported([on_gfx12x]) + @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) def is_linear_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._LINEAR_ENABLED @classmethod - @if_aiter_supported([on_gfx12x]) + @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) def is_linear_fp8_enabled(cls) -> bool: return cls.is_linear_enabled() @classmethod - @if_aiter_supported([on_gfx12x]) + @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) def is_rmsnorm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._RMSNORM_ENABLED @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_fused_moe_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FMOE_ENABLED @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_fusion_moe_shared_experts_enabled(cls) -> bool: return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_mla_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MLA_ENABLED @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_mha_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MHA_ENABLED @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_shuffle_kv_cache_enabled(cls) -> bool: return cls._SHUFFLE_KV_CACHE_ENABLED @classmethod - @if_aiter_supported([on_gfx12x]) + @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) def is_triton_unified_attn_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_fp4bmm_enabled(cls) -> bool: from vllm.platforms.rocm import on_gfx950 return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950() @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: from vllm.platforms.rocm import on_gfx950 return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM and on_gfx950() @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_triton_rotary_embed_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED @classmethod - @if_aiter_supported() + @aiter_ops_supported_arch([on_mi3xx]) def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM @staticmethod - @if_aiter_supported([on_gfx12x]) def register_ops_once() -> None: global _OPS_REGISTERED if not _OPS_REGISTERED: From 4afe02e68b9ad1fd87dff3745dc4dbf0ea527998 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Tue, 17 Mar 2026 04:45:11 +0000 Subject: [PATCH 7/8] check supported arch at register custom ops Signed-off-by: big-yellow-duck --- vllm/_aiter_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index f04ccd0e44b4..ee3c5d666657 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1199,6 +1199,7 @@ def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM @staticmethod + @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) def register_ops_once() -> None: global _OPS_REGISTERED if not _OPS_REGISTERED: From 63ae8d79849257f9361d3e6f71a6afeca71d0075 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Thu, 19 Mar 2026 06:35:47 +0000 Subject: [PATCH 8/8] fix aiter_ops dynamic check gpu arch Signed-off-by: big-yellow-duck --- vllm/_aiter_ops.py | 92 +++++++++++++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 30 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index ee3c5d666657..35492ca05455 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -8,7 +8,6 @@ import vllm.envs as envs from vllm.platforms import current_platform -from vllm.platforms.rocm import on_gfx12x, on_mi3xx from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( rocm_aiter_sparse_attn_indexer, @@ -58,25 +57,32 @@ def is_aiter_found_and_supported() -> bool: def aiter_ops_supported_arch( - arch_checks: list[Callable[[], bool]], + arch_checks: list[str], ): """ Decorator factory that only executes the function if ROCm AITER package is supported and any of the specified arch checks pass. Args: - arch_checks: List of callable predicates that return bool. + arch_checks: List of arch check names (strings). Must explicitly specify which archs are supported. + Available options and their mappings: + - "gfx1x": gfx11, gfx12 + - "gfx12x": gfx12 + - "mi3xx": gfx942, gfx950 + - "gfx9": gfx90a, gfx942, gfx950 + - "gfx942": gfx942 + - "gfx950": gfx950 Examples: - - [on_mi3xx]: MI300 series only (gfx942, gfx950) - - [on_mi3xx, on_gfx12x]: MI300 OR gfx12x - - [on_gfx9]: All gfx9 archs (gfx90a, gfx942, gfx950) + - ["mi3xx"]: MI300 series only (gfx942, gfx950) + - ["mi3xx", "gfx12x"]: MI300 OR gfx12x + - ["gfx9"]: All gfx9 archs (gfx90a, gfx942, gfx950) Usage: - @aiter_ops_supported_arch([on_mi3xx]) # MI300 series only + @aiter_ops_supported_arch(["mi3xx"]) # MI300 series only def is_enabled(): ... - @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) # MI300 OR gfx12x + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) # MI300 OR gfx12x def is_linear_enabled(): ... """ @@ -87,11 +93,37 @@ def wrapper(*args, **kwargs): if not (current_platform.is_rocm() and IS_AITER_FOUND): return None + # Import arch check functions lazily only when on ROCm + from vllm.platforms.rocm import ( + on_gfx1x, + on_gfx9, + on_gfx12x, + on_gfx942, + on_gfx950, + on_mi3xx, + ) + + func_map = { + "gfx1x": on_gfx1x, + "gfx12x": on_gfx12x, + "mi3xx": on_mi3xx, + "gfx9": on_gfx9, + "gfx942": on_gfx942, + "gfx950": on_gfx950, + } + # Execute if any arch check passes - if not any(check() for check in arch_checks): - return None + for check_name in arch_checks: + check_func = func_map.get(check_name) + if check_func is None: + raise ValueError( + f"Unknown arch check: {check_name}. " + f"Available: {list(func_map.keys())}" + ) + if check_func(): + return func(*args, **kwargs) - return func(*args, **kwargs) + return None return wrapper @@ -986,13 +1018,13 @@ class rocm_aiter_ops: the corresponding environment variable is enabled. i.e. ___ is_enabled() == current_platform.is_rocm() and | checked by - (on_mi3xx() or ...) and | @aiter_ops_supported_arch([...]) + (arch_check passes) and | @aiter_ops_supported_arch([...]) IS_AITER_FOUND and _______________| cls._AITER_ENABLED -----> Check by the logic in `is_enabled()` Note: Arch checks must be explicitly specified: - @aiter_ops_supported_arch([on_mi3xx]) # MI300 series only - @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) # MI300 OR gfx12x + @aiter_ops_supported_arch(["mi3xx"]) # MI300 series only + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) # MI300 OR gfx12x Example: from vllm._aiter_ops import rocm_aiter_ops @@ -1120,86 +1152,86 @@ def get_aiter_quant_type(quant_type_str: str): return mapping.get(name) @classmethod - @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_enabled(cls) -> bool: return cls._AITER_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_linear_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._LINEAR_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_linear_fp8_enabled(cls) -> bool: return cls.is_linear_enabled() @classmethod - @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_rmsnorm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._RMSNORM_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_fused_moe_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FMOE_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_fusion_moe_shared_experts_enabled(cls) -> bool: return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_mla_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MLA_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_mha_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MHA_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_shuffle_kv_cache_enabled(cls) -> bool: return cls._SHUFFLE_KV_CACHE_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def is_triton_unified_attn_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_fp4bmm_enabled(cls) -> bool: from vllm.platforms.rocm import on_gfx950 return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950() @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: from vllm.platforms.rocm import on_gfx950 return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM and on_gfx950() @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_triton_rotary_embed_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED @classmethod - @aiter_ops_supported_arch([on_mi3xx]) + @aiter_ops_supported_arch(["mi3xx"]) def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM @staticmethod - @aiter_ops_supported_arch([on_mi3xx, on_gfx12x]) + @aiter_ops_supported_arch(["mi3xx", "gfx12x"]) def register_ops_once() -> None: global _OPS_REGISTERED if not _OPS_REGISTERED: