Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ explorer:
return_partial_tasks: false
dynamic_timeout:
enable: false
warmup_min_steps: 1
ratio: 3.0
runner_state_report_interval: 0
```
Expand Down Expand Up @@ -452,9 +453,10 @@ explorer:
- `ratio`: Explorer will only wait for `(1 - ratio) * batch_size` of tasks at each step. Default is `0.0`, meaning waiting for all tasks.
- `wait_after_min`: After reaching the minimum task threshold, wait for this many seconds before proceeding. Default is `30.0` seconds.
- `return_partial_tasks`: Whether to return the results of tasks that have only completed partially (e.g., only some runs in GRPO). Default is `false`, meaning only return results of tasks that have completed all runs.
- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks.
- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each scheduled execution unit based on historical runtime statistics.
- `enable`: Whether to enable dynamic timeout. Default is `false`.
- `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`.
- `warmup_min_steps`: The minimum number of fully observed non-eval steps required before dynamic timeout takes effect. Default is `1`. This is equivalent to a warmup batch/step count, and helps avoid enabling dynamic timeout based only on a few fast early completions.
- `ratio`: The timeout for each scheduled execution unit is dynamically set to `average_time_per_success_execution * ratio`. Default is `3.0`.
- `runner_state_report_interval`: Workflow runner report interval (in seconds). If set to a value greater than `0`, the workflow runner will periodically report its status to the main explorer process and print it in the command line for monitoring. Default is `0`, meaning this feature is disabled. If you want to use this feature, it is recommended to set it to `10` seconds or longer to minimize performance impact.

---
Expand Down
6 changes: 4 additions & 2 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ explorer:
return_partial_tasks: false
dynamic_timeout:
enable: false
warmup_min_steps: 1
ratio: 3.0
runner_state_report_interval: 0
```
Expand Down Expand Up @@ -449,9 +450,10 @@ explorer:
- `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。
- `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。
- `return_partial_tasks`: 是否返回仅部分完成的任务结果(例如,在 GRPO 中仅完成部分 run 的任务)。默认为 `false`,表示仅返回已完成组内所有 run 的任务结果。
- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间
- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据历史执行耗时动态调整每个调度执行单元的超时时间
- `enable`: 是否启用动态超时。默认为 `false`。
- `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。
- `warmup_min_steps`: 动态超时生效前至少需要观测到多少个完整结束的非评估 step。默认为 `1`。它等价于预热所需的 batch/step 数,可以避免只根据少量较快完成的早期任务就提前启用动态超时。
- `ratio`: 每个调度执行单元的超时时间动态设置为 `average_time_per_success_execution * ratio`。默认为 `3.0`。
- `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。推荐如需使用此功能,将其设置为 `10` 秒或更长时间以减少对性能的影响。

---
Expand Down
162 changes: 162 additions & 0 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import os
import random
import shutil
import time
import unittest
from collections import deque
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock

import httpx
import ray
Expand Down Expand Up @@ -34,6 +39,98 @@
from trinity.manager.state_manager import StateManager


def _build_fake_coordinator_explorer():
class FakeRemoteMethod:
def __init__(self, func):
self.func = func

async def remote(self, *args, **kwargs):
return await self.func(*args, **kwargs)

class FakeCoordinator:
def __init__(self):
self.submit_calls = []
self.finalize_train_calls = []
self.finalize_eval_calls = []
self.shutdown_calls = 0
self.submit_batch = FakeRemoteMethod(self._submit_batch)
self.finalize_train_batch = FakeRemoteMethod(self._finalize_train_batch)
self.finalize_eval_batch = FakeRemoteMethod(self._finalize_eval_batch)
self.shutdown = FakeRemoteMethod(self._shutdown)

async def _submit_batch(self, **kwargs):
self.submit_calls.append(kwargs)

async def _finalize_train_batch(self, batch_id):
self.finalize_train_calls.append(batch_id)
return {
"batch_id": batch_id,
"batch_type": "train",
"finished_task_count": 1,
"metrics": {
"experience_pipeline/experience_count": 2.0,
"rollout/run_metrics/mean": float(batch_id),
"rollout/finished_task_count": 1.0,
},
"finalize_reason": "complete",
"finalized": True,
}

