Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions auto_tune_vllm/core/study_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,11 @@ def _run_baseline_trials(self):

Adds max-num-seqs when concurrency > 256.
Baseline trials are now added to the Optuna study and appear in dashboard.

Baseline runs do not consume the configured optimization budget
(``config.n_trials``), which remains reserved for optimization trials only.
However, because baseline runs are added to the Optuna study, any totals
derived from ``self.study.trials`` will include them unless filtered out.
"""
if not self.config.baseline or not self.config.baseline.enabled:
logger.warning(
Expand Down Expand Up @@ -1204,8 +1209,6 @@ def _run_baseline_trials(self):
values=trial_result.objective_values,
state=TrialState.COMPLETE,
)
# Count baseline trial as completed
self.completed_trials += 1
else:
logger.error(
f"❌ Baseline trial failed: "
Expand All @@ -1217,8 +1220,6 @@ def _run_baseline_trials(self):
values=None,
state=TrialState.FAIL,
)
# Count baseline trial as completed (even if failed)
self.completed_trials += 1

# Clean up trial object cache
if trial.number in self.trial_objects:
Expand All @@ -1241,8 +1242,6 @@ def _run_baseline_trials(self):
values=None,
state=TrialState.FAIL,
)
# Count baseline trial as completed (even if timed out)
self.completed_trials += 1
# Clean up trial object cache
if trial.number in self.trial_objects:
del self.trial_objects[trial.number]
Expand All @@ -1258,8 +1257,6 @@ def _run_baseline_trials(self):
values=None,
state=TrialState.FAIL,
)
# Count baseline trial as completed (even if excepted)
self.completed_trials += 1
# Clean up trial object cache
if trial.number in self.trial_objects:
del self.trial_objects[trial.number]
Expand Down
82 changes: 82 additions & 0 deletions tests/core/test_study_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Unit tests for ``StudyController`` (baseline budget, orchestration hooks, etc.)."""

from __future__ import annotations

from pathlib import Path

from typing_extensions import override

from auto_tune_vllm.benchmarks.config import BenchmarkConfig
from auto_tune_vllm.core.config import (
BaselineConfig,
OptimizationConfig,
StudyConfig,
)
from auto_tune_vllm.core.study_controller import StudyController
from auto_tune_vllm.core.trial import TrialConfig, TrialResult
from auto_tune_vllm.execution.backends import ExecutionBackend, JobHandle


class _ImmediateCompleteBackend(ExecutionBackend):
"""Backend that completes any submitted trial on the first poll (no GPU / vLLM)."""

def __init__(self) -> None:
self._pending: TrialConfig | None = None

@override
def submit_trial(self, trial_config: TrialConfig) -> JobHandle:
self._pending = trial_config
return JobHandle(trial_config.trial_id, "immediate-job")

@override
def poll_trials(
self, job_handles: list[JobHandle]
) -> tuple[list[TrialResult], list[JobHandle]]:
if not job_handles or self._pending is None:
return [], list(job_handles)
cfg = self._pending
self._pending = None
result = TrialResult(
trial_id=cfg.trial_id,
trial_number=cfg.trial_number,
trial_type=cfg.trial_type,
objective_values=[1.0],
success=True,
)
return [result], []

@override
def shutdown(self) -> None:
return None

@override
def cleanup_all_trials(self) -> None:
return None


def test_baseline_trials_do_not_increment_completed_trials(tmp_path: Path) -> None:
"""``n_trials`` counts optimization trials only; baselines must not shrink that budget."""
optimization = OptimizationConfig(
preset="high_throughput",
sampler="random",
n_trials=5,
n_startup_trials=0,
max_concurrent_trials=1,
)
config = StudyConfig(
study_name="test_study_controller_baseline_budget",
database_url=None,
optimization=optimization,
benchmark=BenchmarkConfig(model="dummy-model", max_seconds=1, rate=50),
baseline=BaselineConfig(
enabled=True,
concurrency_levels=[50, 100],
),
storage_file=str(tmp_path / "optuna.db"),
)
backend = _ImmediateCompleteBackend()
controller = StudyController.create_from_config(backend, config, create_db=False)

assert controller.completed_trials == 0
controller._run_baseline_trials() # pyright: ignore[reportPrivateUsage]
assert controller.completed_trials == 0
Loading