-
Notifications
You must be signed in to change notification settings - Fork 422
support localhost agent #1842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Harold-lkk
wants to merge
4
commits into
InternLM:agentic_branch
Choose a base branch
from
Harold-lkk:lkk/localhost_agent
base: agentic_branch
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
support localhost agent #1842
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| """Public surface for the localhost_agent_loop runner.""" | ||
|
|
||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.agent_in_localhost_loop import ( | ||
| AgentInLocalhostLoop, | ||
| AgentInLocalhostLoopConfig, | ||
| ) | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.compose import LocalhostComposeStage | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.judger import LocalhostJudgerStage | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.runner import LocalhostRunner | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.schemas import LocalhostAgentSpec | ||
| from xtuner.v1.rl.agent_loop.localhost_agent_loop.stage import LocalhostStage | ||
| from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import ( | ||
| AgentRolloutItem, | ||
| RolloutError, | ||
| RolloutStatus, | ||
| StageRecord, | ||
| StageStatus, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "AgentInLocalhostLoop", | ||
| "AgentInLocalhostLoopConfig", | ||
| "AgentRolloutItem", | ||
| "LocalhostAgentSpec", | ||
| "LocalhostComposeStage", | ||
| "LocalhostJudgerStage", | ||
| "LocalhostRunner", | ||
| "LocalhostStage", | ||
| "RolloutError", | ||
| "RolloutStatus", | ||
| "StageRecord", | ||
| "StageStatus", | ||
| ] |
143 changes: 143 additions & 0 deletions
143
xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import copy | ||
| import importlib | ||
| import traceback | ||
| import uuid | ||
| from typing import Any | ||
|
|
||
| from lagent.utils import create_object, ctx_session_id | ||
|
|
||
| from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status | ||
| from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import ( | ||
| AgentRolloutItem, | ||
| RolloutStatus, | ||
| ) | ||
| from xtuner.v1.rl.judger import Judger | ||
| from xtuner.v1.rl.rollout import RolloutController | ||
| from xtuner.v1.rl.rollout.trace_store import get_store | ||
| from xtuner.v1.rl.utils import create_task | ||
|
|
||
| from ..agent_loop import AgentLoop, AgentLoopConfig | ||
|
|
||
|
|
||
| def _import_from_path(path: str) -> Any: | ||
| module_name, _, attr = path.rpartition(".") | ||
| if not module_name or not attr: | ||
| raise ValueError(f"Invalid import path: {path!r}. Expected 'module.attr'.") | ||
| module = importlib.import_module(module_name) | ||
| return getattr(module, attr) | ||
|
|
||
|
|
||
| def _resolve_runner(pipeline: Any) -> Any: | ||
| if isinstance(pipeline, str): | ||
| pipeline = _import_from_path(pipeline) | ||
| if isinstance(pipeline, dict): | ||
| return create_object(copy.deepcopy(pipeline)) | ||
| return pipeline | ||
|
|
||
|
|
||
| class AgentInLocalhostLoopConfig(AgentLoopConfig): | ||
| """Run a localhost agent runner from ``RolloutState.extra_fields``.""" | ||
|
|
||
| max_concurrent_samples: int | None = None | ||
|
|
||
| def build_local( | ||
| self, | ||
| rollout_controller: RolloutController | None = None, | ||
| judger: Judger | None = None, | ||
| logger=None, | ||
| ) -> AgentInLocalhostLoop: | ||
| return AgentInLocalhostLoop( | ||
| rollout_ctl=rollout_controller, | ||
| sample_params=self.sample_params, | ||
| hf_checkpoint=self.hf_checkpoint, | ||
| judger=judger, | ||
| logger=logger, | ||
| max_concurrent_samples=self.max_concurrent_samples, | ||
| ) | ||
|
|
||
|
|
||
| class AgentInLocalhostLoop(AgentLoop): | ||
| """AgentLoop adapter for localhost_agent_loop runners.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| rollout_ctl: RolloutController | None = None, | ||
| sample_params: SampleParams | None = None, | ||
| hf_checkpoint: str | None = None, | ||
| judger: Judger | None = None, | ||
| logger=None, | ||
| max_concurrent_samples: int | None = None, | ||
| ): | ||
| super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) | ||
| self.max_concurrent_samples = max_concurrent_samples | ||
| self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None | ||
|
|
||
| async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: | ||
| async def generate_one(state: RolloutState) -> RolloutState: | ||
| if self._sample_semaphore is None: | ||
| return await self.generate_sample(state, **kwargs) | ||
| async with self._sample_semaphore: | ||
| return await self.generate_sample(state, **kwargs) | ||
|
|
||
| tasks = [] | ||
| for state in rollout_state: | ||
| state.sample_params = self.sample_params | ||
| tasks.append(create_task(generate_one(state))) | ||
| return await asyncio.gather(*tasks) | ||
|
|
||
| async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: | ||
| try: | ||
| item = rollout_state.extra_fields["rollout_item"].model_copy(deep=True) | ||
| if rollout_state.uid is None: | ||
| rollout_state.uid = uuid.uuid4().int | ||
| item.uid = rollout_state.uid | ||
| item.group_id = rollout_state.message_uid | ||
| result = await self._run_item(item) | ||
| await self._fill_rollout_state(rollout_state, result) | ||
| return rollout_state | ||
| except Exception as exc: | ||
| rollout_state.status = Status.FAILED | ||
| rollout_state.finish_reason = "error" | ||
| rollout_state.error_msg = f"{type(exc).__name__}: {exc}" | ||
| self.logger.error(f"[AgentInLocalhostLoop] failed: {exc}\n{traceback.format_exc()}") | ||
| return rollout_state | ||
|
|
||
| async def _run_item(self, item: AgentRolloutItem) -> AgentRolloutItem: | ||
| runner = _resolve_runner(item.pipeline) | ||
| if runner is None: | ||
| raise ValueError("AgentRolloutItem.pipeline is required.") | ||
| with ctx_session_id.set(str(item.uid)): | ||
| return await runner.run(item) | ||
|
|
||
| async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRolloutItem) -> None: | ||
| segment = item.artifacts["messages"][-1] | ||
| text = self.tokenizer.apply_chat_template( | ||
| segment["messages"], | ||
| tools=segment["tools"], | ||
| tokenize=False, | ||
| add_generation_prompt=False, | ||
| ) | ||
| prompt_text = text[:-1] if text.endswith("\n") else text | ||
| data = await get_store().export_training_trace.remote(str(rollout_state.uid), prompt_text) | ||
|
|
||
| rollout_state.input_ids = data["input_ids"] | ||
| rollout_state.labels = data["labels"] | ||
| rollout_state.response_ids = [ | ||
| token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100 | ||
| ] | ||
| rollout_state.logprobs = data["logprobs"] | ||
| rollout_state.routed_experts = data["routed_experts"] | ||
| rollout_state.response = str(item.artifacts.get("response") or "") | ||
| rollout_state.finish_reason = "stop" if item.status == RolloutStatus.COMPLETED else "error" | ||
| rollout_state.status = Status.COMPLETED if item.status == RolloutStatus.COMPLETED else Status.FAILED | ||
| rollout_state.reward = {"score": item.reward} | ||
| rollout_state.extra_fields["raw_prompt"] = prompt_text | ||
| rollout_state.extra_fields["agent_artifacts"] = item.artifacts | ||
| if item.error is not None: | ||
| rollout_state.error_msg = f"{item.error.stage}/{item.error.category}: {item.error.message}" | ||
|
|
||
|
|
||
| __all__ = ["AgentInLocalhostLoop", "AgentInLocalhostLoopConfig"] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """Composable localhost stages.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import time | ||
| from typing import Any | ||
|
|
||
| from lagent.utils import create_object | ||
|
|
||
| from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import ( | ||
| AgentRolloutItem, | ||
| RolloutError, | ||
| StageRecord, | ||
| StageStatus, | ||
| ) | ||
|
|
||
|
|
||
| class LocalhostComposeStage: | ||
| """Compose multiple local validation stages behind ``run(item, record) -> | ||
| float``.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| stages: list[Any], | ||
| *, | ||
| name: str = "validate", | ||
| weight: float = 1.0, | ||
| ): | ||
| if not stages: | ||
| raise ValueError("LocalhostComposeStage.stages is empty") | ||
| self.name = name | ||
| self.stages = [create_object(stage) for stage in stages] | ||
| self.weight = weight | ||
|
|
||
| async def run(self, item: AgentRolloutItem, record: StageRecord) -> float: | ||
| record.status = StageStatus.RUNNING | ||
| record.started_at = record.started_at or time.monotonic() | ||
| try: | ||
| weighted_score = 0.0 | ||
| total_weight = 0.0 | ||
| for stage in self.stages: | ||
| name = getattr(stage, "name", stage.__class__.__name__) | ||
| child_record = item.judgers.setdefault(name, StageRecord()) | ||
| score = float(await stage.run(item, child_record)) | ||
| stage_weight = max(float(getattr(stage, "weight", 1.0)), 0.0) | ||
| weighted_score += score * stage_weight | ||
| total_weight += stage_weight | ||
| record.score = weighted_score / total_weight if total_weight > 0 else 0.0 | ||
| record.status = StageStatus.COMPLETED | ||
| return record.score | ||
| except Exception as exc: | ||
| record.status = StageStatus.FAILED | ||
| child_error = next( | ||
| (child.error for child in item.judgers.values() if child.error is not None), | ||
| None, | ||
| ) | ||
| record.error = ( | ||
| record.error | ||
| or child_error | ||
| or RolloutError( | ||
| stage=self.name, | ||
| category="validate_failed", | ||
| type=type(exc).__name__, | ||
| message=str(exc), | ||
| ) | ||
| ) | ||
| raise | ||
| finally: | ||
| record.finished_at = time.monotonic() | ||
|
|
||
|
|
||
| __all__ = ["LocalhostComposeStage"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| """Localhost judger stages.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import time | ||
| from copy import deepcopy | ||
| from typing import Any | ||
|
|
||
| from lagent.utils import create_object | ||
|
|
||
| from xtuner.v1.data_proto.rl_data import RolloutState, Status | ||
| from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import ( | ||
| AgentRolloutItem, | ||
| RolloutError, | ||
| StageRecord, | ||
| StageStatus, | ||
| ) | ||
| from xtuner.v1.rl.judger.native import Judger | ||
|
|
||
|
|
||
| class LocalhostJudgerStage: | ||
| """Run one local validation stage. | ||
|
|
||
| Public stage interface is ``run(item, record) -> float``. ``RolloutState`` | ||
| is only the internal shape needed to reuse xtuner judgers. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| name: str, | ||
| judger_config: Any, | ||
| reward_key: str = "score", | ||
| weight: float = 1.0, | ||
| ): | ||
| config = create_object(deepcopy(judger_config)) if isinstance(judger_config, dict) else judger_config | ||
| self.name = name | ||
| self.judger: Judger = config.build() | ||
| self.reward_key = reward_key | ||
| self.weight = weight | ||
|
|
||
| async def run(self, item: AgentRolloutItem, record: StageRecord) -> float: | ||
| record.status = StageStatus.RUNNING | ||
| record.started_at = record.started_at or time.monotonic() | ||
| try: | ||
| reward_model = dict(item.reward_model or {}) | ||
|
|
||
| messages = item.artifacts["messages"][-1]["messages"] | ||
| tool_turns = sum( | ||
| 1 for message in messages if isinstance(message.get("tool_calls"), list) and message["tool_calls"] | ||
| ) | ||
| reward_model.setdefault("agent_trace", messages) | ||
| reward_model.setdefault("num_turns", tool_turns) | ||
|
|
||
| response = str(item.artifacts.get("response") or "") | ||
| rollout_state = RolloutState( | ||
| message=[{"role": "user", "content": item.instruction}], | ||
| response=response, | ||
| reward_model=reward_model, | ||
| status=Status.COMPLETED, | ||
| ) | ||
| judged = await self.judger.judge(rollout_state) | ||
| reward_payload = judged.reward or {} | ||
| if self.reward_key not in reward_payload: | ||
| raise KeyError(f"judger reward payload has no {self.reward_key!r}: {reward_payload!r}") | ||
| record.metadata["reward"] = reward_payload | ||
| record.score = float(reward_payload[self.reward_key]) | ||
| record.status = StageStatus.COMPLETED | ||
| return record.score | ||
| except Exception as exc: | ||
| record.status = StageStatus.FAILED | ||
| record.error = record.error or RolloutError( | ||
| stage=self.name, | ||
| category="judger", | ||
| type=type(exc).__name__, | ||
| message=str(exc), | ||
| ) | ||
| raise | ||
| finally: | ||
| record.finished_at = time.monotonic() | ||
|
|
||
|
|
||
| __all__ = ["LocalhostJudgerStage"] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个代码可以考虑放到基类里面去,因为是通用的。否则其他agent 也要写一遍