From 06a923777f6ee819413b95a4bcb7b138377ef9c9 Mon Sep 17 00:00:00 2001 From: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> Date: Mon, 4 May 2026 18:40:08 -0700 Subject: [PATCH 1/5] Plug in warmup phase Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> --- .gitignore | 3 + .../services/metrics_aggregator/__main__.py | 13 +- .../commands/benchmark/execute.py | 30 +- src/inference_endpoint/config/schema.py | 21 + .../templates/concurrency_template_full.yaml | 5 + .../templates/offline_template_full.yaml | 5 + .../templates/online_template_full.yaml | 5 + .../dataset_manager/dataset.py | 39 +- .../load_generator/session.py | 13 +- src/inference_endpoint/testing/echo_server.py | 28 +- tests/conftest.py | 77 +++- tests/integration/commands/test_warmup.py | 435 ++++++++++++++++++ tests/unit/commands/test_benchmark.py | 351 ++++++++++++++ .../dataset_manager/test_salted_dataset.py | 186 ++++++++ .../unit/load_generator/test_async_session.py | 20 +- tests/unit/test_http_mock_fixtures.py | 111 +++++ uv.lock | 8 +- 17 files changed, 1292 insertions(+), 58 deletions(-) create mode 100644 tests/integration/commands/test_warmup.py create mode 100644 tests/unit/dataset_manager/test_salted_dataset.py diff --git a/.gitignore b/.gitignore index 30372720..86e1584b 100644 --- a/.gitignore +++ b/.gitignore @@ -197,3 +197,6 @@ docs/superpowers/ # User-specific local dev configs; do not commit CLAUDE.local.md + +# Generated dataset cache (created by Dataset.get_dataloader()) +dataset_cache/ diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py index 50a3163d..0ef46e2a 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py @@ -17,6 +17,7 @@ import argparse import asyncio +import logging from contextlib import AbstractContextManager, nullcontext from pathlib import Path @@ -91,10 +92,16 @@ async def main() -> None: # Using ternary operator causes errors in MyPy object type coalescing # (coalesces to 'object' not 'AbstractContextManager[TokenizePool | None]') + pool_cm: AbstractContextManager[TokenizePool | None] if args.tokenizer: - pool_cm: AbstractContextManager[TokenizePool | None] = TokenizePool( - args.tokenizer, n_workers=args.tokenizer_workers - ) + try: + pool_cm = TokenizePool(args.tokenizer, n_workers=args.tokenizer_workers) + except Exception as e: + logging.warning( + f"Failed to load tokenizer '{args.tokenizer}': {e}. " + "ISL/OSL/TPOT token metrics will be unavailable." + ) + pool_cm = nullcontext() else: pool_cm = nullcontext() diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 73c3427f..2fb7a8e8 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -70,7 +70,7 @@ TestType, ) from inference_endpoint.core.types import QueryResult -from inference_endpoint.dataset_manager.dataset import Dataset +from inference_endpoint.dataset_manager.dataset import Dataset, SaltedDataset from inference_endpoint.dataset_manager.factory import DataLoaderFactory from inference_endpoint.endpoint_client.cpu_affinity import AffinityPlan, pin_loadgen from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient @@ -347,6 +347,34 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: """Build the phase list from BenchmarkContext.""" phases: list[PhaseConfig] = [] + # Warmup phase (optional, before performance) + warmup_cfg = ctx.config.settings.warmup + if warmup_cfg.enabled: + warmup_dataset: Dataset = ( + SaltedDataset(ctx.dataloader) if warmup_cfg.salt else ctx.dataloader + ) + warmup_rt = RuntimeSettings( + metric_target=ctx.rt_settings.metric_target, + reported_metrics=ctx.rt_settings.reported_metrics, + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=ctx.dataloader.num_samples(), + n_samples_to_issue=warmup_cfg.n_requests, + min_sample_count=1, + rng_sched=ctx.rt_settings.rng_sched, + rng_sample_index=ctx.rt_settings.rng_sample_index, + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phases.append( + PhaseConfig( + "warmup", + warmup_rt, + warmup_dataset, + PhaseType.WARMUP, + drain_after=warmup_cfg.drain, + ) + ) + # Performance phase phases.append( PhaseConfig( diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 6a1884b4..c5101dd2 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -392,6 +392,26 @@ def _validate_completeness(self) -> Self: return self +class WarmupConfig(BaseModel): + """Warmup phase configuration. Runs before the performance phase; results are not recorded.""" + + model_config = ConfigDict(extra="forbid", frozen=True) + + enabled: bool = Field( + False, description="Enable warmup phase before performance run" + ) + n_requests: int | None = Field( + None, gt=0, description="Warmup request count (None = full dataset once)" + ) + salt: bool = Field( + False, description="Prepend a unique random hex salt to each warmup prompt" + ) + drain: bool = Field( + False, + description="Drain in-flight warmup requests before starting the performance phase", + ) + + @cyclopts.Parameter(name="*") class Settings(BaseModel): """Test settings.""" @@ -401,6 +421,7 @@ class Settings(BaseModel): runtime: RuntimeConfig = Field(default_factory=RuntimeConfig) load_pattern: LoadPattern = Field(default_factory=LoadPattern) client: HTTPClientConfig = Field(default_factory=HTTPClientConfig) + warmup: WarmupConfig = Field(default_factory=WarmupConfig) class OfflineSettings(Settings): diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 3a8e004f..fa9bab42 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -68,6 +68,11 @@ settings: max_idle_time: 4.0 # Discard connections idle longer than this (seconds) min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system + warmup: + enabled: false # Enable warmup phase before performance run + n_requests: null # Warmup request count (None = full dataset once) + salt: false # Prepend a unique random hex salt to each warmup prompt + drain: false # Drain in-flight warmup requests before starting the performance phase endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index faabffde..368c29ee 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -68,6 +68,11 @@ settings: max_idle_time: 4.0 # Discard connections idle longer than this (seconds) min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system + warmup: + enabled: false # Enable warmup phase before performance run + n_requests: null # Warmup request count (None = full dataset once) + salt: false # Prepend a unique random hex salt to each warmup prompt + drain: false # Drain in-flight warmup requests before starting the performance phase endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index e9b7a673..598bf82d 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -68,6 +68,11 @@ settings: max_idle_time: 4.0 # Discard connections idle longer than this (seconds) min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system + warmup: + enabled: false # Enable warmup phase before performance run + n_requests: null # Warmup request count (None = full dataset once) + salt: false # Prepend a unique random hex salt to each warmup prompt + drain: false # Drain in-flight warmup requests before starting the performance phase endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/dataset_manager/dataset.py b/src/inference_endpoint/dataset_manager/dataset.py index 93e3cfa6..98d45c76 100644 --- a/src/inference_endpoint/dataset_manager/dataset.py +++ b/src/inference_endpoint/dataset_manager/dataset.py @@ -24,6 +24,7 @@ import numpy as np import pandas as pd + from datasets import load_dataset, load_from_disk from ..config.schema import APIType, ModelParams @@ -411,7 +412,7 @@ def num_samples(self) -> int: @classmethod def get_dataloader( cls, - datasets_dir: Path = Path("datasets"), + datasets_dir: Path = Path("dataset_cache"), num_repeats: int = 1, transforms: list[Transform] | None = None, force_regenerate: bool = False, @@ -431,6 +432,42 @@ def get_dataloader( return cls(df, transforms=transforms, repeats=num_repeats) +class SaltedDataset(Dataset): + """Wraps a loaded Dataset, prepending a unique random salt to each prompt on load_sample(). + + Each call to load_sample() generates a fresh salt, so reused samples (when + n_requests > dataset size) each receive a distinct salt. + """ + + def __init__(self, inner: Dataset) -> None: + # Skip Dataset.__init__ — all state is delegated to inner + self._inner = inner + self.data = inner.data + self.dataframe = None + self.transforms = None + self.repeats = inner.repeats + self.logger = getLogger(__name__) + + def load( + self, + adapter: "HttpRequestAdapter | None" = None, + api_type: APIType | None = None, + model_params: ModelParams | None = None, + force: bool = False, + ) -> None: + pass # Inner dataset already loaded + + def load_sample(self, index: int) -> Any: + data = self._inner.load_sample(index) + if isinstance(data, dict) and "prompt" in data: + salt = os.urandom(8).hex() + return {**data, "prompt": f"[{salt}] {data['prompt']}"} + return data + + def num_samples(self) -> int: + return self._inner.num_samples() + + class EmptyDataset(Dataset): """Empty dataset to be used as performance dataset when running only accuracy tests.""" diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 1c8ad992..78e5b271 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -22,7 +22,6 @@ import asyncio import logging -import os import time import uuid from collections.abc import Callable @@ -44,8 +43,6 @@ logger = logging.getLogger(__name__) -_WARMUP_ENABLED = os.environ.get("ENABLE_WARMUP") == "1" - # --------------------------------------------------------------------------- # Phase configuration @@ -68,6 +65,7 @@ class PhaseConfig: runtime_settings: RuntimeSettings dataset: Dataset phase_type: PhaseType = PhaseType.PERFORMANCE + drain_after: bool = True # --------------------------------------------------------------------------- @@ -274,12 +272,6 @@ async def run(self, phases: list[PhaseConfig]) -> SessionResult: for phase in phases: if self._stop_requested: break - if phase.phase_type == PhaseType.WARMUP and not _WARMUP_ENABLED: - logger.info( - "Skipping warmup phase %s (set ENABLE_WARMUP=1 to enable)", - phase.name, - ) - continue result = await self._run_phase(phase) if result is not None: phase_results.append(result) @@ -333,8 +325,7 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: finally: self._strategy_task = None - # Drain in-flight (skip for warmup — keep concurrency hot) - if phase.phase_type != PhaseType.WARMUP: + if phase.drain_after: await self._drain_inflight(phase_issuer) if phase.phase_type == PhaseType.PERFORMANCE: diff --git a/src/inference_endpoint/testing/echo_server.py b/src/inference_endpoint/testing/echo_server.py index 6555f2e6..a0049d1f 100644 --- a/src/inference_endpoint/testing/echo_server.py +++ b/src/inference_endpoint/testing/echo_server.py @@ -17,12 +17,14 @@ import argparse import asyncio +import inspect import json import logging import threading import time import uuid from abc import abstractmethod +from collections.abc import Awaitable, Callable from aiohttp import web @@ -31,6 +33,8 @@ from inference_endpoint.openai.openai_types_gen import CreateChatCompletionRequest from inference_endpoint.utils.logging import setup_logging +RequestHandler = Callable[[web.Request], web.Response | Awaitable[web.Response]] + class HTTPServer: @property @@ -49,11 +53,17 @@ def stop(self): class EchoServer(HTTPServer): def __init__( - self, *, host: str = "127.0.0.1", port: int = 0, max_osl: int | None = None + self, + *, + host: str = "127.0.0.1", + port: int = 0, + max_osl: int | None = None, + request_handler: RequestHandler | None = None, ): self.host = host self.port = port # If 0, will auto-assign available port self.max_osl = max_osl + self._request_handler = request_handler self._actual_port = None # Store the actual port after binding self.app = None @@ -97,6 +107,15 @@ def get_response(self, request: str) -> str: """ return request + async def _dispatch(self, request: web.Request) -> web.Response | None: + """Call the custom request_handler if set; return None to fall through to defaults.""" + if self._request_handler is None: + return None + result = self._request_handler(request) + if inspect.isawaitable(result): + result = await result + return result + async def _handle_echo_request(self, request: web.Request) -> web.Response: """ Handle a generic HTTP request and return a JSON response that echoes all request details. @@ -105,6 +124,10 @@ async def _handle_echo_request(self, request: web.Request) -> web.Response: Returns a standardized JSON response containing the full request details and a success message. """ + custom = await self._dispatch(request) + if custom is not None: + return custom + # Extract request data endpoint = request.path query_params = dict(request.query) @@ -241,6 +264,9 @@ async def _handle_echo_chat_completions_request( Raises: ValueError: If no messages are present in the request payload. """ + custom = await self._dispatch(request) + if custom is not None: + return custom # Get request body try: diff --git a/tests/conftest.py b/tests/conftest.py index d69d3c75..c70bc4bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,11 @@ from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat from inference_endpoint.dataset_manager.transforms import ColumnRemap from inference_endpoint.testing.docker_server import DockerServer -from inference_endpoint.testing.echo_server import EchoServer, HTTPServer +from inference_endpoint.testing.echo_server import ( + EchoServer, + HTTPServer, + RequestHandler, +) logger = logging.getLogger(__name__) # Add src to path for imports @@ -72,38 +76,65 @@ def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: @pytest.fixture(scope="function") -def mock_http_echo_server(): +def mock_http_echo_server_factory(): + """Factory fixture for EchoServer with an optional custom request handler. + + Yields a factory callable: ``factory(handler=None) -> EchoServer``. + + When called with no arguments the server uses the default echo behaviour. + When called with a ``handler`` the handler is invoked for every incoming + request *instead of* the built-in echo logic. The handler may be a plain + (sync) callable or an async coroutine function; in both cases it receives + the ``aiohttp.web.Request`` and must return an ``aiohttp.web.Response``. + + All servers created by the factory are stopped automatically at the end of + the test. + + Example:: + + def test_custom(mock_http_echo_server_factory): + server = mock_http_echo_server_factory() # default echo + custom = mock_http_echo_server_factory( # custom handler + lambda req: web.json_response({"ok": True}) + ) """ - Mock HTTP server that echoes back the request payload in the appropriate format. + servers: list[EchoServer] = [] + + def factory(handler: RequestHandler | None = None) -> EchoServer: + server = EchoServer(port=0, request_handler=handler) + logging.info("Starting mock HTTP echo server") + server.start() + servers.append(server) + return server + + yield factory + + for server in servers: + logging.info("Stopping mock HTTP echo server") + server.stop() + + +@pytest.fixture(scope="function") +def mock_http_echo_server(mock_http_echo_server_factory): + """Mock HTTP server that echoes back the request payload in the appropriate format. This fixture creates a real HTTP server running on localhost that captures any HTTP request and returns the request payload as the response. Useful for testing HTTP clients with real network calls but controlled responses. + For a server with a custom per-request handler, use + ``mock_http_echo_server_factory`` instead. + Returns: - A server instance with URL. + A running ``EchoServer`` instance. + + Example:: - Example: def test_my_http_client(mock_http_echo_server): - server = mock_http_echo_server - # Make real HTTP requests to server.url - # The response will contain the exact payload you sent + # Make real HTTP requests to mock_http_echo_server.url + # The response will contain the exact payload you sent. """ - - # Create and start the server with dynamic port allocation (port=0) - - try: - server = EchoServer(port=0) - logging.info("Starting mock HTTP echo server") - server.start() - yield server - except Exception as e: - logging.error(f"Mock Echo Server error: {e}") - raise RuntimeError(f"Mock Echo Server error: {e}") from e - finally: - logging.info("Stopping mock HTTP echo server") - if server: - server.stop() + yield mock_http_echo_server_factory() @pytest.fixture diff --git a/tests/integration/commands/test_warmup.py b/tests/integration/commands/test_warmup.py new file mode 100644 index 00000000..4622365e --- /dev/null +++ b/tests/integration/commands/test_warmup.py @@ -0,0 +1,435 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for warmup phase behaviour. + +Covers two properties: + +Salt + When ``salt=True`` the server receives distinct prompts for warmup vs perf + (preventing KV-cache reuse). When ``salt=False`` the prompts are identical. + +Drain + When ``drain=True`` all warmup responses complete before perf requests start + (zero concurrent overlap at the server). When ``drain=False`` the perf + phase starts immediately, so the server handles both warmup and perf + requests simultaneously. +""" + +from __future__ import annotations + +import asyncio +import json +import re +import uuid +from pathlib import Path + +import pytest +from aiohttp import web +from inference_endpoint.commands.benchmark.execute import run_benchmark +from inference_endpoint.config.schema import Dataset as ConfigDataset +from inference_endpoint.config.schema import ( + DatasetType, + EndpointConfig, + LoadPattern, + LoadPatternType, + ModelParams, + OfflineBenchmarkConfig, + OfflineSettings, + RuntimeConfig, + StreamingMode, + TestMode, + WarmupConfig, +) +from inference_endpoint.core.types import QueryResult, TextModelOutput +from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.openai.openai_adapter import OpenAIAdapter + +# ── helpers ────────────────────────────────────────────────────────────────── + +_SALT_RE = re.compile(r"^\[([0-9a-f]{16})\] (.+)$") + +_MINIMAL_CLIENT = HTTPClientConfig( + num_workers=1, warmup_connections=0, max_connections=10 +) + + +def _echo_response(prompt: str) -> dict: + """Build a minimal valid OpenAI chat-completion response.""" + req_id = str(uuid.uuid4()) + result = QueryResult(id=req_id, response_output=TextModelOutput(output=prompt)) + body = OpenAIAdapter.to_endpoint_response(result).model_dump(mode="json") + body["id"] = req_id + return body + + +async def _user_prompt(request: web.Request) -> str: + """Extract the user message content from an OpenAI chat-completion request.""" + body = await request.json() + for msg in body.get("messages", []): + if msg.get("role") == "user": + content = msg.get("content", "") + return str(content) if content is not None else "" + return "" + + +def _offline_config( + endpoint_url: str, + dataset_path: str | Path, + warmup: WarmupConfig, + n_perf_samples: int = 5, +) -> OfflineBenchmarkConfig: + return OfflineBenchmarkConfig( + endpoint_config=EndpointConfig(endpoints=[endpoint_url]), + model_params=ModelParams(name="test-model", streaming=StreamingMode.OFF), + datasets=[ConfigDataset(path=str(dataset_path), type=DatasetType.PERFORMANCE)], + settings=OfflineSettings( + runtime=RuntimeConfig(min_duration_ms=0, n_samples_to_issue=n_perf_samples), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + client=_MINIMAL_CLIENT, + warmup=warmup, + ), + ) + + +# ── fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def single_prompt_dataset(tmp_path: Path) -> Path: + """JSONL with one sample: ``{"prompt": "hello world"}``.""" + f = tmp_path / "single.jsonl" + f.write_text(json.dumps({"prompt": "hello world"}) + "\n") + return f + + +@pytest.fixture +def multi_prompt_dataset(tmp_path: Path) -> Path: + """JSONL with five distinct samples.""" + f = tmp_path / "multi.jsonl" + f.write_text( + "\n".join(json.dumps({"prompt": f"prompt_{i}"}) for i in range(5)) + "\n" + ) + return f + + +# ── salt tests ──────────────────────────────────────────────────────────────── + + +@pytest.mark.integration +class TestWarmupSalt: + """Server-side prompt observations when salt is on vs off.""" + + def test_salt_enabled_warmup_prompts_differ_from_perf( + self, + mock_http_echo_server_factory, + single_prompt_dataset: Path, + ): + """With salt=True the server sees a distinct prompt for warmup vs perf. + + The warmup request carries a ``[<16-hex>] `` prefix; the perf request + carries the raw prompt. The base content after stripping the salt must + match the perf prompt. + """ + received: list[str] = [] + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + received.append(prompt) + return web.json_response(_echo_response(prompt)) + + server = mock_http_echo_server_factory(handler) + config = _offline_config( + server.url, + single_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=1, salt=True, drain=True), + n_perf_samples=1, + ) + run_benchmark(config, TestMode.PERF) + + assert ( + len(received) == 2 + ), f"Expected 2 requests (1 warmup + 1 perf), got: {received}" + + warmup_prompts = [p for p in received if _SALT_RE.match(p)] + perf_prompts = [p for p in received if not _SALT_RE.match(p)] + assert len(warmup_prompts) == 1, "Expected exactly 1 salted (warmup) prompt" + assert len(perf_prompts) == 1, "Expected exactly 1 unsalted (perf) prompt" + + # Base content after stripping salt must equal the perf prompt + m = _SALT_RE.match(warmup_prompts[0]) + assert m is not None + assert ( + m.group(2) == perf_prompts[0] + ), f"Stripped warmup prompt {m.group(2)!r} != perf prompt {perf_prompts[0]!r}" + + def test_salt_disabled_server_sees_identical_prompts( + self, + mock_http_echo_server_factory, + single_prompt_dataset: Path, + ): + """With salt=False warmup and perf send the exact same prompt. + + A KV cache on the real server would reuse the cached result, defeating + the purpose of warming up unique sequences. + """ + received: list[str] = [] + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + received.append(prompt) + return web.json_response(_echo_response(prompt)) + + server = mock_http_echo_server_factory(handler) + config = _offline_config( + server.url, + single_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=1, salt=False, drain=True), + n_perf_samples=1, + ) + run_benchmark(config, TestMode.PERF) + + assert len(received) == 2, f"Expected 2 requests, got: {received}" + assert ( + received[0] == received[1] + ), f"Expected identical prompts with salt=False, got: {received}" + assert not any( + _SALT_RE.match(p) for p in received + ), "No prompt should have a salt prefix when salt=False" + + def test_salt_count_matches_n_requests( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """Exactly ``n_requests`` salted prompts reach the server.""" + received: list[str] = [] + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + received.append(prompt) + return web.json_response(_echo_response(prompt)) + + server = mock_http_echo_server_factory(handler) + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=3, salt=True, drain=True), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + warmup_count = sum(1 for p in received if _SALT_RE.match(p)) + assert ( + warmup_count == 3 + ), f"Expected 3 warmup (salted) prompts, got {warmup_count} from: {received}" + + def test_each_salted_warmup_prompt_is_unique( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """Every warmup request has a distinct salt even when the same sample is reused. + + With n_requests (10) > dataset size (5) samples are cycled; without + unique salts the same raw text would appear twice, allowing cache reuse. + """ + received: list[str] = [] + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + received.append(prompt) + return web.json_response(_echo_response(prompt)) + + server = mock_http_echo_server_factory(handler) + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=10, salt=True, drain=True), + n_perf_samples=1, + ) + run_benchmark(config, TestMode.PERF) + + warmup_prompts = [p for p in received if _SALT_RE.match(p)] + assert len(warmup_prompts) == 10 + assert ( + len(set(warmup_prompts)) == 10 + ), "Expected all salted warmup prompts to be unique (distinct salt per request)" + + +# ── drain tests ─────────────────────────────────────────────────────────────── + + +@pytest.mark.integration +class TestWarmupDrain: + """Concurrency overlap between warmup and perf at the server. + + Strategy + -------- + * Use a slow server (``_DELAY`` seconds per response) so that warmup + requests stay in-flight long enough for the perf phase to start. + * Use ``salt=True`` to identify warmup vs perf requests at the server. + * Track concurrent in-flight counts inside the server handler. Since all + handlers share a single asyncio event loop, plain Python ints are safe + (no actual concurrent mutation — only cooperative interleaving at + ``await`` points). + * ``run_benchmark`` is synchronous and blocks until the perf phase drains, + guaranteeing all overlap events have been recorded before we assert. + """ + + _DELAY = 0.15 # seconds; long enough for perf to start while warmup is pending + + def _make_handler(self, delay: float): + """Return ``(handler_coro, state)``. + + ``state`` keys after ``run_benchmark`` returns: + + * ``overlap`` – ``True`` if warmup and perf were in-flight simultaneously + * ``max_concurrent`` – peak total in-flight request count + """ + state: dict = { + "warmup_inflight": 0, + "perf_inflight": 0, + "overlap": False, + "max_concurrent": 0, + } + + async def handler(req: web.Request) -> web.Response: + prompt = await _user_prompt(req) + is_warmup = bool(_SALT_RE.match(prompt)) + + if is_warmup: + state["warmup_inflight"] += 1 + else: + state["perf_inflight"] += 1 + + total = state["warmup_inflight"] + state["perf_inflight"] + if total > state["max_concurrent"]: + state["max_concurrent"] = total + + if state["warmup_inflight"] > 0 and state["perf_inflight"] > 0: + state["overlap"] = True + + await asyncio.sleep(delay) + + if is_warmup: + state["warmup_inflight"] -= 1 + else: + state["perf_inflight"] -= 1 + + return web.json_response(_echo_response(prompt)) + + return handler, state + + def test_drain_true_no_overlap( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """With ``drain=True`` the server never handles warmup and perf simultaneously. + + Timeline: all warmup responses arrive → perf phase starts → no overlap. + """ + handler, state = self._make_handler(self._DELAY) + server = mock_http_echo_server_factory(handler) + + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=5, salt=True, drain=True), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + assert not state[ + "overlap" + ], "With drain=True warmup and perf must not be in-flight at the same time" + + def test_drain_false_overlap_observed( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """With ``drain=False`` perf requests start before warmup responses arrive. + + Timeline: warmup issues 5 requests at MAX_THROUGHPUT → immediately + perf phase starts → both sets in-flight → overlap detected. + """ + handler, state = self._make_handler(self._DELAY) + server = mock_http_echo_server_factory(handler) + + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=5, salt=True, drain=False), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + assert state[ + "overlap" + ], "With drain=False perf requests should start while warmup is in-flight" + + def test_drain_true_max_concurrent_bounded_by_phase_size( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """With ``drain=True`` at most one phase worth of requests is in-flight. + + Both warmup and perf have 5 requests; phases never overlap so the peak + concurrent count is at most 5. + """ + handler, state = self._make_handler(self._DELAY) + server = mock_http_echo_server_factory(handler) + + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=5, salt=True, drain=True), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + assert state["max_concurrent"] <= 5, ( + f"With drain=True max concurrent {state['max_concurrent']} should be ≤ 5 " + "(one phase at a time)" + ) + + def test_drain_false_max_concurrent_exceeds_single_phase_size( + self, + mock_http_echo_server_factory, + multi_prompt_dataset: Path, + ): + """With ``drain=False`` the server concurrently handles requests from both phases. + + 5 warmup requests stay in-flight while 5 perf requests arrive, so the + server reaches a peak of more than 5 simultaneous requests. + """ + handler, state = self._make_handler(self._DELAY) + server = mock_http_echo_server_factory(handler) + + config = _offline_config( + server.url, + multi_prompt_dataset, + warmup=WarmupConfig(enabled=True, n_requests=5, salt=True, drain=False), + n_perf_samples=5, + ) + run_benchmark(config, TestMode.PERF) + + assert state["max_concurrent"] > 5, ( + f"With drain=False peak concurrent {state['max_concurrent']} should be > 5 " + "(warmup + perf requests overlap)" + ) diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index c664234f..241451fb 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -27,6 +27,7 @@ online, ) from inference_endpoint.commands.benchmark.execute import ResponseCollector +from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( BenchmarkConfig, DatasetType, @@ -38,6 +39,7 @@ StreamingMode, TestMode, TestType, + WarmupConfig, ) from inference_endpoint.config.schema import ( OfflineBenchmarkConfig as OfflineConfig, @@ -385,6 +387,355 @@ def test_valid_templates_parse(self, template): assert config.endpoint_config.endpoints +class TestWarmupConfig: + """Tests for WarmupConfig schema model.""" + + @pytest.mark.unit + def test_defaults(self): + cfg = WarmupConfig() + assert cfg.enabled is False + assert cfg.n_requests is None + assert cfg.salt is False + assert cfg.drain is False + + @pytest.mark.unit + @pytest.mark.parametrize("n", [1, 10, 1000]) + def test_n_requests_valid(self, n): + cfg = WarmupConfig(n_requests=n) + assert cfg.n_requests == n + + @pytest.mark.unit + @pytest.mark.parametrize("n", [0, -1, -100]) + def test_n_requests_must_be_positive(self, n): + with pytest.raises(ValidationError): + WarmupConfig(n_requests=n) + + @pytest.mark.unit + def test_extra_fields_rejected(self): + with pytest.raises(ValidationError): + WarmupConfig(unknown_field=True) + + @pytest.mark.unit + def test_immutable(self): + cfg = WarmupConfig() + with pytest.raises(ValidationError): + cfg.enabled = True # type: ignore[misc] + + @pytest.mark.unit + def test_all_flags_enabled(self): + cfg = WarmupConfig(enabled=True, n_requests=50, salt=True, drain=True) + assert cfg.enabled is True + assert cfg.n_requests == 50 + assert cfg.salt is True + assert cfg.drain is True + + @pytest.mark.unit + def test_yaml_roundtrip(self, tmp_path): + yaml_content = """ +type: "offline" +model_params: + name: "test-model" +endpoint_config: + endpoints: ["http://test:8000"] +datasets: + - path: "test.jsonl" +settings: + warmup: + enabled: true + n_requests: 20 + salt: true + drain: true +""" + config_file = tmp_path / "warmup.yaml" + config_file.write_text(yaml_content) + config = BenchmarkConfig.from_yaml_file(config_file) + warmup = config.settings.warmup + assert warmup.enabled is True + assert warmup.n_requests == 20 + assert warmup.salt is True + assert warmup.drain is True + + @pytest.mark.unit + def test_warmup_default_in_settings(self): + config = OfflineConfig(**_OFFLINE_KWARGS) + warmup = config.settings.warmup + assert warmup.enabled is False + assert warmup.n_requests is None + + +class TestBuildPhases: + """Tests for _build_phases() in execute.py.""" + + @pytest.fixture + def base_rt_settings(self): + import random + + from inference_endpoint.config.schema import LoadPattern, LoadPatternType + from inference_endpoint.metrics.metric import Throughput + + return RuntimeSettings( + metric_target=Throughput(10.0), + reported_metrics=[Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=5, + n_samples_to_issue=None, + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + @pytest.fixture + def simple_dataset(self): + import pandas as pd + from inference_endpoint.dataset_manager.dataset import Dataset + + df = pd.DataFrame({"prompt": [f"q{i}" for i in range(5)]}) + ds = Dataset(df) + ds.load() + return ds + + def _make_ctx(self, config, rt_settings, dataloader): + from pathlib import Path + + from inference_endpoint.commands.benchmark.execute import BenchmarkContext + from inference_endpoint.config.schema import TestMode + + return BenchmarkContext( + config=config, + test_mode=TestMode.PERF, + report_dir=Path("/tmp"), + tokenizer_name=None, + dataloader=dataloader, + rt_settings=rt_settings, + total_samples=dataloader.num_samples(), + accuracy_datasets=[], + eval_configs=[], + ) + + @pytest.mark.unit + def test_warmup_disabled_produces_only_perf_phase( + self, base_rt_settings, simple_dataset + ): + from inference_endpoint.commands.benchmark.execute import _build_phases + from inference_endpoint.load_generator.session import PhaseType + + config = OfflineConfig(**_OFFLINE_KWARGS) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert len(phases) == 1 + assert phases[0].phase_type == PhaseType.PERFORMANCE + + @pytest.mark.unit + def test_warmup_enabled_produces_two_phases(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + from inference_endpoint.load_generator.session import PhaseType + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert len(phases) == 2 + assert phases[0].phase_type == PhaseType.WARMUP + assert phases[1].phase_type == PhaseType.PERFORMANCE + + @pytest.mark.unit + def test_warmup_phase_named_warmup(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].name == "warmup" + + @pytest.mark.unit + def test_warmup_phase_uses_max_throughput(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + from inference_endpoint.config.schema import LoadPatternType + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + warmup_rt = phases[0].runtime_settings + assert warmup_rt.load_pattern is not None + assert warmup_rt.load_pattern.type == LoadPatternType.MAX_THROUGHPUT + + @pytest.mark.unit + def test_warmup_phase_min_duration_is_zero(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].runtime_settings.min_duration_ms == 0 + + @pytest.mark.unit + def test_warmup_phase_no_max_duration(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].runtime_settings.max_duration_ms is None + + @pytest.mark.unit + def test_warmup_n_requests_propagated(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, n_requests=7)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].runtime_settings.n_samples_to_issue == 7 + + @pytest.mark.unit + def test_warmup_n_requests_none_when_unset(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings( + warmup=WarmupConfig(enabled=True, n_requests=None) + ), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].runtime_settings.n_samples_to_issue is None + + @pytest.mark.unit + def test_warmup_without_salt_uses_raw_dataloader( + self, base_rt_settings, simple_dataset + ): + from inference_endpoint.commands.benchmark.execute import _build_phases + from inference_endpoint.dataset_manager.dataset import SaltedDataset + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=False)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert not isinstance(phases[0].dataset, SaltedDataset) + assert phases[0].dataset is simple_dataset + + @pytest.mark.unit + def test_warmup_with_salt_uses_salted_dataset( + self, base_rt_settings, simple_dataset + ): + from inference_endpoint.commands.benchmark.execute import _build_phases + from inference_endpoint.dataset_manager.dataset import SaltedDataset + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert isinstance(phases[0].dataset, SaltedDataset) + + @pytest.mark.unit + def test_warmup_drain_false_by_default(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, drain=False)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].drain_after is False + + @pytest.mark.unit + def test_warmup_drain_true_propagated(self, base_rt_settings, simple_dataset): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, drain=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[0].drain_after is True + + @pytest.mark.unit + def test_warmup_n_samples_from_dataset_matches_dataloader( + self, base_rt_settings, simple_dataset + ): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert ( + phases[0].runtime_settings.n_samples_from_dataset + == simple_dataset.num_samples() + ) + + @pytest.mark.unit + def test_performance_phase_dataset_is_always_raw_dataloader( + self, base_rt_settings, simple_dataset + ): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + perf_phase = phases[1] + assert perf_phase.dataset is simple_dataset + + @pytest.mark.unit + def test_performance_phase_uses_original_rt_settings( + self, base_rt_settings, simple_dataset + ): + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + assert phases[1].runtime_settings is base_rt_settings + + class TestScorerMethodSync: """Ensure ScorerMethod enum stays in sync with the scorer registry.""" diff --git a/tests/unit/dataset_manager/test_salted_dataset.py b/tests/unit/dataset_manager/test_salted_dataset.py new file mode 100644 index 00000000..1e7488c0 --- /dev/null +++ b/tests/unit/dataset_manager/test_salted_dataset.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SaltedDataset.""" + +import re +from unittest.mock import MagicMock + +import pandas as pd +import pytest +from inference_endpoint.dataset_manager.dataset import Dataset, SaltedDataset + + +def _make_loaded_dataset(rows: list[dict]) -> Dataset: + """Return a Dataset with .data already populated (no file I/O).""" + ds = Dataset.__new__(Dataset) + ds.dataframe = None + ds.transforms = None + ds.repeats = 1 + ds.data = list(rows) + ds.logger = MagicMock() + return ds + + +@pytest.mark.unit +class TestSaltedDatasetDelegation: + """SaltedDataset correctly mirrors inner-dataset properties.""" + + def test_num_samples_delegates_to_inner(self): + inner = _make_loaded_dataset( + [{"prompt": "a"}, {"prompt": "b"}, {"prompt": "c"}] + ) + sd = SaltedDataset(inner) + assert sd.num_samples() == 3 + + def test_data_attribute_is_inner_data(self): + inner = _make_loaded_dataset([{"prompt": "x"}]) + sd = SaltedDataset(inner) + assert sd.data is inner.data + + def test_repeats_matches_inner(self): + inner = _make_loaded_dataset([{"prompt": "x"}]) + inner.repeats = 5 + sd = SaltedDataset(inner) + assert sd.repeats == 5 + + def test_load_is_noop(self): + inner = _make_loaded_dataset([{"prompt": "x"}]) + sd = SaltedDataset(inner) + sd.load() # must not raise + assert sd.data is inner.data # data unchanged after load() + + +@pytest.mark.unit +class TestSaltedDatasetSaltBehavior: + """Salt is injected correctly into prompts.""" + + _SALT_RE = re.compile(r"^\[([0-9a-f]{16})\] (.+)$") + + def test_prompt_prefixed_with_salt(self): + inner = _make_loaded_dataset([{"prompt": "hello world"}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + assert self._SALT_RE.match( + result["prompt"] + ), f"Expected '[<16-hex>] hello world', got: {result['prompt']!r}" + + def test_salt_is_exactly_16_hex_chars(self): + inner = _make_loaded_dataset([{"prompt": "test"}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + m = self._SALT_RE.match(result["prompt"]) + assert m is not None + assert len(m.group(1)) == 16 + + def test_original_prompt_preserved_after_salt(self): + inner = _make_loaded_dataset([{"prompt": "my question"}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + m = self._SALT_RE.match(result["prompt"]) + assert m is not None + assert m.group(2) == "my question" + + def test_salt_unique_across_calls_same_index(self): + inner = _make_loaded_dataset([{"prompt": "repeated"}]) + sd = SaltedDataset(inner) + salts = {sd.load_sample(0)["prompt"][:18] for _ in range(20)} + # With 8 bytes of randomness we expect all 20 samples to be distinct + assert len(salts) == 20 + + def test_salt_unique_across_different_indices(self): + inner = _make_loaded_dataset([{"prompt": "a"}, {"prompt": "b"}]) + sd = SaltedDataset(inner) + prompt0 = sd.load_sample(0)["prompt"] + prompt1 = sd.load_sample(1)["prompt"] + # Salts should differ (original prompts differ, salts are random) + assert prompt0 != prompt1 + + def test_other_fields_unchanged(self): + inner = _make_loaded_dataset( + [{"prompt": "hi", "system": "you are helpful", "extra": 42}] + ) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + assert result["system"] == "you are helpful" + assert result["extra"] == 42 + + def test_original_dict_not_mutated(self): + row = {"prompt": "original"} + inner = _make_loaded_dataset([row]) + sd = SaltedDataset(inner) + sd.load_sample(0) + assert inner.data[0]["prompt"] == "original" + + +@pytest.mark.unit +class TestSaltedDatasetPassthrough: + """Samples without a 'prompt' key, or non-dict samples, are passed through unchanged.""" + + def test_dict_without_prompt_key_is_unchanged(self): + inner = _make_loaded_dataset([{"question": "what is 2+2?", "answer": "4"}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + assert result == {"question": "what is 2+2?", "answer": "4"} + + def test_empty_dict_is_unchanged(self): + inner = _make_loaded_dataset([{}]) + sd = SaltedDataset(inner) + assert sd.load_sample(0) == {} + + def test_non_dict_sample_is_returned_as_is(self): + inner = _make_loaded_dataset([{"prompt": "x"}]) + inner.data = ["raw string sample"] # override with non-dict + sd = SaltedDataset(inner) + sd.data = inner.data + assert sd.load_sample(0) == "raw string sample" + + +@pytest.mark.unit +class TestSaltedDatasetWithRealDataset: + """Integration-style: SaltedDataset wrapping a real Dataset loaded from a DataFrame.""" + + @pytest.fixture + def loaded_inner(self): + df = pd.DataFrame( + { + "prompt": ["What is AI?", "Explain gradient descent"], + "category": ["general", "ml"], + } + ) + ds = Dataset(df) + ds.load() + return ds + + def test_wraps_real_dataset_correctly(self, loaded_inner): + sd = SaltedDataset(loaded_inner) + assert sd.num_samples() == 2 + + def test_real_dataset_prompts_are_salted(self, loaded_inner): + sd = SaltedDataset(loaded_inner) + for i in range(sd.num_samples()): + result = sd.load_sample(i) + assert result["prompt"].startswith("[") + assert "] " in result["prompt"] + + def test_category_field_preserved(self, loaded_inner): + sd = SaltedDataset(loaded_inner) + assert sd.load_sample(0)["category"] == "general" + assert sd.load_sample(1)["category"] == "ml" + + def test_all_salts_unique_across_full_dataset(self, loaded_inner): + sd = SaltedDataset(loaded_inner) + prompts = [sd.load_sample(i)["prompt"] for i in range(sd.num_samples())] + assert len(set(prompts)) == len(prompts) diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 38dd014e..670db9cc 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -20,7 +20,6 @@ import asyncio import random -import inference_endpoint.load_generator.session as _session_mod import pytest from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import LoadPattern, LoadPatternType @@ -42,13 +41,6 @@ ) from inference_endpoint.metrics.metric import Throughput - -@pytest.fixture(autouse=False) -def enable_warmup(monkeypatch): - """Enable warmup phases for tests that use PhaseType.WARMUP.""" - monkeypatch.setattr(_session_mod, "_WARMUP_ENABLED", True) - - # --------------------------------------------------------------------------- # Test doubles # --------------------------------------------------------------------------- @@ -251,7 +243,7 @@ async def test_accuracy_phase(self): ) @pytest.mark.asyncio - async def test_warmup_produces_no_result(self, enable_warmup): + async def test_warmup_produces_no_result(self): loop = asyncio.get_running_loop() issuer = FakeIssuer() issuer._loop = loop @@ -270,7 +262,7 @@ async def test_warmup_produces_no_result(self, enable_warmup): assert len(result.phase_results) == 0 @pytest.mark.asyncio - async def test_multi_phase(self, enable_warmup): + async def test_multi_phase(self): loop = asyncio.get_running_loop() issuer = FakeIssuer() issuer._loop = loop @@ -346,7 +338,7 @@ def on_complete(result: QueryResult) -> None: assert len(completed) == 5 @pytest.mark.asyncio - async def test_stale_completions_ignored_by_strategy(self, enable_warmup): + async def test_stale_completions_ignored_by_strategy(self): """Responses from warmup phase should not affect perf phase strategy.""" loop = asyncio.get_running_loop() publisher = FakePublisher() @@ -737,7 +729,7 @@ class TestBenchmarkSessionMultiPhaseSatPerfSequence: """Multi-perf + warmup sequence (sat -> perf -> sat -> perf).""" @pytest.mark.asyncio - async def test_sat_perf_sat_perf(self, enable_warmup): + async def test_sat_perf_sat_perf(self): loop = asyncio.get_running_loop() issuer = FakeIssuer() issuer._loop = loop @@ -798,7 +790,7 @@ class TestBenchmarkSessionStaleStreamChunk: """Stale StreamChunk from previous phase is ignored.""" @pytest.mark.asyncio - async def test_stale_stream_chunk_ignored(self, enable_warmup): + async def test_stale_stream_chunk_ignored(self): """StreamChunk from warmup phase should not affect perf phase counts.""" loop = asyncio.get_running_loop() publisher = FakePublisher() @@ -871,7 +863,7 @@ async def inject_responses(): @pytest.mark.unit class TestSessionResult: - def test_perf_results_filter(self, enable_warmup): + def test_perf_results_filter(self): results = [ PhaseResult("sat", PhaseType.WARMUP, {}, 0, 0, 0), PhaseResult("perf1", PhaseType.PERFORMANCE, {"a": 1}, 10, 0, 100), diff --git a/tests/unit/test_http_mock_fixtures.py b/tests/unit/test_http_mock_fixtures.py index 8897b81c..786f28dd 100644 --- a/tests/unit/test_http_mock_fixtures.py +++ b/tests/unit/test_http_mock_fixtures.py @@ -20,10 +20,12 @@ with real HTTP server that echoes requests back. """ +import json import logging import aiohttp import pytest +from aiohttp import web from inference_endpoint.core.types import Query, TextModelOutput from inference_endpoint.openai.openai_adapter import OpenAIAdapter from inference_endpoint.openai.openai_types_gen import CreateChatCompletionResponse @@ -135,3 +137,112 @@ async def test_real_http_server_post_request_with_max_osl( assert len(str(response.response_output)) == 5 mock_http_echo_server.set_max_osl(old_max_osl) + + +@pytest.mark.unit +class TestHttpEchoServerFactory: + """Tests for the mock_http_echo_server_factory fixture.""" + + @pytest.mark.asyncio + async def test_factory_default_is_echo(self, mock_http_echo_server_factory): + """Factory with no handler behaves identically to mock_http_echo_server.""" + server = mock_http_echo_server_factory() + async with aiohttp.ClientSession() as session: + payload = {"query": "hello"} + async with session.post(f"{server.url}/echo", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert data["echo"] is True + assert data["request"]["json_payload"] == payload + + @pytest.mark.asyncio + async def test_factory_sync_handler_overrides_response( + self, mock_http_echo_server_factory + ): + """A plain (sync) lambda completely replaces the built-in handler.""" + server = mock_http_echo_server_factory( + lambda req: web.json_response({"custom": True}, status=201) + ) + async with aiohttp.ClientSession() as session: + async with session.post( + f"{server.url}/v1/chat/completions", + json={"model": "gpt-4", "messages": []}, + ) as resp: + assert resp.status == 201 + data = await resp.json() + assert data == {"custom": True} + + @pytest.mark.asyncio + async def test_factory_async_handler_overrides_response( + self, mock_http_echo_server_factory + ): + """An async handler is awaited and its response is returned.""" + + async def async_handler(request: web.Request) -> web.Response: + body = await request.json() + return web.json_response({"received_model": body.get("model")}) + + server = mock_http_echo_server_factory(async_handler) + async with aiohttp.ClientSession() as session: + async with session.post( + f"{server.url}/v1/chat/completions", + json={"model": "my-model", "messages": []}, + ) as resp: + assert resp.status == 200 + data = await resp.json() + assert data == {"received_model": "my-model"} + + @pytest.mark.asyncio + async def test_factory_handler_receives_request_body( + self, mock_http_echo_server_factory + ): + """Handler can inspect the raw request body for assertions.""" + captured: list[bytes] = [] + + async def capturing_handler(request: web.Request) -> web.Response: + captured.append(await request.read()) + return web.json_response({"ok": True}) + + server = mock_http_echo_server_factory(capturing_handler) + payload = {"hello": "world"} + async with aiohttp.ClientSession() as session: + async with session.post( + f"{server.url}/v1/chat/completions", json=payload + ) as resp: + assert resp.status == 200 + + assert len(captured) == 1 + assert json.loads(captured[0]) == payload + + @pytest.mark.asyncio + async def test_factory_handler_applies_to_echo_route_too( + self, mock_http_echo_server_factory + ): + """Custom handler intercepts /echo requests as well.""" + server = mock_http_echo_server_factory( + lambda req: web.json_response({"intercepted": True}) + ) + async with aiohttp.ClientSession() as session: + async with session.post(f"{server.url}/echo", json={}) as resp: + assert resp.status == 200 + assert (await resp.json()) == {"intercepted": True} + + @pytest.mark.asyncio + async def test_factory_creates_independent_servers( + self, mock_http_echo_server_factory + ): + """Two servers from the same factory operate independently.""" + echo_server = mock_http_echo_server_factory() + custom_server = mock_http_echo_server_factory( + lambda req: web.json_response({"server": "custom"}) + ) + assert echo_server.url != custom_server.url + + async with aiohttp.ClientSession() as session: + async with session.post(f"{custom_server.url}/echo", json={}) as resp: + data = await resp.json() + assert data == {"server": "custom"} + + async with session.post(f"{echo_server.url}/echo", json={"x": 1}) as resp: + data = await resp.json() + assert data["echo"] is True diff --git a/uv.lock b/uv.lock index ed84bd7e..9017350e 100644 --- a/uv.lock +++ b/uv.lock @@ -877,7 +877,7 @@ requires-dist = [ { name = "sphinx-autodoc-typehints", marker = "extra == 'dev'", specifier = "==3.9.11" }, { name = "sphinx-rtd-theme", marker = "extra == 'dev'", specifier = "==3.1.0" }, { name = "sqlalchemy", marker = "extra == 'sql'", specifier = "==2.0.48" }, - { name = "transformers", specifier = "==5.4.0" }, + { name = "transformers", specifier = "==5.5.0" }, { name = "typing-extensions", specifier = "==4.15.0" }, { name = "uvloop", specifier = "==0.22.1" }, { name = "websocket-client", specifier = "==1.9.0" }, @@ -2403,7 +2403,7 @@ wheels = [ [[package]] name = "transformers" -version = "5.4.0" +version = "5.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -2416,9 +2416,9 @@ dependencies = [ { name = "tqdm", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "typer", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0b/4c/42a8e1c7bbe668d8e073941ec3205263afb1cd02683fa5a8a75e615fdfbe/transformers-5.4.0.tar.gz", hash = "sha256:cb34ca89dce345ae3224b290346b9c0fa9694b951d54f3ed16334a4b1bfe3d04", size = 8152836, upload-time = "2026-03-27T00:24:24.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/9d/fb46e729b461985f41a5740167688b924a4019141e5c164bea77548d3d9e/transformers-5.5.0.tar.gz", hash = "sha256:c8db656cf51c600cd8c75f06b20ef85c72e8b8ff9abc880c5d3e8bc70e0ddcbd", size = 8237745, upload-time = "2026-04-02T16:13:08.113Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/a0/0a87883e564e364baab32adcacb4bec2e200b28a568423c8cf7fde316461/transformers-5.4.0-py3-none-any.whl", hash = "sha256:9fbe50602d2a4e6d0aa8a35a605433dfac72d595ee2192eae192590a6cc2df86", size = 10105556, upload-time = "2026-03-27T00:24:21.735Z" }, + { url = "https://files.pythonhosted.org/packages/e7/28/35f7411ff80a3640c1f4fc907dcbb6a65061ebb82f66950e38bfc9f7f740/transformers-5.5.0-py3-none-any.whl", hash = "sha256:821a9ff0961abbb29eb1eb686d78df1c85929fdf213a3fe49dc6bd94f9efa944", size = 10245591, upload-time = "2026-04-02T16:13:03.462Z" }, ] [[package]] From 56258beeabb9b822d2479fbabe63d3e5edee860f Mon Sep 17 00:00:00 2001 From: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> Date: Mon, 4 May 2026 19:37:11 -0700 Subject: [PATCH 2/5] Missing changes. Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> --- .../commands/benchmark/execute.py | 5 +-- .../dataset_manager/dataset.py | 23 ++++++++++--- .../load_generator/session.py | 7 +++- tests/unit/commands/test_benchmark.py | 23 +++++++++++++ .../dataset_manager/test_salted_dataset.py | 32 +++++++++++++++++++ 5 files changed, 82 insertions(+), 8 deletions(-) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 2fb7a8e8..cc4fc4f1 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -27,6 +27,7 @@ import json import logging import platform +import random import shutil import signal import tempfile @@ -361,8 +362,8 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: n_samples_from_dataset=ctx.dataloader.num_samples(), n_samples_to_issue=warmup_cfg.n_requests, min_sample_count=1, - rng_sched=ctx.rt_settings.rng_sched, - rng_sample_index=ctx.rt_settings.rng_sample_index, + rng_sched=random.Random(), + rng_sample_index=random.Random(), load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), ) phases.append( diff --git a/src/inference_endpoint/dataset_manager/dataset.py b/src/inference_endpoint/dataset_manager/dataset.py index 98d45c76..ae6b234c 100644 --- a/src/inference_endpoint/dataset_manager/dataset.py +++ b/src/inference_endpoint/dataset_manager/dataset.py @@ -24,7 +24,6 @@ import numpy as np import pandas as pd - from datasets import load_dataset, load_from_disk from ..config.schema import APIType, ModelParams @@ -459,10 +458,24 @@ def load( def load_sample(self, index: int) -> Any: data = self._inner.load_sample(index) - if isinstance(data, dict) and "prompt" in data: - salt = os.urandom(8).hex() - return {**data, "prompt": f"[{salt}] {data['prompt']}"} - return data + if not (isinstance(data, dict) and "prompt" in data): + return data + prompt = data["prompt"] + salt = os.urandom(8).hex() + if isinstance(prompt, str): + return {**data, "prompt": f"[{salt}] {prompt}"} + if ( + isinstance(prompt, list) + and prompt + and isinstance(prompt[0], dict) + and prompt[0].get("type") == "text" + ): + salted_parts = [ + {**prompt[0], "text": f"[{salt}] {prompt[0]['text']}"}, + *prompt[1:], + ] + return {**data, "prompt": salted_parts} + return data # unsupported prompt type — skip salting def num_samples(self) -> int: return self._inner.num_samples() diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 78e5b271..2184512c 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -240,6 +240,7 @@ def __init__( self._stop_requested = False self._done = False self._current_phase_issuer: PhaseIssuer | None = None + self._current_phase_type: PhaseType | None = None self._current_strategy: LoadStrategy | None = None self._recv_task: asyncio.Task | None = None self._strategy_task: asyncio.Task | None = None @@ -310,6 +311,7 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: ) self._current_phase_issuer = phase_issuer + self._current_phase_type = phase.phase_type self._current_strategy = strategy # Performance phases get tracking events @@ -416,7 +418,10 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: self._drain_event.set() if self._current_strategy: self._current_strategy.on_query_complete(query_id) - if self._on_sample_complete: + if ( + self._on_sample_complete + and self._current_phase_type != PhaseType.WARMUP + ): self._on_sample_complete(resp) elif isinstance(resp, StreamChunk): diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index 241451fb..05f106f2 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -735,6 +735,29 @@ def test_performance_phase_uses_original_rt_settings( assert phases[1].runtime_settings is base_rt_settings + @pytest.mark.unit + def test_warmup_uses_independent_rng_instances( + self, base_rt_settings, simple_dataset + ): + """Warmup RuntimeSettings must not share RNG instances with the perf phase. + + Sharing would cause warmup sample-ordering to consume state from the + perf phase's deterministic random sequence, breaking reproducibility. + """ + from inference_endpoint.commands.benchmark.execute import _build_phases + + config = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), + ) + ctx = self._make_ctx(config, base_rt_settings, simple_dataset) + phases = _build_phases(ctx) + + warmup_rt = phases[0].runtime_settings + perf_rt = phases[1].runtime_settings + assert warmup_rt.rng_sched is not perf_rt.rng_sched + assert warmup_rt.rng_sample_index is not perf_rt.rng_sample_index + class TestScorerMethodSync: """Ensure ScorerMethod enum stays in sync with the scorer registry.""" diff --git a/tests/unit/dataset_manager/test_salted_dataset.py b/tests/unit/dataset_manager/test_salted_dataset.py index 1e7488c0..782a8d44 100644 --- a/tests/unit/dataset_manager/test_salted_dataset.py +++ b/tests/unit/dataset_manager/test_salted_dataset.py @@ -147,6 +147,38 @@ def test_non_dict_sample_is_returned_as_is(self): sd.data = inner.data assert sd.load_sample(0) == "raw string sample" + def test_multimodal_list_prompt_first_text_part_is_salted(self): + """Salt is injected into the 'text' field of the first content part.""" + content_parts = [ + {"type": "text", "text": "describe this image"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ] + inner = _make_loaded_dataset([{"prompt": content_parts}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + parts = result["prompt"] + assert isinstance(parts, list) + assert len(parts) == 2 + # First text part must carry a salt prefix + assert re.match(r"^\[([0-9a-f]{16})\] describe this image$", parts[0]["text"]) + # Image part must be unchanged + assert parts[1] == content_parts[1] + + def test_multimodal_list_prompt_original_not_mutated(self): + """Salting a multimodal prompt must not modify the original list or dicts.""" + content_parts = [{"type": "text", "text": "original text"}] + inner = _make_loaded_dataset([{"prompt": content_parts}]) + sd = SaltedDataset(inner) + sd.load_sample(0) + assert inner.data[0]["prompt"][0]["text"] == "original text" + + def test_unknown_prompt_type_is_not_salted(self): + """A prompt that is neither str nor a recognised list-of-parts is returned unchanged.""" + inner = _make_loaded_dataset([{"prompt": 42}]) + sd = SaltedDataset(inner) + result = sd.load_sample(0) + assert result == {"prompt": 42} + @pytest.mark.unit class TestSaltedDatasetWithRealDataset: From 4d1c89ed77f3c4062d16386939e2a6cb33b88d03 Mon Sep 17 00:00:00 2001 From: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> Date: Mon, 4 May 2026 20:08:37 -0700 Subject: [PATCH 3/5] Fix test Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> --- tests/unit/load_generator/test_async_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 670db9cc..48ad134e 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -818,7 +818,7 @@ def on_complete(result: QueryResult | StreamChunk) -> None: ) phases = [ - PhaseConfig("sat", sat_settings, FakeDataset(2), PhaseType.WARMUP), + PhaseConfig("sat", sat_settings, FakeDataset(2), PhaseType.WARMUP, drain_after=False), PhaseConfig("perf", perf_settings, FakeDataset(2), PhaseType.PERFORMANCE), ] From 1143d643397ab8628a819ef535b05cb6aae07ca9 Mon Sep 17 00:00:00 2001 From: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> Date: Tue, 5 May 2026 04:48:15 -0700 Subject: [PATCH 4/5] fix: address warmup phase review comments - Add `random_seed` field to WarmupConfig for deterministic warmup scheduling - Use `dataclasses.replace` + warmup seed for warmup RuntimeSettings - Fix SaltedDataset.data to be a property (avoids stale snapshot after inner reload) - Fix multimodal salting to find first text part at any index (handles image-first prompts) - Log warning when input_tokens present without prompt (salting not possible) - Fix ruff-format CI failure in test_async_session.py - Move inline imports to top of test_benchmark.py (AGENTS.md compliance) Co-Authored-By: Claude Sonnet 4.6 --- .../commands/benchmark/execute.py | 10 +-- src/inference_endpoint/config/schema.py | 4 ++ .../templates/concurrency_template_full.yaml | 1 + .../templates/offline_template_full.yaml | 1 + .../templates/online_template_full.yaml | 1 + .../dataset_manager/dataset.py | 40 +++++++---- tests/unit/commands/test_benchmark.py | 66 ++++--------------- .../unit/load_generator/test_async_session.py | 4 +- 8 files changed, 54 insertions(+), 73 deletions(-) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index cc4fc4f1..1eb66256 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -33,6 +33,7 @@ import tempfile import uuid from dataclasses import dataclass, field +from dataclasses import replace as dataclass_replace from datetime import datetime from pathlib import Path from typing import Any @@ -354,16 +355,15 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: warmup_dataset: Dataset = ( SaltedDataset(ctx.dataloader) if warmup_cfg.salt else ctx.dataloader ) - warmup_rt = RuntimeSettings( - metric_target=ctx.rt_settings.metric_target, - reported_metrics=ctx.rt_settings.reported_metrics, + warmup_rt = dataclass_replace( + ctx.rt_settings, min_duration_ms=0, max_duration_ms=None, n_samples_from_dataset=ctx.dataloader.num_samples(), n_samples_to_issue=warmup_cfg.n_requests, min_sample_count=1, - rng_sched=random.Random(), - rng_sample_index=random.Random(), + rng_sched=random.Random(warmup_cfg.random_seed), + rng_sample_index=random.Random(warmup_cfg.random_seed), load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), ) phases.append( diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index c5101dd2..04ed9bf4 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -410,6 +410,10 @@ class WarmupConfig(BaseModel): False, description="Drain in-flight warmup requests before starting the performance phase", ) + random_seed: int = Field( + 0, + description="Random seed for warmup phase scheduling and sample ordering", + ) @cyclopts.Parameter(name="*") diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index fa9bab42..bdcf1fbe 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -73,6 +73,7 @@ settings: n_requests: null # Warmup request count (None = full dataset once) salt: false # Prepend a unique random hex salt to each warmup prompt drain: false # Drain in-flight warmup requests before starting the performance phase + random_seed: 0 # Random seed for warmup phase scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index 368c29ee..2433897b 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -73,6 +73,7 @@ settings: n_requests: null # Warmup request count (None = full dataset once) salt: false # Prepend a unique random hex salt to each warmup prompt drain: false # Drain in-flight warmup requests before starting the performance phase + random_seed: 0 # Random seed for warmup phase scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index 598bf82d..0c3ebeb1 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -73,6 +73,7 @@ settings: n_requests: null # Warmup request count (None = full dataset once) salt: false # Prepend a unique random hex salt to each warmup prompt drain: false # Drain in-flight warmup requests before starting the performance phase + random_seed: 0 # Random seed for warmup phase scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/dataset_manager/dataset.py b/src/inference_endpoint/dataset_manager/dataset.py index ae6b234c..a6bde635 100644 --- a/src/inference_endpoint/dataset_manager/dataset.py +++ b/src/inference_endpoint/dataset_manager/dataset.py @@ -441,12 +441,19 @@ class SaltedDataset(Dataset): def __init__(self, inner: Dataset) -> None: # Skip Dataset.__init__ — all state is delegated to inner self._inner = inner - self.data = inner.data self.dataframe = None self.transforms = None self.repeats = inner.repeats self.logger = getLogger(__name__) + @property # type: ignore[override] + def data(self) -> list[Any] | None: + return self._inner.data + + @data.setter + def data(self, value: list[Any] | None) -> None: + self._inner.data = value + def load( self, adapter: "HttpRequestAdapter | None" = None, @@ -458,23 +465,30 @@ def load( def load_sample(self, index: int) -> Any: data = self._inner.load_sample(index) - if not (isinstance(data, dict) and "prompt" in data): + if not isinstance(data, dict): + return data + if "input_tokens" in data and "prompt" not in data: + self.logger.warning( + "SaltedDataset: sample has 'input_tokens' but no 'prompt' — " + "salt cannot be applied to pre-tokenized input; KV-cache reuse may not be prevented" + ) + return data + if "prompt" not in data: return data prompt = data["prompt"] salt = os.urandom(8).hex() if isinstance(prompt, str): return {**data, "prompt": f"[{salt}] {prompt}"} - if ( - isinstance(prompt, list) - and prompt - and isinstance(prompt[0], dict) - and prompt[0].get("type") == "text" - ): - salted_parts = [ - {**prompt[0], "text": f"[{salt}] {prompt[0]['text']}"}, - *prompt[1:], - ] - return {**data, "prompt": salted_parts} + if isinstance(prompt, list) and prompt: + # Find the first text part at any index (image-first prompts place text at index 1+) + for i, part in enumerate(prompt): + if isinstance(part, dict) and part.get("type") == "text": + salted_parts = [ + *prompt[:i], + {**part, "text": f"[{salt}] {part['text']}"}, + *prompt[i + 1 :], + ] + return {**data, "prompt": salted_parts} return data # unsupported prompt type — skip salting def num_samples(self) -> int: diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index 05f106f2..6838aaf1 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -15,18 +15,24 @@ """Tests for benchmark CLI models, config building, and command handlers.""" +import random import tempfile from pathlib import Path from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pandas as pd import pytest from inference_endpoint.commands.benchmark.cli import ( from_config, offline, online, ) -from inference_endpoint.commands.benchmark.execute import ResponseCollector +from inference_endpoint.commands.benchmark.execute import ( + BenchmarkContext, + ResponseCollector, + _build_phases, +) from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( BenchmarkConfig, @@ -36,6 +42,7 @@ OfflineSettings, OnlineSettings, RuntimeConfig, + ScorerMethod, StreamingMode, TestMode, TestType, @@ -49,8 +56,12 @@ ) from inference_endpoint.config.utils import cli_error_formatter as _error_formatter from inference_endpoint.core.types import QueryResult +from inference_endpoint.dataset_manager.dataset import Dataset, SaltedDataset from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.evaluation.scoring import Scorer from inference_endpoint.exceptions import InputValidationError +from inference_endpoint.load_generator.session import PhaseType +from inference_endpoint.metrics.metric import Throughput from pydantic import ValidationError TEMPLATE_DIR = ( @@ -468,11 +479,6 @@ class TestBuildPhases: @pytest.fixture def base_rt_settings(self): - import random - - from inference_endpoint.config.schema import LoadPattern, LoadPatternType - from inference_endpoint.metrics.metric import Throughput - return RuntimeSettings( metric_target=Throughput(10.0), reported_metrics=[Throughput(10.0)], @@ -488,20 +494,12 @@ def base_rt_settings(self): @pytest.fixture def simple_dataset(self): - import pandas as pd - from inference_endpoint.dataset_manager.dataset import Dataset - df = pd.DataFrame({"prompt": [f"q{i}" for i in range(5)]}) ds = Dataset(df) ds.load() return ds def _make_ctx(self, config, rt_settings, dataloader): - from pathlib import Path - - from inference_endpoint.commands.benchmark.execute import BenchmarkContext - from inference_endpoint.config.schema import TestMode - return BenchmarkContext( config=config, test_mode=TestMode.PERF, @@ -518,9 +516,6 @@ def _make_ctx(self, config, rt_settings, dataloader): def test_warmup_disabled_produces_only_perf_phase( self, base_rt_settings, simple_dataset ): - from inference_endpoint.commands.benchmark.execute import _build_phases - from inference_endpoint.load_generator.session import PhaseType - config = OfflineConfig(**_OFFLINE_KWARGS) ctx = self._make_ctx(config, base_rt_settings, simple_dataset) phases = _build_phases(ctx) @@ -530,9 +525,6 @@ def test_warmup_disabled_produces_only_perf_phase( @pytest.mark.unit def test_warmup_enabled_produces_two_phases(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - from inference_endpoint.load_generator.session import PhaseType - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), @@ -546,8 +538,6 @@ def test_warmup_enabled_produces_two_phases(self, base_rt_settings, simple_datas @pytest.mark.unit def test_warmup_phase_named_warmup(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), @@ -559,9 +549,6 @@ def test_warmup_phase_named_warmup(self, base_rt_settings, simple_dataset): @pytest.mark.unit def test_warmup_phase_uses_max_throughput(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - from inference_endpoint.config.schema import LoadPatternType - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), @@ -575,8 +562,6 @@ def test_warmup_phase_uses_max_throughput(self, base_rt_settings, simple_dataset @pytest.mark.unit def test_warmup_phase_min_duration_is_zero(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), @@ -588,8 +573,6 @@ def test_warmup_phase_min_duration_is_zero(self, base_rt_settings, simple_datase @pytest.mark.unit def test_warmup_phase_no_max_duration(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), @@ -601,8 +584,6 @@ def test_warmup_phase_no_max_duration(self, base_rt_settings, simple_dataset): @pytest.mark.unit def test_warmup_n_requests_propagated(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True, n_requests=7)), @@ -614,8 +595,6 @@ def test_warmup_n_requests_propagated(self, base_rt_settings, simple_dataset): @pytest.mark.unit def test_warmup_n_requests_none_when_unset(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings( @@ -631,9 +610,6 @@ def test_warmup_n_requests_none_when_unset(self, base_rt_settings, simple_datase def test_warmup_without_salt_uses_raw_dataloader( self, base_rt_settings, simple_dataset ): - from inference_endpoint.commands.benchmark.execute import _build_phases - from inference_endpoint.dataset_manager.dataset import SaltedDataset - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=False)), @@ -648,9 +624,6 @@ def test_warmup_without_salt_uses_raw_dataloader( def test_warmup_with_salt_uses_salted_dataset( self, base_rt_settings, simple_dataset ): - from inference_endpoint.commands.benchmark.execute import _build_phases - from inference_endpoint.dataset_manager.dataset import SaltedDataset - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=True)), @@ -662,8 +635,6 @@ def test_warmup_with_salt_uses_salted_dataset( @pytest.mark.unit def test_warmup_drain_false_by_default(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True, drain=False)), @@ -675,8 +646,6 @@ def test_warmup_drain_false_by_default(self, base_rt_settings, simple_dataset): @pytest.mark.unit def test_warmup_drain_true_propagated(self, base_rt_settings, simple_dataset): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True, drain=True)), @@ -690,8 +659,6 @@ def test_warmup_drain_true_propagated(self, base_rt_settings, simple_dataset): def test_warmup_n_samples_from_dataset_matches_dataloader( self, base_rt_settings, simple_dataset ): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), @@ -708,8 +675,6 @@ def test_warmup_n_samples_from_dataset_matches_dataloader( def test_performance_phase_dataset_is_always_raw_dataloader( self, base_rt_settings, simple_dataset ): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True, salt=True)), @@ -724,8 +689,6 @@ def test_performance_phase_dataset_is_always_raw_dataloader( def test_performance_phase_uses_original_rt_settings( self, base_rt_settings, simple_dataset ): - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), @@ -744,8 +707,6 @@ def test_warmup_uses_independent_rng_instances( Sharing would cause warmup sample-ordering to consume state from the perf phase's deterministic random sequence, breaking reproducibility. """ - from inference_endpoint.commands.benchmark.execute import _build_phases - config = OfflineConfig( **_OFFLINE_KWARGS, settings=OfflineSettings(warmup=WarmupConfig(enabled=True)), @@ -764,9 +725,6 @@ class TestScorerMethodSync: @pytest.mark.unit def test_scorer_enum_matches_registry(self): - from inference_endpoint.config.schema import ScorerMethod - from inference_endpoint.evaluation.scoring import Scorer - enum_values = {m.value for m in ScorerMethod} registry_keys = set(Scorer.PREDEFINED.keys()) assert enum_values == registry_keys, ( diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 48ad134e..0901c99f 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -818,7 +818,9 @@ def on_complete(result: QueryResult | StreamChunk) -> None: ) phases = [ - PhaseConfig("sat", sat_settings, FakeDataset(2), PhaseType.WARMUP, drain_after=False), + PhaseConfig( + "sat", sat_settings, FakeDataset(2), PhaseType.WARMUP, drain_after=False + ), PhaseConfig("perf", perf_settings, FakeDataset(2), PhaseType.PERFORMANCE), ] From c9e80b0f82e9e004b547fe925a22e607faefac8e Mon Sep 17 00:00:00 2001 From: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> Date: Tue, 5 May 2026 08:17:27 -0700 Subject: [PATCH 5/5] Address comments Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com> --- .../commands/benchmark/execute.py | 23 +++++++- src/inference_endpoint/config/schema.py | 6 +- .../templates/concurrency_template_full.yaml | 2 +- .../templates/offline_template_full.yaml | 2 +- .../templates/online_template_full.yaml | 2 +- .../dataset_manager/dataset.py | 9 ++- .../load_generator/session.py | 6 +- tests/unit/commands/test_benchmark.py | 58 +++++++++++++++++++ 8 files changed, 95 insertions(+), 13 deletions(-) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 1eb66256..9641c7a8 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -362,8 +362,8 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: n_samples_from_dataset=ctx.dataloader.num_samples(), n_samples_to_issue=warmup_cfg.n_requests, min_sample_count=1, - rng_sched=random.Random(warmup_cfg.random_seed), - rng_sample_index=random.Random(warmup_cfg.random_seed), + rng_sched=random.Random(warmup_cfg.warmup_random_seed), + rng_sample_index=random.Random(warmup_cfg.warmup_random_seed + 1), load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), ) phases.append( @@ -554,12 +554,31 @@ async def _run_benchmark_async( phases = _build_phases(ctx) report: Report | None = None + # Global wall-clock timeout covers warmup + performance + accuracy phases + # combined, and bounds the warmup drain so a dropped request can't hang forever. + global_timeout_handle = None + max_duration_ms = ctx.rt_settings.max_duration_ms + if max_duration_ms is not None: + + def _on_global_timeout() -> None: + logger.warning( + "Global experiment timeout reached (%d ms); stopping session.", + max_duration_ms, + ) + session.stop() + + global_timeout_handle = loop.call_later( + max_duration_ms / 1000.0, _on_global_timeout + ) + loop.add_signal_handler(signal.SIGINT, session.stop) try: result = await session.run(phases) except Exception as e: raise ExecutionError(f"Benchmark execution failed: {e}") from e finally: + if global_timeout_handle is not None: + global_timeout_handle.cancel() loop.remove_signal_handler(signal.SIGINT) logger.info("Cleaning up...") try: diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 04ed9bf4..c53e6e44 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -410,9 +410,9 @@ class WarmupConfig(BaseModel): False, description="Drain in-flight warmup requests before starting the performance phase", ) - random_seed: int = Field( - 0, - description="Random seed for warmup phase scheduling and sample ordering", + warmup_random_seed: int = Field( + 42, + description="RNG seed for warmup scheduling and sample ordering", ) diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index bdcf1fbe..f39db1a7 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -73,7 +73,7 @@ settings: n_requests: null # Warmup request count (None = full dataset once) salt: false # Prepend a unique random hex salt to each warmup prompt drain: false # Drain in-flight warmup requests before starting the performance phase - random_seed: 0 # Random seed for warmup phase scheduling and sample ordering + warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index 2433897b..71d4efea 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -73,7 +73,7 @@ settings: n_requests: null # Warmup request count (None = full dataset once) salt: false # Prepend a unique random hex salt to each warmup prompt drain: false # Drain in-flight warmup requests before starting the performance phase - random_seed: 0 # Random seed for warmup phase scheduling and sample ordering + warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index 0c3ebeb1..83d79c4e 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -73,7 +73,7 @@ settings: n_requests: null # Warmup request count (None = full dataset once) salt: false # Prepend a unique random hex salt to each warmup prompt drain: false # Drain in-flight warmup requests before starting the performance phase - random_seed: 0 # Random seed for warmup phase scheduling and sample ordering + warmup_random_seed: 42 # RNG seed for warmup scheduling and sample ordering endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. - http://localhost:8000 diff --git a/src/inference_endpoint/dataset_manager/dataset.py b/src/inference_endpoint/dataset_manager/dataset.py index a6bde635..bd5f12f0 100644 --- a/src/inference_endpoint/dataset_manager/dataset.py +++ b/src/inference_endpoint/dataset_manager/dataset.py @@ -276,11 +276,12 @@ class Dataset: def __init_subclass__( cls, dataset_id: str | None = None, + register: bool = True, **kwargs, ): super().__init_subclass__(**kwargs) - if not inspect.isabstract(cls): + if register and not inspect.isabstract(cls): if dataset_id is None: dataset_id = cls.__name__ cls.DATASET_ID = dataset_id @@ -431,7 +432,7 @@ def get_dataloader( return cls(df, transforms=transforms, repeats=num_repeats) -class SaltedDataset(Dataset): +class SaltedDataset(Dataset, register=False): """Wraps a loaded Dataset, prepending a unique random salt to each prompt on load_sample(). Each call to load_sample() generates a fresh salt, so reused samples (when @@ -489,6 +490,10 @@ def load_sample(self, index: int) -> Any: *prompt[i + 1 :], ] return {**data, "prompt": salted_parts} + self.logger.warning( + "SaltedDataset: multimodal prompt has no text part — " + "salt cannot be applied; KV-cache reuse may not be prevented" + ) return data # unsupported prompt type — skip salting def num_samples(self) -> int: diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 2184512c..c324f976 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -356,9 +356,9 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: async def _drain_inflight(self, phase_issuer: PhaseIssuer) -> None: """Wait for all in-flight responses from this phase to complete. - Currently, there is no timeout for the drain step. In the future, - we can possibly add a dynamic timeout based on the rate of completion - throughout the current phase.""" + Bounded by the global experiment timeout: if the caller schedules a + loop.call_later that calls stop(), stop() sets _drain_event, unblocking + this wait without leaving it hung indefinitely.""" if phase_issuer.inflight <= 0 or self._stop_requested: return logger.info("Draining %d in-flight responses...", phase_issuer.inflight) diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index 6838aaf1..e91063aa 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -60,6 +60,7 @@ from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.evaluation.scoring import Scorer from inference_endpoint.exceptions import InputValidationError +from inference_endpoint.load_generator.sample_order import create_sample_order from inference_endpoint.load_generator.session import PhaseType from inference_endpoint.metrics.metric import Throughput from pydantic import ValidationError @@ -719,6 +720,63 @@ def test_warmup_uses_independent_rng_instances( assert warmup_rt.rng_sched is not perf_rt.rng_sched assert warmup_rt.rng_sample_index is not perf_rt.rng_sample_index + @pytest.mark.unit + def test_performance_sample_order_identical_with_and_without_warmup( + self, simple_dataset + ): + """Warmup must not perturb the performance phase's sample ordering. + + Both runs use separate RuntimeSettings instances seeded identically so + the comparison is valid. If warmup ever accidentally shared or advanced + the perf-phase RNG, the two sequences would diverge. + """ + n_draw = 20 + + def make_rt(): + return RuntimeSettings( + metric_target=Throughput(10.0), + reported_metrics=[Throughput(10.0)], + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=simple_dataset.num_samples(), + n_samples_to_issue=None, + min_sample_count=1, + rng_sched=random.Random(99), + rng_sample_index=random.Random(99), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + config_with = OfflineConfig( + **_OFFLINE_KWARGS, + settings=OfflineSettings(warmup=WarmupConfig(enabled=True, n_requests=5)), + ) + config_without = OfflineConfig(**_OFFLINE_KWARGS) + + ctx_with = self._make_ctx(config_with, make_rt(), simple_dataset) + ctx_without = self._make_ctx(config_without, make_rt(), simple_dataset) + + perf_with = next( + p for p in _build_phases(ctx_with) if p.phase_type == PhaseType.PERFORMANCE + ) + perf_without = next( + p + for p in _build_phases(ctx_without) + if p.phase_type == PhaseType.PERFORMANCE + ) + + order_with = [ + next(create_sample_order(perf_with.runtime_settings)) for _ in range(n_draw) + ] + order_without = [ + next(create_sample_order(perf_without.runtime_settings)) + for _ in range(n_draw) + ] + + assert order_with == order_without, ( + "Performance sample order changed when warmup is enabled — " + "warmup may be sharing or advancing the perf-phase RNG." + ) + class TestScorerMethodSync: """Ensure ScorerMethod enum stays in sync with the scorer registry."""