Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xtuner/v1/rl/agent_loop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
RouterAgentLoop,
get_agent_loop_rollout_ctl,
)
from .localhost_agent_loop.agent_in_localhost_loop import AgentInLocalhostLoop, AgentInLocalhostLoopConfig
from .sandbox_agent_loop.agent_in_sandbox_loop import AgentInSandboxLoop, AgentInSandboxLoopConfig
from .single_turn_agent_loop import SingleTurnAgentLoop, SingleTurnAgentLoopConfig


__all__ = [
"AgentInLocalhostLoop",
"AgentInLocalhostLoopConfig",
"AgentInSandboxLoop",
"AgentInSandboxLoopConfig",
"AgentLoopConfig",
Expand Down
34 changes: 34 additions & 0 deletions xtuner/v1/rl/agent_loop/localhost_agent_loop/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Public surface for the localhost_agent_loop runner."""

from xtuner.v1.rl.agent_loop.localhost_agent_loop.agent_in_localhost_loop import (
AgentInLocalhostLoop,
AgentInLocalhostLoopConfig,
)
from xtuner.v1.rl.agent_loop.localhost_agent_loop.compose import LocalhostComposeStage
from xtuner.v1.rl.agent_loop.localhost_agent_loop.judger import LocalhostJudgerStage
from xtuner.v1.rl.agent_loop.localhost_agent_loop.runner import LocalhostRunner
from xtuner.v1.rl.agent_loop.localhost_agent_loop.schemas import LocalhostAgentSpec
from xtuner.v1.rl.agent_loop.localhost_agent_loop.stage import LocalhostStage
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
AgentRolloutItem,
RolloutError,
RolloutStatus,
StageRecord,
StageStatus,
)


__all__ = [
"AgentInLocalhostLoop",
"AgentInLocalhostLoopConfig",
"AgentRolloutItem",
"LocalhostAgentSpec",
"LocalhostComposeStage",
"LocalhostJudgerStage",
"LocalhostRunner",
"LocalhostStage",
"RolloutError",
"RolloutStatus",
"StageRecord",
"StageStatus",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import annotations

import asyncio
import copy
import importlib
import traceback
import uuid
from typing import Any

from lagent.utils import create_object, ctx_session_id

from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
AgentRolloutItem,
RolloutStatus,
)
from xtuner.v1.rl.judger import Judger
from xtuner.v1.rl.rollout import RolloutController
from xtuner.v1.rl.rollout.trace_store import get_store
from xtuner.v1.rl.utils import create_task

from ..agent_loop import AgentLoop, AgentLoopConfig


def _import_from_path(path: str) -> Any:
module_name, _, attr = path.rpartition(".")
if not module_name or not attr:
raise ValueError(f"Invalid import path: {path!r}. Expected 'module.attr'.")
module = importlib.import_module(module_name)
return getattr(module, attr)


def _resolve_runner(pipeline: Any) -> Any:
if isinstance(pipeline, str):
pipeline = _import_from_path(pipeline)
if isinstance(pipeline, dict):
return create_object(copy.deepcopy(pipeline))
return pipeline


class AgentInLocalhostLoopConfig(AgentLoopConfig):
"""Run a localhost agent runner from ``RolloutState.extra_fields``."""

max_concurrent_samples: int | None = None

def build_local(
self,
rollout_controller: RolloutController | None = None,
judger: Judger | None = None,
logger=None,
) -> AgentInLocalhostLoop:
return AgentInLocalhostLoop(
rollout_ctl=rollout_controller,
sample_params=self.sample_params,
hf_checkpoint=self.hf_checkpoint,
judger=judger,
logger=logger,
max_concurrent_samples=self.max_concurrent_samples,
)


class AgentInLocalhostLoop(AgentLoop):
"""AgentLoop adapter for localhost_agent_loop runners."""

def __init__(
self,
rollout_ctl: RolloutController | None = None,
sample_params: SampleParams | None = None,
hf_checkpoint: str | None = None,
judger: Judger | None = None,
logger=None,
max_concurrent_samples: int | None = None,
):
super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger)
self.max_concurrent_samples = max_concurrent_samples
self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None

async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个代码可以考虑放到基类里面去,因为是通用的。否则其他agent 也要写一遍

async def generate_one(state: RolloutState) -> RolloutState:
if self._sample_semaphore is None:
return await self.generate_sample(state, **kwargs)
async with self._sample_semaphore:
return await self.generate_sample(state, **kwargs)

tasks = []
for state in rollout_state:
state.sample_params = self.sample_params
tasks.append(create_task(generate_one(state)))
return await asyncio.gather(*tasks)

async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
try:
item = rollout_state.extra_fields["rollout_item"].model_copy(deep=True)
if rollout_state.uid is None:
rollout_state.uid = uuid.uuid4().int
item.uid = rollout_state.uid
item.group_id = rollout_state.message_uid
result = await self._run_item(item)
await self._fill_rollout_state(rollout_state, result)
return rollout_state
except Exception as exc:
rollout_state.status = Status.FAILED
rollout_state.finish_reason = "error"
rollout_state.error_msg = f"{type(exc).__name__}: {exc}"
self.logger.error(f"[AgentInLocalhostLoop] failed: {exc}\n{traceback.format_exc()}")
return rollout_state

async def _run_item(self, item: AgentRolloutItem) -> AgentRolloutItem:
runner = _resolve_runner(item.pipeline)
if runner is None:
raise ValueError("AgentRolloutItem.pipeline is required.")
with ctx_session_id.set(str(item.uid)):
return await runner.run(item)

async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRolloutItem) -> None:
segment = item.artifacts["messages"][-1]
text = self.tokenizer.apply_chat_template(
segment["messages"],
tools=segment["tools"],
tokenize=False,
add_generation_prompt=False,
)
prompt_text = text[:-1] if text.endswith("\n") else text
data = await get_store().export_training_trace.remote(str(rollout_state.uid), prompt_text)

rollout_state.input_ids = data["input_ids"]
rollout_state.labels = data["labels"]
rollout_state.response_ids = [
token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100
]
rollout_state.logprobs = data["logprobs"]
rollout_state.routed_experts = data["routed_experts"]
rollout_state.response = str(item.artifacts.get("response") or "")
rollout_state.finish_reason = "stop" if item.status == RolloutStatus.COMPLETED else "error"
rollout_state.status = Status.COMPLETED if item.status == RolloutStatus.COMPLETED else Status.FAILED
rollout_state.reward = {"score": item.reward}
rollout_state.extra_fields["raw_prompt"] = prompt_text
rollout_state.extra_fields["agent_artifacts"] = item.artifacts
if item.error is not None:
rollout_state.error_msg = f"{item.error.stage}/{item.error.category}: {item.error.message}"


__all__ = ["AgentInLocalhostLoop", "AgentInLocalhostLoopConfig"]
72 changes: 72 additions & 0 deletions xtuner/v1/rl/agent_loop/localhost_agent_loop/compose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Composable localhost stages."""

from __future__ import annotations

import time
from typing import Any

from lagent.utils import create_object

from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
AgentRolloutItem,
RolloutError,
StageRecord,
StageStatus,
)


