-
Notifications
You must be signed in to change notification settings - Fork 9
feat: Plug in warmup phase #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,11 +27,13 @@ | |
| import json | ||
| import logging | ||
| import platform | ||
| import random | ||
| import shutil | ||
| import signal | ||
| import tempfile | ||
| import uuid | ||
| from dataclasses import dataclass, field | ||
| from dataclasses import replace as dataclass_replace | ||
| from datetime import datetime | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
@@ -70,7 +72,7 @@ | |
| TestType, | ||
| ) | ||
| from inference_endpoint.core.types import QueryResult | ||
| from inference_endpoint.dataset_manager.dataset import Dataset | ||
| from inference_endpoint.dataset_manager.dataset import Dataset, SaltedDataset | ||
| from inference_endpoint.dataset_manager.factory import DataLoaderFactory | ||
| from inference_endpoint.endpoint_client.cpu_affinity import AffinityPlan, pin_loadgen | ||
| from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient | ||
|
|
@@ -347,6 +349,33 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: | |
| """Build the phase list from BenchmarkContext.""" | ||
| phases: list[PhaseConfig] = [] | ||
|
|
||
| # Warmup phase (optional, before performance) | ||
| warmup_cfg = ctx.config.settings.warmup | ||
| if warmup_cfg.enabled: | ||
| warmup_dataset: Dataset = ( | ||
| SaltedDataset(ctx.dataloader) if warmup_cfg.salt else ctx.dataloader | ||
| ) | ||
| warmup_rt = dataclass_replace( | ||
| ctx.rt_settings, | ||
| min_duration_ms=0, | ||
| max_duration_ms=None, | ||
| n_samples_from_dataset=ctx.dataloader.num_samples(), | ||
| n_samples_to_issue=warmup_cfg.n_requests, | ||
| min_sample_count=1, | ||
| rng_sched=random.Random(warmup_cfg.warmup_random_seed), | ||
| rng_sample_index=random.Random(warmup_cfg.warmup_random_seed + 1), | ||
| load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), | ||
| ) | ||
| phases.append( | ||
| PhaseConfig( | ||
| "warmup", | ||
| warmup_rt, | ||
| warmup_dataset, | ||
| PhaseType.WARMUP, | ||
| drain_after=warmup_cfg.drain, | ||
| ) | ||
| ) | ||
|
|
||
| # Performance phase | ||
| phases.append( | ||
| PhaseConfig( | ||
|
|
@@ -525,12 +554,31 @@ async def _run_benchmark_async( | |
| phases = _build_phases(ctx) | ||
| report: Report | None = None | ||
|
|
||
| # Global wall-clock timeout covers warmup + performance + accuracy phases | ||
| # combined, and bounds the warmup drain so a dropped request can't hang forever. | ||
| global_timeout_handle = None | ||
| max_duration_ms = ctx.rt_settings.max_duration_ms | ||
| if max_duration_ms is not None: | ||
|
|
||
| def _on_global_timeout() -> None: | ||
| logger.warning( | ||
| "Global experiment timeout reached (%d ms); stopping session.", | ||
| max_duration_ms, | ||
| ) | ||
| session.stop() | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] low · concurrency The global timer's callback calls def _on_global_timeout() -> None:
logger.warning(
"Global experiment timeout reached (%d ms); stopping session.",
max_duration_ms,
)
session.stop()
Suggested fix: def _on_global_timeout() -> None:
if session._done: # or expose a public flag
return
logger.warning(...)
session.stop()Also consider moving |
||
|
|
||
| global_timeout_handle = loop.call_later( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Both (Codex P1 + Claude)] high · api-contract The new global timer redefines what max_duration_ms = ctx.rt_settings.max_duration_ms
if max_duration_ms is not None:
...
global_timeout_handle = loop.call_later(
max_duration_ms / 1000.0, _on_global_timeout
)
For a user with The schema description still reads Suggested fixes (any one):
|
||
| max_duration_ms / 1000.0, _on_global_timeout | ||
| ) | ||
|
|
||
| loop.add_signal_handler(signal.SIGINT, session.stop) | ||
| try: | ||
| result = await session.run(phases) | ||
| except Exception as e: | ||
| raise ExecutionError(f"Benchmark execution failed: {e}") from e | ||
| finally: | ||
| if global_timeout_handle is not None: | ||
| global_timeout_handle.cancel() | ||
| loop.remove_signal_handler(signal.SIGINT) | ||
| logger.info("Cleaning up...") | ||
| try: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -276,11 +276,12 @@ class Dataset: | |
| def __init_subclass__( | ||
| cls, | ||
| dataset_id: str | None = None, | ||
| register: bool = True, | ||
| **kwargs, | ||
| ): | ||
| super().__init_subclass__(**kwargs) | ||
|
|
||
| if not inspect.isabstract(cls): | ||
| if register and not inspect.isabstract(cls): | ||
| if dataset_id is None: | ||
| dataset_id = cls.__name__ | ||
| cls.DATASET_ID = dataset_id | ||
|
|
@@ -411,7 +412,7 @@ def num_samples(self) -> int: | |
| @classmethod | ||
| def get_dataloader( | ||
| cls, | ||
| datasets_dir: Path = Path("datasets"), | ||
| datasets_dir: Path = Path("dataset_cache"), | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] high · api-contract Default datasets_dir: Path = Path("dataset_cache"),This is a silent breaking change. Any existing user who relied on the previous default will, after upgrade, fail to find the cached dataframe under Suggest: keep |
||
| num_repeats: int = 1, | ||
| transforms: list[Transform] | None = None, | ||
| force_regenerate: bool = False, | ||
|
|
@@ -431,6 +432,74 @@ def get_dataloader( | |
| return cls(df, transforms=transforms, repeats=num_repeats) | ||
|
|
||
|
|
||
| class SaltedDataset(Dataset, register=False): | ||
|
|
||
| """Wraps a loaded Dataset, prepending a unique random salt to each prompt on load_sample(). | ||
|
|
||
| Each call to load_sample() generates a fresh salt, so reused samples (when | ||
| n_requests > dataset size) each receive a distinct salt. | ||
| """ | ||
|
|
||
| def __init__(self, inner: Dataset) -> None: | ||
| # Skip Dataset.__init__ — all state is delegated to inner | ||
| self._inner = inner | ||
| self.dataframe = None | ||
| self.transforms = None | ||
| self.repeats = inner.repeats | ||
| self.logger = getLogger(__name__) | ||
|
Comment on lines
+442
to
+448
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Specifically, calling the inherited |
||
|
|
||
| @property # type: ignore[override] | ||
| def data(self) -> list[Any] | None: | ||
| return self._inner.data | ||
|
|
||
| @data.setter | ||
| def data(self, value: list[Any] | None) -> None: | ||
| self._inner.data = value | ||
|
|
||
| def load( | ||
| self, | ||
| adapter: "HttpRequestAdapter | None" = None, | ||
| api_type: APIType | None = None, | ||
| model_params: ModelParams | None = None, | ||
| force: bool = False, | ||
| ) -> None: | ||
| pass # Inner dataset already loaded | ||
|
|
||
| def load_sample(self, index: int) -> Any: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] high · bug
input_tokens = query.data["input_tokens"]
return cls._request_encoder.encode(
SGLangGenerateRequest(input_ids=input_tokens, ...))When Fix options: (a) detect |
||
| data = self._inner.load_sample(index) | ||
| if not isinstance(data, dict): | ||
| return data | ||
| if "input_tokens" in data and "prompt" not in data: | ||
| self.logger.warning( | ||
| "SaltedDataset: sample has 'input_tokens' but no 'prompt' — " | ||
| "salt cannot be applied to pre-tokenized input; KV-cache reuse may not be prevented" | ||
| ) | ||
| return data | ||
| if "prompt" not in data: | ||
| return data | ||
| prompt = data["prompt"] | ||
| salt = os.urandom(8).hex() | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] low · design
Fix: thread the |
||
| if isinstance(prompt, str): | ||
| return {**data, "prompt": f"[{salt}] {prompt}"} | ||
| if isinstance(prompt, list) and prompt: | ||
| # Find the first text part at any index (image-first prompts place text at index 1+) | ||
| for i, part in enumerate(prompt): | ||
| if isinstance(part, dict) and part.get("type") == "text": | ||
| salted_parts = [ | ||
| *prompt[:i], | ||
| {**part, "text": f"[{salt}] {part['text']}"}, | ||
| *prompt[i + 1 :], | ||
| ] | ||
| return {**data, "prompt": salted_parts} | ||
| self.logger.warning( | ||
| "SaltedDataset: multimodal prompt has no text part — " | ||
| "salt cannot be applied; KV-cache reuse may not be prevented" | ||
| ) | ||
| return data # unsupported prompt type — skip salting | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] medium · bug When the prompt is a multimodal list where ALL parts are non-text (e.g., all if isinstance(prompt, list) and prompt:
for i, part in enumerate(prompt):
if isinstance(part, dict) and part.get("type") == "text":
...
return {**data, "prompt": salted_parts}
return data # unsupported prompt type — skip salting (silent!)Unlike the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] medium · bug When the prompt is a multimodal list where ALL parts are non-text (e.g., all if isinstance(prompt, list) and prompt:
for i, part in enumerate(prompt):
if isinstance(part, dict) and part.get("type") == "text":
...
return {**data, "prompt": salted_parts}
return data # unsupported prompt type — skip salting (silent!)Unlike the |
||
|
|
||
| def num_samples(self) -> int: | ||
| return self._inner.num_samples() | ||
|
|
||
|
|
||
| class EmptyDataset(Dataset): | ||
| """Empty dataset to be used as performance dataset when running only accuracy tests.""" | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,6 @@ | |
|
|
||
| import asyncio | ||
| import logging | ||
| import os | ||
| import time | ||
| import uuid | ||
| from collections.abc import Callable | ||
|
|
@@ -44,8 +43,6 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _WARMUP_ENABLED = os.environ.get("ENABLE_WARMUP") == "1" | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Phase configuration | ||
|
|
@@ -68,6 +65,7 @@ class PhaseConfig: | |
| runtime_settings: RuntimeSettings | ||
| dataset: Dataset | ||
| phase_type: PhaseType = PhaseType.PERFORMANCE | ||
| drain_after: bool = True | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] low · api-contract
Fix: either default |
||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
|
|
@@ -242,6 +240,7 @@ def __init__( | |
| self._stop_requested = False | ||
| self._done = False | ||
| self._current_phase_issuer: PhaseIssuer | None = None | ||
| self._current_phase_type: PhaseType | None = None | ||
| self._current_strategy: LoadStrategy | None = None | ||
| self._recv_task: asyncio.Task | None = None | ||
| self._strategy_task: asyncio.Task | None = None | ||
|
|
@@ -274,12 +273,6 @@ async def run(self, phases: list[PhaseConfig]) -> SessionResult: | |
| for phase in phases: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] medium · api-contract The Suggested fix: log a one-time warning at session start if |
||
| if self._stop_requested: | ||
| break | ||
| if phase.phase_type == PhaseType.WARMUP and not _WARMUP_ENABLED: | ||
| logger.info( | ||
| "Skipping warmup phase %s (set ENABLE_WARMUP=1 to enable)", | ||
| phase.name, | ||
| ) | ||
| continue | ||
| result = await self._run_phase(phase) | ||
| if result is not None: | ||
| phase_results.append(result) | ||
|
|
@@ -318,6 +311,7 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: | |
| ) | ||
|
|
||
| self._current_phase_issuer = phase_issuer | ||
| self._current_phase_type = phase.phase_type | ||
| self._current_strategy = strategy | ||
|
|
||
| # Performance phases get tracking events | ||
|
|
@@ -333,8 +327,7 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: | |
| finally: | ||
| self._strategy_task = None | ||
|
|
||
| # Drain in-flight (skip for warmup — keep concurrency hot) | ||
| if phase.phase_type != PhaseType.WARMUP: | ||
| if phase.drain_after: | ||
| await self._drain_inflight(phase_issuer) | ||
|
|
||
| if phase.phase_type == PhaseType.PERFORMANCE: | ||
|
|
@@ -363,9 +356,9 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: | |
| async def _drain_inflight(self, phase_issuer: PhaseIssuer) -> None: | ||
| """Wait for all in-flight responses from this phase to complete. | ||
|
|
||
| Currently, there is no timeout for the drain step. In the future, | ||
| we can possibly add a dynamic timeout based on the rate of completion | ||
| throughout the current phase.""" | ||
| Bounded by the global experiment timeout: if the caller schedules a | ||
| loop.call_later that calls stop(), stop() sets _drain_event, unblocking | ||
| this wait without leaving it hung indefinitely.""" | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] medium · error-handling The new docstring on async def _drain_inflight(self, phase_issuer: PhaseIssuer) -> None:
"""Wait for all in-flight responses from this phase to complete.
Bounded by the global experiment timeout: if the caller schedules a
loop.call_later that calls stop(), stop() sets _drain_event, unblocking
this wait without leaving it hung indefinitely."""
...
self._drain_event.clear()
await self._drain_event.wait()The The docstring suggests this case is solved; it isn't. Suggested fixes:
|
||
| if phase_issuer.inflight <= 0 or self._stop_requested: | ||
| return | ||
| logger.info("Draining %d in-flight responses...", phase_issuer.inflight) | ||
|
|
@@ -425,7 +418,10 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: | |
| self._drain_event.set() | ||
| if self._current_strategy: | ||
| self._current_strategy.on_query_complete(query_id) | ||
| if self._on_sample_complete: | ||
| if ( | ||
| self._on_sample_complete | ||
| and self._current_phase_type != PhaseType.WARMUP | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Review Council — Claude] low · data-integrity The self._publisher.publish(EventRecord(
event_type=SampleEventType.COMPLETE, ..., sample_uuid=query_id, ...
))When Fix: tag warmup queries with a phase marker on the |
||
| ): | ||
| self._on_sample_complete(resp) | ||
|
|
||
| elif isinstance(resp, StreamChunk): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Review Council — Claude] low · error-handling
Tokenizer load failure is silently downgraded to "no token metrics":
A user who explicitly passes
--tokenizer Xand thus expects token metrics will instead see a single WARNING line buried in the log, then a silent benchmark run with empty ISL/OSL/TPOT columns. By the time they check the report, the run is over.Fix: re-raise unless an explicit
--allow-tokenizer-fallbackflag is set, OR at least log at ERROR level and surface the missing-tokenizer state in the final report. Also use the module logger (logger = logging.getLogger(__name__)) rather than the rootloggingmodule.