Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
589 changes: 589 additions & 0 deletions docs/design/trace_store_lifecycle_impl.md

Large diffs are not rendered by default.

431 changes: 431 additions & 0 deletions docs/design/trace_store_lifecycle_simple.md

Large diffs are not rendered by default.

16 changes: 12 additions & 4 deletions tests/rl/test_multi_task_agent_loop_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_TaskRunner,
)
from xtuner.v1.rl.agent_loop_manager.producer import GROUP_GENERATE_TIME_KEY, ProduceBatchStatus
from xtuner.v1.rl.replay_buffer import RefreshStalenessResult
from xtuner.v1.rl.utils import calculate_seq_staleness


Expand Down Expand Up @@ -144,7 +145,11 @@ async def produce_batch(self, ctx) -> ProduceBatchStatus:


class _FakeReplayBuffer:
def __init__(self, rollout_states_by_task: dict[str, list[list[str]]], leftover_counts: dict[tuple[str, Status], int]):
def __init__(
self,
rollout_states_by_task: dict[str, list[list[str]]],
leftover_counts: dict[tuple[str, Status], int],
):
self._rollout_states_by_task = rollout_states_by_task
self._leftover_counts = leftover_counts
self.saved_paths: list[Path] = []
Expand All @@ -168,7 +173,7 @@ async def refresh_staleness(
current_train_step: int,
statuses: list[Status] | None = None,
):
expired_counts = {}
refresh_results = {}
for task_name, stale_threshold in task_stale_thresholds.items():
self.refresh_staleness_calls.append(
(task_name, current_train_step, stale_threshold, tuple(statuses or ()))
Expand All @@ -180,8 +185,11 @@ async def refresh_staleness(
state.seq_staleness = calculate_seq_staleness(
min(response_model_steps), current_train_step
)
expired_counts[task_name] = 0
return expired_counts
refresh_results[task_name] = RefreshStalenessResult(
expired_count=0,
expired_session_ids=[],
)
return refresh_results

async def is_ready(self, task_batch_sizes: dict[str, int], *, group_status: Status = Status.COMPLETED):
for task_name, batch_size in task_batch_sizes.items():
Expand Down
16 changes: 13 additions & 3 deletions tests/rl/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@


class MockRolloutState:
def __init__(self, id, seq_staleness=1, status=Status.COMPLETED, reward_score=None):
def __init__(
self,
id,
seq_staleness=1,
status=Status.COMPLETED,
reward_score=None,
session_uid=None,
):
self.id = id
self.uid = id
self.session_uid = session_uid
self.status = status
self.seq_staleness = seq_staleness
self.response_ids = []
Expand Down Expand Up @@ -114,6 +122,7 @@ def _build_context(
progress=progress,
is_valid_sample_fn=strategy.is_valid_sample_fn,
stale_threshold=getattr(strategy, "stale_threshold", None),
enable_partial_rollout=getattr(strategy, "enable_partial_rollout", False),
)

def test_produce_progress_methods_keep_absolute_window(self):
Expand Down Expand Up @@ -827,12 +836,13 @@ async def test_refresh_staleness_refreshes_before_expire_check(self):
stale_item.response_model_steps = [3]
await self.replay_buffer.put([stale_item], task_name)

expired_counts = await self.replay_buffer.refresh_staleness(
refresh_results = await self.replay_buffer.refresh_staleness(
task_stale_thresholds={task_name: 2},
current_train_step=6,
)
expired_groups = await self.replay_buffer.get(10, task_name, Status.EXPIRED)

self.assertEqual(expired_counts, {task_name: 1})
self.assertEqual(refresh_results[task_name].expired_count, 1)
self.assertEqual(refresh_results[task_name].expired_session_ids, [])
self.assertEqual(len(expired_groups), 1)
self.assertEqual(expired_groups[0][0].seq_staleness, 2)
21 changes: 13 additions & 8 deletions tests/rl/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ def __init__(
status=Status.COMPLETED,
response_model_steps=None,
response_ids=None,
session_uid=None,
):
self.id = state_id
self.seq_staleness = staleness
self.status = status
self.input_ids = input_ids if input_ids is not None else [state_id]
self.response_model_steps = response_model_steps
self.response_ids = response_ids if response_ids is not None else []
self.session_uid = session_uid


class TestReplayBuffer(unittest.IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -195,12 +197,13 @@ async def test_refresh_staleness_expires_completed_in_place(self):
"task",
)

expired_counts = await replay_buffer.refresh_staleness(
refresh_results = await replay_buffer.refresh_staleness(
task_stale_thresholds={"task": 2},
current_train_step=6,
)

self.assertEqual(expired_counts, {"task": 1})
self.assertEqual(refresh_results["task"].expired_count, 1)
self.assertEqual(refresh_results["task"].expired_session_ids, [])
self.assertEqual(await replay_buffer.count("task", Status.COMPLETED), 1)
self.assertEqual(await replay_buffer.count("task", Status.EXPIRED), 1)
expired = await replay_buffer.get(1, "task", Status.EXPIRED)
Expand All @@ -212,20 +215,21 @@ async def test_refresh_staleness_expires_completed_in_place(self):
async def test_refresh_staleness_expires_aborted_in_place(self):
replay_buffer = AsyncReplayBufferConfig().build()
await replay_buffer.put(
[MockState("stale-aborted", response_model_steps=[3], status=Status.ABORTED)],
[MockState("stale-aborted", response_model_steps=[3], status=Status.ABORTED, session_uid="sid-aborted")],
"task",
)
await replay_buffer.put(
[MockState("fresh-aborted", response_model_steps=[5], status=Status.ABORTED)],
"task",
)

expired_counts = await replay_buffer.refresh_staleness(
refresh_results = await replay_buffer.refresh_staleness(
task_stale_thresholds={"task": 2},
current_train_step=6,
)

self.assertEqual(expired_counts, {"task": 1})
self.assertEqual(refresh_results["task"].expired_count, 1)
self.assertEqual(refresh_results["task"].expired_session_ids, ["sid-aborted"])
self.assertEqual(await replay_buffer.count("task", Status.ABORTED), 1)
self.assertEqual(await replay_buffer.count("task", Status.EXPIRED), 1)
expired = await replay_buffer.get(1, "task", Status.EXPIRED)
Expand All @@ -244,17 +248,18 @@ async def test_refresh_staleness_respects_status_filter(self):
"task",
)
await replay_buffer.put(
[MockState("stale-aborted", response_model_steps=[3], status=Status.ABORTED)],
[MockState("stale-aborted", response_model_steps=[3], status=Status.ABORTED, session_uid="sid-filter")],
"task",
)

expired_counts = await replay_buffer.refresh_staleness(
refresh_results = await replay_buffer.refresh_staleness(
task_stale_thresholds={"task": 2},
current_train_step=6,
statuses=[Status.ABORTED],
)

self.assertEqual(expired_counts, {"task": 1})
self.assertEqual(refresh_results["task"].expired_count, 1)
self.assertEqual(refresh_results["task"].expired_session_ids, ["sid-filter"])
self.assertEqual(await replay_buffer.count("task", Status.COMPLETED), 1)
self.assertEqual(await replay_buffer.count("task", Status.ABORTED), 0)
self.assertEqual(await replay_buffer.count("task", Status.EXPIRED), 1)
Expand Down
159 changes: 159 additions & 0 deletions tests/rl/test_trace_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, patch

from xtuner.v1.data_proto.rl_data import Status
from xtuner.v1.rl.agent_loop_manager import ProduceContext, ProduceProgress
from xtuner.v1.rl.agent_loop_manager.agent_loop_manager import AgentLoopManager, _TaskRunner
from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig, RefreshStalenessResult
from xtuner.v1.rl.rollout.trace_store import RolloutTraceStore, TokenizedSegment, TraceState


class _TraceRolloutState:
def __init__(
self,
uid: int | str,
*,
status: Status = Status.COMPLETED,
reward_score: float | None = None,
session_uid: int | str | None = None,
):
self.uid = uid
self.id = uid
self.session_uid = session_uid
self.status = status
self.seq_staleness = 0
self.response_ids = []
self.extra_fields = {}
self.reward = {"score": reward_score} if reward_score is not None else None


class _TraceRefreshReplayBuffer:
def __init__(self, expired_session_ids_by_task: dict[str, list[int | str]]):
self._expired_session_ids_by_task = expired_session_ids_by_task

async def refresh_staleness(
self,
*,
task_stale_thresholds: dict[str, int],
current_train_step: int,
statuses: list[Status] | None = None,
):
return {
task_name: RefreshStalenessResult(
expired_count=len(self._expired_session_ids_by_task.get(task_name, [])),
expired_session_ids=self._expired_session_ids_by_task.get(task_name, []),
)
for task_name in task_stale_thresholds
}


class _TraceProduceStrategy:
stale_threshold = 5
enable_partial_rollout = False


class TestRolloutTraceStoreRolloutStatus(unittest.TestCase):
def setUp(self):
store_cls = RolloutTraceStore.__ray_metadata__.modified_class
self.store = store_cls()

def _insert_segment(self, session_id: str):
self.store.insert(session_id, "prompt", TokenizedSegment(text="prompt", token_ids=[1]))

def test_mark_rollout_statuses_marks_completed_and_releases_filtered(self):
self._insert_segment("completed")
self._insert_segment("filtered")

results = self.store.mark_rollout_statuses(
[
("completed", Status.COMPLETED),
("filtered", Status.FILTERED),
]
)

self.assertEqual(results["completed"], TraceState.ROLLOUT_FINISHED.value)
self.assertEqual(results["filtered"], TraceState.RELEASED.value)
self.assertEqual(
self.store.get_state("completed")["state"],
TraceState.ROLLOUT_FINISHED.value,
)
self.assertIsNone(self.store.get_state("filtered"))

def test_mark_rollout_statuses_discards_expired_finished_session(self):
self._insert_segment("expired")
self.store.mark_rollout_status("expired", Status.COMPLETED)

results = self.store.mark_rollout_statuses([("expired", Status.EXPIRED)])

self.assertEqual(results["expired"], TraceState.RELEASED.value)
self.assertIsNone(self.store.get_state("expired"))


class TestTraceStoreProducerReporting(unittest.IsolatedAsyncioTestCase):
async def test_put_generated_group_reports_final_status_to_trace_store(self):
task_name = "test_trace_status"
progress = ProduceProgress.build([task_name])
replay_buffer = AsyncReplayBufferConfig().build()
ctx = ProduceContext(
agent_loop=MagicMock(),
sampler=MagicMock(),
replay_buffer=replay_buffer,
task_batch_size=1,
task_name=task_name,
train_step=0,
update_event=asyncio.Event(),
model_step=0,
progress=progress,
is_valid_sample_fn=lambda samples: False,
)
store = MagicMock()
store.mark_rollout_statuses.remote = AsyncMock(return_value={})

completed_group = [
_TraceRolloutState(
1,
status=Status.COMPLETED,
reward_score=1.0,
session_uid="trace-session",
)
]
with patch("xtuner.v1.rl.agent_loop_manager.producer.get_store", return_value=store):
self.assertFalse(await ctx.put_generated_group(completed_group))

store.mark_rollout_statuses.remote.assert_awaited_once_with(
[("trace-session", Status.FILTERED)],
enable_partial_rollout=False,
)


class TestTraceStoreManagerReporting(unittest.IsolatedAsyncioTestCase):
async def test_refresh_for_all_tasks_reports_expired_sessions_to_trace_store(self):
replay_buffer = _TraceRefreshReplayBuffer({"task_a": ["sid-a", "sid-b"]})
manager = AgentLoopManager(
task_runners=[
_TaskRunner(
task_name="task_a",
agent_loop=MagicMock(),
produce_strategy=_TraceProduceStrategy(),
sampler=MagicMock(),
weight=1.0,
order=0,
),
],
replay_buffer=replay_buffer,
)
store = MagicMock()
store.mark_rollout_statuses.remote = AsyncMock(return_value={})

with patch("xtuner.v1.rl.agent_loop_manager.agent_loop_manager.get_store", return_value=store):
await manager._refresh_for_all_tasks(9, [Status.COMPLETED, Status.ABORTED])

store.mark_rollout_statuses.remote.assert_awaited_once_with(
[("sid-a", Status.EXPIRED), ("sid-b", Status.EXPIRED)],
enable_partial_rollout=False,
)


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

logger = get_logger()

USE_TRACE_STORE_KEY = "use_trace_store"
TRACE_STORE_PROMPT_TEXT_KEY = "trace_store_prompt_text"


class SampleParams(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/data_proto/sequence_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import cast
from typing import Any, cast

import torch
from torch.distributed.device_mesh import DeviceMesh
Expand Down Expand Up @@ -48,7 +48,7 @@ class SequenceContext:
num_img_tokens: list[list[int]] | None

# moe routed_experts
rollout_routed_experts: torch.Tensor | None
rollout_routed_experts: Any

# Private backing attributes for SP shard reconstruction
_raw_input_ids: torch.LongTensor | None
Expand Down
21 changes: 18 additions & 3 deletions xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger
from xtuner.v1.rl.replay_buffer import ReplayBuffer
from xtuner.v1.rl.rollout import RolloutController
from xtuner.v1.rl.rollout.trace_store import get_store
from xtuner.v1.rl.utils import asyncio_run
from xtuner.v1.utils import get_logger

Expand Down Expand Up @@ -209,6 +210,7 @@ def _build_produce_context(
progress=progress,
is_valid_sample_fn=getattr(task_runner.produce_strategy, "is_valid_sample_fn", default_is_valid_sample_fn),
stale_threshold=getattr(task_runner.produce_strategy, "stale_threshold", None),
enable_partial_rollout=getattr(task_runner.produce_strategy, "enable_partial_rollout", False),
)


Expand Down Expand Up @@ -461,15 +463,28 @@ async def _refresh_for_all_tasks(self, train_step: int, statuses: list[Status])
stale_threshold = getattr(task.produce_strategy, "stale_threshold", 1)
task_stale_thresholds[task.task_name] = stale_threshold

expired_counts = await self.replay_buffer.refresh_staleness(
refresh_results = await self.replay_buffer.refresh_staleness(
task_stale_thresholds=task_stale_thresholds,
current_train_step=train_step,
statuses=statuses,
)
for task_name, expired_count in expired_counts.items():
for task in self.task_runners:
task_name = task.task_name
refresh_result = refresh_results[task_name]
self.logger.info(
f"[AgentLoopManager][{self.name}] Refresh staleness for task {task_name}: expired_count={expired_count}"
f"[AgentLoopManager][{self.name}] Refresh staleness for task {task_name}: "
f"expired_count={refresh_result.expired_count}"
)
if refresh_result.expired_session_ids:
trace_events = [(session_id, Status.EXPIRED) for session_id in refresh_result.expired_session_ids]
try:
store = get_store()
await store.mark_rollout_statuses.remote(
trace_events,
enable_partial_rollout=getattr(task.produce_strategy, "enable_partial_rollout", False),
)
except Exception as exc:
self.logger.error(f"Failed to report trace store expired rollout status events: {exc}")

def _get_task_batch_sizes_for_step(self, batch_size: int, train_step: int) -> dict[str, int]:
if len(self.task_runners) == 1:
Expand Down
Loading