From c13cb7898bbdaf7a4923e57fe3e5366f4add6ccb Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 9 Sep 2025 00:31:13 +0000 Subject: [PATCH] feat: overlap shared experts with send/recv --- .../device_communicators/all2all.py | 9 +- .../base_device_communicator.py | 5 +- .../fused_moe/deepep_ht_prepare_finalize.py | 129 ++++++++---- .../fused_moe/deepep_ll_prepare_finalize.py | 50 ++++- .../flashinfer_cutlass_prepare_finalize.py | 4 +- .../layers/fused_moe/fused_batched_moe.py | 4 +- aphrodite/modeling/layers/fused_moe/layer.py | 193 ++++++++++++++---- .../layers/fused_moe/modular_kernel.py | 138 ++++++++++--- .../layers/fused_moe/pplx_prepare_finalize.py | 77 ++++++- .../layers/fused_moe/prepare_finalize.py | 4 +- .../layers/shared_fused_moe/__init__.py | 4 + .../shared_fused_moe/shared_fused_moe.py | 54 +++++ aphrodite/modeling/models/deepseek_v2.py | 102 +++++---- aphrodite/modeling/models/glm4_moe.py | 2 + aphrodite/modeling/models/llama4.py | 28 +-- aphrodite/quantization/awq_marlin.py | 4 +- aphrodite/quantization/bitsandbytes.py | 2 +- .../compressed_tensors_moe.py | 12 +- aphrodite/quantization/experts_int8.py | 4 +- aphrodite/quantization/fp8.py | 4 +- aphrodite/quantization/gguf.py | 5 +- aphrodite/quantization/gptq_marlin.py | 2 +- aphrodite/quantization/modelopt.py | 4 +- aphrodite/quantization/moe_wna16.py | 4 +- aphrodite/quantization/mxfp4.py | 5 +- aphrodite/quantization/quark/quark_moe.py | 7 +- aphrodite/quantization/rtn.py | 4 +- aphrodite/v1/worker/gpu_worker.py | 3 +- 28 files changed, 641 insertions(+), 222 deletions(-) create mode 100644 aphrodite/modeling/layers/shared_fused_moe/__init__.py create mode 100644 aphrodite/modeling/layers/shared_fused_moe/shared_fused_moe.py diff --git a/aphrodite/distributed/device_communicators/all2all.py b/aphrodite/distributed/device_communicators/all2all.py index 07dc10e3b9..f3447b72ea 100644 --- a/aphrodite/distributed/device_communicators/all2all.py +++ b/aphrodite/distributed/device_communicators/all2all.py @@ -1,19 +1,14 @@ -from typing import TYPE_CHECKING, Any +from typing import Any import torch import torch.distributed as dist from loguru import logger -from aphrodite.utils import has_deep_ep, has_pplx from aphrodite.forward_context import get_forward_context +from aphrodite.utils import has_deep_ep, has_pplx from .base_device_communicator import All2AllManagerBase, Cache -if TYPE_CHECKING: - from aphrodite.modeling.layers.fused_moe.layer import FusedMoE -else: - FusedMoE = None - class NaiveAll2AllManager(All2AllManagerBase): """ diff --git a/aphrodite/distributed/device_communicators/base_device_communicator.py b/aphrodite/distributed/device_communicators/base_device_communicator.py index eea7249b7a..6d712494fe 100644 --- a/aphrodite/distributed/device_communicators/base_device_communicator.py +++ b/aphrodite/distributed/device_communicators/base_device_communicator.py @@ -250,7 +250,10 @@ def prepare_communication_buffer_for_model(self, moe_modules = [ module for module in model.modules() - if module.__class__.__name__ == "FusedMoE" + # TODO: Should use isinstance but can't. Maybe search for + # presence of quant_method.init_prepare_finalize? + if (module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE") ] for module in moe_modules: module.quant_method.init_prepare_finalize(module) diff --git a/aphrodite/modeling/layers/fused_moe/deepep_ht_prepare_finalize.py b/aphrodite/modeling/layers/fused_moe/deepep_ht_prepare_finalize.py index 6833ea54f8..38d83b3a67 100644 --- a/aphrodite/modeling/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/aphrodite/modeling/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional, Union import deep_ep import torch @@ -22,6 +22,7 @@ def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, self.num_dispatchers_ = num_dispatchers self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset + self.async_prepare = True # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. @@ -53,10 +54,16 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]: return None return deep_ep.Buffer.get_combine_config(self.dp_size) - def _do_dispatch(self, tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], - rank_topk_ids: torch.Tensor, - rank_topk_weights: torch.Tensor, num_experts: int): + def _do_dispatch( + self, + tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], + rank_topk_ids: torch.Tensor, + rank_topk_weights: torch.Tensor, + num_experts: int, + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> Callable: has_scales = token_scales is not None @@ -90,9 +97,36 @@ def _do_dispatch(self, tokens: torch.Tensor, expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, - async_finish=False, + async_finish=self.async_prepare, allocate_on_comm_stream=False) + return lambda: self._receiver( + event, + has_scales, + token_data, + expert_topk_ids, + num_experts, + expert_num_tokens_per_expert_list, + expert_topk_weights, + a1_scale, + quant_config, + ) + + def _receiver( + self, + event: deep_ep.EventOverlap, + has_scales: bool, + token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], + expert_topk_ids: Optional[torch.Tensor], + num_experts: int, + expert_num_tokens_per_expert_list: list[int], + expert_topk_weights: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + if self.async_prepare: + event.current_stream_wait() + if has_scales: expert_x, expert_x_scale = token_data else: @@ -109,6 +143,7 @@ def _do_dispatch(self, tokens: torch.Tensor, # DeepEP's topk_ids output refers to the local experts directly. Offset # the topk_ids to move it back to the global experts space so it aligns # with existing vLLM interfaces. + assert expert_topk_ids is not None expert_topk_ids = torch.where( expert_topk_ids == -1, num_experts - 1 if self.rank_expert_offset == 0 else 0, @@ -120,10 +155,28 @@ def _do_dispatch(self, tokens: torch.Tensor, expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( expert_num_tokens_per_expert_list, device=expert_x.device) + # Dispatch and Quant + # DeepEP kernels only support dispatching block-quantized + # activation scales. + # Dispatch in bfloat16 and quantize afterwards + if not quant_config.is_block_quantized: + # Quantize after dispatch. + expert_x_scale = None + if expert_x.numel() != 0: + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape) + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -134,9 +187,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> Callable: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -156,37 +207,37 @@ def prepare( ) if a1q_scale is not None and a1q_scale.numel() == 1: a1q_scale = a1q_scale.view(1, 1) - (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) + a1_post_scale = None else: - # Dispatch and Quant - # DeepEP kernels only support dispatching block-quantized - # activation scales. - # Dispatch in bfloat16 - (expert_x, _, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1, - token_scales=None, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) - # Quantize after dispatch. - expert_x_scale = None - if expert_x.numel() != 0: - expert_x, expert_x_scale = moe_kernel_quantize_input( - expert_x, - a1_scale, - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=False, - block_shape=quant_config.block_shape) + a1q = a1 + a1q_scale = None + a1_post_scale = a1_scale - return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) + return self._do_dispatch(tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config) + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, + topk_ids, num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + return receiver() def finalize( self, diff --git a/aphrodite/modeling/layers/fused_moe/deepep_ll_prepare_finalize.py b/aphrodite/modeling/layers/fused_moe/deepep_ll_prepare_finalize.py index b530378c13..0b2186c47d 100644 --- a/aphrodite/modeling/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/aphrodite/modeling/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Callable, Optional, Union import deep_ep import torch @@ -73,7 +73,6 @@ def _do_quant( self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], a1_dtype: torch.dtype, quant_dtype: Union[torch.dtype, str, None], per_act_token_quant: bool, @@ -108,7 +107,10 @@ def _do_quant( return x, x_scales - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -119,9 +121,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.ReceiverType: hidden_size = a1.size(1) assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ @@ -153,16 +153,48 @@ def prepare( num_experts, use_fp8=self.use_fp8_dispatch, async_finish=False, - return_recv_hook=False) + return_recv_hook=True) + + return lambda: self._receiver(hook, expert_x, expert_num_tokens, + a1_scale, a1.dtype, quant_config) + + def _receiver( + self, + hook: Callable, + expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + expert_num_tokens: torch.Tensor, + a1_scale, + a1_dtype, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + hook() expert_x, expert_x_scale = self._do_quant( - expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, + expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) expert_tokens_meta = mk.ExpertTokensMetadata( expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) - return (expert_x, expert_x_scale, expert_tokens_meta, None, None) + return expert_x, expert_x_scale, expert_tokens_meta, None, None + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, + topk_ids, num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + return receiver() def finalize( self, diff --git a/aphrodite/modeling/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/aphrodite/modeling/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 800a759650..59ee02151f 100644 --- a/aphrodite/modeling/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/aphrodite/modeling/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -53,9 +53,7 @@ def prepare( apply_router_weight_on_input: bool, # TODO(bnell): use quant_config + scales instead of ctor args quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) diff --git a/aphrodite/modeling/layers/fused_moe/fused_batched_moe.py b/aphrodite/modeling/layers/fused_moe/fused_batched_moe.py index e12632a096..55622a4a09 100644 --- a/aphrodite/modeling/layers/fused_moe/fused_batched_moe.py +++ b/aphrodite/modeling/layers/fused_moe/fused_batched_moe.py @@ -503,9 +503,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) diff --git a/aphrodite/modeling/layers/fused_moe/layer.py b/aphrodite/modeling/layers/fused_moe/layer.py index 2424a5520f..e59ecd92cb 100644 --- a/aphrodite/modeling/layers/fused_moe/layer.py +++ b/aphrodite/modeling/layers/fused_moe/layer.py @@ -1,7 +1,7 @@ from abc import abstractmethod from collections.abc import Iterable from enum import Enum -from typing import Callable, Literal, Optional, overload +from typing import Callable, Literal, Optional, Union, overload import torch import torch.nn.functional as F @@ -9,7 +9,6 @@ from torch.nn.parameter import UninitializedParameter import aphrodite.common.envs as envs -from aphrodite.common.logger import log_once from aphrodite.config import get_current_aphrodite_config from aphrodite.distributed import (get_dp_group, get_ep_group, get_tensor_model_parallel_world_size, @@ -211,6 +210,7 @@ def init_prepare_finalize(self, layer: torch.nn.Module): self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, + layer.shared_experts, ) def select_gemm_impl( @@ -248,7 +248,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: raise NotImplementedError @@ -405,7 +405,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None @@ -457,7 +457,7 @@ def forward_cuda( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -543,7 +543,7 @@ def forward_cpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -590,7 +590,7 @@ def forward_xpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -629,7 +629,7 @@ def forward_tpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -803,7 +803,7 @@ def __init__( # we padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": - from aphrodite.quantization.mxfp4 import ( # noqa: E501 + from aphrodite.modeling.layers.quantization.mxfp4 import ( # noqa: E501 should_use_flashinfer_mxfp4) if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): hidden_size = round_up(hidden_size, 256) @@ -833,8 +833,7 @@ def __init__( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) - log_once( - "INFO", + logger.info_once( "[EP Rank {}/{}] Expert parallelism is enabled. Local/global" " number of experts: {}/{}. Experts local to global index map:" " {}.", self.ep_rank, self.ep_size, self.local_num_experts, @@ -897,7 +896,7 @@ def __init__( self.quant_method = quant_method if self.enable_eplb: - from aphrodite.quantization.fp8 import Fp8MoEMethod + from aphrodite.modeling.layers.quantization.fp8 import Fp8MoEMethod if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)): # TODO: Add support for additional quantization methods. @@ -944,6 +943,10 @@ def __init__( dtype=moe.in_dtype, device=torch.cuda.current_device()) + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return None + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -1328,7 +1331,7 @@ def weight_loader(self, # load the weight scales and zp based on the quantization scheme # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported - # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # TODO @dsikka: once hardened, refactor to use Aphrodite Parameters # specific to each case quant_method = getattr(param, "quant_method", None) if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: @@ -1396,6 +1399,7 @@ def get_expert_weights(self) -> Iterable[torch.Tensor]: return [ weight.view(self.local_num_experts, -1) for name, weight in weights if name not in NON_EXPERT_WEIGHTS + and not name.startswith("_shared_experts.") ] def set_eplb_state( @@ -1578,25 +1582,45 @@ def maybe_all_reduce_tensor_model_parallel( else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: og_hidden_states = hidden_states.shape[-1] if self.hidden_size != og_hidden_states: hidden_states = F.pad(hidden_states, (0, self.hidden_size - og_hidden_states), mode='constant', value=0.0) - # TODO: Once the OOM issue for the TPU backend is resolved, we will - # switch to using the moe_forward custom op. - if current_platform.is_tpu(): - return self.forward_impl(hidden_states, router_logits) + + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) + else: + fused_output = torch.ops.aphrodite.moe_forward( + hidden_states, router_logits, self.layer_name) + return fused_output[..., :og_hidden_states] else: - return torch.ops.aphrodite.moe_forward( - hidden_states, router_logits, - self.layer_name)[..., :og_hidden_states] + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits) + else: + shared_output, fused_output = torch.ops.aphrodite.moe_forward_shared( + hidden_states, router_logits, self.layer_name) + return (shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states]) - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): + def forward_impl_chunked( + self, + full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None assert self.batched_hidden_states.dtype == full_hidden_states.dtype @@ -1607,7 +1631,10 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, assert ( self.batched_router_logits.size(-1) == full_router_logits.size(-1)) - full_final_hidden_states = torch.empty_like(full_hidden_states) + full_fused_final_hidden_states = torch.empty_like(full_hidden_states) + if self.shared_experts is not None: + full_shared_final_hidden_states = torch.empty_like( + full_hidden_states) def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_size = chunk_end - chunk_start @@ -1648,9 +1675,21 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): logical_replica_count=self.logical_replica_count, ) + assert self.shared_experts is None or isinstance( + final_hidden_states, tuple) + if not skip_result_store: - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states, non_blocking=True) + if self.shared_experts is None: + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states, + non_blocking=True) + else: + full_shared_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[0], + non_blocking=True) + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[1], + non_blocking=True) ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP @@ -1671,10 +1710,17 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_end, skip_result_store=chunk_start_ >= num_tokens) - return full_final_hidden_states + if self.shared_experts is None: + return full_fused_final_hidden_states + else: + return (full_shared_final_hidden_states, + full_fused_final_hidden_states) - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.quant_method is not None # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. @@ -1694,6 +1740,15 @@ def forward_impl(self, hidden_states: torch.Tensor, hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) + # If there are shared experts but we are not using a modular kernel, the + # shared experts must be called here + if (not isinstance(self.quant_method.fused_experts, + FusedMoEModularKernel) + and self.shared_experts is not None): + shared_output = self.shared_experts(hidden_states) + else: + shared_output = None + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1718,14 +1773,30 @@ def forward_impl(self, hidden_states: torch.Tensor, logical_replica_count=self.logical_replica_count, ) - if do_naive_dispatch_combine: - final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs. - final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( - final_hidden_states) + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) - return final_hidden_states + def reduce_output(states: torch.Tensor) -> torch.Tensor: + if do_naive_dispatch_combine: + states = get_ep_group().combine(states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + states = self.maybe_all_reduce_tensor_model_parallel(states) + + return states + + if self.shared_experts is None: + return reduce_output(final_hidden_states) + else: + return ( + reduce_output(final_hidden_states[0]), + reduce_output(final_hidden_states[1]), + ) @classmethod def make_expert_params_mapping( @@ -1780,17 +1851,22 @@ def extra_repr(self) -> str: return s -def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - assert self.quant_method is not None - + assert self.shared_experts is None return self.forward_impl(hidden_states, router_logits) -def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1803,6 +1879,37 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, tags=(torch.Tag.needs_fixed_stride_order, ), ) + +def moe_forward_shared( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.shared_experts is not None + return self.forward_impl(hidden_states, router_logits) + + +def moe_forward_shared_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + shared_out = torch.empty_like(hidden_states) + fused_out = torch.empty_like(hidden_states) + return shared_out, fused_out + + +direct_register_custom_op( + op_name="moe_forward_shared", + op_func=moe_forward_shared, + mutates_args=["hidden_states"], + fake_impl=moe_forward_shared_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + # Mark the FusedMoE weight_loader as supporting MoE-specific parameters # to avoid expensive runtime reflection in model loading code FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] diff --git a/aphrodite/modeling/layers/fused_moe/modular_kernel.py b/aphrodite/modeling/layers/fused_moe/modular_kernel.py index a5133d0c97..9321f101ad 100644 --- a/aphrodite/modeling/layers/fused_moe/modular_kernel.py +++ b/aphrodite/modeling/layers/fused_moe/modular_kernel.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import Optional, final +from typing import Callable, Optional, Union, final import torch @@ -139,6 +139,29 @@ def apply(self, output: Optional[torch.Tensor], raise NotImplementedError +# +# PrepareResultType is a tuple of: +# - quantized + dispatched a. +# - quantized + dispatched a1_scales. +# - Optional ExpertTokensMetadata containing gpu/cpu tensors +# as big as the number of local experts with the information about the +# number of tokens assigned to each local expert. +# - Optional dispatched expert topk IDs +# - Optional dispatched expert topk weight +# +# See `prepare` method below. +# +PrepareResultType = tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], +] + +ReceiverType = Callable[[], PrepareResultType] + + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -158,16 +181,9 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[ExpertTokensMetadata], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: + ) -> PrepareResultType: """ - Perform any quantization (and/or) dispatching needed - for this kernel. + Perform any quantization (and/or) dispatching needed for this kernel. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make @@ -191,6 +207,47 @@ def prepare( """ raise NotImplementedError + def supports_async(self) -> bool: + """ + Indicates whether or not this class implements prepare_async. + """ + return False + + def prepare_async( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> ReceiverType: + """ + Perform any quantization (and/or) dispatching needed for this kernel + but do not wait for results from other workers. + - a1: The (unquantized) input to the MoE layer. + - a1_scale: Optional scales for a1 + - a2_scale: Optional scales for the second MoE gemm. Required to make + sure the quantization is consistent for both gemms. + - topk_ids: The topk ids. + - topk_weights: The topk weights. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. + Returns a callback that when invoked waits for results from other + workers and has the same return signature as `prepare`, e.g. + receiver = obj.prepare_async(...) + a, a_scales, expert_meta, topk_ids, topk_weights = receiver() + is equivalent to: + a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...) + """ + raise NotImplementedError + @abstractmethod def finalize( self, @@ -451,10 +508,12 @@ def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEPermuteExpertsUnpermute, + shared_experts: Optional[torch.nn.Module] = None, ): super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts + self.shared_experts = shared_experts assert prepare_finalize.activation_format == \ fused_experts.activation_formats[0], ( f"{prepare_finalize.__class__.__name__}." @@ -690,7 +749,7 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -734,18 +793,46 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) + shared_output: torch.Tensor + + if (not self.prepare_finalize.supports_async() + or self.shared_experts is None): + + # Run shared experts serially with dispatch. + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = self.prepare_finalize.prepare( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + else: + # Overlap shared expert compute with all2all dispatch. + receiver = self.prepare_finalize.prepare_async( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + + assert self.shared_experts is not None + shared_output = self.shared_experts(a1) + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = receiver() # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids @@ -793,4 +880,7 @@ def forward( self.fused_experts.finalize_weight_and_reduce_impl(), ) - return output + if self.shared_experts is None: + return output + else: + return shared_output, output diff --git a/aphrodite/modeling/layers/fused_moe/pplx_prepare_finalize.py b/aphrodite/modeling/layers/fused_moe/pplx_prepare_finalize.py index 21cf818208..8592447698 100644 --- a/aphrodite/modeling/layers/fused_moe/pplx_prepare_finalize.py +++ b/aphrodite/modeling/layers/fused_moe/pplx_prepare_finalize.py @@ -81,12 +81,15 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int32 + return torch.uint32 def num_dispatchers(self) -> int: return self.num_dispatchers_ - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -97,9 +100,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.ReceiverType: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K @@ -136,6 +137,8 @@ def prepare( _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape) + orig_a_scale_block_shape: Optional[int] = None + if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -203,8 +206,45 @@ def prepare( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=topk_ids.view(dtype=torch.uint32), + indices=topk_ids, + bound_m=bound_m, + do_send=True, + do_recv=False, + ) + + return lambda: self._receiver( + expert_num_tokens, + expert_x, + expert_x_scale, + a1q, + a1q_scale, + topk_ids, + bound_m, + orig_a_scale_block_shape, + ) + + def _receiver( + self, + expert_num_tokens: torch.Tensor, + expert_x: torch.Tensor, + expert_x_scale: Optional[torch.Tensor], + a1q: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + bound_m: Optional[torch.Tensor], + orig_a_scale_block_shape: Optional[int], + ) -> mk.PrepareResultType: + + self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, bound_m=bound_m, + do_send=False, + do_recv=True, ) if expert_x_scale is not None: @@ -216,6 +256,31 @@ def prepare( return expert_x, expert_x_scale, expert_tokens_meta, None, None + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + return receiver() + def finalize( self, output: torch.Tensor, diff --git a/aphrodite/modeling/layers/fused_moe/prepare_finalize.py b/aphrodite/modeling/layers/fused_moe/prepare_finalize.py index 29c7264977..b41d3ff3e4 100644 --- a/aphrodite/modeling/layers/fused_moe/prepare_finalize.py +++ b/aphrodite/modeling/layers/fused_moe/prepare_finalize.py @@ -36,9 +36,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) diff --git a/aphrodite/modeling/layers/shared_fused_moe/__init__.py b/aphrodite/modeling/layers/shared_fused_moe/__init__.py new file mode 100644 index 0000000000..2d01c02f18 --- /dev/null +++ b/aphrodite/modeling/layers/shared_fused_moe/__init__.py @@ -0,0 +1,4 @@ +from aphrodite.modeling.layers.shared_fused_moe.shared_fused_moe import ( + SharedFusedMoE) + +__all__ = ["SharedFusedMoE"] diff --git a/aphrodite/modeling/layers/shared_fused_moe/shared_fused_moe.py b/aphrodite/modeling/layers/shared_fused_moe/shared_fused_moe.py new file mode 100644 index 0000000000..bee6962083 --- /dev/null +++ b/aphrodite/modeling/layers/shared_fused_moe/shared_fused_moe.py @@ -0,0 +1,54 @@ +from typing import Optional + +import torch + +from aphrodite.distributed import tensor_model_parallel_all_reduce +from aphrodite.modeling.layers.fused_moe.layer import FusedMoE + + +# TODO: Add shared + fused combo function? e.g. + +class SharedFusedMoE(FusedMoE): + """ + A FusedMoE operation that also computes the results of shared experts. + If an all2all communicator is being used the shared expert computation + can be interleaved with the fused all2all dispatch communication step. + """ + + def __init__( + self, + shared_experts: torch.nn.Module, + use_overlapped: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return self._shared_experts if self.use_overlapped else None + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not self.use_overlapped: + shared_out = self._shared_experts(hidden_states) + + # Reduce outputs if necessary, since the MLP should + # have been created with reduce_results=False. + if (self.reduce_results and self.tp_size > 1 + and self.must_reduce_shared_expert_outputs()): + shared_out = tensor_model_parallel_all_reduce(shared_out) + + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + else: + shared_out, fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out diff --git a/aphrodite/modeling/models/deepseek_v2.py b/aphrodite/modeling/models/deepseek_v2.py index ed5df70ffc..a9677fcab9 100644 --- a/aphrodite/modeling/models/deepseek_v2.py +++ b/aphrodite/modeling/models/deepseek_v2.py @@ -46,6 +46,7 @@ RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.rotary_embedding import get_rope +from aphrodite.modeling.layers.shared_fused_moe import SharedFusedMoE from aphrodite.modeling.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from aphrodite.modeling.model_loader.weight_utils import ( @@ -144,26 +145,27 @@ def __init__( self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - # we do scaling outside, set factor to 1.0 to avoid double mul - routed_scaling_factor=1.0, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - - if config.n_shared_experts is not None: + if config.n_shared_experts is None: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + self.shared_experts = None + else: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) self.shared_experts = DeepseekV2MLP( @@ -171,36 +173,56 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=False, prefix=f"{prefix}.shared_experts", ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + fused_moe_out = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if shared_output is not None: - if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output \ - * (1. / self.routed_scaling_factor) + shared_output = None + final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= (1. / self.routed_scaling_factor) + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output if self.tp_size > 1: final_hidden_states = ( diff --git a/aphrodite/modeling/models/glm4_moe.py b/aphrodite/modeling/models/glm4_moe.py index 8ba8be92cc..b287be7817 100644 --- a/aphrodite/modeling/models/glm4_moe.py +++ b/aphrodite/modeling/models/glm4_moe.py @@ -179,6 +179,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) + else: + shared_output = None router_logits = self.gate(hidden_states.to(dtype=torch.float32)) final_hidden_states = self.experts( hidden_states=hidden_states, diff --git a/aphrodite/modeling/models/llama4.py b/aphrodite/modeling/models/llama4.py index 81c187e6ef..c8f921a23b 100644 --- a/aphrodite/modeling/models/llama4.py +++ b/aphrodite/modeling/models/llama4.py @@ -34,6 +34,7 @@ ReplicatedLinear, RowParallelLinear) from aphrodite.modeling.layers.rotary_embedding import get_rope +from aphrodite.modeling.layers.shared_fused_moe import SharedFusedMoE from aphrodite.modeling.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from aphrodite.quantization import QuantizationConfig @@ -72,7 +73,18 @@ def __init__(self, quant_config=None, prefix=f"{prefix}.router") - self.experts = FusedMoE( + self.shared_expert = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.shared_expert", + reduce_results=False, + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -82,22 +94,12 @@ def __init__(self, reduce_results=False, renormalize=False, quant_config=quant_config, - prefix=f"{prefix}.experts") - - self.shared_expert = LlamaMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size_moe, - hidden_act="silu", - quant_config=quant_config, - bias=False, - prefix=f"{prefix}.shared_expert", - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + prefix=f"{prefix}.experts", ) def forward(self, hidden_states): router_logits, _ = self.router(hidden_states) - shared_out = self.shared_expert(hidden_states) - routed_out = self.experts( + shared_out, routed_out = self.experts( hidden_states=hidden_states, router_logits=router_logits, ) diff --git a/aphrodite/quantization/awq_marlin.py b/aphrodite/quantization/awq_marlin.py index 3c67d804e1..319f80217e 100644 --- a/aphrodite/quantization/awq_marlin.py +++ b/aphrodite/quantization/awq_marlin.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from loguru import logger @@ -503,7 +503,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( diff --git a/aphrodite/quantization/bitsandbytes.py b/aphrodite/quantization/bitsandbytes.py index eb1fce8645..21609d93bd 100644 --- a/aphrodite/quantization/bitsandbytes.py +++ b/aphrodite/quantization/bitsandbytes.py @@ -470,7 +470,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: from aphrodite.modeling.layers.fused_moe import fused_experts assert self.fused_experts is None diff --git a/aphrodite/quantization/compressed_tensors/compressed_tensors_moe.py b/aphrodite/quantization/compressed_tensors/compressed_tensors_moe.py index dba37ea240..6a434b868b 100644 --- a/aphrodite/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/aphrodite/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,6 +1,6 @@ import enum from enum import Enum -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from compressed_tensors import CompressionFormat @@ -352,7 +352,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -814,7 +814,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -1064,7 +1064,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -1370,7 +1370,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -1603,7 +1603,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/aphrodite/quantization/experts_int8.py b/aphrodite/quantization/experts_int8.py index d03c8ec333..cfcb8c87f4 100644 --- a/aphrodite/quantization/experts_int8.py +++ b/aphrodite/quantization/experts_int8.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch @@ -125,7 +125,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/aphrodite/quantization/fp8.py b/aphrodite/quantization/fp8.py index 05b3650ddb..a963232562 100644 --- a/aphrodite/quantization/fp8.py +++ b/aphrodite/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -989,7 +989,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None diff --git a/aphrodite/quantization/gguf.py b/aphrodite/quantization/gguf.py index 8ed2fc5c58..79d235fa5e 100644 --- a/aphrodite/quantization/gguf.py +++ b/aphrodite/quantization/gguf.py @@ -1,9 +1,8 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import gguf import torch from gguf import GGMLQuantizationType as WeightType -from loguru import logger from torch.nn.parameter import Parameter, UninitializedParameter from aphrodite import _custom_ops as ops @@ -538,7 +537,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/aphrodite/quantization/gptq_marlin.py b/aphrodite/quantization/gptq_marlin.py index 89743cc846..68f1eec8a3 100644 --- a/aphrodite/quantization/gptq_marlin.py +++ b/aphrodite/quantization/gptq_marlin.py @@ -652,7 +652,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/aphrodite/quantization/modelopt.py b/aphrodite/quantization/modelopt.py index 6af836d7df..8a56b56981 100644 --- a/aphrodite/quantization/modelopt.py +++ b/aphrodite/quantization/modelopt.py @@ -491,7 +491,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") @@ -1368,7 +1368,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") diff --git a/aphrodite/quantization/moe_wna16.py b/aphrodite/quantization/moe_wna16.py index 55d4ecc16c..662219e6b8 100644 --- a/aphrodite/quantization/moe_wna16.py +++ b/aphrodite/quantization/moe_wna16.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch @@ -298,7 +298,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/aphrodite/quantization/mxfp4.py b/aphrodite/quantization/mxfp4.py index 7e38fda0ba..04b4b54150 100644 --- a/aphrodite/quantization/mxfp4.py +++ b/aphrodite/quantization/mxfp4.py @@ -1,7 +1,6 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch -from loguru import logger from torch.nn.parameter import Parameter from aphrodite.common import envs @@ -553,7 +552,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") diff --git a/aphrodite/quantization/quark/quark_moe.py b/aphrodite/quantization/quark/quark_moe.py index 52b9a3cc87..5e493db0de 100644 --- a/aphrodite/quantization/quark/quark_moe.py +++ b/aphrodite/quantization/quark/quark_moe.py @@ -1,7 +1,6 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch -from loguru import logger from aphrodite import _custom_ops as ops from aphrodite.common.logger import log_once @@ -222,7 +221,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -388,7 +387,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/aphrodite/quantization/rtn.py b/aphrodite/quantization/rtn.py index d0be4441da..663fcc8627 100644 --- a/aphrodite/quantization/rtn.py +++ b/aphrodite/quantization/rtn.py @@ -1,7 +1,7 @@ # Copyright © 2025, Oracle and/or its affiliates. import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -287,7 +287,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/aphrodite/v1/worker/gpu_worker.py b/aphrodite/v1/worker/gpu_worker.py index e07f42da7b..6bbc48068e 100644 --- a/aphrodite/v1/worker/gpu_worker.py +++ b/aphrodite/v1/worker/gpu_worker.py @@ -532,7 +532,8 @@ def _reconfigure_moe(self, old_ep_size: int, parallel_config = self.aphrodite_config.parallel_config moe_modules = [ module for module in self.model_runner.model.modules() - if module.__class__.__name__ == "FusedMoE" + if (module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE") ] num_local_experts = moe_modules[0].moe_config.num_local_experts assert all(module.moe_config.num_local_experts == num_local_experts