diff --git a/docs/design/custom_op.md b/docs/design/custom_op.md index 13c2915abe8f..3f4934b15699 100644 --- a/docs/design/custom_op.md +++ b/docs/design/custom_op.md @@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n `CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively. -??? code - - ```python - class CustomOp(nn.Module): - - op_registry: dict[str, type["CustomOp"]] = {} - op_registry_oot: dict[str, type["CustomOp"]] = {} - ``` - We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later. When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method. diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ccdacf40c430..7763be0cb5bf 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -609,7 +609,7 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: Compute the minimum number of blocks required to hold num_tokens tokens, given block_size """ - return (num_tokens + block_size) // block_size + return (num_tokens + block_size - 1) // block_size def make_empty_slot_mapping_tensor(device: torch.device | str): @@ -694,7 +694,7 @@ def make_block_tables_slot_mapping( For a sequence with num_tokens tokens the minimum number of required KV cache blocks is - num_blocks = (num_tokens + block_size) // block_size + num_blocks = (num_tokens + block_size - 1) // block_size Then the minimum KV cache size in blocks is diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 8ee1b1a37ca6..316caf06b29c 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -11,7 +11,7 @@ get_cached_compilation_config, set_current_vllm_config, ) -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import CustomOp, op_registry from vllm.model_executor.layers.activation import ( GeluAndMul, ReLUSquaredActivation, @@ -98,17 +98,17 @@ def test_enabled_ops( ops_enabled = [bool(x) for x in ops_enabled] assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert op_registry["rms_norm"].enabled() == ops_enabled[0] assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + assert op_registry["silu_and_mul"].enabled() == ops_enabled[1] assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2] # If registered, subclasses should follow their own name assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] + assert op_registry["relu3"].enabled() == ops_enabled[3] # Unregistered subclass class SiluAndMul2(SiluAndMul): diff --git a/tests/v1/spec_decode/__init__.py b/tests/v1/spec_decode/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/tests/v1/spec_decode/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py new file mode 100644 index 000000000000..1a615878bb8b --- /dev/null +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +EAGLE3 Acceptance Length Regression Tests. + +These tests verify that acceptance lengths for EAGLE3 speculative decoding +do not regress across vLLM commits. Each test runs inference on the MT-Bench +dataset and asserts that the mean acceptance length is within tolerance of +the expected baseline. +""" + +from dataclasses import dataclass, field +from types import SimpleNamespace + +import pytest +import torch + +from tests.conftest import VllmRunner +from tests.utils import large_gpu_mark +from vllm import SamplingParams +from vllm.benchmarks.datasets import get_samples +from vllm.inputs import TokensPrompt +from vllm.platforms import current_platform +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.attention.selector import AttentionSelectorConfig +from vllm.v1.metrics.reader import Counter, Vector + + +@dataclass +class Eagle3ModelConfig: + verifier: str + drafter: str + expected_acceptance_length: float + expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list) + id: str = "" + # Backends that are incompatible with this model (will be skipped) + excluded_backends: set[AttentionBackendEnum] = field(default_factory=set) + + +# Model configurations for EAGLE3 acceptance length tests. +# Expected acceptance lengths are determined by running baseline benchmarks +# using examples/offline_inference/spec_decode.py with the MT-Bench dataset. +EAGLE3_MODEL_CONFIGS = [ + Eagle3ModelConfig( + verifier="meta-llama/Llama-3.1-8B-Instruct", + drafter="RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", + expected_acceptance_length=2.60, + expected_acceptance_lengths_per_pos=[0.7296, 0.5208, 0.3545], + id="llama3-8b-eagle3", + ), + Eagle3ModelConfig( + verifier="Qwen/Qwen3-8B", + drafter="RedHatAI/Qwen3-8B-speculator.eagle3", + expected_acceptance_length=2.26, + expected_acceptance_lengths_per_pos=[0.6541, 0.3993, 0.2020], + id="qwen3-8b-eagle3", + ), + Eagle3ModelConfig( + verifier="openai/gpt-oss-20b", + drafter="RedHatAI/gpt-oss-20b-speculator.eagle3", + expected_acceptance_length=2.56, + expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.3337], + id="gpt-oss-20b-eagle3", + # FLASHINFER incompatible: gpt-oss-20b uses sink attention which + # FLASHINFER does not support ("sink setting not supported") + excluded_backends={AttentionBackendEnum.FLASHINFER}, + ), +] + +# Default test parameters +DEFAULT_NUM_SPEC_TOKENS = 3 +DEFAULT_NUM_PROMPTS = 80 +DEFAULT_OUTPUT_LEN = 256 +DEFAULT_MAX_MODEL_LEN = 16384 +DEFAULT_RTOL = 0.05 + +# TP sizes to test +TP_SIZES = [1, 2, 4] + + +# Backends excluded from testing due to significantly different behavior +EXCLUDED_BACKENDS = {AttentionBackendEnum.FLEX_ATTENTION} + + +def get_available_attention_backends() -> list[str]: + if not hasattr(current_platform, "get_valid_backends"): + return ["FLASH_ATTN"] + + device_capability = current_platform.get_device_capability() + if device_capability is None: + return ["FLASH_ATTN"] + + attn_selector_config = AttentionSelectorConfig( + head_size=128, + dtype=torch.bfloat16, + kv_cache_dtype=None, + block_size=None, + use_mla=False, + has_sink=False, + use_sparse=False, + use_mm_prefix=False, + ) + + valid_backends, _ = current_platform.get_valid_backends( + device_capability=device_capability, + attn_selector_config=attn_selector_config, + ) + + return [ + backend.name + for backend, _ in valid_backends + if backend not in EXCLUDED_BACKENDS + ] + + +def get_attention_backend_params() -> list[str]: + return get_available_attention_backends() + + +def get_tp_size_params() -> list[pytest.param]: + num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 + return [pytest.param(tp, id=f"tp{tp}") for tp in TP_SIZES if tp <= num_gpus] + + +def get_mt_bench_prompts( + tokenizer, num_prompts: int = DEFAULT_NUM_PROMPTS +) -> list[list[int]]: + args = SimpleNamespace( + dataset_name="hf", + dataset_path="philschmid/mt-bench", + num_prompts=num_prompts, + seed=42, + no_oversample=False, + endpoint_type="openai-chat", + input_len=None, + output_len=DEFAULT_OUTPUT_LEN, + sharegpt_output_len=DEFAULT_OUTPUT_LEN, + hf_name=None, + hf_split="train", + hf_subset=None, + hf_output_len=DEFAULT_OUTPUT_LEN, + no_stream=True, + disable_shuffle=False, + skip_chat_template=False, + ) + samples = get_samples(args, tokenizer) + prompt_ids = [ + tokenizer.encode(sample.prompt, add_special_tokens=False) for sample in samples + ] + return prompt_ids + + +def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> dict: + num_drafts = 0 + num_accepted_tokens = 0 + acceptance_counts = [0] * num_spec_tokens + + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(min(len(metric.values), num_spec_tokens)): + acceptance_counts[pos] += metric.values[pos] + + # Calculate mean acceptance length + # Formula: 1 + (accepted_tokens / num_drafts) + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + + # Calculate per-position acceptance lengths (contribution to total) + # Each position contributes: accepted_at_pos / num_drafts + acceptance_lengths_per_pos = [ + count / num_drafts if num_drafts > 0 else 0.0 for count in acceptance_counts + ] + + return { + "acceptance_length": acceptance_length, + "acceptance_lengths_per_pos": acceptance_lengths_per_pos, + "num_drafts": num_drafts, + "num_accepted_tokens": num_accepted_tokens, + } + + +@large_gpu_mark(min_gb=40) +@pytest.mark.parametrize( + "model_config", + [pytest.param(config, id=config.id) for config in EAGLE3_MODEL_CONFIGS], +) +@pytest.mark.parametrize("num_spec_tokens", [DEFAULT_NUM_SPEC_TOKENS]) +@pytest.mark.parametrize("tp_size", get_tp_size_params()) +@pytest.mark.parametrize("attention_backend", get_attention_backend_params()) +def test_eagle3_acceptance_length( + model_config: Eagle3ModelConfig, + num_spec_tokens: int, + tp_size: int, + attention_backend: str, + monkeypatch: pytest.MonkeyPatch, +): + # Skip if this backend is incompatible with the model + backend_enum = AttentionBackendEnum[attention_backend] + if backend_enum in model_config.excluded_backends: + pytest.skip(f"{attention_backend} is incompatible with {model_config.id}") + + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attention_backend) + + with VllmRunner( + model_name=model_config.verifier, + speculative_config={ + "method": "eagle3", + "model": model_config.drafter, + "num_speculative_tokens": num_spec_tokens, + }, + tensor_parallel_size=tp_size, + gpu_memory_utilization=0.7, + disable_log_stats=False, + max_model_len=DEFAULT_MAX_MODEL_LEN, + ) as vllm_runner: + tokenizer = vllm_runner.llm.get_tokenizer() + prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS) + + sampling_params = SamplingParams( + temperature=0, + max_tokens=DEFAULT_OUTPUT_LEN, + ) + vllm_runner.llm.generate( + [TokensPrompt(prompt_token_ids=ids) for ids in prompt_ids], + sampling_params=sampling_params, + ) + + metrics = vllm_runner.llm.get_metrics() + results = extract_acceptance_metrics(metrics, num_spec_tokens) + + actual_acceptance_length = results["acceptance_length"] + expected = model_config.expected_acceptance_length + actual_per_pos = results["acceptance_lengths_per_pos"] + expected_per_pos = model_config.expected_acceptance_lengths_per_pos + + rel_error = abs(actual_acceptance_length - expected) / expected + + assert rel_error <= DEFAULT_RTOL, ( + f"Acceptance length regression detected for {model_config.id}!\n" + f" Expected: {expected:.3f}\n" + f" Actual: {actual_acceptance_length:.3f}\n" + f" Relative error: {rel_error:.2%} (tolerance: {DEFAULT_RTOL:.2%})\n" + f" Drafts: {results['num_drafts']}, " + f"Accepted tokens: {results['num_accepted_tokens']}" + ) + + if expected_per_pos and len(expected_per_pos) == len(actual_per_pos): + for pos, (actual, exp) in enumerate( + zip(actual_per_pos, expected_per_pos) + ): + if exp > 0: + pos_rel_error = abs(actual - exp) / exp + assert pos_rel_error <= DEFAULT_RTOL, ( + f"Per-position acceptance length regression at pos {pos} " + f"for {model_config.id}!\n" + f" Expected: {exp:.3f}\n" + f" Actual: {actual:.3f}\n" + f" Relative error: {pos_rel_error:.2%} " + f"(tolerance: {DEFAULT_RTOL:.2%})" + ) + + print( + f"\n{model_config.id} [tp={tp_size}, backend={attention_backend}]: " + f"acceptance_length={actual_acceptance_length:.3f}" + f" (expected={expected:.3f}, rel_error={rel_error:.2%})" + ) + print(f" Per-position: {[f'{v:.3f}' for v in actual_per_pos]}") + if expected_per_pos: + print(f" Expected: {[f'{v:.3f}' for v in expected_per_pos]}") diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index f3d9262d20cf..6bc26b6f3b1c 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -23,23 +23,146 @@ def __init__( ): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: - if self.all2all_backend != "naive": # type: ignore[has-type] - logger.warning( - "`%s` all2all manager is not supported on XPU. " - "Falling back to `naive` all2all manager for XPU.", - self.all2all_backend, # type: ignore[has-type] - ) - self.all2all_backend = "naive" if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") + elif self.all2all_backend == "allgather_reducescatter": + from .all2all import AgRsAll2AllManager + + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AgRs manager on XPU device.") + + else: # type: ignore[has-type] + logger.warning( + "`%s` all2all manager is not supported on XPU. " + "Falling back to AgRs manager for XPU, " + "which is the Default backend", + self.all2all_backend, # type: ignore[has-type] + ) + from .all2all import AgRsAll2AllManager + + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AgRs manager on XPU device.") + def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size,) + input_tensor.shape[1:] + + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) + + dist.reduce_scatter_tensor(output, input_tensor) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None + ): + world_size = self.world_size + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + if sizes is not None: + assert len(sizes) == world_size + assert input_tensor.shape[0] == sum(sizes) + chunk_size = sizes[self.rank_in_group] + else: + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size,) + input_tensor.shape[1:] + + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) + if sizes is not None and sizes.count(sizes[0]) != len(sizes): + # if inputs shape in different ranks is not the same using reduce_scatter + input_splits = list(input_tensor.split(sizes, dim=0)) + dist.reduce_scatter(output, input_splits) + else: + dist.reduce_scatter_tensor(output, input_tensor) + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def all_gatherv( + self, + input_: torch.Tensor | list[torch.Tensor], + dim: int = 0, + sizes: list[int] | None = None, + ): + if dim != 0: + raise NotImplementedError("only dim 0 all-gatherv is supported") + world_size = self.world_size + + # 'sizes' is not needed if all inputs in the same group have the same + # shape + if sizes is not None and all(s == sizes[0] for s in sizes): + sizes = None + + def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): + input_size = input_.size() + if sizes is not None: + assert len(sizes) == world_size + assert input_.shape[dim] == sizes[self.rank_in_group], ( + f"{input_.shape[dim]} != {sizes[self.rank_in_group]}" + ) + output_size = (sum(sizes),) + input_size[1:] + else: + output_size = (input_size[0] * world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + + if sizes is not None: + all_gather_list = [] + for size in sizes: + all_gather_list.append( + torch.empty( + (size,) + input_.shape[1:], + dtype=input_.dtype, + device=input_.device, + ) + ) + dist.all_gather(all_gather_list, input_) + output_tensor = torch.cat(all_gather_list, dim=0) + else: + dist.all_gather([output_tensor], input_) + return output_tensor + + if isinstance(input_, torch.Tensor): + return _all_gather_single(input_, sizes) + + output_list = [] + for inp in input_: + output_list.append(_all_gather_single(inp, sizes=sizes)) + return output_list + def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 ) -> torch.Tensor | None: diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 81ba544b4813..6fe252fa27ee 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -11,6 +11,86 @@ logger = init_logger(__name__) +# Dictionary of all custom ops (classes, indexed by registered name). +# To check if an op with a name is enabled, call .enabled() on the class. +# Examples: +# - MyOp.enabled() +# - op_registry["my_op"].enabled() +op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} +op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} + + +class PluggableLayer(nn.Module): + """ + Base class for pluggable layers. + + A PluggableLayer is a *module-composing* abstraction: it may instantiate other + ``torch.nn.Module`` objects as sub-layers, and its functionality depends on + these sub-layers following a generalized invocation sequence. Also, it is stateful + and may hold parameters or buffers. + + Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform + ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement + of the entire layer class at instantiation time, allowing customized + initialization and submodule composition. + """ + + def __new__(cls, *args, **kwargs): + try: + layer_class_name = cls.__name__ + except AttributeError: + raise TypeError( + f"Cannot instantiate '{cls.__name__}': its 'name' attribute " + f"was not set, possibly because it was not decorated with " + f"@PluggableLayer.register, or it's the PluggableLayer itself." + ) from None + + if layer_class_name not in op_registry_oot: + layer_cls_to_instantiate = cls + else: + layer_cls_to_instantiate = op_registry_oot[layer_class_name] + logger.debug( + "Instantiating pluggable layer: %s using %s", + layer_class_name, + str(layer_cls_to_instantiate), + ) + return super().__new__(layer_cls_to_instantiate) + + # Decorator to register pluggable layers. + @classmethod + def register(cls, name: str): + def decorator(op_cls): + assert name not in op_registry, f"Duplicate op name: {name}" + op_cls.name = name + op_registry[name] = op_cls + return op_cls + + return decorator + + # Decorator to register out-of-tree(oot) pluggable layers. + # For OOT pluggable layers: + # if in-tree layer class is registered with an oot_custom_layer, + # the oot_custom_layer will be used instead. + @classmethod + def register_oot(cls, _decorated_layer_cls=None, name: str | None = None): + def decorator(layer_cls): + reg_name = name if name is not None else cls.__name__ + assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}" + layer_cls.name = reg_name + op_registry_oot[reg_name] = layer_cls + return layer_cls + + if _decorated_layer_cls is None: + # Called with parentheses: @PluggableLayer.register_oot() + # or @PluggableLayer.register_oot(name="...") + return decorator + elif isinstance(_decorated_layer_cls, type): # Check if it's a class + # Called without parentheses: @PluggableLayer.register_oot + return decorator(_decorated_layer_cls) + else: + raise TypeError("Decorator can only be applied to classes.") + + class CustomOp(nn.Module): """ Base class for custom ops. @@ -27,10 +107,10 @@ def __new__(cls, *args, **kwargs): f"@CustomOp.register, or it's the CustomOp base class itself." ) from None - if op_name not in cls.op_registry_oot: + if op_name not in op_registry_oot: op_cls_to_instantiate = cls else: - op_cls_to_instantiate = cls.op_registry_oot[op_name] + op_cls_to_instantiate = op_registry_oot[op_name] logger.debug( "Instantiating custom op: %s using %s", op_name, @@ -150,21 +230,13 @@ def default_on() -> bool: return not count_none > 0 or count_all > 0 - # Dictionary of all custom ops (classes, indexed by registered name). - # To check if an op with a name is enabled, call .enabled() on the class. - # Examples: - # - MyOp.enabled() - # - op_registry["my_op"].enabled() - op_registry: dict[str, type["CustomOp"]] = {} - op_registry_oot: dict[str, type["CustomOp"]] = {} - # Decorator to register custom ops. @classmethod def register(cls, name: str): def decorator(op_cls): - assert name not in cls.op_registry, f"Duplicate op name: {name}" + assert name not in op_registry, f"Duplicate op name: {name}" op_cls.name = name - cls.op_registry[name] = op_cls + op_registry[name] = op_cls return op_cls return decorator @@ -182,9 +254,9 @@ def decorator(op_cls): def register_oot(cls, _decorated_op_cls=None, name: str | None = None): def decorator(op_cls): reg_name = name if name is not None else cls.__name__ - assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" + assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}" op_cls.name = reg_name - cls.op_registry_oot[reg_name] = op_cls + op_registry_oot[reg_name] = op_cls return op_cls if _decorated_op_cls is None: diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 65541d2a485a..2549f1221f36 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -6,7 +6,7 @@ from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.quantization import QuantizationConfig @@ -30,13 +30,13 @@ class MLAModules: # --8<-- [start:multi_head_latent_attention] -@CustomOp.register("multi_head_latent_attention") -class MultiHeadLatentAttentionWrapper(CustomOp): - """MLA layer registered as CustomOp to allow OOT backends to add +@PluggableLayer.register("multi_head_latent_attention") +class MultiHeadLatentAttentionWrapper(PluggableLayer): + """Pluggable MLA layer which allows OOT backends to add custom implementations of the outer MLA layer (including rope & o_proj). - Note that currently MLA ignores the enable/disable mechanism of CustomOp - because there is only one in-tree implementation in forward_native. - TODO: implement this with a new PluggableLayer mechanism. + Note that currently oot platforms can still use CustomOp.register_oot to + replace MLA layer entirly, although we use PluggableLayer to register + this layer now. This class takes positions and hidden_states as input. The input tensors can either contain prefill tokens or decode tokens. @@ -110,7 +110,7 @@ def __init__( self.prefix = prefix - def forward_native( + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -174,6 +174,3 @@ def forward_native( ) return self.o_proj(attn_out)[0] - - def forward_cuda(self, *args, **kwargs): - return self.forward_native(*args, **kwargs)