diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 1cd0467be93..080c8f3a64d 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -424,6 +424,7 @@ explorer: return_partial_tasks: false dynamic_timeout: enable: false + warmup_min_steps: 1 ratio: 3.0 runner_state_report_interval: 0 ``` @@ -452,9 +453,10 @@ explorer: - `ratio`: Explorer will only wait for `(1 - ratio) * batch_size` of tasks at each step. Default is `0.0`, meaning waiting for all tasks. - `wait_after_min`: After reaching the minimum task threshold, wait for this many seconds before proceeding. Default is `30.0` seconds. - `return_partial_tasks`: Whether to return the results of tasks that have only completed partially (e.g., only some runs in GRPO). Default is `false`, meaning only return results of tasks that have completed all runs. -- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks. +- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each scheduled execution unit based on historical runtime statistics. - `enable`: Whether to enable dynamic timeout. Default is `false`. - - `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`. + - `warmup_min_steps`: The minimum number of fully observed non-eval steps required before dynamic timeout takes effect. Default is `1`. This is equivalent to a warmup batch/step count, and helps avoid enabling dynamic timeout based only on a few fast early completions. + - `ratio`: The timeout for each scheduled execution unit is dynamically set to `average_time_per_success_execution * ratio`. Default is `3.0`. - `runner_state_report_interval`: Workflow runner report interval (in seconds). If set to a value greater than `0`, the workflow runner will periodically report its status to the main explorer process and print it in the command line for monitoring. Default is `0`, meaning this feature is disabled. If you want to use this feature, it is recommended to set it to `10` seconds or longer to minimize performance impact. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 5192f1b2786..5792ca77d1c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -421,6 +421,7 @@ explorer: return_partial_tasks: false dynamic_timeout: enable: false + warmup_min_steps: 1 ratio: 3.0 runner_state_report_interval: 0 ``` @@ -449,9 +450,10 @@ explorer: - `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。 - `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。 - `return_partial_tasks`: 是否返回仅部分完成的任务结果(例如,在 GRPO 中仅完成部分 run 的任务)。默认为 `false`,表示仅返回已完成组内所有 run 的任务结果。 -- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。 +- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据历史执行耗时动态调整每个调度执行单元的超时时间。 - `enable`: 是否启用动态超时。默认为 `false`。 - - `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。 + - `warmup_min_steps`: 动态超时生效前至少需要观测到多少个完整结束的非评估 step。默认为 `1`。它等价于预热所需的 batch/step 数,可以避免只根据少量较快完成的早期任务就提前启用动态超时。 + - `ratio`: 每个调度执行单元的超时时间动态设置为 `average_time_per_success_execution * ratio`。默认为 `3.0`。 - `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。推荐如需使用此功能,将其设置为 `10` 秒或更长时间以减少对性能的影响。 --- diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 075c7ed0f6b..da170517d63 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -5,7 +5,12 @@ import os import random import shutil +import time +import unittest +from collections import deque from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock import httpx import ray @@ -34,6 +39,98 @@ from trinity.manager.state_manager import StateManager +def _build_fake_coordinator_explorer(): + class FakeRemoteMethod: + def __init__(self, func): + self.func = func + + async def remote(self, *args, **kwargs): + return await self.func(*args, **kwargs) + + class FakeCoordinator: + def __init__(self): + self.submit_calls = [] + self.finalize_train_calls = [] + self.finalize_eval_calls = [] + self.shutdown_calls = 0 + self.submit_batch = FakeRemoteMethod(self._submit_batch) + self.finalize_train_batch = FakeRemoteMethod(self._finalize_train_batch) + self.finalize_eval_batch = FakeRemoteMethod(self._finalize_eval_batch) + self.shutdown = FakeRemoteMethod(self._shutdown) + + async def _submit_batch(self, **kwargs): + self.submit_calls.append(kwargs) + + async def _finalize_train_batch(self, batch_id): + self.finalize_train_calls.append(batch_id) + return { + "batch_id": batch_id, + "batch_type": "train", + "finished_task_count": 1, + "metrics": { + "experience_pipeline/experience_count": 2.0, + "rollout/run_metrics/mean": float(batch_id), + "rollout/finished_task_count": 1.0, + }, + "finalize_reason": "complete", + "finalized": True, + } + + async def _finalize_eval_batch(self, batch_id): + self.finalize_eval_calls.append(batch_id) + eval_name = batch_id.split("/", 1)[1] + return { + "batch_id": batch_id, + "batch_type": "eval", + "finished_task_count": 2, + "metrics": { + f"eval/{eval_name}/accuracy": 0.5, + f"eval/{eval_name}/finished_task_count": 2.0, + }, + "finalize_reason": "complete", + "finalized": True, + } + + async def _shutdown(self): + self.shutdown_calls += 1 + + class FakeMonitor: + def __init__(self): + self.logged = [] + + def log(self, metric, step): + self.logged.append((step, metric)) + + feedback_calls = [] + + async def read_async(): + return [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)] + + def record_feedback(metrics): + feedback_calls.append(metrics) + + explorer = Explorer.__new__(Explorer) + explorer.logger = MagicMock() + explorer.rollout_coordinator = FakeCoordinator() + explorer.monitor = FakeMonitor() + explorer.taskset = SimpleNamespace(read_async=read_async, feedback=record_feedback) + explorer.min_wait_num = None + explorer.pending_eval_tasks = deque() + explorer.explore_start_time = None + explorer.eval_start_time = None + explorer.last_monitored_step = 0 + explorer.explore_step_num = 0 + explorer.model_version = 7 + explorer.detailed_stats = False + explorer.config = SimpleNamespace( + explorer=SimpleNamespace( + over_rollout=SimpleNamespace(return_partial_tasks=False), + eval_interval=1, + ) + ) + return explorer, feedback_calls + + class BaseExplorerCase(RayUnittestBase): def setUp(self): self.config = get_template_config() @@ -226,6 +323,71 @@ def test_explorer(self): ray.get(explorer.shutdown.remote()) +class TestExplorerCoordinatorPath(unittest.IsolatedAsyncioTestCase): + async def test_explore_step_submits_train_batch_to_rollout_coordinator(self): + explorer, _ = _build_fake_coordinator_explorer() + + should_continue = await explorer.explore_step() + + self.assertTrue(should_continue) + self.assertEqual(explorer.explore_step_num, 1) + self.assertEqual( + explorer.rollout_coordinator.submit_calls, + [ + { + "batch_id": 1, + "tasks": [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + "batch_type": "train", + "min_wait_num": None, + } + ], + ) + + async def test_finish_current_steps_uses_rollout_coordinator_finalize(self): + explorer, feedback_calls = _build_fake_coordinator_explorer() + explorer.explore_step_num = 2 + + await explorer.finish_current_steps() + + self.assertEqual(explorer.rollout_coordinator.finalize_train_calls, [1, 2]) + self.assertEqual(len(feedback_calls), 2) + self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) + self.assertEqual(explorer.last_monitored_step, 2) + + async def test_finish_eval_step_uses_rollout_coordinator_finalize(self): + explorer, _ = _build_fake_coordinator_explorer() + explorer.pending_eval_tasks.append((3, "eval_set")) + explorer.eval_start_time = time.time() + + await explorer._finish_eval_step(step=3) + + self.assertEqual(explorer.rollout_coordinator.finalize_eval_calls, ["3/eval_set"]) + self.assertEqual([step for step, _ in explorer.monitor.logged], [3]) + self.assertIn("eval/eval_set/accuracy", explorer.monitor.logged[0][1]) + + +class TestExplorerCoordinatorPolicies(unittest.IsolatedAsyncioTestCase): + async def test_over_rollout_submits_partial_finalize_policy_to_rollout_coordinator(self): + explorer, _ = _build_fake_coordinator_explorer() + explorer.min_wait_num = 1 + explorer.config.explorer.over_rollout.return_partial_tasks = True + + should_continue = await explorer.explore_step() + + self.assertTrue(should_continue) + self.assertEqual( + explorer.rollout_coordinator.submit_calls, + [ + { + "batch_id": 1, + "tasks": [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + "batch_type": "train", + "min_wait_num": 1, + } + ], + ) + + def run_serve(config): config.check_and_update() run_stage(config) diff --git a/tests/explorer/rollout_coordinator_test.py b/tests/explorer/rollout_coordinator_test.py new file mode 100644 index 00000000000..3afd8dbc8b2 --- /dev/null +++ b/tests/explorer/rollout_coordinator_test.py @@ -0,0 +1,389 @@ +"""Unit tests for RolloutCoordinator.""" + +import asyncio +import unittest +from collections import defaultdict +from types import SimpleNamespace + +from trinity.explorer.rollout_coordinator import RolloutCoordinator +from trinity.explorer.scheduler import CompletedTaskResult +from trinity.explorer.workflow_runner import Status + + +class FakePipeline: + """Minimal in-memory pipeline double for coordinator tests.""" + + def __init__(self): + """Initialize call tracking state.""" + + self.stage_calls = [] + self.finalize_calls = [] + self.process_chunk_calls = [] + self.abort_calls = [] + self.prepare_called = False + self.close_called = False + self.staged_payloads = defaultdict(dict) + + async def prepare(self): + """Record pipeline preparation.""" + + self.prepare_called = True + + async def stage_task_payloads(self, batch_id, task_id, exp_chunks): + """Record task payload staging.""" + + chunks = [chunk for chunk in exp_chunks if chunk] + self.stage_calls.append((batch_id, task_id, chunks)) + if not chunks: + return None + self.staged_payloads[batch_id][task_id] = chunks + return f"{batch_id}:{task_id}" + + async def finalize_batch(self, batch_id, task_ids=None): + """Record batch finalization.""" + + staged = self.staged_payloads.get(batch_id, {}) + selected_task_ids = sorted(staged) if task_ids is None else list(task_ids) + self.finalize_calls.append((batch_id, selected_task_ids)) + exp_chunks = [] + for task_id in selected_task_ids: + exp_chunks.extend(staged.pop(task_id, [])) + if batch_id in self.staged_payloads and not self.staged_payloads[batch_id]: + del self.staged_payloads[batch_id] + return {"experience_pipeline/experience_count": float(len(exp_chunks))} + + async def process_serialized_chunks(self, exp_chunks): + """Record serialized chunk processing.""" + + chunks = list(exp_chunks) + self.process_chunk_calls.append(chunks) + return {"experience_pipeline/experience_count": float(len(chunks))} + + async def abort_batch(self, batch_id): + """Record batch abort cleanup.""" + + self.abort_calls.append(batch_id) + self.staged_payloads.pop(batch_id, None) + + async def close(self): + """Record pipeline closure.""" + + self.close_called = True + + +class FakeScheduler: + """Minimal scheduler double for coordinator tests.""" + + def __init__(self): + """Initialize scheduler state and recorded calls.""" + + self.default_timeout = 1.0 + self.started = False + self.stopped = False + self.schedule_calls = [] + self.abort_calls = [] + self.completed_task_results = asyncio.Queue() + self.batch_results = {} + self.get_statuses_calls = [] + + async def start(self): + """Mark the scheduler as started.""" + + self.started = True + + async def stop(self): + """Mark the scheduler as stopped.""" + + self.stopped = True + + def schedule(self, tasks, batch_id): + """Record scheduled tasks for one batch.""" + + self.schedule_calls.append((batch_id, list(tasks))) + + async def wait_completed_task(self, timeout=None): + """Return the next queued completed task result.""" + + try: + if timeout is None: + return await self.completed_task_results.get() + return await asyncio.wait_for(self.completed_task_results.get(), timeout=timeout) + except asyncio.TimeoutError: + return None + + async def get_statuses( + self, + batch_id, + min_num=None, + timeout=None, + clear_timeout_tasks=True, + return_partial_tasks=False, + ): + """Return only preconfigured statuses for one batch.""" + + self.get_statuses_calls.append( + { + "batch_id": batch_id, + "min_num": min_num, + "timeout": timeout, + "clear_timeout_tasks": clear_timeout_tasks, + "return_partial_tasks": return_partial_tasks, + } + ) + statuses, _ = self.batch_results.pop(batch_id, ([], [])) + return statuses + + async def abort_batch(self, batch_id, return_partial_tasks=False, restart_runners=True): + """Record one scheduler abort request.""" + + self.abort_calls.append( + { + "batch_id": batch_id, + "return_partial_tasks": return_partial_tasks, + "restart_runners": restart_runners, + } + ) + + def emit_completed_task(self, batch_id, task_id, result): + """Queue one completed task event and result.""" + + self.completed_task_results.put_nowait(result) + + +def _build_config(wait_after_min=0.0, return_partial_tasks=True, detailed_stats=False): + return SimpleNamespace( + explorer=SimpleNamespace( + over_rollout=SimpleNamespace( + wait_after_min=wait_after_min, + return_partial_tasks=return_partial_tasks, + ) + ), + monitor=SimpleNamespace(detailed_stats=detailed_stats), + ) + + +def _build_status(metric_value): + return Status( + completed_runs=1, + total_runs=1, + metrics=[{"run_metrics": float(metric_value)}], + ) + + +class CoordinatorHarness(RolloutCoordinator): + """Coordinator subclass that injects fake owned dependencies.""" + + def __init__(self, config, rollout_model, auxiliary_models=None, *, pipeline, scheduler): + """Store doubles and delegate the main initialization to the parent.""" + + self._test_pipeline = pipeline + self._test_scheduler = scheduler + super().__init__(config, rollout_model, auxiliary_models) + + def _init_experience_pipeline(self): + """Return the injected fake pipeline.""" + + return self._test_pipeline + + def _init_scheduler(self): + """Return the injected fake scheduler.""" + + return self._test_scheduler + + +class TestRolloutCoordinator(unittest.IsolatedAsyncioTestCase): + """Focused behavioral tests for the first coordinator implementation.""" + + async def asyncSetUp(self): + """Create one coordinator wired to fake owned dependencies.""" + + self.scheduler = FakeScheduler() + self.pipeline = FakePipeline() + self.coordinator = CoordinatorHarness( + _build_config(), + rollout_model=[], + pipeline=self.pipeline, + scheduler=self.scheduler, + ) + await self.coordinator.prepare() + + async def asyncTearDown(self): + """Shutdown the coordinator after each test.""" + + await self.coordinator.shutdown() + + async def test_finalize_train_batch_tracks_scheduler_events_and_is_idempotent(self): + """Train finalize should consume task events once and reuse the final result.""" + + await self.coordinator.submit_batch( + batch_id=1, + tasks=[SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + batch_type="train", + ) + + self.scheduler.emit_completed_task( + 1, + 0, + CompletedTaskResult( + batch_id=1, + task_id=0, + status=_build_status(10.0), + experience_payloads=[b"payload-0"], + ), + ) + self.scheduler.emit_completed_task( + 1, + 1, + CompletedTaskResult( + batch_id=1, + task_id=1, + status=_build_status(20.0), + experience_payloads=[b"payload-1"], + ), + ) + + result = await self.coordinator.finalize_train_batch(1, timeout=1.0) + self.assertEqual(result["finalize_reason"], "complete") + self.assertEqual(result["finished_task_count"], 2) + self.assertEqual(result["metrics"]["rollout/run_metrics/mean"], 15.0) + self.assertEqual(result["metrics"]["experience_pipeline/experience_count"], 2.0) + self.assertTrue(self.pipeline.prepare_called) + self.assertEqual(len(self.pipeline.stage_calls), 2) + self.assertEqual(self.pipeline.finalize_calls, [(1, [0, 1])]) + self.assertNotIn(1, self.coordinator.pending_batches) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_train_batch(1, timeout=1.0) + + async def test_finalize_train_batch_supports_partial_finalize(self): + """Train finalize should allow partial completion when policy permits it.""" + + await self.coordinator.submit_batch( + batch_id=2, + tasks=[SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + batch_type="train", + min_wait_num=1, + ) + + self.scheduler.emit_completed_task( + 2, + 0, + CompletedTaskResult( + batch_id=2, + task_id=0, + status=_build_status(7.0), + experience_payloads=[b"payload-0"], + ), + ) + + result = await self.coordinator.finalize_train_batch(2, timeout=1.0) + + self.assertEqual(result["finalize_reason"], "partial") + self.assertEqual(result["finished_task_count"], 1) + self.assertEqual(self.pipeline.finalize_calls[-1], (2, [0])) + self.assertEqual(self.scheduler.abort_calls[-1]["batch_id"], 2) + self.assertEqual(self.pipeline.abort_calls[-1], 2) + self.assertNotIn(2, self.coordinator.pending_batches) + + async def test_finalize_eval_batch_aggregates_eval_metrics(self): + """Eval finalize should aggregate scheduler results without pipeline writes.""" + + batch_id = "3/eval_set" + await self.coordinator.submit_batch( + batch_id=batch_id, + tasks=[SimpleNamespace(is_eval=True), SimpleNamespace(is_eval=True)], + batch_type="eval", + ) + self.scheduler.batch_results[batch_id] = ( + [_build_status(3.0), _build_status(5.0)], + [], + ) + + result = await self.coordinator.finalize_eval_batch(batch_id, timeout=1.0) + + self.assertEqual(result["finalize_reason"], "complete") + self.assertEqual(result["finished_task_count"], 2) + self.assertEqual(result["metrics"]["eval/eval_set/run_metrics"], 4.0) + self.assertEqual(self.pipeline.finalize_calls, []) + self.assertEqual(self.scheduler.get_statuses_calls[0]["batch_id"], batch_id) + self.assertNotIn(batch_id, self.coordinator.pending_batches) + + async def test_finalize_train_batch_rejects_eval_batches_before_waiting(self): + """Train finalize should reject an active eval batch instead of entering wait logic.""" + + batch_id = "4/eval_set" + await self.coordinator.submit_batch( + batch_id=batch_id, + tasks=[SimpleNamespace(is_eval=True)], + batch_type="eval", + ) + + with self.assertRaisesRegex(ValueError, "expected train"): + await self.coordinator.finalize_train_batch(batch_id, timeout=0.1) + + async def test_terminal_batches_are_not_reusable_after_finalize(self): + """A finalized batch should be evicted instead of being cached for later reuse.""" + + eval_batch_id = "5/eval_set" + await self.coordinator.submit_batch( + batch_id=eval_batch_id, + tasks=[SimpleNamespace(is_eval=True)], + batch_type="eval", + ) + self.scheduler.batch_results[eval_batch_id] = ([_build_status(3.0)], []) + await self.coordinator.finalize_eval_batch(eval_batch_id, timeout=1.0) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_train_batch(eval_batch_id, timeout=0.1) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_eval_batch(eval_batch_id, timeout=0.1) + + await self.coordinator.submit_batch( + batch_id=6, + tasks=[SimpleNamespace(is_eval=False)], + batch_type="train", + ) + self.scheduler.emit_completed_task( + 6, + 0, + CompletedTaskResult( + batch_id=6, + task_id=0, + status=_build_status(11.0), + experience_payloads=[b"payload-0"], + ), + ) + await self.coordinator.finalize_train_batch(6, timeout=1.0) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_eval_batch(6, timeout=0.1) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_train_batch(6, timeout=0.1) + + async def test_abort_batch_marks_batch_aborted_and_evicts_it(self): + """Abort should cleanup the batch immediately instead of caching a terminal result.""" + + await self.coordinator.submit_batch( + batch_id=4, + tasks=[SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + batch_type="train", + ) + + await self.coordinator.abort_batch(4, reason="shutdown") + + self.assertEqual(self.scheduler.abort_calls[0]["batch_id"], 4) + self.assertEqual(self.pipeline.abort_calls, [4]) + self.assertNotIn(4, self.coordinator.pending_batches) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_train_batch(4, timeout=0.1) + + async def test_shutdown_closes_internal_dependencies(self): + """Shutdown should close both owned scheduler and owned pipeline.""" + + await self.coordinator.shutdown() + + self.assertTrue(self.scheduler.stopped) + self.assertTrue(self.pipeline.close_called) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 430eb7ec7d6..a25d7c604d9 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -15,7 +15,7 @@ from trinity.common.experience import EID, Experience from trinity.common.models.model import InferenceModel, ModelWrapper from trinity.common.workflows import WORKFLOWS, Task, Workflow -from trinity.explorer.scheduler import Scheduler +from trinity.explorer.scheduler import CompletedTaskResult, Scheduler @WORKFLOWS.register_module("dummy_workflow") @@ -351,6 +351,31 @@ def get_api_server_url(self) -> str: return "http://localhost:12345" +@ray.remote +class DummyPayloadStage: + def __init__(self): + self.staged_payloads = defaultdict(dict) + + async def stage_task_payloads(self, batch_id, task_id: int, exp_chunks: list[bytes]): + self.staged_payloads[batch_id][task_id] = list(exp_chunks) + return f"{batch_id}:{task_id}" + + async def take_staged_task_payloads(self, batch_id, task_ids: list[int]) -> list[bytes]: + batch_payloads = self.staged_payloads.get(batch_id, {}) + exp_chunks = [] + for task_id in task_ids: + exp_chunks.extend(batch_payloads.pop(task_id, [])) + if batch_id in self.staged_payloads and not self.staged_payloads[batch_id]: + del self.staged_payloads[batch_id] + return exp_chunks + + async def get_staged_task_ids(self, batch_id): + return sorted(self.staged_payloads.get(batch_id, {}).keys()) + + async def abort_batch(self, batch_id): + self.staged_payloads.pop(batch_id, None) + + def generate_tasks( total_num: int, timeout_num: int = 0, @@ -409,6 +434,16 @@ def generate_tasks( return tasks +async def collect_results(scheduler: Scheduler, **kwargs): + """Collect serialized payloads and materialize experiences for assertions.""" + + statuses, payloads = await scheduler.get_payload_results(**kwargs) + experiences = [] + for payload in payloads: + experiences.extend(Experience.deserialize_many(payload)) + return statuses, experiences + + class SchedulerTest(unittest.IsolatedAsyncioTestCase): def setUp(self): ray.init(ignore_reinit_error=True) @@ -431,17 +466,17 @@ def setUp(self): self.config.algorithm.repeat_times = 1 self.config.check_and_update() - async def test_get_results(self): + async def test_get_payload_results(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() tasks = generate_tasks(8) scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=8, timeout=20) self.assertEqual(len(statuses), 8) self.assertEqual(len(exps), 8) - _, exps = await scheduler.get_results(batch_id=0, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=0, min_num=1, timeout=1) self.assertEqual(len(exps), 0) for result in statuses: @@ -453,17 +488,19 @@ async def test_get_results(self): for batch_id in range(1, 4): self.assertTrue(scheduler.has_step(batch_id)) - statuses, exps = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10) + statuses, exps = await collect_results( + scheduler, batch_id=batch_id, min_num=4, timeout=10 + ) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4) self.assertFalse(scheduler.has_step(batch_id)) - _, exps = await scheduler.get_results(batch_id=0, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=0, min_num=1, timeout=1) self.assertEqual(len(exps), 0) tasks = generate_tasks(3) scheduler.schedule(tasks, batch_id=4) self.assertTrue(scheduler.has_step(4)) - statuses, exps = await scheduler.get_results(batch_id=4) + statuses, exps = await collect_results(scheduler, batch_id=4) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3) self.assertFalse(scheduler.has_step(4)) @@ -473,7 +510,7 @@ async def test_get_results(self): scheduler.schedule(tasks, batch_id=0) start_time = time.time() - statuses, exps = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=4, timeout=3) end_time = time.time() self.assertLessEqual(end_time - start_time, 15) # sync wait for runner restart @@ -485,50 +522,50 @@ async def test_get_results(self): scheduler.schedule(tasks, batch_id=0) # actor restart is slow, set a big timeout - statuses, exps = await scheduler.get_results(batch_id=0, timeout=20) + statuses, exps = await collect_results(scheduler, batch_id=0, timeout=20) self.assertEqual(len(statuses), 4) success_count = sum(1 for r in statuses if r.ok) self.assertEqual(success_count, 4) self.assertEqual(len(exps), 4) - _, exps = await scheduler.get_results(batch_id=0, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=0, min_num=1, timeout=1) self.assertEqual(len(exps), 0) # test exception tasks tasks = generate_tasks(1, exception_num=3) scheduler.schedule(tasks, batch_id=1) - statuses, exps = await scheduler.get_results(batch_id=1, timeout=5) + statuses, exps = await collect_results(scheduler, batch_id=1, timeout=5) self.assertEqual(len(statuses), 4) success_count = sum(1 for r in statuses if r.ok) self.assertEqual(success_count, 1) self.assertEqual(len(exps), 1) - _, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=1, min_num=1, timeout=1) self.assertEqual(len(exps), 0) # test _cleanup_batch_and_restart_runners: part I, no clear tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) scheduler.schedule(tasks, batch_id=2) - statuses, exps = await scheduler.get_results( - batch_id=2, timeout=2, clear_timeout_tasks=False + statuses, exps = await collect_results( + scheduler, batch_id=2, timeout=2, clear_timeout_tasks=False ) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3) - statuses, exps = await scheduler.get_results( - batch_id=2, timeout=2, clear_timeout_tasks=False + statuses, exps = await collect_results( + scheduler, batch_id=2, timeout=2, clear_timeout_tasks=False ) self.assertEqual(len(statuses), 1) self.assertEqual(len(exps), 1) # test _cleanup_batch_and_restart_runners: part II, clear tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) scheduler.schedule(tasks, batch_id=3) - statuses, exps = await scheduler.get_results(batch_id=3, timeout=2) + statuses, exps = await collect_results(scheduler, batch_id=3, timeout=2) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3) - statuses, exps = await scheduler.get_results(batch_id=3, timeout=2) + statuses, exps = await collect_results(scheduler, batch_id=3, timeout=2) self.assertEqual(len(statuses), 0) self.assertEqual(len(exps), 0) - _, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=3, min_num=1, timeout=1) self.assertEqual(len(exps), 0) await scheduler.stop() @@ -552,8 +589,8 @@ async def test_wait_all(self): self.assertEqual(len(scheduler.pending_tasks), 0) self.assertEqual(len(scheduler.running_tasks), 0) - status0, exps0 = await scheduler.get_results(batch_id=0, min_num=4, timeout=1) - status1, exps1 = await scheduler.get_results(batch_id=1, min_num=3, timeout=1) + status0, exps0 = await collect_results(scheduler, batch_id=0, min_num=4, timeout=1) + status1, exps1 = await collect_results(scheduler, batch_id=1, min_num=3, timeout=1) self.assertEqual(len(status0), 4) self.assertEqual(len(status1), 3) @@ -609,7 +646,9 @@ async def test_concurrent_operations(self): async def schedule_tasks(batch_id, num_tasks): tasks = generate_tasks(num_tasks) scheduler.schedule(tasks, batch_id=batch_id) - return await scheduler.get_results(batch_id=batch_id, min_num=num_tasks, timeout=10) + return await collect_results( + scheduler, batch_id=batch_id, min_num=num_tasks, timeout=10 + ) results = await asyncio.gather( schedule_tasks(0, 3), @@ -629,7 +668,7 @@ async def test_scheduler_restart_after_stop(self): await scheduler.start() tasks = generate_tasks(2) scheduler.schedule(tasks, batch_id=0) - results, exps = await scheduler.get_results(batch_id=0, min_num=2, timeout=10) + results, exps = await collect_results(scheduler, batch_id=0, min_num=2, timeout=10) self.assertEqual(len(results), 2) self.assertEqual(len(exps), 2) await scheduler.stop() @@ -637,7 +676,7 @@ async def test_scheduler_restart_after_stop(self): await scheduler.start() tasks = generate_tasks(3, repeat_times=2) scheduler.schedule(tasks, batch_id=1) - results, exps = await scheduler.get_results(batch_id=1, min_num=3, timeout=10) + results, exps = await collect_results(scheduler, batch_id=1, min_num=3, timeout=10) self.assertEqual(len(results), 3) self.assertEqual(len(exps), 3 * 2) await scheduler.stop() @@ -648,13 +687,13 @@ async def test_scheduler_all_methods(self): tasks = generate_tasks(8) scheduler.schedule(tasks, batch_id=0) self.assertTrue(scheduler.has_step(0)) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=8, timeout=20) self.assertEqual(len(statuses), 8) self.assertEqual(len(exps), 8) scheduler.schedule(tasks, batch_id=1) scheduler.schedule(tasks[:4], batch_id=2) self.assertFalse(scheduler.has_step(0)) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=8) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=8) self.assertFalse(scheduler.has_step(0)) self.assertEqual(len(statuses), 0) # batch_id 0 has no more tasks self.assertEqual(len(exps), 0) @@ -663,7 +702,7 @@ async def test_scheduler_all_methods(self): self.assertTrue(scheduler.has_step(2)) await scheduler.wait_all() st = time.time() - statuses, exps = await scheduler.get_results(batch_id=1) + statuses, exps = await collect_results(scheduler, batch_id=1) et = time.time() self.assertTrue(et - st < 1.0) self.assertEqual(len(statuses), 8) @@ -671,7 +710,7 @@ async def test_scheduler_all_methods(self): self.assertFalse(scheduler.has_step(1)) self.assertTrue(scheduler.has_step(2)) st = time.time() - statuses, exps = await scheduler.get_results(batch_id=2) + statuses, exps = await collect_results(scheduler, batch_id=2) et = time.time() self.assertTrue(et - st < 1.0) self.assertEqual(len(statuses), 4) @@ -688,29 +727,29 @@ async def test_split_tasks(self): tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4 scheduler.schedule(tasks, batch_id=1) - statuses, exps = await scheduler.get_results(batch_id=1) + statuses, exps = await collect_results(scheduler, batch_id=1) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4 * 8) exp_list.extend(exps) - _, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=1, min_num=1, timeout=1) self.assertEqual(len(exps), 0) tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3 scheduler.schedule(tasks, batch_id=2) - statuses, exps = await scheduler.get_results(batch_id=2) + statuses, exps = await collect_results(scheduler, batch_id=2) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4 * 5) exp_list.extend(exps) - _, exps = await scheduler.get_results(batch_id=2, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=2, min_num=1, timeout=1) self.assertEqual(len(exps), 0) tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1 scheduler.schedule(tasks, batch_id=3) - statuses, exps = await scheduler.get_results(batch_id=3) + statuses, exps = await collect_results(scheduler, batch_id=3) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3 * 1) exp_list.extend(exps) - _, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=3, min_num=1, timeout=1) self.assertEqual(len(exps), 0) # test task_id, run_id and unique_id @@ -733,7 +772,7 @@ async def test_multi_step_execution(self): n_steps = 3 for i in range(1, n_steps + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 2 * 4) @@ -751,7 +790,7 @@ async def test_non_repeatable_workflow(self): exp_list = [] for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times) exp_list.extend(exps) @@ -792,7 +831,7 @@ async def test_async_workflow(self): exp_list = [] for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -822,7 +861,7 @@ async def test_stepwise_experience_eid(self): exp_list = [] for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -842,7 +881,7 @@ async def test_stepwise_experience_eid(self): exp_list = [] for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -870,7 +909,7 @@ async def test_metric_calculation_with_repeatable_workflow(self, max_repeat_time tasks.extend(generate_tasks(total_num=1, step_num=1, repeat_times=4, repeatable=True)) tasks.extend(generate_tasks(total_num=1, step_num=4, repeat_times=8, repeatable=True)) scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0) + statuses, exps = await collect_results(scheduler, batch_id=0) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 1 * 4 * 1 + 1 * 8 * 4) expected_run_metrics = set({1.5, 3.5}) # (0+1+2+3)/4 and (0+1+2+3+4+5+6+7)/8 @@ -896,7 +935,7 @@ async def test_metric_calculation_with_non_repeatable_workflow( tasks.extend(generate_tasks(total_num=1, step_num=8, repeat_times=5, repeatable=False)) tasks[-1].workflow_args["metrics"] = [2 * i for i in range(8)] scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0) + statuses, exps = await collect_results(scheduler, batch_id=0) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 1 * 4 * 3 + 1 * 5 * 8) # (1+2+3)/3 = 2.0 @@ -919,7 +958,7 @@ async def test_over_rollout_min_wait(self): tasks.extend(generate_tasks(0, timeout_num=1, repeat_times=1, timeout_seconds=3)) tasks.extend(generate_tasks(0, timeout_num=1, repeat_times=1, timeout_seconds=6)) scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=2) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=2) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3 * 1) @@ -966,7 +1005,8 @@ async def test_over_rollout_return_partial_tasks(self): scheduler.schedule(tasks, batch_id=0) start_time = time.time() - statuses, exps = await scheduler.get_results( + statuses, exps = await collect_results( + scheduler, batch_id=0, min_num=1, timeout=2, @@ -1016,7 +1056,7 @@ async def test_over_rollout_return_partial_tasks(self): self.assertEqual(partial_status.metrics[0]["run_metrics"], 10.0) self.assertIn("1/3 runs completed successfully", partial_status.message) - statuses, exps = await scheduler.get_results(batch_id=0, timeout=1) + statuses, exps = await collect_results(scheduler, batch_id=0, timeout=1) self.assertEqual(len(statuses), 0) self.assertEqual(len(exps), 0) @@ -1055,7 +1095,8 @@ async def test_over_rollout_async_cancelled_runner_accepts_next_batch(self): ] scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results( + statuses, exps = await collect_results( + scheduler, batch_id=0, min_num=1, timeout=3, @@ -1068,7 +1109,9 @@ async def test_over_rollout_async_cancelled_runner_accepts_next_batch(self): follow_up_tasks = generate_tasks(2) scheduler.schedule(follow_up_tasks, batch_id=1) start_time = time.time() - next_statuses, next_exps = await scheduler.get_results(batch_id=1, min_num=2, timeout=2) + next_statuses, next_exps = await collect_results( + scheduler, batch_id=1, min_num=2, timeout=2 + ) elapsed = time.time() - start_time self.assertEqual(len(next_statuses), 2) @@ -1110,7 +1153,8 @@ async def test_over_rollout_sync_cancel_does_not_imply_immediate_runner_reuse(se ] scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results( + statuses, exps = await collect_results( + scheduler, batch_id=0, min_num=1, timeout=3, @@ -1123,7 +1167,8 @@ async def test_over_rollout_sync_cancel_does_not_imply_immediate_runner_reuse(se follow_up_tasks = generate_tasks(2) scheduler.schedule(follow_up_tasks, batch_id=1) start_time = time.time() - next_statuses, next_exps = await scheduler.get_results( + next_statuses, next_exps = await collect_results( + scheduler, batch_id=1, min_num=2, timeout=0.5, @@ -1136,7 +1181,9 @@ async def test_over_rollout_sync_cancel_does_not_imply_immediate_runner_reuse(se self.assertGreaterEqual(elapsed, 0.5) self.assertTrue(scheduler.has_step(1)) - next_statuses, next_exps = await scheduler.get_results(batch_id=1, min_num=2, timeout=4) + next_statuses, next_exps = await collect_results( + scheduler, batch_id=1, min_num=2, timeout=4 + ) self.assertEqual(len(next_statuses), 2) self.assertEqual(len(next_exps), 2) @@ -1154,7 +1201,7 @@ async def test_timeout_cleanup_still_restarts_runner(self): scheduler.schedule(tasks, batch_id=0) with patch.object(scheduler, "_restart_runner", new=AsyncMock()) as restart_runner_mock: - await scheduler.get_results(batch_id=0, timeout=1) + await collect_results(scheduler, batch_id=0, timeout=1) self.assertGreaterEqual(restart_runner_mock.await_count, 1) @@ -1189,6 +1236,7 @@ async def fake_restart_runner(runner_id): async def test_dynamic_timeout(self): self.config.explorer.dynamic_timeout.enable = True + self.config.explorer.dynamic_timeout.warmup_min_steps = 1 self.config.explorer.dynamic_timeout.ratio = 3.0 self.config.buffer.batch_size = 4 self.config.explorer.max_timeout = 20 @@ -1202,23 +1250,24 @@ async def test_dynamic_timeout(self): scheduler.schedule( tasks, batch_id="0/eval" ) # eval tasks will not count into dynamic timeout - statuses, exps = await scheduler.get_results(batch_id="0/eval") + statuses, exps = await collect_results(scheduler, batch_id="0/eval") self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 0) self.assertEqual(scheduler.total_running_time, 0) + self.assertEqual(scheduler.total_completed_steps, 0) self.assertEqual(scheduler.total_completed_tasks, 0) tasks = [] # generate 4 tasks that will run 1 second tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=1)) scheduler.schedule(tasks, batch_id=0) # first step will not use dynamic timeout - statuses, exps = await scheduler.get_results(batch_id=0) + statuses, exps = await collect_results(scheduler, batch_id=0) self.assertEqual(len(statuses), 4) # dynamic timeout will be set to 3.0 * 1.0 = 3.0 seconds for next step tasks = [] tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=4)) st = time.time() scheduler.schedule(tasks, batch_id=1) - statuses, exps = await scheduler.get_results(batch_id=1) + statuses, exps = await collect_results(scheduler, batch_id=1) et = time.time() self.assertTrue( et - st < 4 @@ -1229,10 +1278,151 @@ async def test_dynamic_timeout(self): tasks = [] tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=2)) scheduler.schedule(tasks, batch_id=2) - statuses, exps = await scheduler.get_results(batch_id=2) + statuses, exps = await collect_results(scheduler, batch_id=2) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4) + async def test_dynamic_timeout_warmup_min_steps_uses_completed_steps(self): + self.config.explorer.dynamic_timeout.enable = True + self.config.explorer.dynamic_timeout.warmup_min_steps = 2 + self.config.explorer.dynamic_timeout.ratio = 1.5 + self.config.buffer.batch_size = 3 + self.config.explorer.max_timeout = 20 + self.config.explorer.max_retry_times = 0 + self.config.explorer.max_repeat_times_per_runner = 2 + self.config.check_and_update() + + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks = generate_tasks(0, timeout_num=2, repeat_times=4, timeout_seconds=1) + scheduler.schedule(tasks, batch_id=0) + statuses, exps = await collect_results(scheduler, batch_id=0) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 8) + self.assertEqual(scheduler.total_completed_steps, 1) + self.assertEqual(scheduler.total_completed_tasks, 2) + self.assertEqual(scheduler.total_completed_sub_tasks, 4) + self.assertEqual(scheduler.dynamic_timeout(), scheduler.default_timeout) + + tasks = generate_tasks(0, timeout_num=2, repeat_times=4, timeout_seconds=1) + scheduler.schedule(tasks, batch_id=1) + statuses, exps = await collect_results(scheduler, batch_id=1) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 8) + self.assertEqual(scheduler.total_completed_steps, 2) + self.assertAlmostEqual(scheduler.dynamic_timeout(), 1.5, delta=0.8) + + await scheduler.stop() + + async def test_completed_task_events_return_full_task_results_directly(self): + scheduler = Scheduler( + self.config, + [DummyModel.remote(), DummyModel.remote()], + emit_completed_task_events=True, + ) + await scheduler.start() + + scheduler.schedule(generate_tasks(4, repeat_times=2), batch_id=0) + + completed_results = [] + for _ in range(4): + completed_result = await scheduler.wait_completed_task(timeout=10) + self.assertIsNotNone(completed_result) + self.assertIsInstance(completed_result, CompletedTaskResult) + completed_results.append(completed_result) + + self.assertEqual({result.batch_id for result in completed_results}, {0}) + self.assertEqual({result.task_id for result in completed_results}, {0, 1, 2, 3}) + self.assertNotIn(0, scheduler.completed_tasks) + + statuses = [result.status for result in completed_results] + exps = [] + for result in completed_results: + for payload in result.experience_payloads: + exps.extend(Experience.deserialize_many(payload)) + + self.assertEqual(len(statuses), 4) + self.assertEqual(len(exps), 8) + + await scheduler.stop() + + async def test_collect_results_reads_payloads_returned_by_workflow_runner(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + scheduler.schedule(generate_tasks(3, repeat_times=2), batch_id=0) + + statuses, exps = await collect_results(scheduler, batch_id=0, timeout=10) + self.assertEqual(len(statuses), 3) + self.assertEqual(len(exps), 6) + + await scheduler.stop() + + async def test_timeout_cleanup_keeps_completed_payloads_local(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + scheduler.schedule(generate_tasks(1, timeout_num=1, timeout_seconds=10), batch_id=0) + + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=2, timeout=1) + self.assertEqual(len(statuses), 1) + self.assertEqual(len(exps), 1) + + await scheduler.stop() + + async def test_eval_tasks_do_not_return_training_experiences(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + eval_tasks = generate_tasks(2, repeat_times=2) + for task in eval_tasks: + task.is_eval = True + + scheduler.schedule(eval_tasks, batch_id="0/eval") + statuses, exps = await collect_results(scheduler, batch_id="0/eval", timeout=10) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 0) + + await scheduler.stop() + + async def test_get_statuses_skips_payload_deserialization(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + scheduler.schedule(generate_tasks(2, repeat_times=2), batch_id=0) + + with patch( + "trinity.common.experience.Experience.deserialize_many", + side_effect=AssertionError("payload deserialization should not happen"), + ): + statuses = await scheduler.get_statuses(batch_id=0, timeout=10) + + self.assertEqual(len(statuses), 2) + + await scheduler.stop() + + async def test_get_payload_results_keeps_payloads_serialized(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + scheduler.schedule(generate_tasks(2, repeat_times=2), batch_id=0) + + with patch( + "trinity.common.experience.Experience.deserialize_many", + side_effect=AssertionError("payload deserialization should not happen"), + ): + statuses, payloads = await scheduler.get_payload_results(batch_id=0, timeout=10) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(payloads), 2) + self.assertTrue(all(isinstance(payload, bytes) for payload in payloads)) + + await scheduler.stop() + def tearDown(self): try: ray.shutdown() @@ -1291,5 +1481,5 @@ async def monitor_routine(): await asyncio.gather( monitor_routine(), - scheduler.get_results(batch_id=0), + collect_results(scheduler, batch_id=0), ) diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index 8052d8ccd1e..7f7510681d2 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -1,6 +1,7 @@ import asyncio import time import traceback +from collections import defaultdict from typing import Dict, Optional from trinity.buffer.buffer import BufferWriter, get_buffer_reader, get_buffer_writer @@ -46,6 +47,7 @@ def __init__(self, config: Config): ) self.auxiliary_model_wrappers = {} self.auxiliary_models = {} + self.staged_task_payloads = defaultdict(dict) def _init_input_storage( self, @@ -137,8 +139,62 @@ async def process(self, exp_bytes: bytes) -> Dict: Returns: Dict: A dictionary containing metrics collected during the processing of experiences. """ - st = time.time() exps = Experience.deserialize_many(exp_bytes) + return await self._process_experiences(exps) + + async def process_serialized_chunks(self, exp_chunks: list[bytes]) -> Dict: + """Process a batch assembled from multiple serialized task payloads.""" + exps = [] + for exp_bytes in exp_chunks: + if not exp_bytes: + continue + exps.extend(Experience.deserialize_many(exp_bytes)) + return await self._process_experiences(exps) + + async def stage_task_payloads( + self, batch_id, task_id: int, exp_chunks: list[bytes] + ) -> Optional[str]: + """Stage serialized payload chunks for one completed task.""" + valid_chunks = [chunk for chunk in exp_chunks if chunk] + if not valid_chunks: + return None + self.staged_task_payloads[batch_id][task_id] = valid_chunks + return f"{batch_id}:{task_id}" + + async def finalize_batch(self, batch_id, task_ids: Optional[list[int]] = None) -> Dict: + """Finalize a staged batch and process all staged task payloads.""" + batch_payloads = self.staged_task_payloads.get(batch_id, {}) + if not batch_payloads: + return await self._process_experiences([]) + + selected_task_ids = task_ids or list(batch_payloads.keys()) + exp_chunks = [] + for task_id in selected_task_ids: + exp_chunks.extend(batch_payloads.pop(task_id, [])) + + if batch_id in self.staged_task_payloads and not self.staged_task_payloads[batch_id]: + del self.staged_task_payloads[batch_id] + + return await self.process_serialized_chunks(exp_chunks) + + async def take_staged_task_payloads(self, batch_id, task_ids: list[int]) -> list[bytes]: + """Drain staged payload chunks for selected tasks without processing them.""" + batch_payloads = self.staged_task_payloads.get(batch_id, {}) + exp_chunks = [] + for task_id in task_ids: + exp_chunks.extend(batch_payloads.pop(task_id, [])) + + if batch_id in self.staged_task_payloads and not self.staged_task_payloads[batch_id]: + del self.staged_task_payloads[batch_id] + + return exp_chunks + + async def abort_batch(self, batch_id) -> None: + """Discard any staged payloads for a batch.""" + self.staged_task_payloads.pop(batch_id, None) + + async def _process_experiences(self, exps: list[Experience]) -> Dict: + st = time.time() if self.input_store is not None: await self.input_store.write_async(exps) diff --git a/trinity/common/config.py b/trinity/common/config.py index c3ac8ab1bfd..57ff7f003a0 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -157,6 +157,7 @@ class DynamicTimeoutConfig: """Config for dynamic timeout in explorer.""" enable: bool = False + warmup_min_steps: int = 1 # only enable dynamic timeout after this many fully observed steps ratio: float = 3.0 # the timeout for each step will be min(max_timeout, average_time_per_task * dynamic_timeout.ratio) diff --git a/trinity/explorer/__init__.py b/trinity/explorer/__init__.py index 8665a1b125f..2b93afc5857 100644 --- a/trinity/explorer/__init__.py +++ b/trinity/explorer/__init__.py @@ -1,3 +1,6 @@ +"""Explorer package exports.""" + from trinity.explorer.explorer import Explorer +from trinity.explorer.rollout_coordinator import RolloutCoordinator -__all__ = ["Explorer"] +__all__ = ["Explorer", "RolloutCoordinator"] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index db82c83067a..24d62c0e027 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -12,10 +12,8 @@ import ray import torch -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from trinity.buffer.buffer import get_buffer_reader -from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline from trinity.buffer.task_scheduler import get_taskset_scheduler from trinity.common.config import Config from trinity.common.constants import ( @@ -24,14 +22,13 @@ SyncMethod, SyncStyle, ) -from trinity.common.experience import Experience from trinity.common.models import create_explorer_models -from trinity.explorer.scheduler import Scheduler +from trinity.explorer.rollout_coordinator import RolloutCoordinator from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer from trinity.utils.annotations import Experimental from trinity.utils.log import get_logger -from trinity.utils.monitor import MONITOR, gather_eval_metrics, gather_metrics +from trinity.utils.monitor import MONITOR from trinity.utils.plugin_loader import load_plugins from trinity.utils.timer import Timer @@ -52,13 +49,11 @@ def __init__(self, config: Config): self.config = config self.model_type = config.explorer.rollout_model.engine_type self.models, self.auxiliary_models = create_explorer_models(config) - self.experience_pipeline = self._init_experience_pipeline() self.taskset = ( get_taskset_scheduler(explorer_state=explorer_state, config=config) if self.config.mode not in {"bench", "serve"} else None ) - self.scheduler = None self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, group=self.config.group, @@ -76,6 +71,7 @@ def __init__(self, config: Config): ) else: self.min_wait_num = None + self.rollout_coordinator = None self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL self.pending_eval_tasks = deque() @@ -191,10 +187,6 @@ async def prepare(self) -> None: ) await asyncio.gather(*run_api_ref) self.logger.info("All models are ready.") - # prepare experience pipeline - if self.experience_pipeline: - await self.experience_pipeline.prepare.remote() - self.logger.info("Experience pipeline is ready.") if not self.use_nccl_sync and self.model_type not in {"tinker", "external"}: if self.config.mode == "serve": # In serving mode, each engine will setup its own process group @@ -206,8 +198,13 @@ async def prepare(self) -> None: await self.setup_weight_sync_group(master_address, master_port) if self.config.mode != "serve": - self.scheduler = Scheduler(self.config, self.models, self.auxiliary_models) - await self.scheduler.start() + self.rollout_coordinator = RolloutCoordinator.get_actor( + self.config, + self.models, + self.auxiliary_models, + ) + await self.rollout_coordinator.prepare.remote() + self.logger.info("Rollout coordinator is ready.") if self.config.explorer.eval_on_startup and self.explore_step_num == 0: await self.eval() @@ -268,11 +265,17 @@ async def explore_step(self) -> bool: await self.shutdown() return False self.explore_step_num += 1 - self.scheduler.schedule(tasks, batch_id=self.explore_step_num) + assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first." + await self.rollout_coordinator.submit_batch.remote( + batch_id=self.explore_step_num, + tasks=tasks, + batch_type="train", + min_wait_num=self.min_wait_num, + ) return True async def finish_current_steps(self) -> None: - if self.scheduler: + if self.rollout_coordinator is not None: await self._finish_steps( self.last_monitored_step + 1, self.explore_step_num, self.model_version ) @@ -316,7 +319,14 @@ async def eval(self): while True: try: data = await eval_taskset.read_async() - self.scheduler.schedule(data, batch_id=eval_batch_id) + assert ( + self.rollout_coordinator is not None + ), "Rollout coordinator must be prepared first." + await self.rollout_coordinator.submit_batch.remote( + batch_id=eval_batch_id, + tasks=data, + batch_type="eval", + ) except StopAsyncIteration: break @@ -360,7 +370,7 @@ async def save_checkpoint(self) -> None: async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights - if self.scheduler and self.explore_step_num == 0: + if self.rollout_coordinator is not None and self.explore_step_num == 0: await self._finish_eval_step(step=0) self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") @@ -388,24 +398,15 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int self.monitor.log(metric, step=end_step) async def _finish_explore_step(self, step: int, model_version: int) -> None: + assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first." metric = {"rollout/model_version": model_version} with Timer(metric, "time/wait_explore_step"): - statuses, exps = await self.scheduler.get_results( - batch_id=step, - min_num=self.min_wait_num, - return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, - ) - if self.experience_pipeline is not None: - pipeline_metrics = await self.experience_pipeline.process.remote( - Experience.serialize_many(exps) - ) - self.taskset.feedback(pipeline_metrics) - metric.update(pipeline_metrics) - if statuses: - metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout")) - metric["rollout/finished_task_count"] = len(statuses) - if self.monitor is not None: - self.monitor.log(metric, step=step) + result = await self.rollout_coordinator.finalize_train_batch.remote(step) + if self.taskset is not None: + self.taskset.feedback(result["metrics"]) + metric.update(result["metrics"]) + if result["finished_task_count"] > 0 and self.monitor is not None: + self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: if not self.pending_eval_tasks: @@ -417,18 +418,13 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - statuses, _ = await self.scheduler.get_results( - batch_id=f"{step}/{eval_task_name}", - return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, - ) - metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses) - metric.update( - gather_eval_metrics( - [status.metrics[0] for status in statuses], - f"{prefix}/{eval_task_name}", - detailed_stats=self.detailed_stats, - ) + assert ( + self.rollout_coordinator is not None + ), "Rollout coordinator must be prepared first." + result = await self.rollout_coordinator.finalize_eval_batch.remote( + f"{step}/{eval_task_name}" ) + metric.update(result["metrics"]) if self.eval_start_time is not None: metric.update({"time/eval": time.time() - self.eval_start_time}) self.eval_start_time = None @@ -436,15 +432,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva self.monitor.log(metric, step) async def shutdown(self) -> None: - if self.scheduler: - await self.scheduler.stop() - self.scheduler = None - if self.experience_pipeline: - await self.experience_pipeline.close.remote() - # reserve `experience_pipeline.output` for trainer - # TODO: refactor the lifecycle of buffer actor - self._old_experience_pipeline = self.experience_pipeline - self.experience_pipeline = None + if self.rollout_coordinator: + await self.rollout_coordinator.shutdown.remote() + self.rollout_coordinator = None if self.monitor: self.monitor.close() self.monitor = None @@ -463,26 +453,6 @@ async def is_alive(self) -> bool: """Check if the explorer is alive.""" return True - def _init_experience_pipeline(self) -> ray.actor.ActorHandle: - """Init experience pipeline for the explorer.""" - if self.config.mode == "bench": - return None - # place the pipeline on the same node as the explorer to - # avoid unnecessary data transfer between nodes - node_id = ray.get_runtime_context().get_node_id() - return ( - ray.remote(ExperiencePipeline) - .options( - name=f"{self.config.explorer.name}_pipeline", - namespace=self.config.ray_namespace, - scheduling_strategy=NodeAffinitySchedulingStrategy( - node_id=node_id, - soft=False, - ), - ) - .remote(self.config) - ) - @Experimental async def serve(self) -> None: """Run the explorer in serving mode. diff --git a/trinity/explorer/rollout_coordinator.py b/trinity/explorer/rollout_coordinator.py new file mode 100644 index 00000000000..8e8773badaa --- /dev/null +++ b/trinity/explorer/rollout_coordinator.py @@ -0,0 +1,462 @@ +"""Rollout coordinator for async batch submission and finalize.""" + +import asyncio +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +import ray + +from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline +from trinity.common.config import Config +from trinity.common.models import InferenceModel +from trinity.common.workflows import Task +from trinity.explorer.scheduler import CompletedTaskResult, Scheduler +from trinity.utils.log import get_logger +from trinity.utils.monitor import gather_eval_metrics, gather_metrics + +BatchId = Union[int, str] +BatchType = Literal["train", "eval"] + + +class BatchLifecycleState(str, Enum): + """Lifecycle states for one submitted batch.""" + + PENDING = "pending" + RUNNING = "running" + FINALIZING = "finalizing" + FINALIZED = "finalized" + ABORTED = "aborted" + + +class FinalizeReason(str, Enum): + """Reasons why a batch finalize call returns.""" + + COMPLETE = "complete" + PARTIAL = "partial" + TIMEOUT = "timeout" + ABORT = "abort" + + +@dataclass +class BatchState: + """In-memory state tracked for one train or eval batch.""" + + batch_id: BatchId + batch_type: BatchType + expected_task_count: int + statuses: Dict[Union[int, str], Any] = field(default_factory=dict) + min_wait_num: Optional[int] = None + state: BatchLifecycleState = BatchLifecycleState.PENDING + final_result: Optional[dict] = None + finalize_lock: asyncio.Lock = field(default_factory=asyncio.Lock) + min_threshold_reached_time: Optional[float] = None + + @property + def completed_task_count(self) -> int: + """Return the number of completed tasks tracked by status.""" + + return len(self.statuses) + + +class RolloutCoordinator: + """Own scheduler-side batch state and expose batch-level finalize APIs.""" + + def __init__( + self, + config: Config, + rollout_model: List[InferenceModel], + auxiliary_models: Optional[List[List[InferenceModel]]] = None, + ): + """Create a coordinator with internally managed scheduler and pipeline.""" + self.logger = get_logger("rollout_coordinator", in_ray_actor=True) + self.config = config + self.rollout_model = rollout_model + self.auxiliary_models = auxiliary_models or [] + self.experience_pipeline = None + self.scheduler: Optional[Scheduler] = None + self.pending_batches: Dict[BatchId, BatchState] = {} + self.event_loop_task: Optional[asyncio.Task] = None + self.running = False + self.detailed_stats = getattr(getattr(config, "monitor", None), "detailed_stats", False) + + async def prepare(self) -> None: + """Initialize the owned pipeline and scheduler.""" + if self.running: + return + if self.experience_pipeline is None: + self.experience_pipeline = self._init_experience_pipeline() + if self.experience_pipeline is not None: + await self.experience_pipeline.prepare() + if self.scheduler is None: + self.scheduler = self._init_scheduler() + await self.scheduler.start() + self.running = True + self.event_loop_task = asyncio.create_task(self._completed_task_event_loop()) + + async def shutdown(self) -> None: + """Stop background work and close owned dependencies.""" + self.running = False + if self.event_loop_task is not None: + self.event_loop_task.cancel() + try: + await self.event_loop_task + except asyncio.CancelledError: + pass + self.event_loop_task = None + if self.scheduler is not None: + await self.scheduler.stop() + self.scheduler = None + if self.experience_pipeline is not None: + await self.experience_pipeline.close() + self.experience_pipeline = None + + def _init_experience_pipeline(self): + """Create the experience pipeline owned by this coordinator actor.""" + if self.config.mode == "bench": + return None + return ExperiencePipeline(self.config) + + def _init_scheduler(self) -> Scheduler: + """Create the scheduler owned by this coordinator.""" + return Scheduler( + self.config, + self.rollout_model, + self.auxiliary_models, + emit_completed_task_events=True, + ) + + def _require_scheduler(self) -> Scheduler: + """Return the initialized scheduler.""" + assert self.scheduler is not None, "RolloutCoordinator.prepare() must be called first." + return self.scheduler + + async def submit_batch( + self, + *, + batch_id: BatchId, + tasks: list[Task], + batch_type: BatchType, + min_wait_num: Optional[int] = None, + ) -> None: + """Register a new batch and schedule its tasks.""" + existing_state = self.pending_batches.get(batch_id) + if existing_state is not None and existing_state.state not in { + BatchLifecycleState.FINALIZED, + BatchLifecycleState.ABORTED, + }: + raise ValueError(f"Batch {batch_id} is already active.") + + batch_state = BatchState( + batch_id=batch_id, + batch_type=batch_type, + expected_task_count=len(tasks), + min_wait_num=min_wait_num, + ) + self.pending_batches[batch_id] = batch_state + + if tasks: + self._require_scheduler().schedule(tasks, batch_id=batch_id) + batch_state.state = BatchLifecycleState.RUNNING + + async def finalize_train_batch( + self, + batch_id: int, + *, + timeout: Optional[float] = None, + ) -> dict: + """Finalize one train batch and return aggregated metrics.""" + batch_state = self._get_batch_state(batch_id, expected_type="train") + return await self._finalize_train_batch(batch_state, timeout=timeout) + + async def finalize_eval_batch( + self, + batch_id: str, + *, + timeout: Optional[float] = None, + ) -> dict: + """Finalize one eval batch and return aggregated eval metrics.""" + batch_state = self._get_batch_state(batch_id, expected_type="eval") + return await self._finalize_eval_batch(batch_state, timeout=timeout) + + async def _finalize_eval_batch( + self, batch_state: BatchState, *, timeout: Optional[float] + ) -> dict: + """Finalize one eval batch.""" + + scheduler = self._require_scheduler() + async with batch_state.finalize_lock: + existing_result = self._get_existing_final_result(batch_state) + if existing_result is not None: + return existing_result + + statuses = await scheduler.get_statuses( + batch_id=batch_state.batch_id, + timeout=timeout, + return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, + ) + for task_id, status in enumerate(statuses): + if task_id in batch_state.statuses: + continue + batch_state.statuses[task_id] = status + reason = ( + FinalizeReason.COMPLETE + if batch_state.completed_task_count >= batch_state.expected_task_count + else FinalizeReason.TIMEOUT + ) + return self._finish_batch(batch_state, reason, {}) + + async def abort_batch( + self, + batch_id: BatchId, + *, + reason: str, + keep_partial_results: bool = False, + ) -> None: + """Abort one batch and cleanup its running and staged state.""" + scheduler = self._require_scheduler() + batch_state = self.pending_batches.get(batch_id) + if batch_state is None: + return + if batch_state.state in {BatchLifecycleState.FINALIZED, BatchLifecycleState.ABORTED}: + return + + self.logger.warning("Abort batch %s: %s", batch_id, reason) + await scheduler.abort_batch( + batch_id, + return_partial_tasks=keep_partial_results, + restart_runners=True, + ) + if self.experience_pipeline is not None: + await self.experience_pipeline.abort_batch(batch_id) + + batch_state.state = BatchLifecycleState.ABORTED + batch_state.final_result = self._build_batch_result(batch_state, FinalizeReason.ABORT, {}) + self.pending_batches.pop(batch_id, None) + + @classmethod + def get_actor( + cls, config: Config, models: List, auxiliary_models: List + ) -> ray.actor.ActorHandle: + """Init rollout coordinator for the task-event-completion path.""" + return ( + ray.remote(RolloutCoordinator) + .options(namespace=config.ray_namespace) + .remote( + config, + models, + auxiliary_models, + ) + ) + + async def _completed_task_event_loop(self) -> None: + """Consume task completion events emitted by the scheduler.""" + scheduler = self._require_scheduler() + while self.running: + try: + completed_result = await scheduler.wait_completed_task(timeout=0.1) + if completed_result is None: + return + if not isinstance(completed_result.task_id, int): + self.logger.warning( + "Skip completed task event with non-integer task id: %s", + completed_result.task_id, + ) + return + batch_state = self.pending_batches.get(completed_result.batch_id) + if batch_state is None: + return + await self._store_completed_task_result(batch_state, completed_result) + self._maybe_mark_ready(batch_state) + except Exception: # noqa: BLE001 + self.logger.exception("RolloutCoordinator task event loop failed.") + + async def _store_completed_task_result( + self, batch_state: BatchState, result: CompletedTaskResult + ) -> None: + """Persist one completed task into batch-level aggregation state.""" + if result.task_id in batch_state.statuses: + return + batch_state.statuses[result.task_id] = result.status + + if batch_state.batch_type != "train": + return + + if self.experience_pipeline is not None and result.experience_payloads: + staged_task_id = int(result.task_id) + await self.experience_pipeline.stage_task_payloads( + batch_state.batch_id, + staged_task_id, + result.experience_payloads, + ) + return + + def _get_batch_state(self, batch_id: BatchId, *, expected_type: BatchType) -> BatchState: + """Return one registered batch and validate its type.""" + batch_state = self.pending_batches.get(batch_id) + if batch_state is None: + raise KeyError(f"Batch {batch_id} is not registered.") + if batch_state.batch_type != expected_type: + raise ValueError( + f"Batch {batch_id} is {batch_state.batch_type}, expected {expected_type}." + ) + return batch_state + + def _get_existing_final_result(self, batch_state: BatchState) -> Optional[dict]: + """Reuse an in-flight final result or synthesize an abort result.""" + + if batch_state.final_result is not None: + return dict(batch_state.final_result) + if batch_state.state != BatchLifecycleState.ABORTED: + return None + batch_state.final_result = self._build_batch_result(batch_state, FinalizeReason.ABORT, {}) + return dict(batch_state.final_result) + + def _get_ready_reason(self, batch_state: BatchState) -> Optional[FinalizeReason]: + """Check whether a batch is ready to finalize and why.""" + if batch_state.state == BatchLifecycleState.ABORTED: + return FinalizeReason.ABORT + if batch_state.completed_task_count >= batch_state.expected_task_count: + return FinalizeReason.COMPLETE + if batch_state.batch_type != "train": + return None + if batch_state.min_wait_num is None: + return None + if batch_state.completed_task_count < batch_state.min_wait_num: + batch_state.min_threshold_reached_time = None + return None + + if batch_state.min_threshold_reached_time is None: + batch_state.min_threshold_reached_time = time.time() + + wait_after_min = getattr(self.config.explorer.over_rollout, "wait_after_min", 0.0) + if time.time() - batch_state.min_threshold_reached_time >= wait_after_min: + return FinalizeReason.PARTIAL + return None + + def _maybe_mark_ready(self, batch_state: BatchState) -> Optional[FinalizeReason]: + """Transition a batch to ready state when finalize conditions are met.""" + ready_reason = self._get_ready_reason(batch_state) + if ready_reason is None: + return None + return ready_reason + + async def _wait_for_ready( + self, batch_state: BatchState, timeout: Optional[float] + ) -> Optional[FinalizeReason]: + """Wait until a batch is ready to finalize or the timeout expires.""" + start_time = time.time() + while True: + ready_reason = self._maybe_mark_ready(batch_state) + if ready_reason is not None: + return ready_reason + if timeout is not None and (time.time() - start_time) >= timeout: + return None + await asyncio.sleep(0.05) + + async def _finalize_train_batch( + self, batch_state: BatchState, *, timeout: Optional[float] + ) -> dict: + """Finalize one train batch.""" + async with batch_state.finalize_lock: + existing_result = self._get_existing_final_result(batch_state) + if existing_result is not None: + return existing_result + + ready_reason = await self._wait_for_ready(batch_state, timeout) + if ready_reason is None: + if batch_state.min_wait_num is not None and batch_state.statuses: + ready_reason = FinalizeReason.TIMEOUT + else: + raise TimeoutError(f"Timeout waiting for batch {batch_state.batch_id}.") + + batch_state.state = BatchLifecycleState.FINALIZING + try: + pipeline_metrics = await self._finalize_train_payloads(batch_state) + if ready_reason != FinalizeReason.COMPLETE: + await self._cleanup_train_batch_runtime(batch_state) + except Exception: + batch_state.state = self._get_active_batch_state(batch_state) + raise + + return self._finish_batch(batch_state, ready_reason, pipeline_metrics) + + def _finish_batch( + self, + batch_state: BatchState, + reason: FinalizeReason, + pipeline_metrics: dict, + ) -> dict: + """Persist one terminal result and evict the batch from active state.""" + batch_state.state = BatchLifecycleState.FINALIZED + batch_state.final_result = self._build_batch_result(batch_state, reason, pipeline_metrics) + self.pending_batches.pop(batch_state.batch_id, None) + return dict(batch_state.final_result) + + def _get_active_batch_state(self, batch_state: BatchState) -> BatchLifecycleState: + """Return the active lifecycle state to restore after a failed finalize attempt.""" + if batch_state.expected_task_count == 0: + return BatchLifecycleState.PENDING + return BatchLifecycleState.RUNNING + + async def _cleanup_train_batch_runtime(self, batch_state: BatchState) -> None: + """Drop unfinished train work after a non-complete finalize result.""" + scheduler = self._require_scheduler() + await scheduler.abort_batch( + batch_state.batch_id, + return_partial_tasks=False, + restart_runners=True, + ) + if self.experience_pipeline is not None: + await self.experience_pipeline.abort_batch(batch_state.batch_id) + + async def _finalize_train_payloads(self, batch_state: BatchState) -> dict: + """Flush staged train payloads through the experience pipeline.""" + if self.experience_pipeline is not None and batch_state.statuses: + return await self.experience_pipeline.finalize_batch(batch_state.batch_id) + return {} + + def _build_batch_result( + self, + batch_state: BatchState, + reason: FinalizeReason, + pipeline_metrics: dict, + ) -> dict: + """Build the public finalize result returned to Explorer.""" + + metrics = dict(pipeline_metrics) + status_metrics = [ + status.metrics[0] for status in batch_state.statuses.values() if status.metrics + ] + if batch_state.batch_type == "train": + if status_metrics: + metrics.update(gather_metrics(status_metrics, "rollout")) + metrics["rollout/finished_task_count"] = float(batch_state.completed_task_count) + else: + prefix = self._eval_metric_prefix(batch_state.batch_id) + if status_metrics: + metrics.update( + gather_eval_metrics( + status_metrics, + prefix, + detailed_stats=self.detailed_stats, + ) + ) + metrics[f"{prefix}/finished_task_count"] = float(batch_state.completed_task_count) + + return { + "batch_id": batch_state.batch_id, + "batch_type": batch_state.batch_type, + "finished_task_count": batch_state.completed_task_count, + "metrics": metrics, + "finalize_reason": reason.value, + "finalized": reason != FinalizeReason.ABORT, + } + + def _eval_metric_prefix(self, batch_id: BatchId) -> str: + """Return the metric namespace prefix for one eval batch.""" + batch_name = str(batch_id) + if "/" in batch_name: + batch_name = batch_name.split("/", 1)[1] + return f"eval/{batch_name}" diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index ea3206ef00b..1d135ef8871 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -12,7 +12,6 @@ import ray from trinity.common.config import Config -from trinity.common.experience import Experience from trinity.common.models import InferenceModel from trinity.common.workflows import Task from trinity.explorer.workflow_runner import Status, WorkflowRunner @@ -33,11 +32,21 @@ class TaskWrapper: completed_runs: int = 0 total_runs: int = 0 # total planned runs for the whole task metrics: List[Dict[str, float]] = field(default_factory=list) - experiences: List[Experience] = field(default_factory=list) + experience_payloads: List[bytes] = field(default_factory=list) first_error: Optional[str] = None emitted: bool = False +@dataclass(frozen=True) +class CompletedTaskResult: + """A completed task result stored by batch and task id.""" + + batch_id: Union[int, str] + task_id: Union[int, str] + status: Status + experience_payloads: List[bytes] = field(default_factory=list) + + # Adapted from verl/trainer/ppo/metric_utils.py def bootstrap_metric( data: list[Any], @@ -161,7 +170,12 @@ def _create_runner(self): "env_vars": self.config.explorer.env_vars, }, ) - .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) + .remote( + self.config, + self.rollout_model, + self.auxiliary_models, + self.runner_id, + ) ) async def prepare(self): @@ -179,7 +193,7 @@ async def run_with_retry( run_id_base: int, timeout: float, collect_partial_runs: bool, - ) -> Tuple[Status, List, int, float]: + ) -> Tuple[Status, bytes, int, float]: """ Args: task (`TaskWrapper`): The task to run. @@ -197,7 +211,7 @@ async def run_with_retry( await self.runner.__ray_ready__.remote() start_time = time.time() status = Status(completed_runs=0, total_runs=repeat_times, metrics=list()) - exps = [] + exp_payload = b"" run_task_ref = None task2run = replace( task.task, @@ -216,7 +230,7 @@ async def run_with_retry( run_id_base=run_id_base, collect_partial_runs=collect_partial_runs, ) - status, exps = await asyncio.wait_for( + status, exp_payload = await asyncio.wait_for( run_task_ref, timeout=timeout, ) @@ -261,7 +275,7 @@ async def run_with_retry( finally: end_time = time.time() status.metrics.append({"time/task_execution": end_time - start_time}) - return status, exps, self.runner_id, end_time - start_time + return status, exp_payload, self.runner_id, end_time - start_time async def restart_runner(self): old_runner = self.runner @@ -299,6 +313,8 @@ def __init__( config: Config, rollout_model: List[InferenceModel], auxiliary_models: Optional[List[List[InferenceModel]]] = None, + *, + emit_completed_task_events: bool = False, ): self.logger = get_logger(__name__) self.config = config @@ -328,17 +344,22 @@ def __init__( self.running_task_map: Dict[asyncio.Future, TaskWrapper] = dict() # future -> task self.running_task_runner_map: Dict[asyncio.Future, int] = dict() # future -> runner_id self.cancelled_task_restart_map: Dict[asyncio.Future, bool] = dict() + self.batch_is_eval_map: Dict[Union[int, str], bool] = dict() self.completed_tasks: Dict[ - Union[int, str], deque[Tuple[Status, List[Experience]]] + Union[int, str], Dict[Union[int, str], CompletedTaskResult] ] = defaultdict( - deque + dict ) # batch_id -> results + self.completed_task_results: asyncio.Queue[CompletedTaskResult] = asyncio.Queue() + self.emit_completed_task_events = emit_completed_task_events self.background_tasks: set[asyncio.Task] = set() self.scheduler_task: Optional[asyncio.Task] = None self.running = False self.total_running_time = 0.0 + self.total_completed_steps = 0 + self.total_completed_sub_tasks = 0 self.total_completed_tasks = 0 async def _create_runner( @@ -448,15 +469,17 @@ def task_done_callback(self, async_task: asyncio.Task): self._schedule_runner_restart(runner_id) else: self.cancelled_task_restart_map.pop(async_task, None) - status, exps, runner_id, run_time = async_task.result() - if not task.task.is_eval: # only count running time for non-eval tasks + status, exp_payload, runner_id, run_time = async_task.result() + if not task.task.is_eval: self.total_running_time += run_time - self.total_completed_tasks += 1 - self._accumulate_task_result(task, status, exps) + self.total_completed_sub_tasks += 1 + self._accumulate_task_result(task, status, exp_payload) self.busy_runners.pop(runner_id, None) self.idle_runners.add(runner_id) # If all sub runs in a task are completed if task.finished_sub_task_num == task.sub_task_num: + if not task.task.is_eval: + self.total_completed_tasks += 1 self._emit_task_result(task) self.logger.debug(f"Task completed (batch_id {task.batch_id}).") @@ -466,16 +489,17 @@ def task_done_callback(self, async_task: asyncio.Task): del self.running_tasks[task.batch_id] def _accumulate_task_result( - self, task: TaskWrapper, status: Status, experiences: List[Experience] + self, task: TaskWrapper, status: Status, experience_payload: bytes ) -> None: task.finished_sub_task_num += 1 task.completed_runs += status.completed_runs task.metrics.extend(status.metrics) - task.experiences.extend(experiences) + if experience_payload: + task.experience_payloads.append(experience_payload) if not status.ok and task.first_error is None: task.first_error = status.message - def _build_task_result(self, task: TaskWrapper) -> Tuple[Status, List[Experience]]: + def _build_task_result(self, task: TaskWrapper) -> Tuple[Status, List[bytes]]: if task.completed_runs < task.total_runs: message = f"{task.completed_runs}/{task.total_runs} runs completed successfully." if task.first_error: @@ -490,14 +514,35 @@ def _build_task_result(self, task: TaskWrapper) -> Tuple[Status, List[Experience metrics=[calculate_task_level_metrics(task.metrics, task.task.is_eval)], message=message, ) - return status, list(task.experiences) + return status, list(task.experience_payloads) def _emit_task_result(self, task: TaskWrapper) -> None: if task.emitted: return - self.completed_tasks[task.batch_id].appendleft(self._build_task_result(task)) + status, experience_payloads = self._build_task_result(task) + task_id = task.task.task_id + completed_result = CompletedTaskResult( + batch_id=task.batch_id, + task_id=task_id, + status=status, + experience_payloads=experience_payloads, + ) + if self.emit_completed_task_events and not task.task.is_eval: + self.completed_task_results.put_nowait(completed_result) + else: + self.completed_tasks[task.batch_id][task_id] = completed_result task.emitted = True + async def wait_completed_task( + self, timeout: Optional[float] = None + ) -> Optional[CompletedTaskResult]: + try: + if timeout is None: + return await self.completed_task_results.get() + return await asyncio.wait_for(self.completed_task_results.get(), timeout=timeout) + except asyncio.TimeoutError: + return None + def _collect_incomplete_tasks(self, batch_id: Union[int, str]) -> List[TaskWrapper]: tasks = {} for task, _, _ in self.pending_tasks.get(batch_id, deque()): @@ -589,6 +634,7 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: """ if not tasks: return + self.batch_is_eval_map[batch_id] = tasks[0].is_eval self.task_num_map[batch_id] += len(tasks) self._split_and_submit_tasks(tasks, batch_id=batch_id) @@ -616,9 +662,9 @@ def dynamic_timeout(self, timeout: Optional[float] = None) -> float: max_timeout = timeout or self.default_timeout if not self.config.explorer.dynamic_timeout.enable: return max_timeout - if self.total_completed_tasks < self.default_batch_size: + if self.total_completed_steps < self.config.explorer.dynamic_timeout.warmup_min_steps: return max_timeout - avg_time_per_task = self.total_running_time / self.total_completed_tasks + avg_time_per_task = self.total_running_time / self.total_completed_sub_tasks return min( max_timeout, avg_time_per_task * self.config.explorer.dynamic_timeout.ratio, @@ -647,42 +693,45 @@ async def _cleanup_batch( if runners_to_restart: await asyncio.gather(*[self._restart_runner(rid) for rid in runners_to_restart]) - async def get_results( - self, - batch_id: Union[int, str], - min_num: Optional[int] = None, - timeout: Optional[float] = None, - clear_timeout_tasks: bool = True, - return_partial_tasks: bool = False, - ) -> Tuple[List[Status], List[Experience]]: - """Get the result of tasks at the specific batch_id. - - Args: - batch_id (`Union[int, str]`): Only wait for tasks at this batch. - min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. - timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. - clear_timeout_tasks (`bool`): Whether to clear timeout tasks. - return_partial_tasks (`bool`): Whether to emit tasks with partial successful runs when cleaning up unfinished tasks. - """ - timeout = timeout or self.default_timeout - start_time = time.time() + def _resolve_result_target( + self, batch_id: Union[int, str], min_num: Optional[int] + ) -> Tuple[int, int]: scheduled_num = self.task_num_map.get(batch_id, 0) if min_num is None: - min_num = scheduled_num - elif min_num > scheduled_num: + return scheduled_num, scheduled_num + if min_num > scheduled_num: self.logger.warning( f"Requested min_num {min_num} is greater than scheduled tasks {scheduled_num} at batch_id {batch_id}. Adjusting min_num to {scheduled_num}." ) - min_num = scheduled_num + return scheduled_num, scheduled_num + return scheduled_num, min_num - self.logger.debug(f"Waiting for {min_num} tasks to complete...") + def _finalize_dynamic_timeout_step( + self, batch_id: Union[int, str], scheduled_num: int, completed_count: int + ) -> None: + if batch_id in self.pending_tasks or batch_id in self.running_tasks: + return + is_eval = self.batch_is_eval_map.pop(batch_id, False) + if not is_eval and completed_count >= scheduled_num: + self.total_completed_steps += 1 + + async def _wait_for_batch_results( + self, + batch_id: Union[int, str], + min_num: int, + scheduled_num: int, + timeout: float, + clear_timeout_tasks: bool, + return_partial_tasks: bool, + ) -> bool: + start_time = time.time() min_threshold_reached_time = None while time.time() - start_time <= timeout: - completed_count = len(self.completed_tasks.get(batch_id, [])) + completed_count = len(self.completed_tasks.get(batch_id, {})) if completed_count >= min_num: min_threshold_reached_time = min_threshold_reached_time or time.time() if completed_count >= scheduled_num: - break + return False if ( time.time() - min_threshold_reached_time >= self.config.explorer.over_rollout.wait_after_min @@ -693,10 +742,47 @@ async def get_results( return_partial_tasks=return_partial_tasks, restart_runners=False, ) - break + return False await asyncio.sleep(0.1) + return True + + async def _collect_batch_results( + self, batch_id: Union[int, str] + ) -> Tuple[List[Status], List[bytes]]: + statuses = [] + payload_chunks = [] + completed_results = list(self.completed_tasks.get(batch_id, {}).values()) + for result in completed_results: + statuses.append(result.status) + if result.experience_payloads: + payload_chunks.extend(result.experience_payloads) - if time.time() - start_time > timeout: + return statuses, payload_chunks + + async def _get_batch_payload_results( + self, + batch_id: Union[int, str], + *, + min_num: Optional[int], + timeout: Optional[float], + clear_timeout_tasks: bool, + return_partial_tasks: bool, + ) -> Tuple[List[Status], List[bytes]]: + """Wait for one batch and drain its completed payload chunks.""" + + timeout = timeout or self.default_timeout + scheduled_num, min_num = self._resolve_result_target(batch_id, min_num) + + self.logger.debug(f"Waiting for {min_num} tasks to complete...") + timed_out = await self._wait_for_batch_results( + batch_id=batch_id, + min_num=min_num, + scheduled_num=scheduled_num, + timeout=timeout, + clear_timeout_tasks=clear_timeout_tasks, + return_partial_tasks=return_partial_tasks, + ) + if timed_out: self.logger.error( f"Timed out waiting for tasks at batch {batch_id} to complete after {timeout} seconds" ) @@ -707,27 +793,69 @@ async def get_results( restart_runners=True, ) - statuses = [] - experiences = [] - completed_queue = self.completed_tasks.get(batch_id, deque()) - while completed_queue: - status, exps = completed_queue.pop() - statuses.append(status) - if isinstance(exps, list): - experiences.extend(exps) - else: - experiences.append(exps) + statuses, payload_chunks = await self._collect_batch_results(batch_id) - if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: + if batch_id in self.completed_tasks: del self.completed_tasks[batch_id] completed_count = len(statuses) + self._finalize_dynamic_timeout_step(batch_id, scheduled_num, completed_count) if completed_count < min_num: self.logger.warning( f"Timeout reached, only {completed_count}/{min_num} tasks completed" ) - return statuses, experiences + return statuses, payload_chunks + + async def get_payload_results( + self, + batch_id: Union[int, str], + min_num: Optional[int] = None, + timeout: Optional[float] = None, + clear_timeout_tasks: bool = True, + return_partial_tasks: bool = False, + ) -> Tuple[List[Status], List[bytes]]: + """Wait for one batch and return task statuses plus serialized payload chunks.""" + + return await self._get_batch_payload_results( + batch_id=batch_id, + min_num=min_num, + timeout=timeout, + clear_timeout_tasks=clear_timeout_tasks, + return_partial_tasks=return_partial_tasks, + ) + + async def get_statuses( + self, + batch_id: Union[int, str], + min_num: Optional[int] = None, + timeout: Optional[float] = None, + clear_timeout_tasks: bool = True, + return_partial_tasks: bool = False, + ) -> List[Status]: + """Wait for one batch and return only task statuses without materializing experiences.""" + + statuses, _ = await self._get_batch_payload_results( + batch_id=batch_id, + min_num=min_num, + timeout=timeout, + clear_timeout_tasks=clear_timeout_tasks, + return_partial_tasks=return_partial_tasks, + ) + return statuses + + async def abort_batch( + self, + batch_id: Union[int, str], + return_partial_tasks: bool = False, + restart_runners: bool = True, + ) -> None: + """Abort one batch and cleanup unfinished scheduler state.""" + await self._cleanup_batch( + batch_id, + return_partial_tasks=return_partial_tasks, + restart_runners=restart_runners, + ) def has_step(self, batch_id: Union[int, str]) -> bool: return ( @@ -772,7 +900,8 @@ async def wait_all( self.logger.error(error_msg) if clear_timeout_tasks: - for batch_id in self.pending_tasks.keys() | self.running_tasks.keys(): + batch_ids_to_abort = self.pending_tasks.keys() | self.running_tasks.keys() + for batch_id in batch_ids_to_abort: self._clear_timeout_tasks(batch_id) asyncio.gather( *[self._restart_runner(runner_id) for runner_id in self.busy_runners.keys()] diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 5bd7fc667c3..f10fca58c8c 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -398,11 +398,11 @@ async def run_task( repeat_times: int = 1, run_id_base: int = 0, collect_partial_runs: bool = True, - ) -> Tuple[Status, List[Experience]]: + ) -> Tuple[Status, bytes]: """Run the task and return the states.""" # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead + st = time.time() try: - st = time.time() model_version = await self.model_wrapper.model_version_async self.runner_state["model_version"] = model_version self.logger.info( @@ -436,9 +436,10 @@ async def run_task( if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer - return status, [] + return status, b"" else: - return status, exps + exp_payload = Experience.serialize_many(exps) + return status, exp_payload except Exception as e: error_trace_back = traceback.format_exc() @@ -450,7 +451,7 @@ async def run_task( metrics=[{"time/run_execution": time.time() - st}], message=error_trace_back.rstrip(), ), - [], + b"", ) @@ -511,20 +512,23 @@ async def debug(self) -> None: task = tasks[0] self.logger.info(f"Start debugging task:\n{task.raw_task}") if not self.enable_profiling: - status, exps = await self.run_task( + status, exp_payload = await self.run_task( task=task, batch_id="debug", repeat_times=1, run_id_base=0 ) else: from viztracer import VizTracer with VizTracer(output_file=self.output_profiling_file): - status, exps = await self.run_task( + status, exp_payload = await self.run_task( task=task, batch_id="debug", repeat_times=1, run_id_base=0 ) - if not status.ok and len(exps) == 0: - exps = self.model_wrapper.extract_experience_from_history() - self.logger.info(f"Debugging failed, extracting {len(exps)} experiences from history.") - await self.sqlite_writer.write_async(exps) + experiences = Experience.deserialize_many(exp_payload) if exp_payload else [] + if not status.ok and not experiences: + experiences = self.model_wrapper.extract_experience_from_history() + self.logger.info( + f"Debugging failed, extracting {len(experiences)} experiences from history." + ) + await self.sqlite_writer.write_async(experiences) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") else: