diff --git a/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py index 0a403a29b0d2..a67215651389 100644 --- a/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +from typing import Any import pytest @@ -38,6 +39,13 @@ def make_parser(tool_format: str) -> ToolParser: ) +def make_parser_with_kwargs(chat_template_kwargs: dict[str, Any]) -> ToolParser: + return ToolParserManager.get_tool_parser("multi_format")( + FakeTokenizer(), + chat_template_kwargs=chat_template_kwargs, + ) + + def make_request() -> ChatCompletionRequest: return ChatCompletionRequest( model="test-model", @@ -45,12 +53,39 @@ def make_request() -> ChatCompletionRequest: ) -def test_default_format_delegates_to_hermes(): - parser = make_parser("default") +def make_schema_request() -> ChatCompletionRequest: + return ChatCompletionRequest( + model="test-model", + messages=[], + tools=[ + { + "type": "function", + "function": { + "name": "study_args", + "description": "Study argument coercion.", + "parameters": { + "type": "object", + "properties": { + "user_id": {"type": "string"}, + "include_revoked": {"type": "boolean"}, + "page": {"type": "integer"}, + "filters": {"type": "object"}, + }, + }, + }, + } + ], + ) + + +def test_missing_tool_format_defaults_to_xml(): + parser = make_parser_with_kwargs({}) extracted = run_tool_extraction_nonstreaming( parser, - '\n{"name":"get_weather","arguments":{"city":"Tokyo"}}\n', + "get_weather" + "cityTokyo" + "", make_request(), ) @@ -93,6 +128,133 @@ def test_glm_format_matches_template_output(): } +def test_ifm_json_format_uses_schema_type_coercion(): + parser = make_parser_with_kwargs({"tool_call_format": "json"}) + + extracted = run_tool_extraction_nonstreaming( + parser, + 'Planning.\n\n' + '{"name":"study_args","arguments":{' + '"user_id":12345,' + '"include_revoked":"true",' + '"page":"2",' + '"filters":"{\\"unit\\":\\"celsius\\"}"' + "}}\n" + "", + make_schema_request(), + ) + + assert extracted.tools_called + assert extracted.content == "Planning.\n" + assert extracted.tool_calls[0].function.name == "study_args" + args = json.loads(extracted.tool_calls[0].function.arguments) + assert args == { + "user_id": "12345", + "include_revoked": True, + "page": 2, + "filters": {"unit": "celsius"}, + } + assert isinstance(args["user_id"], str) + + +def test_ifm_xml_format_uses_schema_type_coercion(): + parser = make_parser("xml") + + extracted = run_tool_extraction_nonstreaming( + parser, + "Planning.\n\n" + "study_args\n" + "user_id\n" + "12345\n" + "include_revoked\n" + "true\n" + "page\n" + "2\n" + "filters\n" + '{"unit":"celsius"}\n' + "\n" + "", + make_schema_request(), + ) + + assert extracted.tools_called + args = json.loads(extracted.tool_calls[0].function.arguments) + assert args == { + "user_id": "12345", + "include_revoked": True, + "page": 2, + "filters": {"unit": "celsius"}, + } + assert isinstance(args["user_id"], str) + + +def test_ifm_xml_typed_format_uses_arg_type_without_schema(): + parser = make_parser_with_kwargs({"tool_call_format": "xml_typed"}) + + extracted = run_tool_extraction_nonstreaming( + parser, + "\n" + "study_args\n" + "user_id\n" + "string\n" + "12345\n" + "include_revoked\n" + "boolean\n" + "true\n" + "page\n" + "integer\n" + "2\n" + "\n" + "", + make_request(), + ) + + assert extracted.tools_called + args = json.loads(extracted.tool_calls[0].function.arguments) + assert args == { + "user_id": "12345", + "include_revoked": True, + "page": 2, + } + assert isinstance(args["user_id"], str) + + +@pytest.mark.parametrize( + "tool_format", + ["default", "typed_xml", "XML", "xllm_typed", "xml ", ""], +) +def test_tool_format_requires_exact_supported_value(tool_format: str): + with pytest.raises(ValueError, match="Use one of these exact values"): + make_parser_with_kwargs({"tool_call_format": tool_format}) + + +def test_tool_format_must_be_a_string(): + with pytest.raises(ValueError, match="must be a string"): + make_parser_with_kwargs({"tool_call_format": 123}) + + +def test_k2_v3_parser_alias_uses_ifm_formats(): + parser = ToolParserManager.get_tool_parser("k2_v3")( + FakeTokenizer(), + chat_template_kwargs={"tool_call_format": "xml_typed"}, + ) + + extracted = run_tool_extraction_nonstreaming( + parser, + "study_args\n" + "user_id" + "string" + "12345" + "", + make_request(), + ) + + assert extracted.tools_called + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "user_id": "12345" + } + + def test_minimax_format_extracts_inline_invokes(): parser = make_parser("minimax") @@ -260,13 +422,12 @@ def test_custom_formats_do_not_stream_yet(): assert delta is None -def test_readme_default_example(): - parser = make_parser("default") +def test_readme_json_example(): + parser = make_parser("json") extracted = run_tool_extraction_nonstreaming( parser, - '\n' - '{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}}\n' - "", + '{"name": "get_weather", ' + '"arguments": {"location": "San Francisco, CA"}}', make_request(), ) assert extracted.tools_called diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 0680d90b51e7..1c7ae7363654 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -66,6 +66,10 @@ "kimi_k2_tool_parser", "KimiK2ToolParser", ), + "k2_v3": ( + "multi_format_tool_parser", + "K2V3ToolParser", + ), "llama3_json": ( "llama_tool_parser", "Llama3JsonToolParser", diff --git a/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py index 12fc2069e888..30b476315c78 100644 --- a/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py @@ -24,7 +24,21 @@ class MultiFormatToolParser(ToolParser): - """Tool parser that dispatches on ``chat_template_kwargs['tool_format']``.""" + """Tool parser that dispatches on chat template tool-call format kwargs.""" + + _SUPPORTED_TOOL_FORMATS = frozenset( + { + "qwen3", + "minimax", + "dsv32", + "glm", + "gptoss", + "python", + "json", + "xml", + "xml_typed", + } + ) _MINIMAX_START_TOKEN = "" _MINIMAX_BLOCK_REGEX = re.compile( @@ -51,6 +65,20 @@ class MultiFormatToolParser(ToolParser): r"(.*?)", re.DOTALL, ) + + _IFM_TOOL_CALLS_START_TOKEN = "" + _IFM_TOOL_CALL_START_TOKEN = "" + _IFM_BLOCK_REGEX = re.compile( + r"(.*?)", + re.DOTALL, + ) + _IFM_ARG_REGEX = re.compile( + r"(.*?)\s*" + r"(?:(.*?)\s*)?" + r"(.*?)", + re.DOTALL, + ) + _GLM_BLOCK_REGEX = re.compile( r"(.*?)", re.DOTALL, @@ -67,24 +95,38 @@ def __init__( ): super().__init__(tokenizer) - self.tool_format = str( - (chat_template_kwargs or {}).get("tool_format") or "default" - ) + chat_template_kwargs = chat_template_kwargs or {} + raw_tool_format = "xml" + for key in ("tool_call_format", "tool_calling_format", "tool_format"): + if key in chat_template_kwargs and chat_template_kwargs[key] is not None: + raw_tool_format = chat_template_kwargs[key] + break + self.tool_format = self._validate_tool_format(raw_tool_format) self._delegate: ToolParser | None = None - if self.tool_format == "default": - from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import ( - Hermes2ProToolParser, - ) - - self._delegate = Hermes2ProToolParser(tokenizer) - elif self.tool_format == "qwen3": + if self.tool_format == "qwen3": from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import ( Qwen3XMLToolParser, ) self._delegate = Qwen3XMLToolParser(tokenizer) + @classmethod + def _validate_tool_format(cls, tool_format: Any) -> str: + if not isinstance(tool_format, str): + raise ValueError( + "tool_format/tool_call_format must be a string. " + f"Got {type(tool_format).__name__}." + ) + if tool_format not in cls._SUPPORTED_TOOL_FORMATS: + supported_formats = ", ".join(sorted(cls._SUPPORTED_TOOL_FORMATS)) + raise ValueError( + f"Unsupported tool_format/tool_call_format '{tool_format}'. " + "Use one of these exact values: " + f"{supported_formats}." + ) + return tool_format + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: if self._delegate is not None: return self._delegate.adjust_request(request) @@ -99,11 +141,17 @@ def extract_tool_calls( return self._delegate.extract_tool_calls(model_output, request) try: + if self.tool_format == "json": + return self._extract_ifm_json_tool_calls(model_output, request) + if self.tool_format in {"xml", "xml_typed"}: + return self._extract_ifm_xml_tool_calls(model_output, request) if self.tool_format == "minimax": return self._extract_minimax_tool_calls(model_output) if self.tool_format == "dsv32": return self._extract_dsv32_tool_calls(model_output) if self.tool_format == "glm": + if self._IFM_TOOL_CALL_START_TOKEN in model_output: + return self._extract_ifm_xml_tool_calls(model_output, request) return self._extract_glm_tool_calls(model_output, request) if self.tool_format == "gptoss": return self._extract_gptoss_tool_calls(model_output) @@ -169,6 +217,214 @@ def _tool_call(function_name: str, arguments: dict[str, Any]) -> ToolCall: ), ) + @staticmethod + def _schema_arg_type( + tool_name: str, + arg_name: str, + tools: list[ChatCompletionToolsParam] | None, + ) -> Any | None: + if tools is None: + return None + for tool in tools: + if tool.function.name != tool_name or tool.function.parameters is None: + continue + properties = tool.function.parameters.get("properties", {}) + arg_spec = properties.get(arg_name, {}) + if not isinstance(arg_spec, dict): + return None + return arg_spec.get("type") + return None + + @staticmethod + def _arg_type_is_string(arg_type: Any | None) -> bool: + if isinstance(arg_type, str): + return arg_type == "string" + if isinstance(arg_type, list): + return "string" in arg_type + return False + + @staticmethod + def _json_stringify(value: Any) -> str: + if isinstance(value, str): + return value + return json.dumps(value, ensure_ascii=False) + + @classmethod + def _coerce_argument_value( + cls, + value: Any, + tool_name: str, + arg_name: str, + tools: list[ChatCompletionToolsParam] | None, + *, + arg_type: str | None = None, + from_text: bool = False, + ) -> Any: + target_type = cls._schema_arg_type(tool_name, arg_name, tools) or arg_type + if cls._arg_type_is_string(target_type): + return cls._json_stringify(value) + + if isinstance(value, str) and (from_text or target_type is not None): + return cls._deserialize_glm_value(value) + return value + + @classmethod + def _coerce_arguments( + cls, + tool_name: str, + arguments: dict[str, Any], + tools: list[ChatCompletionToolsParam] | None, + ) -> dict[str, Any]: + return { + arg_name: cls._coerce_argument_value( + arg_value, + tool_name, + arg_name, + tools, + ) + for arg_name, arg_value in arguments.items() + } + + @staticmethod + def _json_arguments_to_dict(arguments: Any) -> dict[str, Any]: + if arguments is None: + return {} + if isinstance(arguments, str): + arguments = json.loads(arguments) if arguments.strip() else {} + if not isinstance(arguments, dict): + raise ValueError("Tool call arguments must be a JSON object.") + return arguments + + @classmethod + def _ifm_prefix_index(cls, model_output: str, first_match_index: int) -> int: + group_index = model_output.find(cls._IFM_TOOL_CALLS_START_TOKEN) + if group_index != -1: + return group_index + return first_match_index + + def _extract_ifm_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + matches = list(self._IFM_BLOCK_REGEX.finditer(model_output)) + if not matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + first_block = matches[0].group(1).strip() + if first_block.startswith(("{", "[")): + return self._extract_ifm_json_tool_calls(model_output, request) + return self._extract_ifm_xml_tool_calls(model_output, request) + + def _extract_ifm_json_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + matches = list(self._IFM_BLOCK_REGEX.finditer(model_output)) + if not matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls: list[ToolCall] = [] + for match in matches: + raw_tool_call = json.loads(match.group(1).strip()) + raw_tool_calls = ( + raw_tool_call if isinstance(raw_tool_call, list) else [raw_tool_call] + ) + for tool_call in raw_tool_calls: + function = tool_call.get("function", tool_call) + function_name = function.get("name") + if not function_name: + raise ValueError("Tool call JSON is missing a function name.") + arguments = self._json_arguments_to_dict( + function.get("arguments", {}) + ) + arguments = self._coerce_arguments( + function_name, + arguments, + request.tools, + ) + tool_calls.append(self._tool_call(function_name, arguments)) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=self._prefix_content( + model_output, + self._ifm_prefix_index(model_output, matches[0].start()), + ), + ) + + def _extract_ifm_xml_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + matches = list(self._IFM_BLOCK_REGEX.finditer(model_output)) + if not matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls: list[ToolCall] = [] + for match in matches: + block = match.group(1) + first_arg_idx = block.find("") + if first_arg_idx == -1: + function_name = block.strip() + arguments: dict[str, Any] = {} + else: + function_name = block[:first_arg_idx].strip() + arg_block = block[first_arg_idx:] + arguments = {} + for key, arg_type, value in self._IFM_ARG_REGEX.findall(arg_block): + arg_key = key.strip() + arg_value = self._coerce_argument_value( + value.strip(), + function_name, + arg_key, + request.tools, + arg_type=arg_type.strip() or None, + from_text=True, + ) + arguments[arg_key] = arg_value + + if function_name: + tool_calls.append(self._tool_call(function_name, arguments)) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=self._prefix_content( + model_output, + self._ifm_prefix_index(model_output, matches[0].start()), + ), + ) + def _extract_minimax_tool_calls( self, model_output: str, @@ -290,25 +546,6 @@ def _deserialize_glm_value(value: str) -> Any: return value - @staticmethod - def _glm_value_is_string( - tool_name: str, - arg_name: str, - tools: list[ChatCompletionToolsParam] | None, - ) -> bool: - if tools is None: - return False - for tool in tools: - if tool.function.name != tool_name or tool.function.parameters is None: - continue - arg_type = ( - tool.function.parameters.get("properties", {}) - .get(arg_name, {}) - .get("type") - ) - return arg_type == "string" - return False - def _extract_glm_tool_calls( self, model_output: str, @@ -335,11 +572,13 @@ def _extract_glm_tool_calls( arguments = {} for key, value in self._GLM_ARG_REGEX.findall(arg_block): arg_key = key.strip() - arg_value = value.strip() - if not self._glm_value_is_string( - function_name, arg_key, request.tools - ): - arg_value = self._deserialize_glm_value(arg_value) + arg_value = self._coerce_argument_value( + value.strip(), + function_name, + arg_key, + request.tools, + from_text=True, + ) arguments[arg_key] = arg_value if function_name: @@ -448,3 +687,7 @@ def _extract_python_tool_calls( model_output.find(""), ), ) + + +class K2V3ToolParser(MultiFormatToolParser): + """K2-V3 alias for the IFM-aware multi-format parser.""" diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c37ecf9be66d..dd3149d63dad 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -65,6 +65,7 @@ _TEXT_GENERATION_MODELS = { # [Decoder-only] + "XllmForCausalLM": ("xllm", "XllmForCausalLM"), "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"), "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"), @@ -173,7 +174,6 @@ "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), - "XllmForCausalLM": ("xllm", "XllmForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), @@ -1181,4 +1181,4 @@ def _run() -> None: if __name__ == "__main__": - _run() + _run() \ No newline at end of file diff --git a/vllm/model_executor/models/xllm.py b/vllm/model_executor/models/xllm.py index 7eb59a922af3..d70265322de5 100644 --- a/vllm/model_executor/models/xllm.py +++ b/vllm/model_executor/models/xllm.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only Xllm model compatible with HuggingFace weights.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from itertools import islice from typing import Any @@ -10,12 +11,25 @@ from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + get_tensor_model_parallel_rank +) +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import rms_norm, fused_add_rms_norm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -29,62 +43,126 @@ default_weight_loader, maybe_remap_kv_scale_name, ) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP -from .utils import ( +from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, + extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) +logger = init_logger(__name__) -class GroupRMSNorm(nn.Module): - """RMSNorm with per-group variance computation. - Computes variance over groups of hidden_size/n_groups dimensions - instead of the full hidden dimension. - """ +def permute_to_xllm(x): + return x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).reshape(*x.shape[:-1], -1) - def __init__( - self, - hidden_size: int, - n_groups: int = 1, - eps: float = 1e-6, - ) -> None: - super().__init__() - self.hidden_size = hidden_size + +def permute_to_hf(x): + return x.reshape(*x.shape[:-1], -1, 2).transpose(-1, -2).reshape(*x.shape[:-1], -1) + + +@CustomOp.register("grouped_rms_norm") +class XllmRMSNorm(RMSNorm): + def __init__(self, + hidden_size: int, + n_groups: int, + tp_size: int, + num_replicas: int = 1, + eps=1e-6): + """ + XllmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__(hidden_size=hidden_size, eps=eps) self.n_groups = n_groups - self.variance_epsilon = eps + self.hidden_size = hidden_size assert hidden_size % n_groups == 0 - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.ones(hidden_size * tp_size // num_replicas)) + self.variance_epsilon = eps - def forward( - self, - x: torch.Tensor, - residual: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - orig_dtype = x.dtype - x = x.to(torch.float32) + if tp_size > 1: + self.tp_weight = self.weight.reshape( + tp_size // num_replicas, self.hidden_size + )[get_tensor_model_parallel_rank() // num_replicas] + else: + self.tp_weight = self.weight - if residual is not None: - x = x + residual.to(torch.float32) - residual = x.to(orig_dtype) + # assert self._forward_method == self.forward_native + assert self._forward_method in [self.forward_native, self.forward_cuda] + if get_tensor_model_parallel_rank() == 0: + print(f'{self._forward_method=}') - # Group RMSNorm: compute variance per group - x_grouped = x.reshape(*x.shape[:-1], self.n_groups, -1) - variance = x_grouped.pow(2).mean(-1, keepdim=True) - x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon) - x = x_grouped.reshape(*x.shape[:-1], self.hidden_size) + def forward_native(self, x, residual=None): + x = x.reshape(*x.shape[:-1], self.n_groups, -1) + if residual is not None: + residual = residual.reshape(x.shape) + + x = self.forward_static( + x, + self.variance_epsilon, + self.hidden_size // self.n_groups, + x.dtype, + None, + residual, + self.variance_size_override, + ) + if residual is not None: + x, residual = x - x = (self.weight * x).to(orig_dtype) + x = x.reshape(*x.shape[:-2], -1) + x = self.tp_weight.data * x if residual is None: return x - return x, residual + else: + residual = residual.reshape(x.shape) + return x, residual + + def forward_cuda(self, hidden_states, residual=None): + # input_dtype = hidden_states.dtype + # hidden_states = hidden_states.to(torch.float32) + # if residual is not None: + # hidden_states = hidden_states + residual.to(torch.float32) + # residual = hidden_states.to(input_dtype) + + hidden_states = hidden_states.reshape(*hidden_states.shape[:-1], self.n_groups, -1) + if residual is not None: + residual = residual.reshape(hidden_states.shape) + # variance = hidden_states.pow(2).mean(-1, keepdim=True) + # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + if residual is not None: + hidden_states, residual = fused_add_rms_norm( + x=hidden_states, + residual=residual, + weight=torch.ones( + hidden_states.shape[-1], + device=hidden_states.device, + dtype=hidden_states.dtype), + variance_epsilon=self.variance_epsilon) + else: + hidden_states = rms_norm( + x=hidden_states, + weight=torch.ones( + hidden_states.shape[-1], + device=hidden_states.device, + dtype=hidden_states.dtype), + variance_epsilon=self.variance_epsilon) + + hidden_states = hidden_states.reshape(*hidden_states.shape[:-2], -1) + hidden_states = self.tp_weight * hidden_states + + if residual is None: + return hidden_states + else: + residual = residual.reshape(hidden_states.shape) + return hidden_states, residual class XllmMLP(nn.Module): @@ -115,8 +193,7 @@ def __init__( ) if hidden_act != "silu": raise ValueError( - f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now." + f"Unsupported activation: {hidden_act}. Only silu is supported for now." ) self.act_fn = SiluAndMul() @@ -127,38 +204,165 @@ def forward(self, x): return x +class XllmSparseMoeBlock(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + config = vllm_config.model_config.hf_text_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=config.moe_gate_bias, + skip_bias_add=True, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + n_shared_experts=config.num_shared_experts, + top_k=config.num_experts_per_tok, + use_grouped_topk=True, + num_expert_group=1, + topk_group=1, + scoring_func=config.router_score_func, + e_score_correction_bias=self.gate.bias, + routed_scaling_factor=config.router_scaling_factor, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=True, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=None # RoutingMethodType.Renormalize, + ) + + self.num_shared_experts = config.num_shared_experts + if config.num_shared_experts > 0: + self.shared_experts = XllmMLP( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size * config.num_shared_experts, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert hidden_states.dim() <= 2, ( + "XllmSparseMoeBlock only supports 1D or 2D inputs" + ) + is_input_1d = hidden_states.dim() == 1 + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + if self.num_shared_experts > 0: + final_hidden_states = ( + final_hidden_states + self.shared_experts(hidden_states)) + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + + # return to 1d if input is 1d + return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states + + class XllmAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, + query_key_norm: bool, rope_parameters: dict[str, Any], max_position_embeddings: int = 8192, head_dim: int | None = None, + rope_head_dim: int | None = None, rms_norm_eps: float = 1e-06, qkv_bias: bool = False, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", + dual_chunk_attention_config: dict[str, Any] | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + self.tp_size = tp_size + self.tp_rank = tp_rank + self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.rope_head_dim = rope_head_dim or self.head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -169,20 +373,22 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + self.num_kv_head_replicas = self.qkv_proj.num_kv_head_replicas self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, - bias=False, + bias=qkv_bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, + self.rope_head_dim, + rotary_dim=self.rope_head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.attn = Attention( self.num_heads, @@ -192,8 +398,30 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } + if dual_chunk_attention_config + else {}, ) + self.query_key_norm = query_key_norm + if self.query_key_norm: + # self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + # self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.q_norm = XllmRMSNorm( + hidden_size=self.num_heads * self.head_dim, + n_groups=self.num_heads, + tp_size=tp_size, + eps=rms_norm_eps) + self.k_norm = XllmRMSNorm( + hidden_size=self.num_kv_heads * self.head_dim, + n_groups=self.num_kv_heads, + tp_size=tp_size, + num_replicas=self.num_kv_head_replicas, + eps=rms_norm_eps) + def forward( self, positions: torch.Tensor, @@ -201,7 +429,52 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) + + # Add qk-norm + if self.query_key_norm: + # q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + # q_by_head = self.q_norm(q_by_head) + # q = q_by_head.view(q.shape) + # + # k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + # k_by_head = self.k_norm(k_by_head) + # k = k_by_head.view(k.shape) + q = self.q_norm(q) + k = self.k_norm(k) + + if self.rope_head_dim == self.head_dim: + q, k = self.rotary_emb(positions, q, k) + else: + tp_size, tp_rank = self.tp_size, self.tp_rank + q_ = tensor_model_parallel_all_gather(q.contiguous()) + k_ = tensor_model_parallel_all_gather(k.contiguous()) + q_ = q_.reshape(*q_.shape[:-1], self.total_num_heads, self.head_dim) + k_ = k_.reshape( + *k_.shape[:-1], self.total_num_kv_heads * self.num_kv_head_replicas, self.head_dim + )[..., ::self.num_kv_head_replicas, :] + + q_rope, q_nope = torch.split( + permute_to_xllm(q_), + split_size_or_sections=[self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1) + k_rope, k_nope = torch.split( + permute_to_xllm(k_), + split_size_or_sections=[self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1) + + q_rope, k_rope = self.rotary_emb( + positions, permute_to_hf(q_rope), permute_to_hf(k_rope)) + + q_ = permute_to_hf(torch.cat( + [permute_to_xllm(q_rope), q_nope], dim=-1)).reshape(*q_.shape[:-2], -1) + k_ = permute_to_hf(torch.cat( + [permute_to_xllm(k_rope), k_nope], dim=-1)).reshape(*k_.shape[:-2], -1) + + q = q_.split(q_.shape[-1] // tp_size, dim=-1)[tp_rank] + k = k_.split( + k_.shape[-1] // (tp_size // self.num_kv_head_replicas), dim=-1 + )[tp_rank // self.num_kv_head_replicas] + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -216,38 +489,61 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: quant_config = vllm_config.quant_config self.hidden_size = config.hidden_size - max_position_embeddings = getattr( - config, "max_position_embeddings", 8192 + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None ) self.self_attn = XllmAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, + query_key_norm=config.query_key_norm, rope_parameters=config.rope_parameters, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr(config, "attention_bias", False), head_dim=getattr(config, "head_dim", None), + rope_head_dim=getattr(config, "rope_head_dim", None), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, ) - self.mlp = XllmMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - - n_groups = getattr(config, "layernorm_num_groups", 1) - self.input_layernorm = GroupRMSNorm( - config.hidden_size, n_groups=n_groups, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = GroupRMSNorm( - config.hidden_size, n_groups=n_groups, eps=config.rms_norm_eps + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = XllmSparseMoeBlock( + vllm_config=vllm_config, prefix=f"{prefix}.mlp" + ) + else: + self.mlp = XllmMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + # self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.post_attention_layernorm = RMSNorm( + # config.hidden_size, eps=config.rms_norm_eps + # ) + assert config.hidden_size % config.layernorm_num_groups == 0 + self.input_layernorm = XllmRMSNorm( + hidden_size=config.hidden_size, + n_groups=config.layernorm_num_groups, + tp_size=1, + eps=config.rms_norm_eps) + self.post_attention_layernorm = XllmRMSNorm( + hidden_size=config.hidden_size, + n_groups=config.layernorm_num_groups, + tp_size=1, + eps=config.rms_norm_eps) def forward( self, @@ -255,21 +551,19 @@ def forward( hidden_states: torch.Tensor, residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual - ) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -281,41 +575,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_text_config quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config - if get_pp_group().is_first_rank or ( - config.tie_word_embeddings and get_pp_group().is_last_rank - ): - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) - else: - self.embed_tokens = PPMissingLayer() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: XllmDecoderLayer( - vllm_config=vllm_config, prefix=prefix - ), + lambda prefix: XllmDecoderLayer(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) + # self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + assert config.hidden_size % config.layernorm_num_groups == 0 + self.norm = XllmRMSNorm( + hidden_size=config.hidden_size, + n_groups=config.layernorm_num_groups, + tp_size=1, + eps=config.rms_norm_eps) - n_groups = getattr(config, "layernorm_num_groups", 1) - if get_pp_group().is_last_rank: - self.norm = GroupRMSNorm( - config.hidden_size, n_groups=n_groups, eps=config.rms_norm_eps - ) - else: - self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size ) + # Track layers for auxiliary hidden state outputs (EAGLE3) + self.aux_hidden_state_layers: tuple[int, ...] = () def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -326,7 +616,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -338,9 +628,17 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice( - self.layers, self.start_layer, self.end_layer + aux_hidden_states = [] + for layer_idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer), + start=self.start_layer, ): + # Collect auxiliary hidden states if specified + if layer_idx in self.aux_hidden_state_layers: + aux_hidden_state = ( + hidden_states + residual if residual is not None else hidden_states + ) + aux_hidden_states.append(aux_hidden_state) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -348,10 +646,26 @@ def forward( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) + + # Return auxiliary hidden states if collected + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ + # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), @@ -359,47 +673,148 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("gate_up_proj", "up_proj", 1), ] + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = ( + # ".bias", + # "_bias", + # ".k_scale", + # "_k_scale", + # ".v_scale", + # "_v_scale", + # ".weight_scale", + # "_weight_scale", + # ".input_scale", + # "_input_scale", + ) + params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue name = name.replace(weight_name, param_name) + + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue if name not in params_dict: - break + continue + param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: weight_loader(param, loaded_weight, shard_id) break else: - if is_pp_missing_parameter(name, self): - continue - if name.endswith("kv_scale"): - remapped = maybe_remap_kv_scale_name(name, params_dict) - if remapped is None: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: continue - name = remapped - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + # Skip loading extra parameters for GPTQ/modelopt models. + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale" + ) + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 + name, + remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class XllmForCausalLM(nn.Module, SupportsPP, SupportsLoRA): +class XllmForCausalLM( + nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts +): packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ] } fall_back_to_pt_during_load = False @@ -410,26 +825,77 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + # Only perform the following mapping when XllmMLP exists + if getattr(config, "mlp_only_layers", []): + self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] self.model = XllmModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - if get_pp_group().is_last_rank: - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, XllmDecoderLayer) + if isinstance(layer.mlp, XllmSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + # if example_layer is None: + # raise RuntimeError("No Qwen3MoE layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + + if example_layer is not None: + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, XllmSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -440,10 +906,10 @@ def forward( intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: - model_output = self.model( + hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) - return model_output + return hidden_states def compute_logits( self, @@ -453,9 +919,8 @@ def compute_logits( return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() \ No newline at end of file