diff --git a/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py
index 0a403a29b0d2..a67215651389 100644
--- a/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py
+++ b/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
+from typing import Any
import pytest
@@ -38,6 +39,13 @@ def make_parser(tool_format: str) -> ToolParser:
)
+def make_parser_with_kwargs(chat_template_kwargs: dict[str, Any]) -> ToolParser:
+ return ToolParserManager.get_tool_parser("multi_format")(
+ FakeTokenizer(),
+ chat_template_kwargs=chat_template_kwargs,
+ )
+
+
def make_request() -> ChatCompletionRequest:
return ChatCompletionRequest(
model="test-model",
@@ -45,12 +53,39 @@ def make_request() -> ChatCompletionRequest:
)
-def test_default_format_delegates_to_hermes():
- parser = make_parser("default")
+def make_schema_request() -> ChatCompletionRequest:
+ return ChatCompletionRequest(
+ model="test-model",
+ messages=[],
+ tools=[
+ {
+ "type": "function",
+ "function": {
+ "name": "study_args",
+ "description": "Study argument coercion.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "user_id": {"type": "string"},
+ "include_revoked": {"type": "boolean"},
+ "page": {"type": "integer"},
+ "filters": {"type": "object"},
+ },
+ },
+ },
+ }
+ ],
+ )
+
+
+def test_missing_tool_format_defaults_to_xml():
+ parser = make_parser_with_kwargs({})
extracted = run_tool_extraction_nonstreaming(
parser,
- '\n{"name":"get_weather","arguments":{"city":"Tokyo"}}\n',
+ "get_weather"
+ "cityTokyo"
+ "",
make_request(),
)
@@ -93,6 +128,133 @@ def test_glm_format_matches_template_output():
}
+def test_ifm_json_format_uses_schema_type_coercion():
+ parser = make_parser_with_kwargs({"tool_call_format": "json"})
+
+ extracted = run_tool_extraction_nonstreaming(
+ parser,
+ 'Planning.\n\n'
+ '{"name":"study_args","arguments":{'
+ '"user_id":12345,'
+ '"include_revoked":"true",'
+ '"page":"2",'
+ '"filters":"{\\"unit\\":\\"celsius\\"}"'
+ "}}\n"
+ "",
+ make_schema_request(),
+ )
+
+ assert extracted.tools_called
+ assert extracted.content == "Planning.\n"
+ assert extracted.tool_calls[0].function.name == "study_args"
+ args = json.loads(extracted.tool_calls[0].function.arguments)
+ assert args == {
+ "user_id": "12345",
+ "include_revoked": True,
+ "page": 2,
+ "filters": {"unit": "celsius"},
+ }
+ assert isinstance(args["user_id"], str)
+
+
+def test_ifm_xml_format_uses_schema_type_coercion():
+ parser = make_parser("xml")
+
+ extracted = run_tool_extraction_nonstreaming(
+ parser,
+ "Planning.\n\n"
+ "study_args\n"
+ "user_id\n"
+ "12345\n"
+ "include_revoked\n"
+ "true\n"
+ "page\n"
+ "2\n"
+ "filters\n"
+ '{"unit":"celsius"}\n'
+ "\n"
+ "",
+ make_schema_request(),
+ )
+
+ assert extracted.tools_called
+ args = json.loads(extracted.tool_calls[0].function.arguments)
+ assert args == {
+ "user_id": "12345",
+ "include_revoked": True,
+ "page": 2,
+ "filters": {"unit": "celsius"},
+ }
+ assert isinstance(args["user_id"], str)
+
+
+def test_ifm_xml_typed_format_uses_arg_type_without_schema():
+ parser = make_parser_with_kwargs({"tool_call_format": "xml_typed"})
+
+ extracted = run_tool_extraction_nonstreaming(
+ parser,
+ "\n"
+ "study_args\n"
+ "user_id\n"
+ "string\n"
+ "12345\n"
+ "include_revoked\n"
+ "boolean\n"
+ "true\n"
+ "page\n"
+ "integer\n"
+ "2\n"
+ "\n"
+ "",
+ make_request(),
+ )
+
+ assert extracted.tools_called
+ args = json.loads(extracted.tool_calls[0].function.arguments)
+ assert args == {
+ "user_id": "12345",
+ "include_revoked": True,
+ "page": 2,
+ }
+ assert isinstance(args["user_id"], str)
+
+
+@pytest.mark.parametrize(
+ "tool_format",
+ ["default", "typed_xml", "XML", "xllm_typed", "xml ", ""],
+)
+def test_tool_format_requires_exact_supported_value(tool_format: str):
+ with pytest.raises(ValueError, match="Use one of these exact values"):
+ make_parser_with_kwargs({"tool_call_format": tool_format})
+
+
+def test_tool_format_must_be_a_string():
+ with pytest.raises(ValueError, match="must be a string"):
+ make_parser_with_kwargs({"tool_call_format": 123})
+
+
+def test_k2_v3_parser_alias_uses_ifm_formats():
+ parser = ToolParserManager.get_tool_parser("k2_v3")(
+ FakeTokenizer(),
+ chat_template_kwargs={"tool_call_format": "xml_typed"},
+ )
+
+ extracted = run_tool_extraction_nonstreaming(
+ parser,
+ "study_args\n"
+ "user_id"
+ "string"
+ "12345"
+ "",
+ make_request(),
+ )
+
+ assert extracted.tools_called
+ assert json.loads(extracted.tool_calls[0].function.arguments) == {
+ "user_id": "12345"
+ }
+
+
def test_minimax_format_extracts_inline_invokes():
parser = make_parser("minimax")
@@ -260,13 +422,12 @@ def test_custom_formats_do_not_stream_yet():
assert delta is None
-def test_readme_default_example():
- parser = make_parser("default")
+def test_readme_json_example():
+ parser = make_parser("json")
extracted = run_tool_extraction_nonstreaming(
parser,
- '\n'
- '{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}}\n'
- "",
+ '{"name": "get_weather", '
+ '"arguments": {"location": "San Francisco, CA"}}',
make_request(),
)
assert extracted.tools_called
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index 0680d90b51e7..1c7ae7363654 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -66,6 +66,10 @@
"kimi_k2_tool_parser",
"KimiK2ToolParser",
),
+ "k2_v3": (
+ "multi_format_tool_parser",
+ "K2V3ToolParser",
+ ),
"llama3_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
diff --git a/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py
index 12fc2069e888..30b476315c78 100644
--- a/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py
@@ -24,7 +24,21 @@
class MultiFormatToolParser(ToolParser):
- """Tool parser that dispatches on ``chat_template_kwargs['tool_format']``."""
+ """Tool parser that dispatches on chat template tool-call format kwargs."""
+
+ _SUPPORTED_TOOL_FORMATS = frozenset(
+ {
+ "qwen3",
+ "minimax",
+ "dsv32",
+ "glm",
+ "gptoss",
+ "python",
+ "json",
+ "xml",
+ "xml_typed",
+ }
+ )
_MINIMAX_START_TOKEN = ""
_MINIMAX_BLOCK_REGEX = re.compile(
@@ -51,6 +65,20 @@ class MultiFormatToolParser(ToolParser):
r"(.*?)",
re.DOTALL,
)
+
+ _IFM_TOOL_CALLS_START_TOKEN = ""
+ _IFM_TOOL_CALL_START_TOKEN = ""
+ _IFM_BLOCK_REGEX = re.compile(
+ r"(.*?)",
+ re.DOTALL,
+ )
+ _IFM_ARG_REGEX = re.compile(
+ r"(.*?)\s*"
+ r"(?:(.*?)\s*)?"
+ r"(.*?)",
+ re.DOTALL,
+ )
+
_GLM_BLOCK_REGEX = re.compile(
r"(.*?)",
re.DOTALL,
@@ -67,24 +95,38 @@ def __init__(
):
super().__init__(tokenizer)
- self.tool_format = str(
- (chat_template_kwargs or {}).get("tool_format") or "default"
- )
+ chat_template_kwargs = chat_template_kwargs or {}
+ raw_tool_format = "xml"
+ for key in ("tool_call_format", "tool_calling_format", "tool_format"):
+ if key in chat_template_kwargs and chat_template_kwargs[key] is not None:
+ raw_tool_format = chat_template_kwargs[key]
+ break
+ self.tool_format = self._validate_tool_format(raw_tool_format)
self._delegate: ToolParser | None = None
- if self.tool_format == "default":
- from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
- Hermes2ProToolParser,
- )
-
- self._delegate = Hermes2ProToolParser(tokenizer)
- elif self.tool_format == "qwen3":
+ if self.tool_format == "qwen3":
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import (
Qwen3XMLToolParser,
)
self._delegate = Qwen3XMLToolParser(tokenizer)
+ @classmethod
+ def _validate_tool_format(cls, tool_format: Any) -> str:
+ if not isinstance(tool_format, str):
+ raise ValueError(
+ "tool_format/tool_call_format must be a string. "
+ f"Got {type(tool_format).__name__}."
+ )
+ if tool_format not in cls._SUPPORTED_TOOL_FORMATS:
+ supported_formats = ", ".join(sorted(cls._SUPPORTED_TOOL_FORMATS))
+ raise ValueError(
+ f"Unsupported tool_format/tool_call_format '{tool_format}'. "
+ "Use one of these exact values: "
+ f"{supported_formats}."
+ )
+ return tool_format
+
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if self._delegate is not None:
return self._delegate.adjust_request(request)
@@ -99,11 +141,17 @@ def extract_tool_calls(
return self._delegate.extract_tool_calls(model_output, request)
try:
+ if self.tool_format == "json":
+ return self._extract_ifm_json_tool_calls(model_output, request)
+ if self.tool_format in {"xml", "xml_typed"}:
+ return self._extract_ifm_xml_tool_calls(model_output, request)
if self.tool_format == "minimax":
return self._extract_minimax_tool_calls(model_output)
if self.tool_format == "dsv32":
return self._extract_dsv32_tool_calls(model_output)
if self.tool_format == "glm":
+ if self._IFM_TOOL_CALL_START_TOKEN in model_output:
+ return self._extract_ifm_xml_tool_calls(model_output, request)
return self._extract_glm_tool_calls(model_output, request)
if self.tool_format == "gptoss":
return self._extract_gptoss_tool_calls(model_output)
@@ -169,6 +217,214 @@ def _tool_call(function_name: str, arguments: dict[str, Any]) -> ToolCall:
),
)
+ @staticmethod
+ def _schema_arg_type(
+ tool_name: str,
+ arg_name: str,
+ tools: list[ChatCompletionToolsParam] | None,
+ ) -> Any | None:
+ if tools is None:
+ return None
+ for tool in tools:
+ if tool.function.name != tool_name or tool.function.parameters is None:
+ continue
+ properties = tool.function.parameters.get("properties", {})
+ arg_spec = properties.get(arg_name, {})
+ if not isinstance(arg_spec, dict):
+ return None
+ return arg_spec.get("type")
+ return None
+
+ @staticmethod
+ def _arg_type_is_string(arg_type: Any | None) -> bool:
+ if isinstance(arg_type, str):
+ return arg_type == "string"
+ if isinstance(arg_type, list):
+ return "string" in arg_type
+ return False
+
+ @staticmethod
+ def _json_stringify(value: Any) -> str:
+ if isinstance(value, str):
+ return value
+ return json.dumps(value, ensure_ascii=False)
+
+ @classmethod
+ def _coerce_argument_value(
+ cls,
+ value: Any,
+ tool_name: str,
+ arg_name: str,
+ tools: list[ChatCompletionToolsParam] | None,
+ *,
+ arg_type: str | None = None,
+ from_text: bool = False,
+ ) -> Any:
+ target_type = cls._schema_arg_type(tool_name, arg_name, tools) or arg_type
+ if cls._arg_type_is_string(target_type):
+ return cls._json_stringify(value)
+
+ if isinstance(value, str) and (from_text or target_type is not None):
+ return cls._deserialize_glm_value(value)
+ return value
+
+ @classmethod
+ def _coerce_arguments(
+ cls,
+ tool_name: str,
+ arguments: dict[str, Any],
+ tools: list[ChatCompletionToolsParam] | None,
+ ) -> dict[str, Any]:
+ return {
+ arg_name: cls._coerce_argument_value(
+ arg_value,
+ tool_name,
+ arg_name,
+ tools,
+ )
+ for arg_name, arg_value in arguments.items()
+ }
+
+ @staticmethod
+ def _json_arguments_to_dict(arguments: Any) -> dict[str, Any]:
+ if arguments is None:
+ return {}
+ if isinstance(arguments, str):
+ arguments = json.loads(arguments) if arguments.strip() else {}
+ if not isinstance(arguments, dict):
+ raise ValueError("Tool call arguments must be a JSON object.")
+ return arguments
+
+ @classmethod
+ def _ifm_prefix_index(cls, model_output: str, first_match_index: int) -> int:
+ group_index = model_output.find(cls._IFM_TOOL_CALLS_START_TOKEN)
+ if group_index != -1:
+ return group_index
+ return first_match_index
+
+ def _extract_ifm_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ matches = list(self._IFM_BLOCK_REGEX.finditer(model_output))
+ if not matches:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output,
+ )
+
+ first_block = matches[0].group(1).strip()
+ if first_block.startswith(("{", "[")):
+ return self._extract_ifm_json_tool_calls(model_output, request)
+ return self._extract_ifm_xml_tool_calls(model_output, request)
+
+ def _extract_ifm_json_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ matches = list(self._IFM_BLOCK_REGEX.finditer(model_output))
+ if not matches:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output,
+ )
+
+ tool_calls: list[ToolCall] = []
+ for match in matches:
+ raw_tool_call = json.loads(match.group(1).strip())
+ raw_tool_calls = (
+ raw_tool_call if isinstance(raw_tool_call, list) else [raw_tool_call]
+ )
+ for tool_call in raw_tool_calls:
+ function = tool_call.get("function", tool_call)
+ function_name = function.get("name")
+ if not function_name:
+ raise ValueError("Tool call JSON is missing a function name.")
+ arguments = self._json_arguments_to_dict(
+ function.get("arguments", {})
+ )
+ arguments = self._coerce_arguments(
+ function_name,
+ arguments,
+ request.tools,
+ )
+ tool_calls.append(self._tool_call(function_name, arguments))
+
+ if not tool_calls:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output,
+ )
+
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=self._prefix_content(
+ model_output,
+ self._ifm_prefix_index(model_output, matches[0].start()),
+ ),
+ )
+
+ def _extract_ifm_xml_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ matches = list(self._IFM_BLOCK_REGEX.finditer(model_output))
+ if not matches:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output,
+ )
+
+ tool_calls: list[ToolCall] = []
+ for match in matches:
+ block = match.group(1)
+ first_arg_idx = block.find("")
+ if first_arg_idx == -1:
+ function_name = block.strip()
+ arguments: dict[str, Any] = {}
+ else:
+ function_name = block[:first_arg_idx].strip()
+ arg_block = block[first_arg_idx:]
+ arguments = {}
+ for key, arg_type, value in self._IFM_ARG_REGEX.findall(arg_block):
+ arg_key = key.strip()
+ arg_value = self._coerce_argument_value(
+ value.strip(),
+ function_name,
+ arg_key,
+ request.tools,
+ arg_type=arg_type.strip() or None,
+ from_text=True,
+ )
+ arguments[arg_key] = arg_value
+
+ if function_name:
+ tool_calls.append(self._tool_call(function_name, arguments))
+
+ if not tool_calls:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output,
+ )
+
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=self._prefix_content(
+ model_output,
+ self._ifm_prefix_index(model_output, matches[0].start()),
+ ),
+ )
+
def _extract_minimax_tool_calls(
self,
model_output: str,
@@ -290,25 +546,6 @@ def _deserialize_glm_value(value: str) -> Any:
return value
- @staticmethod
- def _glm_value_is_string(
- tool_name: str,
- arg_name: str,
- tools: list[ChatCompletionToolsParam] | None,
- ) -> bool:
- if tools is None:
- return False
- for tool in tools:
- if tool.function.name != tool_name or tool.function.parameters is None:
- continue
- arg_type = (
- tool.function.parameters.get("properties", {})
- .get(arg_name, {})
- .get("type")
- )
- return arg_type == "string"
- return False
-
def _extract_glm_tool_calls(
self,
model_output: str,
@@ -335,11 +572,13 @@ def _extract_glm_tool_calls(
arguments = {}
for key, value in self._GLM_ARG_REGEX.findall(arg_block):
arg_key = key.strip()
- arg_value = value.strip()
- if not self._glm_value_is_string(
- function_name, arg_key, request.tools
- ):
- arg_value = self._deserialize_glm_value(arg_value)
+ arg_value = self._coerce_argument_value(
+ value.strip(),
+ function_name,
+ arg_key,
+ request.tools,
+ from_text=True,
+ )
arguments[arg_key] = arg_value
if function_name:
@@ -448,3 +687,7 @@ def _extract_python_tool_calls(
model_output.find(""),
),
)
+
+
+class K2V3ToolParser(MultiFormatToolParser):
+ """K2-V3 alias for the IFM-aware multi-format parser."""
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index c37ecf9be66d..dd3149d63dad 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -65,6 +65,7 @@
_TEXT_GENERATION_MODELS = {
# [Decoder-only]
+ "XllmForCausalLM": ("xllm", "XllmForCausalLM"),
"AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
"ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
"AquilaModel": ("llama", "LlamaForCausalLM"),
@@ -173,7 +174,6 @@
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
- "XllmForCausalLM": ("xllm", "XllmForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
@@ -1181,4 +1181,4 @@ def _run() -> None:
if __name__ == "__main__":
- _run()
+ _run()
\ No newline at end of file
diff --git a/vllm/model_executor/models/xllm.py b/vllm/model_executor/models/xllm.py
index 7eb59a922af3..d70265322de5 100644
--- a/vllm/model_executor/models/xllm.py
+++ b/vllm/model_executor/models/xllm.py
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Xllm model compatible with HuggingFace weights."""
-from collections.abc import Iterable
+import typing
+from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any
@@ -10,12 +11,25 @@
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
-from vllm.config import CacheConfig, VllmConfig
-from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
+from vllm.distributed import (
+ get_ep_group,
+ get_pp_group,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_gather,
+ get_tensor_model_parallel_rank
+)
+from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
+from vllm.model_executor.custom_op import CustomOp
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.layernorm import rms_norm, fused_add_rms_norm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
+ ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -29,62 +43,126 @@
default_weight_loader,
maybe_remap_kv_scale_name,
)
+from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsLoRA, SupportsPP
-from .utils import (
+from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
+from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
+ extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
+logger = init_logger(__name__)
-class GroupRMSNorm(nn.Module):
- """RMSNorm with per-group variance computation.
- Computes variance over groups of hidden_size/n_groups dimensions
- instead of the full hidden dimension.
- """
+def permute_to_xllm(x):
+ return x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).reshape(*x.shape[:-1], -1)
- def __init__(
- self,
- hidden_size: int,
- n_groups: int = 1,
- eps: float = 1e-6,
- ) -> None:
- super().__init__()
- self.hidden_size = hidden_size
+
+def permute_to_hf(x):
+ return x.reshape(*x.shape[:-1], -1, 2).transpose(-1, -2).reshape(*x.shape[:-1], -1)
+
+
+@CustomOp.register("grouped_rms_norm")
+class XllmRMSNorm(RMSNorm):
+ def __init__(self,
+ hidden_size: int,
+ n_groups: int,
+ tp_size: int,
+ num_replicas: int = 1,
+ eps=1e-6):
+ """
+ XllmRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__(hidden_size=hidden_size, eps=eps)
self.n_groups = n_groups
- self.variance_epsilon = eps
+ self.hidden_size = hidden_size
assert hidden_size % n_groups == 0
- self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.weight = nn.Parameter(torch.ones(hidden_size * tp_size // num_replicas))
+ self.variance_epsilon = eps
- def forward(
- self,
- x: torch.Tensor,
- residual: torch.Tensor | None = None,
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- orig_dtype = x.dtype
- x = x.to(torch.float32)
+ if tp_size > 1:
+ self.tp_weight = self.weight.reshape(
+ tp_size // num_replicas, self.hidden_size
+ )[get_tensor_model_parallel_rank() // num_replicas]
+ else:
+ self.tp_weight = self.weight
- if residual is not None:
- x = x + residual.to(torch.float32)
- residual = x.to(orig_dtype)
+ # assert self._forward_method == self.forward_native
+ assert self._forward_method in [self.forward_native, self.forward_cuda]
+ if get_tensor_model_parallel_rank() == 0:
+ print(f'{self._forward_method=}')
- # Group RMSNorm: compute variance per group
- x_grouped = x.reshape(*x.shape[:-1], self.n_groups, -1)
- variance = x_grouped.pow(2).mean(-1, keepdim=True)
- x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
- x = x_grouped.reshape(*x.shape[:-1], self.hidden_size)
+ def forward_native(self, x, residual=None):
+ x = x.reshape(*x.shape[:-1], self.n_groups, -1)
+ if residual is not None:
+ residual = residual.reshape(x.shape)
+
+ x = self.forward_static(
+ x,
+ self.variance_epsilon,
+ self.hidden_size // self.n_groups,
+ x.dtype,
+ None,
+ residual,
+ self.variance_size_override,
+ )
+ if residual is not None:
+ x, residual = x
- x = (self.weight * x).to(orig_dtype)
+ x = x.reshape(*x.shape[:-2], -1)
+ x = self.tp_weight.data * x
if residual is None:
return x
- return x, residual
+ else:
+ residual = residual.reshape(x.shape)
+ return x, residual
+
+ def forward_cuda(self, hidden_states, residual=None):
+ # input_dtype = hidden_states.dtype
+ # hidden_states = hidden_states.to(torch.float32)
+ # if residual is not None:
+ # hidden_states = hidden_states + residual.to(torch.float32)
+ # residual = hidden_states.to(input_dtype)
+
+ hidden_states = hidden_states.reshape(*hidden_states.shape[:-1], self.n_groups, -1)
+ if residual is not None:
+ residual = residual.reshape(hidden_states.shape)
+ # variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ if residual is not None:
+ hidden_states, residual = fused_add_rms_norm(
+ x=hidden_states,
+ residual=residual,
+ weight=torch.ones(
+ hidden_states.shape[-1],
+ device=hidden_states.device,
+ dtype=hidden_states.dtype),
+ variance_epsilon=self.variance_epsilon)
+ else:
+ hidden_states = rms_norm(
+ x=hidden_states,
+ weight=torch.ones(
+ hidden_states.shape[-1],
+ device=hidden_states.device,
+ dtype=hidden_states.dtype),
+ variance_epsilon=self.variance_epsilon)
+
+ hidden_states = hidden_states.reshape(*hidden_states.shape[:-2], -1)
+ hidden_states = self.tp_weight * hidden_states
+
+ if residual is None:
+ return hidden_states
+ else:
+ residual = residual.reshape(hidden_states.shape)
+ return hidden_states, residual
class XllmMLP(nn.Module):
@@ -115,8 +193,7 @@ def __init__(
)
if hidden_act != "silu":
raise ValueError(
- f"Unsupported activation: {hidden_act}. "
- "Only silu is supported for now."
+ f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
@@ -127,38 +204,165 @@ def forward(self, x):
return x
+class XllmSparseMoeBlock(nn.Module):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_text_config
+ parallel_config = vllm_config.parallel_config
+ quant_config = vllm_config.quant_config
+
+ self.tp_size = get_tensor_model_parallel_world_size()
+
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = get_ep_group().rank_in_group
+ self.ep_size = self.ep_group.size()
+ self.n_routed_experts = config.num_experts
+
+ self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
+
+ if self.tp_size > config.num_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than "
+ f"the number of experts {config.num_experts}."
+ )
+
+ # Load balancing settings.
+ vllm_config = get_current_vllm_config()
+ eplb_config = vllm_config.parallel_config.eplb_config
+ self.enable_eplb = parallel_config.enable_eplb
+
+ self.n_logical_experts = self.n_routed_experts
+ self.n_redundant_experts = eplb_config.num_redundant_experts
+ self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+
+ self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
+ self.physical_expert_end = (
+ self.physical_expert_start + self.n_local_physical_experts
+ )
+
+ self.gate = ReplicatedLinear(
+ config.hidden_size,
+ config.num_experts,
+ bias=config.moe_gate_bias,
+ skip_bias_add=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate",
+ )
+
+ self.experts = FusedMoE(
+ num_experts=self.n_routed_experts,
+ n_shared_experts=config.num_shared_experts,
+ top_k=config.num_experts_per_tok,
+ use_grouped_topk=True,
+ num_expert_group=1,
+ topk_group=1,
+ scoring_func=config.router_score_func,
+ e_score_correction_bias=self.gate.bias,
+ routed_scaling_factor=config.router_scaling_factor,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=True,
+ renormalize=config.norm_topk_prob,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts",
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts,
+ is_sequence_parallel=self.is_sequence_parallel,
+ routing_method_type=None # RoutingMethodType.Renormalize,
+ )
+
+ self.num_shared_experts = config.num_shared_experts
+ if config.num_shared_experts > 0:
+ self.shared_experts = XllmMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size * config.num_shared_experts,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.shared_experts")
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ assert hidden_states.dim() <= 2, (
+ "XllmSparseMoeBlock only supports 1D or 2D inputs"
+ )
+ is_input_1d = hidden_states.dim() == 1
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ if self.is_sequence_parallel:
+ hidden_states = sequence_parallel_chunk(hidden_states)
+
+ # router_logits: (num_tokens, n_experts)
+ router_logits, _ = self.gate(hidden_states)
+ final_hidden_states = self.experts(
+ hidden_states=hidden_states, router_logits=router_logits
+ )
+ if self.num_shared_experts > 0:
+ final_hidden_states = (
+ final_hidden_states + self.shared_experts(hidden_states))
+
+ if self.is_sequence_parallel:
+ final_hidden_states = tensor_model_parallel_all_gather(
+ final_hidden_states, 0
+ )
+ final_hidden_states = final_hidden_states[:num_tokens]
+
+ # return to 1d if input is 1d
+ return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
+
+
class XllmAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
+ query_key_norm: bool,
rope_parameters: dict[str, Any],
max_position_embeddings: int = 8192,
head_dim: int | None = None,
+ rope_head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
+ dual_chunk_attention_config: dict[str, Any] | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
+
tp_size = get_tensor_model_parallel_world_size()
+ tp_rank = get_tensor_model_parallel_rank()
+ self.tp_size = tp_size
+ self.tp_rank = tp_rank
+
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
+ self.rope_head_dim = rope_head_dim or self.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
+ self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -169,20 +373,22 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
+ self.num_kv_head_replicas = self.qkv_proj.num_kv_head_replicas
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
- bias=False,
+ bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
- self.head_dim,
- rotary_dim=self.head_dim,
+ self.rope_head_dim,
+ rotary_dim=self.rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
+ dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
@@ -192,8 +398,30 @@ def __init__(
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
+ **{
+ "layer_idx": extract_layer_index(prefix),
+ "dual_chunk_attention_config": dual_chunk_attention_config,
+ }
+ if dual_chunk_attention_config
+ else {},
)
+ self.query_key_norm = query_key_norm
+ if self.query_key_norm:
+ # self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ # self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.q_norm = XllmRMSNorm(
+ hidden_size=self.num_heads * self.head_dim,
+ n_groups=self.num_heads,
+ tp_size=tp_size,
+ eps=rms_norm_eps)
+ self.k_norm = XllmRMSNorm(
+ hidden_size=self.num_kv_heads * self.head_dim,
+ n_groups=self.num_kv_heads,
+ tp_size=tp_size,
+ num_replicas=self.num_kv_head_replicas,
+ eps=rms_norm_eps)
+
def forward(
self,
positions: torch.Tensor,
@@ -201,7 +429,52 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- q, k = self.rotary_emb(positions, q, k)
+
+ # Add qk-norm
+ if self.query_key_norm:
+ # q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
+ # q_by_head = self.q_norm(q_by_head)
+ # q = q_by_head.view(q.shape)
+ #
+ # k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
+ # k_by_head = self.k_norm(k_by_head)
+ # k = k_by_head.view(k.shape)
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ if self.rope_head_dim == self.head_dim:
+ q, k = self.rotary_emb(positions, q, k)
+ else:
+ tp_size, tp_rank = self.tp_size, self.tp_rank
+ q_ = tensor_model_parallel_all_gather(q.contiguous())
+ k_ = tensor_model_parallel_all_gather(k.contiguous())
+ q_ = q_.reshape(*q_.shape[:-1], self.total_num_heads, self.head_dim)
+ k_ = k_.reshape(
+ *k_.shape[:-1], self.total_num_kv_heads * self.num_kv_head_replicas, self.head_dim
+ )[..., ::self.num_kv_head_replicas, :]
+
+ q_rope, q_nope = torch.split(
+ permute_to_xllm(q_),
+ split_size_or_sections=[self.rope_head_dim, self.head_dim - self.rope_head_dim],
+ dim=-1)
+ k_rope, k_nope = torch.split(
+ permute_to_xllm(k_),
+ split_size_or_sections=[self.rope_head_dim, self.head_dim - self.rope_head_dim],
+ dim=-1)
+
+ q_rope, k_rope = self.rotary_emb(
+ positions, permute_to_hf(q_rope), permute_to_hf(k_rope))
+
+ q_ = permute_to_hf(torch.cat(
+ [permute_to_xllm(q_rope), q_nope], dim=-1)).reshape(*q_.shape[:-2], -1)
+ k_ = permute_to_hf(torch.cat(
+ [permute_to_xllm(k_rope), k_nope], dim=-1)).reshape(*k_.shape[:-2], -1)
+
+ q = q_.split(q_.shape[-1] // tp_size, dim=-1)[tp_rank]
+ k = k_.split(
+ k_.shape[-1] // (tp_size // self.num_kv_head_replicas), dim=-1
+ )[tp_rank // self.num_kv_head_replicas]
+
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -216,38 +489,61 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
- max_position_embeddings = getattr(
- config, "max_position_embeddings", 8192
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
+ dual_chunk_attention_config = getattr(
+ config, "dual_chunk_attention_config", None
)
self.self_attn = XllmAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
+ query_key_norm=config.query_key_norm,
rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
+ rope_head_dim=getattr(config, "rope_head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
+ dual_chunk_attention_config=dual_chunk_attention_config,
)
- self.mlp = XllmMLP(
- hidden_size=config.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- quant_config=quant_config,
- prefix=f"{prefix}.mlp",
- )
-
- n_groups = getattr(config, "layernorm_num_groups", 1)
- self.input_layernorm = GroupRMSNorm(
- config.hidden_size, n_groups=n_groups, eps=config.rms_norm_eps
- )
- self.post_attention_layernorm = GroupRMSNorm(
- config.hidden_size, n_groups=n_groups, eps=config.rms_norm_eps
+ # `mlp_only_layers` in the config.
+ layer_idx = extract_layer_index(prefix)
+ mlp_only_layers = (
+ [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
+ if (layer_idx not in mlp_only_layers) and (
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
+ ):
+ self.mlp = XllmSparseMoeBlock(
+ vllm_config=vllm_config, prefix=f"{prefix}.mlp"
+ )
+ else:
+ self.mlp = XllmMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ # self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ # self.post_attention_layernorm = RMSNorm(
+ # config.hidden_size, eps=config.rms_norm_eps
+ # )
+ assert config.hidden_size % config.layernorm_num_groups == 0
+ self.input_layernorm = XllmRMSNorm(
+ hidden_size=config.hidden_size,
+ n_groups=config.layernorm_num_groups,
+ tp_size=1,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = XllmRMSNorm(
+ hidden_size=config.hidden_size,
+ n_groups=config.layernorm_num_groups,
+ tp_size=1,
+ eps=config.rms_norm_eps)
def forward(
self,
@@ -255,21 +551,19 @@ def forward(
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
+ # Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
- hidden_states, residual = self.input_layernorm(
- hidden_states, residual
- )
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
- hidden_states, residual = self.post_attention_layernorm(
- hidden_states, residual
- )
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@@ -281,41 +575,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config
+ parallel_config = vllm_config.parallel_config
+ eplb_config = parallel_config.eplb_config
+ self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
- if get_pp_group().is_first_rank or (
- config.tie_word_embeddings and get_pp_group().is_last_rank
- ):
- self.embed_tokens = VocabParallelEmbedding(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=f"{prefix}.embed_tokens",
- )
- else:
- self.embed_tokens = PPMissingLayer()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens",
+ )
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
- lambda prefix: XllmDecoderLayer(
- vllm_config=vllm_config, prefix=prefix
- ),
+ lambda prefix: XllmDecoderLayer(vllm_config=vllm_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
+ # self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ assert config.hidden_size % config.layernorm_num_groups == 0
+ self.norm = XllmRMSNorm(
+ hidden_size=config.hidden_size,
+ n_groups=config.layernorm_num_groups,
+ tp_size=1,
+ eps=config.rms_norm_eps)
- n_groups = getattr(config, "layernorm_num_groups", 1)
- if get_pp_group().is_last_rank:
- self.norm = GroupRMSNorm(
- config.hidden_size, n_groups=n_groups, eps=config.rms_norm_eps
- )
- else:
- self.norm = PPMissingLayer()
- self.make_empty_intermediate_tensors = (
- make_empty_intermediate_tensors_factory(
- ["hidden_states", "residual"], config.hidden_size
- )
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
)
+ # Track layers for auxiliary hidden state outputs (EAGLE3)
+ self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -326,7 +616,7 @@ def forward(
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
- ) -> torch.Tensor | IntermediateTensors:
+ ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
@@ -338,9 +628,17 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
- for layer in islice(
- self.layers, self.start_layer, self.end_layer
+ aux_hidden_states = []
+ for layer_idx, layer in enumerate(
+ islice(self.layers, self.start_layer, self.end_layer),
+ start=self.start_layer,
):
+ # Collect auxiliary hidden states if specified
+ if layer_idx in self.aux_hidden_state_layers:
+ aux_hidden_state = (
+ hidden_states + residual if residual is not None else hidden_states
+ )
+ aux_hidden_states.append(aux_hidden_state)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
@@ -348,10 +646,26 @@ def forward(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
+
+ # Return auxiliary hidden states if collected
+ if len(aux_hidden_states) > 0:
+ return hidden_states, aux_hidden_states
return hidden_states
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ return FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.num_experts,
+ num_redundant_experts=self.num_redundant_experts,
+ )
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
@@ -359,47 +673,148 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("gate_up_proj", "up_proj", 1),
]
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ ignore_suffixes = (
+ # ".bias",
+ # "_bias",
+ # ".k_scale",
+ # "_k_scale",
+ # ".v_scale",
+ # "_v_scale",
+ # ".weight_scale",
+ # "_weight_scale",
+ # ".input_scale",
+ # "_input_scale",
+ )
+
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
+ expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if "mlp.experts" in name:
+ continue
name = name.replace(weight_name, param_name)
+
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ if name.endswith(ignore_suffixes) and name not in params_dict:
+ continue
+
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ if name.endswith("scale"):
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
if name not in params_dict:
- break
+ continue
+
param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
- if is_pp_missing_parameter(name, self):
- continue
- if name.endswith("kv_scale"):
- remapped = maybe_remap_kv_scale_name(name, params_dict)
- if remapped is None:
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
continue
- name = remapped
- if name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- weight_loader(param, loaded_weight)
+
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ if (
+ name_mapped.endswith(ignore_suffixes)
+ and name_mapped not in params_dict
+ ):
+ continue
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or not
+ # here since otherwise we may skip experts with other
+ # available replicas.
+ weight_loader = typing.cast(
+ Callable[..., bool], param.weight_loader
+ )
+ success = weight_loader(
+ param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ name = name_mapped
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ if name.endswith(ignore_suffixes) and name not in params_dict:
+ continue
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ # Remapping the name of FP8 kv-scale.
+ if name.endswith("kv_scale"):
+ remapped_kv_scale_name = name.replace(
+ ".kv_scale", ".attn.kv_scale"
+ )
+ if remapped_kv_scale_name not in params_dict:
+ logger.warning_once(
+ "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
+ name,
+ remapped_kv_scale_name,
+ )
+ continue
+ else:
+ name = remapped_kv_scale_name
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
-class XllmForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
+class XllmForCausalLM(
+ nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
+):
packed_modules_mapping = {
- "qkv_proj": ["q_proj", "k_proj", "v_proj"],
- "gate_up_proj": ["gate_proj", "up_proj"],
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ]
}
fall_back_to_pt_during_load = False
@@ -410,26 +825,77 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
+ # Only perform the following mapping when XllmMLP exists
+ if getattr(config, "mlp_only_layers", []):
+ self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
self.model = XllmModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
- if get_pp_group().is_last_rank:
- if self.config.tie_word_embeddings:
- self.lm_head = self.model.embed_tokens
- else:
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=maybe_prefix(prefix, "lm_head"),
- )
- else:
- self.lm_head = PPMissingLayer()
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
+ # Set MoE hyperparameters
+ self.expert_weights = []
+
+ self.moe_layers = []
+ example_layer = None
+ for layer in self.model.layers:
+ if isinstance(layer, PPMissingLayer):
+ continue
+
+ assert isinstance(layer, XllmDecoderLayer)
+ if isinstance(layer.mlp, XllmSparseMoeBlock):
+ example_layer = layer.mlp
+ self.moe_layers.append(layer.mlp.experts)
+
+ # if example_layer is None:
+ # raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
+
+ self.num_moe_layers = len(self.moe_layers)
+ self.num_expert_groups = 1
+ self.num_shared_experts = 0
+
+ if example_layer is not None:
+ self.num_logical_experts = example_layer.n_logical_experts
+ self.num_physical_experts = example_layer.n_physical_experts
+ self.num_local_physical_experts = example_layer.n_local_physical_experts
+ self.num_routed_experts = example_layer.n_routed_experts
+ self.num_redundant_experts = example_layer.n_redundant_experts
+
+ def update_physical_experts_metadata(
+ self,
+ num_physical_experts: int,
+ num_local_physical_experts: int,
+ ) -> None:
+ assert self.num_local_physical_experts == num_local_physical_experts
+ self.num_physical_experts = num_physical_experts
+ self.num_local_physical_experts = num_local_physical_experts
+ self.num_redundant_experts = num_physical_experts - self.num_logical_experts
+ for layer in self.model.layers:
+ if isinstance(layer.mlp, XllmSparseMoeBlock):
+ moe = layer.mlp
+ moe.n_local_physical_experts = num_local_physical_experts
+ moe.n_physical_experts = num_physical_experts
+ moe.n_redundant_experts = self.num_redundant_experts
+ moe.experts.update_expert_map()
+
+ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
+ self.model.aux_hidden_state_layers = layers
+
+ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
+ num_layers = len(self.model.layers)
+ return (2, num_layers // 2, num_layers - 3)
+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
@@ -440,10 +906,10 @@ def forward(
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
- model_output = self.model(
+ hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
- return model_output
+ return hidden_states
def compute_logits(
self,
@@ -453,9 +919,8 @@ def compute_logits(
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(
- self,
- skip_prefixes=(["lm_head."]
- if self.config.tie_word_embeddings else None),
- )
+ loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ return self.model.get_expert_mapping()
\ No newline at end of file