From 113a52cc0b9273762d9e20e71f9705069bbcd38b Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 13:41:03 +0800 Subject: [PATCH 01/11] support task level wait and reduce experience communication overhead --- tests/explorer/explorer_test.py | 407 +++++++++++++++++- tests/explorer/scheduler_test.py | 111 ++++- .../buffer/pipelines/experience_pipeline.py | 58 ++- trinity/explorer/explorer.py | 123 +++++- trinity/explorer/scheduler.py | 136 ++++-- trinity/explorer/workflow_runner.py | 20 +- 6 files changed, 821 insertions(+), 34 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 075c7ed0f6b..9d745b2d5d6 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -5,10 +5,15 @@ import os import random import shutil +import unittest +from collections import deque from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock import httpx import ray +import torch from tests.tools import ( RayUnittestBase, @@ -29,11 +34,116 @@ OperatorConfig, ) from trinity.common.constants import StorageType -from trinity.explorer.explorer import Explorer +from trinity.common.experience import Experience +from trinity.explorer.explorer import ExploreStepBuffer, Explorer from trinity.explorer.proxy.client import TrinityClient +from trinity.explorer.scheduler import CompletedTaskRef, CompletedTaskResult +from trinity.explorer.workflow_runner import Status from trinity.manager.state_manager import StateManager +def _build_fake_task_event_explorer(use_payloads: bool = False): + class FakeScheduler: + def __init__(self): + self.default_timeout = 5.0 + # Step 2 finishes before step 1 on purpose. This simulates the + # fully async path where Scheduler can return completed tasks in + # task order rather than in step order. + self.completed_task_refs = deque( + [ + CompletedTaskRef(batch_id=2, task_id=0), + CompletedTaskRef(batch_id=1, task_id=0), + CompletedTaskRef(batch_id=1, task_id=1), + ] + ) + self.completed_results = { + (2, 0): CompletedTaskResult( + task_id=0, + status=Status( + completed_runs=1, + total_runs=1, + metrics=[{"run_metrics": 20.0}], + ), + experiences=[], + experience_payloads=[b"step-2-task-0"] if use_payloads else [], + ), + (1, 0): CompletedTaskResult( + task_id=0, + status=Status( + completed_runs=1, + total_runs=1, + metrics=[{"run_metrics": 10.0}], + ), + experiences=[], + experience_payloads=[b"step-1-task-0"] if use_payloads else [], + ), + (1, 1): CompletedTaskResult( + task_id=1, + status=Status( + completed_runs=1, + total_runs=1, + metrics=[{"run_metrics": 11.0}], + ), + experiences=[], + experience_payloads=[b"step-1-task-1"] if use_payloads else [], + ), + } + self.get_results_calls = [] + + async def wait_completed_task(self, timeout=None): + if self.completed_task_refs: + return self.completed_task_refs.popleft() + return None + + def pop_completed_task(self, batch_id, task_id): + return self.completed_results.pop((batch_id, task_id), None) + + async def get_results( + self, + batch_id, + timeout=None, + clear_timeout_tasks=True, + return_partial_tasks=False, + ): + self.get_results_calls.append( + { + "batch_id": batch_id, + "timeout": timeout, + "clear_timeout_tasks": clear_timeout_tasks, + "return_partial_tasks": return_partial_tasks, + } + ) + return [], [] + + class FakeMonitor: + def __init__(self): + self.logged = [] + + def log(self, metric, step): + self.logged.append((step, metric)) + + explorer = Explorer.__new__(Explorer) + explorer.logger = MagicMock() + explorer.scheduler = FakeScheduler() + explorer.monitor = FakeMonitor() + explorer.experience_pipeline = None + explorer.taskset = SimpleNamespace(feedback=lambda metrics: None) + explorer.use_task_event_completion = True + explorer.pending_eval_tasks = deque() + explorer.pending_step_buffers = { + 1: ExploreStepBuffer(expected_task_count=2), + 2: ExploreStepBuffer(expected_task_count=1), + } + explorer.explore_start_time = None + explorer.last_monitored_step = 0 + explorer.explore_step_num = 2 + explorer.model_version = 7 + explorer.config = SimpleNamespace( + explorer=SimpleNamespace(over_rollout=SimpleNamespace(return_partial_tasks=False)) + ) + return explorer + + class BaseExplorerCase(RayUnittestBase): def setUp(self): self.config = get_template_config() @@ -226,6 +336,301 @@ def test_explorer(self): ray.get(explorer.shutdown.remote()) +class TestExplorerTaskLevelCompletion(unittest.IsolatedAsyncioTestCase): + """Tests Explorer's task-level completion buffering. + + The task event stream may arrive out of step order because Scheduler emits a + completion event as soon as one task is fully done. Explorer must therefore + accept task-level completions eagerly, buffer them by step, and only publish + aggregated step metrics when all tasks of the next step are ready. + """ + + async def test_out_of_order_task_completion_is_buffered_before_step_finish(self): + explorer = _build_fake_task_event_explorer() + + # Waiting for step 1 is allowed to consume and buffer task events from + # later steps. The important property is that step 2 is not blocked from + # returning at task granularity just because step 1 is still incomplete. + await explorer._wait_step_buffer(1) + + self.assertEqual(explorer.pending_step_buffers[1].completed_task_count, 2) + self.assertEqual(explorer.pending_step_buffers[2].completed_task_count, 1) + self.assertEqual(explorer.scheduler.get_results_calls, []) + + # When Explorer finally flushes steps, aggregation and monitor logging + # must still happen in ascending step order even though the underlying + # task completion order was 2 -> 1 -> 1. + await explorer._finish_steps(1, 2, model_version=7) + + self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) + self.assertEqual(explorer.scheduler.get_results_calls, []) + self.assertEqual(explorer.pending_step_buffers, {}) + + async def test_finish_current_steps_flushes_buffered_steps_in_order(self): + explorer = _build_fake_task_event_explorer() + + # This exercises the public sync boundary: finish_current_steps should + # drain any out-of-order task events gathered so far, but still flush + # step metrics strictly in step order and advance the monitored cursor. + await explorer.finish_current_steps() + + self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) + self.assertEqual(explorer.last_monitored_step, 2) + self.assertEqual(explorer.scheduler.get_results_calls, []) + self.assertEqual(explorer.pending_step_buffers, {}) + + async def test_finish_current_steps_stages_payloads_before_ordered_finalize(self): + class FakeRemoteMethod: + def __init__(self, func): + self.func = func + + async def remote(self, *args, **kwargs): + return await self.func(*args, **kwargs) + + class FakePipeline: + def __init__(self): + self.stage_calls = [] + self.finalize_calls = [] + self.chunk_process_calls = [] + self.stage_task_payloads = FakeRemoteMethod(self._stage_task_payloads) + self.finalize_batch = FakeRemoteMethod(self._finalize_batch) + self.process_serialized_chunks = FakeRemoteMethod(self._process_serialized_chunks) + + async def _stage_task_payloads(self, batch_id, task_id, exp_chunks): + self.stage_calls.append((batch_id, task_id, list(exp_chunks))) + return f"{batch_id}:{task_id}" + + async def _finalize_batch(self, batch_id, task_ids): + self.finalize_calls.append((batch_id, list(task_ids))) + return {"experience_pipeline/experience_count": float(len(task_ids))} + + async def _process_serialized_chunks(self, exp_chunks): + self.chunk_process_calls.append(list(exp_chunks)) + return {"experience_pipeline/experience_count": float(len(exp_chunks))} + + explorer = _build_fake_task_event_explorer(use_payloads=True) + explorer.experience_pipeline = FakePipeline() + explorer.taskset = SimpleNamespace(feedback=lambda metrics: None) + + # Task payloads are staged immediately when completion events arrive, + # but batch finalize still runs in step order during the flush. + await explorer.finish_current_steps() + + self.assertEqual( + explorer.experience_pipeline.stage_calls, + [ + (2, 0, [b"step-2-task-0"]), + (1, 0, [b"step-1-task-0"]), + (1, 1, [b"step-1-task-1"]), + ], + ) + self.assertEqual( + explorer.experience_pipeline.finalize_calls, + [ + (1, [0, 1]), + (2, [0]), + ], + ) + self.assertEqual(explorer.experience_pipeline.chunk_process_calls, []) + self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) + + +class TestExplorerFallbackPaths(unittest.IsolatedAsyncioTestCase): + async def test_over_rollout_path_keeps_batch_get_results_and_process(self): + class FakeRemoteMethod: + def __init__(self, func): + self.func = func + + async def remote(self, *args, **kwargs): + return await self.func(*args, **kwargs) + + class FakeScheduler: + def __init__(self): + self.calls = [] + + async def get_results( + self, + batch_id, + min_num=None, + timeout=None, + clear_timeout_tasks=True, + return_partial_tasks=False, + ): + self.calls.append( + { + "batch_id": batch_id, + "min_num": min_num, + "timeout": timeout, + "clear_timeout_tasks": clear_timeout_tasks, + "return_partial_tasks": return_partial_tasks, + } + ) + return [Status(1, 1, metrics=[{"run_metrics": 1.0}])], [ + Experience(tokens=torch.zeros(5), prompt_length=2) + ] + + class FakePipeline: + def __init__(self): + self.process_calls = [] + self.finalize_calls = [] + self.process = FakeRemoteMethod(self._process) + self.finalize_batch = FakeRemoteMethod(self._finalize_batch) + + async def _process(self, exp_bytes): + self.process_calls.append(exp_bytes) + return {"experience_pipeline/experience_count": 1.0} + + async def _finalize_batch(self, batch_id, task_ids): + self.finalize_calls.append((batch_id, list(task_ids))) + return {"experience_pipeline/experience_count": 0.0} + + class FakeMonitor: + def __init__(self): + self.logged = [] + + def log(self, metric, step): + self.logged.append((step, metric)) + + explorer = Explorer.__new__(Explorer) + explorer.logger = MagicMock() + explorer.scheduler = FakeScheduler() + explorer.monitor = FakeMonitor() + explorer.experience_pipeline = FakePipeline() + explorer.taskset = SimpleNamespace(feedback=lambda metrics: None) + explorer.use_task_event_completion = False + explorer.min_wait_num = 1 + explorer.pending_eval_tasks = deque() + explorer.pending_step_buffers = {} + explorer.explore_start_time = None + explorer.last_monitored_step = 0 + explorer.explore_step_num = 1 + explorer.model_version = 7 + explorer.config = SimpleNamespace( + explorer=SimpleNamespace( + over_rollout=SimpleNamespace(return_partial_tasks=True) + ) + ) + + await explorer.finish_current_steps() + + self.assertEqual(len(explorer.scheduler.calls), 1) + self.assertEqual(explorer.scheduler.calls[0]["batch_id"], 1) + self.assertEqual(explorer.scheduler.calls[0]["min_num"], 1) + self.assertTrue(explorer.scheduler.calls[0]["return_partial_tasks"]) + self.assertEqual(len(explorer.experience_pipeline.process_calls), 1) + self.assertEqual(explorer.experience_pipeline.finalize_calls, []) + self.assertEqual([step for step, _ in explorer.monitor.logged], [1]) + + async def test_eval_flush_does_not_use_training_pipeline_staging(self): + class FakeScheduler: + def __init__(self): + self.calls = [] + + async def get_results( + self, + batch_id, + min_num=None, + timeout=None, + clear_timeout_tasks=True, + return_partial_tasks=False, + ): + self.calls.append( + { + "batch_id": batch_id, + "min_num": min_num, + "timeout": timeout, + "clear_timeout_tasks": clear_timeout_tasks, + "return_partial_tasks": return_partial_tasks, + } + ) + return [Status(1, 1, metrics=[{"accuracy": 1.0}])], [] + + class FakePipeline: + def __init__(self): + self.finalize_calls = [] + self.process_calls = [] + + class FakeMonitor: + def __init__(self): + self.logged = [] + + def log(self, metric, step): + self.logged.append((step, metric)) + + explorer = Explorer.__new__(Explorer) + explorer.logger = MagicMock() + explorer.scheduler = FakeScheduler() + explorer.monitor = FakeMonitor() + explorer.experience_pipeline = FakePipeline() + explorer.pending_eval_tasks = deque([(3, "eval-set")]) + explorer.eval_start_time = None + explorer.explore_step_num = 3 + explorer.detailed_stats = False + explorer.config = SimpleNamespace( + explorer=SimpleNamespace( + over_rollout=SimpleNamespace(return_partial_tasks=False) + ) + ) + + await explorer._finish_eval_step(step=3) + + self.assertEqual(len(explorer.scheduler.calls), 1) + self.assertEqual(explorer.scheduler.calls[0]["batch_id"], "3/eval-set") + self.assertEqual(explorer.experience_pipeline.process_calls, []) + self.assertEqual(explorer.experience_pipeline.finalize_calls, []) + self.assertEqual([step for step, _ in explorer.monitor.logged], [3]) + + async def test_finish_current_steps_finalizes_upstream_staged_tasks_in_order(self): + class FakeRemoteMethod: + def __init__(self, func): + self.func = func + + async def remote(self, *args, **kwargs): + return await self.func(*args, **kwargs) + + class FakePipeline: + def __init__(self): + self.stage_calls = [] + self.finalize_calls = [] + self.chunk_process_calls = [] + self.stage_task_payloads = FakeRemoteMethod(self._stage_task_payloads) + self.finalize_batch = FakeRemoteMethod(self._finalize_batch) + self.process_serialized_chunks = FakeRemoteMethod(self._process_serialized_chunks) + + async def _stage_task_payloads(self, batch_id, task_id, exp_chunks): + self.stage_calls.append((batch_id, task_id, list(exp_chunks))) + return f"{batch_id}:{task_id}" + + async def _finalize_batch(self, batch_id, task_ids): + self.finalize_calls.append((batch_id, list(task_ids))) + return {"experience_pipeline/experience_count": float(len(task_ids))} + + async def _process_serialized_chunks(self, exp_chunks): + self.chunk_process_calls.append(list(exp_chunks)) + return {"experience_pipeline/experience_count": float(len(exp_chunks))} + + explorer = _build_fake_task_event_explorer(use_payloads=False) + explorer.experience_pipeline = FakePipeline() + explorer.taskset = SimpleNamespace(feedback=lambda metrics: None) + + # In the real direct-staging path, Scheduler returns only lightweight + # task completion metadata here because payloads were already staged by + # WorkflowRunner. Explorer should therefore skip re-staging and only + # drive ordered batch finalization. + await explorer.finish_current_steps() + + self.assertEqual(explorer.experience_pipeline.stage_calls, []) + self.assertEqual( + explorer.experience_pipeline.finalize_calls, + [ + (1, [0, 1]), + (2, [0]), + ], + ) + self.assertEqual(explorer.experience_pipeline.chunk_process_calls, []) + self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) + + def run_serve(config): config.check_and_update() run_stage(config) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 430eb7ec7d6..52a4709ba73 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -15,7 +15,7 @@ from trinity.common.experience import EID, Experience from trinity.common.models.model import InferenceModel, ModelWrapper from trinity.common.workflows import WORKFLOWS, Task, Workflow -from trinity.explorer.scheduler import Scheduler +from trinity.explorer.scheduler import CompletedTaskRef, Scheduler @WORKFLOWS.register_module("dummy_workflow") @@ -351,6 +351,31 @@ def get_api_server_url(self) -> str: return "http://localhost:12345" +@ray.remote +class DummyPayloadStage: + def __init__(self): + self.staged_payloads = defaultdict(dict) + + async def stage_task_payloads(self, batch_id, task_id: int, exp_chunks: list[bytes]): + self.staged_payloads[batch_id][task_id] = list(exp_chunks) + return f"{batch_id}:{task_id}" + + async def take_staged_task_payloads(self, batch_id, task_ids: list[int]) -> list[bytes]: + batch_payloads = self.staged_payloads.get(batch_id, {}) + exp_chunks = [] + for task_id in task_ids: + exp_chunks.extend(batch_payloads.pop(task_id, [])) + if batch_id in self.staged_payloads and not self.staged_payloads[batch_id]: + del self.staged_payloads[batch_id] + return exp_chunks + + async def get_staged_task_ids(self, batch_id): + return sorted(self.staged_payloads.get(batch_id, {}).keys()) + + async def abort_batch(self, batch_id): + self.staged_payloads.pop(batch_id, None) + + def generate_tasks( total_num: int, timeout_num: int = 0, @@ -1233,6 +1258,90 @@ async def test_dynamic_timeout(self): self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4) + async def test_completed_task_events_keep_get_results_compatible(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + scheduler.enable_completed_task_events() + + scheduler.schedule(generate_tasks(4, repeat_times=2), batch_id=0) + + completed_refs = [] + for _ in range(4): + completed_ref = await scheduler.wait_completed_task(timeout=10) + self.assertIsNotNone(completed_ref) + self.assertIsInstance(completed_ref, CompletedTaskRef) + completed_refs.append(completed_ref) + + self.assertEqual({ref.batch_id for ref in completed_refs}, {0}) + self.assertEqual({ref.task_id for ref in completed_refs}, {0, 1, 2, 3}) + + statuses, exps = await scheduler.get_results(batch_id=0, timeout=1) + self.assertEqual(len(statuses), 4) + self.assertEqual(len(exps), 8) + + await scheduler.stop() + + async def test_get_results_reads_payloads_staged_by_workflow_runner(self): + payload_stage = DummyPayloadStage.remote() + scheduler = Scheduler( + self.config, + [DummyModel.remote(), DummyModel.remote()], + experience_pipeline=payload_stage, + ) + await scheduler.start() + + scheduler.schedule(generate_tasks(3, repeat_times=2), batch_id=0) + + statuses, exps = await scheduler.get_results(batch_id=0, timeout=10) + self.assertEqual(len(statuses), 3) + self.assertEqual(len(exps), 6) + self.assertEqual(ray.get(payload_stage.get_staged_task_ids.remote(0)), []) + + await scheduler.stop() + + async def test_timeout_cleanup_aborts_leftover_staged_payloads(self): + payload_stage = DummyPayloadStage.remote() + scheduler = Scheduler( + self.config, + [DummyModel.remote(), DummyModel.remote()], + experience_pipeline=payload_stage, + ) + await scheduler.start() + + # Seed one orphan staged payload to verify timeout cleanup clears the + # whole batch after draining any successfully completed task payloads. + await payload_stage.stage_task_payloads.remote(0, 99, [b"orphan-payload"]) + scheduler.schedule(generate_tasks(1, timeout_num=1, timeout_seconds=10), batch_id=0) + + statuses, exps = await scheduler.get_results(batch_id=0, min_num=2, timeout=1) + self.assertEqual(len(statuses), 1) + self.assertEqual(len(exps), 1) + self.assertEqual(ray.get(payload_stage.get_staged_task_ids.remote(0)), []) + + await scheduler.stop() + + async def test_eval_tasks_do_not_stage_payloads(self): + payload_stage = DummyPayloadStage.remote() + scheduler = Scheduler( + self.config, + [DummyModel.remote(), DummyModel.remote()], + experience_pipeline=payload_stage, + ) + await scheduler.start() + + eval_tasks = generate_tasks(2, repeat_times=2) + for task in eval_tasks: + task.is_eval = True + + scheduler.schedule(eval_tasks, batch_id="0/eval") + statuses, exps = await scheduler.get_results(batch_id="0/eval", timeout=10) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 0) + self.assertEqual(ray.get(payload_stage.get_staged_task_ids.remote("0/eval")), []) + + await scheduler.stop() + def tearDown(self): try: ray.shutdown() diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index 8052d8ccd1e..7f7510681d2 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -1,6 +1,7 @@ import asyncio import time import traceback +from collections import defaultdict from typing import Dict, Optional from trinity.buffer.buffer import BufferWriter, get_buffer_reader, get_buffer_writer @@ -46,6 +47,7 @@ def __init__(self, config: Config): ) self.auxiliary_model_wrappers = {} self.auxiliary_models = {} + self.staged_task_payloads = defaultdict(dict) def _init_input_storage( self, @@ -137,8 +139,62 @@ async def process(self, exp_bytes: bytes) -> Dict: Returns: Dict: A dictionary containing metrics collected during the processing of experiences. """ - st = time.time() exps = Experience.deserialize_many(exp_bytes) + return await self._process_experiences(exps) + + async def process_serialized_chunks(self, exp_chunks: list[bytes]) -> Dict: + """Process a batch assembled from multiple serialized task payloads.""" + exps = [] + for exp_bytes in exp_chunks: + if not exp_bytes: + continue + exps.extend(Experience.deserialize_many(exp_bytes)) + return await self._process_experiences(exps) + + async def stage_task_payloads( + self, batch_id, task_id: int, exp_chunks: list[bytes] + ) -> Optional[str]: + """Stage serialized payload chunks for one completed task.""" + valid_chunks = [chunk for chunk in exp_chunks if chunk] + if not valid_chunks: + return None + self.staged_task_payloads[batch_id][task_id] = valid_chunks + return f"{batch_id}:{task_id}" + + async def finalize_batch(self, batch_id, task_ids: Optional[list[int]] = None) -> Dict: + """Finalize a staged batch and process all staged task payloads.""" + batch_payloads = self.staged_task_payloads.get(batch_id, {}) + if not batch_payloads: + return await self._process_experiences([]) + + selected_task_ids = task_ids or list(batch_payloads.keys()) + exp_chunks = [] + for task_id in selected_task_ids: + exp_chunks.extend(batch_payloads.pop(task_id, [])) + + if batch_id in self.staged_task_payloads and not self.staged_task_payloads[batch_id]: + del self.staged_task_payloads[batch_id] + + return await self.process_serialized_chunks(exp_chunks) + + async def take_staged_task_payloads(self, batch_id, task_ids: list[int]) -> list[bytes]: + """Drain staged payload chunks for selected tasks without processing them.""" + batch_payloads = self.staged_task_payloads.get(batch_id, {}) + exp_chunks = [] + for task_id in task_ids: + exp_chunks.extend(batch_payloads.pop(task_id, [])) + + if batch_id in self.staged_task_payloads and not self.staged_task_payloads[batch_id]: + del self.staged_task_payloads[batch_id] + + return exp_chunks + + async def abort_batch(self, batch_id) -> None: + """Discard any staged payloads for a batch.""" + self.staged_task_payloads.pop(batch_id, None) + + async def _process_experiences(self, exps: list[Experience]) -> Dict: + st = time.time() if self.input_store is not None: await self.input_store.write_async(exps) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index db82c83067a..db6836185b6 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -8,7 +8,8 @@ import time import traceback from collections import deque -from typing import List, Optional +from dataclasses import dataclass, field +from typing import Dict, List, Optional import ray import torch @@ -36,6 +37,17 @@ from trinity.utils.timer import Timer +@dataclass +class ExploreStepBuffer: + """Buffered task results for one explore step.""" + + expected_task_count: int + statuses: List = field(default_factory=list) + experience_payloads: List[bytes] = field(default_factory=list) + staged_task_ids: List[int] = field(default_factory=list) + completed_task_count: int = 0 + + class Explorer: """Responsible for exploring the taskset.""" @@ -76,8 +88,13 @@ def __init__(self, config: Config): ) else: self.min_wait_num = None + self.use_task_event_completion = ( + self.min_wait_num is None + and not self.config.explorer.over_rollout.return_partial_tasks + ) self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL self.pending_eval_tasks = deque() + self.pending_step_buffers: Dict[int, ExploreStepBuffer] = {} # For checkpoint weights update # Use explorer to periodically load the latest model weights and @@ -206,8 +223,17 @@ async def prepare(self) -> None: await self.setup_weight_sync_group(master_address, master_port) if self.config.mode != "serve": - self.scheduler = Scheduler(self.config, self.models, self.auxiliary_models) + self.scheduler = Scheduler( + self.config, + self.models, + self.auxiliary_models, + experience_pipeline=self.experience_pipeline + if self.use_task_event_completion + else None, + ) await self.scheduler.start() + if self.use_task_event_completion: + self.scheduler.enable_completed_task_events() if self.config.explorer.eval_on_startup and self.explore_step_num == 0: await self.eval() @@ -269,6 +295,10 @@ async def explore_step(self) -> bool: return False self.explore_step_num += 1 self.scheduler.schedule(tasks, batch_id=self.explore_step_num) + if self.use_task_event_completion: + self.pending_step_buffers[self.explore_step_num] = ExploreStepBuffer( + expected_task_count=len(tasks) + ) return True async def finish_current_steps(self) -> None: @@ -388,6 +418,36 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int self.monitor.log(metric, step=end_step) async def _finish_explore_step(self, step: int, model_version: int) -> None: + if self.use_task_event_completion: + await self._wait_step_buffer(step) + step_buffer = self.pending_step_buffers.pop(step, None) + if step_buffer is None: + return + + metric = {"rollout/model_version": model_version} + if self.experience_pipeline is not None: + if step_buffer.staged_task_ids: + pipeline_metrics = await self.experience_pipeline.finalize_batch.remote( + step, step_buffer.staged_task_ids + ) + else: + pipeline_metrics = await self.experience_pipeline.process_serialized_chunks.remote( + step_buffer.experience_payloads + ) + self.taskset.feedback(pipeline_metrics) + metric.update(pipeline_metrics) + if step_buffer.statuses: + metric.update( + gather_metrics( + [status.metrics[0] for status in step_buffer.statuses], + "rollout", + ) + ) + metric["rollout/finished_task_count"] = len(step_buffer.statuses) + if self.monitor is not None: + self.monitor.log(metric, step=step) + return + metric = {"rollout/model_version": model_version} with Timer(metric, "time/wait_explore_step"): statuses, exps = await self.scheduler.get_results( @@ -407,6 +467,65 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: if self.monitor is not None: self.monitor.log(metric, step=step) + async def _store_completed_task_result(self, step: int, result) -> None: + step_buffer = self.pending_step_buffers.get(step) + if step_buffer is None: + return + step_buffer.completed_task_count += 1 + step_buffer.statuses.append(result.status) + if self.experience_pipeline is not None and result.experience_payloads: + staged = await self.experience_pipeline.stage_task_payloads.remote( + step, + result.task_id, + result.experience_payloads, + ) + if staged is not None: + step_buffer.staged_task_ids.append(result.task_id) + elif self.experience_pipeline is not None and result.status.completed_runs > 0: + step_buffer.staged_task_ids.append(result.task_id) + elif result.experience_payloads: + step_buffer.experience_payloads.extend(result.experience_payloads) + elif result.experiences: + step_buffer.experience_payloads.append(Experience.serialize_many(result.experiences)) + + async def _wait_step_buffer(self, step: int) -> None: + if self.scheduler is None: + return + + step_buffer = self.pending_step_buffers.get(step) + if step_buffer is None: + return + + start_time = time.time() + while step_buffer.completed_task_count < step_buffer.expected_task_count: + remaining = self.scheduler.default_timeout - (time.time() - start_time) + if remaining <= 0: + break + completed_ref = await self.scheduler.wait_completed_task(timeout=remaining) + if completed_ref is None: + break + completed_result = self.scheduler.pop_completed_task( + completed_ref.batch_id, completed_ref.task_id + ) + if completed_result is None: + continue + if isinstance(completed_ref.batch_id, int): + await self._store_completed_task_result(completed_ref.batch_id, completed_result) + + if step_buffer.completed_task_count >= step_buffer.expected_task_count: + return + + statuses, exps = await self.scheduler.get_results( + batch_id=step, + timeout=0, + clear_timeout_tasks=True, + return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, + ) + step_buffer.statuses.extend(statuses) + if exps: + step_buffer.experience_payloads.append(Experience.serialize_many(exps)) + step_buffer.completed_task_count += len(statuses) + async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: if not self.pending_eval_tasks: return diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index ea3206ef00b..d013b1dc257 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -33,11 +33,29 @@ class TaskWrapper: completed_runs: int = 0 total_runs: int = 0 # total planned runs for the whole task metrics: List[Dict[str, float]] = field(default_factory=list) - experiences: List[Experience] = field(default_factory=list) + experience_payloads: List[bytes] = field(default_factory=list) first_error: Optional[str] = None emitted: bool = False +@dataclass(frozen=True) +class CompletedTaskResult: + """A completed task result stored by batch and task id.""" + + task_id: int + status: Status + experiences: List[Experience] = field(default_factory=list) + experience_payloads: List[bytes] = field(default_factory=list) + + +@dataclass(frozen=True) +class CompletedTaskRef: + """A lightweight reference to a completed task result.""" + + batch_id: Union[int, str] + task_id: int + + # Adapted from verl/trainer/ppo/metric_utils.py def bootstrap_metric( data: list[Any], @@ -138,12 +156,14 @@ def __init__( rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config, + experience_pipeline=None, ): self.logger = get_logger(__name__) self.runner_id = runner_id self.rollout_model = rollout_model self.auxiliary_models = auxiliary_models self.config = config + self.experience_pipeline = experience_pipeline self.retry_times = config.explorer.max_retry_times self.timeout = config.explorer.max_timeout self.namespace = config.ray_namespace @@ -161,7 +181,13 @@ def _create_runner(self): "env_vars": self.config.explorer.env_vars, }, ) - .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) + .remote( + self.config, + self.rollout_model, + self.auxiliary_models, + self.experience_pipeline, + self.runner_id, + ) ) async def prepare(self): @@ -179,7 +205,7 @@ async def run_with_retry( run_id_base: int, timeout: float, collect_partial_runs: bool, - ) -> Tuple[Status, List, int, float]: + ) -> Tuple[Status, bytes, int, float]: """ Args: task (`TaskWrapper`): The task to run. @@ -197,7 +223,7 @@ async def run_with_retry( await self.runner.__ray_ready__.remote() start_time = time.time() status = Status(completed_runs=0, total_runs=repeat_times, metrics=list()) - exps = [] + exp_payload = b"" run_task_ref = None task2run = replace( task.task, @@ -216,7 +242,7 @@ async def run_with_retry( run_id_base=run_id_base, collect_partial_runs=collect_partial_runs, ) - status, exps = await asyncio.wait_for( + status, exp_payload = await asyncio.wait_for( run_task_ref, timeout=timeout, ) @@ -261,7 +287,7 @@ async def run_with_retry( finally: end_time = time.time() status.metrics.append({"time/task_execution": end_time - start_time}) - return status, exps, self.runner_id, end_time - start_time + return status, exp_payload, self.runner_id, end_time - start_time async def restart_runner(self): old_runner = self.runner @@ -299,11 +325,13 @@ def __init__( config: Config, rollout_model: List[InferenceModel], auxiliary_models: Optional[List[List[InferenceModel]]] = None, + experience_pipeline=None, ): self.logger = get_logger(__name__) self.config = config self.rollout_model = rollout_model self.auxiliary_models = auxiliary_models or [] + self.experience_pipeline = experience_pipeline self.namespace = ray.get_runtime_context().namespace self.default_timeout = config.explorer.max_timeout * (config.explorer.max_retry_times + 1) self.max_retry_times = config.explorer.max_retry_times @@ -329,10 +357,12 @@ def __init__( self.running_task_runner_map: Dict[asyncio.Future, int] = dict() # future -> runner_id self.cancelled_task_restart_map: Dict[asyncio.Future, bool] = dict() self.completed_tasks: Dict[ - Union[int, str], deque[Tuple[Status, List[Experience]]] + Union[int, str], Dict[int, CompletedTaskResult] ] = defaultdict( - deque + dict ) # batch_id -> results + self.completed_task_refs: asyncio.Queue[CompletedTaskRef] = asyncio.Queue() + self.emit_completed_task_events = False self.background_tasks: set[asyncio.Task] = set() self.scheduler_task: Optional[asyncio.Task] = None @@ -353,6 +383,7 @@ async def _create_runner( for j in range(len(self.auxiliary_models)) ], config=self.config, + experience_pipeline=self.experience_pipeline, ) await runner.prepare() self.runners[runner_id] = runner @@ -448,11 +479,11 @@ def task_done_callback(self, async_task: asyncio.Task): self._schedule_runner_restart(runner_id) else: self.cancelled_task_restart_map.pop(async_task, None) - status, exps, runner_id, run_time = async_task.result() + status, exp_payload, runner_id, run_time = async_task.result() if not task.task.is_eval: # only count running time for non-eval tasks self.total_running_time += run_time self.total_completed_tasks += 1 - self._accumulate_task_result(task, status, exps) + self._accumulate_task_result(task, status, exp_payload) self.busy_runners.pop(runner_id, None) self.idle_runners.add(runner_id) # If all sub runs in a task are completed @@ -466,16 +497,17 @@ def task_done_callback(self, async_task: asyncio.Task): del self.running_tasks[task.batch_id] def _accumulate_task_result( - self, task: TaskWrapper, status: Status, experiences: List[Experience] + self, task: TaskWrapper, status: Status, experience_payload: bytes ) -> None: task.finished_sub_task_num += 1 task.completed_runs += status.completed_runs task.metrics.extend(status.metrics) - task.experiences.extend(experiences) + if experience_payload: + task.experience_payloads.append(experience_payload) if not status.ok and task.first_error is None: task.first_error = status.message - def _build_task_result(self, task: TaskWrapper) -> Tuple[Status, List[Experience]]: + def _build_task_result(self, task: TaskWrapper) -> Tuple[Status, List[bytes]]: if task.completed_runs < task.total_runs: message = f"{task.completed_runs}/{task.total_runs} runs completed successfully." if task.first_error: @@ -490,14 +522,50 @@ def _build_task_result(self, task: TaskWrapper) -> Tuple[Status, List[Experience metrics=[calculate_task_level_metrics(task.metrics, task.task.is_eval)], message=message, ) - return status, list(task.experiences) + return status, list(task.experience_payloads) def _emit_task_result(self, task: TaskWrapper) -> None: if task.emitted: return - self.completed_tasks[task.batch_id].appendleft(self._build_task_result(task)) + status, experience_payloads = self._build_task_result(task) + self.completed_tasks[task.batch_id][task.task.task_id] = CompletedTaskResult( + task_id=task.task.task_id, + status=status, + experience_payloads=experience_payloads, + ) + if self.emit_completed_task_events and not task.task.is_eval: + self.completed_task_refs.put_nowait( + CompletedTaskRef(batch_id=task.batch_id, task_id=task.task.task_id) + ) task.emitted = True + def enable_completed_task_events(self) -> None: + self.emit_completed_task_events = True + + def disable_completed_task_events(self) -> None: + self.emit_completed_task_events = False + + async def wait_completed_task( + self, timeout: Optional[float] = None + ) -> Optional[CompletedTaskRef]: + try: + if timeout is None: + return await self.completed_task_refs.get() + return await asyncio.wait_for(self.completed_task_refs.get(), timeout=timeout) + except asyncio.TimeoutError: + return None + + def pop_completed_task( + self, batch_id: Union[int, str], task_id: int + ) -> Optional[CompletedTaskResult]: + batch_results = self.completed_tasks.get(batch_id) + if not batch_results: + return None + result = batch_results.pop(task_id, None) + if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: + del self.completed_tasks[batch_id] + return result + def _collect_incomplete_tasks(self, batch_id: Union[int, str]) -> List[TaskWrapper]: tasks = {} for task, _, _ in self.pending_tasks.get(batch_id, deque()): @@ -709,18 +777,35 @@ async def get_results( statuses = [] experiences = [] - completed_queue = self.completed_tasks.get(batch_id, deque()) - while completed_queue: - status, exps = completed_queue.pop() - statuses.append(status) - if isinstance(exps, list): - experiences.extend(exps) + completed_results = list(self.completed_tasks.get(batch_id, {}).values()) + staged_task_ids = [] + for result in completed_results: + statuses.append(result.status) + if result.experience_payloads: + for exp_payload in result.experience_payloads: + experiences.extend(Experience.deserialize_many(exp_payload)) else: - experiences.append(exps) + experiences.extend(result.experiences) + if ( + self.experience_pipeline is not None + and result.status.completed_runs > 0 + and not result.experiences + ): + staged_task_ids.append(result.task_id) - if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: + if staged_task_ids: + staged_payloads = await self.experience_pipeline.take_staged_task_payloads.remote( + batch_id, staged_task_ids + ) + for exp_payload in staged_payloads: + experiences.extend(Experience.deserialize_many(exp_payload)) + + if batch_id in self.completed_tasks: del self.completed_tasks[batch_id] + if clear_timeout_tasks and self.experience_pipeline is not None: + await self.experience_pipeline.abort_batch.remote(batch_id) + completed_count = len(statuses) if completed_count < min_num: self.logger.warning( @@ -772,8 +857,11 @@ async def wait_all( self.logger.error(error_msg) if clear_timeout_tasks: - for batch_id in self.pending_tasks.keys() | self.running_tasks.keys(): + batch_ids_to_abort = self.pending_tasks.keys() | self.running_tasks.keys() + for batch_id in batch_ids_to_abort: self._clear_timeout_tasks(batch_id) + if self.experience_pipeline is not None: + await self.experience_pipeline.abort_batch.remote(batch_id) asyncio.gather( *[self._restart_runner(runner_id) for runner_id in self.busy_runners.keys()] ) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 5bd7fc667c3..ee8fe7cf771 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -68,6 +68,7 @@ def __init__( config: Config, model: InferenceModel, auxiliary_models: Optional[List[InferenceModel]] = None, + experience_pipeline=None, runner_id: Optional[int] = None, ) -> None: self.name = f"{config.explorer.name}_runner_{runner_id}" @@ -80,6 +81,7 @@ def __init__( enable_history=config.explorer.rollout_model.enable_history, ) self.auxiliary_models = auxiliary_models or [] + self.experience_pipeline = experience_pipeline self.auxiliary_model_wrappers = [ ModelWrapper( model, @@ -398,11 +400,11 @@ async def run_task( repeat_times: int = 1, run_id_base: int = 0, collect_partial_runs: bool = True, - ) -> Tuple[Status, List[Experience]]: + ) -> Tuple[Status, bytes]: """Run the task and return the states.""" # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead + st = time.time() try: - st = time.time() model_version = await self.model_wrapper.model_version_async self.runner_state["model_version"] = model_version self.logger.info( @@ -436,9 +438,17 @@ async def run_task( if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer - return status, [] + return status, b"" else: - return status, exps + exp_payload = Experience.serialize_many(exps) + if self.experience_pipeline is not None: + await self.experience_pipeline.stage_task_payloads.remote( + task.batch_id, + task.task_id, + [exp_payload], + ) + return status, b"" + return status, exp_payload except Exception as e: error_trace_back = traceback.format_exc() @@ -450,7 +460,7 @@ async def run_task( metrics=[{"time/run_execution": time.time() - st}], message=error_trace_back.rstrip(), ), - [], + b"", ) From c444d2d3a1284087a24b32e9768f5b5505de8b62 Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 14:04:24 +0800 Subject: [PATCH 02/11] fix dynamic timeout --- tests/explorer/explorer_test.py | 10 +++------- tests/explorer/scheduler_test.py | 24 ++++++++++++++++++++++++ trinity/explorer/explorer.py | 9 +++++---- trinity/explorer/scheduler.py | 13 +++++++------ 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 9d745b2d5d6..bb379a9f15d 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -35,7 +35,7 @@ ) from trinity.common.constants import StorageType from trinity.common.experience import Experience -from trinity.explorer.explorer import ExploreStepBuffer, Explorer +from trinity.explorer.explorer import Explorer, ExploreStepBuffer from trinity.explorer.proxy.client import TrinityClient from trinity.explorer.scheduler import CompletedTaskRef, CompletedTaskResult from trinity.explorer.workflow_runner import Status @@ -506,9 +506,7 @@ def log(self, metric, step): explorer.explore_step_num = 1 explorer.model_version = 7 explorer.config = SimpleNamespace( - explorer=SimpleNamespace( - over_rollout=SimpleNamespace(return_partial_tasks=True) - ) + explorer=SimpleNamespace(over_rollout=SimpleNamespace(return_partial_tasks=True)) ) await explorer.finish_current_steps() @@ -567,9 +565,7 @@ def log(self, metric, step): explorer.explore_step_num = 3 explorer.detailed_stats = False explorer.config = SimpleNamespace( - explorer=SimpleNamespace( - over_rollout=SimpleNamespace(return_partial_tasks=False) - ) + explorer=SimpleNamespace(over_rollout=SimpleNamespace(return_partial_tasks=False)) ) await explorer._finish_eval_step(step=3) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 52a4709ba73..4fb303d86ac 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1258,6 +1258,30 @@ async def test_dynamic_timeout(self): self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4) + async def test_dynamic_timeout_counts_completed_tasks_instead_of_sub_tasks(self): + self.config.explorer.dynamic_timeout.enable = True + self.config.explorer.dynamic_timeout.ratio = 1.5 + self.config.buffer.batch_size = 3 + self.config.explorer.max_timeout = 20 + self.config.explorer.max_retry_times = 0 + self.config.explorer.max_repeat_times_per_runner = 2 + self.config.check_and_update() + + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks = generate_tasks(0, timeout_num=2, repeat_times=4, timeout_seconds=1) + scheduler.schedule(tasks, batch_id=0) + statuses, exps = await scheduler.get_results(batch_id=0) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 8) + self.assertEqual(scheduler.total_completed_tasks, 2) + self.assertEqual(scheduler.total_completed_sub_tasks, 4) + self.assertEqual(scheduler.dynamic_timeout(), scheduler.default_timeout) + + await scheduler.stop() + async def test_completed_task_events_keep_get_results_compatible(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index db6836185b6..14069e9ac7d 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -89,8 +89,7 @@ def __init__(self, config: Config): else: self.min_wait_num = None self.use_task_event_completion = ( - self.min_wait_num is None - and not self.config.explorer.over_rollout.return_partial_tasks + self.min_wait_num is None and not self.config.explorer.over_rollout.return_partial_tasks ) self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL self.pending_eval_tasks = deque() @@ -431,8 +430,10 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: step, step_buffer.staged_task_ids ) else: - pipeline_metrics = await self.experience_pipeline.process_serialized_chunks.remote( - step_buffer.experience_payloads + pipeline_metrics = ( + await self.experience_pipeline.process_serialized_chunks.remote( + step_buffer.experience_payloads + ) ) self.taskset.feedback(pipeline_metrics) metric.update(pipeline_metrics) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index d013b1dc257..cc673b83c70 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -356,9 +356,7 @@ def __init__( self.running_task_map: Dict[asyncio.Future, TaskWrapper] = dict() # future -> task self.running_task_runner_map: Dict[asyncio.Future, int] = dict() # future -> runner_id self.cancelled_task_restart_map: Dict[asyncio.Future, bool] = dict() - self.completed_tasks: Dict[ - Union[int, str], Dict[int, CompletedTaskResult] - ] = defaultdict( + self.completed_tasks: Dict[Union[int, str], Dict[int, CompletedTaskResult]] = defaultdict( dict ) # batch_id -> results self.completed_task_refs: asyncio.Queue[CompletedTaskRef] = asyncio.Queue() @@ -369,6 +367,7 @@ def __init__( self.running = False self.total_running_time = 0.0 + self.total_completed_sub_tasks = 0 self.total_completed_tasks = 0 async def _create_runner( @@ -480,14 +479,16 @@ def task_done_callback(self, async_task: asyncio.Task): else: self.cancelled_task_restart_map.pop(async_task, None) status, exp_payload, runner_id, run_time = async_task.result() - if not task.task.is_eval: # only count running time for non-eval tasks + if not task.task.is_eval: self.total_running_time += run_time - self.total_completed_tasks += 1 + self.total_completed_sub_tasks += 1 self._accumulate_task_result(task, status, exp_payload) self.busy_runners.pop(runner_id, None) self.idle_runners.add(runner_id) # If all sub runs in a task are completed if task.finished_sub_task_num == task.sub_task_num: + if not task.task.is_eval: + self.total_completed_tasks += 1 self._emit_task_result(task) self.logger.debug(f"Task completed (batch_id {task.batch_id}).") @@ -686,7 +687,7 @@ def dynamic_timeout(self, timeout: Optional[float] = None) -> float: return max_timeout if self.total_completed_tasks < self.default_batch_size: return max_timeout - avg_time_per_task = self.total_running_time / self.total_completed_tasks + avg_time_per_task = self.total_running_time / self.total_completed_sub_tasks return min( max_timeout, avg_time_per_task * self.config.explorer.dynamic_timeout.ratio, From dc965dd1388f8b5d5e35c33e853d9911f523521b Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 14:38:00 +0800 Subject: [PATCH 03/11] fix dynamic timeout --- trinity/explorer/scheduler.py | 131 ++++++++++++++++++---------- trinity/explorer/workflow_runner.py | 15 ++-- 2 files changed, 92 insertions(+), 54 deletions(-) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index cc673b83c70..59fe7189420 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -529,14 +529,17 @@ def _emit_task_result(self, task: TaskWrapper) -> None: if task.emitted: return status, experience_payloads = self._build_task_result(task) - self.completed_tasks[task.batch_id][task.task.task_id] = CompletedTaskResult( - task_id=task.task.task_id, + task_id = task.task.task_id + if not isinstance(task_id, int): + raise TypeError(f"Expected task_id to be int when emitting results, got {task_id!r}") + self.completed_tasks[task.batch_id][task_id] = CompletedTaskResult( + task_id=task_id, status=status, experience_payloads=experience_payloads, ) if self.emit_completed_task_events and not task.task.is_eval: self.completed_task_refs.put_nowait( - CompletedTaskRef(batch_id=task.batch_id, task_id=task.task.task_id) + CompletedTaskRef(batch_id=task.batch_id, task_id=task_id) ) task.emitted = True @@ -716,42 +719,36 @@ async def _cleanup_batch( if runners_to_restart: await asyncio.gather(*[self._restart_runner(rid) for rid in runners_to_restart]) - async def get_results( - self, - batch_id: Union[int, str], - min_num: Optional[int] = None, - timeout: Optional[float] = None, - clear_timeout_tasks: bool = True, - return_partial_tasks: bool = False, - ) -> Tuple[List[Status], List[Experience]]: - """Get the result of tasks at the specific batch_id. - - Args: - batch_id (`Union[int, str]`): Only wait for tasks at this batch. - min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. - timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. - clear_timeout_tasks (`bool`): Whether to clear timeout tasks. - return_partial_tasks (`bool`): Whether to emit tasks with partial successful runs when cleaning up unfinished tasks. - """ - timeout = timeout or self.default_timeout - start_time = time.time() + def _resolve_result_target( + self, batch_id: Union[int, str], min_num: Optional[int] + ) -> Tuple[int, int]: scheduled_num = self.task_num_map.get(batch_id, 0) if min_num is None: - min_num = scheduled_num - elif min_num > scheduled_num: + return scheduled_num, scheduled_num + if min_num > scheduled_num: self.logger.warning( f"Requested min_num {min_num} is greater than scheduled tasks {scheduled_num} at batch_id {batch_id}. Adjusting min_num to {scheduled_num}." ) - min_num = scheduled_num + return scheduled_num, scheduled_num + return scheduled_num, min_num - self.logger.debug(f"Waiting for {min_num} tasks to complete...") + async def _wait_for_batch_results( + self, + batch_id: Union[int, str], + min_num: int, + scheduled_num: int, + timeout: float, + clear_timeout_tasks: bool, + return_partial_tasks: bool, + ) -> bool: + start_time = time.time() min_threshold_reached_time = None while time.time() - start_time <= timeout: - completed_count = len(self.completed_tasks.get(batch_id, [])) + completed_count = len(self.completed_tasks.get(batch_id, {})) if completed_count >= min_num: min_threshold_reached_time = min_threshold_reached_time or time.time() if completed_count >= scheduled_num: - break + return False if ( time.time() - min_threshold_reached_time >= self.config.explorer.over_rollout.wait_after_min @@ -762,20 +759,13 @@ async def get_results( return_partial_tasks=return_partial_tasks, restart_runners=False, ) - break + return False await asyncio.sleep(0.1) + return True - if time.time() - start_time > timeout: - self.logger.error( - f"Timed out waiting for tasks at batch {batch_id} to complete after {timeout} seconds" - ) - if clear_timeout_tasks: - await self._cleanup_batch( - batch_id, - return_partial_tasks=return_partial_tasks, - restart_runners=True, - ) - + async def _collect_batch_results( + self, batch_id: Union[int, str] + ) -> Tuple[List[Status], List[Experience]]: statuses = [] experiences = [] completed_results = list(self.completed_tasks.get(batch_id, {}).values()) @@ -785,14 +775,15 @@ async def get_results( if result.experience_payloads: for exp_payload in result.experience_payloads: experiences.extend(Experience.deserialize_many(exp_payload)) - else: - experiences.extend(result.experiences) - if ( - self.experience_pipeline is not None - and result.status.completed_runs > 0 - and not result.experiences - ): - staged_task_ids.append(result.task_id) + continue + + experiences.extend(result.experiences) + if ( + self.experience_pipeline is not None + and result.status.completed_runs > 0 + and not result.experiences + ): + staged_task_ids.append(result.task_id) if staged_task_ids: staged_payloads = await self.experience_pipeline.take_staged_task_payloads.remote( @@ -801,6 +792,50 @@ async def get_results( for exp_payload in staged_payloads: experiences.extend(Experience.deserialize_many(exp_payload)) + return statuses, experiences + + async def get_results( + self, + batch_id: Union[int, str], + min_num: Optional[int] = None, + timeout: Optional[float] = None, + clear_timeout_tasks: bool = True, + return_partial_tasks: bool = False, + ) -> Tuple[List[Status], List[Experience]]: + """Get the result of tasks at the specific batch_id. + + Args: + batch_id (`Union[int, str]`): Only wait for tasks at this batch. + min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. + timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. + clear_timeout_tasks (`bool`): Whether to clear timeout tasks. + return_partial_tasks (`bool`): Whether to emit tasks with partial successful runs when cleaning up unfinished tasks. + """ + timeout = timeout or self.default_timeout + scheduled_num, min_num = self._resolve_result_target(batch_id, min_num) + + self.logger.debug(f"Waiting for {min_num} tasks to complete...") + timed_out = await self._wait_for_batch_results( + batch_id=batch_id, + min_num=min_num, + scheduled_num=scheduled_num, + timeout=timeout, + clear_timeout_tasks=clear_timeout_tasks, + return_partial_tasks=return_partial_tasks, + ) + if timed_out: + self.logger.error( + f"Timed out waiting for tasks at batch {batch_id} to complete after {timeout} seconds" + ) + if clear_timeout_tasks: + await self._cleanup_batch( + batch_id, + return_partial_tasks=return_partial_tasks, + restart_runners=True, + ) + + statuses, experiences = await self._collect_batch_results(batch_id) + if batch_id in self.completed_tasks: del self.completed_tasks[batch_id] diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index ee8fe7cf771..b75f91978c4 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -521,20 +521,23 @@ async def debug(self) -> None: task = tasks[0] self.logger.info(f"Start debugging task:\n{task.raw_task}") if not self.enable_profiling: - status, exps = await self.run_task( + status, exp_payload = await self.run_task( task=task, batch_id="debug", repeat_times=1, run_id_base=0 ) else: from viztracer import VizTracer with VizTracer(output_file=self.output_profiling_file): - status, exps = await self.run_task( + status, exp_payload = await self.run_task( task=task, batch_id="debug", repeat_times=1, run_id_base=0 ) - if not status.ok and len(exps) == 0: - exps = self.model_wrapper.extract_experience_from_history() - self.logger.info(f"Debugging failed, extracting {len(exps)} experiences from history.") - await self.sqlite_writer.write_async(exps) + experiences = Experience.deserialize_many(exp_payload) if exp_payload else [] + if not status.ok and not experiences: + experiences = self.model_wrapper.extract_experience_from_history() + self.logger.info( + f"Debugging failed, extracting {len(experiences)} experiences from history." + ) + await self.sqlite_writer.write_async(experiences) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") else: From 5c920c6125439804218682f200dfd20a3e2ab326 Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 14:57:28 +0800 Subject: [PATCH 04/11] fix dynamic timeout warmup --- .../source/tutorial/trinity_configs.md | 6 +++-- .../source_zh/tutorial/trinity_configs.md | 6 +++-- tests/explorer/scheduler_test.py | 15 ++++++++++- trinity/common/config.py | 1 + trinity/explorer/scheduler.py | 25 ++++++++++++++----- 5 files changed, 42 insertions(+), 11 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 1cd0467be93..080c8f3a64d 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -424,6 +424,7 @@ explorer: return_partial_tasks: false dynamic_timeout: enable: false + warmup_min_steps: 1 ratio: 3.0 runner_state_report_interval: 0 ``` @@ -452,9 +453,10 @@ explorer: - `ratio`: Explorer will only wait for `(1 - ratio) * batch_size` of tasks at each step. Default is `0.0`, meaning waiting for all tasks. - `wait_after_min`: After reaching the minimum task threshold, wait for this many seconds before proceeding. Default is `30.0` seconds. - `return_partial_tasks`: Whether to return the results of tasks that have only completed partially (e.g., only some runs in GRPO). Default is `false`, meaning only return results of tasks that have completed all runs. -- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks. +- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each scheduled execution unit based on historical runtime statistics. - `enable`: Whether to enable dynamic timeout. Default is `false`. - - `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`. + - `warmup_min_steps`: The minimum number of fully observed non-eval steps required before dynamic timeout takes effect. Default is `1`. This is equivalent to a warmup batch/step count, and helps avoid enabling dynamic timeout based only on a few fast early completions. + - `ratio`: The timeout for each scheduled execution unit is dynamically set to `average_time_per_success_execution * ratio`. Default is `3.0`. - `runner_state_report_interval`: Workflow runner report interval (in seconds). If set to a value greater than `0`, the workflow runner will periodically report its status to the main explorer process and print it in the command line for monitoring. Default is `0`, meaning this feature is disabled. If you want to use this feature, it is recommended to set it to `10` seconds or longer to minimize performance impact. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 5192f1b2786..5792ca77d1c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -421,6 +421,7 @@ explorer: return_partial_tasks: false dynamic_timeout: enable: false + warmup_min_steps: 1 ratio: 3.0 runner_state_report_interval: 0 ``` @@ -449,9 +450,10 @@ explorer: - `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。 - `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。 - `return_partial_tasks`: 是否返回仅部分完成的任务结果(例如,在 GRPO 中仅完成部分 run 的任务)。默认为 `false`,表示仅返回已完成组内所有 run 的任务结果。 -- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。 +- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据历史执行耗时动态调整每个调度执行单元的超时时间。 - `enable`: 是否启用动态超时。默认为 `false`。 - - `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。 + - `warmup_min_steps`: 动态超时生效前至少需要观测到多少个完整结束的非评估 step。默认为 `1`。它等价于预热所需的 batch/step 数,可以避免只根据少量较快完成的早期任务就提前启用动态超时。 + - `ratio`: 每个调度执行单元的超时时间动态设置为 `average_time_per_success_execution * ratio`。默认为 `3.0`。 - `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。推荐如需使用此功能,将其设置为 `10` 秒或更长时间以减少对性能的影响。 --- diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 4fb303d86ac..7969a9c2b3d 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1214,6 +1214,7 @@ async def fake_restart_runner(runner_id): async def test_dynamic_timeout(self): self.config.explorer.dynamic_timeout.enable = True + self.config.explorer.dynamic_timeout.warmup_min_steps = 1 self.config.explorer.dynamic_timeout.ratio = 3.0 self.config.buffer.batch_size = 4 self.config.explorer.max_timeout = 20 @@ -1231,6 +1232,7 @@ async def test_dynamic_timeout(self): self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 0) self.assertEqual(scheduler.total_running_time, 0) + self.assertEqual(scheduler.total_completed_steps, 0) self.assertEqual(scheduler.total_completed_tasks, 0) tasks = [] # generate 4 tasks that will run 1 second @@ -1258,8 +1260,9 @@ async def test_dynamic_timeout(self): self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4) - async def test_dynamic_timeout_counts_completed_tasks_instead_of_sub_tasks(self): + async def test_dynamic_timeout_warmup_min_steps_uses_completed_steps(self): self.config.explorer.dynamic_timeout.enable = True + self.config.explorer.dynamic_timeout.warmup_min_steps = 2 self.config.explorer.dynamic_timeout.ratio = 1.5 self.config.buffer.batch_size = 3 self.config.explorer.max_timeout = 20 @@ -1276,10 +1279,20 @@ async def test_dynamic_timeout_counts_completed_tasks_instead_of_sub_tasks(self) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 8) + self.assertEqual(scheduler.total_completed_steps, 1) self.assertEqual(scheduler.total_completed_tasks, 2) self.assertEqual(scheduler.total_completed_sub_tasks, 4) self.assertEqual(scheduler.dynamic_timeout(), scheduler.default_timeout) + tasks = generate_tasks(0, timeout_num=2, repeat_times=4, timeout_seconds=1) + scheduler.schedule(tasks, batch_id=1) + statuses, exps = await scheduler.get_results(batch_id=1) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 8) + self.assertEqual(scheduler.total_completed_steps, 2) + self.assertAlmostEqual(scheduler.dynamic_timeout(), 1.5, delta=0.8) + await scheduler.stop() async def test_completed_task_events_keep_get_results_compatible(self): diff --git a/trinity/common/config.py b/trinity/common/config.py index c3ac8ab1bfd..57ff7f003a0 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -157,6 +157,7 @@ class DynamicTimeoutConfig: """Config for dynamic timeout in explorer.""" enable: bool = False + warmup_min_steps: int = 1 # only enable dynamic timeout after this many fully observed steps ratio: float = 3.0 # the timeout for each step will be min(max_timeout, average_time_per_task * dynamic_timeout.ratio) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 59fe7189420..77d7fc356c2 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -42,7 +42,7 @@ class TaskWrapper: class CompletedTaskResult: """A completed task result stored by batch and task id.""" - task_id: int + task_id: Union[int, str] status: Status experiences: List[Experience] = field(default_factory=list) experience_payloads: List[bytes] = field(default_factory=list) @@ -53,7 +53,7 @@ class CompletedTaskRef: """A lightweight reference to a completed task result.""" batch_id: Union[int, str] - task_id: int + task_id: Union[int, str] # Adapted from verl/trainer/ppo/metric_utils.py @@ -356,7 +356,10 @@ def __init__( self.running_task_map: Dict[asyncio.Future, TaskWrapper] = dict() # future -> task self.running_task_runner_map: Dict[asyncio.Future, int] = dict() # future -> runner_id self.cancelled_task_restart_map: Dict[asyncio.Future, bool] = dict() - self.completed_tasks: Dict[Union[int, str], Dict[int, CompletedTaskResult]] = defaultdict( + self.batch_is_eval_map: Dict[Union[int, str], bool] = dict() + self.completed_tasks: Dict[ + Union[int, str], Dict[Union[int, str], CompletedTaskResult] + ] = defaultdict( dict ) # batch_id -> results self.completed_task_refs: asyncio.Queue[CompletedTaskRef] = asyncio.Queue() @@ -367,6 +370,7 @@ def __init__( self.running = False self.total_running_time = 0.0 + self.total_completed_steps = 0 self.total_completed_sub_tasks = 0 self.total_completed_tasks = 0 @@ -530,8 +534,6 @@ def _emit_task_result(self, task: TaskWrapper) -> None: return status, experience_payloads = self._build_task_result(task) task_id = task.task.task_id - if not isinstance(task_id, int): - raise TypeError(f"Expected task_id to be int when emitting results, got {task_id!r}") self.completed_tasks[task.batch_id][task_id] = CompletedTaskResult( task_id=task_id, status=status, @@ -661,6 +663,7 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: """ if not tasks: return + self.batch_is_eval_map[batch_id] = tasks[0].is_eval self.task_num_map[batch_id] += len(tasks) self._split_and_submit_tasks(tasks, batch_id=batch_id) @@ -688,7 +691,7 @@ def dynamic_timeout(self, timeout: Optional[float] = None) -> float: max_timeout = timeout or self.default_timeout if not self.config.explorer.dynamic_timeout.enable: return max_timeout - if self.total_completed_tasks < self.default_batch_size: + if self.total_completed_steps < self.config.explorer.dynamic_timeout.warmup_min_steps: return max_timeout avg_time_per_task = self.total_running_time / self.total_completed_sub_tasks return min( @@ -732,6 +735,15 @@ def _resolve_result_target( return scheduled_num, scheduled_num return scheduled_num, min_num + def _finalize_dynamic_timeout_step( + self, batch_id: Union[int, str], scheduled_num: int, completed_count: int + ) -> None: + if batch_id in self.pending_tasks or batch_id in self.running_tasks: + return + is_eval = self.batch_is_eval_map.pop(batch_id, False) + if not is_eval and completed_count >= scheduled_num: + self.total_completed_steps += 1 + async def _wait_for_batch_results( self, batch_id: Union[int, str], @@ -843,6 +855,7 @@ async def get_results( await self.experience_pipeline.abort_batch.remote(batch_id) completed_count = len(statuses) + self._finalize_dynamic_timeout_step(batch_id, scheduled_num, completed_count) if completed_count < min_num: self.logger.warning( f"Timeout reached, only {completed_count}/{min_num} tasks completed" From 68f4c3b750ec218b5887bec2a1c80efb83000e0c Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 19:15:17 +0800 Subject: [PATCH 05/11] refactor explorer --- tests/explorer/explorer_test.py | 474 ++++++---------------- tests/explorer/scheduler_test.py | 67 ++-- trinity/explorer/__init__.py | 5 +- trinity/explorer/explorer.py | 227 +++-------- trinity/explorer/rollout_coordinator.py | 506 ++++++++++++++++++++++++ trinity/explorer/scheduler.py | 141 ++++--- trinity/explorer/workflow_runner.py | 9 - 7 files changed, 807 insertions(+), 622 deletions(-) create mode 100644 trinity/explorer/rollout_coordinator.py diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index bb379a9f15d..136f0a0ea17 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -5,6 +5,7 @@ import os import random import shutil +import time import unittest from collections import deque from datetime import datetime @@ -13,7 +14,6 @@ import httpx import ray -import torch from tests.tools import ( RayUnittestBase, @@ -34,86 +34,65 @@ OperatorConfig, ) from trinity.common.constants import StorageType -from trinity.common.experience import Experience -from trinity.explorer.explorer import Explorer, ExploreStepBuffer +from trinity.explorer.explorer import Explorer from trinity.explorer.proxy.client import TrinityClient -from trinity.explorer.scheduler import CompletedTaskRef, CompletedTaskResult -from trinity.explorer.workflow_runner import Status from trinity.manager.state_manager import StateManager -def _build_fake_task_event_explorer(use_payloads: bool = False): - class FakeScheduler: +def _build_fake_coordinator_explorer(): + class FakeRemoteMethod: + def __init__(self, func): + self.func = func + + async def remote(self, *args, **kwargs): + return await self.func(*args, **kwargs) + + class FakeCoordinator: def __init__(self): - self.default_timeout = 5.0 - # Step 2 finishes before step 1 on purpose. This simulates the - # fully async path where Scheduler can return completed tasks in - # task order rather than in step order. - self.completed_task_refs = deque( - [ - CompletedTaskRef(batch_id=2, task_id=0), - CompletedTaskRef(batch_id=1, task_id=0), - CompletedTaskRef(batch_id=1, task_id=1), - ] - ) - self.completed_results = { - (2, 0): CompletedTaskResult( - task_id=0, - status=Status( - completed_runs=1, - total_runs=1, - metrics=[{"run_metrics": 20.0}], - ), - experiences=[], - experience_payloads=[b"step-2-task-0"] if use_payloads else [], - ), - (1, 0): CompletedTaskResult( - task_id=0, - status=Status( - completed_runs=1, - total_runs=1, - metrics=[{"run_metrics": 10.0}], - ), - experiences=[], - experience_payloads=[b"step-1-task-0"] if use_payloads else [], - ), - (1, 1): CompletedTaskResult( - task_id=1, - status=Status( - completed_runs=1, - total_runs=1, - metrics=[{"run_metrics": 11.0}], - ), - experiences=[], - experience_payloads=[b"step-1-task-1"] if use_payloads else [], - ), + self.submit_calls = [] + self.finalize_train_calls = [] + self.finalize_eval_calls = [] + self.shutdown_calls = 0 + self.submit_batch = FakeRemoteMethod(self._submit_batch) + self.finalize_train_batch = FakeRemoteMethod(self._finalize_train_batch) + self.finalize_eval_batch = FakeRemoteMethod(self._finalize_eval_batch) + self.shutdown = FakeRemoteMethod(self._shutdown) + + async def _submit_batch(self, **kwargs): + self.submit_calls.append(kwargs) + + async def _finalize_train_batch(self, batch_id): + self.finalize_train_calls.append(batch_id) + return { + "batch_id": batch_id, + "batch_type": "train", + "finished_task_count": 1, + "metrics": { + "experience_pipeline/experience_count": 2.0, + "rollout/run_metrics/mean": float(batch_id), + "rollout/finished_task_count": 1.0, + }, + "finalize_reason": "complete", + "finalized": True, } - self.get_results_calls = [] - - async def wait_completed_task(self, timeout=None): - if self.completed_task_refs: - return self.completed_task_refs.popleft() - return None - - def pop_completed_task(self, batch_id, task_id): - return self.completed_results.pop((batch_id, task_id), None) - - async def get_results( - self, - batch_id, - timeout=None, - clear_timeout_tasks=True, - return_partial_tasks=False, - ): - self.get_results_calls.append( - { - "batch_id": batch_id, - "timeout": timeout, - "clear_timeout_tasks": clear_timeout_tasks, - "return_partial_tasks": return_partial_tasks, - } - ) - return [], [] + + async def _finalize_eval_batch(self, batch_id): + self.finalize_eval_calls.append(batch_id) + eval_name = batch_id.split("/", 1)[1] + return { + "batch_id": batch_id, + "batch_type": "eval", + "finished_task_count": 2, + "metrics": { + f"eval/{eval_name}/accuracy": 0.5, + f"eval/{eval_name}/finished_task_count": 2.0, + }, + "finalize_reason": "complete", + "finalized": True, + } + + async def _shutdown(self): + self.shutdown_calls += 1 class FakeMonitor: def __init__(self): @@ -122,26 +101,34 @@ def __init__(self): def log(self, metric, step): self.logged.append((step, metric)) + feedback_calls = [] + + async def read_async(): + return [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)] + + def record_feedback(metrics): + feedback_calls.append(metrics) + explorer = Explorer.__new__(Explorer) explorer.logger = MagicMock() - explorer.scheduler = FakeScheduler() + explorer.rollout_coordinator = FakeCoordinator() explorer.monitor = FakeMonitor() - explorer.experience_pipeline = None - explorer.taskset = SimpleNamespace(feedback=lambda metrics: None) - explorer.use_task_event_completion = True + explorer.taskset = SimpleNamespace(read_async=read_async, feedback=record_feedback) + explorer.min_wait_num = None explorer.pending_eval_tasks = deque() - explorer.pending_step_buffers = { - 1: ExploreStepBuffer(expected_task_count=2), - 2: ExploreStepBuffer(expected_task_count=1), - } explorer.explore_start_time = None + explorer.eval_start_time = None explorer.last_monitored_step = 0 - explorer.explore_step_num = 2 + explorer.explore_step_num = 0 explorer.model_version = 7 + explorer.detailed_stats = False explorer.config = SimpleNamespace( - explorer=SimpleNamespace(over_rollout=SimpleNamespace(return_partial_tasks=False)) + explorer=SimpleNamespace( + over_rollout=SimpleNamespace(return_partial_tasks=False), + eval_interval=1, + ) ) - return explorer + return explorer, feedback_calls class BaseExplorerCase(RayUnittestBase): @@ -336,295 +323,72 @@ def test_explorer(self): ray.get(explorer.shutdown.remote()) -class TestExplorerTaskLevelCompletion(unittest.IsolatedAsyncioTestCase): - """Tests Explorer's task-level completion buffering. +class TestExplorerCoordinatorPath(unittest.IsolatedAsyncioTestCase): + async def test_explore_step_submits_train_batch_to_rollout_coordinator(self): + explorer, _ = _build_fake_coordinator_explorer() - The task event stream may arrive out of step order because Scheduler emits a - completion event as soon as one task is fully done. Explorer must therefore - accept task-level completions eagerly, buffer them by step, and only publish - aggregated step metrics when all tasks of the next step are ready. - """ - - async def test_out_of_order_task_completion_is_buffered_before_step_finish(self): - explorer = _build_fake_task_event_explorer() - - # Waiting for step 1 is allowed to consume and buffer task events from - # later steps. The important property is that step 2 is not blocked from - # returning at task granularity just because step 1 is still incomplete. - await explorer._wait_step_buffer(1) - - self.assertEqual(explorer.pending_step_buffers[1].completed_task_count, 2) - self.assertEqual(explorer.pending_step_buffers[2].completed_task_count, 1) - self.assertEqual(explorer.scheduler.get_results_calls, []) - - # When Explorer finally flushes steps, aggregation and monitor logging - # must still happen in ascending step order even though the underlying - # task completion order was 2 -> 1 -> 1. - await explorer._finish_steps(1, 2, model_version=7) - - self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) - self.assertEqual(explorer.scheduler.get_results_calls, []) - self.assertEqual(explorer.pending_step_buffers, {}) - - async def test_finish_current_steps_flushes_buffered_steps_in_order(self): - explorer = _build_fake_task_event_explorer() - - # This exercises the public sync boundary: finish_current_steps should - # drain any out-of-order task events gathered so far, but still flush - # step metrics strictly in step order and advance the monitored cursor. - await explorer.finish_current_steps() - - self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) - self.assertEqual(explorer.last_monitored_step, 2) - self.assertEqual(explorer.scheduler.get_results_calls, []) - self.assertEqual(explorer.pending_step_buffers, {}) - - async def test_finish_current_steps_stages_payloads_before_ordered_finalize(self): - class FakeRemoteMethod: - def __init__(self, func): - self.func = func - - async def remote(self, *args, **kwargs): - return await self.func(*args, **kwargs) - - class FakePipeline: - def __init__(self): - self.stage_calls = [] - self.finalize_calls = [] - self.chunk_process_calls = [] - self.stage_task_payloads = FakeRemoteMethod(self._stage_task_payloads) - self.finalize_batch = FakeRemoteMethod(self._finalize_batch) - self.process_serialized_chunks = FakeRemoteMethod(self._process_serialized_chunks) - - async def _stage_task_payloads(self, batch_id, task_id, exp_chunks): - self.stage_calls.append((batch_id, task_id, list(exp_chunks))) - return f"{batch_id}:{task_id}" - - async def _finalize_batch(self, batch_id, task_ids): - self.finalize_calls.append((batch_id, list(task_ids))) - return {"experience_pipeline/experience_count": float(len(task_ids))} - - async def _process_serialized_chunks(self, exp_chunks): - self.chunk_process_calls.append(list(exp_chunks)) - return {"experience_pipeline/experience_count": float(len(exp_chunks))} - - explorer = _build_fake_task_event_explorer(use_payloads=True) - explorer.experience_pipeline = FakePipeline() - explorer.taskset = SimpleNamespace(feedback=lambda metrics: None) - - # Task payloads are staged immediately when completion events arrive, - # but batch finalize still runs in step order during the flush. - await explorer.finish_current_steps() + should_continue = await explorer.explore_step() + self.assertTrue(should_continue) + self.assertEqual(explorer.explore_step_num, 1) self.assertEqual( - explorer.experience_pipeline.stage_calls, + explorer.rollout_coordinator.submit_calls, [ - (2, 0, [b"step-2-task-0"]), - (1, 0, [b"step-1-task-0"]), - (1, 1, [b"step-1-task-1"]), - ], - ) - self.assertEqual( - explorer.experience_pipeline.finalize_calls, - [ - (1, [0, 1]), - (2, [0]), + { + "batch_id": 1, + "tasks": [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + "batch_type": "train", + "min_wait_num": None, + "allow_partial_finalize": False, + } ], ) - self.assertEqual(explorer.experience_pipeline.chunk_process_calls, []) - self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) - -class TestExplorerFallbackPaths(unittest.IsolatedAsyncioTestCase): - async def test_over_rollout_path_keeps_batch_get_results_and_process(self): - class FakeRemoteMethod: - def __init__(self, func): - self.func = func - - async def remote(self, *args, **kwargs): - return await self.func(*args, **kwargs) - - class FakeScheduler: - def __init__(self): - self.calls = [] - - async def get_results( - self, - batch_id, - min_num=None, - timeout=None, - clear_timeout_tasks=True, - return_partial_tasks=False, - ): - self.calls.append( - { - "batch_id": batch_id, - "min_num": min_num, - "timeout": timeout, - "clear_timeout_tasks": clear_timeout_tasks, - "return_partial_tasks": return_partial_tasks, - } - ) - return [Status(1, 1, metrics=[{"run_metrics": 1.0}])], [ - Experience(tokens=torch.zeros(5), prompt_length=2) - ] - - class FakePipeline: - def __init__(self): - self.process_calls = [] - self.finalize_calls = [] - self.process = FakeRemoteMethod(self._process) - self.finalize_batch = FakeRemoteMethod(self._finalize_batch) - - async def _process(self, exp_bytes): - self.process_calls.append(exp_bytes) - return {"experience_pipeline/experience_count": 1.0} - - async def _finalize_batch(self, batch_id, task_ids): - self.finalize_calls.append((batch_id, list(task_ids))) - return {"experience_pipeline/experience_count": 0.0} - - class FakeMonitor: - def __init__(self): - self.logged = [] - - def log(self, metric, step): - self.logged.append((step, metric)) - - explorer = Explorer.__new__(Explorer) - explorer.logger = MagicMock() - explorer.scheduler = FakeScheduler() - explorer.monitor = FakeMonitor() - explorer.experience_pipeline = FakePipeline() - explorer.taskset = SimpleNamespace(feedback=lambda metrics: None) - explorer.use_task_event_completion = False - explorer.min_wait_num = 1 - explorer.pending_eval_tasks = deque() - explorer.pending_step_buffers = {} - explorer.explore_start_time = None - explorer.last_monitored_step = 0 - explorer.explore_step_num = 1 - explorer.model_version = 7 - explorer.config = SimpleNamespace( - explorer=SimpleNamespace(over_rollout=SimpleNamespace(return_partial_tasks=True)) - ) + async def test_finish_current_steps_uses_rollout_coordinator_finalize(self): + explorer, feedback_calls = _build_fake_coordinator_explorer() + explorer.explore_step_num = 2 await explorer.finish_current_steps() - self.assertEqual(len(explorer.scheduler.calls), 1) - self.assertEqual(explorer.scheduler.calls[0]["batch_id"], 1) - self.assertEqual(explorer.scheduler.calls[0]["min_num"], 1) - self.assertTrue(explorer.scheduler.calls[0]["return_partial_tasks"]) - self.assertEqual(len(explorer.experience_pipeline.process_calls), 1) - self.assertEqual(explorer.experience_pipeline.finalize_calls, []) - self.assertEqual([step for step, _ in explorer.monitor.logged], [1]) - - async def test_eval_flush_does_not_use_training_pipeline_staging(self): - class FakeScheduler: - def __init__(self): - self.calls = [] - - async def get_results( - self, - batch_id, - min_num=None, - timeout=None, - clear_timeout_tasks=True, - return_partial_tasks=False, - ): - self.calls.append( - { - "batch_id": batch_id, - "min_num": min_num, - "timeout": timeout, - "clear_timeout_tasks": clear_timeout_tasks, - "return_partial_tasks": return_partial_tasks, - } - ) - return [Status(1, 1, metrics=[{"accuracy": 1.0}])], [] - - class FakePipeline: - def __init__(self): - self.finalize_calls = [] - self.process_calls = [] - - class FakeMonitor: - def __init__(self): - self.logged = [] - - def log(self, metric, step): - self.logged.append((step, metric)) - - explorer = Explorer.__new__(Explorer) - explorer.logger = MagicMock() - explorer.scheduler = FakeScheduler() - explorer.monitor = FakeMonitor() - explorer.experience_pipeline = FakePipeline() - explorer.pending_eval_tasks = deque([(3, "eval-set")]) - explorer.eval_start_time = None - explorer.explore_step_num = 3 - explorer.detailed_stats = False - explorer.config = SimpleNamespace( - explorer=SimpleNamespace(over_rollout=SimpleNamespace(return_partial_tasks=False)) - ) + self.assertEqual(explorer.rollout_coordinator.finalize_train_calls, [1, 2]) + self.assertEqual(len(feedback_calls), 2) + self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) + self.assertEqual(explorer.last_monitored_step, 2) + + async def test_finish_eval_step_uses_rollout_coordinator_finalize(self): + explorer, _ = _build_fake_coordinator_explorer() + explorer.pending_eval_tasks.append((3, "eval_set")) + explorer.eval_start_time = time.time() await explorer._finish_eval_step(step=3) - self.assertEqual(len(explorer.scheduler.calls), 1) - self.assertEqual(explorer.scheduler.calls[0]["batch_id"], "3/eval-set") - self.assertEqual(explorer.experience_pipeline.process_calls, []) - self.assertEqual(explorer.experience_pipeline.finalize_calls, []) + self.assertEqual(explorer.rollout_coordinator.finalize_eval_calls, ["3/eval_set"]) self.assertEqual([step for step, _ in explorer.monitor.logged], [3]) + self.assertIn("eval/eval_set/accuracy", explorer.monitor.logged[0][1]) - async def test_finish_current_steps_finalizes_upstream_staged_tasks_in_order(self): - class FakeRemoteMethod: - def __init__(self, func): - self.func = func - - async def remote(self, *args, **kwargs): - return await self.func(*args, **kwargs) - - class FakePipeline: - def __init__(self): - self.stage_calls = [] - self.finalize_calls = [] - self.chunk_process_calls = [] - self.stage_task_payloads = FakeRemoteMethod(self._stage_task_payloads) - self.finalize_batch = FakeRemoteMethod(self._finalize_batch) - self.process_serialized_chunks = FakeRemoteMethod(self._process_serialized_chunks) - - async def _stage_task_payloads(self, batch_id, task_id, exp_chunks): - self.stage_calls.append((batch_id, task_id, list(exp_chunks))) - return f"{batch_id}:{task_id}" - - async def _finalize_batch(self, batch_id, task_ids): - self.finalize_calls.append((batch_id, list(task_ids))) - return {"experience_pipeline/experience_count": float(len(task_ids))} - - async def _process_serialized_chunks(self, exp_chunks): - self.chunk_process_calls.append(list(exp_chunks)) - return {"experience_pipeline/experience_count": float(len(exp_chunks))} - - explorer = _build_fake_task_event_explorer(use_payloads=False) - explorer.experience_pipeline = FakePipeline() - explorer.taskset = SimpleNamespace(feedback=lambda metrics: None) - - # In the real direct-staging path, Scheduler returns only lightweight - # task completion metadata here because payloads were already staged by - # WorkflowRunner. Explorer should therefore skip re-staging and only - # drive ordered batch finalization. - await explorer.finish_current_steps() - self.assertEqual(explorer.experience_pipeline.stage_calls, []) +class TestExplorerCoordinatorPolicies(unittest.IsolatedAsyncioTestCase): + async def test_over_rollout_submits_partial_finalize_policy_to_rollout_coordinator(self): + explorer, _ = _build_fake_coordinator_explorer() + explorer.min_wait_num = 1 + explorer.config.explorer.over_rollout.return_partial_tasks = True + + should_continue = await explorer.explore_step() + + self.assertTrue(should_continue) self.assertEqual( - explorer.experience_pipeline.finalize_calls, + explorer.rollout_coordinator.submit_calls, [ - (1, [0, 1]), - (2, [0]), + { + "batch_id": 1, + "tasks": [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + "batch_type": "train", + "min_wait_num": 1, + "allow_partial_finalize": True, + } ], ) - self.assertEqual(explorer.experience_pipeline.chunk_process_calls, []) - self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2]) + def run_serve(config): diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 7969a9c2b3d..542e1ca2fa6 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1318,13 +1318,8 @@ async def test_completed_task_events_keep_get_results_compatible(self): await scheduler.stop() - async def test_get_results_reads_payloads_staged_by_workflow_runner(self): - payload_stage = DummyPayloadStage.remote() - scheduler = Scheduler( - self.config, - [DummyModel.remote(), DummyModel.remote()], - experience_pipeline=payload_stage, - ) + async def test_get_results_reads_payloads_returned_by_workflow_runner(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() scheduler.schedule(generate_tasks(3, repeat_times=2), batch_id=0) @@ -1332,38 +1327,23 @@ async def test_get_results_reads_payloads_staged_by_workflow_runner(self): statuses, exps = await scheduler.get_results(batch_id=0, timeout=10) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 6) - self.assertEqual(ray.get(payload_stage.get_staged_task_ids.remote(0)), []) await scheduler.stop() - async def test_timeout_cleanup_aborts_leftover_staged_payloads(self): - payload_stage = DummyPayloadStage.remote() - scheduler = Scheduler( - self.config, - [DummyModel.remote(), DummyModel.remote()], - experience_pipeline=payload_stage, - ) + async def test_timeout_cleanup_keeps_completed_payloads_local(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() - # Seed one orphan staged payload to verify timeout cleanup clears the - # whole batch after draining any successfully completed task payloads. - await payload_stage.stage_task_payloads.remote(0, 99, [b"orphan-payload"]) scheduler.schedule(generate_tasks(1, timeout_num=1, timeout_seconds=10), batch_id=0) statuses, exps = await scheduler.get_results(batch_id=0, min_num=2, timeout=1) self.assertEqual(len(statuses), 1) self.assertEqual(len(exps), 1) - self.assertEqual(ray.get(payload_stage.get_staged_task_ids.remote(0)), []) await scheduler.stop() - async def test_eval_tasks_do_not_stage_payloads(self): - payload_stage = DummyPayloadStage.remote() - scheduler = Scheduler( - self.config, - [DummyModel.remote(), DummyModel.remote()], - experience_pipeline=payload_stage, - ) + async def test_eval_tasks_do_not_return_training_experiences(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() eval_tasks = generate_tasks(2, repeat_times=2) @@ -1375,7 +1355,40 @@ async def test_eval_tasks_do_not_stage_payloads(self): self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 0) - self.assertEqual(ray.get(payload_stage.get_staged_task_ids.remote("0/eval")), []) + + await scheduler.stop() + + async def test_get_statuses_skips_payload_deserialization(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + scheduler.schedule(generate_tasks(2, repeat_times=2), batch_id=0) + + with patch( + "trinity.explorer.scheduler.Experience.deserialize_many", + side_effect=AssertionError("payload deserialization should not happen"), + ): + statuses = await scheduler.get_statuses(batch_id=0, timeout=10) + + self.assertEqual(len(statuses), 2) + + await scheduler.stop() + + async def test_get_payload_results_keeps_payloads_serialized(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + scheduler.schedule(generate_tasks(2, repeat_times=2), batch_id=0) + + with patch( + "trinity.explorer.scheduler.Experience.deserialize_many", + side_effect=AssertionError("payload deserialization should not happen"), + ): + statuses, payloads = await scheduler.get_payload_results(batch_id=0, timeout=10) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(payloads), 2) + self.assertTrue(all(isinstance(payload, bytes) for payload in payloads)) await scheduler.stop() diff --git a/trinity/explorer/__init__.py b/trinity/explorer/__init__.py index 8665a1b125f..2b93afc5857 100644 --- a/trinity/explorer/__init__.py +++ b/trinity/explorer/__init__.py @@ -1,3 +1,6 @@ +"""Explorer package exports.""" + from trinity.explorer.explorer import Explorer +from trinity.explorer.rollout_coordinator import RolloutCoordinator -__all__ = ["Explorer"] +__all__ = ["Explorer", "RolloutCoordinator"] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 14069e9ac7d..9e159cb3b25 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -8,15 +8,13 @@ import time import traceback from collections import deque -from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import List, Optional import ray import torch from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from trinity.buffer.buffer import get_buffer_reader -from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline from trinity.buffer.task_scheduler import get_taskset_scheduler from trinity.common.config import Config from trinity.common.constants import ( @@ -25,9 +23,8 @@ SyncMethod, SyncStyle, ) -from trinity.common.experience import Experience from trinity.common.models import create_explorer_models -from trinity.explorer.scheduler import Scheduler +from trinity.explorer.rollout_coordinator import RolloutCoordinator from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer from trinity.utils.annotations import Experimental @@ -37,17 +34,6 @@ from trinity.utils.timer import Timer -@dataclass -class ExploreStepBuffer: - """Buffered task results for one explore step.""" - - expected_task_count: int - statuses: List = field(default_factory=list) - experience_payloads: List[bytes] = field(default_factory=list) - staged_task_ids: List[int] = field(default_factory=list) - completed_task_count: int = 0 - - class Explorer: """Responsible for exploring the taskset.""" @@ -64,13 +50,11 @@ def __init__(self, config: Config): self.config = config self.model_type = config.explorer.rollout_model.engine_type self.models, self.auxiliary_models = create_explorer_models(config) - self.experience_pipeline = self._init_experience_pipeline() self.taskset = ( get_taskset_scheduler(explorer_state=explorer_state, config=config) if self.config.mode not in {"bench", "serve"} else None ) - self.scheduler = None self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, group=self.config.group, @@ -88,12 +72,9 @@ def __init__(self, config: Config): ) else: self.min_wait_num = None - self.use_task_event_completion = ( - self.min_wait_num is None and not self.config.explorer.over_rollout.return_partial_tasks - ) + self.rollout_coordinator = None self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL self.pending_eval_tasks = deque() - self.pending_step_buffers: Dict[int, ExploreStepBuffer] = {} # For checkpoint weights update # Use explorer to periodically load the latest model weights and @@ -207,10 +188,6 @@ async def prepare(self) -> None: ) await asyncio.gather(*run_api_ref) self.logger.info("All models are ready.") - # prepare experience pipeline - if self.experience_pipeline: - await self.experience_pipeline.prepare.remote() - self.logger.info("Experience pipeline is ready.") if not self.use_nccl_sync and self.model_type not in {"tinker", "external"}: if self.config.mode == "serve": # In serving mode, each engine will setup its own process group @@ -222,17 +199,9 @@ async def prepare(self) -> None: await self.setup_weight_sync_group(master_address, master_port) if self.config.mode != "serve": - self.scheduler = Scheduler( - self.config, - self.models, - self.auxiliary_models, - experience_pipeline=self.experience_pipeline - if self.use_task_event_completion - else None, - ) - await self.scheduler.start() - if self.use_task_event_completion: - self.scheduler.enable_completed_task_events() + self.rollout_coordinator = self._init_rollout_coordinator() + await self.rollout_coordinator.prepare.remote() + self.logger.info("Rollout coordinator is ready.") if self.config.explorer.eval_on_startup and self.explore_step_num == 0: await self.eval() @@ -293,15 +262,18 @@ async def explore_step(self) -> bool: await self.shutdown() return False self.explore_step_num += 1 - self.scheduler.schedule(tasks, batch_id=self.explore_step_num) - if self.use_task_event_completion: - self.pending_step_buffers[self.explore_step_num] = ExploreStepBuffer( - expected_task_count=len(tasks) - ) + assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first." + await self.rollout_coordinator.submit_batch.remote( + batch_id=self.explore_step_num, + tasks=tasks, + batch_type="train", + min_wait_num=self.min_wait_num, + allow_partial_finalize=self.min_wait_num is not None, + ) return True async def finish_current_steps(self) -> None: - if self.scheduler: + if self.rollout_coordinator is not None: await self._finish_steps( self.last_monitored_step + 1, self.explore_step_num, self.model_version ) @@ -345,7 +317,14 @@ async def eval(self): while True: try: data = await eval_taskset.read_async() - self.scheduler.schedule(data, batch_id=eval_batch_id) + assert ( + self.rollout_coordinator is not None + ), "Rollout coordinator must be prepared first." + await self.rollout_coordinator.submit_batch.remote( + batch_id=eval_batch_id, + tasks=data, + batch_type="eval", + ) except StopAsyncIteration: break @@ -389,7 +368,7 @@ async def save_checkpoint(self) -> None: async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights - if self.scheduler and self.explore_step_num == 0: + if self.rollout_coordinator is not None and self.explore_step_num == 0: await self._finish_eval_step(step=0) self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") @@ -417,115 +396,15 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int self.monitor.log(metric, step=end_step) async def _finish_explore_step(self, step: int, model_version: int) -> None: - if self.use_task_event_completion: - await self._wait_step_buffer(step) - step_buffer = self.pending_step_buffers.pop(step, None) - if step_buffer is None: - return - - metric = {"rollout/model_version": model_version} - if self.experience_pipeline is not None: - if step_buffer.staged_task_ids: - pipeline_metrics = await self.experience_pipeline.finalize_batch.remote( - step, step_buffer.staged_task_ids - ) - else: - pipeline_metrics = ( - await self.experience_pipeline.process_serialized_chunks.remote( - step_buffer.experience_payloads - ) - ) - self.taskset.feedback(pipeline_metrics) - metric.update(pipeline_metrics) - if step_buffer.statuses: - metric.update( - gather_metrics( - [status.metrics[0] for status in step_buffer.statuses], - "rollout", - ) - ) - metric["rollout/finished_task_count"] = len(step_buffer.statuses) - if self.monitor is not None: - self.monitor.log(metric, step=step) - return - + assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first." metric = {"rollout/model_version": model_version} with Timer(metric, "time/wait_explore_step"): - statuses, exps = await self.scheduler.get_results( - batch_id=step, - min_num=self.min_wait_num, - return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, - ) - if self.experience_pipeline is not None: - pipeline_metrics = await self.experience_pipeline.process.remote( - Experience.serialize_many(exps) - ) - self.taskset.feedback(pipeline_metrics) - metric.update(pipeline_metrics) - if statuses: - metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout")) - metric["rollout/finished_task_count"] = len(statuses) - if self.monitor is not None: - self.monitor.log(metric, step=step) - - async def _store_completed_task_result(self, step: int, result) -> None: - step_buffer = self.pending_step_buffers.get(step) - if step_buffer is None: - return - step_buffer.completed_task_count += 1 - step_buffer.statuses.append(result.status) - if self.experience_pipeline is not None and result.experience_payloads: - staged = await self.experience_pipeline.stage_task_payloads.remote( - step, - result.task_id, - result.experience_payloads, - ) - if staged is not None: - step_buffer.staged_task_ids.append(result.task_id) - elif self.experience_pipeline is not None and result.status.completed_runs > 0: - step_buffer.staged_task_ids.append(result.task_id) - elif result.experience_payloads: - step_buffer.experience_payloads.extend(result.experience_payloads) - elif result.experiences: - step_buffer.experience_payloads.append(Experience.serialize_many(result.experiences)) - - async def _wait_step_buffer(self, step: int) -> None: - if self.scheduler is None: - return - - step_buffer = self.pending_step_buffers.get(step) - if step_buffer is None: - return - - start_time = time.time() - while step_buffer.completed_task_count < step_buffer.expected_task_count: - remaining = self.scheduler.default_timeout - (time.time() - start_time) - if remaining <= 0: - break - completed_ref = await self.scheduler.wait_completed_task(timeout=remaining) - if completed_ref is None: - break - completed_result = self.scheduler.pop_completed_task( - completed_ref.batch_id, completed_ref.task_id - ) - if completed_result is None: - continue - if isinstance(completed_ref.batch_id, int): - await self._store_completed_task_result(completed_ref.batch_id, completed_result) - - if step_buffer.completed_task_count >= step_buffer.expected_task_count: - return - - statuses, exps = await self.scheduler.get_results( - batch_id=step, - timeout=0, - clear_timeout_tasks=True, - return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, - ) - step_buffer.statuses.extend(statuses) - if exps: - step_buffer.experience_payloads.append(Experience.serialize_many(exps)) - step_buffer.completed_task_count += len(statuses) + result = await self.rollout_coordinator.finalize_train_batch.remote(step) + if self.taskset is not None: + self.taskset.feedback(result["metrics"]) + metric.update(result["metrics"]) + if result["finished_task_count"] > 0 and self.monitor is not None: + self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: if not self.pending_eval_tasks: @@ -537,18 +416,11 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - statuses, _ = await self.scheduler.get_results( - batch_id=f"{step}/{eval_task_name}", - return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, - ) - metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses) - metric.update( - gather_eval_metrics( - [status.metrics[0] for status in statuses], - f"{prefix}/{eval_task_name}", - detailed_stats=self.detailed_stats, - ) + assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first." + result = await self.rollout_coordinator.finalize_eval_batch.remote( + f"{step}/{eval_task_name}" ) + metric.update(result["metrics"]) if self.eval_start_time is not None: metric.update({"time/eval": time.time() - self.eval_start_time}) self.eval_start_time = None @@ -556,15 +428,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva self.monitor.log(metric, step) async def shutdown(self) -> None: - if self.scheduler: - await self.scheduler.stop() - self.scheduler = None - if self.experience_pipeline: - await self.experience_pipeline.close.remote() - # reserve `experience_pipeline.output` for trainer - # TODO: refactor the lifecycle of buffer actor - self._old_experience_pipeline = self.experience_pipeline - self.experience_pipeline = None + if self.rollout_coordinator: + await self.rollout_coordinator.shutdown.remote() + self.rollout_coordinator = None if self.monitor: self.monitor.close() self.monitor = None @@ -583,24 +449,23 @@ async def is_alive(self) -> bool: """Check if the explorer is alive.""" return True - def _init_experience_pipeline(self) -> ray.actor.ActorHandle: - """Init experience pipeline for the explorer.""" - if self.config.mode == "bench": - return None - # place the pipeline on the same node as the explorer to - # avoid unnecessary data transfer between nodes + def _init_rollout_coordinator(self) -> ray.actor.ActorHandle: + """Init rollout coordinator for the task-event-completion path.""" node_id = ray.get_runtime_context().get_node_id() return ( - ray.remote(ExperiencePipeline) + ray.remote(RolloutCoordinator) .options( - name=f"{self.config.explorer.name}_pipeline", namespace=self.config.ray_namespace, scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=node_id, soft=False, ), ) - .remote(self.config) + .remote( + self.config, + self.models, + self.auxiliary_models, + ) ) @Experimental diff --git a/trinity/explorer/rollout_coordinator.py b/trinity/explorer/rollout_coordinator.py new file mode 100644 index 00000000000..f1a220208a7 --- /dev/null +++ b/trinity/explorer/rollout_coordinator.py @@ -0,0 +1,506 @@ +"""Rollout coordinator for async batch submission and finalize.""" + +import asyncio +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline +from trinity.common.config import Config +from trinity.common.experience import Experience +from trinity.common.models import InferenceModel +from trinity.common.workflows import Task +from trinity.explorer.scheduler import CompletedTaskResult, Scheduler +from trinity.utils.log import get_logger +from trinity.utils.monitor import gather_eval_metrics, gather_metrics + +BatchId = Union[int, str] +BatchType = Literal["train", "eval"] + + +class BatchLifecycleState(str, Enum): + """Lifecycle states for one submitted batch.""" + + PENDING = "pending" + RUNNING = "running" + READY_TO_FINALIZE = "ready_to_finalize" + FINALIZING = "finalizing" + FINALIZED = "finalized" + ABORTED = "aborted" + + +class FinalizeReason(str, Enum): + """Reasons why a batch finalize call returns.""" + + COMPLETE = "complete" + PARTIAL = "partial" + TIMEOUT = "timeout" + ABORT = "abort" + + +@dataclass +class BatchState: + """In-memory state tracked for one train or eval batch.""" + + batch_id: BatchId + batch_type: BatchType + expected_task_count: int + completed_task_count: int = 0 + statuses: Dict[Union[int, str], Any] = field(default_factory=dict) + staged_task_ids: set[int] = field(default_factory=set) + payload_chunks_by_task: Dict[Union[int, str], List[bytes]] = field(default_factory=dict) + min_wait_num: Optional[int] = None + allow_partial_finalize: bool = False + state: BatchLifecycleState = BatchLifecycleState.PENDING + finalize_reason: Optional[FinalizeReason] = None + final_result: Optional[dict] = None + result_future: Optional[asyncio.Future] = None + ready_event: asyncio.Event = field(default_factory=asyncio.Event) + finalize_lock: asyncio.Lock = field(default_factory=asyncio.Lock) + first_submit_time: float = 0.0 + last_progress_time: float = 0.0 + min_threshold_reached_time: Optional[float] = None + + @property + def has_partial_results(self) -> bool: + """Whether at least one task has completed.""" + + return self.completed_task_count > 0 + + @property + def has_staged_payloads(self) -> bool: + """Whether this batch still owns unconsumed payload data.""" + + return bool(self.staged_task_ids or self.payload_chunks_by_task) + + +class RolloutCoordinator: + """Own scheduler-side batch state and expose batch-level finalize APIs.""" + + def __init__( + self, + config: Config, + rollout_model: List[InferenceModel], + auxiliary_models: Optional[List[List[InferenceModel]]] = None, + ): + """Create a coordinator with internally managed scheduler and pipeline.""" + + self.logger = get_logger("rollout_coordinator", in_ray_actor=True) + self.config = config + self.rollout_model = rollout_model + self.auxiliary_models = auxiliary_models or [] + self.experience_pipeline = None + self.scheduler: Optional[Scheduler] = None + self.pending_batches: Dict[BatchId, BatchState] = {} + self.event_loop_task: Optional[asyncio.Task] = None + self.running = False + self.detailed_stats = getattr(getattr(config, "monitor", None), "detailed_stats", False) + + async def prepare(self) -> None: + """Initialize the owned pipeline and scheduler.""" + + if self.running: + return + if self.experience_pipeline is None: + self.experience_pipeline = self._init_experience_pipeline() + if self.experience_pipeline is not None: + await self.experience_pipeline.prepare() + if self.scheduler is None: + self.scheduler = self._init_scheduler() + await self.scheduler.start() + self.scheduler.enable_completed_task_events() + self.running = True + self.event_loop_task = asyncio.create_task(self._completed_task_event_loop()) + + async def shutdown(self) -> None: + """Stop background work and close owned dependencies.""" + + self.running = False + if self.event_loop_task is not None: + self.event_loop_task.cancel() + try: + await self.event_loop_task + except asyncio.CancelledError: + pass + self.event_loop_task = None + if self.scheduler is not None: + await self.scheduler.stop() + self.scheduler = None + if self.experience_pipeline is not None: + await self.experience_pipeline.close() + self.experience_pipeline = None + + def _init_experience_pipeline(self): + """Create the experience pipeline owned by this coordinator actor.""" + + if self.config.mode == "bench": + return None + return ExperiencePipeline(self.config) + + def _init_scheduler(self) -> Scheduler: + """Create the scheduler owned by this coordinator.""" + + return Scheduler( + self.config, + self.rollout_model, + self.auxiliary_models, + ) + + def _require_scheduler(self) -> Scheduler: + """Return the initialized scheduler.""" + + assert self.scheduler is not None, "RolloutCoordinator.prepare() must be called first." + return self.scheduler + + async def submit_batch( + self, + *, + batch_id: BatchId, + tasks: list[Task], + batch_type: BatchType, + min_wait_num: Optional[int] = None, + allow_partial_finalize: bool = False, + ) -> None: + """Register a new batch and schedule its tasks.""" + + existing_state = self.pending_batches.get(batch_id) + if existing_state is not None and existing_state.state not in { + BatchLifecycleState.FINALIZED, + BatchLifecycleState.ABORTED, + }: + raise ValueError(f"Batch {batch_id} is already active.") + + loop = asyncio.get_running_loop() + now = time.time() + batch_state = BatchState( + batch_id=batch_id, + batch_type=batch_type, + expected_task_count=len(tasks), + min_wait_num=min_wait_num, + allow_partial_finalize=allow_partial_finalize, + first_submit_time=now, + last_progress_time=now, + result_future=loop.create_future(), + ) + self.pending_batches[batch_id] = batch_state + + if tasks: + self._require_scheduler().schedule(tasks, batch_id=batch_id) + batch_state.state = BatchLifecycleState.RUNNING + else: + batch_state.state = BatchLifecycleState.READY_TO_FINALIZE + batch_state.finalize_reason = FinalizeReason.COMPLETE + batch_state.ready_event.set() + + async def finalize_train_batch( + self, + batch_id: int, + *, + timeout: Optional[float] = None, + ) -> dict: + """Finalize one train batch and return aggregated metrics.""" + + batch_state = self._get_batch_state(batch_id, expected_type="train") + return await self._finalize_batch(batch_state, timeout=timeout) + + async def finalize_eval_batch( + self, + batch_id: str, + *, + timeout: Optional[float] = None, + ) -> dict: + """Finalize one eval batch and return aggregated eval metrics.""" + + scheduler = self._require_scheduler() + batch_state = self._get_batch_state(batch_id, expected_type="eval") + async with batch_state.finalize_lock: + if batch_state.final_result is not None: + return dict(batch_state.final_result) + if batch_state.state == BatchLifecycleState.ABORTED: + batch_state.final_result = self._build_batch_result( + batch_state, FinalizeReason.ABORT, {} + ) + return dict(batch_state.final_result) + + statuses = await scheduler.get_statuses( + batch_id=batch_id, + timeout=timeout, + return_partial_tasks=self._return_partial_tasks(), + ) + for task_id, status in enumerate(statuses): + if task_id in batch_state.statuses: + continue + batch_state.statuses[task_id] = status + batch_state.completed_task_count += 1 + batch_state.last_progress_time = time.time() + reason = ( + FinalizeReason.COMPLETE + if batch_state.completed_task_count >= batch_state.expected_task_count + else FinalizeReason.TIMEOUT + ) + batch_state.finalize_reason = reason + batch_state.state = BatchLifecycleState.FINALIZED + batch_state.final_result = self._build_batch_result(batch_state, reason, {}) + if batch_state.result_future is not None and not batch_state.result_future.done(): + batch_state.result_future.set_result(batch_state.final_result) + return dict(batch_state.final_result) + + async def abort_batch( + self, + batch_id: BatchId, + *, + reason: str, + keep_partial_results: bool = False, + ) -> None: + """Abort one batch and cleanup its running and staged state.""" + + scheduler = self._require_scheduler() + batch_state = self.pending_batches.get(batch_id) + if batch_state is None: + return + if batch_state.state in {BatchLifecycleState.FINALIZED, BatchLifecycleState.ABORTED}: + return + + self.logger.warning("Abort batch %s: %s", batch_id, reason) + await scheduler.abort_batch( + batch_id, + return_partial_tasks=keep_partial_results, + restart_runners=True, + ) + if self.experience_pipeline is not None: + await self.experience_pipeline.abort_batch(batch_id) + + batch_state.state = BatchLifecycleState.ABORTED + batch_state.finalize_reason = FinalizeReason.ABORT + batch_state.final_result = self._build_batch_result(batch_state, FinalizeReason.ABORT, {}) + batch_state.ready_event.set() + if batch_state.result_future is not None and not batch_state.result_future.done(): + batch_state.result_future.set_result(batch_state.final_result) + + async def _completed_task_event_loop(self) -> None: + """Consume task completion events emitted by the scheduler.""" + + scheduler = self._require_scheduler() + while self.running: + try: + completed_ref = await scheduler.wait_completed_task(timeout=0.1) + if completed_ref is None: + continue + if not isinstance(completed_ref.task_id, int): + self.logger.warning( + "Skip completed task event with non-integer task id: %s", + completed_ref.task_id, + ) + continue + completed_result = scheduler.pop_completed_task(completed_ref.batch_id, completed_ref.task_id) + if completed_result is None: + continue + batch_state = self.pending_batches.get(completed_ref.batch_id) + if batch_state is None: + continue + await self._store_completed_task_result(batch_state, completed_result) + self._maybe_mark_ready(batch_state) + except Exception: # noqa: BLE001 + self.logger.exception("RolloutCoordinator task event loop failed.") + + async def _store_completed_task_result( + self, batch_state: BatchState, result: CompletedTaskResult + ) -> None: + """Persist one completed task into batch-level aggregation state.""" + + if result.task_id in batch_state.statuses: + return + batch_state.statuses[result.task_id] = result.status + batch_state.completed_task_count += 1 + batch_state.last_progress_time = time.time() + + if batch_state.batch_type != "train": + return + + if self.experience_pipeline is not None and result.experience_payloads: + staged_task_id = int(result.task_id) + staged = await self.experience_pipeline.stage_task_payloads( + batch_state.batch_id, + staged_task_id, + result.experience_payloads, + ) + if staged is not None: + batch_state.staged_task_ids.add(staged_task_id) + return + + if self.experience_pipeline is not None and result.status.completed_runs > 0: + batch_state.staged_task_ids.add(int(result.task_id)) + return + + if result.experience_payloads: + batch_state.payload_chunks_by_task[result.task_id] = list(result.experience_payloads) + return + + def _get_batch_state(self, batch_id: BatchId, *, expected_type: BatchType) -> BatchState: + """Return one registered batch and validate its type.""" + + batch_state = self.pending_batches.get(batch_id) + if batch_state is None: + raise KeyError(f"Batch {batch_id} is not registered.") + if batch_state.batch_type != expected_type: + raise ValueError( + f"Batch {batch_id} is {batch_state.batch_type}, expected {expected_type}." + ) + return batch_state + + def _get_ready_reason(self, batch_state: BatchState) -> Optional[FinalizeReason]: + """Check whether a batch is ready to finalize and why.""" + + if batch_state.state == BatchLifecycleState.ABORTED: + return FinalizeReason.ABORT + if batch_state.completed_task_count >= batch_state.expected_task_count: + return FinalizeReason.COMPLETE + if batch_state.batch_type != "train": + return None + if not batch_state.allow_partial_finalize or batch_state.min_wait_num is None: + return None + if batch_state.completed_task_count < batch_state.min_wait_num: + batch_state.min_threshold_reached_time = None + return None + + if batch_state.min_threshold_reached_time is None: + batch_state.min_threshold_reached_time = time.time() + + wait_after_min = getattr(self.config.explorer.over_rollout, "wait_after_min", 0.0) + if time.time() - batch_state.min_threshold_reached_time >= wait_after_min: + return FinalizeReason.PARTIAL + return None + + def _maybe_mark_ready(self, batch_state: BatchState) -> Optional[FinalizeReason]: + """Transition a batch to ready state when finalize conditions are met.""" + + ready_reason = self._get_ready_reason(batch_state) + if ready_reason is None: + return None + if batch_state.state not in { + BatchLifecycleState.FINALIZED, + BatchLifecycleState.ABORTED, + BatchLifecycleState.FINALIZING, + }: + batch_state.state = BatchLifecycleState.READY_TO_FINALIZE + batch_state.finalize_reason = ready_reason + batch_state.ready_event.set() + return ready_reason + + async def _wait_for_ready( + self, batch_state: BatchState, timeout: Optional[float] + ) -> Optional[FinalizeReason]: + """Wait until a batch is ready to finalize or the timeout expires.""" + + start_time = time.time() + while True: + ready_reason = self._maybe_mark_ready(batch_state) + if ready_reason is not None: + return ready_reason + if timeout is not None and (time.time() - start_time) >= timeout: + return None + await asyncio.sleep(0.05) + + async def _finalize_batch(self, batch_state: BatchState, *, timeout: Optional[float]) -> dict: + """Finalize one train batch with idempotent result reuse.""" + + async with batch_state.finalize_lock: + if batch_state.final_result is not None: + return dict(batch_state.final_result) + if batch_state.state == BatchLifecycleState.ABORTED: + batch_state.final_result = self._build_batch_result( + batch_state, FinalizeReason.ABORT, {} + ) + return dict(batch_state.final_result) + + ready_reason = await self._wait_for_ready(batch_state, timeout) + if ready_reason is None: + if batch_state.allow_partial_finalize and batch_state.has_partial_results: + ready_reason = FinalizeReason.TIMEOUT + else: + raise TimeoutError(f"Timeout waiting for batch {batch_state.batch_id}.") + + batch_state.state = BatchLifecycleState.FINALIZING + batch_state.finalize_reason = ready_reason + try: + pipeline_metrics = await self._finalize_train_payloads(batch_state) + except Exception: + batch_state.state = BatchLifecycleState.READY_TO_FINALIZE + batch_state.finalize_reason = None + raise + + batch_state.state = BatchLifecycleState.FINALIZED + batch_state.final_result = self._build_batch_result( + batch_state, ready_reason, pipeline_metrics + ) + if batch_state.result_future is not None and not batch_state.result_future.done(): + batch_state.result_future.set_result(batch_state.final_result) + return dict(batch_state.final_result) + + async def _finalize_train_payloads(self, batch_state: BatchState) -> dict: + """Flush staged train payloads through the experience pipeline.""" + + if self.experience_pipeline is not None and batch_state.staged_task_ids: + return await self.experience_pipeline.finalize_batch( + batch_state.batch_id, + task_ids=sorted(batch_state.staged_task_ids), + ) + if self.experience_pipeline is not None and batch_state.payload_chunks_by_task: + exp_chunks = [] + for task_id in sorted(batch_state.payload_chunks_by_task): + exp_chunks.extend(batch_state.payload_chunks_by_task[task_id]) + return await self.experience_pipeline.process_serialized_chunks( + exp_chunks, + ) + return {} + + def _build_batch_result( + self, + batch_state: BatchState, + reason: FinalizeReason, + pipeline_metrics: dict, + ) -> dict: + """Build the public finalize result returned to Explorer.""" + + metrics = dict(pipeline_metrics) + status_metrics = [ + status.metrics[0] for status in batch_state.statuses.values() if status.metrics + ] + if batch_state.batch_type == "train": + if status_metrics: + metrics.update(gather_metrics(status_metrics, "rollout")) + metrics["rollout/finished_task_count"] = float(batch_state.completed_task_count) + else: + prefix = self._eval_metric_prefix(batch_state.batch_id) + if status_metrics: + metrics.update( + gather_eval_metrics( + status_metrics, + prefix, + detailed_stats=self.detailed_stats, + ) + ) + metrics[f"{prefix}/finished_task_count"] = float(batch_state.completed_task_count) + + return { + "batch_id": batch_state.batch_id, + "batch_type": batch_state.batch_type, + "finished_task_count": batch_state.completed_task_count, + "metrics": metrics, + "finalize_reason": reason.value, + "finalized": reason != FinalizeReason.ABORT, + } + + def _eval_metric_prefix(self, batch_id: BatchId) -> str: + """Return the metric namespace prefix for one eval batch.""" + + batch_name = str(batch_id) + if "/" in batch_name: + batch_name = batch_name.split("/", 1)[1] + return f"eval/{batch_name}" + + def _return_partial_tasks(self) -> bool: + """Return whether scheduler cleanup may emit partial task results.""" + + return bool(getattr(self.config.explorer.over_rollout, "return_partial_tasks", False)) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 77d7fc356c2..526500de636 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -44,7 +44,6 @@ class CompletedTaskResult: task_id: Union[int, str] status: Status - experiences: List[Experience] = field(default_factory=list) experience_payloads: List[bytes] = field(default_factory=list) @@ -156,14 +155,12 @@ def __init__( rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config, - experience_pipeline=None, ): self.logger = get_logger(__name__) self.runner_id = runner_id self.rollout_model = rollout_model self.auxiliary_models = auxiliary_models self.config = config - self.experience_pipeline = experience_pipeline self.retry_times = config.explorer.max_retry_times self.timeout = config.explorer.max_timeout self.namespace = config.ray_namespace @@ -185,7 +182,6 @@ def _create_runner(self): self.config, self.rollout_model, self.auxiliary_models, - self.experience_pipeline, self.runner_id, ) ) @@ -325,13 +321,11 @@ def __init__( config: Config, rollout_model: List[InferenceModel], auxiliary_models: Optional[List[List[InferenceModel]]] = None, - experience_pipeline=None, ): self.logger = get_logger(__name__) self.config = config self.rollout_model = rollout_model self.auxiliary_models = auxiliary_models or [] - self.experience_pipeline = experience_pipeline self.namespace = ray.get_runtime_context().namespace self.default_timeout = config.explorer.max_timeout * (config.explorer.max_retry_times + 1) self.max_retry_times = config.explorer.max_retry_times @@ -386,7 +380,6 @@ async def _create_runner( for j in range(len(self.auxiliary_models)) ], config=self.config, - experience_pipeline=self.experience_pipeline, ) await runner.prepare() self.runners[runner_id] = runner @@ -777,52 +770,28 @@ async def _wait_for_batch_results( async def _collect_batch_results( self, batch_id: Union[int, str] - ) -> Tuple[List[Status], List[Experience]]: + ) -> Tuple[List[Status], List[bytes]]: statuses = [] - experiences = [] + payload_chunks = [] completed_results = list(self.completed_tasks.get(batch_id, {}).values()) - staged_task_ids = [] for result in completed_results: statuses.append(result.status) if result.experience_payloads: - for exp_payload in result.experience_payloads: - experiences.extend(Experience.deserialize_many(exp_payload)) - continue + payload_chunks.extend(result.experience_payloads) - experiences.extend(result.experiences) - if ( - self.experience_pipeline is not None - and result.status.completed_runs > 0 - and not result.experiences - ): - staged_task_ids.append(result.task_id) - - if staged_task_ids: - staged_payloads = await self.experience_pipeline.take_staged_task_payloads.remote( - batch_id, staged_task_ids - ) - for exp_payload in staged_payloads: - experiences.extend(Experience.deserialize_many(exp_payload)) - - return statuses, experiences + return statuses, payload_chunks - async def get_results( + async def _get_batch_payload_results( self, batch_id: Union[int, str], - min_num: Optional[int] = None, - timeout: Optional[float] = None, - clear_timeout_tasks: bool = True, - return_partial_tasks: bool = False, - ) -> Tuple[List[Status], List[Experience]]: - """Get the result of tasks at the specific batch_id. + *, + min_num: Optional[int], + timeout: Optional[float], + clear_timeout_tasks: bool, + return_partial_tasks: bool, + ) -> Tuple[List[Status], List[bytes]]: + """Wait for one batch and drain its completed payload chunks.""" - Args: - batch_id (`Union[int, str]`): Only wait for tasks at this batch. - min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. - timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. - clear_timeout_tasks (`bool`): Whether to clear timeout tasks. - return_partial_tasks (`bool`): Whether to emit tasks with partial successful runs when cleaning up unfinished tasks. - """ timeout = timeout or self.default_timeout scheduled_num, min_num = self._resolve_result_target(batch_id, min_num) @@ -846,14 +815,11 @@ async def get_results( restart_runners=True, ) - statuses, experiences = await self._collect_batch_results(batch_id) + statuses, payload_chunks = await self._collect_batch_results(batch_id) if batch_id in self.completed_tasks: del self.completed_tasks[batch_id] - if clear_timeout_tasks and self.experience_pipeline is not None: - await self.experience_pipeline.abort_batch.remote(batch_id) - completed_count = len(statuses) self._finalize_dynamic_timeout_step(batch_id, scheduled_num, completed_count) if completed_count < min_num: @@ -861,8 +827,87 @@ async def get_results( f"Timeout reached, only {completed_count}/{min_num} tasks completed" ) + return statuses, payload_chunks + + async def get_payload_results( + self, + batch_id: Union[int, str], + min_num: Optional[int] = None, + timeout: Optional[float] = None, + clear_timeout_tasks: bool = True, + return_partial_tasks: bool = False, + ) -> Tuple[List[Status], List[bytes]]: + """Wait for one batch and return task statuses plus serialized payload chunks.""" + + return await self._get_batch_payload_results( + batch_id=batch_id, + min_num=min_num, + timeout=timeout, + clear_timeout_tasks=clear_timeout_tasks, + return_partial_tasks=return_partial_tasks, + ) + + async def get_results( + self, + batch_id: Union[int, str], + min_num: Optional[int] = None, + timeout: Optional[float] = None, + clear_timeout_tasks: bool = True, + return_partial_tasks: bool = False, + ) -> Tuple[List[Status], List[Experience]]: + """Get the result of tasks at the specific batch_id. + + Args: + batch_id (`Union[int, str]`): Only wait for tasks at this batch. + min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. + timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. + clear_timeout_tasks (`bool`): Whether to clear timeout tasks. + return_partial_tasks (`bool`): Whether to emit tasks with partial successful runs when cleaning up unfinished tasks. + """ + statuses, payload_chunks = await self._get_batch_payload_results( + batch_id=batch_id, + min_num=min_num, + timeout=timeout, + clear_timeout_tasks=clear_timeout_tasks, + return_partial_tasks=return_partial_tasks, + ) + experiences = [] + for exp_payload in payload_chunks: + experiences.extend(Experience.deserialize_many(exp_payload)) return statuses, experiences + async def get_statuses( + self, + batch_id: Union[int, str], + min_num: Optional[int] = None, + timeout: Optional[float] = None, + clear_timeout_tasks: bool = True, + return_partial_tasks: bool = False, + ) -> List[Status]: + """Wait for one batch and return only task statuses without materializing experiences.""" + + statuses, _ = await self._get_batch_payload_results( + batch_id=batch_id, + min_num=min_num, + timeout=timeout, + clear_timeout_tasks=clear_timeout_tasks, + return_partial_tasks=return_partial_tasks, + ) + return statuses + + async def abort_batch( + self, + batch_id: Union[int, str], + return_partial_tasks: bool = False, + restart_runners: bool = True, + ) -> None: + """Abort one batch and cleanup unfinished scheduler state.""" + await self._cleanup_batch( + batch_id, + return_partial_tasks=return_partial_tasks, + restart_runners=restart_runners, + ) + def has_step(self, batch_id: Union[int, str]) -> bool: return ( batch_id in self.completed_tasks @@ -909,8 +954,6 @@ async def wait_all( batch_ids_to_abort = self.pending_tasks.keys() | self.running_tasks.keys() for batch_id in batch_ids_to_abort: self._clear_timeout_tasks(batch_id) - if self.experience_pipeline is not None: - await self.experience_pipeline.abort_batch.remote(batch_id) asyncio.gather( *[self._restart_runner(runner_id) for runner_id in self.busy_runners.keys()] ) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index b75f91978c4..f10fca58c8c 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -68,7 +68,6 @@ def __init__( config: Config, model: InferenceModel, auxiliary_models: Optional[List[InferenceModel]] = None, - experience_pipeline=None, runner_id: Optional[int] = None, ) -> None: self.name = f"{config.explorer.name}_runner_{runner_id}" @@ -81,7 +80,6 @@ def __init__( enable_history=config.explorer.rollout_model.enable_history, ) self.auxiliary_models = auxiliary_models or [] - self.experience_pipeline = experience_pipeline self.auxiliary_model_wrappers = [ ModelWrapper( model, @@ -441,13 +439,6 @@ async def run_task( return status, b"" else: exp_payload = Experience.serialize_many(exps) - if self.experience_pipeline is not None: - await self.experience_pipeline.stage_task_payloads.remote( - task.batch_id, - task.task_id, - [exp_payload], - ) - return status, b"" return status, exp_payload except Exception as e: From 4a6e64acf345fe9a3ea135bedebdada4e963549a Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 19:30:23 +0800 Subject: [PATCH 06/11] clean stale interface --- tests/explorer/scheduler_test.py | 148 +++++++++++++++++-------------- trinity/explorer/scheduler.py | 30 ------- 2 files changed, 82 insertions(+), 96 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 542e1ca2fa6..2abaf82d8ba 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -434,6 +434,16 @@ def generate_tasks( return tasks +async def collect_results(scheduler: Scheduler, **kwargs): + """Collect serialized payloads and materialize experiences for assertions.""" + + statuses, payloads = await scheduler.get_payload_results(**kwargs) + experiences = [] + for payload in payloads: + experiences.extend(Experience.deserialize_many(payload)) + return statuses, experiences + + class SchedulerTest(unittest.IsolatedAsyncioTestCase): def setUp(self): ray.init(ignore_reinit_error=True) @@ -456,17 +466,17 @@ def setUp(self): self.config.algorithm.repeat_times = 1 self.config.check_and_update() - async def test_get_results(self): + async def test_get_payload_results(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() tasks = generate_tasks(8) scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=8, timeout=20) self.assertEqual(len(statuses), 8) self.assertEqual(len(exps), 8) - _, exps = await scheduler.get_results(batch_id=0, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=0, min_num=1, timeout=1) self.assertEqual(len(exps), 0) for result in statuses: @@ -478,17 +488,19 @@ async def test_get_results(self): for batch_id in range(1, 4): self.assertTrue(scheduler.has_step(batch_id)) - statuses, exps = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10) + statuses, exps = await collect_results( + scheduler, batch_id=batch_id, min_num=4, timeout=10 + ) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4) self.assertFalse(scheduler.has_step(batch_id)) - _, exps = await scheduler.get_results(batch_id=0, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=0, min_num=1, timeout=1) self.assertEqual(len(exps), 0) tasks = generate_tasks(3) scheduler.schedule(tasks, batch_id=4) self.assertTrue(scheduler.has_step(4)) - statuses, exps = await scheduler.get_results(batch_id=4) + statuses, exps = await collect_results(scheduler, batch_id=4) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3) self.assertFalse(scheduler.has_step(4)) @@ -498,7 +510,7 @@ async def test_get_results(self): scheduler.schedule(tasks, batch_id=0) start_time = time.time() - statuses, exps = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=4, timeout=3) end_time = time.time() self.assertLessEqual(end_time - start_time, 15) # sync wait for runner restart @@ -510,50 +522,46 @@ async def test_get_results(self): scheduler.schedule(tasks, batch_id=0) # actor restart is slow, set a big timeout - statuses, exps = await scheduler.get_results(batch_id=0, timeout=20) + statuses, exps = await collect_results(scheduler, batch_id=0, timeout=20) self.assertEqual(len(statuses), 4) success_count = sum(1 for r in statuses if r.ok) self.assertEqual(success_count, 4) self.assertEqual(len(exps), 4) - _, exps = await scheduler.get_results(batch_id=0, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=0, min_num=1, timeout=1) self.assertEqual(len(exps), 0) # test exception tasks tasks = generate_tasks(1, exception_num=3) scheduler.schedule(tasks, batch_id=1) - statuses, exps = await scheduler.get_results(batch_id=1, timeout=5) + statuses, exps = await collect_results(scheduler, batch_id=1, timeout=5) self.assertEqual(len(statuses), 4) success_count = sum(1 for r in statuses if r.ok) self.assertEqual(success_count, 1) self.assertEqual(len(exps), 1) - _, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=1, min_num=1, timeout=1) self.assertEqual(len(exps), 0) # test _cleanup_batch_and_restart_runners: part I, no clear tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) scheduler.schedule(tasks, batch_id=2) - statuses, exps = await scheduler.get_results( - batch_id=2, timeout=2, clear_timeout_tasks=False - ) + statuses, exps = await collect_results(scheduler, batch_id=2, timeout=2, clear_timeout_tasks=False) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3) - statuses, exps = await scheduler.get_results( - batch_id=2, timeout=2, clear_timeout_tasks=False - ) + statuses, exps = await collect_results(scheduler, batch_id=2, timeout=2, clear_timeout_tasks=False) self.assertEqual(len(statuses), 1) self.assertEqual(len(exps), 1) # test _cleanup_batch_and_restart_runners: part II, clear tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) scheduler.schedule(tasks, batch_id=3) - statuses, exps = await scheduler.get_results(batch_id=3, timeout=2) + statuses, exps = await collect_results(scheduler, batch_id=3, timeout=2) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3) - statuses, exps = await scheduler.get_results(batch_id=3, timeout=2) + statuses, exps = await collect_results(scheduler, batch_id=3, timeout=2) self.assertEqual(len(statuses), 0) self.assertEqual(len(exps), 0) - _, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=3, min_num=1, timeout=1) self.assertEqual(len(exps), 0) await scheduler.stop() @@ -577,8 +585,8 @@ async def test_wait_all(self): self.assertEqual(len(scheduler.pending_tasks), 0) self.assertEqual(len(scheduler.running_tasks), 0) - status0, exps0 = await scheduler.get_results(batch_id=0, min_num=4, timeout=1) - status1, exps1 = await scheduler.get_results(batch_id=1, min_num=3, timeout=1) + status0, exps0 = await collect_results(scheduler, batch_id=0, min_num=4, timeout=1) + status1, exps1 = await collect_results(scheduler, batch_id=1, min_num=3, timeout=1) self.assertEqual(len(status0), 4) self.assertEqual(len(status1), 3) @@ -634,7 +642,7 @@ async def test_concurrent_operations(self): async def schedule_tasks(batch_id, num_tasks): tasks = generate_tasks(num_tasks) scheduler.schedule(tasks, batch_id=batch_id) - return await scheduler.get_results(batch_id=batch_id, min_num=num_tasks, timeout=10) + return await collect_results(scheduler, batch_id=batch_id, min_num=num_tasks, timeout=10) results = await asyncio.gather( schedule_tasks(0, 3), @@ -654,7 +662,7 @@ async def test_scheduler_restart_after_stop(self): await scheduler.start() tasks = generate_tasks(2) scheduler.schedule(tasks, batch_id=0) - results, exps = await scheduler.get_results(batch_id=0, min_num=2, timeout=10) + results, exps = await collect_results(scheduler, batch_id=0, min_num=2, timeout=10) self.assertEqual(len(results), 2) self.assertEqual(len(exps), 2) await scheduler.stop() @@ -662,7 +670,7 @@ async def test_scheduler_restart_after_stop(self): await scheduler.start() tasks = generate_tasks(3, repeat_times=2) scheduler.schedule(tasks, batch_id=1) - results, exps = await scheduler.get_results(batch_id=1, min_num=3, timeout=10) + results, exps = await collect_results(scheduler, batch_id=1, min_num=3, timeout=10) self.assertEqual(len(results), 3) self.assertEqual(len(exps), 3 * 2) await scheduler.stop() @@ -673,13 +681,13 @@ async def test_scheduler_all_methods(self): tasks = generate_tasks(8) scheduler.schedule(tasks, batch_id=0) self.assertTrue(scheduler.has_step(0)) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=8, timeout=20) self.assertEqual(len(statuses), 8) self.assertEqual(len(exps), 8) scheduler.schedule(tasks, batch_id=1) scheduler.schedule(tasks[:4], batch_id=2) self.assertFalse(scheduler.has_step(0)) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=8) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=8) self.assertFalse(scheduler.has_step(0)) self.assertEqual(len(statuses), 0) # batch_id 0 has no more tasks self.assertEqual(len(exps), 0) @@ -688,7 +696,7 @@ async def test_scheduler_all_methods(self): self.assertTrue(scheduler.has_step(2)) await scheduler.wait_all() st = time.time() - statuses, exps = await scheduler.get_results(batch_id=1) + statuses, exps = await collect_results(scheduler, batch_id=1) et = time.time() self.assertTrue(et - st < 1.0) self.assertEqual(len(statuses), 8) @@ -696,7 +704,7 @@ async def test_scheduler_all_methods(self): self.assertFalse(scheduler.has_step(1)) self.assertTrue(scheduler.has_step(2)) st = time.time() - statuses, exps = await scheduler.get_results(batch_id=2) + statuses, exps = await collect_results(scheduler, batch_id=2) et = time.time() self.assertTrue(et - st < 1.0) self.assertEqual(len(statuses), 4) @@ -713,29 +721,29 @@ async def test_split_tasks(self): tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4 scheduler.schedule(tasks, batch_id=1) - statuses, exps = await scheduler.get_results(batch_id=1) + statuses, exps = await collect_results(scheduler, batch_id=1) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4 * 8) exp_list.extend(exps) - _, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=1, min_num=1, timeout=1) self.assertEqual(len(exps), 0) tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3 scheduler.schedule(tasks, batch_id=2) - statuses, exps = await scheduler.get_results(batch_id=2) + statuses, exps = await collect_results(scheduler, batch_id=2) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4 * 5) exp_list.extend(exps) - _, exps = await scheduler.get_results(batch_id=2, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=2, min_num=1, timeout=1) self.assertEqual(len(exps), 0) tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1 scheduler.schedule(tasks, batch_id=3) - statuses, exps = await scheduler.get_results(batch_id=3) + statuses, exps = await collect_results(scheduler, batch_id=3) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3 * 1) exp_list.extend(exps) - _, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1) + _, exps = await collect_results(scheduler, batch_id=3, min_num=1, timeout=1) self.assertEqual(len(exps), 0) # test task_id, run_id and unique_id @@ -758,7 +766,7 @@ async def test_multi_step_execution(self): n_steps = 3 for i in range(1, n_steps + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 2 * 4) @@ -776,7 +784,7 @@ async def test_non_repeatable_workflow(self): exp_list = [] for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times) exp_list.extend(exps) @@ -817,7 +825,7 @@ async def test_async_workflow(self): exp_list = [] for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -847,7 +855,7 @@ async def test_stepwise_experience_eid(self): exp_list = [] for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -867,7 +875,7 @@ async def test_stepwise_experience_eid(self): exp_list = [] for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) - statuses, exps = await scheduler.get_results(batch_id=i) + statuses, exps = await collect_results(scheduler, batch_id=i) self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -895,7 +903,7 @@ async def test_metric_calculation_with_repeatable_workflow(self, max_repeat_time tasks.extend(generate_tasks(total_num=1, step_num=1, repeat_times=4, repeatable=True)) tasks.extend(generate_tasks(total_num=1, step_num=4, repeat_times=8, repeatable=True)) scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0) + statuses, exps = await collect_results(scheduler, batch_id=0) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 1 * 4 * 1 + 1 * 8 * 4) expected_run_metrics = set({1.5, 3.5}) # (0+1+2+3)/4 and (0+1+2+3+4+5+6+7)/8 @@ -921,7 +929,7 @@ async def test_metric_calculation_with_non_repeatable_workflow( tasks.extend(generate_tasks(total_num=1, step_num=8, repeat_times=5, repeatable=False)) tasks[-1].workflow_args["metrics"] = [2 * i for i in range(8)] scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0) + statuses, exps = await collect_results(scheduler, batch_id=0) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 1 * 4 * 3 + 1 * 5 * 8) # (1+2+3)/3 = 2.0 @@ -944,7 +952,7 @@ async def test_over_rollout_min_wait(self): tasks.extend(generate_tasks(0, timeout_num=1, repeat_times=1, timeout_seconds=3)) tasks.extend(generate_tasks(0, timeout_num=1, repeat_times=1, timeout_seconds=6)) scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=2) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=2) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3 * 1) @@ -991,7 +999,8 @@ async def test_over_rollout_return_partial_tasks(self): scheduler.schedule(tasks, batch_id=0) start_time = time.time() - statuses, exps = await scheduler.get_results( + statuses, exps = await collect_results( + scheduler, batch_id=0, min_num=1, timeout=2, @@ -1041,7 +1050,7 @@ async def test_over_rollout_return_partial_tasks(self): self.assertEqual(partial_status.metrics[0]["run_metrics"], 10.0) self.assertIn("1/3 runs completed successfully", partial_status.message) - statuses, exps = await scheduler.get_results(batch_id=0, timeout=1) + statuses, exps = await collect_results(scheduler, batch_id=0, timeout=1) self.assertEqual(len(statuses), 0) self.assertEqual(len(exps), 0) @@ -1080,7 +1089,8 @@ async def test_over_rollout_async_cancelled_runner_accepts_next_batch(self): ] scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results( + statuses, exps = await collect_results( + scheduler, batch_id=0, min_num=1, timeout=3, @@ -1093,7 +1103,9 @@ async def test_over_rollout_async_cancelled_runner_accepts_next_batch(self): follow_up_tasks = generate_tasks(2) scheduler.schedule(follow_up_tasks, batch_id=1) start_time = time.time() - next_statuses, next_exps = await scheduler.get_results(batch_id=1, min_num=2, timeout=2) + next_statuses, next_exps = await collect_results( + scheduler, batch_id=1, min_num=2, timeout=2 + ) elapsed = time.time() - start_time self.assertEqual(len(next_statuses), 2) @@ -1135,7 +1147,8 @@ async def test_over_rollout_sync_cancel_does_not_imply_immediate_runner_reuse(se ] scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results( + statuses, exps = await collect_results( + scheduler, batch_id=0, min_num=1, timeout=3, @@ -1148,7 +1161,8 @@ async def test_over_rollout_sync_cancel_does_not_imply_immediate_runner_reuse(se follow_up_tasks = generate_tasks(2) scheduler.schedule(follow_up_tasks, batch_id=1) start_time = time.time() - next_statuses, next_exps = await scheduler.get_results( + next_statuses, next_exps = await collect_results( + scheduler, batch_id=1, min_num=2, timeout=0.5, @@ -1161,7 +1175,9 @@ async def test_over_rollout_sync_cancel_does_not_imply_immediate_runner_reuse(se self.assertGreaterEqual(elapsed, 0.5) self.assertTrue(scheduler.has_step(1)) - next_statuses, next_exps = await scheduler.get_results(batch_id=1, min_num=2, timeout=4) + next_statuses, next_exps = await collect_results( + scheduler, batch_id=1, min_num=2, timeout=4 + ) self.assertEqual(len(next_statuses), 2) self.assertEqual(len(next_exps), 2) @@ -1179,7 +1195,7 @@ async def test_timeout_cleanup_still_restarts_runner(self): scheduler.schedule(tasks, batch_id=0) with patch.object(scheduler, "_restart_runner", new=AsyncMock()) as restart_runner_mock: - await scheduler.get_results(batch_id=0, timeout=1) + await collect_results(scheduler, batch_id=0, timeout=1) self.assertGreaterEqual(restart_runner_mock.await_count, 1) @@ -1228,7 +1244,7 @@ async def test_dynamic_timeout(self): scheduler.schedule( tasks, batch_id="0/eval" ) # eval tasks will not count into dynamic timeout - statuses, exps = await scheduler.get_results(batch_id="0/eval") + statuses, exps = await collect_results(scheduler, batch_id="0/eval") self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 0) self.assertEqual(scheduler.total_running_time, 0) @@ -1238,14 +1254,14 @@ async def test_dynamic_timeout(self): # generate 4 tasks that will run 1 second tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=1)) scheduler.schedule(tasks, batch_id=0) # first step will not use dynamic timeout - statuses, exps = await scheduler.get_results(batch_id=0) + statuses, exps = await collect_results(scheduler, batch_id=0) self.assertEqual(len(statuses), 4) # dynamic timeout will be set to 3.0 * 1.0 = 3.0 seconds for next step tasks = [] tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=4)) st = time.time() scheduler.schedule(tasks, batch_id=1) - statuses, exps = await scheduler.get_results(batch_id=1) + statuses, exps = await collect_results(scheduler, batch_id=1) et = time.time() self.assertTrue( et - st < 4 @@ -1256,7 +1272,7 @@ async def test_dynamic_timeout(self): tasks = [] tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=2)) scheduler.schedule(tasks, batch_id=2) - statuses, exps = await scheduler.get_results(batch_id=2) + statuses, exps = await collect_results(scheduler, batch_id=2) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4) @@ -1275,7 +1291,7 @@ async def test_dynamic_timeout_warmup_min_steps_uses_completed_steps(self): tasks = generate_tasks(0, timeout_num=2, repeat_times=4, timeout_seconds=1) scheduler.schedule(tasks, batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0) + statuses, exps = await collect_results(scheduler, batch_id=0) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 8) @@ -1286,7 +1302,7 @@ async def test_dynamic_timeout_warmup_min_steps_uses_completed_steps(self): tasks = generate_tasks(0, timeout_num=2, repeat_times=4, timeout_seconds=1) scheduler.schedule(tasks, batch_id=1) - statuses, exps = await scheduler.get_results(batch_id=1) + statuses, exps = await collect_results(scheduler, batch_id=1) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 8) @@ -1295,7 +1311,7 @@ async def test_dynamic_timeout_warmup_min_steps_uses_completed_steps(self): await scheduler.stop() - async def test_completed_task_events_keep_get_results_compatible(self): + async def test_completed_task_events_keep_payload_results_compatible(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() scheduler.enable_completed_task_events() @@ -1312,19 +1328,19 @@ async def test_completed_task_events_keep_get_results_compatible(self): self.assertEqual({ref.batch_id for ref in completed_refs}, {0}) self.assertEqual({ref.task_id for ref in completed_refs}, {0, 1, 2, 3}) - statuses, exps = await scheduler.get_results(batch_id=0, timeout=1) + statuses, exps = await collect_results(scheduler, batch_id=0, timeout=1) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 8) await scheduler.stop() - async def test_get_results_reads_payloads_returned_by_workflow_runner(self): + async def test_collect_results_reads_payloads_returned_by_workflow_runner(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() scheduler.schedule(generate_tasks(3, repeat_times=2), batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0, timeout=10) + statuses, exps = await collect_results(scheduler, batch_id=0, timeout=10) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 6) @@ -1336,7 +1352,7 @@ async def test_timeout_cleanup_keeps_completed_payloads_local(self): scheduler.schedule(generate_tasks(1, timeout_num=1, timeout_seconds=10), batch_id=0) - statuses, exps = await scheduler.get_results(batch_id=0, min_num=2, timeout=1) + statuses, exps = await collect_results(scheduler, batch_id=0, min_num=2, timeout=1) self.assertEqual(len(statuses), 1) self.assertEqual(len(exps), 1) @@ -1351,7 +1367,7 @@ async def test_eval_tasks_do_not_return_training_experiences(self): task.is_eval = True scheduler.schedule(eval_tasks, batch_id="0/eval") - statuses, exps = await scheduler.get_results(batch_id="0/eval", timeout=10) + statuses, exps = await collect_results(scheduler, batch_id="0/eval", timeout=10) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 0) @@ -1365,7 +1381,7 @@ async def test_get_statuses_skips_payload_deserialization(self): scheduler.schedule(generate_tasks(2, repeat_times=2), batch_id=0) with patch( - "trinity.explorer.scheduler.Experience.deserialize_many", + "trinity.common.experience.Experience.deserialize_many", side_effect=AssertionError("payload deserialization should not happen"), ): statuses = await scheduler.get_statuses(batch_id=0, timeout=10) @@ -1381,7 +1397,7 @@ async def test_get_payload_results_keeps_payloads_serialized(self): scheduler.schedule(generate_tasks(2, repeat_times=2), batch_id=0) with patch( - "trinity.explorer.scheduler.Experience.deserialize_many", + "trinity.common.experience.Experience.deserialize_many", side_effect=AssertionError("payload deserialization should not happen"), ): statuses, payloads = await scheduler.get_payload_results(batch_id=0, timeout=10) @@ -1450,5 +1466,5 @@ async def monitor_routine(): await asyncio.gather( monitor_routine(), - scheduler.get_results(batch_id=0), + collect_results(scheduler, batch_id=0), ) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 526500de636..b64bff812dd 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -12,7 +12,6 @@ import ray from trinity.common.config import Config -from trinity.common.experience import Experience from trinity.common.models import InferenceModel from trinity.common.workflows import Task from trinity.explorer.workflow_runner import Status, WorkflowRunner @@ -847,35 +846,6 @@ async def get_payload_results( return_partial_tasks=return_partial_tasks, ) - async def get_results( - self, - batch_id: Union[int, str], - min_num: Optional[int] = None, - timeout: Optional[float] = None, - clear_timeout_tasks: bool = True, - return_partial_tasks: bool = False, - ) -> Tuple[List[Status], List[Experience]]: - """Get the result of tasks at the specific batch_id. - - Args: - batch_id (`Union[int, str]`): Only wait for tasks at this batch. - min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. - timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. - clear_timeout_tasks (`bool`): Whether to clear timeout tasks. - return_partial_tasks (`bool`): Whether to emit tasks with partial successful runs when cleaning up unfinished tasks. - """ - statuses, payload_chunks = await self._get_batch_payload_results( - batch_id=batch_id, - min_num=min_num, - timeout=timeout, - clear_timeout_tasks=clear_timeout_tasks, - return_partial_tasks=return_partial_tasks, - ) - experiences = [] - for exp_payload in payload_chunks: - experiences.extend(Experience.deserialize_many(exp_payload)) - return statuses, experiences - async def get_statuses( self, batch_id: Union[int, str], From 200fedfa421df4365b04e35cd678bdff467709c5 Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 19:35:05 +0800 Subject: [PATCH 07/11] fix pre-commit --- tests/explorer/explorer_test.py | 1 - tests/explorer/scheduler_test.py | 12 +++++++++--- trinity/explorer/explorer.py | 6 ++++-- trinity/explorer/rollout_coordinator.py | 5 +++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 136f0a0ea17..85128cab8bb 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -390,7 +390,6 @@ async def test_over_rollout_submits_partial_finalize_policy_to_rollout_coordinat ) - def run_serve(config): config.check_and_update() run_stage(config) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 2abaf82d8ba..76eab176d43 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -546,10 +546,14 @@ async def test_get_payload_results(self): # test _cleanup_batch_and_restart_runners: part I, no clear tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) scheduler.schedule(tasks, batch_id=2) - statuses, exps = await collect_results(scheduler, batch_id=2, timeout=2, clear_timeout_tasks=False) + statuses, exps = await collect_results( + scheduler, batch_id=2, timeout=2, clear_timeout_tasks=False + ) self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3) - statuses, exps = await collect_results(scheduler, batch_id=2, timeout=2, clear_timeout_tasks=False) + statuses, exps = await collect_results( + scheduler, batch_id=2, timeout=2, clear_timeout_tasks=False + ) self.assertEqual(len(statuses), 1) self.assertEqual(len(exps), 1) # test _cleanup_batch_and_restart_runners: part II, clear @@ -642,7 +646,9 @@ async def test_concurrent_operations(self): async def schedule_tasks(batch_id, num_tasks): tasks = generate_tasks(num_tasks) scheduler.schedule(tasks, batch_id=batch_id) - return await collect_results(scheduler, batch_id=batch_id, min_num=num_tasks, timeout=10) + return await collect_results( + scheduler, batch_id=batch_id, min_num=num_tasks, timeout=10 + ) results = await asyncio.gather( schedule_tasks(0, 3), diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 9e159cb3b25..0633ed565eb 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -29,7 +29,7 @@ from trinity.manager.synchronizer import Synchronizer from trinity.utils.annotations import Experimental from trinity.utils.log import get_logger -from trinity.utils.monitor import MONITOR, gather_eval_metrics, gather_metrics +from trinity.utils.monitor import MONITOR from trinity.utils.plugin_loader import load_plugins from trinity.utils.timer import Timer @@ -416,7 +416,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first." + assert ( + self.rollout_coordinator is not None + ), "Rollout coordinator must be prepared first." result = await self.rollout_coordinator.finalize_eval_batch.remote( f"{step}/{eval_task_name}" ) diff --git a/trinity/explorer/rollout_coordinator.py b/trinity/explorer/rollout_coordinator.py index f1a220208a7..7922ccd0420 100644 --- a/trinity/explorer/rollout_coordinator.py +++ b/trinity/explorer/rollout_coordinator.py @@ -8,7 +8,6 @@ from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline from trinity.common.config import Config -from trinity.common.experience import Experience from trinity.common.models import InferenceModel from trinity.common.workflows import Task from trinity.explorer.scheduler import CompletedTaskResult, Scheduler @@ -293,7 +292,9 @@ async def _completed_task_event_loop(self) -> None: completed_ref.task_id, ) continue - completed_result = scheduler.pop_completed_task(completed_ref.batch_id, completed_ref.task_id) + completed_result = scheduler.pop_completed_task( + completed_ref.batch_id, completed_ref.task_id + ) if completed_result is None: continue batch_state = self.pending_batches.get(completed_ref.batch_id) From f4a8817c638786e41d1061815ffed04f8e3927cf Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 20:50:28 +0800 Subject: [PATCH 08/11] simplify implementation --- tests/explorer/scheduler_test.py | 33 +++++++----- trinity/explorer/rollout_coordinator.py | 71 ++++++++++++++----------- trinity/explorer/scheduler.py | 47 +++++----------- 3 files changed, 73 insertions(+), 78 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 76eab176d43..a25d7c604d9 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -15,7 +15,7 @@ from trinity.common.experience import EID, Experience from trinity.common.models.model import InferenceModel, ModelWrapper from trinity.common.workflows import WORKFLOWS, Task, Workflow -from trinity.explorer.scheduler import CompletedTaskRef, Scheduler +from trinity.explorer.scheduler import CompletedTaskResult, Scheduler @WORKFLOWS.register_module("dummy_workflow") @@ -1317,24 +1317,33 @@ async def test_dynamic_timeout_warmup_min_steps_uses_completed_steps(self): await scheduler.stop() - async def test_completed_task_events_keep_payload_results_compatible(self): - scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + async def test_completed_task_events_return_full_task_results_directly(self): + scheduler = Scheduler( + self.config, + [DummyModel.remote(), DummyModel.remote()], + emit_completed_task_events=True, + ) await scheduler.start() - scheduler.enable_completed_task_events() scheduler.schedule(generate_tasks(4, repeat_times=2), batch_id=0) - completed_refs = [] + completed_results = [] for _ in range(4): - completed_ref = await scheduler.wait_completed_task(timeout=10) - self.assertIsNotNone(completed_ref) - self.assertIsInstance(completed_ref, CompletedTaskRef) - completed_refs.append(completed_ref) + completed_result = await scheduler.wait_completed_task(timeout=10) + self.assertIsNotNone(completed_result) + self.assertIsInstance(completed_result, CompletedTaskResult) + completed_results.append(completed_result) - self.assertEqual({ref.batch_id for ref in completed_refs}, {0}) - self.assertEqual({ref.task_id for ref in completed_refs}, {0, 1, 2, 3}) + self.assertEqual({result.batch_id for result in completed_results}, {0}) + self.assertEqual({result.task_id for result in completed_results}, {0, 1, 2, 3}) + self.assertNotIn(0, scheduler.completed_tasks) + + statuses = [result.status for result in completed_results] + exps = [] + for result in completed_results: + for payload in result.experience_payloads: + exps.extend(Experience.deserialize_many(payload)) - statuses, exps = await collect_results(scheduler, batch_id=0, timeout=1) self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 8) diff --git a/trinity/explorer/rollout_coordinator.py b/trinity/explorer/rollout_coordinator.py index 7922ccd0420..44e00d724bd 100644 --- a/trinity/explorer/rollout_coordinator.py +++ b/trinity/explorer/rollout_coordinator.py @@ -54,10 +54,7 @@ class BatchState: state: BatchLifecycleState = BatchLifecycleState.PENDING finalize_reason: Optional[FinalizeReason] = None final_result: Optional[dict] = None - result_future: Optional[asyncio.Future] = None - ready_event: asyncio.Event = field(default_factory=asyncio.Event) finalize_lock: asyncio.Lock = field(default_factory=asyncio.Lock) - first_submit_time: float = 0.0 last_progress_time: float = 0.0 min_threshold_reached_time: Optional[float] = None @@ -67,12 +64,6 @@ def has_partial_results(self) -> bool: return self.completed_task_count > 0 - @property - def has_staged_payloads(self) -> bool: - """Whether this batch still owns unconsumed payload data.""" - - return bool(self.staged_task_ids or self.payload_chunks_by_task) - class RolloutCoordinator: """Own scheduler-side batch state and expose batch-level finalize APIs.""" @@ -92,6 +83,7 @@ def __init__( self.experience_pipeline = None self.scheduler: Optional[Scheduler] = None self.pending_batches: Dict[BatchId, BatchState] = {} + self.terminal_batch_results: Dict[BatchId, dict] = {} self.event_loop_task: Optional[asyncio.Task] = None self.running = False self.detailed_stats = getattr(getattr(config, "monitor", None), "detailed_stats", False) @@ -108,7 +100,6 @@ async def prepare(self) -> None: if self.scheduler is None: self.scheduler = self._init_scheduler() await self.scheduler.start() - self.scheduler.enable_completed_task_events() self.running = True self.event_loop_task = asyncio.create_task(self._completed_task_event_loop()) @@ -144,6 +135,7 @@ def _init_scheduler(self) -> Scheduler: self.config, self.rollout_model, self.auxiliary_models, + emit_completed_task_events=True, ) def _require_scheduler(self) -> Scheduler: @@ -163,6 +155,7 @@ async def submit_batch( ) -> None: """Register a new batch and schedule its tasks.""" + self.terminal_batch_results.pop(batch_id, None) existing_state = self.pending_batches.get(batch_id) if existing_state is not None and existing_state.state not in { BatchLifecycleState.FINALIZED, @@ -170,7 +163,6 @@ async def submit_batch( }: raise ValueError(f"Batch {batch_id} is already active.") - loop = asyncio.get_running_loop() now = time.time() batch_state = BatchState( batch_id=batch_id, @@ -178,9 +170,7 @@ async def submit_batch( expected_task_count=len(tasks), min_wait_num=min_wait_num, allow_partial_finalize=allow_partial_finalize, - first_submit_time=now, last_progress_time=now, - result_future=loop.create_future(), ) self.pending_batches[batch_id] = batch_state @@ -190,7 +180,6 @@ async def submit_batch( else: batch_state.state = BatchLifecycleState.READY_TO_FINALIZE batch_state.finalize_reason = FinalizeReason.COMPLETE - batch_state.ready_event.set() async def finalize_train_batch( self, @@ -200,6 +189,9 @@ async def finalize_train_batch( ) -> dict: """Finalize one train batch and return aggregated metrics.""" + terminal_result = self.terminal_batch_results.get(batch_id) + if terminal_result is not None: + return dict(terminal_result) batch_state = self._get_batch_state(batch_id, expected_type="train") return await self._finalize_batch(batch_state, timeout=timeout) @@ -211,6 +203,9 @@ async def finalize_eval_batch( ) -> dict: """Finalize one eval batch and return aggregated eval metrics.""" + terminal_result = self.terminal_batch_results.get(batch_id) + if terminal_result is not None: + return dict(terminal_result) scheduler = self._require_scheduler() batch_state = self._get_batch_state(batch_id, expected_type="eval") async with batch_state.finalize_lock: @@ -241,8 +236,7 @@ async def finalize_eval_batch( batch_state.finalize_reason = reason batch_state.state = BatchLifecycleState.FINALIZED batch_state.final_result = self._build_batch_result(batch_state, reason, {}) - if batch_state.result_future is not None and not batch_state.result_future.done(): - batch_state.result_future.set_result(batch_state.final_result) + self._cache_terminal_batch_result(batch_state) return dict(batch_state.final_result) async def abort_batch( @@ -273,9 +267,7 @@ async def abort_batch( batch_state.state = BatchLifecycleState.ABORTED batch_state.finalize_reason = FinalizeReason.ABORT batch_state.final_result = self._build_batch_result(batch_state, FinalizeReason.ABORT, {}) - batch_state.ready_event.set() - if batch_state.result_future is not None and not batch_state.result_future.done(): - batch_state.result_future.set_result(batch_state.final_result) + self._cache_terminal_batch_result(batch_state) async def _completed_task_event_loop(self) -> None: """Consume task completion events emitted by the scheduler.""" @@ -283,21 +275,16 @@ async def _completed_task_event_loop(self) -> None: scheduler = self._require_scheduler() while self.running: try: - completed_ref = await scheduler.wait_completed_task(timeout=0.1) - if completed_ref is None: + completed_result = await scheduler.wait_completed_task(timeout=0.1) + if completed_result is None: continue - if not isinstance(completed_ref.task_id, int): + if not isinstance(completed_result.task_id, int): self.logger.warning( "Skip completed task event with non-integer task id: %s", - completed_ref.task_id, + completed_result.task_id, ) continue - completed_result = scheduler.pop_completed_task( - completed_ref.batch_id, completed_ref.task_id - ) - if completed_result is None: - continue - batch_state = self.pending_batches.get(completed_ref.batch_id) + batch_state = self.pending_batches.get(completed_result.batch_id) if batch_state is None: continue await self._store_completed_task_result(batch_state, completed_result) @@ -386,7 +373,6 @@ def _maybe_mark_ready(self, batch_state: BatchState) -> Optional[FinalizeReason] }: batch_state.state = BatchLifecycleState.READY_TO_FINALIZE batch_state.finalize_reason = ready_reason - batch_state.ready_event.set() return ready_reason async def _wait_for_ready( @@ -426,6 +412,8 @@ async def _finalize_batch(self, batch_state: BatchState, *, timeout: Optional[fl batch_state.finalize_reason = ready_reason try: pipeline_metrics = await self._finalize_train_payloads(batch_state) + if ready_reason != FinalizeReason.COMPLETE: + await self._cleanup_train_batch_runtime(batch_state) except Exception: batch_state.state = BatchLifecycleState.READY_TO_FINALIZE batch_state.finalize_reason = None @@ -435,10 +423,29 @@ async def _finalize_batch(self, batch_state: BatchState, *, timeout: Optional[fl batch_state.final_result = self._build_batch_result( batch_state, ready_reason, pipeline_metrics ) - if batch_state.result_future is not None and not batch_state.result_future.done(): - batch_state.result_future.set_result(batch_state.final_result) + self._cache_terminal_batch_result(batch_state) return dict(batch_state.final_result) + async def _cleanup_train_batch_runtime(self, batch_state: BatchState) -> None: + """Drop unfinished train work after a non-complete finalize result.""" + + scheduler = self._require_scheduler() + await scheduler.abort_batch( + batch_state.batch_id, + return_partial_tasks=False, + restart_runners=True, + ) + if self.experience_pipeline is not None: + await self.experience_pipeline.abort_batch(batch_state.batch_id) + + def _cache_terminal_batch_result(self, batch_state: BatchState) -> None: + """Store one terminal result outside the active batch map for idempotent reuse.""" + + if batch_state.final_result is None: + return + self.terminal_batch_results[batch_state.batch_id] = dict(batch_state.final_result) + self.pending_batches.pop(batch_state.batch_id, None) + async def _finalize_train_payloads(self, batch_state: BatchState) -> dict: """Flush staged train payloads through the experience pipeline.""" diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index b64bff812dd..1d135ef8871 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -41,19 +41,12 @@ class TaskWrapper: class CompletedTaskResult: """A completed task result stored by batch and task id.""" + batch_id: Union[int, str] task_id: Union[int, str] status: Status experience_payloads: List[bytes] = field(default_factory=list) -@dataclass(frozen=True) -class CompletedTaskRef: - """A lightweight reference to a completed task result.""" - - batch_id: Union[int, str] - task_id: Union[int, str] - - # Adapted from verl/trainer/ppo/metric_utils.py def bootstrap_metric( data: list[Any], @@ -320,6 +313,8 @@ def __init__( config: Config, rollout_model: List[InferenceModel], auxiliary_models: Optional[List[List[InferenceModel]]] = None, + *, + emit_completed_task_events: bool = False, ): self.logger = get_logger(__name__) self.config = config @@ -355,8 +350,8 @@ def __init__( ] = defaultdict( dict ) # batch_id -> results - self.completed_task_refs: asyncio.Queue[CompletedTaskRef] = asyncio.Queue() - self.emit_completed_task_events = False + self.completed_task_results: asyncio.Queue[CompletedTaskResult] = asyncio.Queue() + self.emit_completed_task_events = emit_completed_task_events self.background_tasks: set[asyncio.Task] = set() self.scheduler_task: Optional[asyncio.Task] = None @@ -526,44 +521,28 @@ def _emit_task_result(self, task: TaskWrapper) -> None: return status, experience_payloads = self._build_task_result(task) task_id = task.task.task_id - self.completed_tasks[task.batch_id][task_id] = CompletedTaskResult( + completed_result = CompletedTaskResult( + batch_id=task.batch_id, task_id=task_id, status=status, experience_payloads=experience_payloads, ) if self.emit_completed_task_events and not task.task.is_eval: - self.completed_task_refs.put_nowait( - CompletedTaskRef(batch_id=task.batch_id, task_id=task_id) - ) + self.completed_task_results.put_nowait(completed_result) + else: + self.completed_tasks[task.batch_id][task_id] = completed_result task.emitted = True - def enable_completed_task_events(self) -> None: - self.emit_completed_task_events = True - - def disable_completed_task_events(self) -> None: - self.emit_completed_task_events = False - async def wait_completed_task( self, timeout: Optional[float] = None - ) -> Optional[CompletedTaskRef]: + ) -> Optional[CompletedTaskResult]: try: if timeout is None: - return await self.completed_task_refs.get() - return await asyncio.wait_for(self.completed_task_refs.get(), timeout=timeout) + return await self.completed_task_results.get() + return await asyncio.wait_for(self.completed_task_results.get(), timeout=timeout) except asyncio.TimeoutError: return None - def pop_completed_task( - self, batch_id: Union[int, str], task_id: int - ) -> Optional[CompletedTaskResult]: - batch_results = self.completed_tasks.get(batch_id) - if not batch_results: - return None - result = batch_results.pop(task_id, None) - if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: - del self.completed_tasks[batch_id] - return result - def _collect_incomplete_tasks(self, batch_id: Union[int, str]) -> List[TaskWrapper]: tasks = {} for task, _, _ in self.pending_tasks.get(batch_id, deque()): From b07eb29a3a16002be595fbf35094ee26a4708f6b Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 20:54:39 +0800 Subject: [PATCH 09/11] add coordinator test --- tests/explorer/rollout_coordinator_test.py | 327 +++++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 tests/explorer/rollout_coordinator_test.py diff --git a/tests/explorer/rollout_coordinator_test.py b/tests/explorer/rollout_coordinator_test.py new file mode 100644 index 00000000000..64d68463505 --- /dev/null +++ b/tests/explorer/rollout_coordinator_test.py @@ -0,0 +1,327 @@ +"""Unit tests for RolloutCoordinator.""" + +import asyncio +import unittest +from types import SimpleNamespace + +from trinity.explorer.rollout_coordinator import RolloutCoordinator +from trinity.explorer.scheduler import CompletedTaskResult +from trinity.explorer.workflow_runner import Status + + +class FakePipeline: + """Minimal in-memory pipeline double for coordinator tests.""" + + def __init__(self): + """Initialize call tracking state.""" + + self.stage_calls = [] + self.finalize_calls = [] + self.process_chunk_calls = [] + self.abort_calls = [] + self.prepare_called = False + self.close_called = False + + async def prepare(self): + """Record pipeline preparation.""" + + self.prepare_called = True + + async def stage_task_payloads(self, batch_id, task_id, exp_chunks): + """Record task payload staging.""" + + self.stage_calls.append((batch_id, task_id, list(exp_chunks))) + return f"{batch_id}:{task_id}" if exp_chunks else None + + async def finalize_batch(self, batch_id, task_ids=None): + """Record batch finalization.""" + + task_ids = [] if task_ids is None else list(task_ids) + self.finalize_calls.append((batch_id, task_ids)) + return {"experience_pipeline/experience_count": float(len(task_ids))} + + async def process_serialized_chunks(self, exp_chunks): + """Record serialized chunk processing.""" + + chunks = list(exp_chunks) + self.process_chunk_calls.append(chunks) + return {"experience_pipeline/experience_count": float(len(chunks))} + + async def abort_batch(self, batch_id): + """Record batch abort cleanup.""" + + self.abort_calls.append(batch_id) + + async def close(self): + """Record pipeline closure.""" + + self.close_called = True + + +class FakeScheduler: + """Minimal scheduler double for coordinator tests.""" + + def __init__(self): + """Initialize scheduler state and recorded calls.""" + + self.default_timeout = 1.0 + self.started = False + self.stopped = False + self.schedule_calls = [] + self.abort_calls = [] + self.completed_task_results = asyncio.Queue() + self.batch_results = {} + self.get_statuses_calls = [] + + async def start(self): + """Mark the scheduler as started.""" + + self.started = True + + async def stop(self): + """Mark the scheduler as stopped.""" + + self.stopped = True + + def schedule(self, tasks, batch_id): + """Record scheduled tasks for one batch.""" + + self.schedule_calls.append((batch_id, list(tasks))) + + async def wait_completed_task(self, timeout=None): + """Return the next queued completed task result.""" + + try: + if timeout is None: + return await self.completed_task_results.get() + return await asyncio.wait_for(self.completed_task_results.get(), timeout=timeout) + except asyncio.TimeoutError: + return None + + async def get_statuses( + self, + batch_id, + min_num=None, + timeout=None, + clear_timeout_tasks=True, + return_partial_tasks=False, + ): + """Return only preconfigured statuses for one batch.""" + + self.get_statuses_calls.append( + { + "batch_id": batch_id, + "min_num": min_num, + "timeout": timeout, + "clear_timeout_tasks": clear_timeout_tasks, + "return_partial_tasks": return_partial_tasks, + } + ) + statuses, _ = self.batch_results.pop(batch_id, ([], [])) + return statuses + + async def abort_batch(self, batch_id, return_partial_tasks=False, restart_runners=True): + """Record one scheduler abort request.""" + + self.abort_calls.append( + { + "batch_id": batch_id, + "return_partial_tasks": return_partial_tasks, + "restart_runners": restart_runners, + } + ) + + def emit_completed_task(self, batch_id, task_id, result): + """Queue one completed task event and result.""" + + self.completed_task_results.put_nowait(result) + + +def _build_config(wait_after_min=0.0, return_partial_tasks=True, detailed_stats=False): + return SimpleNamespace( + explorer=SimpleNamespace( + over_rollout=SimpleNamespace( + wait_after_min=wait_after_min, + return_partial_tasks=return_partial_tasks, + ) + ), + monitor=SimpleNamespace(detailed_stats=detailed_stats), + ) + + +def _build_status(metric_value): + return Status( + completed_runs=1, + total_runs=1, + metrics=[{"run_metrics": float(metric_value)}], + ) + + +class CoordinatorHarness(RolloutCoordinator): + """Coordinator subclass that injects fake owned dependencies.""" + + def __init__(self, config, rollout_model, auxiliary_models=None, *, pipeline, scheduler): + """Store doubles and delegate the main initialization to the parent.""" + + self._test_pipeline = pipeline + self._test_scheduler = scheduler + super().__init__(config, rollout_model, auxiliary_models) + + def _init_experience_pipeline(self): + """Return the injected fake pipeline.""" + + return self._test_pipeline + + def _init_scheduler(self): + """Return the injected fake scheduler.""" + + return self._test_scheduler + + +class TestRolloutCoordinator(unittest.IsolatedAsyncioTestCase): + """Focused behavioral tests for the first coordinator implementation.""" + + async def asyncSetUp(self): + """Create one coordinator wired to fake owned dependencies.""" + + self.scheduler = FakeScheduler() + self.pipeline = FakePipeline() + self.coordinator = CoordinatorHarness( + _build_config(), + rollout_model=[], + pipeline=self.pipeline, + scheduler=self.scheduler, + ) + await self.coordinator.prepare() + + async def asyncTearDown(self): + """Shutdown the coordinator after each test.""" + + await self.coordinator.shutdown() + + async def test_finalize_train_batch_tracks_scheduler_events_and_is_idempotent(self): + """Train finalize should consume task events once and reuse the final result.""" + + await self.coordinator.submit_batch( + batch_id=1, + tasks=[SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + batch_type="train", + ) + + self.scheduler.emit_completed_task( + 1, + 0, + CompletedTaskResult( + batch_id=1, + task_id=0, + status=_build_status(10.0), + experience_payloads=[b"payload-0"], + ), + ) + self.scheduler.emit_completed_task( + 1, + 1, + CompletedTaskResult( + batch_id=1, + task_id=1, + status=_build_status(20.0), + experience_payloads=[b"payload-1"], + ), + ) + + result = await self.coordinator.finalize_train_batch(1, timeout=1.0) + repeated = await self.coordinator.finalize_train_batch(1, timeout=1.0) + + self.assertEqual(result["finalize_reason"], "complete") + self.assertEqual(result["finished_task_count"], 2) + self.assertEqual(result["metrics"]["rollout/run_metrics/mean"], 15.0) + self.assertEqual(result["metrics"]["experience_pipeline/experience_count"], 2.0) + self.assertTrue(self.pipeline.prepare_called) + self.assertEqual(len(self.pipeline.stage_calls), 2) + self.assertEqual(self.pipeline.finalize_calls, [(1, [0, 1])]) + self.assertEqual(result, repeated) + self.assertNotIn(1, self.coordinator.pending_batches) + self.assertEqual( + self.coordinator.terminal_batch_results[1]["finalize_reason"], + "complete", + ) + + async def test_finalize_train_batch_supports_partial_finalize(self): + """Train finalize should allow partial completion when policy permits it.""" + + await self.coordinator.submit_batch( + batch_id=2, + tasks=[SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + batch_type="train", + min_wait_num=1, + allow_partial_finalize=True, + ) + + self.scheduler.emit_completed_task( + 2, + 0, + CompletedTaskResult( + batch_id=2, + task_id=0, + status=_build_status(7.0), + experience_payloads=[b"payload-0"], + ), + ) + + result = await self.coordinator.finalize_train_batch(2, timeout=1.0) + + self.assertEqual(result["finalize_reason"], "partial") + self.assertEqual(result["finished_task_count"], 1) + self.assertEqual(self.pipeline.finalize_calls[-1], (2, [0])) + self.assertEqual(self.scheduler.abort_calls[-1]["batch_id"], 2) + self.assertEqual(self.pipeline.abort_calls[-1], 2) + self.assertNotIn(2, self.coordinator.pending_batches) + + async def test_finalize_eval_batch_aggregates_eval_metrics(self): + """Eval finalize should aggregate scheduler results without pipeline writes.""" + + batch_id = "3/eval_set" + await self.coordinator.submit_batch( + batch_id=batch_id, + tasks=[SimpleNamespace(is_eval=True), SimpleNamespace(is_eval=True)], + batch_type="eval", + ) + self.scheduler.batch_results[batch_id] = ( + [_build_status(3.0), _build_status(5.0)], + [], + ) + + result = await self.coordinator.finalize_eval_batch(batch_id, timeout=1.0) + + self.assertEqual(result["finalize_reason"], "complete") + self.assertEqual(result["finished_task_count"], 2) + self.assertEqual(result["metrics"]["eval/eval_set/run_metrics"], 4.0) + self.assertEqual(self.pipeline.finalize_calls, []) + self.assertEqual(self.scheduler.get_statuses_calls[0]["batch_id"], batch_id) + self.assertNotIn(batch_id, self.coordinator.pending_batches) + + async def test_abort_batch_marks_batch_aborted_and_is_visible_to_finalize(self): + """Abort should short-circuit later finalize calls.""" + + await self.coordinator.submit_batch( + batch_id=4, + tasks=[SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], + batch_type="train", + ) + + await self.coordinator.abort_batch(4, reason="shutdown") + result = await self.coordinator.finalize_train_batch(4, timeout=0.1) + + self.assertEqual(result["finalize_reason"], "abort") + self.assertFalse(result["finalized"]) + self.assertEqual(self.scheduler.abort_calls[0]["batch_id"], 4) + self.assertEqual(self.pipeline.abort_calls, [4]) + self.assertNotIn(4, self.coordinator.pending_batches) + + async def test_shutdown_closes_internal_dependencies(self): + """Shutdown should close both owned scheduler and owned pipeline.""" + + await self.coordinator.shutdown() + + self.assertTrue(self.scheduler.stopped) + self.assertTrue(self.pipeline.close_called) From 353632825c069212cc9ea77d2107a33d087d4360 Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 21:27:53 +0800 Subject: [PATCH 10/11] simplify state --- tests/explorer/explorer_test.py | 2 - tests/explorer/rollout_coordinator_test.py | 1 - trinity/explorer/explorer.py | 27 +++--------- trinity/explorer/rollout_coordinator.py | 48 +++++++++------------- 4 files changed, 24 insertions(+), 54 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 85128cab8bb..da170517d63 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -339,7 +339,6 @@ async def test_explore_step_submits_train_batch_to_rollout_coordinator(self): "tasks": [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], "batch_type": "train", "min_wait_num": None, - "allow_partial_finalize": False, } ], ) @@ -384,7 +383,6 @@ async def test_over_rollout_submits_partial_finalize_policy_to_rollout_coordinat "tasks": [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], "batch_type": "train", "min_wait_num": 1, - "allow_partial_finalize": True, } ], ) diff --git a/tests/explorer/rollout_coordinator_test.py b/tests/explorer/rollout_coordinator_test.py index 64d68463505..aa04c50f6cf 100644 --- a/tests/explorer/rollout_coordinator_test.py +++ b/tests/explorer/rollout_coordinator_test.py @@ -254,7 +254,6 @@ async def test_finalize_train_batch_supports_partial_finalize(self): tasks=[SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)], batch_type="train", min_wait_num=1, - allow_partial_finalize=True, ) self.scheduler.emit_completed_task( diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 0633ed565eb..24d62c0e027 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -12,7 +12,6 @@ import ray import torch -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from trinity.buffer.buffer import get_buffer_reader from trinity.buffer.task_scheduler import get_taskset_scheduler @@ -199,7 +198,11 @@ async def prepare(self) -> None: await self.setup_weight_sync_group(master_address, master_port) if self.config.mode != "serve": - self.rollout_coordinator = self._init_rollout_coordinator() + self.rollout_coordinator = RolloutCoordinator.get_actor( + self.config, + self.models, + self.auxiliary_models, + ) await self.rollout_coordinator.prepare.remote() self.logger.info("Rollout coordinator is ready.") if self.config.explorer.eval_on_startup and self.explore_step_num == 0: @@ -268,7 +271,6 @@ async def explore_step(self) -> bool: tasks=tasks, batch_type="train", min_wait_num=self.min_wait_num, - allow_partial_finalize=self.min_wait_num is not None, ) return True @@ -451,25 +453,6 @@ async def is_alive(self) -> bool: """Check if the explorer is alive.""" return True - def _init_rollout_coordinator(self) -> ray.actor.ActorHandle: - """Init rollout coordinator for the task-event-completion path.""" - node_id = ray.get_runtime_context().get_node_id() - return ( - ray.remote(RolloutCoordinator) - .options( - namespace=self.config.ray_namespace, - scheduling_strategy=NodeAffinitySchedulingStrategy( - node_id=node_id, - soft=False, - ), - ) - .remote( - self.config, - self.models, - self.auxiliary_models, - ) - ) - @Experimental async def serve(self) -> None: """Run the explorer in serving mode. diff --git a/trinity/explorer/rollout_coordinator.py b/trinity/explorer/rollout_coordinator.py index 44e00d724bd..73c73692f06 100644 --- a/trinity/explorer/rollout_coordinator.py +++ b/trinity/explorer/rollout_coordinator.py @@ -6,6 +6,8 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union +import ray + from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline from trinity.common.config import Config from trinity.common.models import InferenceModel @@ -48,14 +50,10 @@ class BatchState: completed_task_count: int = 0 statuses: Dict[Union[int, str], Any] = field(default_factory=dict) staged_task_ids: set[int] = field(default_factory=set) - payload_chunks_by_task: Dict[Union[int, str], List[bytes]] = field(default_factory=dict) min_wait_num: Optional[int] = None - allow_partial_finalize: bool = False state: BatchLifecycleState = BatchLifecycleState.PENDING - finalize_reason: Optional[FinalizeReason] = None final_result: Optional[dict] = None finalize_lock: asyncio.Lock = field(default_factory=asyncio.Lock) - last_progress_time: float = 0.0 min_threshold_reached_time: Optional[float] = None @property @@ -151,7 +149,6 @@ async def submit_batch( tasks: list[Task], batch_type: BatchType, min_wait_num: Optional[int] = None, - allow_partial_finalize: bool = False, ) -> None: """Register a new batch and schedule its tasks.""" @@ -163,14 +160,11 @@ async def submit_batch( }: raise ValueError(f"Batch {batch_id} is already active.") - now = time.time() batch_state = BatchState( batch_id=batch_id, batch_type=batch_type, expected_task_count=len(tasks), min_wait_num=min_wait_num, - allow_partial_finalize=allow_partial_finalize, - last_progress_time=now, ) self.pending_batches[batch_id] = batch_state @@ -179,7 +173,6 @@ async def submit_batch( batch_state.state = BatchLifecycleState.RUNNING else: batch_state.state = BatchLifecycleState.READY_TO_FINALIZE - batch_state.finalize_reason = FinalizeReason.COMPLETE async def finalize_train_batch( self, @@ -227,13 +220,11 @@ async def finalize_eval_batch( continue batch_state.statuses[task_id] = status batch_state.completed_task_count += 1 - batch_state.last_progress_time = time.time() reason = ( FinalizeReason.COMPLETE if batch_state.completed_task_count >= batch_state.expected_task_count else FinalizeReason.TIMEOUT ) - batch_state.finalize_reason = reason batch_state.state = BatchLifecycleState.FINALIZED batch_state.final_result = self._build_batch_result(batch_state, reason, {}) self._cache_terminal_batch_result(batch_state) @@ -265,10 +256,24 @@ async def abort_batch( await self.experience_pipeline.abort_batch(batch_id) batch_state.state = BatchLifecycleState.ABORTED - batch_state.finalize_reason = FinalizeReason.ABORT batch_state.final_result = self._build_batch_result(batch_state, FinalizeReason.ABORT, {}) self._cache_terminal_batch_result(batch_state) + @classmethod + def get_actor( + cls, config: Config, models: List, auxiliary_models: List + ) -> ray.actor.ActorHandle: + """Init rollout coordinator for the task-event-completion path.""" + return ( + ray.remote(RolloutCoordinator) + .options(namespace=config.ray_namespace) + .remote( + config, + models, + auxiliary_models, + ) + ) + async def _completed_task_event_loop(self) -> None: """Consume task completion events emitted by the scheduler.""" @@ -301,7 +306,6 @@ async def _store_completed_task_result( return batch_state.statuses[result.task_id] = result.status batch_state.completed_task_count += 1 - batch_state.last_progress_time = time.time() if batch_state.batch_type != "train": return @@ -321,10 +325,6 @@ async def _store_completed_task_result( batch_state.staged_task_ids.add(int(result.task_id)) return - if result.experience_payloads: - batch_state.payload_chunks_by_task[result.task_id] = list(result.experience_payloads) - return - def _get_batch_state(self, batch_id: BatchId, *, expected_type: BatchType) -> BatchState: """Return one registered batch and validate its type.""" @@ -346,7 +346,7 @@ def _get_ready_reason(self, batch_state: BatchState) -> Optional[FinalizeReason] return FinalizeReason.COMPLETE if batch_state.batch_type != "train": return None - if not batch_state.allow_partial_finalize or batch_state.min_wait_num is None: + if batch_state.min_wait_num is None: return None if batch_state.completed_task_count < batch_state.min_wait_num: batch_state.min_threshold_reached_time = None @@ -372,7 +372,6 @@ def _maybe_mark_ready(self, batch_state: BatchState) -> Optional[FinalizeReason] BatchLifecycleState.FINALIZING, }: batch_state.state = BatchLifecycleState.READY_TO_FINALIZE - batch_state.finalize_reason = ready_reason return ready_reason async def _wait_for_ready( @@ -403,20 +402,18 @@ async def _finalize_batch(self, batch_state: BatchState, *, timeout: Optional[fl ready_reason = await self._wait_for_ready(batch_state, timeout) if ready_reason is None: - if batch_state.allow_partial_finalize and batch_state.has_partial_results: + if batch_state.min_wait_num is not None and batch_state.has_partial_results: ready_reason = FinalizeReason.TIMEOUT else: raise TimeoutError(f"Timeout waiting for batch {batch_state.batch_id}.") batch_state.state = BatchLifecycleState.FINALIZING - batch_state.finalize_reason = ready_reason try: pipeline_metrics = await self._finalize_train_payloads(batch_state) if ready_reason != FinalizeReason.COMPLETE: await self._cleanup_train_batch_runtime(batch_state) except Exception: batch_state.state = BatchLifecycleState.READY_TO_FINALIZE - batch_state.finalize_reason = None raise batch_state.state = BatchLifecycleState.FINALIZED @@ -454,13 +451,6 @@ async def _finalize_train_payloads(self, batch_state: BatchState) -> dict: batch_state.batch_id, task_ids=sorted(batch_state.staged_task_ids), ) - if self.experience_pipeline is not None and batch_state.payload_chunks_by_task: - exp_chunks = [] - for task_id in sorted(batch_state.payload_chunks_by_task): - exp_chunks.extend(batch_state.payload_chunks_by_task[task_id]) - return await self.experience_pipeline.process_serialized_chunks( - exp_chunks, - ) return {} def _build_batch_result( From ff0e00a8c3ab91d3e06ce4dcae939df1630a0496 Mon Sep 17 00:00:00 2001 From: pxc Date: Sat, 25 Apr 2026 22:16:38 +0800 Subject: [PATCH 11/11] simplify --- tests/explorer/rollout_coordinator_test.py | 97 +++++++++--- trinity/explorer/rollout_coordinator.py | 166 ++++++++------------- 2 files changed, 142 insertions(+), 121 deletions(-) diff --git a/tests/explorer/rollout_coordinator_test.py b/tests/explorer/rollout_coordinator_test.py index aa04c50f6cf..3afd8dbc8b2 100644 --- a/tests/explorer/rollout_coordinator_test.py +++ b/tests/explorer/rollout_coordinator_test.py @@ -2,6 +2,7 @@ import asyncio import unittest +from collections import defaultdict from types import SimpleNamespace from trinity.explorer.rollout_coordinator import RolloutCoordinator @@ -21,6 +22,7 @@ def __init__(self): self.abort_calls = [] self.prepare_called = False self.close_called = False + self.staged_payloads = defaultdict(dict) async def prepare(self): """Record pipeline preparation.""" @@ -30,15 +32,25 @@ async def prepare(self): async def stage_task_payloads(self, batch_id, task_id, exp_chunks): """Record task payload staging.""" - self.stage_calls.append((batch_id, task_id, list(exp_chunks))) - return f"{batch_id}:{task_id}" if exp_chunks else None + chunks = [chunk for chunk in exp_chunks if chunk] + self.stage_calls.append((batch_id, task_id, chunks)) + if not chunks: + return None + self.staged_payloads[batch_id][task_id] = chunks + return f"{batch_id}:{task_id}" async def finalize_batch(self, batch_id, task_ids=None): """Record batch finalization.""" - task_ids = [] if task_ids is None else list(task_ids) - self.finalize_calls.append((batch_id, task_ids)) - return {"experience_pipeline/experience_count": float(len(task_ids))} + staged = self.staged_payloads.get(batch_id, {}) + selected_task_ids = sorted(staged) if task_ids is None else list(task_ids) + self.finalize_calls.append((batch_id, selected_task_ids)) + exp_chunks = [] + for task_id in selected_task_ids: + exp_chunks.extend(staged.pop(task_id, [])) + if batch_id in self.staged_payloads and not self.staged_payloads[batch_id]: + del self.staged_payloads[batch_id] + return {"experience_pipeline/experience_count": float(len(exp_chunks))} async def process_serialized_chunks(self, exp_chunks): """Record serialized chunk processing.""" @@ -51,6 +63,7 @@ async def abort_batch(self, batch_id): """Record batch abort cleanup.""" self.abort_calls.append(batch_id) + self.staged_payloads.pop(batch_id, None) async def close(self): """Record pipeline closure.""" @@ -230,8 +243,6 @@ async def test_finalize_train_batch_tracks_scheduler_events_and_is_idempotent(se ) result = await self.coordinator.finalize_train_batch(1, timeout=1.0) - repeated = await self.coordinator.finalize_train_batch(1, timeout=1.0) - self.assertEqual(result["finalize_reason"], "complete") self.assertEqual(result["finished_task_count"], 2) self.assertEqual(result["metrics"]["rollout/run_metrics/mean"], 15.0) @@ -239,12 +250,10 @@ async def test_finalize_train_batch_tracks_scheduler_events_and_is_idempotent(se self.assertTrue(self.pipeline.prepare_called) self.assertEqual(len(self.pipeline.stage_calls), 2) self.assertEqual(self.pipeline.finalize_calls, [(1, [0, 1])]) - self.assertEqual(result, repeated) self.assertNotIn(1, self.coordinator.pending_batches) - self.assertEqual( - self.coordinator.terminal_batch_results[1]["finalize_reason"], - "complete", - ) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_train_batch(1, timeout=1.0) async def test_finalize_train_batch_supports_partial_finalize(self): """Train finalize should allow partial completion when policy permits it.""" @@ -299,8 +308,62 @@ async def test_finalize_eval_batch_aggregates_eval_metrics(self): self.assertEqual(self.scheduler.get_statuses_calls[0]["batch_id"], batch_id) self.assertNotIn(batch_id, self.coordinator.pending_batches) - async def test_abort_batch_marks_batch_aborted_and_is_visible_to_finalize(self): - """Abort should short-circuit later finalize calls.""" + async def test_finalize_train_batch_rejects_eval_batches_before_waiting(self): + """Train finalize should reject an active eval batch instead of entering wait logic.""" + + batch_id = "4/eval_set" + await self.coordinator.submit_batch( + batch_id=batch_id, + tasks=[SimpleNamespace(is_eval=True)], + batch_type="eval", + ) + + with self.assertRaisesRegex(ValueError, "expected train"): + await self.coordinator.finalize_train_batch(batch_id, timeout=0.1) + + async def test_terminal_batches_are_not_reusable_after_finalize(self): + """A finalized batch should be evicted instead of being cached for later reuse.""" + + eval_batch_id = "5/eval_set" + await self.coordinator.submit_batch( + batch_id=eval_batch_id, + tasks=[SimpleNamespace(is_eval=True)], + batch_type="eval", + ) + self.scheduler.batch_results[eval_batch_id] = ([_build_status(3.0)], []) + await self.coordinator.finalize_eval_batch(eval_batch_id, timeout=1.0) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_train_batch(eval_batch_id, timeout=0.1) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_eval_batch(eval_batch_id, timeout=0.1) + + await self.coordinator.submit_batch( + batch_id=6, + tasks=[SimpleNamespace(is_eval=False)], + batch_type="train", + ) + self.scheduler.emit_completed_task( + 6, + 0, + CompletedTaskResult( + batch_id=6, + task_id=0, + status=_build_status(11.0), + experience_payloads=[b"payload-0"], + ), + ) + await self.coordinator.finalize_train_batch(6, timeout=1.0) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_eval_batch(6, timeout=0.1) + + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_train_batch(6, timeout=0.1) + + async def test_abort_batch_marks_batch_aborted_and_evicts_it(self): + """Abort should cleanup the batch immediately instead of caching a terminal result.""" await self.coordinator.submit_batch( batch_id=4, @@ -309,14 +372,14 @@ async def test_abort_batch_marks_batch_aborted_and_is_visible_to_finalize(self): ) await self.coordinator.abort_batch(4, reason="shutdown") - result = await self.coordinator.finalize_train_batch(4, timeout=0.1) - self.assertEqual(result["finalize_reason"], "abort") - self.assertFalse(result["finalized"]) self.assertEqual(self.scheduler.abort_calls[0]["batch_id"], 4) self.assertEqual(self.pipeline.abort_calls, [4]) self.assertNotIn(4, self.coordinator.pending_batches) + with self.assertRaisesRegex(KeyError, "not registered"): + await self.coordinator.finalize_train_batch(4, timeout=0.1) + async def test_shutdown_closes_internal_dependencies(self): """Shutdown should close both owned scheduler and owned pipeline.""" diff --git a/trinity/explorer/rollout_coordinator.py b/trinity/explorer/rollout_coordinator.py index 73c73692f06..8e8773badaa 100644 --- a/trinity/explorer/rollout_coordinator.py +++ b/trinity/explorer/rollout_coordinator.py @@ -25,7 +25,6 @@ class BatchLifecycleState(str, Enum): PENDING = "pending" RUNNING = "running" - READY_TO_FINALIZE = "ready_to_finalize" FINALIZING = "finalizing" FINALIZED = "finalized" ABORTED = "aborted" @@ -47,9 +46,7 @@ class BatchState: batch_id: BatchId batch_type: BatchType expected_task_count: int - completed_task_count: int = 0 statuses: Dict[Union[int, str], Any] = field(default_factory=dict) - staged_task_ids: set[int] = field(default_factory=set) min_wait_num: Optional[int] = None state: BatchLifecycleState = BatchLifecycleState.PENDING final_result: Optional[dict] = None @@ -57,10 +54,10 @@ class BatchState: min_threshold_reached_time: Optional[float] = None @property - def has_partial_results(self) -> bool: - """Whether at least one task has completed.""" + def completed_task_count(self) -> int: + """Return the number of completed tasks tracked by status.""" - return self.completed_task_count > 0 + return len(self.statuses) class RolloutCoordinator: @@ -73,7 +70,6 @@ def __init__( auxiliary_models: Optional[List[List[InferenceModel]]] = None, ): """Create a coordinator with internally managed scheduler and pipeline.""" - self.logger = get_logger("rollout_coordinator", in_ray_actor=True) self.config = config self.rollout_model = rollout_model @@ -81,14 +77,12 @@ def __init__( self.experience_pipeline = None self.scheduler: Optional[Scheduler] = None self.pending_batches: Dict[BatchId, BatchState] = {} - self.terminal_batch_results: Dict[BatchId, dict] = {} self.event_loop_task: Optional[asyncio.Task] = None self.running = False self.detailed_stats = getattr(getattr(config, "monitor", None), "detailed_stats", False) async def prepare(self) -> None: """Initialize the owned pipeline and scheduler.""" - if self.running: return if self.experience_pipeline is None: @@ -103,7 +97,6 @@ async def prepare(self) -> None: async def shutdown(self) -> None: """Stop background work and close owned dependencies.""" - self.running = False if self.event_loop_task is not None: self.event_loop_task.cancel() @@ -121,14 +114,12 @@ async def shutdown(self) -> None: def _init_experience_pipeline(self): """Create the experience pipeline owned by this coordinator actor.""" - if self.config.mode == "bench": return None return ExperiencePipeline(self.config) def _init_scheduler(self) -> Scheduler: """Create the scheduler owned by this coordinator.""" - return Scheduler( self.config, self.rollout_model, @@ -138,7 +129,6 @@ def _init_scheduler(self) -> Scheduler: def _require_scheduler(self) -> Scheduler: """Return the initialized scheduler.""" - assert self.scheduler is not None, "RolloutCoordinator.prepare() must be called first." return self.scheduler @@ -151,8 +141,6 @@ async def submit_batch( min_wait_num: Optional[int] = None, ) -> None: """Register a new batch and schedule its tasks.""" - - self.terminal_batch_results.pop(batch_id, None) existing_state = self.pending_batches.get(batch_id) if existing_state is not None and existing_state.state not in { BatchLifecycleState.FINALIZED, @@ -171,8 +159,6 @@ async def submit_batch( if tasks: self._require_scheduler().schedule(tasks, batch_id=batch_id) batch_state.state = BatchLifecycleState.RUNNING - else: - batch_state.state = BatchLifecycleState.READY_TO_FINALIZE async def finalize_train_batch( self, @@ -181,12 +167,8 @@ async def finalize_train_batch( timeout: Optional[float] = None, ) -> dict: """Finalize one train batch and return aggregated metrics.""" - - terminal_result = self.terminal_batch_results.get(batch_id) - if terminal_result is not None: - return dict(terminal_result) batch_state = self._get_batch_state(batch_id, expected_type="train") - return await self._finalize_batch(batch_state, timeout=timeout) + return await self._finalize_train_batch(batch_state, timeout=timeout) async def finalize_eval_batch( self, @@ -195,40 +177,35 @@ async def finalize_eval_batch( timeout: Optional[float] = None, ) -> dict: """Finalize one eval batch and return aggregated eval metrics.""" + batch_state = self._get_batch_state(batch_id, expected_type="eval") + return await self._finalize_eval_batch(batch_state, timeout=timeout) + + async def _finalize_eval_batch( + self, batch_state: BatchState, *, timeout: Optional[float] + ) -> dict: + """Finalize one eval batch.""" - terminal_result = self.terminal_batch_results.get(batch_id) - if terminal_result is not None: - return dict(terminal_result) scheduler = self._require_scheduler() - batch_state = self._get_batch_state(batch_id, expected_type="eval") async with batch_state.finalize_lock: - if batch_state.final_result is not None: - return dict(batch_state.final_result) - if batch_state.state == BatchLifecycleState.ABORTED: - batch_state.final_result = self._build_batch_result( - batch_state, FinalizeReason.ABORT, {} - ) - return dict(batch_state.final_result) + existing_result = self._get_existing_final_result(batch_state) + if existing_result is not None: + return existing_result statuses = await scheduler.get_statuses( - batch_id=batch_id, + batch_id=batch_state.batch_id, timeout=timeout, - return_partial_tasks=self._return_partial_tasks(), + return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, ) for task_id, status in enumerate(statuses): if task_id in batch_state.statuses: continue batch_state.statuses[task_id] = status - batch_state.completed_task_count += 1 reason = ( FinalizeReason.COMPLETE if batch_state.completed_task_count >= batch_state.expected_task_count else FinalizeReason.TIMEOUT ) - batch_state.state = BatchLifecycleState.FINALIZED - batch_state.final_result = self._build_batch_result(batch_state, reason, {}) - self._cache_terminal_batch_result(batch_state) - return dict(batch_state.final_result) + return self._finish_batch(batch_state, reason, {}) async def abort_batch( self, @@ -238,7 +215,6 @@ async def abort_batch( keep_partial_results: bool = False, ) -> None: """Abort one batch and cleanup its running and staged state.""" - scheduler = self._require_scheduler() batch_state = self.pending_batches.get(batch_id) if batch_state is None: @@ -257,7 +233,7 @@ async def abort_batch( batch_state.state = BatchLifecycleState.ABORTED batch_state.final_result = self._build_batch_result(batch_state, FinalizeReason.ABORT, {}) - self._cache_terminal_batch_result(batch_state) + self.pending_batches.pop(batch_id, None) @classmethod def get_actor( @@ -276,22 +252,21 @@ def get_actor( async def _completed_task_event_loop(self) -> None: """Consume task completion events emitted by the scheduler.""" - scheduler = self._require_scheduler() while self.running: try: completed_result = await scheduler.wait_completed_task(timeout=0.1) if completed_result is None: - continue + return if not isinstance(completed_result.task_id, int): self.logger.warning( "Skip completed task event with non-integer task id: %s", completed_result.task_id, ) - continue + return batch_state = self.pending_batches.get(completed_result.batch_id) if batch_state is None: - continue + return await self._store_completed_task_result(batch_state, completed_result) self._maybe_mark_ready(batch_state) except Exception: # noqa: BLE001 @@ -301,33 +276,24 @@ async def _store_completed_task_result( self, batch_state: BatchState, result: CompletedTaskResult ) -> None: """Persist one completed task into batch-level aggregation state.""" - if result.task_id in batch_state.statuses: return batch_state.statuses[result.task_id] = result.status - batch_state.completed_task_count += 1 if batch_state.batch_type != "train": return if self.experience_pipeline is not None and result.experience_payloads: staged_task_id = int(result.task_id) - staged = await self.experience_pipeline.stage_task_payloads( + await self.experience_pipeline.stage_task_payloads( batch_state.batch_id, staged_task_id, result.experience_payloads, ) - if staged is not None: - batch_state.staged_task_ids.add(staged_task_id) - return - - if self.experience_pipeline is not None and result.status.completed_runs > 0: - batch_state.staged_task_ids.add(int(result.task_id)) return def _get_batch_state(self, batch_id: BatchId, *, expected_type: BatchType) -> BatchState: """Return one registered batch and validate its type.""" - batch_state = self.pending_batches.get(batch_id) if batch_state is None: raise KeyError(f"Batch {batch_id} is not registered.") @@ -337,9 +303,18 @@ def _get_batch_state(self, batch_id: BatchId, *, expected_type: BatchType) -> Ba ) return batch_state + def _get_existing_final_result(self, batch_state: BatchState) -> Optional[dict]: + """Reuse an in-flight final result or synthesize an abort result.""" + + if batch_state.final_result is not None: + return dict(batch_state.final_result) + if batch_state.state != BatchLifecycleState.ABORTED: + return None + batch_state.final_result = self._build_batch_result(batch_state, FinalizeReason.ABORT, {}) + return dict(batch_state.final_result) + def _get_ready_reason(self, batch_state: BatchState) -> Optional[FinalizeReason]: """Check whether a batch is ready to finalize and why.""" - if batch_state.state == BatchLifecycleState.ABORTED: return FinalizeReason.ABORT if batch_state.completed_task_count >= batch_state.expected_task_count: @@ -362,23 +337,15 @@ def _get_ready_reason(self, batch_state: BatchState) -> Optional[FinalizeReason] def _maybe_mark_ready(self, batch_state: BatchState) -> Optional[FinalizeReason]: """Transition a batch to ready state when finalize conditions are met.""" - ready_reason = self._get_ready_reason(batch_state) if ready_reason is None: return None - if batch_state.state not in { - BatchLifecycleState.FINALIZED, - BatchLifecycleState.ABORTED, - BatchLifecycleState.FINALIZING, - }: - batch_state.state = BatchLifecycleState.READY_TO_FINALIZE return ready_reason async def _wait_for_ready( self, batch_state: BatchState, timeout: Optional[float] ) -> Optional[FinalizeReason]: """Wait until a batch is ready to finalize or the timeout expires.""" - start_time = time.time() while True: ready_reason = self._maybe_mark_ready(batch_state) @@ -388,21 +355,18 @@ async def _wait_for_ready( return None await asyncio.sleep(0.05) - async def _finalize_batch(self, batch_state: BatchState, *, timeout: Optional[float]) -> dict: - """Finalize one train batch with idempotent result reuse.""" - + async def _finalize_train_batch( + self, batch_state: BatchState, *, timeout: Optional[float] + ) -> dict: + """Finalize one train batch.""" async with batch_state.finalize_lock: - if batch_state.final_result is not None: - return dict(batch_state.final_result) - if batch_state.state == BatchLifecycleState.ABORTED: - batch_state.final_result = self._build_batch_result( - batch_state, FinalizeReason.ABORT, {} - ) - return dict(batch_state.final_result) + existing_result = self._get_existing_final_result(batch_state) + if existing_result is not None: + return existing_result ready_reason = await self._wait_for_ready(batch_state, timeout) if ready_reason is None: - if batch_state.min_wait_num is not None and batch_state.has_partial_results: + if batch_state.min_wait_num is not None and batch_state.statuses: ready_reason = FinalizeReason.TIMEOUT else: raise TimeoutError(f"Timeout waiting for batch {batch_state.batch_id}.") @@ -413,19 +377,31 @@ async def _finalize_batch(self, batch_state: BatchState, *, timeout: Optional[fl if ready_reason != FinalizeReason.COMPLETE: await self._cleanup_train_batch_runtime(batch_state) except Exception: - batch_state.state = BatchLifecycleState.READY_TO_FINALIZE + batch_state.state = self._get_active_batch_state(batch_state) raise - batch_state.state = BatchLifecycleState.FINALIZED - batch_state.final_result = self._build_batch_result( - batch_state, ready_reason, pipeline_metrics - ) - self._cache_terminal_batch_result(batch_state) - return dict(batch_state.final_result) + return self._finish_batch(batch_state, ready_reason, pipeline_metrics) + + def _finish_batch( + self, + batch_state: BatchState, + reason: FinalizeReason, + pipeline_metrics: dict, + ) -> dict: + """Persist one terminal result and evict the batch from active state.""" + batch_state.state = BatchLifecycleState.FINALIZED + batch_state.final_result = self._build_batch_result(batch_state, reason, pipeline_metrics) + self.pending_batches.pop(batch_state.batch_id, None) + return dict(batch_state.final_result) + + def _get_active_batch_state(self, batch_state: BatchState) -> BatchLifecycleState: + """Return the active lifecycle state to restore after a failed finalize attempt.""" + if batch_state.expected_task_count == 0: + return BatchLifecycleState.PENDING + return BatchLifecycleState.RUNNING async def _cleanup_train_batch_runtime(self, batch_state: BatchState) -> None: """Drop unfinished train work after a non-complete finalize result.""" - scheduler = self._require_scheduler() await scheduler.abort_batch( batch_state.batch_id, @@ -435,22 +411,10 @@ async def _cleanup_train_batch_runtime(self, batch_state: BatchState) -> None: if self.experience_pipeline is not None: await self.experience_pipeline.abort_batch(batch_state.batch_id) - def _cache_terminal_batch_result(self, batch_state: BatchState) -> None: - """Store one terminal result outside the active batch map for idempotent reuse.""" - - if batch_state.final_result is None: - return - self.terminal_batch_results[batch_state.batch_id] = dict(batch_state.final_result) - self.pending_batches.pop(batch_state.batch_id, None) - async def _finalize_train_payloads(self, batch_state: BatchState) -> dict: """Flush staged train payloads through the experience pipeline.""" - - if self.experience_pipeline is not None and batch_state.staged_task_ids: - return await self.experience_pipeline.finalize_batch( - batch_state.batch_id, - task_ids=sorted(batch_state.staged_task_ids), - ) + if self.experience_pipeline is not None and batch_state.statuses: + return await self.experience_pipeline.finalize_batch(batch_state.batch_id) return {} def _build_batch_result( @@ -492,13 +456,7 @@ def _build_batch_result( def _eval_metric_prefix(self, batch_id: BatchId) -> str: """Return the metric namespace prefix for one eval batch.""" - batch_name = str(batch_id) if "/" in batch_name: batch_name = batch_name.split("/", 1)[1] return f"eval/{batch_name}" - - def _return_partial_tasks(self) -> bool: - """Return whether scheduler cleanup may emit partial task results.""" - - return bool(getattr(self.config.explorer.over_rollout, "return_partial_tasks", False))