async def _finalize_eval_batch(self, batch_id):
self.finalize_eval_calls.append(batch_id)
eval_name = batch_id.split("/", 1)[1]
return {
"batch_id": batch_id,
"batch_type": "eval",
"finished_task_count": 2,
"metrics": {
f"eval/{eval_name}/accuracy": 0.5,
f"eval/{eval_name}/finished_task_count": 2.0,
},
"finalize_reason": "complete",
"finalized": True,
}

async def _shutdown(self):
self.shutdown_calls += 1

class FakeMonitor:
def __init__(self):
self.logged = []

def log(self, metric, step):
self.logged.append((step, metric))

feedback_calls = []

async def read_async():
return [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)]

def record_feedback(metrics):
feedback_calls.append(metrics)

explorer = Explorer.__new__(Explorer)
explorer.logger = MagicMock()
explorer.rollout_coordinator = FakeCoordinator()
explorer.monitor = FakeMonitor()
explorer.taskset = SimpleNamespace(read_async=read_async, feedback=record_feedback)
explorer.min_wait_num = None
explorer.pending_eval_tasks = deque()
explorer.explore_start_time = None
explorer.eval_start_time = None
explorer.last_monitored_step = 0
explorer.explore_step_num = 0
explorer.model_version = 7
explorer.detailed_stats = False
explorer.config = SimpleNamespace(
explorer=SimpleNamespace(
over_rollout=SimpleNamespace(return_partial_tasks=False),
eval_interval=1,
)
)
return explorer, feedback_calls


class BaseExplorerCase(RayUnittestBase):
def setUp(self):
self.config = get_template_config()
Expand Down Expand Up @@ -226,6 +323,71 @@ def test_explorer(self):
ray.get(explorer.shutdown.remote())


class TestExplorerCoordinatorPath(unittest.IsolatedAsyncioTestCase):
async def test_explore_step_submits_train_batch_to_rollout_coordinator(self):
explorer, _ = _build_fake_coordinator_explorer()

should_continue = await explorer.explore_step()

self.assertTrue(should_continue)
self.assertEqual(explorer.explore_step_num, 1)
self.assertEqual(
explorer.rollout_coordinator.submit_calls,
[
{
"batch_id": 1,
"tasks": [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)],
"batch_type": "train",
"min_wait_num": None,
}
],
)

async def test_finish_current_steps_uses_rollout_coordinator_finalize(self):
explorer, feedback_calls = _build_fake_coordinator_explorer()
explorer.explore_step_num = 2

await explorer.finish_current_steps()

self.assertEqual(explorer.rollout_coordinator.finalize_train_calls, [1, 2])
self.assertEqual(len(feedback_calls), 2)
self.assertEqual([step for step, _ in explorer.monitor.logged], [1, 2])
self.assertEqual(explorer.last_monitored_step, 2)

async def test_finish_eval_step_uses_rollout_coordinator_finalize(self):
explorer, _ = _build_fake_coordinator_explorer()
explorer.pending_eval_tasks.append((3, "eval_set"))
explorer.eval_start_time = time.time()

await explorer._finish_eval_step(step=3)

self.assertEqual(explorer.rollout_coordinator.finalize_eval_calls, ["3/eval_set"])
self.assertEqual([step for step, _ in explorer.monitor.logged], [3])
self.assertIn("eval/eval_set/accuracy", explorer.monitor.logged[0][1])


class TestExplorerCoordinatorPolicies(unittest.IsolatedAsyncioTestCase):
async def test_over_rollout_submits_partial_finalize_policy_to_rollout_coordinator(self):
explorer, _ = _build_fake_coordinator_explorer()
explorer.min_wait_num = 1
explorer.config.explorer.over_rollout.return_partial_tasks = True

should_continue = await explorer.explore_step()

self.assertTrue(should_continue)
self.assertEqual(
explorer.rollout_coordinator.submit_calls,
[
{
"batch_id": 1,
"tasks": [SimpleNamespace(is_eval=False), SimpleNamespace(is_eval=False)],
"batch_type": "train",
"min_wait_num": 1,
}
],
)


def run_serve(config):
config.check_and_update()
run_stage(config)
Expand Down
Loading
Loading