diff --git a/.gitignore b/.gitignore index 30372720..86e1584b 100644 --- a/.gitignore +++ b/.gitignore @@ -197,3 +197,6 @@ docs/superpowers/ # User-specific local dev configs; do not commit CLAUDE.local.md + +# Generated dataset cache (created by Dataset.get_dataloader()) +dataset_cache/ diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py index 50a3163d..0ef46e2a 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py @@ -17,6 +17,7 @@ import argparse import asyncio +import logging from contextlib import AbstractContextManager, nullcontext from pathlib import Path @@ -91,10 +92,16 @@ async def main() -> None: # Using ternary operator causes errors in MyPy object type coalescing # (coalesces to 'object' not 'AbstractContextManager[TokenizePool | None]') + pool_cm: AbstractContextManager[TokenizePool | None] if args.tokenizer: - pool_cm: AbstractContextManager[TokenizePool | None] = TokenizePool( - args.tokenizer, n_workers=args.tokenizer_workers - ) + try: + pool_cm = TokenizePool(args.tokenizer, n_workers=args.tokenizer_workers) + except Exception as e: + logging.warning( + f"Failed to load tokenizer '{args.tokenizer}': {e}. " + "ISL/OSL/TPOT token metrics will be unavailable." + ) + pool_cm = nullcontext() else: pool_cm = nullcontext() diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 73c3427f..9641c7a8 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -27,11 +27,13 @@ import json import logging import platform +import random import shutil import signal import tempfile import uuid from dataclasses import dataclass, field +from dataclasses import replace as dataclass_replace from datetime import datetime from pathlib import Path from typing import Any @@ -70,7 +72,7 @@ TestType, ) from inference_endpoint.core.types import QueryResult -from inference_endpoint.dataset_manager.dataset import Dataset +from inference_endpoint.dataset_manager.dataset import Dataset, SaltedDataset from inference_endpoint.dataset_manager.factory import DataLoaderFactory from inference_endpoint.endpoint_client.cpu_affinity import AffinityPlan, pin_loadgen from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient @@ -347,6 +349,33 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: """Build the phase list from BenchmarkContext.""" phases: list[PhaseConfig] = [] + # Warmup phase (optional, before performance) + warmup_cfg = ctx.config.settings.warmup + if warmup_cfg.enabled: + warmup_dataset: Dataset = ( + SaltedDataset(ctx.dataloader) if warmup_cfg.salt else ctx.dataloader + ) + warmup_rt = dataclass_replace( + ctx.rt_settings, + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=ctx.dataloader.num_samples(), + n_samples_to_issue=warmup_cfg.n_requests, + min_sample_count=1, + rng_sched=random.Random(warmup_cfg.warmup_random_seed), + rng_sample_index=random.Random(warmup_cfg.warmup_random_seed + 1), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phases.append( + PhaseConfig( + "warmup", + warmup_rt, + warmup_dataset, + PhaseType.WARMUP, + drain_after=warmup_cfg.drain, + ) + ) + # Performance phase phases.append( PhaseConfig( @@ -525,12 +554,31 @@ async def _run_benchmark_async( phases = _build_phases(ctx) report: Report | None = None + # Global wall-clock timeout covers warmup + performance + accuracy phases + # combined, and bounds the warmup drain so a dropped request can't hang forever. + global_timeout_handle = None + max_duration_ms = ctx.rt_settings.max_duration_ms + if max_duration_ms is not None: + + def _on_global_timeout() -> None: + logger.warning( + "Global experiment timeout reached (%d ms); stopping session.", + max_duration_ms, + ) + session.stop() + + global_timeout_handle = loop.call_later( + max_duration_ms / 1000.0, _on_global_timeout + ) + loop.add_signal_handler(signal.SIGINT, session.stop) try: result = await session.run(phases) except Exception as e: raise ExecutionError(f"Benchmark execution failed: {e}") from e finally: + if global_timeout_handle is not None: + global_timeout_handle.cancel() loop.remove_signal_handler(signal.SIGINT) logger.info("Cleaning up...") try: diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 6a1884b4..c53e6e44 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -392,6 +392,30 @@ def _validate_completeness(self) -> Self: return self +class WarmupConfig(BaseModel): + """Warmup phase configuration. Runs before the performance phase; results are not recorded.""" + + model_config = ConfigDict(extra="forbid", frozen=True) + + enabled: bool = Field( + False, description="Enable warmup phase before performance run" + ) + n_requests: int | None = Field( + None, gt=0, description="Warmup request count (None = full dataset once)" + ) + salt: bool = Field( + False, description="Prepend a unique random hex salt to each warmup prompt" + ) + drain: bool = Field( + False, + description="Drain in-flight warmup requests before starting the performance phase", + ) + warmup_random_seed: int = Field( + 42, + description="RNG seed for warmup scheduling and sample ordering", + ) + + @cyclopts.Parameter(name="*") class Settings(BaseModel): """Test settings.""" @@ -401,6 +425,7 @@ class Settings(BaseModel): runtime: RuntimeConfig = Field(default_factory=RuntimeConfig) load_pattern: LoadPattern = Field(default_factory=LoadPattern) client: HTTPClientConfig = Field(default_factory=HTTPClientConfig) + warmup: WarmupConfig = Field(default_factory=WarmupConfig) class OfflineSettings(Settings): diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 3a8e004f..f39db1a7 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -68,6 +68,12 @@ settings: max_idle_time: 4.0 # Discard connections idle longer than this (seconds) min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system + warmup: + enabled: false # Enable warmup phase before performance run + n_requests: null # Warmup request count (None = full dataset once) + salt: false # Prepend a unique random hex salt to each warmup prompt + drain: false # Drain in-flight warmup requests before starting the performance phase + warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index faabffde..71d4efea 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -68,6 +68,12 @@ settings: max_idle_time: 4.0 # Discard connections idle longer than this (seconds) min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system + warmup: + enabled: false # Enable warmup phase before performance run + n_requests: null # Warmup request count (None = full dataset once) + salt: false # Prepend a unique random hex salt to each warmup prompt + drain: false # Drain in-flight warmup requests before starting the performance phase + warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index e9b7a673..83d79c4e 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -68,6 +68,12 @@ settings: max_idle_time: 4.0 # Discard connections idle longer than this (seconds) min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system + warmup: + enabled: false # Enable warmup phase before performance run + n_requests: null # Warmup request count (None = full dataset once) + salt: false # Prepend a unique random hex salt to each warmup prompt + drain: false # Drain in-flight warmup requests before starting the performance phase + warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/dataset_manager/dataset.py b/src/inference_endpoint/dataset_manager/dataset.py index 93e3cfa6..bd5f12f0 100644 --- a/src/inference_endpoint/dataset_manager/dataset.py +++ b/src/inference_endpoint/dataset_manager/dataset.py @@ -276,11 +276,12 @@ class Dataset: def __init_subclass__( cls, dataset_id: str | None = None, + register: bool = True, **kwargs, ): super().__init_subclass__(**kwargs) - if not inspect.isabstract(cls): + if register and not inspect.isabstract(cls): if dataset_id is None: dataset_id = cls.__name__ cls.DATASET_ID = dataset_id @@ -411,7 +412,7 @@ def num_samples(self) -> int: @classmethod def get_dataloader( cls, - datasets_dir: Path = Path("datasets"), + datasets_dir: Path = Path("dataset_cache"), num_repeats: int = 1, transforms: list[Transform] | None = None, force_regenerate: bool = False, @@ -431,6 +432,74 @@ def get_dataloader( return cls(df, transforms=transforms, repeats=num_repeats) +class SaltedDataset(Dataset, register=False): + """Wraps a loaded Dataset, prepending a unique random salt to each prompt on load_sample(). + + Each call to load_sample() generates a fresh salt, so reused samples (when + n_requests > dataset size) each receive a distinct salt. + """ + + def __init__(self, inner: Dataset) -> None: + # Skip Dataset.__init__ — all state is delegated to inner + self._inner = inner + self.dataframe = None + self.transforms = None + self.repeats = inner.repeats + self.logger = getLogger(__name__) + + @property # type: ignore[override] + def data(self) -> list[Any] | None: + return self._inner.data + + @data.setter + def data(self, value: list[Any] | None) -> None: + self._inner.data = value + + def load( + self, + adapter: "HttpRequestAdapter | None" = None, + api_type: APIType | None = None, + model_params: ModelParams | None = None, + force: bool = False, + ) -> None: + pass # Inner dataset already loaded + + def load_sample(self, index: int) -> Any: + data = self._inner.load_sample(index) + if not isinstance(data, dict): + return data + if "input_tokens" in data and "prompt" not in data: + self.logger.warning( + "SaltedDataset: sample has 'input_tokens' but no 'prompt' — " + "salt cannot be applied to pre-tokenized input; KV-cache reuse may not be prevented" + ) + return data + if "prompt" not in data: + return data + prompt = data["prompt"] + salt = os.urandom(8).hex() + if isinstance(prompt, str): + return {**data, "prompt": f"[{salt}] {prompt}"} + if isinstance(prompt, list) and prompt: + # Find the first text part at any index (image-first prompts place text at index 1+) + for i, part in enumerate(prompt): + if isinstance(part, dict) and part.get("type") == "text": + salted_parts = [ + *prompt[:i], + {**part, "text": f"[{salt}] {part['text']}"}, + *prompt[i + 1 :], + ] + return {**data, "prompt": salted_parts} + self.logger.warning( + "SaltedDataset: multimodal prompt has no text part — " + "salt cannot be applied; KV-cache reuse may not be prevented" + ) + return data # unsupported prompt type — skip salting + + def num_samples(self) -> int: + return self._inner.num_samples() + + class EmptyDataset(Dataset): """Empty dataset to be used as performance dataset when running only accuracy tests.""" diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 1c8ad992..c324f976 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -22,7 +22,6 @@ import asyncio import logging -import os import time import uuid from collections.abc import Callable @@ -44,8 +43,6 @@ logger = logging.getLogger(__name__) -_WARMUP_ENABLED = os.environ.get("ENABLE_WARMUP") == "1" - # --------------------------------------------------------------------------- # Phase configuration @@ -68,6 +65,7 @@ class PhaseConfig: runtime_settings: RuntimeSettings dataset: Dataset phase_type: PhaseType = PhaseType.PERFORMANCE + drain_after: bool = True # --------------------------------------------------------------------------- @@ -242,6 +240,7 @@ def __init__( self._stop_requested = False self._done = False self._current_phase_issuer: PhaseIssuer | None = None + self._current_phase_type: PhaseType | None = None self._current_strategy: LoadStrategy | None = None self._recv_task: asyncio.Task | None = None self._strategy_task: asyncio.Task | None = None @@ -274,12 +273,6 @@ async def run(self, phases: list[PhaseConfig]) -> SessionResult: for phase in phases: if self._stop_requested: break - if phase.phase_type == PhaseType.WARMUP and not _WARMUP_ENABLED: - logger.info( - "Skipping warmup phase %s (set ENABLE_WARMUP=1 to enable)", - phase.name, - ) - continue result = await self._run_phase(phase) if result is not None: phase_results.append(result) @@ -318,6 +311,7 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: ) self._current_phase_issuer = phase_issuer + self._current_phase_type = phase.phase_type self._current_strategy = strategy # Performance phases get tracking events @@ -333,8 +327,7 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: finally: self._strategy_task = None - # Drain in-flight (skip for warmup — keep concurrency hot) - if phase.phase_type != PhaseType.WARMUP: + if phase.drain_after: await self._drain_inflight(phase_issuer) if phase.phase_type == PhaseType.PERFORMANCE: @@ -363,9 +356,9 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: async def _drain_inflight(self, phase_issuer: PhaseIssuer) -> None: """Wait for all in-flight responses from this phase to complete. - Currently, there is no timeout for the drain step. In the future, - we can possibly add a dynamic timeout based on the rate of completion - throughout the current phase.""" + Bounded by the global experiment timeout: if the caller schedules a + loop.call_later that calls stop(), stop() sets _drain_event, unblocking + this wait without leaving it hung indefinitely.""" if phase_issuer.inflight <= 0 or self._stop_requested: return logger.info("Draining %d in-flight responses...", phase_issuer.inflight) @@ -425,7 +418,10 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: self._drain_event.set() if self._current_strategy: self._current_strategy.on_query_complete(query_id) - if self._on_sample_complete: + if ( + self._on_sample_complete + and self._current_phase_type != PhaseType.WARMUP + ): self._on_sample_complete(resp) elif isinstance(resp, StreamChunk): diff --git a/src/inference_endpoint/testing/echo_server.py b/src/inference_endpoint/testing/echo_server.py index 6555f2e6..a0049d1f 100644 --- a/src/inference_endpoint/testing/echo_server.py +++ b/src/inference_endpoint/testing/echo_server.py @@ -17,12 +17,14 @@ import argparse import asyncio +import inspect import json import logging import threading import time import uuid from abc import abstractmethod +from collections.abc import Awaitable, Callable from aiohttp import web @@ -31,6 +33,8 @@ from inference_endpoint.openai.openai_types_gen import CreateChatCompletionRequest from inference_endpoint.utils.logging import setup_logging +RequestHandler = Callable[[web.Request], web.Response | Awaitable[web.Response]] + class HTTPServer: @property @@ -49,11 +53,17 @@ def stop(self): class EchoServer(HTTPServer): def __init__( - self, *, host: str = "127.0.0.1", port: int = 0, max_osl: int | None = None + self, + *, + host: str = "127.0.0.1", + port: int = 0, + max_osl: int | None = None, + request_handler: RequestHandler | None = None, ): self.host = host self.port = port # If 0, will auto-assign available port self.max_osl = max_osl + self._request_handler = request_handler self._actual_port = None # Store the actual port after binding self.app = None @@ -97,6 +107,15 @@ def get_response(self, request: str) -> str: """ return request + async def _dispatch(self, request: web.Request) -> web.Response | None: + """Call the custom request_handler if set; return None to fall through to defaults.""" + if self._request_handler is None: + return None + result = self._request_handler(request) + if inspect.isawaitable(result): + result = await result + return result + async def _handle_echo_request(self, request: web.Request) -> web.Response: """ Handle a generic HTTP request and return a JSON response that echoes all request details. @@ -105,6 +124,10 @@ async def _handle_echo_request(self, request: web.Request) -> web.Response: Returns a standardized JSON response containing the full request details and a success message. """ + custom = await self._dispatch(request) + if custom is not None: + return custom + # Extract request data endpoint = request.path query_params = dict(request.query) @@ -241,6 +264,9 @@ async def _handle_echo_chat_completions_request( Raises: ValueError: If no messages are present in the request payload. """ + custom = await self._dispatch(request) + if custom is not None: + return custom # Get request body try: diff --git a/tests/conftest.py b/tests/conftest.py index d69d3c75..c70bc4bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,11 @@ from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat from inference_endpoint.dataset_manager.transforms import ColumnRemap from inference_endpoint.testing.docker_server import DockerServer -from inference_endpoint.testing.echo_server import EchoServer, HTTPServer +from inference_endpoint.testing.echo_server import ( + EchoServer, + HTTPServer, + RequestHandler, +) logger = logging.getLogger(__name__) # Add src to path for imports @@ -72,38 +76,65 @@ def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: @pytest.fixture(scope="function") -def mock_http_echo_server(): +def mock_http_echo_server_factory(): + """Factory fixture for EchoServer with an optional custom request handler. + + Yields a factory callable: ``factory(handler=None) -> EchoServer``. + + When called with no arguments the server uses the default echo behaviour. + When called with a ``handler`` the handler is invoked for every incoming + request *instead of* the built-in echo logic. The handler may be a plain + (sync) callable or an async coroutine function; in both cases it receives + the ``aiohttp.web.Request`` and must return an ``aiohttp.web.Response``. + + All servers created by the factory are stopped automatically at the end of + the test. + + Example:: + + def test_custom(mock_http_echo_server_factory): + server = mock_http_echo_server_factory() # default echo + custom = mock_http_echo_server_factory( # custom handler + lambda req: web.json_response({"ok": True}) + ) """ - Mock HTTP server that echoes back the request payload in the appropriate format. + servers: list[EchoServer] = [] + + def factory(handler: RequestHandler | None = None) -> EchoServer: + server = EchoServer(port=0, request_handler=handler) + logging.info("Starting mock HTTP echo server") + server.start() + servers.append(server) + return server + + yield factory + + for server in servers: + logging.info("Stopping mock HTTP echo server") + server.stop() + + +@pytest.fixture(scope="function") +def mock_http_echo_server(mock_http_echo_server_factory): + """Mock HTTP server that echoes back the request payload in the appropriate format. This fixture creates a real HTTP server running on localhost that captures any HTTP request and returns the request payload as the response. Useful for testing HTTP clients with real network calls but controlled responses. + For a server with a custom per-request handler, use + ``mock_http_echo_server_factory`` instead. + Returns: - A server instance with URL. + A running ``EchoServer`` instance. + + Example:: - Example: def test_my_http_client(mock_http_echo_server): - server = mock_http_echo_server - # Make real HTTP requests to server.url - # The response will contain the exact payload you sent + # Make real HTTP requests to mock_http_echo_server.url + # The response will contain the exact payload you sent. """ - - # Create and start the server with dynamic port allocation (port=0) - - try: - server = EchoServer(port=0) - logging.info("Starting mock HTTP echo server") - server.start() - yield server - except Exception as e: - logging.error(f"Mock Echo Server error: {e}") - raise RuntimeError(f"Mock Echo Server error: {e}") from e - finally: - logging.info("Stopping mock HTTP echo server") - if server: - server.stop() + yield mock_http_echo_server_factory() @pytest.fixture diff --git a/tests/integration/commands/test_warmup.py b/tests/integration/commands/test_warmup.py new file mode 100644 index 00000000..4622365e --- /dev/null +++ b/tests/integration/commands/test_warmup.py @@ -0,0 +1,435 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for warmup phase behaviour. + +Covers two properties: + +Salt + When ``salt=True`` the server receives distinct prompts for warmup vs perf + (preventing KV-cache reuse). When ``salt=False`` the prompts are identical. + +Drain + When ``drain=True`` all warmup responses complete before perf requests start + (zero concurrent overlap at the server). When ``drain=False`` the perf + phase starts immediately, so the server handles both warmup and perf + requests simultaneously. +""" + +from __future__ import annotations + +import asyncio +import json +import re +import uuid +from pathlib import Path + +import pytest +from aiohttp import web +from inference_endpoint.commands.benchmark.execute import run_benchmark +from inference_endpoint.config.schema import Dataset as ConfigDataset +from inference_endpoint.config.schema import ( + DatasetType, + EndpointConfig, + LoadPattern, + LoadPatternType, + ModelParams, + OfflineBenchmarkConfig, + OfflineSettings, + RuntimeConfig, + StreamingMode, + TestMode, + WarmupConfig, +) +from inference_endpoint.core.types import QueryResult, TextModelOutput +from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.openai.openai_adapter import OpenAIAdapter + +# ── helpers ────────────────────────────────────────────────────────────────── + +_SALT_RE = re.compile(r"^\[([0-9a-f]{16})\] (.+)$") + +_MINIMAL_CLIENT = HTTPClientConfig( + num_workers=1, warmup_connections=0, max_connections=10 +) + + +def _echo_response(prompt: str) -> dict: + """Build a minimal valid OpenAI chat-completion response.""" + req_id = str(uuid.uuid4()) + result = QueryResult(id=req_id, response_output=TextModelOutput(output=prompt)) + body = OpenAIAdapter.to_endpoint_response(result).model_dump(mode="json") + body["id"] = req_id + return body + + +async def _user_prompt(request: web.Request) -> str: + """Extract the user message content from an OpenAI chat-completion request.""" + body = await request.json() + for msg in body.get("messages", []): + if msg.get("role") == "user": + content = msg.get("content", "") + return str(content) if content is not None else "" + return "" + + +def _offline_config( + endpoint_url: str, + dataset_path: str | Path, + warmup: WarmupConfig, + n_perf_samples: int = 5, +) -> OfflineBenchmarkConfig: + return OfflineBenchmarkConfig( + endpoint_config=EndpointConfig(endpoints=[endpoint_url]), + model_params=ModelParams(name="test-model", streaming=StreamingMode.OFF), + datasets=[ConfigDataset(path=str(dataset_path), type=DatasetType.PERFORMANCE)], + settings=OfflineSettings( + runtime=RuntimeConfig(min_duration_ms=0, n_samples_to_issue=n_perf_samples), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + client=_MINIMAL_CLIENT, + warmup=warmup, + ), + ) + + +# ── fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def single_prompt_dataset(tmp_path: Path) -> Path: + """JSONL with one sample: ``{"prompt": "hello world"}``.""" + f = tmp_path / "single.jsonl" + f.write_text(json.dumps({"prompt": "hello world"}) + "\n") + return f + + +@pytest.fixture +def multi_prompt_dataset(tmp_path: Path) -> Path: + """JSONL with five distinct samples.""" + f = tmp_path / "multi.jsonl" + f.write_text( + "\n".join(json.dumps({"prompt": f"prompt_{i}"}) for i in range(5)) + "\n" + ) + return f + + +# ── salt tests ──────────────────────────────────────────────────────────────── + + +@pytest.mark.integration +class TestWarmupSalt: + """Server-side prompt observations when salt is on vs off.""" + + def test_salt_enabled_warmup_prompts_differ_from_perf( + self, + mock_http_echo_server_factory, + single_prompt_dataset: Path, + ): + """With salt=True the server sees a distinct prompt for warmup vs perf. + + The warmup request carries a ``[<16-hex>] `` prefix; the perf request + carries the raw prompt. The base content after stripping the salt must + match the perf prompt. + """ + received: list[str] = [] + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + received.append(prompt) + return web.json_response(_echo_response(prompt)) + + server = mock_http_echo_server_factory(handler) + config = _offline_config( + server.url, + single_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=1, salt=True, drain=True), + n_perf_samples=1, + ) + run_benchmark(config, TestMode.PERF) + + assert ( + len(received) == 2 + ), f"Expected 2 requests (1 warmup + 1 perf), got: {received}" + + warmup_prompts = [p for p in received if _SALT_RE.match(p)] + perf_prompts = [p for p in received if not _SALT_RE.match(p)] + assert len(warmup_prompts) == 1, "Expected exactly 1 salted (warmup) prompt" + assert len(perf_prompts) == 1, "Expected exactly 1 unsalted (perf) prompt" + + # Base content after stripping salt must equal the perf prompt + m = _SALT_RE.match(warmup_prompts[0]) + assert m is not None + assert ( + m.group(2) == perf_prompts[0] + ), f"Stripped warmup prompt {m.group(2)!r} != perf prompt {perf_prompts[0]!r}" + + def test_salt_disabled_server_sees_identical_prompts( + self, + mock_http_echo_server_factory, + single_prompt_dataset: Path, + ): + """With salt=False warmup and perf send the exact same prompt. + + A KV cache on the real server would reuse the cached result, defeating + the purpose of warming up unique sequences. + """ + received: list[str] = [] + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + received.append(prompt) + return web.json_response(_echo_response(prompt)) + + server = mock_http_echo_server_factory(handler) + config = _offline_config( + server.url, + single_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=1, salt=False, drain=True), + n_perf_samples=1, + ) + run_benchmark(config, TestMode.PERF) + + assert len(received) == 2, f"Expected 2 requests, got: {received}" + assert ( + received[0] == received[1] + ), f"Expected identical prompts with salt=False, got: {received}" + assert not any( + _SALT_RE.match(p) for p in received + ), "No prompt should have a salt prefix when salt=False" + + def test_salt_count_matches_n_requests( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """Exactly ``n_requests`` salted prompts reach the server.""" + received: list[str] = [] + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + received.append(prompt) + return web.json_response(_echo_response(prompt)) + + server = mock_http_echo_server_factory(handler) + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=3, salt=True, drain=True), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + warmup_count = sum(1 for p in received if _SALT_RE.match(p)) + assert ( + warmup_count == 3 + ), f"Expected 3 warmup (salted) prompts, got {warmup_count} from: {received}" + + def test_each_salted_warmup_prompt_is_unique( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """Every warmup request has a distinct salt even when the same sample is reused. + + With n_requests (10) > dataset size (5) samples are cycled; without + unique salts the same raw text would appear twice, allowing cache reuse. + """ + received: list[str] = [] + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + received.append(prompt) + return web.json_response(_echo_response(prompt)) + + server = mock_http_echo_server_factory(handler) + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=10, salt=True, drain=True), + n_perf_samples=1, + ) + run_benchmark(config, TestMode.PERF) + + warmup_prompts = [p for p in received if _SALT_RE.match(p)] + assert len(warmup_prompts) == 10 + assert ( + len(set(warmup_prompts)) == 10 + ), "Expected all salted warmup prompts to be unique (distinct salt per request)" + + +# ── drain tests ─────────────────────────────────────────────────────────────── + + +@pytest.mark.integration +class TestWarmupDrain: + """Concurrency overlap between warmup and perf at the server. + + Strategy + -------- + * Use a slow server (``_DELAY`` seconds per response) so that warmup + requests stay in-flight long enough for the perf phase to start. + * Use ``salt=True`` to identify warmup vs perf requests at the server. + * Track concurrent in-flight counts inside the server handler. Since all + handlers share a single asyncio event loop, plain Python ints are safe + (no actual concurrent mutation — only cooperative interleaving at + ``await`` points). + * ``run_benchmark`` is synchronous and blocks until the perf phase drains, + guaranteeing all overlap events have been recorded before we assert. + """ + + _DELAY = 0.15 # seconds; long enough for perf to start while warmup is pending + + def _make_handler(self, delay: float): + """Return ``(handler_coro, state)``. + + ``state`` keys after ``run_benchmark`` returns: + + * ``overlap`` – ``True`` if warmup and perf were in-flight simultaneously + * ``max_concurrent`` – peak total in-flight request count + """ + state: dict = { + "warmup_inflight": 0, + "perf_inflight": 0, + "overlap": False, + "max_concurrent": 0, + } + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + is_warmup = bool(_SALT_RE.match(prompt)) + + if is_warmup: + state["warmup_inflight"] += 1 + else: + state["perf_inflight"] += 1 + + total = state["warmup_inflight"] + state["perf_inflight"] + if total > state["max_concurrent"]: + state["max_concurrent"] = total + + if state["warmup_inflight"] > 0 and state["perf_inflight"] > 0: + state["overlap"] = True + + await asyncio.sleep(delay) + + if is_warmup: + state["warmup_inflight"] -= 1 + else: + state["perf_inflight"] -= 1 + + return web.json_response(_echo_response(prompt)) + + return handler, state + + def test_drain_true_no_overlap( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """With ``drain=True`` the server never handles warmup and perf simultaneously. + + Timeline: all warmup responses arrive → perf phase starts → no overlap. + """ + handler, state = self._make_handler(self._DELAY) + server = mock_http_echo_server_factory(handler) + + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=5, salt=True, drain=True), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + assert not state[ + "overlap" + ], "With drain=True warmup and perf must not be in-flight at the same time" + + def test_drain_false_overlap_observed( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """With ``drain=False`` perf requests start before warmup responses arrive. + + Timeline: warmup issues 5 requests at MAX_THROUGHPUT → immediately + perf phase starts → both sets in-flight → overlap detected. + """ + handler, state = self._make_handler(self._DELAY) + server = mock_http_echo_server_factory(handler) + + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=5, salt=True, drain=False), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + assert state[ + "overlap" + ], "With drain=False perf requests should start while warmup is in-flight" + + def test_drain_true_max_concurrent_bounded_by_phase_size( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """With ``drain=True`` at most one phase worth of requests is in-flight. + + Both warmup and perf have 5 requests; phases never overlap so the peak + concurrent count is at most 5. + """ + handler, state = self._make_handler(self._DELAY) + server = mock_http_echo_server_factory(handler) + + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=5, salt=True, drain=True), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + assert state["max_concurrent"] <= 5, ( + f"With drain=True max concurrent {state['max_concurrent']} should be ≤ 5 " + "(one phase at a time)" + ) + + def test_drain_false_max_concurrent_exceeds_single_phase_size( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """With ``drain=False`` the server concurrently handles requests from both phases. + + 5 warmup requests stay in-flight while 5 perf requests arrive, so the + server reaches a peak of more than 5 simultaneous requests. + """ + handler, state = self._make_handler(self._DELAY) + server = mock_http_echo_server_factory(handler) + + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=5, salt=True, drain=False), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + assert state["max_concurrent"] > 5, ( + f"With drain=False peak concurrent {state['max_concurrent']} should be > 5 " + "(warmup + perf requests overlap)" + ) diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index c664234f..e91063aa 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -15,18 +15,25 @@ """Tests for benchmark CLI models, config building, and command handlers.""" +import random import tempfile from pathlib import Path from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pandas as pd import pytest from inference_endpoint.commands.benchmark.cli import ( from_config, offline, online, ) -from inference_endpoint.commands.benchmark.execute import ResponseCollector +from inference_endpoint.commands.benchmark.execute import ( + BenchmarkContext, + ResponseCollector, + _build_phases, +) +from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( BenchmarkConfig, DatasetType, @@ -35,9 +42,11 @@ OfflineSettings, OnlineSettings, RuntimeConfig, + ScorerMethod, StreamingMode, TestMode, TestType, + WarmupConfig, ) from inference_endpoint.config.schema import ( OfflineBenchmarkConfig as OfflineConfig, @@ -47,8 +56,13 @@ ) from inference_endpoint.config.utils import cli_error_formatter as _error_formatter from inference_endpoint.core.types import QueryResult +from inference_endpoint.dataset_manager.dataset import Dataset, SaltedDataset from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.evaluation.scoring import Scorer from inference_endpoint.exceptions import InputValidationError +from inference_endpoint.load_generator.sample_order import create_sample_order +from inference_endpoint.load_generator.session import PhaseType +from inference_endpoint.metrics.metric import Throughput from pydantic import ValidationError TEMPLATE_DIR = ( @@ -385,14 +399,390 @@ def test_valid_templates_parse(self, template): assert config.endpoint_config.endpoints +class TestWarmupConfig: + """Tests for WarmupConfig schema model.""" + + @pytest.mark.unit + def test_defaults(self): + cfg = WarmupConfig() + assert cfg.enabled is False + assert cfg.n_requests is None + assert cfg.salt is False + assert cfg.drain is False + + @pytest.mark.unit + @pytest.mark.parametrize("n", [1, 10, 1000]) + def test_n_requests_valid(self, n): + cfg = WarmupConfig(n_requests=n) + assert cfg.n_requests == n + + @pytest.mark.unit + @pytest.mark.parametrize("n", [0, -1, -100]) + def test_n_requests_must_be_positive(self, n): + with pytest.raises(ValidationError): + WarmupConfig(n_requests=n) + + @pytest.mark.unit + def test_extra_fields_rejected(self): + with pytest.raises(ValidationError): + WarmupConfig(unknown_field=True) + + @pytest.mark.unit + def test_immutable(self): + cfg = WarmupConfig() + with pytest.raises(ValidationError): + cfg.enabled = True # type: ignore[misc] + + @pytest.mark.unit + def test_all_flags_enabled(self): + cfg = WarmupConfig(enabled=True, n_requests=50, salt=True, drain=True) + assert cfg.enabled is True + assert cfg.n_requests == 50 + assert cfg.salt is True + assert cfg.drain is True + + @pytest.mark.unit + def test_yaml_roundtrip(self, tmp_path): + yaml_content = """ +type: "offline" +model_params: + name: "test-model" +endpoint_config: + endpoints: ["http://test:8000"] +datasets: + - path: "test.jsonl" +settings: + warmup: + enabled: true + n_requests: 20 + salt: true + drain: true +""" + config_file = tmp_path / "warmup.yaml" + config_file.write_text(yaml_content) + config = BenchmarkConfig.from_yaml_file(config_file) + warmup = config.settings.warmup + assert warmup.enabled is True + assert warmup.n_requests == 20 + assert warmup.salt is True + assert warmup.drain is True + + @pytest.mark.unit + def test_warmup_default_in_settings(self): + config = OfflineConfig(**_OFFLINE_KWARGS) + warmup = config.settings.warmup + assert warmup.enabled is False + assert warmup.n_requests is None + + +class TestBuildPhases: + """Tests for _build_phases() in execute.py.""" + + @pytest.fixture + def base_rt_settings(self): + return RuntimeSettings( + metric_target=Throughput(10.0), + reported_metrics=[Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=5, + n_samples_to_issue=None, + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + @pytest.fixture + def simple_dataset(self): + df = pd.DataFrame({"prompt": [f"q{i}" for i in range(5)]}) + ds = Dataset(df) + ds.load() + return ds + + def _make_ctx(self, config, rt_settings, dataloader): + return BenchmarkContext( + config=config, + test_mode=TestMode.PERF, + report_dir=Path("/tmp"), + tokenizer_name=None, + dataloader=dataloader, + rt_settings=rt_settings, + total_samples=dataloader.num_samples(), + accuracy_datasets=[], + eval_configs=[], + ) + + @pytest.mark.unit + def test_warmup_disabled_produces_only_perf_phase( + self, base_rt_settings, simple_dataset + ): + config = OfflineConfig(**_OFFLINE_KWARGS) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert len(phases) == 1 + assert phases[0].phase_type == PhaseType.PERFORMANCE + + @pytest.mark.unit + def test_warmup_enabled_produces_two_phases(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert len(phases) == 2 + assert phases[0].phase_type == PhaseType.WARMUP + assert phases[1].phase_type == PhaseType.PERFORMANCE + + @pytest.mark.unit + def test_warmup_phase_named_warmup(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].name == "warmup" + + @pytest.mark.unit + def test_warmup_phase_uses_max_throughput(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + warmup_rt = phases[0].runtime_settings + assert warmup_rt.load_pattern is not None + assert warmup_rt.load_pattern.type == LoadPatternType.MAX_THROUGHPUT + + @pytest.mark.unit + def test_warmup_phase_min_duration_is_zero(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].runtime_settings.min_duration_ms == 0 + + @pytest.mark.unit + def test_warmup_phase_no_max_duration(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].runtime_settings.max_duration_ms is None + + @pytest.mark.unit + def test_warmup_n_requests_propagated(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, n_requests=7)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].runtime_settings.n_samples_to_issue == 7 + + @pytest.mark.unit + def test_warmup_n_requests_none_when_unset(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings( + warmup=WarmupConfig(enabled=True, n_requests=None) + ), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].runtime_settings.n_samples_to_issue is None + + @pytest.mark.unit + def test_warmup_without_salt_uses_raw_dataloader( + self, base_rt_settings, simple_dataset + ): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=False)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert not isinstance(phases[0].dataset, SaltedDataset) + assert phases[0].dataset is simple_dataset + + @pytest.mark.unit + def test_warmup_with_salt_uses_salted_dataset( + self, base_rt_settings, simple_dataset + ): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert isinstance(phases[0].dataset, SaltedDataset) + + @pytest.mark.unit + def test_warmup_drain_false_by_default(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, drain=False)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].drain_after is False + + @pytest.mark.unit + def test_warmup_drain_true_propagated(self, base_rt_settings, simple_dataset): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, drain=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].drain_after is True + + @pytest.mark.unit + def test_warmup_n_samples_from_dataset_matches_dataloader( + self, base_rt_settings, simple_dataset + ): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert ( + phases[0].runtime_settings.n_samples_from_dataset + == simple_dataset.num_samples() + ) + + @pytest.mark.unit + def test_performance_phase_dataset_is_always_raw_dataloader( + self, base_rt_settings, simple_dataset + ): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + perf_phase = phases[1] + assert perf_phase.dataset is simple_dataset + + @pytest.mark.unit + def test_performance_phase_uses_original_rt_settings( + self, base_rt_settings, simple_dataset + ): + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[1].runtime_settings is base_rt_settings + + @pytest.mark.unit + def test_warmup_uses_independent_rng_instances( + self, base_rt_settings, simple_dataset + ): + """Warmup RuntimeSettings must not share RNG instances with the perf phase. + + Sharing would cause warmup sample-ordering to consume state from the + perf phase's deterministic random sequence, breaking reproducibility. + """ + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + warmup_rt = phases[0].runtime_settings + perf_rt = phases[1].runtime_settings + assert warmup_rt.rng_sched is not perf_rt.rng_sched + assert warmup_rt.rng_sample_index is not perf_rt.rng_sample_index + + @pytest.mark.unit + def test_performance_sample_order_identical_with_and_without_warmup( + self, simple_dataset + ): + """Warmup must not perturb the performance phase's sample ordering. + + Both runs use separate RuntimeSettings instances seeded identically so + the comparison is valid. If warmup ever accidentally shared or advanced + the perf-phase RNG, the two sequences would diverge. + """ + n_draw = 20 + + def make_rt(): + return RuntimeSettings( + metric_target=Throughput(10.0), + reported_metrics=[Throughput(10.0)], + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=simple_dataset.num_samples(), + n_samples_to_issue=None, + min_sample_count=1, + rng_sched=random.Random(99), + rng_sample_index=random.Random(99), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + config_with = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, n_requests=5)), + ) + config_without = OfflineConfig(**_OFFLINE_KWARGS) + + ctx_with = self._make_ctx(config_with, make_rt(), simple_dataset) + ctx_without = self._make_ctx(config_without, make_rt(), simple_dataset) + + perf_with = next( + p for p in _build_phases(ctx_with) if p.phase_type == PhaseType.PERFORMANCE + ) + perf_without = next( + p + for p in _build_phases(ctx_without) + if p.phase_type == PhaseType.PERFORMANCE + ) + + order_with = [ + next(create_sample_order(perf_with.runtime_settings)) for _ in range(n_draw) + ] + order_without = [ + next(create_sample_order(perf_without.runtime_settings)) + for _ in range(n_draw) + ] + + assert order_with == order_without, ( + "Performance sample order changed when warmup is enabled — " + "warmup may be sharing or advancing the perf-phase RNG." + ) + + class TestScorerMethodSync: """Ensure ScorerMethod enum stays in sync with the scorer registry.""" @pytest.mark.unit def test_scorer_enum_matches_registry(self): - from inference_endpoint.config.schema import ScorerMethod - from inference_endpoint.evaluation.scoring import Scorer - enum_values = {m.value for m in ScorerMethod} registry_keys = set(Scorer.PREDEFINED.keys()) assert enum_values == registry_keys, ( diff --git a/tests/unit/dataset_manager/test_salted_dataset.py b/tests/unit/dataset_manager/test_salted_dataset.py new file mode 100644 index 00000000..782a8d44 --- /dev/null +++ b/tests/unit/dataset_manager/test_salted_dataset.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SaltedDataset.""" + +import re +from unittest.mock import MagicMock + +import pandas as pd +import pytest +from inference_endpoint.dataset_manager.dataset import Dataset, SaltedDataset + + +def _make_loaded_dataset(rows: list[dict]) -> Dataset: + """Return a Dataset with .data already populated (no file I/O).""" + ds = Dataset.__new__(Dataset) + ds.dataframe = None + ds.transforms = None + ds.repeats = 1 + ds.data = list(rows) + ds.logger = MagicMock() + return ds + + +@pytest.mark.unit +class TestSaltedDatasetDelegation: + """SaltedDataset correctly mirrors inner-dataset properties.""" + + def test_num_samples_delegates_to_inner(self): + inner = _make_loaded_dataset( + [{"prompt": "a"}, {"prompt": "b"}, {"prompt": "c"}] + ) + sd = SaltedDataset(inner) + assert sd.num_samples() == 3 + + def test_data_attribute_is_inner_data(self): + inner = _make_loaded_dataset([{"prompt": "x"}]) + sd = SaltedDataset(inner) + assert sd.data is inner.data + + def test_repeats_matches_inner(self): + inner = _make_loaded_dataset([{"prompt": "x"}]) + inner.repeats = 5 + sd = SaltedDataset(inner) + assert sd.repeats == 5 + + def test_load_is_noop(self): + inner = _make_loaded_dataset([{"prompt": "x"}]) + sd = SaltedDataset(inner) + sd.load() # must not raise + assert sd.data is inner.data # data unchanged after load() + + +@pytest.mark.unit +class TestSaltedDatasetSaltBehavior: + """Salt is injected correctly into prompts.""" + + _SALT_RE = re.compile(r"^\[([0-9a-f]{16})\] (.+)$") + + def test_prompt_prefixed_with_salt(self): + inner = _make_loaded_dataset([{"prompt": "hello world"}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + assert self._SALT_RE.match( + result["prompt"] + ), f"Expected '[<16-hex>] hello world', got: {result['prompt']!r}" + + def test_salt_is_exactly_16_hex_chars(self): + inner = _make_loaded_dataset([{"prompt": "test"}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + m = self._SALT_RE.match(result["prompt"]) + assert m is not None + assert len(m.group(1)) == 16 + + def test_original_prompt_preserved_after_salt(self): + inner = _make_loaded_dataset([{"prompt": "my question"}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + m = self._SALT_RE.match(result["prompt"]) + assert m is not None + assert m.group(2) == "my question" + + def test_salt_unique_across_calls_same_index(self): + inner = _make_loaded_dataset([{"prompt": "repeated"}]) + sd = SaltedDataset(inner) + salts = {sd.load_sample(0)["prompt"][:18] for _ in range(20)} + # With 8 bytes of randomness we expect all 20 samples to be distinct + assert len(salts) == 20 + + def test_salt_unique_across_different_indices(self): + inner = _make_loaded_dataset([{"prompt": "a"}, {"prompt": "b"}]) + sd = SaltedDataset(inner) + prompt0 = sd.load_sample(0)["prompt"] + prompt1 = sd.load_sample(1)["prompt"] + # Salts should differ (original prompts differ, salts are random) + assert prompt0 != prompt1 + + def test_other_fields_unchanged(self): + inner = _make_loaded_dataset( + [{"prompt": "hi", "system": "you are helpful", "extra": 42}] + ) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + assert result["system"] == "you are helpful" + assert result["extra"] == 42 + + def test_original_dict_not_mutated(self): + row = {"prompt": "original"} + inner = _make_loaded_dataset([row]) + sd = SaltedDataset(inner) + sd.load_sample(0) + assert inner.data[0]["prompt"] == "original" + + +@pytest.mark.unit +class TestSaltedDatasetPassthrough: + """Samples without a 'prompt' key, or non-dict samples, are passed through unchanged.""" + + def test_dict_without_prompt_key_is_unchanged(self): + inner = _make_loaded_dataset([{"question": "what is 2+2?", "answer": "4"}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + assert result == {"question": "what is 2+2?", "answer": "4"} + + def test_empty_dict_is_unchanged(self): + inner = _make_loaded_dataset([{}]) + sd = SaltedDataset(inner) + assert sd.load_sample(0) == {} + + def test_non_dict_sample_is_returned_as_is(self): + inner = _make_loaded_dataset([{"prompt": "x"}]) + inner.data = ["raw string sample"] # override with non-dict + sd = SaltedDataset(inner) + sd.data = inner.data + assert sd.load_sample(0) == "raw string sample" + + def test_multimodal_list_prompt_first_text_part_is_salted(self): + """Salt is injected into the 'text' field of the first content part.""" + content_parts = [ + {"type": "text", "text": "describe this image"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ] + inner = _make_loaded_dataset([{"prompt": content_parts}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + parts = result["prompt"] + assert isinstance(parts, list) + assert len(parts) == 2 + # First text part must carry a salt prefix + assert re.match(r"^\[([0-9a-f]{16})\] describe this image$", parts[0]["text"]) + # Image part must be unchanged + assert parts[1] == content_parts[1] + + def test_multimodal_list_prompt_original_not_mutated(self): + """Salting a multimodal prompt must not modify the original list or dicts.""" + content_parts = [{"type": "text", "text": "original text"}] + inner = _make_loaded_dataset([{"prompt": content_parts}]) + sd = SaltedDataset(inner) + sd.load_sample(0) + assert inner.data[0]["prompt"][0]["text"] == "original text" + + def test_unknown_prompt_type_is_not_salted(self): + """A prompt that is neither str nor a recognised list-of-parts is returned unchanged.""" + inner = _make_loaded_dataset([{"prompt": 42}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + assert result == {"prompt": 42} + + +@pytest.mark.unit +class TestSaltedDatasetWithRealDataset: + """Integration-style: SaltedDataset wrapping a real Dataset loaded from a DataFrame.""" + + @pytest.fixture + def loaded_inner(self): + df = pd.DataFrame( + { + "prompt": ["What is AI?", "Explain gradient descent"], + "category": ["general", "ml"], + } + ) + ds = Dataset(df) + ds.load() + return ds + + def test_wraps_real_dataset_correctly(self, loaded_inner): + sd = SaltedDataset(loaded_inner) + assert sd.num_samples() == 2 + + def test_real_dataset_prompts_are_salted(self, loaded_inner): + sd = SaltedDataset(loaded_inner) + for i in range(sd.num_samples()): + result = sd.load_sample(i) + assert result["prompt"].startswith("[") + assert "] " in result["prompt"] + + def test_category_field_preserved(self, loaded_inner): + sd = SaltedDataset(loaded_inner) + assert sd.load_sample(0)["category"] == "general" + assert sd.load_sample(1)["category"] == "ml" + + def test_all_salts_unique_across_full_dataset(self, loaded_inner): + sd = SaltedDataset(loaded_inner) + prompts = [sd.load_sample(i)["prompt"] for i in range(sd.num_samples())] + assert len(set(prompts)) == len(prompts) diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 38dd014e..0901c99f 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -20,7 +20,6 @@ import asyncio import random -import inference_endpoint.load_generator.session as _session_mod import pytest from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import LoadPattern, LoadPatternType @@ -42,13 +41,6 @@ ) from inference_endpoint.metrics.metric import Throughput - -@pytest.fixture(autouse=False) -def enable_warmup(monkeypatch): - """Enable warmup phases for tests that use PhaseType.WARMUP.""" - monkeypatch.setattr(_session_mod, "_WARMUP_ENABLED", True) - - # --------------------------------------------------------------------------- # Test doubles # --------------------------------------------------------------------------- @@ -251,7 +243,7 @@ async def test_accuracy_phase(self): ) @pytest.mark.asyncio - async def test_warmup_produces_no_result(self, enable_warmup): + async def test_warmup_produces_no_result(self): loop = asyncio.get_running_loop() issuer = FakeIssuer() issuer._loop = loop @@ -270,7 +262,7 @@ async def test_warmup_produces_no_result(self, enable_warmup): assert len(result.phase_results) == 0 @pytest.mark.asyncio - async def test_multi_phase(self, enable_warmup): + async def test_multi_phase(self): loop = asyncio.get_running_loop() issuer = FakeIssuer() issuer._loop = loop @@ -346,7 +338,7 @@ def on_complete(result: QueryResult) -> None: assert len(completed) == 5 @pytest.mark.asyncio - async def test_stale_completions_ignored_by_strategy(self, enable_warmup): + async def test_stale_completions_ignored_by_strategy(self): """Responses from warmup phase should not affect perf phase strategy.""" loop = asyncio.get_running_loop() publisher = FakePublisher() @@ -737,7 +729,7 @@ class TestBenchmarkSessionMultiPhaseSatPerfSequence: """Multi-perf + warmup sequence (sat -> perf -> sat -> perf).""" @pytest.mark.asyncio - async def test_sat_perf_sat_perf(self, enable_warmup): + async def test_sat_perf_sat_perf(self): loop = asyncio.get_running_loop() issuer = FakeIssuer() issuer._loop = loop @@ -798,7 +790,7 @@ class TestBenchmarkSessionStaleStreamChunk: """Stale StreamChunk from previous phase is ignored.""" @pytest.mark.asyncio - async def test_stale_stream_chunk_ignored(self, enable_warmup): + async def test_stale_stream_chunk_ignored(self): """StreamChunk from warmup phase should not affect perf phase counts.""" loop = asyncio.get_running_loop() publisher = FakePublisher() @@ -826,7 +818,9 @@ def on_complete(result: QueryResult | StreamChunk) -> None: ) phases = [ - PhaseConfig("sat", sat_settings, FakeDataset(2), PhaseType.WARMUP), + PhaseConfig( + "sat", sat_settings, FakeDataset(2), PhaseType.WARMUP, drain_after=False + ), PhaseConfig("perf", perf_settings, FakeDataset(2), PhaseType.PERFORMANCE), ] @@ -871,7 +865,7 @@ async def inject_responses(): @pytest.mark.unit class TestSessionResult: - def test_perf_results_filter(self, enable_warmup): + def test_perf_results_filter(self): results = [ PhaseResult("sat", PhaseType.WARMUP, {}, 0, 0, 0), PhaseResult("perf1", PhaseType.PERFORMANCE, {"a": 1}, 10, 0, 100), diff --git a/tests/unit/test_http_mock_fixtures.py b/tests/unit/test_http_mock_fixtures.py index 8897b81c..786f28dd 100644 --- a/tests/unit/test_http_mock_fixtures.py +++ b/tests/unit/test_http_mock_fixtures.py @@ -20,10 +20,12 @@ with real HTTP server that echoes requests back. """ +import json import logging import aiohttp import pytest +from aiohttp import web from inference_endpoint.core.types import Query, TextModelOutput from inference_endpoint.openai.openai_adapter import OpenAIAdapter from inference_endpoint.openai.openai_types_gen import CreateChatCompletionResponse @@ -135,3 +137,112 @@ async def test_real_http_server_post_request_with_max_osl( assert len(str(response.response_output)) == 5 mock_http_echo_server.set_max_osl(old_max_osl) + + +@pytest.mark.unit +class TestHttpEchoServerFactory: + """Tests for the mock_http_echo_server_factory fixture.""" + + @pytest.mark.asyncio + async def test_factory_default_is_echo(self, mock_http_echo_server_factory): + """Factory with no handler behaves identically to mock_http_echo_server.""" + server = mock_http_echo_server_factory() + async with aiohttp.ClientSession() as session: + payload = {"query": "hello"} + async with session.post(f"{server.url}/echo", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert data["echo"] is True + assert data["request"]["json_payload"] == payload + + @pytest.mark.asyncio + async def test_factory_sync_handler_overrides_response( + self, mock_http_echo_server_factory + ): + """A plain (sync) lambda completely replaces the built-in handler.""" + server = mock_http_echo_server_factory( + lambda req: web.json_response({"custom": True}, status=201) + ) + async with aiohttp.ClientSession() as session: + async with session.post( + f"{server.url}/v1/chat/completions", + json={"model": "gpt-4", "messages": []}, + ) as resp: + assert resp.status == 201 + data = await resp.json() + assert data == {"custom": True} + + @pytest.mark.asyncio + async def test_factory_async_handler_overrides_response( + self, mock_http_echo_server_factory + ): + """An async handler is awaited and its response is returned.""" + + async def async_handler(request: web.Request) -> web.Response: + body = await request.json() + return web.json_response({"received_model": body.get("model")}) + + server = mock_http_echo_server_factory(async_handler) + async with aiohttp.ClientSession() as session: + async with session.post( + f"{server.url}/v1/chat/completions", + json={"model": "my-model", "messages": []}, + ) as resp: + assert resp.status == 200 + data = await resp.json() + assert data == {"received_model": "my-model"} + + @pytest.mark.asyncio + async def test_factory_handler_receives_request_body( + self, mock_http_echo_server_factory + ): + """Handler can inspect the raw request body for assertions.""" + captured: list[bytes] = [] + + async def capturing_handler(request: web.Request) -> web.Response: + captured.append(await request.read()) + return web.json_response({"ok": True}) + + server = mock_http_echo_server_factory(capturing_handler) + payload = {"hello": "world"} + async with aiohttp.ClientSession() as session: + async with session.post( + f"{server.url}/v1/chat/completions", json=payload + ) as resp: + assert resp.status == 200 + + assert len(captured) == 1 + assert json.loads(captured[0]) == payload + + @pytest.mark.asyncio + async def test_factory_handler_applies_to_echo_route_too( + self, mock_http_echo_server_factory + ): + """Custom handler intercepts /echo requests as well.""" + server = mock_http_echo_server_factory( + lambda req: web.json_response({"intercepted": True}) + ) + async with aiohttp.ClientSession() as session: + async with session.post(f"{server.url}/echo", json={}) as resp: + assert resp.status == 200 + assert (await resp.json()) == {"intercepted": True} + + @pytest.mark.asyncio + async def test_factory_creates_independent_servers( + self, mock_http_echo_server_factory + ): + """Two servers from the same factory operate independently.""" + echo_server = mock_http_echo_server_factory() + custom_server = mock_http_echo_server_factory( + lambda req: web.json_response({"server": "custom"}) + ) + assert echo_server.url != custom_server.url + + async with aiohttp.ClientSession() as session: + async with session.post(f"{custom_server.url}/echo", json={}) as resp: + data = await resp.json() + assert data == {"server": "custom"} + + async with session.post(f"{echo_server.url}/echo", json={"x": 1}) as resp: + data = await resp.json() + assert data["echo"] is True diff --git a/uv.lock b/uv.lock index ed84bd7e..9017350e 100644 --- a/uv.lock +++ b/uv.lock @@ -877,7 +877,7 @@ requires-dist = [ { name = "sphinx-autodoc-typehints", marker = "extra == 'dev'", specifier = "==3.9.11" }, { name = "sphinx-rtd-theme", marker = "extra == 'dev'", specifier = "==3.1.0" }, { name = "sqlalchemy", marker = "extra == 'sql'", specifier = "==2.0.48" }, - { name = "transformers", specifier = "==5.4.0" }, + { name = "transformers", specifier = "==5.5.0" }, { name = "typing-extensions", specifier = "==4.15.0" }, { name = "uvloop", specifier = "==0.22.1" }, { name = "websocket-client", specifier = "==1.9.0" }, @@ -2403,7 +2403,7 @@ wheels = [ [[package]] name = "transformers" -version = "5.4.0" +version = "5.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -2416,9 +2416,9 @@ dependencies = [ { name = "tqdm", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "typer", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0b/4c/42a8e1c7bbe668d8e073941ec3205263afb1cd02683fa5a8a75e615fdfbe/transformers-5.4.0.tar.gz", hash = "sha256:cb34ca89dce345ae3224b290346b9c0fa9694b951d54f3ed16334a4b1bfe3d04", size = 8152836, upload-time = "2026-03-27T00:24:24.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/9d/fb46e729b461985f41a5740167688b924a4019141e5c164bea77548d3d9e/transformers-5.5.0.tar.gz", hash = "sha256:c8db656cf51c600cd8c75f06b20ef85c72e8b8ff9abc880c5d3e8bc70e0ddcbd", size = 8237745, upload-time = "2026-04-02T16:13:08.113Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/a0/0a87883e564e364baab32adcacb4bec2e200b28a568423c8cf7fde316461/transformers-5.4.0-py3-none-any.whl", hash = "sha256:9fbe50602d2a4e6d0aa8a35a605433dfac72d595ee2192eae192590a6cc2df86", size = 10105556, upload-time = "2026-03-27T00:24:21.735Z" }, + { url = "https://files.pythonhosted.org/packages/e7/28/35f7411ff80a3640c1f4fc907dcbb6a65061ebb82f66950e38bfc9f7f740/transformers-5.5.0-py3-none-any.whl", hash = "sha256:821a9ff0961abbb29eb1eb686d78df1c85929fdf213a3fe49dc6bd94f9efa944", size = 10245591, upload-time = "2026-04-02T16:13:03.462Z" }, ] [[package]]