From 59c4fde27d5d77a3a87c18eb458de93d2737da05 Mon Sep 17 00:00:00 2001 From: Vincent Gimenes Date: Wed, 13 May 2026 16:47:32 +0200 Subject: [PATCH 1/2] fix: exclude baseline runs from optimization n_trials budget Signed-off-by: Vincent Gimenes --- auto_tune_vllm/core/study_controller.py | 11 +--- tests/core/test_study_controller.py | 82 +++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 8 deletions(-) create mode 100644 tests/core/test_study_controller.py diff --git a/auto_tune_vllm/core/study_controller.py b/auto_tune_vllm/core/study_controller.py index fd20573..854523c 100644 --- a/auto_tune_vllm/core/study_controller.py +++ b/auto_tune_vllm/core/study_controller.py @@ -1101,6 +1101,9 @@ def _run_baseline_trials(self): Adds max-num-seqs when concurrency > 256. Baseline trials are now added to the Optuna study and appear in dashboard. + + Baseline runs do not consume ``n_trials``: that budget is for optimization + trials only (consistent with ``_collect_completed_trials`` for async paths). """ if not self.config.baseline or not self.config.baseline.enabled: logger.warning( @@ -1204,8 +1207,6 @@ def _run_baseline_trials(self): values=trial_result.objective_values, state=TrialState.COMPLETE, ) - # Count baseline trial as completed - self.completed_trials += 1 else: logger.error( f"❌ Baseline trial failed: " @@ -1217,8 +1218,6 @@ def _run_baseline_trials(self): values=None, state=TrialState.FAIL, ) - # Count baseline trial as completed (even if failed) - self.completed_trials += 1 # Clean up trial object cache if trial.number in self.trial_objects: @@ -1241,8 +1240,6 @@ def _run_baseline_trials(self): values=None, state=TrialState.FAIL, ) - # Count baseline trial as completed (even if timed out) - self.completed_trials += 1 # Clean up trial object cache if trial.number in self.trial_objects: del self.trial_objects[trial.number] @@ -1258,8 +1255,6 @@ def _run_baseline_trials(self): values=None, state=TrialState.FAIL, ) - # Count baseline trial as completed (even if excepted) - self.completed_trials += 1 # Clean up trial object cache if trial.number in self.trial_objects: del self.trial_objects[trial.number] diff --git a/tests/core/test_study_controller.py b/tests/core/test_study_controller.py new file mode 100644 index 0000000..893ab38 --- /dev/null +++ b/tests/core/test_study_controller.py @@ -0,0 +1,82 @@ +"""Unit tests for ``StudyController`` (baseline budget, orchestration hooks, etc.).""" + +from __future__ import annotations + +from pathlib import Path + +from typing_extensions import override + +from auto_tune_vllm.benchmarks.config import BenchmarkConfig +from auto_tune_vllm.core.config import ( + BaselineConfig, + OptimizationConfig, + StudyConfig, +) +from auto_tune_vllm.core.study_controller import StudyController +from auto_tune_vllm.core.trial import TrialConfig, TrialResult +from auto_tune_vllm.execution.backends import ExecutionBackend, JobHandle + + +class _ImmediateCompleteBackend(ExecutionBackend): + """Backend that completes any submitted trial on the first poll (no GPU / vLLM).""" + + def __init__(self) -> None: + self._pending: TrialConfig | None = None + + @override + def submit_trial(self, trial_config: TrialConfig) -> JobHandle: + self._pending = trial_config + return JobHandle(trial_config.trial_id, "immediate-job") + + @override + def poll_trials( + self, job_handles: list[JobHandle] + ) -> tuple[list[TrialResult], list[JobHandle]]: + if not job_handles or self._pending is None: + return [], list(job_handles) + cfg = self._pending + self._pending = None + result = TrialResult( + trial_id=cfg.trial_id, + trial_number=cfg.trial_number, + trial_type=cfg.trial_type, + objective_values=[1.0], + success=True, + ) + return [result], [] + + @override + def shutdown(self) -> None: + return None + + @override + def cleanup_all_trials(self) -> None: + return None + + +def test_baseline_trials_do_not_increment_completed_trials(tmp_path: Path) -> None: + """``n_trials`` counts optimization trials only; baselines must not shrink that budget.""" + optimization = OptimizationConfig( + preset="high_throughput", + sampler="random", + n_trials=5, + n_startup_trials=0, + max_concurrent_trials=1, + ) + config = StudyConfig( + study_name="test_study_controller_baseline_budget", + database_url=None, + optimization=optimization, + benchmark=BenchmarkConfig(model="dummy-model", max_seconds=1, rate=50), + baseline=BaselineConfig( + enabled=True, + concurrency_levels=[50, 100], + ), + storage_file=str(tmp_path / "optuna.db"), + ) + backend = _ImmediateCompleteBackend() + controller = StudyController.create_from_config(backend, config, create_db=False) + + assert controller.completed_trials == 0 + controller._run_baseline_trials() # pyright: ignore[reportPrivateUsage] + assert controller.completed_trials == 0 From 56a4a6cece1aeb30b4eb11f12a9dcd2bacb34d79 Mon Sep 17 00:00:00 2001 From: Vincent Gimenes <147169146+VincentG1234@users.noreply.github.com> Date: Wed, 20 May 2026 16:08:54 +0200 Subject: [PATCH 2/2] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- auto_tune_vllm/core/study_controller.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/auto_tune_vllm/core/study_controller.py b/auto_tune_vllm/core/study_controller.py index 854523c..5775a1c 100644 --- a/auto_tune_vllm/core/study_controller.py +++ b/auto_tune_vllm/core/study_controller.py @@ -1102,8 +1102,10 @@ def _run_baseline_trials(self): Adds max-num-seqs when concurrency > 256. Baseline trials are now added to the Optuna study and appear in dashboard. - Baseline runs do not consume ``n_trials``: that budget is for optimization - trials only (consistent with ``_collect_completed_trials`` for async paths). + Baseline runs do not consume the configured optimization budget + (``config.n_trials``), which remains reserved for optimization trials only. + However, because baseline runs are added to the Optuna study, any totals + derived from ``self.study.trials`` will include them unless filtered out. """ if not self.config.baseline or not self.config.baseline.enabled: logger.warning(