diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,block_shape=[128,128].json new file mode 100644 index 000000000000..678c64cc7c54 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,block_shape=[128,128].json @@ -0,0 +1,69 @@ +{ + "triton_version": "3.5.1+rocm7.2.0.gita272dfa8", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..b03511a2cc72 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=AMD_Radeon_AI_PRO_R9700,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,57 @@ +{ + "triton_version": "3.5.1+rocm7.2.0.gita272dfa8", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 70adac711f5a..6a1a0eb1089c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1926,14 +1926,17 @@ def _supports_quant_scheme( ) -> bool: p = current_platform if p.is_rocm(): - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx12x is_rocm_on_gfx9 = on_gfx9() + is_rocm_on_gfx12x = on_gfx12x() else: is_rocm_on_gfx9 = False + is_rocm_on_gfx12x = False device_supports_fp8 = ( is_rocm_on_gfx9 + or is_rocm_on_gfx12x or (p.is_cuda() and p.has_device_capability((8, 9))) or p.is_xpu() ) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 76be83c0638a..0551586f1ef3 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -145,6 +145,7 @@ def _get_gcn_arch() -> str: _GCN_ARCH = _get_gcn_arch() _ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"]) +_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"]) _ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"]) _ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) _ON_GFX942 = "gfx942" in _GCN_ARCH @@ -226,6 +227,10 @@ def on_gfx1x() -> bool: return _ON_GFX1X +def on_gfx12x() -> bool: + return _ON_GFX12X + + def on_mi3xx() -> bool: return _ON_MI3XX