class LocalhostComposeStage:
"""Compose multiple local validation stages behind ``run(item, record) ->
float``."""

def __init__(
self,
stages: list[Any],
*,
name: str = "validate",
weight: float = 1.0,
):
if not stages:
raise ValueError("LocalhostComposeStage.stages is empty")
self.name = name
self.stages = [create_object(stage) for stage in stages]
self.weight = weight

async def run(self, item: AgentRolloutItem, record: StageRecord) -> float:
record.status = StageStatus.RUNNING
record.started_at = record.started_at or time.monotonic()
try:
weighted_score = 0.0
total_weight = 0.0
for stage in self.stages:
name = getattr(stage, "name", stage.__class__.__name__)
child_record = item.judgers.setdefault(name, StageRecord())
score = float(await stage.run(item, child_record))
stage_weight = max(float(getattr(stage, "weight", 1.0)), 0.0)
weighted_score += score * stage_weight
total_weight += stage_weight
record.score = weighted_score / total_weight if total_weight > 0 else 0.0
record.status = StageStatus.COMPLETED
return record.score
except Exception as exc:
record.status = StageStatus.FAILED
child_error = next(
(child.error for child in item.judgers.values() if child.error is not None),
None,
)
record.error = (
record.error
or child_error
or RolloutError(
stage=self.name,
category="validate_failed",
type=type(exc).__name__,
message=str(exc),
)
)
raise
finally:
record.finished_at = time.monotonic()


__all__ = ["LocalhostComposeStage"]
83 changes: 83 additions & 0 deletions xtuner/v1/rl/agent_loop/localhost_agent_loop/judger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Localhost judger stages."""

from __future__ import annotations

import time
from copy import deepcopy
from typing import Any

from lagent.utils import create_object

from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
AgentRolloutItem,
RolloutError,
StageRecord,
StageStatus,
)
from xtuner.v1.rl.judger.native import Judger


class LocalhostJudgerStage:
"""Run one local validation stage.

Public stage interface is ``run(item, record) -> float``. ``RolloutState``
is only the internal shape needed to reuse xtuner judgers.
"""

def __init__(
self,
*,
name: str,
judger_config: Any,
reward_key: str = "score",
weight: float = 1.0,
):
config = create_object(deepcopy(judger_config)) if isinstance(judger_config, dict) else judger_config
self.name = name
self.judger: Judger = config.build()
self.reward_key = reward_key
self.weight = weight

async def run(self, item: AgentRolloutItem, record: StageRecord) -> float:
record.status = StageStatus.RUNNING
record.started_at = record.started_at or time.monotonic()
try:
reward_model = dict(item.reward_model or {})

messages = item.artifacts["messages"][-1]["messages"]
tool_turns = sum(
1 for message in messages if isinstance(message.get("tool_calls"), list) and message["tool_calls"]
)
reward_model.setdefault("agent_trace", messages)
reward_model.setdefault("num_turns", tool_turns)

response = str(item.artifacts.get("response") or "")
rollout_state = RolloutState(
message=[{"role": "user", "content": item.instruction}],
response=response,
reward_model=reward_model,
status=Status.COMPLETED,
)
judged = await self.judger.judge(rollout_state)
reward_payload = judged.reward or {}
if self.reward_key not in reward_payload:
raise KeyError(f"judger reward payload has no {self.reward_key!r}: {reward_payload!r}")
record.metadata["reward"] = reward_payload
record.score = float(reward_payload[self.reward_key])
record.status = StageStatus.COMPLETED
return record.score
except Exception as exc:
record.status = StageStatus.FAILED
record.error = record.error or RolloutError(
stage=self.name,
category="judger",
type=type(exc).__name__,
message=str(exc),
)
raise
finally:
record.finished_at = time.monotonic()


__all__ = ["LocalhostJudgerStage"]
Loading
Loading