From b997bed36abd0b9f0c8046a2ada7762f7d0c4d1c Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Fri, 9 Jan 2026 13:50:56 +0000 Subject: [PATCH 01/11] [Test] Add EAGLE3 acceptance length regression tests with per-position validation Add parameterized pytest tests to detect acceptance length regressions in EAGLE3 speculative decoding. Tests run inference on MT-Bench dataset (80 prompts) and assert both mean and per-position acceptance lengths are within 2% tolerance of baseline. Models tested: - Llama-3.1-8B-Instruct (AL: 2.60) - Qwen3-8B (AL: 2.26) - GPT-OSS-20B (AL: 2.56) Signed-off-by: rahul-tuli --- .../v1/spec_decode/test_acceptance_length.py | 243 ++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 tests/v1/spec_decode/test_acceptance_length.py diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py new file mode 100644 index 000000000000..4f094abd138d --- /dev/null +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +EAGLE3 Acceptance Length Regression Tests. + +These tests verify that acceptance lengths for EAGLE3 speculative decoding +do not regress across vLLM commits. Each test runs inference on the MT-Bench +dataset and asserts that the mean acceptance length is within tolerance of +the expected baseline. +""" + +from dataclasses import dataclass, field +from types import SimpleNamespace + +import pytest +import torch + +from vllm import LLM, SamplingParams +from vllm.benchmarks.datasets import get_samples +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.inputs import TokensPrompt +from vllm.v1.metrics.reader import Counter, Vector + + +@dataclass +class Eagle3ModelConfig: + """Configuration for an EAGLE3 model pair.""" + + verifier: str + drafter: str + expected_acceptance_length: float + expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list) + id: str = "" + + +# Model configurations for EAGLE3 acceptance length tests. +# Expected acceptance lengths are determined by running the baseline script. +# See local/acceptance_length/run_baselines.md for commands. +EAGLE3_MODEL_CONFIGS = [ + Eagle3ModelConfig( + verifier="meta-llama/Llama-3.1-8B-Instruct", + drafter="RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", + expected_acceptance_length=2.60, + expected_acceptance_lengths_per_pos=[0.7296, 0.5208, 0.3545], + id="llama3-8b-eagle3", + ), + Eagle3ModelConfig( + verifier="Qwen/Qwen3-8B", + drafter="RedHatAI/Qwen3-8B-speculator.eagle3", + expected_acceptance_length=2.26, + expected_acceptance_lengths_per_pos=[0.6541, 0.3993, 0.2020], + id="qwen3-8b-eagle3", + ), + Eagle3ModelConfig( + verifier="openai/gpt-oss-20b", + drafter="RedHatAI/gpt-oss-20b-speculator.eagle3", + expected_acceptance_length=2.56, + expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.3337], + id="gpt-oss-20b-eagle3", + ), +] + +# Default test parameters +DEFAULT_NUM_SPEC_TOKENS = 3 +DEFAULT_NUM_PROMPTS = 80 +DEFAULT_OUTPUT_LEN = 256 +DEFAULT_MAX_MODEL_LEN = 16384 +DEFAULT_RTOL = 0.02 # 2% relative tolerance + + +def get_mt_bench_prompts(tokenizer, num_prompts: int = DEFAULT_NUM_PROMPTS): + """Load prompts from MT-Bench dataset.""" + args = SimpleNamespace( + dataset_name="hf", + dataset_path="philschmid/mt-bench", + num_prompts=num_prompts, + seed=42, + no_oversample=False, + endpoint_type="openai-chat", + input_len=None, + output_len=DEFAULT_OUTPUT_LEN, + sharegpt_output_len=DEFAULT_OUTPUT_LEN, + ) + samples = get_samples(args, tokenizer) + prompt_ids = [ + tokenizer.encode(sample.prompt, add_special_tokens=False) for sample in samples + ] + return prompt_ids + + +def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> dict: + """Extract acceptance length metrics from LLM metrics.""" + num_drafts = 0 + num_accepted_tokens = 0 + acceptance_counts = [0] * num_spec_tokens + + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted_tokens += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(min(len(metric.values), num_spec_tokens)): + acceptance_counts[pos] += metric.values[pos] + + # Calculate mean acceptance length + # Formula: 1 + (accepted_tokens / num_drafts) + acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 + + # Calculate per-position acceptance lengths (contribution to total) + # Each position contributes: accepted_at_pos / num_drafts + acceptance_lengths_per_pos = [ + count / num_drafts if num_drafts > 0 else 0.0 for count in acceptance_counts + ] + + return { + "acceptance_length": acceptance_length, + "acceptance_lengths_per_pos": acceptance_lengths_per_pos, + "num_drafts": num_drafts, + "num_accepted_tokens": num_accepted_tokens, + } + + +@pytest.fixture(autouse=True) +def cleanup_after_test(): + """Clean up GPU memory after each test.""" + yield + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + torch._dynamo.reset() + + +@pytest.mark.parametrize( + "model_config", + [pytest.param(config, id=config.id) for config in EAGLE3_MODEL_CONFIGS], +) +@pytest.mark.parametrize("num_spec_tokens", [DEFAULT_NUM_SPEC_TOKENS]) +def test_eagle3_acceptance_length( + model_config: Eagle3ModelConfig, + num_spec_tokens: int, + monkeypatch: pytest.MonkeyPatch, +): + """ + Test EAGLE3 acceptance length does not regress. + + This test: + 1. Loads the MT-Bench dataset + 2. Runs inference with EAGLE3 speculative decoding + 3. Extracts acceptance length metrics + 4. Asserts the acceptance length is within tolerance of expected baseline + + Args: + model_config: Configuration for verifier/drafter model pair + num_spec_tokens: Number of speculative tokens to generate + monkeypatch: Pytest monkeypatch fixture + """ + # Skip if expected acceptance length is not set + if model_config.expected_acceptance_length <= 0: + pytest.skip( + f"Expected acceptance length not set for {model_config.id}. " + "Run baseline script to determine expected value." + ) + + # Allow insecure serialization for speculators + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + # Initialize LLM with speculative decoding + llm = LLM( + model=model_config.verifier, + speculative_config={ + "method": "eagle3", + "model": model_config.drafter, + "num_speculative_tokens": num_spec_tokens, + }, + tensor_parallel_size=1, + gpu_memory_utilization=0.9, + disable_log_stats=False, + max_model_len=DEFAULT_MAX_MODEL_LEN, + ) + + # Load MT-Bench prompts + tokenizer = llm.get_tokenizer() + prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS) + + # Run inference + sampling_params = SamplingParams( + temperature=0, + max_tokens=DEFAULT_OUTPUT_LEN, + ) + llm.generate( + [TokensPrompt(prompt_token_ids=ids) for ids in prompt_ids], + sampling_params=sampling_params, + ) + + # Extract and validate metrics + metrics = llm.get_metrics() + results = extract_acceptance_metrics(metrics, num_spec_tokens) + + actual_acceptance_length = results["acceptance_length"] + expected = model_config.expected_acceptance_length + actual_per_pos = results["acceptance_lengths_per_pos"] + expected_per_pos = model_config.expected_acceptance_lengths_per_pos + + # Calculate relative error for mean acceptance length + rel_error = abs(actual_acceptance_length - expected) / expected + + # Assert mean acceptance length within tolerance + assert rel_error <= DEFAULT_RTOL, ( + f"Acceptance length regression detected for {model_config.id}!\n" + f" Expected: {expected:.3f}\n" + f" Actual: {actual_acceptance_length:.3f}\n" + f" Relative error: {rel_error:.2%} (tolerance: {DEFAULT_RTOL:.2%})\n" + f" Drafts: {results['num_drafts']}, " + f"Accepted tokens: {results['num_accepted_tokens']}" + ) + + # Assert per-position acceptance lengths within tolerance (if expected values set) + if expected_per_pos and len(expected_per_pos) == len(actual_per_pos): + for pos, (actual, exp) in enumerate(zip(actual_per_pos, expected_per_pos)): + if exp > 0: + pos_rel_error = abs(actual - exp) / exp + assert pos_rel_error <= DEFAULT_RTOL, ( + f"Per-position acceptance length regression at pos {pos} " + f"for {model_config.id}!\n" + f" Expected: {exp:.3f}\n" + f" Actual: {actual:.3f}\n" + f" Relative error: {pos_rel_error:.2%} " + f"(tolerance: {DEFAULT_RTOL:.2%})" + ) + + print( + f"\n{model_config.id}: acceptance_length={actual_acceptance_length:.3f} " + f"(expected={expected:.3f}, rel_error={rel_error:.2%})" + ) + print(f" Per-position: {[f'{v:.3f}' for v in actual_per_pos]}") + if expected_per_pos: + print(f" Expected: {[f'{v:.3f}' for v in expected_per_pos]}") + + # Cleanup + del llm From f91a4d59f6244c676545caba441bf3f88aeb30bf Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Fri, 9 Jan 2026 14:21:41 +0000 Subject: [PATCH 02/11] Update: AL values Signed-off-by: rahul-tuli --- tests/v1/spec_decode/test_acceptance_length.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index 4f094abd138d..f78d5ebacfe2 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -55,7 +55,7 @@ class Eagle3ModelConfig: verifier="openai/gpt-oss-20b", drafter="RedHatAI/gpt-oss-20b-speculator.eagle3", expected_acceptance_length=2.56, - expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.3337], + expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.322], id="gpt-oss-20b-eagle3", ), ] @@ -80,6 +80,13 @@ def get_mt_bench_prompts(tokenizer, num_prompts: int = DEFAULT_NUM_PROMPTS): input_len=None, output_len=DEFAULT_OUTPUT_LEN, sharegpt_output_len=DEFAULT_OUTPUT_LEN, + hf_name=None, + hf_split="train", + hf_subset=None, + hf_output_len=DEFAULT_OUTPUT_LEN, + no_stream=True, + disable_shuffle=False, + skip_chat_template=False, ) samples = get_samples(args, tokenizer) prompt_ids = [ From 2c5d50540f23e0b9e8bb9668542c6dd89a8b26dc Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Fri, 9 Jan 2026 14:41:26 +0000 Subject: [PATCH 03/11] Some more cleanups Signed-off-by: rahul-tuli --- tests/v1/spec_decode/__init__.py | 2 ++ .../v1/spec_decode/test_acceptance_length.py | 24 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 tests/v1/spec_decode/__init__.py diff --git a/tests/v1/spec_decode/__init__.py b/tests/v1/spec_decode/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/tests/v1/spec_decode/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index f78d5ebacfe2..324a0d67eb36 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -11,10 +11,12 @@ from dataclasses import dataclass, field from types import SimpleNamespace +from typing import TypedDict import pytest import torch +from tests.utils import large_gpu_mark from vllm import LLM, SamplingParams from vllm.benchmarks.datasets import get_samples from vllm.distributed import cleanup_dist_env_and_memory @@ -22,6 +24,15 @@ from vllm.v1.metrics.reader import Counter, Vector +class AcceptanceMetrics(TypedDict): + """Typed dict for acceptance length metrics.""" + + acceptance_length: float + acceptance_lengths_per_pos: list[float] + num_drafts: int + num_accepted_tokens: int + + @dataclass class Eagle3ModelConfig: """Configuration for an EAGLE3 model pair.""" @@ -34,8 +45,8 @@ class Eagle3ModelConfig: # Model configurations for EAGLE3 acceptance length tests. -# Expected acceptance lengths are determined by running the baseline script. -# See local/acceptance_length/run_baselines.md for commands. +# Expected acceptance lengths are determined by running baseline benchmarks +# using examples/offline_inference/spec_decode.py with the MT-Bench dataset. EAGLE3_MODEL_CONFIGS = [ Eagle3ModelConfig( verifier="meta-llama/Llama-3.1-8B-Instruct", @@ -55,7 +66,7 @@ class Eagle3ModelConfig: verifier="openai/gpt-oss-20b", drafter="RedHatAI/gpt-oss-20b-speculator.eagle3", expected_acceptance_length=2.56, - expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.322], + expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.3220], id="gpt-oss-20b-eagle3", ), ] @@ -68,7 +79,9 @@ class Eagle3ModelConfig: DEFAULT_RTOL = 0.02 # 2% relative tolerance -def get_mt_bench_prompts(tokenizer, num_prompts: int = DEFAULT_NUM_PROMPTS): +def get_mt_bench_prompts( + tokenizer, num_prompts: int = DEFAULT_NUM_PROMPTS +) -> list[list[int]]: """Load prompts from MT-Bench dataset.""" args = SimpleNamespace( dataset_name="hf", @@ -95,7 +108,7 @@ def get_mt_bench_prompts(tokenizer, num_prompts: int = DEFAULT_NUM_PROMPTS): return prompt_ids -def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> dict: +def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> AcceptanceMetrics: """Extract acceptance length metrics from LLM metrics.""" num_drafts = 0 num_accepted_tokens = 0 @@ -140,6 +153,7 @@ def cleanup_after_test(): torch._dynamo.reset() +@large_gpu_mark(min_gb=40) @pytest.mark.parametrize( "model_config", [pytest.param(config, id=config.id) for config in EAGLE3_MODEL_CONFIGS], From 85f67f4fb9fb8ef86b7672a9d5f730a3ccbf8826 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Fri, 9 Jan 2026 14:55:32 +0000 Subject: [PATCH 04/11] Review comments Signed-off-by: rahul-tuli --- .../v1/spec_decode/test_acceptance_length.py | 116 +++++++++--------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index 324a0d67eb36..63eb7d925b5e 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -202,63 +202,63 @@ def test_eagle3_acceptance_length( max_model_len=DEFAULT_MAX_MODEL_LEN, ) - # Load MT-Bench prompts - tokenizer = llm.get_tokenizer() - prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS) - - # Run inference - sampling_params = SamplingParams( - temperature=0, - max_tokens=DEFAULT_OUTPUT_LEN, - ) - llm.generate( - [TokensPrompt(prompt_token_ids=ids) for ids in prompt_ids], - sampling_params=sampling_params, - ) - - # Extract and validate metrics - metrics = llm.get_metrics() - results = extract_acceptance_metrics(metrics, num_spec_tokens) - - actual_acceptance_length = results["acceptance_length"] - expected = model_config.expected_acceptance_length - actual_per_pos = results["acceptance_lengths_per_pos"] - expected_per_pos = model_config.expected_acceptance_lengths_per_pos - - # Calculate relative error for mean acceptance length - rel_error = abs(actual_acceptance_length - expected) / expected - - # Assert mean acceptance length within tolerance - assert rel_error <= DEFAULT_RTOL, ( - f"Acceptance length regression detected for {model_config.id}!\n" - f" Expected: {expected:.3f}\n" - f" Actual: {actual_acceptance_length:.3f}\n" - f" Relative error: {rel_error:.2%} (tolerance: {DEFAULT_RTOL:.2%})\n" - f" Drafts: {results['num_drafts']}, " - f"Accepted tokens: {results['num_accepted_tokens']}" - ) + try: + # Load MT-Bench prompts + tokenizer = llm.get_tokenizer() + prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS) + + # Run inference + sampling_params = SamplingParams( + temperature=0, + max_tokens=DEFAULT_OUTPUT_LEN, + ) + llm.generate( + [TokensPrompt(prompt_token_ids=ids) for ids in prompt_ids], + sampling_params=sampling_params, + ) - # Assert per-position acceptance lengths within tolerance (if expected values set) - if expected_per_pos and len(expected_per_pos) == len(actual_per_pos): - for pos, (actual, exp) in enumerate(zip(actual_per_pos, expected_per_pos)): - if exp > 0: - pos_rel_error = abs(actual - exp) / exp - assert pos_rel_error <= DEFAULT_RTOL, ( - f"Per-position acceptance length regression at pos {pos} " - f"for {model_config.id}!\n" - f" Expected: {exp:.3f}\n" - f" Actual: {actual:.3f}\n" - f" Relative error: {pos_rel_error:.2%} " - f"(tolerance: {DEFAULT_RTOL:.2%})" - ) - - print( - f"\n{model_config.id}: acceptance_length={actual_acceptance_length:.3f} " - f"(expected={expected:.3f}, rel_error={rel_error:.2%})" - ) - print(f" Per-position: {[f'{v:.3f}' for v in actual_per_pos]}") - if expected_per_pos: - print(f" Expected: {[f'{v:.3f}' for v in expected_per_pos]}") + # Extract and validate metrics + metrics = llm.get_metrics() + results = extract_acceptance_metrics(metrics, num_spec_tokens) + + actual_acceptance_length = results["acceptance_length"] + expected = model_config.expected_acceptance_length + actual_per_pos = results["acceptance_lengths_per_pos"] + expected_per_pos = model_config.expected_acceptance_lengths_per_pos + + # Calculate relative error for mean acceptance length + rel_error = abs(actual_acceptance_length - expected) / expected + + # Assert mean acceptance length within tolerance + assert rel_error <= DEFAULT_RTOL, ( + f"Acceptance length regression detected for {model_config.id}!\n" + f" Expected: {expected:.3f}\n" + f" Actual: {actual_acceptance_length:.3f}\n" + f" Relative error: {rel_error:.2%} (tolerance: {DEFAULT_RTOL:.2%})\n" + f" Drafts: {results['num_drafts']}, " + f"Accepted tokens: {results['num_accepted_tokens']}" + ) - # Cleanup - del llm + # Assert per-position acceptance lengths within tolerance + if expected_per_pos and len(expected_per_pos) == len(actual_per_pos): + for pos, (actual, exp) in enumerate(zip(actual_per_pos, expected_per_pos)): + if exp > 0: + pos_rel_error = abs(actual - exp) / exp + assert pos_rel_error <= DEFAULT_RTOL, ( + f"Per-position acceptance length regression at pos {pos} " + f"for {model_config.id}!\n" + f" Expected: {exp:.3f}\n" + f" Actual: {actual:.3f}\n" + f" Relative error: {pos_rel_error:.2%} " + f"(tolerance: {DEFAULT_RTOL:.2%})" + ) + + print( + f"\n{model_config.id}: acceptance_length={actual_acceptance_length:.3f} " + f"(expected={expected:.3f}, rel_error={rel_error:.2%})" + ) + print(f" Per-position: {[f'{v:.3f}' for v in actual_per_pos]}") + if expected_per_pos: + print(f" Expected: {[f'{v:.3f}' for v in expected_per_pos]}") + finally: + del llm From a0a9efb847ef5610240ac6a6d0b90b488026b975 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Fri, 16 Jan 2026 15:01:52 +0000 Subject: [PATCH 05/11] [Test] Address review comments for EAGLE3 acceptance length tests - Use VllmRunner context manager instead of direct LLM instantiation - Use monkeypatch.context() for proper env var scoping - Use AcceptanceMetrics TypedDict in return statement - Remove docstrings from TypedDict and dataclass definitions - Remove inline comments from constants - Remove prototyping skip condition (all configs have baselines) - Fix gpt-oss-20b expected position 2 value (0.3220 -> 0.3337) Signed-off-by: rahul-tuli --- .../v1/spec_decode/test_acceptance_length.py | 201 +++++++----------- 1 file changed, 79 insertions(+), 122 deletions(-) diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index 63eb7d925b5e..d61e6747a485 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -14,19 +14,16 @@ from typing import TypedDict import pytest -import torch +from tests.conftest import VllmRunner from tests.utils import large_gpu_mark -from vllm import LLM, SamplingParams +from vllm import SamplingParams from vllm.benchmarks.datasets import get_samples -from vllm.distributed import cleanup_dist_env_and_memory from vllm.inputs import TokensPrompt from vllm.v1.metrics.reader import Counter, Vector class AcceptanceMetrics(TypedDict): - """Typed dict for acceptance length metrics.""" - acceptance_length: float acceptance_lengths_per_pos: list[float] num_drafts: int @@ -35,8 +32,6 @@ class AcceptanceMetrics(TypedDict): @dataclass class Eagle3ModelConfig: - """Configuration for an EAGLE3 model pair.""" - verifier: str drafter: str expected_acceptance_length: float @@ -66,7 +61,7 @@ class Eagle3ModelConfig: verifier="openai/gpt-oss-20b", drafter="RedHatAI/gpt-oss-20b-speculator.eagle3", expected_acceptance_length=2.56, - expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.3220], + expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.3337], id="gpt-oss-20b-eagle3", ), ] @@ -76,7 +71,7 @@ class Eagle3ModelConfig: DEFAULT_NUM_PROMPTS = 80 DEFAULT_OUTPUT_LEN = 256 DEFAULT_MAX_MODEL_LEN = 16384 -DEFAULT_RTOL = 0.02 # 2% relative tolerance +DEFAULT_RTOL = 0.02 def get_mt_bench_prompts( @@ -136,21 +131,12 @@ def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> AcceptanceMetri count / num_drafts if num_drafts > 0 else 0.0 for count in acceptance_counts ] - return { - "acceptance_length": acceptance_length, - "acceptance_lengths_per_pos": acceptance_lengths_per_pos, - "num_drafts": num_drafts, - "num_accepted_tokens": num_accepted_tokens, - } - - -@pytest.fixture(autouse=True) -def cleanup_after_test(): - """Clean up GPU memory after each test.""" - yield - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() - torch._dynamo.reset() + return AcceptanceMetrics( + acceptance_length=acceptance_length, + acceptance_lengths_per_pos=acceptance_lengths_per_pos, + num_drafts=num_drafts, + num_accepted_tokens=num_accepted_tokens, + ) @large_gpu_mark(min_gb=40) @@ -164,101 +150,72 @@ def test_eagle3_acceptance_length( num_spec_tokens: int, monkeypatch: pytest.MonkeyPatch, ): - """ - Test EAGLE3 acceptance length does not regress. - - This test: - 1. Loads the MT-Bench dataset - 2. Runs inference with EAGLE3 speculative decoding - 3. Extracts acceptance length metrics - 4. Asserts the acceptance length is within tolerance of expected baseline - - Args: - model_config: Configuration for verifier/drafter model pair - num_spec_tokens: Number of speculative tokens to generate - monkeypatch: Pytest monkeypatch fixture - """ - # Skip if expected acceptance length is not set - if model_config.expected_acceptance_length <= 0: - pytest.skip( - f"Expected acceptance length not set for {model_config.id}. " - "Run baseline script to determine expected value." - ) - - # Allow insecure serialization for speculators - monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - - # Initialize LLM with speculative decoding - llm = LLM( - model=model_config.verifier, - speculative_config={ - "method": "eagle3", - "model": model_config.drafter, - "num_speculative_tokens": num_spec_tokens, - }, - tensor_parallel_size=1, - gpu_memory_utilization=0.9, - disable_log_stats=False, - max_model_len=DEFAULT_MAX_MODEL_LEN, - ) - - try: - # Load MT-Bench prompts - tokenizer = llm.get_tokenizer() - prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS) - - # Run inference - sampling_params = SamplingParams( - temperature=0, - max_tokens=DEFAULT_OUTPUT_LEN, - ) - llm.generate( - [TokensPrompt(prompt_token_ids=ids) for ids in prompt_ids], - sampling_params=sampling_params, - ) - - # Extract and validate metrics - metrics = llm.get_metrics() - results = extract_acceptance_metrics(metrics, num_spec_tokens) - - actual_acceptance_length = results["acceptance_length"] - expected = model_config.expected_acceptance_length - actual_per_pos = results["acceptance_lengths_per_pos"] - expected_per_pos = model_config.expected_acceptance_lengths_per_pos - - # Calculate relative error for mean acceptance length - rel_error = abs(actual_acceptance_length - expected) / expected - - # Assert mean acceptance length within tolerance - assert rel_error <= DEFAULT_RTOL, ( - f"Acceptance length regression detected for {model_config.id}!\n" - f" Expected: {expected:.3f}\n" - f" Actual: {actual_acceptance_length:.3f}\n" - f" Relative error: {rel_error:.2%} (tolerance: {DEFAULT_RTOL:.2%})\n" - f" Drafts: {results['num_drafts']}, " - f"Accepted tokens: {results['num_accepted_tokens']}" - ) - - # Assert per-position acceptance lengths within tolerance - if expected_per_pos and len(expected_per_pos) == len(actual_per_pos): - for pos, (actual, exp) in enumerate(zip(actual_per_pos, expected_per_pos)): - if exp > 0: - pos_rel_error = abs(actual - exp) / exp - assert pos_rel_error <= DEFAULT_RTOL, ( - f"Per-position acceptance length regression at pos {pos} " - f"for {model_config.id}!\n" - f" Expected: {exp:.3f}\n" - f" Actual: {actual:.3f}\n" - f" Relative error: {pos_rel_error:.2%} " - f"(tolerance: {DEFAULT_RTOL:.2%})" - ) - - print( - f"\n{model_config.id}: acceptance_length={actual_acceptance_length:.3f} " - f"(expected={expected:.3f}, rel_error={rel_error:.2%})" - ) - print(f" Per-position: {[f'{v:.3f}' for v in actual_per_pos]}") - if expected_per_pos: - print(f" Expected: {[f'{v:.3f}' for v in expected_per_pos]}") - finally: - del llm + """Test EAGLE3 acceptance length does not regress.""" + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + with VllmRunner( + model_name=model_config.verifier, + speculative_config={ + "method": "eagle3", + "model": model_config.drafter, + "num_speculative_tokens": num_spec_tokens, + }, + tensor_parallel_size=1, + gpu_memory_utilization=0.9, + disable_log_stats=False, + max_model_len=DEFAULT_MAX_MODEL_LEN, + ) as vllm_runner: + tokenizer = vllm_runner.llm.get_tokenizer() + prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS) + + sampling_params = SamplingParams( + temperature=0, + max_tokens=DEFAULT_OUTPUT_LEN, + ) + vllm_runner.llm.generate( + [TokensPrompt(prompt_token_ids=ids) for ids in prompt_ids], + sampling_params=sampling_params, + ) + + metrics = vllm_runner.llm.get_metrics() + results = extract_acceptance_metrics(metrics, num_spec_tokens) + + actual_acceptance_length = results["acceptance_length"] + expected = model_config.expected_acceptance_length + actual_per_pos = results["acceptance_lengths_per_pos"] + expected_per_pos = model_config.expected_acceptance_lengths_per_pos + + rel_error = abs(actual_acceptance_length - expected) / expected + + assert rel_error <= DEFAULT_RTOL, ( + f"Acceptance length regression detected for {model_config.id}!\n" + f" Expected: {expected:.3f}\n" + f" Actual: {actual_acceptance_length:.3f}\n" + f" Relative error: {rel_error:.2%} (tolerance: {DEFAULT_RTOL:.2%})\n" + f" Drafts: {results['num_drafts']}, " + f"Accepted tokens: {results['num_accepted_tokens']}" + ) + + if expected_per_pos and len(expected_per_pos) == len(actual_per_pos): + for pos, (actual, exp) in enumerate( + zip(actual_per_pos, expected_per_pos) + ): + if exp > 0: + pos_rel_error = abs(actual - exp) / exp + assert pos_rel_error <= DEFAULT_RTOL, ( + f"Per-position acceptance length regression at pos {pos} " + f"for {model_config.id}!\n" + f" Expected: {exp:.3f}\n" + f" Actual: {actual:.3f}\n" + f" Relative error: {pos_rel_error:.2%} " + f"(tolerance: {DEFAULT_RTOL:.2%})" + ) + + print( + f"\n{model_config.id}: acceptance_length={actual_acceptance_length:.3f}" + f" (expected={expected:.3f}, rel_error={rel_error:.2%})" + ) + print(f" Per-position: {[f'{v:.3f}' for v in actual_per_pos]}") + if expected_per_pos: + print(f" Expected: {[f'{v:.3f}' for v in expected_per_pos]}") From ec01d6c99b79d9ba221fe00df748021e147376e6 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Mon, 19 Jan 2026 14:11:06 +0000 Subject: [PATCH 06/11] Added: multiple tp and attention backends Signed-off-by: rahul-tuli --- .../v1/spec_decode/test_acceptance_length.py | 68 +++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index d61e6747a485..207f1ed279f4 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -14,12 +14,15 @@ from typing import TypedDict import pytest +import torch from tests.conftest import VllmRunner from tests.utils import large_gpu_mark from vllm import SamplingParams from vllm.benchmarks.datasets import get_samples from vllm.inputs import TokensPrompt +from vllm.platforms import current_platform +from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.metrics.reader import Counter, Vector @@ -71,7 +74,58 @@ class Eagle3ModelConfig: DEFAULT_NUM_PROMPTS = 80 DEFAULT_OUTPUT_LEN = 256 DEFAULT_MAX_MODEL_LEN = 16384 -DEFAULT_RTOL = 0.02 +DEFAULT_RTOL = 0.05 + +# TP sizes to test +TP_SIZES = [1, 2, 4] + + +# Backends excluded from testing due to significantly different behavior +EXCLUDED_BACKENDS = {"FLEX_ATTENTION"} + + +def get_available_attention_backends() -> list[str]: + """Get list of available attention backends for the current platform.""" + if not hasattr(current_platform, "get_valid_backends"): + return ["FLASH_ATTN"] + + device_capability = current_platform.get_device_capability() + if device_capability is None: + return ["FLASH_ATTN"] + + attn_selector_config = AttentionSelectorConfig( + head_size=128, + dtype=torch.bfloat16, + kv_cache_dtype=None, + block_size=None, + use_mla=False, + has_sink=False, + use_sparse=False, + use_mm_prefix=False, + ) + + valid_backends, _ = current_platform.get_valid_backends( + device_capability=device_capability, + attn_selector_config=attn_selector_config, + ) + + return [ + backend.name + for backend, _ in valid_backends + if backend.name not in EXCLUDED_BACKENDS + ] + + +def get_attention_backend_params() -> list[pytest.param]: + """Generate pytest params for available attention backends.""" + available = get_available_attention_backends() + return [pytest.param(backend, id=backend.lower()) for backend in available] + + +def get_tp_size_params() -> list[pytest.param]: + """Generate pytest params for TP sizes based on available GPUs.""" + num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 + return [pytest.param(tp, id=f"tp{tp}") for tp in TP_SIZES if tp <= num_gpus] def get_mt_bench_prompts( @@ -145,14 +199,19 @@ def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> AcceptanceMetri [pytest.param(config, id=config.id) for config in EAGLE3_MODEL_CONFIGS], ) @pytest.mark.parametrize("num_spec_tokens", [DEFAULT_NUM_SPEC_TOKENS]) +@pytest.mark.parametrize("tp_size", get_tp_size_params()) +@pytest.mark.parametrize("attention_backend", get_attention_backend_params()) def test_eagle3_acceptance_length( model_config: Eagle3ModelConfig, num_spec_tokens: int, + tp_size: int, + attention_backend: str, monkeypatch: pytest.MonkeyPatch, ): """Test EAGLE3 acceptance length does not regress.""" with monkeypatch.context() as m: m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attention_backend) with VllmRunner( model_name=model_config.verifier, @@ -161,8 +220,8 @@ def test_eagle3_acceptance_length( "model": model_config.drafter, "num_speculative_tokens": num_spec_tokens, }, - tensor_parallel_size=1, - gpu_memory_utilization=0.9, + tensor_parallel_size=tp_size, + gpu_memory_utilization=0.7, disable_log_stats=False, max_model_len=DEFAULT_MAX_MODEL_LEN, ) as vllm_runner: @@ -213,7 +272,8 @@ def test_eagle3_acceptance_length( ) print( - f"\n{model_config.id}: acceptance_length={actual_acceptance_length:.3f}" + f"\n{model_config.id} [tp={tp_size}, backend={attention_backend}]: " + f"acceptance_length={actual_acceptance_length:.3f}" f" (expected={expected:.3f}, rel_error={rel_error:.2%})" ) print(f" Per-position: {[f'{v:.3f}' for v in actual_per_pos]}") From 5db9a9af31356a5a902d0a3b440426b253016020 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Mon, 19 Jan 2026 14:56:20 +0000 Subject: [PATCH 07/11] Cleanups Signed-off-by: rahul-tuli --- .../v1/spec_decode/test_acceptance_length.py | 47 +++++++++---------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index 207f1ed279f4..0f908bee452c 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -11,7 +11,6 @@ from dataclasses import dataclass, field from types import SimpleNamespace -from typing import TypedDict import pytest import torch @@ -22,17 +21,11 @@ from vllm.benchmarks.datasets import get_samples from vllm.inputs import TokensPrompt from vllm.platforms import current_platform +from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.metrics.reader import Counter, Vector -class AcceptanceMetrics(TypedDict): - acceptance_length: float - acceptance_lengths_per_pos: list[float] - num_drafts: int - num_accepted_tokens: int - - @dataclass class Eagle3ModelConfig: verifier: str @@ -40,6 +33,8 @@ class Eagle3ModelConfig: expected_acceptance_length: float expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list) id: str = "" + # Backends that are incompatible with this model (will be skipped) + excluded_backends: set[AttentionBackendEnum] = field(default_factory=set) # Model configurations for EAGLE3 acceptance length tests. @@ -66,6 +61,9 @@ class Eagle3ModelConfig: expected_acceptance_length=2.56, expected_acceptance_lengths_per_pos=[0.7165, 0.5120, 0.3337], id="gpt-oss-20b-eagle3", + # FLASHINFER incompatible: gpt-oss-20b uses sink attention which + # FLASHINFER does not support ("sink setting not supported") + excluded_backends={AttentionBackendEnum.FLASHINFER}, ), ] @@ -81,11 +79,10 @@ class Eagle3ModelConfig: # Backends excluded from testing due to significantly different behavior -EXCLUDED_BACKENDS = {"FLEX_ATTENTION"} +EXCLUDED_BACKENDS = {AttentionBackendEnum.FLEX_ATTENTION} def get_available_attention_backends() -> list[str]: - """Get list of available attention backends for the current platform.""" if not hasattr(current_platform, "get_valid_backends"): return ["FLASH_ATTN"] @@ -112,18 +109,15 @@ def get_available_attention_backends() -> list[str]: return [ backend.name for backend, _ in valid_backends - if backend.name not in EXCLUDED_BACKENDS + if backend not in EXCLUDED_BACKENDS ] -def get_attention_backend_params() -> list[pytest.param]: - """Generate pytest params for available attention backends.""" - available = get_available_attention_backends() - return [pytest.param(backend, id=backend.lower()) for backend in available] +def get_attention_backend_params() -> list[str]: + return get_available_attention_backends() def get_tp_size_params() -> list[pytest.param]: - """Generate pytest params for TP sizes based on available GPUs.""" num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 return [pytest.param(tp, id=f"tp{tp}") for tp in TP_SIZES if tp <= num_gpus] @@ -157,8 +151,7 @@ def get_mt_bench_prompts( return prompt_ids -def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> AcceptanceMetrics: - """Extract acceptance length metrics from LLM metrics.""" +def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> dict: num_drafts = 0 num_accepted_tokens = 0 acceptance_counts = [0] * num_spec_tokens @@ -185,12 +178,12 @@ def extract_acceptance_metrics(metrics, num_spec_tokens: int) -> AcceptanceMetri count / num_drafts if num_drafts > 0 else 0.0 for count in acceptance_counts ] - return AcceptanceMetrics( - acceptance_length=acceptance_length, - acceptance_lengths_per_pos=acceptance_lengths_per_pos, - num_drafts=num_drafts, - num_accepted_tokens=num_accepted_tokens, - ) + return { + "acceptance_length": acceptance_length, + "acceptance_lengths_per_pos": acceptance_lengths_per_pos, + "num_drafts": num_drafts, + "num_accepted_tokens": num_accepted_tokens, + } @large_gpu_mark(min_gb=40) @@ -208,7 +201,11 @@ def test_eagle3_acceptance_length( attention_backend: str, monkeypatch: pytest.MonkeyPatch, ): - """Test EAGLE3 acceptance length does not regress.""" + # Skip if this backend is incompatible with the model + backend_enum = AttentionBackendEnum[attention_backend] + if backend_enum in model_config.excluded_backends: + pytest.skip(f"{attention_backend} is incompatible with {model_config.id}") + with monkeypatch.context() as m: m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") m.setenv("VLLM_ATTENTION_BACKEND", attention_backend) From 9803cc798b7e78dcc7361aee2a31c57c0de43c04 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Mon, 19 Jan 2026 15:01:54 +0000 Subject: [PATCH 08/11] Cleanups Signed-off-by: rahul-tuli --- tests/v1/spec_decode/test_acceptance_length.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/spec_decode/test_acceptance_length.py b/tests/v1/spec_decode/test_acceptance_length.py index 0f908bee452c..1a615878bb8b 100644 --- a/tests/v1/spec_decode/test_acceptance_length.py +++ b/tests/v1/spec_decode/test_acceptance_length.py @@ -125,7 +125,6 @@ def get_tp_size_params() -> list[pytest.param]: def get_mt_bench_prompts( tokenizer, num_prompts: int = DEFAULT_NUM_PROMPTS ) -> list[list[int]]: - """Load prompts from MT-Bench dataset.""" args = SimpleNamespace( dataset_name="hf", dataset_path="philschmid/mt-bench", From 13f6630a9ea78bee4bd80bb6e842e55e374eec9a Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Tue, 20 Jan 2026 22:27:24 +0800 Subject: [PATCH 09/11] [XPU]Support AgRsAll2AllManager on XPU device (#32654) Signed-off-by: yisheng --- .../device_communicators/xpu_communicator.py | 137 +++++++++++++++++- 1 file changed, 130 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index f3d9262d20cf..6bc26b6f3b1c 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -23,23 +23,146 @@ def __init__( ): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: - if self.all2all_backend != "naive": # type: ignore[has-type] - logger.warning( - "`%s` all2all manager is not supported on XPU. " - "Falling back to `naive` all2all manager for XPU.", - self.all2all_backend, # type: ignore[has-type] - ) - self.all2all_backend = "naive" if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") + elif self.all2all_backend == "allgather_reducescatter": + from .all2all import AgRsAll2AllManager + + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AgRs manager on XPU device.") + + else: # type: ignore[has-type] + logger.warning( + "`%s` all2all manager is not supported on XPU. " + "Falling back to AgRs manager for XPU, " + "which is the Default backend", + self.all2all_backend, # type: ignore[has-type] + ) + from .all2all import AgRsAll2AllManager + + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AgRs manager on XPU device.") + def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size,) + input_tensor.shape[1:] + + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) + + dist.reduce_scatter_tensor(output, input_tensor) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None + ): + world_size = self.world_size + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + if sizes is not None: + assert len(sizes) == world_size + assert input_tensor.shape[0] == sum(sizes) + chunk_size = sizes[self.rank_in_group] + else: + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size,) + input_tensor.shape[1:] + + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) + if sizes is not None and sizes.count(sizes[0]) != len(sizes): + # if inputs shape in different ranks is not the same using reduce_scatter + input_splits = list(input_tensor.split(sizes, dim=0)) + dist.reduce_scatter(output, input_splits) + else: + dist.reduce_scatter_tensor(output, input_tensor) + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def all_gatherv( + self, + input_: torch.Tensor | list[torch.Tensor], + dim: int = 0, + sizes: list[int] | None = None, + ): + if dim != 0: + raise NotImplementedError("only dim 0 all-gatherv is supported") + world_size = self.world_size + + # 'sizes' is not needed if all inputs in the same group have the same + # shape + if sizes is not None and all(s == sizes[0] for s in sizes): + sizes = None + + def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): + input_size = input_.size() + if sizes is not None: + assert len(sizes) == world_size + assert input_.shape[dim] == sizes[self.rank_in_group], ( + f"{input_.shape[dim]} != {sizes[self.rank_in_group]}" + ) + output_size = (sum(sizes),) + input_size[1:] + else: + output_size = (input_size[0] * world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + + if sizes is not None: + all_gather_list = [] + for size in sizes: + all_gather_list.append( + torch.empty( + (size,) + input_.shape[1:], + dtype=input_.dtype, + device=input_.device, + ) + ) + dist.all_gather(all_gather_list, input_) + output_tensor = torch.cat(all_gather_list, dim=0) + else: + dist.all_gather([output_tensor], input_) + return output_tensor + + if isinstance(input_, torch.Tensor): + return _all_gather_single(input_, sizes) + + output_list = [] + for inp in input_: + output_list.append(_all_gather_single(inp, sizes=sizes)) + return output_list + def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 ) -> torch.Tensor | None: From 7901109ea5c9794a5e7e01481343235e3a042971 Mon Sep 17 00:00:00 2001 From: linhaifeng <1371675203@qq.com> Date: Wed, 21 Jan 2026 00:13:39 +0800 Subject: [PATCH 10/11] [Bugfix] Fix Off-by-one error in _num_tokens_to_min_blocks calculation (#32603) Signed-off-by: linhaifeng <1371675203@qq.com> --- tests/kernels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ccdacf40c430..7763be0cb5bf 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -609,7 +609,7 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: Compute the minimum number of blocks required to hold num_tokens tokens, given block_size """ - return (num_tokens + block_size) // block_size + return (num_tokens + block_size - 1) // block_size def make_empty_slot_mapping_tensor(device: torch.device | str): @@ -694,7 +694,7 @@ def make_block_tables_slot_mapping( For a sequence with num_tokens tokens the minimum number of required KV cache blocks is - num_blocks = (num_tokens + block_size) // block_size + num_blocks = (num_tokens + block_size - 1) // block_size Then the minimum KV cache size in blocks is From 4ca62a0dbdba8b8aaf40438c99d20cdb56db8d5d Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Wed, 21 Jan 2026 00:19:21 +0800 Subject: [PATCH 11/11] [PluggableLayer][1/N] Define PluggableLayer (#32331) Signed-off-by: whx-sjtu <2952154980@qq.com> --- docs/design/custom_op.md | 9 -- .../model_executor/test_enabled_custom_ops.py | 10 +- vllm/model_executor/custom_op.py | 100 +++++++++++++++--- vllm/model_executor/layers/mla.py | 19 ++-- 4 files changed, 99 insertions(+), 39 deletions(-) diff --git a/docs/design/custom_op.md b/docs/design/custom_op.md index 13c2915abe8f..3f4934b15699 100644 --- a/docs/design/custom_op.md +++ b/docs/design/custom_op.md @@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n `CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively. -??? code - - ```python - class CustomOp(nn.Module): - - op_registry: dict[str, type["CustomOp"]] = {} - op_registry_oot: dict[str, type["CustomOp"]] = {} - ``` - We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later. When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method. diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 8ee1b1a37ca6..316caf06b29c 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -11,7 +11,7 @@ get_cached_compilation_config, set_current_vllm_config, ) -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import CustomOp, op_registry from vllm.model_executor.layers.activation import ( GeluAndMul, ReLUSquaredActivation, @@ -98,17 +98,17 @@ def test_enabled_ops( ops_enabled = [bool(x) for x in ops_enabled] assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert op_registry["rms_norm"].enabled() == ops_enabled[0] assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + assert op_registry["silu_and_mul"].enabled() == ops_enabled[1] assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2] # If registered, subclasses should follow their own name assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] + assert op_registry["relu3"].enabled() == ops_enabled[3] # Unregistered subclass class SiluAndMul2(SiluAndMul): diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 81ba544b4813..6fe252fa27ee 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -11,6 +11,86 @@ logger = init_logger(__name__) +# Dictionary of all custom ops (classes, indexed by registered name). +# To check if an op with a name is enabled, call .enabled() on the class. +# Examples: +# - MyOp.enabled() +# - op_registry["my_op"].enabled() +op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} +op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} + + +class PluggableLayer(nn.Module): + """ + Base class for pluggable layers. + + A PluggableLayer is a *module-composing* abstraction: it may instantiate other + ``torch.nn.Module`` objects as sub-layers, and its functionality depends on + these sub-layers following a generalized invocation sequence. Also, it is stateful + and may hold parameters or buffers. + + Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform + ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement + of the entire layer class at instantiation time, allowing customized + initialization and submodule composition. + """ + + def __new__(cls, *args, **kwargs): + try: + layer_class_name = cls.__name__ + except AttributeError: + raise TypeError( + f"Cannot instantiate '{cls.__name__}': its 'name' attribute " + f"was not set, possibly because it was not decorated with " + f"@PluggableLayer.register, or it's the PluggableLayer itself." + ) from None + + if layer_class_name not in op_registry_oot: + layer_cls_to_instantiate = cls + else: + layer_cls_to_instantiate = op_registry_oot[layer_class_name] + logger.debug( + "Instantiating pluggable layer: %s using %s", + layer_class_name, + str(layer_cls_to_instantiate), + ) + return super().__new__(layer_cls_to_instantiate) + + # Decorator to register pluggable layers. + @classmethod + def register(cls, name: str): + def decorator(op_cls): + assert name not in op_registry, f"Duplicate op name: {name}" + op_cls.name = name + op_registry[name] = op_cls + return op_cls + + return decorator + + # Decorator to register out-of-tree(oot) pluggable layers. + # For OOT pluggable layers: + # if in-tree layer class is registered with an oot_custom_layer, + # the oot_custom_layer will be used instead. + @classmethod + def register_oot(cls, _decorated_layer_cls=None, name: str | None = None): + def decorator(layer_cls): + reg_name = name if name is not None else cls.__name__ + assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}" + layer_cls.name = reg_name + op_registry_oot[reg_name] = layer_cls + return layer_cls + + if _decorated_layer_cls is None: + # Called with parentheses: @PluggableLayer.register_oot() + # or @PluggableLayer.register_oot(name="...") + return decorator + elif isinstance(_decorated_layer_cls, type): # Check if it's a class + # Called without parentheses: @PluggableLayer.register_oot + return decorator(_decorated_layer_cls) + else: + raise TypeError("Decorator can only be applied to classes.") + + class CustomOp(nn.Module): """ Base class for custom ops. @@ -27,10 +107,10 @@ def __new__(cls, *args, **kwargs): f"@CustomOp.register, or it's the CustomOp base class itself." ) from None - if op_name not in cls.op_registry_oot: + if op_name not in op_registry_oot: op_cls_to_instantiate = cls else: - op_cls_to_instantiate = cls.op_registry_oot[op_name] + op_cls_to_instantiate = op_registry_oot[op_name] logger.debug( "Instantiating custom op: %s using %s", op_name, @@ -150,21 +230,13 @@ def default_on() -> bool: return not count_none > 0 or count_all > 0 - # Dictionary of all custom ops (classes, indexed by registered name). - # To check if an op with a name is enabled, call .enabled() on the class. - # Examples: - # - MyOp.enabled() - # - op_registry["my_op"].enabled() - op_registry: dict[str, type["CustomOp"]] = {} - op_registry_oot: dict[str, type["CustomOp"]] = {} - # Decorator to register custom ops. @classmethod def register(cls, name: str): def decorator(op_cls): - assert name not in cls.op_registry, f"Duplicate op name: {name}" + assert name not in op_registry, f"Duplicate op name: {name}" op_cls.name = name - cls.op_registry[name] = op_cls + op_registry[name] = op_cls return op_cls return decorator @@ -182,9 +254,9 @@ def decorator(op_cls): def register_oot(cls, _decorated_op_cls=None, name: str | None = None): def decorator(op_cls): reg_name = name if name is not None else cls.__name__ - assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" + assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}" op_cls.name = reg_name - cls.op_registry_oot[reg_name] = op_cls + op_registry_oot[reg_name] = op_cls return op_cls if _decorated_op_cls is None: diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 65541d2a485a..2549f1221f36 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -6,7 +6,7 @@ from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.quantization import QuantizationConfig @@ -30,13 +30,13 @@ class MLAModules: # --8<-- [start:multi_head_latent_attention] -@CustomOp.register("multi_head_latent_attention") -class MultiHeadLatentAttentionWrapper(CustomOp): - """MLA layer registered as CustomOp to allow OOT backends to add +@PluggableLayer.register("multi_head_latent_attention") +class MultiHeadLatentAttentionWrapper(PluggableLayer): + """Pluggable MLA layer which allows OOT backends to add custom implementations of the outer MLA layer (including rope & o_proj). - Note that currently MLA ignores the enable/disable mechanism of CustomOp - because there is only one in-tree implementation in forward_native. - TODO: implement this with a new PluggableLayer mechanism. + Note that currently oot platforms can still use CustomOp.register_oot to + replace MLA layer entirly, although we use PluggableLayer to register + this layer now. This class takes positions and hidden_states as input. The input tensors can either contain prefill tokens or decode tokens. @@ -110,7 +110,7 @@ def __init__( self.prefix = prefix - def forward_native( + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -174,6 +174,3 @@ def forward_native( ) return self.o_proj(attn_out)[0] - - def forward_cuda(self, *args, **kwargs): - return self.forward_native(*args, **kwargs)