From 96e9340500ff47acba7addeac5a3eb305adbfc39 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 26 May 2026 08:00:16 +0000 Subject: [PATCH] add reasoning rl of url --- .../rl/agent_loop/single_turn_agent_loop.py | 151 +++++++++++++++++- xtuner/v1/rl/rollout/session_server.py | 32 +++- xtuner/v1/rl/rollout/worker.py | 2 +- 3 files changed, 170 insertions(+), 15 deletions(-) diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py index 40edc2dca..0b501a847 100644 --- a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py +++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py @@ -1,13 +1,27 @@ import asyncio +import os +import traceback +import uuid +from typing import Any -from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status +import httpx + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status, update_status_from_finish_reason from xtuner.v1.rl.judger import Judger from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.trace_store import get_store from xtuner.v1.rl.utils import create_task from .agent_loop import AgentLoop, AgentLoopConfig +ROUTED_APIPROXY_BASE_URL = "http://s-20260104203038-22bhb.ailab-evalservice.pjh-service.org.cn/v1" +ROUTED_APIPROXY_API_KEY = "sk-admin" +ROUTED_APIPROXY_TIMEOUT = 3600.0 +ROUTED_APIPROXY_MAX_CONNECTIONS = 512 +ROUTED_APIPROXY_MAX_KEEPALIVE_CONNECTIONS = 128 + + class SingleTurnAgentLoopConfig(AgentLoopConfig): """Configuration for the built-in single-turn agent loop. @@ -38,6 +52,11 @@ class SingleTurnAgentLoopConfig(AgentLoopConfig): """ enable_batch_judge: bool = False + api_base_url: str = ROUTED_APIPROXY_BASE_URL + api_key: str = ROUTED_APIPROXY_API_KEY + api_timeout: float = ROUTED_APIPROXY_TIMEOUT + api_max_connections: int = ROUTED_APIPROXY_MAX_CONNECTIONS + api_max_keepalive_connections: int = ROUTED_APIPROXY_MAX_KEEPALIVE_CONNECTIONS def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "SingleTurnAgentLoop": return SingleTurnAgentLoop( @@ -47,6 +66,11 @@ def build_local(self, rollout_controller, judger: Judger | None = None, logger=N judger=judger, logger=logger, enable_batch_judge=self.enable_batch_judge, + api_base_url=self.api_base_url, + api_key=self.api_key, + api_timeout=self.api_timeout, + api_max_connections=self.api_max_connections, + api_max_keepalive_connections=self.api_max_keepalive_connections, ) @@ -59,28 +83,143 @@ def __init__( judger: Judger | None = None, logger=None, enable_batch_judge: bool = False, + api_base_url: str = ROUTED_APIPROXY_BASE_URL, + api_key: str = ROUTED_APIPROXY_API_KEY, + api_timeout: float = ROUTED_APIPROXY_TIMEOUT, + api_max_connections: int = ROUTED_APIPROXY_MAX_CONNECTIONS, + api_max_keepalive_connections: int = ROUTED_APIPROXY_MAX_KEEPALIVE_CONNECTIONS, ): super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) self.enable_batch_judge = enable_batch_judge + self.api_base_url = api_base_url.rstrip("/") + self.api_key = api_key + self.api_timeout = api_timeout + self.api_max_connections = api_max_connections + self.api_max_keepalive_connections = api_max_keepalive_connections + self._model_name = os.environ.get("MODEL_NAME") + self._http_client: httpx.AsyncClient | None = None + + def _get_http_client(self) -> httpx.AsyncClient: + if self._http_client is None or self._http_client.is_closed: + timeout = httpx.Timeout(self.api_timeout) + limits = httpx.Limits( + max_connections=self.api_max_connections, + max_keepalive_connections=self.api_max_keepalive_connections, + ) + self._http_client = httpx.AsyncClient(timeout=timeout, limits=limits) + return self._http_client async def generate_sample( self, rollout_state: RolloutState, **kwargs, ) -> RolloutState: - if not rollout_state.tokens: - rollout_state.tokens = rollout_state.prompt_ids + try: + if rollout_state.uid is None: + rollout_state.uid = uuid.uuid4().int + response = await self._chat_completions(rollout_state) + await self._fill_rollout_state_from_response(rollout_state, response) + except Exception as exc: + rollout_state.status = Status.FAILED + rollout_state.finish_reason = "error" + rollout_state.error_msg = f"{type(exc).__name__}: {exc}" + self.logger.error(f"[SingleTurnAgentLoop] failed: {exc}\n{traceback.format_exc()}") + return rollout_state - # 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上 - rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] - # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 if rollout_state.status != Status.COMPLETED: + # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 return rollout_state if self.judger is not None and not self.enable_batch_judge: # 如果开启了批量打分,则在 generate_group 里统一打分,不在这里逐条打分 rollout_state = await self.judger.judge(rollout_state) return rollout_state + def _build_http_payload(self, rollout_state: RolloutState, model_name: str) -> dict[str, Any]: + sample_params = rollout_state.sample_params + payload: dict[str, Any] = { + "model": model_name, + "session_id": str(rollout_state.uid), + "messages": rollout_state.message, + "max_tokens": sample_params.max_tokens, + "temperature": sample_params.temperature, + "top_p": sample_params.top_p, + "extra_body": {"spaces_between_special_tokens": sample_params.spaces_between_special_tokens}, + } + if sample_params.stops: + payload["stop"] = sample_params.stops + if rollout_state.tools is not None: + payload["tools"] = rollout_state.tools + if rollout_state.tool_choice is not None: + payload["tool_choice"] = rollout_state.tool_choice + return payload + + async def _chat_completions(self, rollout_state: RolloutState) -> dict[str, Any]: + model_name = self._model_name + if model_name is None: + raise ValueError("RL_LLM_MODEL environment variable is required for routed API rollout.") + url = f"{self.api_base_url}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + client = self._get_http_client() + payload = self._build_http_payload(rollout_state, model_name) + response = await client.post(url, headers=headers, json=payload) + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + raise RuntimeError(f"HTTP rollout failed: {exc}. response={response.text}") from exc + data = response.json() + if "choices" not in data: + raise RuntimeError(f"HTTP rollout response missing choices: {data}") + return data + + async def _fill_rollout_state_from_response( + self, + rollout_state: RolloutState, + response: dict[str, Any], + ) -> None: + choice = response["choices"][0] + message = choice["message"] + content = message.get("content") + reasoning_content = message.get("reasoning_content") + finish_reason = choice.get("finish_reason") + status = update_status_from_finish_reason(finish_reason) + + rollout_state.response = content if content is not None else reasoning_content + rollout_state.finish_reason = finish_reason + rollout_state.status = status + if status != Status.COMPLETED: + rollout_state.error_msg = f"HTTP rollout finished with status={status.value}, finish_reason={finish_reason}" + return + + messages = [dict(item) for item in rollout_state.message] + messages.append(message) + text = self.tokenizer.apply_chat_template( + messages, + tools=rollout_state.tools, + tokenize=False, + add_generation_prompt=False, + ).rstrip() + + trace_store = get_store() + data = await trace_store.export_training_trace.remote(str(rollout_state.uid), text) + rollout_state.input_ids = data["input_ids"] + rollout_state.labels = data["labels"] + rollout_state.response_ids = [ + token_id + for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) + if label != -100 + ] + if rollout_state.response is None: + rollout_state.response = self.tokenizer.decode( + rollout_state.response_ids, + skip_special_tokens=True, + ) + rollout_state.response_mask = [1] * len(rollout_state.response_ids) + rollout_state.logprobs = data["logprobs"] + rollout_state.routed_experts = data["routed_experts"] + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: pending_tasks = [] for state in rollout_state: diff --git a/xtuner/v1/rl/rollout/session_server.py b/xtuner/v1/rl/rollout/session_server.py index 7f430950f..1849c09e9 100644 --- a/xtuner/v1/rl/rollout/session_server.py +++ b/xtuner/v1/rl/rollout/session_server.py @@ -5,7 +5,7 @@ import numpy as np import ray -from aiohttp import ClientSession, web +from aiohttp import ClientSession, ClientTimeout, web from transformers import AutoTokenizer from xtuner.v1.utils import get_logger @@ -187,6 +187,18 @@ async def stop(self): get_logger().info("SessionServer stopped.") async def _handle_request(self, request: web.Request) -> web.Response: + try: + return await self._handle_request_impl(request) + except BaseException: + get_logger().exception( + "SessionServer request failed: remote=%s method=%s path=%s", + request.remote, + request.method, + request.path_qs, + ) + raise + + async def _handle_request_impl(self, request: web.Request) -> web.Response: """Proxy handler for the worker API.""" # Read the request body @@ -210,8 +222,9 @@ async def _handle_request(self, request: web.Request) -> web.Response: request_data = await self.on_request(request_data) # Re-serialize the modified payload back to bytes request_body = json.dumps(request_data).encode("utf-8") - except json.JSONDecodeError: - pass + except json.JSONDecodeError as e: + get_logger().error(f"Failed to parse request body: {request_body} error: {e}") + raise e # Build forwarding headers, dropping original Host forward_headers = dict(request.headers) @@ -260,7 +273,8 @@ def _clean_data(data: dict) -> bool: # read_bufsize controls StreamReader's line buffer limit; SSE events with large # tool_calls/reasoning_content payloads can exceed the 64KB default and trigger # "Chunk too big" from readuntil(b"\n"). - async with ClientSession(read_bufsize=self.read_bufsize) as client: + timeout = ClientTimeout(total=None, sock_connect=30) + async with ClientSession(read_bufsize=self.read_bufsize, timeout=timeout) as client: async with client.request( method=request.method, url=target_url, headers=forward_headers, data=request_body ) as resp: @@ -325,8 +339,9 @@ def _clean_data(data: dict) -> bool: else: try: response_data = json.loads(raw_response) - except json.JSONDecodeError: - pass + except json.JSONDecodeError as e: + get_logger().error(f"Failed to parse response body: {raw_response} error: {e}") + raise e if response_data is not None: for c in response_data.get("choices", []): @@ -367,8 +382,9 @@ def _parse_stream_response(raw: bytes) -> Optional[dict]: if line.startswith("data: ") and line != "data: [DONE]": try: events.append(json.loads(line[6:])) - except json.JSONDecodeError: - pass + except json.JSONDecodeError as e: + get_logger().error(f"Failed to parse stream response body: {line} error: {e}") + raise e if not events: return None diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index 4dabb19f7..56dbad7ae 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -698,7 +698,7 @@ def check_health(self) -> bool: "Authorization": f"Bearer {self.config.api_key}", } response = requests.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=5.0 + f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=60.0 ) return response.status_code == 200 except requests.RequestException as e: