From 628c389e39ff6a90a1f46aaa05940546981f4bb8 Mon Sep 17 00:00:00 2001 From: nuzant Date: Mon, 20 Apr 2026 13:59:08 +0800 Subject: [PATCH 1/3] feat(service): add external model API support for inference service (#1183) * feat(service): add external model API support for inference service * fix(service): address review feedback for external model API Key changes: - Default model field to "default" and validate non-empty on init - Remove redundant if-else in chat_completion; always set model/api_key - Revert cosmetic variable renames in workflow _run_online - Replace register_external_model with lazy get_or_create_session - Remove is_external_api; set needs_online_callback=True always - Add test for empty model validation * refactor(service): unify session resolution to bearer token auth Remove model-name-based session lookup in set_reward and chat/completions endpoints. All session resolution now uses bearer token, and external model interactions are recorded on the token-resolved session instead of auto-created per-model sessions. Key changes: - Simplify set_reward to use only bearer token auth - Move token extraction before external model dispatch - Remove unused SessionStore.get_or_create_session method - Add docstring to SessionData.export_interactions - Update tests to pass bearer token for external model flows * fix(service): remove unused imports and variable in data proxy Clean up lint issues flagged by Ruff: unused imports (orjson, Body, JSONResponse) and unused local variable in register_model. * fix(examples): fix external model routing and result printing in HITL demo Zeroclaw was not sending the correct model name because _patch_zeroclaw_config did not set default_model, causing requests to miss the registered external model in the data proxy and fall through to the non-existent internal path. The result printing also crashed because concat_string_interactions returns {"interactions": [...]}, not InteractionWithTokenLogpReward objects. Key changes: - Set default_model in zeroclaw config when --model is provided - Rename CLI args from --external-* to --api-url/--provider-api-key/--model - Fix result printing to use traj.get("interactions") dict access - Remove model param from _set_reward and _do_round (unused) - Restore original comments and step numbering --- .../inference_service/controller/config.py | 17 +- .../controller/controller.py | 223 +++-- .../inference_service/controller/workflow.py | 21 +- .../inference_service/data_proxy/__main__.py | 22 + .../inference_service/data_proxy/app.py | 202 ++++- .../inference_service/data_proxy/config.py | 6 + .../inference_service/data_proxy/session.py | 29 +- .../inference_service/gateway/app.py | 144 +++- .../inference_service/gateway/streaming.py | 93 ++ .../inference_service/router/app.py | 120 ++- .../inference_service/router/state.py | 51 ++ areal/experimental/openai/proxy/server.py | 25 +- areal/experimental/openai/types.py | 29 + areal/infra/workflow_executor.py | 19 +- .../experimental/inference_service/README.md | 33 +- .../human_in_the_loop_demo.py | 59 +- .../inference_service/online_rollout.py | 80 +- .../inference_service/test_controller.py | 120 +-- .../test_controller_integration.py | 32 +- .../test_controller_version.py | 1 + .../inference_service/test_external_model.py | 795 ++++++++++++++++++ .../test_external_model_integration.py | 376 +++++++++ .../test_ipv6_entrypoints.py | 4 + .../inference_service/test_online_stack.py | 1 + 24 files changed, 2238 insertions(+), 264 deletions(-) create mode 100644 tests/experimental/inference_service/test_external_model.py create mode 100644 tests/experimental/inference_service/test_external_model_integration.py diff --git a/areal/experimental/inference_service/controller/config.py b/areal/experimental/inference_service/controller/config.py index 2b8832a722..4d45b1391b 100644 --- a/areal/experimental/inference_service/controller/config.py +++ b/areal/experimental/inference_service/controller/config.py @@ -6,8 +6,6 @@ from dataclasses import dataclass, field -from areal.api.cli_args import OpenAIProxyConfig - @dataclass class GatewayControllerConfig: @@ -20,6 +18,7 @@ class GatewayControllerConfig: # -- Model / tokenizer ------------------------------------------------- tokenizer_path: str = "" model_path: str = "" + model: str = "default" # -- Routing ----------------------------------------------------------- routing_strategy: str = "round_robin" @@ -53,5 +52,15 @@ class GatewayControllerConfig: pause_grace_period: float = 0.5 n_gpus_per_node: int | None = None # GPUs per physical node; None = single-node - # -- OpenAI proxy configuration (for agent-like workflows) --------------- - openai: OpenAIProxyConfig = field(default_factory=lambda: OpenAIProxyConfig()) + # -- Admin / workflow -------------------------------------------------- + admin_api_key: str | None = None + turn_discount: float = 1.0 + export_style: str = "individual" + tool_call_parser: str = "qwen" + reasoning_parser: str = "qwen3" + engine_max_tokens: int | None = None + chat_template_type: str = "hf" + + # -- External model API ------------------------------------------------ + api_url: str | None = None + provider_api_key: str | None = None diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index 70bdfba148..7cc546a554 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -82,30 +82,45 @@ def __init__( config: GatewayControllerConfig, scheduler: Scheduler, ) -> None: - from areal.api.alloc_mode import ModelAllocation - + if config.admin_api_key is None: + raise ValueError( + "GatewayControllerConfig.admin_api_key must be set (not None)" + ) + if not config.model: + raise ValueError("GatewayControllerConfig.model must not be empty") self.config = config self.scheduler = scheduler - # Parse allocation from config.backend - self.rollout_alloc = ModelAllocation.from_str(config.backend) + if config.api_url is not None: + self.rollout_alloc = None + else: + from areal.api.alloc_mode import ModelAllocation - # Multi-node: derive nnodes_per_instance from n_gpus_per_node - total_gpus = ( - self.rollout_alloc.parallel.tp_size * self.rollout_alloc.parallel.pp_size - ) - n_gpus_per_node = config.n_gpus_per_node - if n_gpus_per_node is None: + self.rollout_alloc = ModelAllocation.from_str(config.backend) + + # Multi-node: derive nnodes_per_instance from n_gpus_per_node. + # External mode has no local inference servers, so always single-node. + if self.rollout_alloc is None: nnodes_per_instance = 1 else: - if n_gpus_per_node < 1: - raise ValueError(f"n_gpus_per_node must be >= 1, got {n_gpus_per_node}") - if total_gpus % n_gpus_per_node != 0: - raise ValueError( - f"tp_size * pp_size ({total_gpus}) must be divisible " - f"by n_gpus_per_node ({n_gpus_per_node})" - ) - nnodes_per_instance = total_gpus // n_gpus_per_node + total_gpus = ( + self.rollout_alloc.parallel.tp_size + * self.rollout_alloc.parallel.pp_size + ) + n_gpus_per_node = config.n_gpus_per_node + if n_gpus_per_node is None: + nnodes_per_instance = 1 + else: + if n_gpus_per_node < 1: + raise ValueError( + f"n_gpus_per_node must be >= 1, got {n_gpus_per_node}" + ) + if total_gpus % n_gpus_per_node != 0: + raise ValueError( + f"tp_size * pp_size ({total_gpus}) must be divisible " + f"by n_gpus_per_node ({n_gpus_per_node})" + ) + nnodes_per_instance = total_gpus // n_gpus_per_node self._nnodes_per_instance = nnodes_per_instance # Worker management @@ -209,6 +224,19 @@ def initialize( logger.info("GatewayInferenceController initialized (role=%s)", role) + if self.config.model: + self.register_model( + model=self.config.model, + url=self.config.api_url or "", + api_key=self.config.provider_api_key, + ) + if self.external_mode: + logger.info( + "External model mode: url=%s, model=%s", + self.config.api_url, + self.config.model, + ) + async def _async_initialize( self, server_args: dict[str, Any] | None, @@ -226,6 +254,8 @@ async def _async_initialize( * **server_infos is not None** — SGLang servers already exist so we only fork data proxy on every worker; fork router + gateway on worker 0. + * **external_mode** — skip inference servers entirely; data proxies + start with an empty ``--backend-addr``. """ from dataclasses import asdict @@ -234,35 +264,49 @@ async def _async_initialize( from areal.api.cli_args import SchedulingSpec, SchedulingStrategy from areal.api.scheduler_api import Job - alloc = self.rollout_alloc - dp_size = alloc.parallel.dp_size cfg = self.config - admin_api_key = self.config.openai.admin_api_key + admin_api_key = self.config.admin_api_key - inf_backend = alloc.backend + if self.external_mode: + dp_size = 1 + inf_backend = None + else: + alloc = self.rollout_alloc + dp_size = alloc.parallel.dp_size + inf_backend = alloc.backend # ================================================================== # Step 0: Create RPCGuard workers (dp_size × nnodes_per_instance) # ================================================================== - inf_spec = SchedulingSpec(**asdict(cfg.scheduling_spec[0])) - instance_size = alloc.parallel.tp_size * alloc.parallel.pp_size - nnodes_per_instance = self._nnodes_per_instance - gpus_per_worker = instance_size // nnodes_per_instance - - if server_infos is not None: - # Pre-existing inference servers — only need dp_size workers - # for CPU services (data proxy, router, gateway), no GPUs. + if self.external_mode: + inf_spec = SchedulingSpec( + task_type="worker", + port_count=2, + gpu=0, + mem=8, + cmd="python -m areal.experimental.inference_service.guard", + ) total_workers = dp_size - inf_spec.gpu = 0 else: - total_workers = dp_size * nnodes_per_instance - inf_spec.cpu *= gpus_per_worker - inf_spec.mem *= gpus_per_worker - if inf_spec.gpu > 0: - inf_spec.gpu = gpus_per_worker + inf_spec = SchedulingSpec(**asdict(cfg.scheduling_spec[0])) + instance_size = alloc.parallel.tp_size * alloc.parallel.pp_size + nnodes_per_instance = self._nnodes_per_instance + gpus_per_worker = instance_size // nnodes_per_instance + + if server_infos is not None: + # Pre-existing inference servers — only need dp_size workers + # for CPU services (data proxy, router, gateway), no GPUs. + total_workers = dp_size + inf_spec.gpu = 0 + else: + total_workers = dp_size * nnodes_per_instance + inf_spec.cpu *= gpus_per_worker + inf_spec.mem *= gpus_per_worker + if inf_spec.gpu > 0: + inf_spec.gpu = gpus_per_worker - # Override cmd to launch RPCGuard instead of RPC server - inf_spec.cmd = "python -m areal.experimental.inference_service.guard" + # Override cmd to launch RPCGuard instead of RPC server + inf_spec.cmd = "python -m areal.experimental.inference_service.guard" inf_role = f"{self._worker_role}{self._INF_SUFFIX}" inf_job = Job( @@ -284,9 +328,11 @@ async def _async_initialize( logger.info("RPCGuard workers ready: %s", [w.id for w in inf_workers]) # ================================================================== - # Step 1: Launch inference servers (skip when pre-existing) + # Step 1: Launch inference servers (skip in external mode or when pre-existing) # ================================================================== - if server_infos is not None: + if self.external_mode: + logger.info("External mode — skipping inference server launch") + elif server_infos is not None: # Pre-existing servers — just record their addresses self.server_infos = server_infos self._inf_addrs = [ @@ -536,22 +582,39 @@ def _build_launch_cmd( str(cfg.set_reward_finish_timeout), "--callback-server-addr", f"http://{self.callback_addr}", + "--tool-call-parser", + cfg.tool_call_parser, + "--reasoning-parser", + cfg.reasoning_parser, + "--chat-template-type", + cfg.chat_template_type, ] + if cfg.engine_max_tokens is not None: + data_proxy_base_cmd += [ + "--engine-max-tokens", + str(cfg.engine_max_tokens), + ] for group_idx in range(dp_size): - head_worker = inf_workers[ - group_idx - if server_infos is not None - else group_idx * nnodes_per_instance - ] + if self.external_mode: + head_worker = inf_workers[group_idx] + else: + head_worker = inf_workers[ + group_idx + if server_infos is not None + else group_idx * nnodes_per_instance + ] guard_addr = f"http://{format_hostport(head_worker.ip, int(head_worker.worker_ports[0]))}" # Each data proxy connects to its group's head inference server - data_proxy_cmd = data_proxy_base_cmd + [ - "--backend-addr", - self._inf_addrs[group_idx], - "--backend-type", - inf_backend or "sglang", - ] + if self.external_mode: + data_proxy_cmd = data_proxy_base_cmd + ["--backend-addr", ""] + else: + data_proxy_cmd = data_proxy_base_cmd + [ + "--backend-addr", + self._inf_addrs[group_idx], + "--backend-type", + inf_backend or "sglang", + ] data_proxy_host, data_proxy_port = self._fork_on_guard( guard_addr=guard_addr, role="data-proxy", @@ -619,7 +682,7 @@ def _register_data_proxies_in_router(self) -> None: resp = requests.post( f"{self._router_addr}/register", json={"worker_addr": data_proxy_addr}, - headers={"Authorization": f"Bearer {self.config.openai.admin_api_key}"}, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, timeout=5, ) resp.raise_for_status() @@ -632,6 +695,34 @@ def _register_data_proxies_in_router(self) -> None: worker_id, ) + def register_model( + self, + model: str, + url: str = "", + api_key: str | None = None, + data_proxy_addrs: list[str] | None = None, + ) -> None: + import requests + + if data_proxy_addrs is None: + data_proxy_addrs = self._data_proxy_addrs + resp = requests.post( + f"{self._gateway_addr}/register_model", + json={ + "model": model, + "url": url, + "api_key": api_key, + "data_proxy_addrs": data_proxy_addrs, + }, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, + timeout=self.config.request_timeout, + ) + resp.raise_for_status() + + @property + def external_mode(self) -> bool: + return self.config.api_url is not None + def _start_online_callback_server(self) -> None: """Start callback server used by the router to deliver ready trajectories.""" if self._callback_server is not None: @@ -647,7 +738,7 @@ def _start_online_callback_server(self) -> None: @app.route("/callback/online_ready", methods=["POST"]) def online_ready(): if request.headers.get("Authorization") != ( - f"Bearer {self.config.openai.admin_api_key}" + f"Bearer {self.config.admin_api_key}" ): return jsonify({"error": "Invalid admin API key"}), 403 payload = request.get_json() or {} @@ -1089,10 +1180,11 @@ async def chat_completion( if extra_body and isinstance(extra_body, dict): body.update(extra_body) + body["model"] = self.config.model api_key = ( session_api_key if session_api_key is not None - else self.config.openai.admin_api_key + else self.config.admin_api_key ) url = f"{self._gateway_addr}/chat/completions" headers = { @@ -1261,7 +1353,7 @@ def _wrap_agent(self, agent: Any): "Gateway address is unavailable; initialize the controller first" ) - openai_cfg = self.config.openai + openai_cfg = self.config admin_api_key = openai_cfg.admin_api_key turn_discount = openai_cfg.turn_discount export_style = openai_cfg.export_style @@ -1300,6 +1392,19 @@ def _resolve_workflow( from areal.api.workflow_api import RolloutWorkflow from areal.utils.dynamic_import import import_from_string + # External mode only supports online mode (workflow=None) + if self.external_mode and workflow is not None: + raise ValueError( + "External model mode only supports online mode (workflow=None). " + "Agent-based workflows are not supported with external models." + ) + + if self.external_mode and group_size > 1: + raise ValueError( + "External model mode requires group_size=1, " + f"got group_size={group_size}." + ) + # (a) None → online mode: create InferenceServiceWorkflow without agent if workflow is None: from areal.experimental.inference_service.controller.workflow import ( @@ -1312,7 +1417,7 @@ def _resolve_workflow( controller=self, agent=None, gateway_addr=self._gateway_addr, - admin_api_key=self.config.openai.admin_api_key, + admin_api_key=self.config.admin_api_key, **online_kwargs, ) @@ -1473,7 +1578,7 @@ def _gateway_http_post(self, endpoint: str, payload: dict[str, Any]) -> None: resp = requests.post( url, json=payload, - headers={"Authorization": f"Bearer {self.config.openai.admin_api_key}"}, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, timeout=self.config.request_timeout, ) if resp.status_code >= 400: @@ -1500,9 +1605,7 @@ async def _async_gateway_http_post( resp = await client.post( url, json=payload, - headers={ - "Authorization": f"Bearer {self.config.openai.admin_api_key}" - }, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, ) if resp.status_code >= 400: raise RuntimeError( diff --git a/areal/experimental/inference_service/controller/workflow.py b/areal/experimental/inference_service/controller/workflow.py index 7790edb01c..95f0571770 100644 --- a/areal/experimental/inference_service/controller/workflow.py +++ b/areal/experimental/inference_service/controller/workflow.py @@ -7,6 +7,7 @@ import aiohttp from areal.api.workflow_api import RolloutWorkflow +from areal.experimental.openai.proxy.server import deserialize_interactions from areal.infra import workflow_context from areal.utils import logging, stats_tracker @@ -25,23 +26,6 @@ _EXPORT_TRAJECTORIES_PATHNAME = "export_trajectories" -def _deserialize_interactions( - data: dict[str, Any], -) -> dict[str, InteractionWithTokenLogpReward]: - from areal.experimental.openai.types import InteractionWithTokenLogpReward - from areal.infra.rpc.serialization import deserialize_value - - data = deserialize_value(data) - result: dict[str, InteractionWithTokenLogpReward] = {} - for key, item in data.items(): - interaction = InteractionWithTokenLogpReward() - interaction._cache = item["tensor_dict"] - interaction.reward = item["reward"] - interaction.interaction_id = item["interaction_id"] - result[key] = interaction - return result - - class InferenceServiceWorkflow(RolloutWorkflow): def __init__( self, @@ -110,7 +94,8 @@ async def _export_interactions( async with session.post(url, json=payload, headers=headers) as resp: resp.raise_for_status() data = await resp.json() - return _deserialize_interactions(data["interactions"]) + + return deserialize_interactions(data["interactions"]) async def arun_episode( self, diff --git a/areal/experimental/inference_service/data_proxy/__main__.py b/areal/experimental/inference_service/data_proxy/__main__.py index e36722d8d8..ef2ea87894 100644 --- a/areal/experimental/inference_service/data_proxy/__main__.py +++ b/areal/experimental/inference_service/data_proxy/__main__.py @@ -52,6 +52,24 @@ def main(): "--callback-server-addr", default="", ) + parser.add_argument( + "--tool-call-parser", + default="qwen", + ) + parser.add_argument( + "--reasoning-parser", + default="qwen3", + ) + parser.add_argument( + "--engine-max-tokens", + type=int, + default=None, + ) + parser.add_argument( + "--chat-template-type", + default="hf", + choices=("hf", "concat"), + ) args, _ = parser.parse_known_args() # Resolve the actual serving host (replace 0.0.0.0 with real IP) @@ -73,6 +91,10 @@ def main(): admin_api_key=args.admin_api_key, callback_server_addr=args.callback_server_addr, serving_addr=format_hostport(serving_host, args.port), + tool_call_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + engine_max_tokens=args.engine_max_tokens, + chat_template_type=args.chat_template_type, ) app = create_app(config) uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) diff --git a/areal/experimental/inference_service/data_proxy/app.py b/areal/experimental/inference_service/data_proxy/app.py index 78d55f7821..947b4a6ebe 100644 --- a/areal/experimental/inference_service/data_proxy/app.py +++ b/areal/experimental/inference_service/data_proxy/app.py @@ -4,16 +4,16 @@ import asyncio import hmac +import json from contextlib import asynccontextmanager from typing import Any import httpx from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.wsgi import WSGIMiddleware +from fastapi.responses import Response as RawResponse from fastapi.responses import StreamingResponse from flask import Flask -from openai.types.chat.completion_create_params import CompletionCreateParams -from pydantic import BaseModel from areal.experimental.inference_service.data_proxy.backend import ( SGLangBridgeBackend, @@ -140,11 +140,16 @@ def _create_inf_bridge( def _create_areal_client( inf_bridge: InfBridge, tok: TokenizerProxy, + config: DataProxyConfig, ) -> ArealOpenAI: """Create an ArealOpenAI client backed by the given InfBridge.""" return ArealOpenAI( engine=inf_bridge, tokenizer=tok._tok, + tool_call_parser=config.tool_call_parser, + reasoning_parser=config.reasoning_parser, + engine_max_tokens=config.engine_max_tokens, + chat_template_type=config.chat_template_type, ) @@ -228,19 +233,11 @@ def create_app(config: DataProxyConfig) -> FastAPI: async def lifespan(app: FastAPI): logger.info( "Data proxy starting — backend=%s, tokenizer=%s", - config.backend_addr, + config.backend_addr or "(none)", config.tokenizer_path, ) - tok = TokenizerProxy(config.tokenizer_path) - pause_state = PauseState() - # InfBridge + ArealOpenAI for /chat/completions - inf_bridge = _create_inf_bridge(config.backend_addr, pause_state, config) - areal_client = _create_areal_client(inf_bridge, tok) - - app.state.tokenizer = tok - app.state.inf_bridge = inf_bridge - app.state.areal_client = areal_client + pause_state = PauseState() app.state.pause_state = pause_state app.state.config = config app.state.session_store = SessionStore( @@ -248,6 +245,19 @@ async def lifespan(app: FastAPI): ) app.state.session_store.set_admin_key(config.admin_api_key) app.state.version = 0 + + if not config.backend_addr: + app.state.tokenizer = None + app.state.inf_bridge = None + app.state.areal_client = None + else: + tok = TokenizerProxy(config.tokenizer_path) + inf_bridge = _create_inf_bridge(config.backend_addr, pause_state, config) + areal_client = _create_areal_client(inf_bridge, tok, config) + app.state.tokenizer = tok + app.state.inf_bridge = inf_bridge + app.state.areal_client = areal_client + ready_task = asyncio.create_task(_ready_trajectory_loop(app)) try: yield @@ -260,6 +270,7 @@ async def lifespan(app: FastAPI): logger.info("Data proxy shutting down") app = FastAPI(title="AReaL Data Proxy", lifespan=lifespan) + _registered_models: dict[str, dict[str, str | None]] = {} # ========================================================================= # Health @@ -287,13 +298,23 @@ async def configure(): @app.post("/pause_generation") async def pause_generation(): - inf_bridge: InfBridge = app.state.inf_bridge + inf_bridge: InfBridge | None = app.state.inf_bridge + if inf_bridge is None: + raise HTTPException( + status_code=503, + detail="No inference backend configured (external model mode).", + ) await inf_bridge.pause() return {"status": "ok", "paused": True} @app.post("/continue_generation") async def continue_generation(): - inf_bridge: InfBridge = app.state.inf_bridge + inf_bridge: InfBridge | None = app.state.inf_bridge + if inf_bridge is None: + raise HTTPException( + status_code=503, + detail="No inference backend configured (external model mode).", + ) await inf_bridge.resume() return {"status": "ok", "paused": False} @@ -368,24 +389,137 @@ async def set_reward(body: SetRewardRequest, request: Request): # ========================================================================= @app.post("/chat/completions") - async def chat_completions(body: CompletionCreateParams, request: Request): + async def chat_completions(request: Request): + raw_body = await request.body() + try: + body_json = json.loads(raw_body) + except (json.JSONDecodeError, AttributeError): + raise HTTPException(status_code=400, detail="Invalid JSON body") + + model_name = body_json.get("model") store: SessionStore = app.state.session_store - areal_client: ArealOpenAI = app.state.areal_client token = _try_extract_bearer_token(request) session = _resolve_session_from_token(token, store) if session is not None: session.update_last_access() + + # ----------------------------------------------------------------- + # External model path: model is a registered external model name + # ----------------------------------------------------------------- + ext_info = _registered_models.get(model_name) if model_name else None + if ext_info is not None and ext_info.get("url"): + ext_url = (ext_info["url"] or "").rstrip("/") + ext_model = ext_info["model"] + provider_api_key = ext_info.get("api_key") + + forward_body = dict(body_json) + if ext_model is not None: + forward_body["model"] = ext_model + else: + forward_body.pop("model", None) + + _skip = {"host", "content-length", "transfer-encoding", "authorization"} + forward_headers = { + k: v for k, v in dict(request.headers).items() if k.lower() not in _skip + } + if provider_api_key: + forward_headers["authorization"] = f"Bearer {provider_api_key}" + + is_streaming = forward_body.get("stream", False) or False + messages = body_json.get("messages", []) + + if is_streaming: + collected_chunks: list[str] = [] + + async def _stream_and_cache(): + success = False + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(config.request_timeout) + ) as client: + async with client.stream( + "POST", + f"{ext_url}/chat/completions", + json=forward_body, + headers=forward_headers, + ) as resp: + if resp.status_code != 200: + error_body = await resp.aread() + yield ( + f"data: {json.dumps({'error': error_body.decode()})}\n\n".encode() + ) + return + async for chunk in resp.aiter_bytes(): + decoded = chunk.decode("utf-8", errors="replace") + collected_chunks.append(decoded) + yield chunk + success = True + except Exception as exc: + logger.error( + "External stream error for %s: %s", model_name, exc + ) + yield f"data: {json.dumps({'error': str(exc)})}\n\n".encode() + finally: + if success and collected_chunks and session is not None: + session.add_string_interaction( + messages, + "".join(collected_chunks), + ) + + return StreamingResponse( + _stream_and_cache(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + + full_url = f"{ext_url}/chat/completions" + try: + async with httpx.AsyncClient(timeout=config.request_timeout) as client: + resp = await client.post( + full_url, + json=forward_body, + headers=forward_headers, + ) + except Exception as exc: + raise HTTPException( + status_code=502, detail=f"External API error: {exc}" + ) + + if resp.status_code != 200: + logger.error( + "External API returned %d for %s: %s", + resp.status_code, + full_url, + resp.text[:500], + ) + + response_str = resp.text + + if resp.status_code == 200 and session is not None: + session.add_string_interaction(messages, response_str) + + return RawResponse( + content=resp.content, + status_code=resp.status_code, + media_type=resp.headers.get("content-type"), + ) + + # ----------------------------------------------------------------- + # Internal model path: use AReaL inference server + # ----------------------------------------------------------------- + areal_client: ArealOpenAI = app.state.areal_client + + if session is not None: areal_cache: Any = session.active_completions else: areal_cache = None # Build kwargs from request body - if isinstance(body, BaseModel): - kwargs = body.model_dump() - else: - kwargs = dict(body) - + kwargs = dict(body_json) # Remove model (ArealOpenAI ignores it) kwargs.pop("model", None) @@ -430,6 +564,19 @@ async def _sse_stream(): return result + @app.post("/register_model") + async def register_model(request: Request): + body = await request.json() + name = body.get("name") or body.get("model") + url = body.get("url", "") + model = body.get("model", name) + api_key = body.get("api_key") + if not name: + raise HTTPException(status_code=400, detail="model name is required") + _registered_models[name] = {"url": url, "model": model, "api_key": api_key} + logger.info("Model registered: name=%s url=%s", name, url or "(internal)") + return {"status": "ok", "name": name} + # ========================================================================= # Trajectory export (admin key required) # ========================================================================= @@ -465,10 +612,13 @@ async def export_trajectories( from areal.infra.rpc.rtensor import RTensor for item in interactions.values(): - # Set the internal cache - item.to_tensor_dict() - # Remotize the tensor dict cache - item._cache = RTensor.remotize(item._cache, node_addr=config.serving_addr) + if item.has_tensor_data: + # Set the internal cache + item.to_tensor_dict() + # Remotize the tensor dict cache + item._cache = RTensor.remotize( + item._cache, node_addr=config.serving_addr + ) # serialize RTensors serialized = serialize_interactions(interactions) @@ -501,7 +651,7 @@ async def configure_backend(request: Request): # Recreate InfBridge + ArealOpenAI with new backend address new_inf_bridge = _create_inf_bridge(new_addr, pause_state, app.state.config) - new_areal_client = _create_areal_client(new_inf_bridge, tok) + new_areal_client = _create_areal_client(new_inf_bridge, tok, app.state.config) # Build updated config copy, then swap all three state fields. # Concurrent requests already hold their own references so they diff --git a/areal/experimental/inference_service/data_proxy/config.py b/areal/experimental/inference_service/data_proxy/config.py index 5299934ecd..44b2f6fe91 100644 --- a/areal/experimental/inference_service/data_proxy/config.py +++ b/areal/experimental/inference_service/data_proxy/config.py @@ -20,3 +20,9 @@ class DataProxyConfig: # Resolved serving address (host:port) used as node_addr for RTensor shards. # Set at startup by __main__.py after the host is resolved. serving_addr: str = "" + + # ArealOpenAI client parameters (forwarded from OpenAIProxyConfig) + tool_call_parser: str = "qwen" + reasoning_parser: str = "qwen3" + engine_max_tokens: int | None = None + chat_template_type: str = "hf" diff --git a/areal/experimental/inference_service/data_proxy/session.py b/areal/experimental/inference_service/data_proxy/session.py index b8f79d1bcf..5646104fb8 100644 --- a/areal/experimental/inference_service/data_proxy/session.py +++ b/areal/experimental/inference_service/data_proxy/session.py @@ -7,16 +7,15 @@ import secrets import threading import time +import uuid from collections import OrderedDict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import Any from pydantic import BaseModel from areal.experimental.openai.cache import InteractionCache - -if TYPE_CHECKING: - from areal.experimental.openai.types import InteractionWithTokenLogpReward +from areal.experimental.openai.types import InteractionWithTokenLogpReward # Session timeout for cleanup (1 hour) SESSION_TIMEOUT_SECONDS = 3600 @@ -46,6 +45,7 @@ class SetRewardRequest(BaseModel): interaction_id: str | None = None reward: float + model: str | None = None class ExportTrajectoriesRequest(BaseModel): @@ -61,7 +61,7 @@ class ExportTrajectoriesRequest(BaseModel): class ExportTrajectoriesResponse(BaseModel): """Response containing serialized interactions.""" - interactions: dict[str, Any] + interactions: Any @dataclass(frozen=True) @@ -110,7 +110,11 @@ class SessionData: via repeated ``set_reward`` → ``export_trajectory`` calls. """ - def __init__(self, session_id: str, set_reward_finish_timeout: float = 0.0): + def __init__( + self, + session_id: str, + set_reward_finish_timeout: float = 0.0, + ): self.session_id = session_id self._set_reward_finish_timeout = set_reward_finish_timeout self._last_access_time = time.time() @@ -179,7 +183,7 @@ def _mark_active_trajectory_ready_locked( interaction_id=resolved_interaction_id, completions=completions, created_at=now, - needs_online_callback=self.session_id == "__hitl__", + needs_online_callback=True, ) self._ready_trajectories[trajectory_id] = ready self._active_completions = InteractionCache() @@ -273,6 +277,17 @@ def mark_online_callback_delivered(self, trajectory_id: int) -> bool: ready.callback_delivered = True return True + def add_string_interaction(self, messages: list[dict], response: str) -> str: + interaction_id = str(uuid.uuid4()) + interaction = InteractionWithTokenLogpReward( + messages=messages, + output_message_list=[{"role": "assistant", "content": response}], + ) + interaction._interaction_id = interaction_id + self._active_completions[interaction_id] = interaction + self.update_last_access() + return interaction_id + def export_trajectory( self, discount: float, diff --git a/areal/experimental/inference_service/gateway/app.py b/areal/experimental/inference_service/gateway/app.py index fe6720e671..6619258381 100644 --- a/areal/experimental/inference_service/gateway/app.py +++ b/areal/experimental/inference_service/gateway/app.py @@ -27,8 +27,11 @@ forward_request, forward_sse_stream, grant_capacity_in_router, + list_models_from_router, query_router, + register_model_in_router, register_session_in_router, + remove_model_from_router, resolve_worker_addr, revoke_session_in_router, ) @@ -67,6 +70,18 @@ async def health(): @app.post("/chat/completions") async def chat_completions(request: Request): token = extract_bearer_token(request) + body = await request.body() + headers = _forwarding_headers(dict(request.headers)) + + model_name = None + is_streaming = False + try: + body_json = json.loads(body) + model_name = body_json.get("model") + is_streaming = body_json.get("stream", False) or False + except (json.JSONDecodeError, AttributeError): + pass + try: worker_addr = await query_router( config.router_addr, @@ -74,21 +89,11 @@ async def chat_completions(request: Request): "/chat/completions", config.router_timeout, admin_api_key=config.admin_api_key, + model=model_name, ) except (RouterUnreachableError, RouterKeyRejectedError) as exc: return _router_error_response(exc) - body = await request.body() - headers = _forwarding_headers(dict(request.headers)) - - # Detect streaming from request body - is_streaming = False - try: - body_json = json.loads(body) - is_streaming = body_json.get("stream", False) or False - except (json.JSONDecodeError, AttributeError): - pass - if is_streaming: return StreamingResponse( forward_sse_stream( @@ -110,6 +115,72 @@ async def chat_completions(request: Request): media_type=resp.headers.get("content-type"), ) + @app.post("/register_model") + async def register_model(request: Request): + require_admin_key(request, config.admin_api_key) + body = await request.json() + model = body.get("model") + url = body.get("url", "") + api_key = body.get("api_key") + data_proxy_addrs = body.get("data_proxy_addrs", []) + if not model: + return JSONResponse({"error": "model is required"}, status_code=400) + try: + result = await register_model_in_router( + config.router_addr, + model, + url, + api_key, + data_proxy_addrs, + config.admin_api_key, + config.router_timeout, + ) + except (RouterUnreachableError, RouterKeyRejectedError) as exc: + return _router_error_response(exc) + + resolved_addrs = result.get("data_proxy_addrs", data_proxy_addrs) + headers = _forwarding_headers(dict(request.headers)) + + for addr in resolved_addrs: + resp = await forward_request( + f"{addr}/register_model", + json.dumps( + { + "name": model, + "url": url, + "model": model, + "api_key": api_key, + } + ).encode(), + headers, + config.forward_timeout, + ) + if resp.status_code != 200: + await remove_model_from_router( + config.router_addr, + model, + config.admin_api_key, + config.router_timeout, + ) + return JSONResponse( + {"error": f"Data proxy registration failed: {resp.text}"}, + status_code=502, + ) + return result + + @app.get("/models") + async def list_models(request: Request): + require_admin_key(request, config.admin_api_key) + try: + names = await list_models_from_router( + config.router_addr, + config.admin_api_key, + config.router_timeout, + ) + except (RouterUnreachableError, RouterKeyRejectedError) as exc: + return _router_error_response(exc) + return {"models": names} + # ========================================================================= # POST /rl/start_session — admin key ONLY, intercept response # ========================================================================= @@ -176,6 +247,16 @@ async def start_session(request: Request): @app.post("/rl/set_reward") async def set_reward(request: Request): token = extract_bearer_token(request) + body = await request.body() + headers = _forwarding_headers(dict(request.headers)) + + model = None + try: + body_json = json.loads(body) + model = body_json.get("model") + except (json.JSONDecodeError, AttributeError): + pass + try: worker_addr = await query_router( config.router_addr, @@ -183,16 +264,14 @@ async def set_reward(request: Request): "/rl/set_reward", config.router_timeout, admin_api_key=config.admin_api_key, + model=model, ) except (RouterUnreachableError, RouterKeyRejectedError) as exc: return _router_error_response(exc) - body = await request.body() - headers = _forwarding_headers(dict(request.headers)) resp = await forward_request( f"{worker_addr}/rl/set_reward", body, headers, config.forward_timeout ) - return Response( content=resp.content, status_code=resp.status_code, @@ -248,24 +327,26 @@ async def continue_generation(worker_id: str, request: Request): return {"results": results} # ========================================================================= - # POST /export_trajectories — admin key ONLY, route by session_id + # POST /export_trajectories — admin key ONLY, route by session_id or model field # ========================================================================= @app.post("/export_trajectories") async def export_trajectories(request: Request): require_admin_key(request, config.admin_api_key) - body = await request.body() - # Parse body to extract session_id for routing try: body_json = json.loads(body) - session_id = body_json.get("session_id") except (json.JSONDecodeError, AttributeError): - return JSONResponse( - {"error": "Invalid JSON body or missing session_id"}, - status_code=400, - ) + return JSONResponse({"error": "Invalid JSON body"}, status_code=400) + + model = body_json.get("model") + session_id = body_json.get("session_id") + + if model and not session_id: + session_id = model + body_json["session_id"] = session_id + body = json.dumps(body_json).encode() if not session_id: return JSONResponse({"error": "session_id is required"}, status_code=400) @@ -276,6 +357,7 @@ async def export_trajectories(request: Request): timeout=config.router_timeout, session_id=session_id, admin_api_key=config.admin_api_key, + model=model, ) except (RouterUnreachableError, RouterKeyRejectedError) as exc: return _router_error_response(exc) @@ -288,9 +370,6 @@ async def export_trajectories(request: Request): config.forward_timeout, ) - # Always ask the router to clean up after successful export. - # The router itself distinguishes offline one-shot sessions from - # persistent online sessions and will keep online bindings intact. if resp.status_code == 200: await revoke_session_in_router( config.router_addr, @@ -399,4 +478,19 @@ async def grant_capacity(request: Request): continue_generation, methods=["POST"], ) + + # ========================================================================= + # OpenAI / OpenRouter compatibility aliases — /v1/* prefixed routes + # ========================================================================= + app.add_api_route( + "/v1/chat/completions", + chat_completions, + methods=["POST"], + ) + app.add_api_route( + "/v1/models", + list_models, + methods=["GET"], + ) + return app diff --git a/areal/experimental/inference_service/gateway/streaming.py b/areal/experimental/inference_service/gateway/streaming.py index 19b30404cf..3a9cbda77d 100644 --- a/areal/experimental/inference_service/gateway/streaming.py +++ b/areal/experimental/inference_service/gateway/streaming.py @@ -38,6 +38,7 @@ async def query_router( *, session_id: str | None = None, admin_api_key: str | None = None, + model: str | None = None, ) -> str: """Ask the Router for a worker address. @@ -59,6 +60,8 @@ async def query_router( Router returned 404 (unknown key / session) or 503 (no healthy workers). """ payload: dict[str, str] = {} + if model is not None: + payload["model"] = model if session_id is not None: payload["session_id"] = session_id else: @@ -234,6 +237,96 @@ async def get_all_worker_addrs( raise RouterUnreachableError(f"Failed to get workers: {exc}") from exc +async def register_model_in_router( + router_addr: str, + model: str, + url: str, + api_key: str | None, + data_proxy_addrs: list[str], + admin_api_key: str, + timeout: float, +) -> dict: + try: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + f"{router_addr}/register_model", + json={ + "model": model, + "url": url, + "api_key": api_key, + "data_proxy_addrs": data_proxy_addrs, + }, + headers={"Authorization": f"Bearer {admin_api_key}"}, + ) + if resp.status_code == 503: + raise RouterKeyRejectedError("No healthy workers", 503) + resp.raise_for_status() + return resp.json() + except (httpx.ConnectError, httpx.ConnectTimeout) as exc: + raise RouterUnreachableError(f"Router unreachable: {exc}") from exc + + +async def route_external_model( + router_addr: str, + name: str, + admin_api_key: str, + timeout: float, +) -> dict: + try: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + f"{router_addr}/route", + json={"model": name}, + headers={"Authorization": f"Bearer {admin_api_key}"}, + ) + if resp.status_code == 404: + raise RouterKeyRejectedError(f"Model '{name}' not found", 404) + if resp.status_code == 503: + raise RouterKeyRejectedError(f"No healthy workers for model '{name}'", 503) + resp.raise_for_status() + return resp.json() + except RouterKeyRejectedError: + raise + except (httpx.ConnectError, httpx.ConnectTimeout) as exc: + raise RouterUnreachableError(f"Router unreachable: {exc}") from exc + + +async def list_models_from_router( + router_addr: str, + admin_api_key: str, + timeout: float, +) -> list[str]: + try: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.get( + f"{router_addr}/models", + headers={"Authorization": f"Bearer {admin_api_key}"}, + ) + resp.raise_for_status() + return resp.json().get("models", []) + except (httpx.ConnectError, httpx.ConnectTimeout) as exc: + raise RouterUnreachableError(f"Router unreachable: {exc}") from exc + + +async def remove_model_from_router( + router_addr: str, + name: str, + admin_api_key: str, + timeout: float, +) -> None: + """Remove an external model from the router registry (best-effort rollback).""" + try: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + f"{router_addr}/remove_model", + json={"name": name}, + headers={"Authorization": f"Bearer {admin_api_key}"}, + ) + resp.raise_for_status() + except Exception: + pass # Best-effort rollback; swallow errors + + async def resolve_worker_addr( router_addr: str, admin_api_key: str, diff --git a/areal/experimental/inference_service/router/app.py b/areal/experimental/inference_service/router/app.py index a1ad2de4b2..8ad85a3ad5 100644 --- a/areal/experimental/inference_service/router/app.py +++ b/areal/experimental/inference_service/router/app.py @@ -24,6 +24,7 @@ from areal.experimental.inference_service.router.config import RouterConfig from areal.experimental.inference_service.router.state import ( CapacityManager, + ModelRegistry, SessionRegistry, WorkerRegistry, ) @@ -75,6 +76,7 @@ class RouteRequest(BaseModel): api_key: str | None = None path: str | None = None session_id: str | None = None + model: str | None = None class RegisterSessionRequest(BaseModel): @@ -87,6 +89,17 @@ class RemoveSessionRequest(BaseModel): session_id: str +class RegisterModelRequest(BaseModel): + model: str + url: str = "" + api_key: str | None = None + data_proxy_addrs: list[str] = [] + + +class RemoveModelRequest(BaseModel): + name: str + + # ============================================================================= # App factory # ============================================================================= @@ -97,6 +110,7 @@ def create_app(config: RouterConfig) -> FastAPI: worker_registry = WorkerRegistry() session_registry = SessionRegistry() + model_registry = ModelRegistry() capacity_manager = CapacityManager() strategy = get_strategy(config.routing_strategy) @@ -127,6 +141,7 @@ async def lifespan(app: FastAPI): poll_task = asyncio.create_task(_poll_workers()) app.state.worker_registry = worker_registry app.state.session_registry = session_registry + app.state.model_registry = model_registry app.state.capacity_manager = capacity_manager app.state.strategy = strategy yield @@ -142,6 +157,7 @@ async def lifespan(app: FastAPI): # Expose registries on app.state for tests that bypass lifespan app.state.worker_registry = worker_registry app.state.session_registry = session_registry + app.state.model_registry = model_registry app.state.capacity_manager = capacity_manager app.state.strategy = strategy @@ -218,12 +234,57 @@ async def unregister(body: UnregisterWorkerRequest, request: Request): @app.post("/route") async def route(body: RouteRequest, request: Request): _require_admin_key(request, config.admin_api_key) - # 0. session_id lookup takes precedence + + # Step A: resolve model → candidate worker addrs + model_addrs: list[str] | None = None + if body.model is not None: + info = await model_registry.get(body.model) + if info is not None: + model_addrs = info.data_proxy_addrs + if model_addrs is None: + first = await model_registry.first() + if first is not None: + model_addrs = first.data_proxy_addrs + + def _filter_healthy(workers: list, addrs: list[str] | None) -> list: + if addrs is None: + return workers + addr_set = set(addrs) + return [w for w in workers if w.worker_addr in addr_set] + + # Step B: session_id lookup if body.session_id is not None: worker = await session_registry.lookup_by_id(body.session_id) - if worker is None: + if worker is not None: + return {"worker_addr": worker} + if model_addrs is None: raise HTTPException(status_code=404, detail="Session not found") - return {"worker_addr": worker} + + # Step C: model-only routing (no api_key/session_id) + if body.api_key is None and model_addrs is not None: + healthy = await worker_registry.get_healthy_workers() + addr_set = set(model_addrs) + healthy = [w for w in healthy if w.worker_addr in addr_set] + if not healthy: + raise HTTPException(status_code=503, detail="No healthy workers") + worker = strategy.pick(healthy) + if worker is None: + raise HTTPException(status_code=503, detail="No healthy workers") + info = ( + await model_registry.get(body.model) + if body.model + else await model_registry.first() + ) + return { + "worker_addr": worker.worker_addr, + "url": info.url if info else "", + "api_key": info.api_key if info else None, + } + + if body.api_key is None and body.model is not None and model_addrs is None: + raise HTTPException( + status_code=404, detail=f"Model '{body.model}' not found" + ) if body.api_key is None: raise HTTPException( @@ -231,10 +292,9 @@ async def route(body: RouteRequest, request: Request): detail="Either 'api_key' or 'session_id' must be provided", ) - # 1. Session key → pinned worker (batch sessions) + # Step C: Session key → pinned worker pinned = await session_registry.lookup_by_key(body.api_key) if pinned is not None: - # Check if pinned worker is healthy all_workers = await worker_registry.get_all_workers() worker_map = {w.worker_addr: w for w in all_workers} w = worker_map.get(pinned) @@ -242,9 +302,10 @@ async def route(body: RouteRequest, request: Request): raise HTTPException(status_code=503, detail="Pinned worker unhealthy") return {"worker_addr": pinned} - # 2. Admin key → HITL routing (sticky session) + # Step D: Admin key → pick from model addrs if hmac.compare_digest(body.api_key, config.admin_api_key): healthy = await worker_registry.get_healthy_workers() + healthy = _filter_healthy(healthy, model_addrs) if not healthy: raise HTTPException(status_code=503, detail="No healthy workers") worker = strategy.pick(healthy) @@ -257,7 +318,7 @@ async def route(body: RouteRequest, request: Request): ) return {"worker_addr": worker.worker_addr} - # 3. Unknown key + # Step E: Unknown key raise HTTPException(status_code=404, detail="Unknown API key") # ========================================================================= @@ -331,6 +392,51 @@ async def list_workers(request: Request): ] } + @app.post("/register_model") + async def register_model(body: RegisterModelRequest, request: Request): + _require_admin_key(request, config.admin_api_key) + addrs = body.data_proxy_addrs + if not addrs: + healthy = await worker_registry.get_healthy_workers() + if not healthy: + raise HTTPException(status_code=503, detail="No healthy workers") + addrs = [w.worker_addr for w in healthy] + await model_registry.register( + body.model, + body.url, + body.api_key, + addrs, + ) + logger.info( + "Model registered: model=%s url=%s data_proxy_addrs=%s", + body.model, + body.url or "(internal)", + addrs, + ) + return { + "status": "ok", + "model": body.model, + "data_proxy_addrs": addrs, + } + + @app.get("/models") + async def list_models(request: Request): + _require_admin_key(request, config.admin_api_key) + names = await model_registry.list_names() + return {"models": names} + + @app.post("/remove_model") + async def remove_model(body: RemoveModelRequest, request: Request): + _require_admin_key(request, config.admin_api_key) + removed = await model_registry.remove(body.name) + if not removed: + raise HTTPException( + status_code=404, + detail=f"External model '{body.name}' not found", + ) + logger.info("External model removed: name=%s", body.name) + return {"status": "ok", "name": body.name} + # ========================================================================= # Worker resolution by ID (admin key required) # ========================================================================= diff --git a/areal/experimental/inference_service/router/state.py b/areal/experimental/inference_service/router/state.py index 026be954e9..eeff5582e9 100644 --- a/areal/experimental/inference_service/router/state.py +++ b/areal/experimental/inference_service/router/state.py @@ -207,3 +207,54 @@ async def count(self) -> int: """Return the number of registered session keys.""" async with self._lock: return len(self._key_to_worker) + + +@dataclass +class ModelInfo: + """A registered model (internal or external).""" + + name: str + url: str # empty string for internal models + api_key: str | None + data_proxy_addrs: list[str] = field(default_factory=list) + + +class ModelRegistry: + """Thread-safe registry for model routing.""" + + def __init__(self) -> None: + self._models: dict[str, ModelInfo] = {} + self._lock = asyncio.Lock() + + async def register( + self, + name: str, + url: str, + api_key: str | None, + data_proxy_addrs: list[str], + ) -> None: + async with self._lock: + self._models[name] = ModelInfo( + name=name, + url=url, + api_key=api_key, + data_proxy_addrs=data_proxy_addrs, + ) + + async def get(self, name: str) -> ModelInfo | None: + async with self._lock: + return self._models.get(name) + + async def first(self) -> ModelInfo | None: + async with self._lock: + if not self._models: + return None + return next(iter(self._models.values())) + + async def list_names(self) -> list[str]: + async with self._lock: + return list(self._models.keys()) + + async def remove(self, name: str) -> bool: + async with self._lock: + return self._models.pop(name, None) is not None diff --git a/areal/experimental/openai/proxy/server.py b/areal/experimental/openai/proxy/server.py index bed06da5a5..6e160c4644 100644 --- a/areal/experimental/openai/proxy/server.py +++ b/areal/experimental/openai/proxy/server.py @@ -134,11 +134,19 @@ def serialize_interactions( result = {} for key, interaction in interactions.items(): - result[key] = { - "tensor_dict": interaction.to_tensor_dict(), - "reward": interaction.reward, - "interaction_id": interaction.interaction_id, - } + if interaction.has_tensor_data: + result[key] = { + "tensor_dict": interaction.to_tensor_dict(), + "reward": interaction.reward, + "interaction_id": interaction.interaction_id, + } + else: + result[key] = { + "messages": interaction.messages, + "output_message_list": interaction.output_message_list, + "reward": interaction.reward, + "interaction_id": interaction.interaction_id, + } return serialize_value(result) @@ -152,9 +160,12 @@ def deserialize_interactions( data = deserialize_value(data) result = {} for key, item in data.items(): - # Create a minimal InteractionWithTokenLogpReward with cached tensor dict interaction = InteractionWithTokenLogpReward() - interaction._cache = item["tensor_dict"] + if "tensor_dict" in item: + interaction._cache = item["tensor_dict"] + else: + interaction.messages = item["messages"] + interaction.output_message_list = item["output_message_list"] interaction.reward = item["reward"] interaction.interaction_id = item["interaction_id"] result[key] = interaction diff --git a/areal/experimental/openai/types.py b/areal/experimental/openai/types.py index 8af3d01a26..2582c37e3b 100644 --- a/areal/experimental/openai/types.py +++ b/areal/experimental/openai/types.py @@ -57,6 +57,10 @@ class InteractionWithTokenLogpReward: # Interaction ID cache (used for deserialization) _interaction_id: str | None = None + @property + def has_tensor_data(self) -> bool: + return self.model_response is not None or self._cache is not None + @property def is_completion(self) -> bool: return self.completion is not None @@ -200,3 +204,28 @@ def to_tensor_dict(self) -> dict[str, torch.Tensor]: ) self._cache = result return result + + +def concat_string_interactions( + interactions: dict[str, InteractionWithTokenLogpReward], +) -> dict[str, list[dict]]: + """Concat interactions that lack tensor data (e.g. external API mode). + + Returns a dict with an ``"interactions"`` key containing a list of + ``{"request": ..., "response": ..., "reward": ...}`` dicts, one per + interaction. This is the counterpart of + :func:`~areal.utils.data.concat_padded_tensors` for string-only + trajectories. + """ + return { + "interactions": [ + { + "request": v.messages, + "response": ( + v.output_message_list[0]["content"] if v.output_message_list else "" + ), + "reward": v.reward, + } + for v in interactions.values() + ] + } diff --git a/areal/infra/workflow_executor.py b/areal/infra/workflow_executor.py index b5b384e18c..949d2af0b2 100644 --- a/areal/infra/workflow_executor.py +++ b/areal/infra/workflow_executor.py @@ -30,7 +30,10 @@ from .staleness_manager import StalenessManager from areal.infra import workflow_context from .workflow_context import WorkflowContext -from areal.experimental.openai.types import InteractionWithTokenLogpReward +from areal.experimental.openai.types import ( + InteractionWithTokenLogpReward, + concat_string_interactions, +) from areal.utils import logging, perf_tracer, stats_tracker from areal.infra.utils.concurrent import get_executor from areal.utils.data import concat_padded_tensors, cycle_dataloader @@ -1067,13 +1070,19 @@ async def _execute_workflow() -> _RolloutResult | None: self._expected_trajectory_keys, ) - # Convert InteractionWithTokenLogpReward to tensor dict if needed + # Convert InteractionWithTokenLogpReward to tensor dict if needed. + # External-API interactions have no tensor data; fall back to + # concat_string_interactions which produces a plain dict of + # request/response strings instead of padded tensors. if isinstance(traj, dict) and all( isinstance(v, InteractionWithTokenLogpReward) for v in traj.values() ): - traj = concat_padded_tensors( - [v.to_tensor_dict() for v in traj.values()] - ) + if all(v.has_tensor_data for v in traj.values()): + traj = concat_padded_tensors( + [v.to_tensor_dict() for v in traj.values()] + ) + else: + traj = concat_string_interactions(traj) assert traj is None or isinstance(traj, dict), traj diff --git a/examples/experimental/inference_service/README.md b/examples/experimental/inference_service/README.md index 4565015993..1b7c0af21f 100644 --- a/examples/experimental/inference_service/README.md +++ b/examples/experimental/inference_service/README.md @@ -113,13 +113,16 @@ python3 examples/experimental/inference_service/human_in_the_loop_demo.py Key CLI arguments: -| Argument | Default | Description | -| ------------------- | --------------------- | --------------------------------------------------------------------- | -| `--actor-path` | `Qwen/Qwen3-0.6B` | Path to the HuggingFace model weights | -| `--admin-key` | `sk-test123456` | Admin API key (must match `rollout.openai.admin_api_key` in the YAML) | -| `--request-timeout` | `3600` | Per-request timeout in seconds | -| `--gateway-wait` | `600` | Seconds to wait for the gateway to become ready | -| `--question` | *strawberry question* | Question posed in every HITL round | +| Argument | Default | Description | +| -------------------- | --------------------- | --------------------------------------------------------------------- | +| `--actor-path` | `Qwen/Qwen3-0.6B` | Path to the HuggingFace model weights | +| `--admin-key` | `sk-test123456` | Admin API key (must match `rollout.openai.admin_api_key` in the YAML) | +| `--request-timeout` | `3600` | Per-request timeout in seconds | +| `--gateway-wait` | `600` | Seconds to wait for the gateway to become ready | +| `--question` | *strawberry question* | Question posed in every HITL round | +| `--external-url` | `None` | External API URL (enables external model mode) | +| `--external-api-key` | `None` | API key for the external provider | +| `--external-model` | `None` | Model name sent to the external API | You can override the model path without editing the script: @@ -128,6 +131,22 @@ python3 examples/experimental/inference_service/human_in_the_loop_demo.py \ --actor-path /path/to/your/model ``` +### External Model Mode (optional) + +Example 2 can also run HITL with an external OpenAI-compatible provider instead of the +local rollout model. Pass the external flags through `human_in_the_loop_demo.py`; they +are forwarded to `online_rollout.py`: + +```bash +python3 examples/experimental/inference_service/human_in_the_loop_demo.py \ + --external-url https://api.openai.com/v1 \ + --external-api-key sk-... \ + --external-model gpt-4o +``` + +When `--external-url` is set, the controller enables external model mode and routes chat +traffic through the unified `/chat/completions` + `/export_trajectories` external flow. + ### Running a Manual HITL Session To drive the rollout interactively instead of using the automated script: diff --git a/examples/experimental/inference_service/human_in_the_loop_demo.py b/examples/experimental/inference_service/human_in_the_loop_demo.py index 45485176af..96f08c00a8 100644 --- a/examples/experimental/inference_service/human_in_the_loop_demo.py +++ b/examples/experimental/inference_service/human_in_the_loop_demo.py @@ -54,7 +54,12 @@ def _print_header(title: str) -> None: # ── Zeroclaw config helpers ──────────────────────────────────────────────── -def _patch_zeroclaw_config(config_path: Path, gateway_addr: str, api_key: str) -> Path: +def _patch_zeroclaw_config( + config_path: Path, + gateway_addr: str, + api_key: str, + model: str | None = None, +) -> Path: backup = config_path.with_suffix(".demo_bak") shutil.copy2(config_path, backup) @@ -75,6 +80,14 @@ def _patch_zeroclaw_config(config_path: Path, gateway_addr: str, api_key: str) - else: text = f'api_key = "{api_key}"\n' + text + if model is not None: + text = re.sub( + r'^default_model\s*=\s*".*"', + f'default_model = "{model}"', + text, + flags=re.MULTILINE, + ) + config_path.write_text(text) return backup @@ -214,6 +227,21 @@ def main() -> None: default=DEFAULT_INFERENCE_BACKEND, help="Inference backend used by online_rollout.py", ) + parser.add_argument( + "--api-url", + default=None, + help="External API URL (enables external model mode)", + ) + parser.add_argument( + "--provider-api-key", + default=None, + help="API key for the external provider", + ) + parser.add_argument( + "--model", + default=None, + help="Model name for the gateway controller", + ) args = parser.parse_args() online_rollout = ( @@ -248,17 +276,24 @@ def cleanup(signum=None, frame=None): # ── Step 1: Launch online_rollout.py ── _print_header("Step 1: Launch online_rollout.py") log_fh = open(rollout_log, "w") + rollout_cmd = [ + sys.executable, + str(online_rollout), + "--config", + str(config_yaml), + f"actor.path={args.actor_path}", + f"rollout.backend={args.inference_backend}:d1", + f"rollout.openai.admin_api_key={args.admin_key}", + f"rollout.request_timeout={args.request_timeout}", + ] + if args.api_url: + rollout_cmd.extend(["--api-url", args.api_url]) + if args.provider_api_key: + rollout_cmd.extend(["--provider-api-key", args.provider_api_key]) + if args.model: + rollout_cmd.extend(["--model", args.model]) rollout_proc = subprocess.Popen( - [ - sys.executable, - str(online_rollout), - "--config", - str(config_yaml), - f"actor.path={args.actor_path}", - f"rollout.backend={args.inference_backend}:d1", - f"rollout.openai.admin_api_key={args.admin_key}", - f"rollout.request_timeout={args.request_timeout}", - ], + rollout_cmd, stdout=log_fh, stderr=subprocess.STDOUT, cwd=str(REPO_ROOT), @@ -295,7 +330,7 @@ def cleanup(signum=None, frame=None): # ── Step 2: Patch zeroclaw config ── _print_header("Step 2: Update ~/.zeroclaw/config.toml") zeroclaw_backup = _patch_zeroclaw_config( - zeroclaw_config, gateway_addr, args.admin_key + zeroclaw_config, gateway_addr, args.admin_key, model=args.model ) print(" Done.") diff --git a/examples/experimental/inference_service/online_rollout.py b/examples/experimental/inference_service/online_rollout.py index 90f1bd077b..106f8e86d0 100644 --- a/examples/experimental/inference_service/online_rollout.py +++ b/examples/experimental/inference_service/online_rollout.py @@ -2,19 +2,23 @@ from __future__ import annotations +import argparse import sys from dataclasses import asdict from pathlib import Path -import torch - def main(args: list[str]) -> None: repo_root = Path(__file__).resolve().parents[3] if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) - from areal.api.alloc_mode import ModelAllocation + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--api-url", default=None) + parser.add_argument("--provider-api-key", default=None) + parser.add_argument("--model", default=None) + ext_args, remaining = parser.parse_known_args(args) + from areal.api.cli_args import PPOConfig, load_expr_config from areal.experimental.inference_service.controller.config import ( GatewayControllerConfig, @@ -22,13 +26,12 @@ def main(args: list[str]) -> None: from areal.experimental.inference_service.controller.controller import ( GatewayInferenceController, ) - from areal.infra.rpc.rtensor import RTensor from areal.utils import logging from areal.utils.environ import is_single_controller logger = logging.getLogger("InferenceServiceOnlineTrain") - config, _ = load_expr_config(args, PPOConfig) + config, _ = load_expr_config(remaining, PPOConfig) openai_cfg = config.rollout.openai if openai_cfg is None or openai_cfg.mode != "online": raise ValueError( @@ -49,6 +52,8 @@ def main(args: list[str]) -> None: else: raise NotImplementedError(f"Unknown scheduler type: {sched_type}") + is_external = ext_args.api_url is not None + ctrl_config = GatewayControllerConfig( tokenizer_path=config.tokenizer_path, model_path=config.actor.path, @@ -65,15 +70,26 @@ def main(args: list[str]) -> None: scheduling_spec=config.rollout.scheduling_spec, setup_timeout=config.rollout.setup_timeout, request_timeout=config.rollout.request_timeout, - openai=openai_cfg, + admin_api_key=openai_cfg.admin_api_key, + turn_discount=openai_cfg.turn_discount, + export_style=openai_cfg.export_style, ) - rollout_alloc = ModelAllocation.from_str(config.rollout.backend, name="rollout") - if rollout_alloc.backend == "sglang": - server_args = asdict(config.sglang) - elif rollout_alloc.backend == "vllm": - server_args = asdict(config.vllm) + if ext_args.model: + ctrl_config.model = ext_args.model + if is_external: + ctrl_config.api_url = ext_args.api_url + ctrl_config.provider_api_key = ext_args.provider_api_key + server_args = None else: - raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") + from areal.api.alloc_mode import ModelAllocation + + rollout_alloc = ModelAllocation.from_str(config.rollout.backend, name="rollout") + if rollout_alloc.backend == "sglang": + server_args = asdict(config.sglang) + elif rollout_alloc.backend == "vllm": + server_args = asdict(config.vllm) + else: + raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) try: @@ -93,15 +109,37 @@ def main(args: list[str]) -> None: workflow=None, ) - # Localize RTensor references into real torch tensors so we - # can compute aggregate reward statistics. - localized_rewards = [RTensor.localize(traj)["rewards"] for traj in result] - all_rewards = torch.cat(localized_rewards, dim=0) - logger.info( - "Rollout complete (%d trajectories), avg_reward=%.4f", - len(result), - all_rewards.mean().item(), - ) + if is_external: + logger.info("Rollout complete (%d trajectories)", len(result)) + for i, traj in enumerate(result): + for j, interaction in enumerate(traj.get("interactions", [])): + request_msgs = interaction.get("request", []) + request = ( + request_msgs[-1].get("content", "") if request_msgs else "" + ) + response = interaction.get("response", "") + logger.info( + "Trajectory %d, interaction %d:\n" + " request: %s\n response: %s", + i, + j, + request[:300], + response[:300], + ) + else: + import torch + + from areal.infra.rpc.rtensor import RTensor + + # Localize RTensor references into real torch tensors so we + # can compute aggregate reward statistics. + localized_rewards = [RTensor.localize(traj)["rewards"] for traj in result] + all_rewards = torch.cat(localized_rewards, dim=0) + logger.info( + "Rollout complete (%d trajectories), avg_reward=%.4f", + len(result), + all_rewards.mean().item(), + ) finally: ctrl.destroy() scheduler.delete_workers(None) diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 13b8f43e95..5e61f768f1 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -8,7 +8,6 @@ import httpx import pytest -from areal.api.cli_args import OpenAIProxyConfig from areal.experimental.inference_service.controller.config import ( GatewayControllerConfig, ) @@ -27,8 +26,8 @@ class TestGatewayControllerConfig: def test_defaults(self): cfg = GatewayControllerConfig() - assert isinstance(cfg.openai, OpenAIProxyConfig) - assert cfg.openai.admin_api_key == "areal-admin-key" + assert cfg.admin_api_key is None + assert cfg.model == "default" assert cfg.consumer_batch_size == 16 assert cfg.max_concurrent_rollouts is None assert cfg.max_head_offpolicyness == 0 @@ -37,14 +36,13 @@ def test_defaults(self): def test_custom_values(self): cfg = GatewayControllerConfig( - openai=OpenAIProxyConfig(admin_api_key="custom-key"), + admin_api_key="custom-key", consumer_batch_size=32, max_concurrent_rollouts=64, max_head_offpolicyness=5, set_reward_finish_timeout=3.0, ) - assert cfg.openai is not None - assert cfg.openai.admin_api_key == "custom-key" + assert cfg.admin_api_key == "custom-key" assert cfg.consumer_batch_size == 32 assert cfg.max_concurrent_rollouts == 64 assert cfg.max_head_offpolicyness == 5 @@ -71,7 +69,7 @@ def test_dump_to_file_defaults_to_false(self): class TestControllerWorkflowResolution: def test_resolve_workflow_with_instance(self): controller = GatewayInferenceController( - config=GatewayControllerConfig(), + config=GatewayControllerConfig(admin_api_key="test-key"), scheduler=MagicMock(), ) with pytest.raises(TypeError, match=r"callable run\(\) method"): @@ -79,7 +77,7 @@ def test_resolve_workflow_with_instance(self): def test_resolve_workflow_none_creates_online_inference_service_workflow(self): cfg = GatewayControllerConfig( - openai=OpenAIProxyConfig(admin_api_key="test-admin-key") + admin_api_key="test-admin-key", ) scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -97,7 +95,7 @@ def test_resolve_workflow_none_creates_online_inference_service_workflow(self): def test_resolve_workflow_agent_class_creates_offline_workflow(self): cfg = GatewayControllerConfig( - openai=OpenAIProxyConfig(admin_api_key="test-admin-key") + admin_api_key="test-admin-key", ) scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -125,7 +123,7 @@ def test_resolve_should_accept_fn_callable(self): def test_resolve_workflow_with_agent_class(self): """Test _resolve_workflow wraps agent-like classes in InferenceServiceWorkflow.""" - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) controller._gateway_addr = "http://test:8080" @@ -144,7 +142,7 @@ async def run(self, data, **kwargs): def test_resolve_workflow_agent_class_without_gateway_raises(self): controller = GatewayInferenceController( - config=GatewayControllerConfig(), + config=GatewayControllerConfig(admin_api_key="test-key"), scheduler=MagicMock(), ) @@ -157,7 +155,7 @@ async def run(self, data, **kwargs): def test_resolve_workflow_rollout_workflow_instance_raises(self): controller = GatewayInferenceController( - config=GatewayControllerConfig(), + config=GatewayControllerConfig(admin_api_key="test-key"), scheduler=MagicMock(), ) controller._gateway_addr = "http://test:8080" @@ -175,7 +173,7 @@ def test_resolve_workflow_rollout_workflow_instance_raises(self): def test_resolve_workflow_rollout_workflow_class_raises(self): controller = GatewayInferenceController( - config=GatewayControllerConfig(), + config=GatewayControllerConfig(admin_api_key="test-key"), scheduler=MagicMock(), ) controller._gateway_addr = "http://test:8080" @@ -245,8 +243,18 @@ def test_not_subclass_of_rollout_controller(self): class TestGatewayInferenceControllerConstruction: - def test_constructor(self): + def test_admin_api_key_none_raises(self): cfg = GatewayControllerConfig() + with pytest.raises(ValueError, match="admin_api_key must be set"): + GatewayInferenceController(config=cfg, scheduler=MagicMock()) + + def test_model_empty_raises(self): + cfg = GatewayControllerConfig(admin_api_key="test-key", model="") + with pytest.raises(ValueError, match="model must not be empty"): + GatewayInferenceController(config=cfg, scheduler=MagicMock()) + + def test_constructor(self): + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -259,15 +267,15 @@ def test_constructor(self): assert controller._worker_ids == {} assert controller.worker_ids == {} - def test_admin_api_key_defaults_from_openai_proxy_config(self): - cfg = GatewayControllerConfig() + def test_admin_api_key_defaults(self): + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) - assert controller.config.openai.admin_api_key == "areal-admin-key" + assert controller.config.admin_api_key == "test-key" def test_version_management_without_services(self): """set_version / get_version work even without gateway services.""" - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -276,14 +284,14 @@ def test_version_management_without_services(self): assert controller.get_version() == 42 def test_export_stats_returns_dict(self): - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) stats = controller.export_stats() assert isinstance(stats, dict) def test_start_proxy_is_noop(self): - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) # Should not raise @@ -291,14 +299,14 @@ def test_start_proxy_is_noop(self): controller.start_proxy_gateway() def test_proxy_gateway_addr(self): - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) # Before initialize, proxy_gateway_addr returns the empty _gateway_addr assert controller.proxy_gateway_addr == "" def test_callback_addr_formats_ipv6_hostport(self): - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) controller._callback_host = "2001:db8::10" @@ -307,14 +315,14 @@ def test_callback_addr_formats_ipv6_hostport(self): assert controller.callback_addr == "[2001:db8::10]:19000" def test_workflow_executor_raises_before_init(self): - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) with pytest.raises(RuntimeError, match="initialize"): _ = controller.workflow_executor def test_config_perf_tracer_is_noop(self): - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) # Should not raise @@ -347,7 +355,7 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy cmd="python -m areal.experimental.inference_service.guard", ), ), - openai=OpenAIProxyConfig(admin_api_key="test-admin-key"), + admin_api_key="test-admin-key", ) controller = GatewayInferenceController(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" @@ -383,10 +391,9 @@ async def test_async_initialize_passes_callback_and_reward_timeout_to_data_proxy class TestGatewayInferenceControllerHTTP: def test_gateway_http_post_raises_on_failure(self): - cfg = GatewayControllerConfig() + cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) - # _gateway_addr points to unreachable host — should raise RuntimeError controller._gateway_addr = "http://127.0.0.1:19999" with pytest.raises(RuntimeError, match="Failed to POST"): controller._gateway_http_post("/test", {"key": "value"}) @@ -398,7 +405,7 @@ def test_gateway_http_post_sends_auth(self, mock_post): mock_post.return_value = mock_resp cfg = GatewayControllerConfig( - openai=OpenAIProxyConfig(admin_api_key="my-secret-key") + admin_api_key="my-secret-key", ) scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -416,7 +423,7 @@ class TestOnlineCallbackFlow: @pytest.mark.asyncio async def test_online_callback_without_waiter_buffers_export_request(self): cfg = GatewayControllerConfig( - openai=OpenAIProxyConfig(admin_api_key="test-admin-key") + admin_api_key="test-admin-key", ) scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -437,7 +444,7 @@ async def test_online_callback_without_waiter_buffers_export_request(self): @pytest.mark.asyncio async def test_online_callback_settles_waiter_once(self): cfg = GatewayControllerConfig( - openai=OpenAIProxyConfig(admin_api_key="test-admin-key") + admin_api_key="test-admin-key", ) scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -464,7 +471,7 @@ async def test_online_callback_settles_waiter_once(self): @pytest.mark.asyncio async def test_online_callback_invalid_payload_keeps_waiter_pending(self): cfg = GatewayControllerConfig( - openai=OpenAIProxyConfig(admin_api_key="test-admin-key") + admin_api_key="test-admin-key", ) scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -491,7 +498,7 @@ async def test_online_callback_invalid_payload_keeps_waiter_pending(self): @pytest.mark.asyncio async def test_cancelled_waiter_buffers_completed_online_result(self): cfg = GatewayControllerConfig( - openai=OpenAIProxyConfig(admin_api_key="test-admin-key") + admin_api_key="test-admin-key", ) scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) @@ -537,9 +544,6 @@ async def test_online_mode_waits_on_controller(self): timeout=3.0, ) workflow._grant_capacity = AsyncMock() - workflow._export_interactions = AsyncMock( - return_value={"chatcmpl-1": mock_interaction} - ) with ( patch( @@ -548,8 +552,27 @@ async def test_online_mode_waits_on_controller(self): patch( "areal.experimental.inference_service.controller.workflow.stats_tracker" ) as mock_st, + patch( + "areal.experimental.inference_service.controller.workflow.deserialize_interactions" + ) as mock_deserialize, ): - mock_http_session = AsyncMock() + mock_deserialize.return_value = {"chatcmpl-1": mock_interaction} + + # _run_online uses ``async with http_session.post(...)`` directly, + # so the mock must support the async context-manager protocol. + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = AsyncMock( + return_value={"interactions": {"chatcmpl-1": {}}} + ) + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + mock_http_session = MagicMock() + mock_http_session.post = MagicMock(return_value=mock_cm) + mock_wf_ctx.get_aiohttp_session = AsyncMock(return_value=mock_http_session) mock_wf_ctx.stat_scope.return_value = "rollout" mock_st.get.return_value = MagicMock() @@ -560,11 +583,8 @@ async def test_online_mode_waits_on_controller(self): assert "chatcmpl-1" in result workflow._grant_capacity.assert_awaited_once() controller.wait_for_online_trajectory.assert_awaited_once_with(timeout=3.0) - workflow._export_interactions.assert_awaited_once_with( - mock_http_session, - "sess-1", - trajectory_id=7, - ) + mock_http_session.post.assert_called_once() + mock_deserialize.assert_called_once_with({"chatcmpl-1": {}}) @pytest.mark.asyncio async def test_offline_mode_runs_agent(self): @@ -632,23 +652,29 @@ def test_n_gpus_per_node_custom(self): assert cfg.n_gpus_per_node == 4 def test_n_gpus_per_node_zero_raises(self): - cfg = GatewayControllerConfig(n_gpus_per_node=0, backend="sglang:d1t8") + cfg = GatewayControllerConfig( + n_gpus_per_node=0, backend="sglang:d1t8", admin_api_key="test-key" + ) with pytest.raises(ValueError, match="n_gpus_per_node must be >= 1"): GatewayInferenceController(config=cfg, scheduler=MagicMock()) def test_gpus_not_divisible_raises(self): - cfg = GatewayControllerConfig(n_gpus_per_node=3, backend="sglang:d1t8") + cfg = GatewayControllerConfig( + n_gpus_per_node=3, backend="sglang:d1t8", admin_api_key="test-key" + ) with pytest.raises(ValueError, match="must be divisible by n_gpus_per_node"): GatewayInferenceController(config=cfg, scheduler=MagicMock()) def test_single_node_backward_compat(self): - cfg = GatewayControllerConfig(backend="sglang:d2t4") + cfg = GatewayControllerConfig(backend="sglang:d2t4", admin_api_key="test-key") controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) assert controller._nnodes_per_instance == 1 def test_multi_node_valid_config(self): # tp=16, n_gpus_per_node=8 → nnodes_per_instance=2 - cfg = GatewayControllerConfig(n_gpus_per_node=8, backend="sglang:d1t16") + cfg = GatewayControllerConfig( + n_gpus_per_node=8, backend="sglang:d1t16", admin_api_key="test-key" + ) controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) assert controller._nnodes_per_instance == 2 @@ -677,7 +703,7 @@ async def test_async_initialize_multinode_worker_count(self): backend="sglang:d1t8", n_gpus_per_node=4, scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), - openai=OpenAIProxyConfig(admin_api_key="test-key"), + admin_api_key="test-key", ) controller = GatewayInferenceController(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" @@ -735,7 +761,7 @@ async def test_async_initialize_multinode_fork_path(self): backend="sglang:d1t8", n_gpus_per_node=4, scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), - openai=OpenAIProxyConfig(admin_api_key="test-key"), + admin_api_key="test-key", ) controller = GatewayInferenceController(config=cfg, scheduler=scheduler) controller._callback_host = "127.0.0.1" diff --git a/tests/experimental/inference_service/test_controller_integration.py b/tests/experimental/inference_service/test_controller_integration.py index a4bc99ae40..cdaa0b1868 100644 --- a/tests/experimental/inference_service/test_controller_integration.py +++ b/tests/experimental/inference_service/test_controller_integration.py @@ -121,15 +121,11 @@ def _make_gateway_controller_config( online_mode: bool = False, set_reward_finish_timeout: float = 0.0, ): - from areal.api.cli_args import OpenAIProxyConfig, SchedulingSpec + from areal.api.cli_args import SchedulingSpec from areal.experimental.inference_service.controller.config import ( GatewayControllerConfig, ) - openai_cfg = OpenAIProxyConfig(admin_api_key="test-admin") - if online_mode: - openai_cfg = OpenAIProxyConfig(mode="online", admin_api_key="test-admin") - return GatewayControllerConfig( tokenizer_path=model_path, model_path=model_path, @@ -145,7 +141,7 @@ def _make_gateway_controller_config( consumer_batch_size=8, max_head_offpolicyness=1024, setup_timeout=180.0, - openai=openai_cfg, + admin_api_key="test-admin", ) @@ -630,8 +626,8 @@ def test_online_workflow_submit_wait_roundtrip(self, gateway_controller_online): import requests gateway_url = gateway_controller_online.proxy_gateway_addr - assert gateway_controller_online.config.openai is not None - admin_key = gateway_controller_online.config.openai.admin_api_key + assert gateway_controller_online.config.admin_api_key is not None + admin_key = gateway_controller_online.config.admin_api_key task_id = gateway_controller_online.submit( data={}, @@ -687,8 +683,8 @@ def test_offline_export_applies_discount_after_multiple_rewards_in_same_trajecto self, gateway_controller_with_reward_timeout ): gateway_url = gateway_controller_with_reward_timeout.proxy_gateway_addr - assert gateway_controller_with_reward_timeout.config.openai is not None - admin_key = gateway_controller_with_reward_timeout.config.openai.admin_api_key + assert gateway_controller_with_reward_timeout.config.admin_api_key is not None + admin_key = gateway_controller_with_reward_timeout.config.admin_api_key grant_resp = httpx.post( f"{gateway_url}/grant_capacity", @@ -790,7 +786,7 @@ def gateway_controller_full_init(model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import OpenAIProxyConfig, SchedulingSpec + from areal.api.cli_args import SchedulingSpec from areal.experimental.inference_service.controller.config import ( GatewayControllerConfig, ) @@ -810,7 +806,7 @@ def gateway_controller_full_init(model_path, tmp_path_factory): consumer_batch_size=8, max_head_offpolicyness=1024, setup_timeout=300.0, - openai=OpenAIProxyConfig(admin_api_key="test-admin"), + admin_api_key="test-admin", ) server_args = { @@ -1079,7 +1075,7 @@ def gateway_controller_full_init_vllm(model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import OpenAIProxyConfig, SchedulingSpec + from areal.api.cli_args import SchedulingSpec from areal.experimental.inference_service.controller.config import ( GatewayControllerConfig, ) @@ -1099,7 +1095,7 @@ def gateway_controller_full_init_vllm(model_path, tmp_path_factory): consumer_batch_size=8, max_head_offpolicyness=1024, setup_timeout=300.0, - openai=OpenAIProxyConfig(admin_api_key="test-admin"), + admin_api_key="test-admin", ) server_args = { @@ -1284,7 +1280,7 @@ def gateway_controller_full_init_vlm_sglang(vlm_model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import OpenAIProxyConfig, SchedulingSpec + from areal.api.cli_args import SchedulingSpec from areal.experimental.inference_service.controller.config import ( GatewayControllerConfig, ) @@ -1304,7 +1300,7 @@ def gateway_controller_full_init_vlm_sglang(vlm_model_path, tmp_path_factory): consumer_batch_size=8, max_head_offpolicyness=1024, setup_timeout=300.0, - openai=OpenAIProxyConfig(admin_api_key="test-admin"), + admin_api_key="test-admin", ) local_scheduler = _make_local_scheduler( @@ -1328,7 +1324,7 @@ def gateway_controller_full_init_vlm_vllm(vlm_model_path, tmp_path_factory): if not has_gpu(): pytest.skip("GPU required") - from areal.api.cli_args import OpenAIProxyConfig, SchedulingSpec + from areal.api.cli_args import SchedulingSpec from areal.experimental.inference_service.controller.config import ( GatewayControllerConfig, ) @@ -1348,7 +1344,7 @@ def gateway_controller_full_init_vlm_vllm(vlm_model_path, tmp_path_factory): consumer_batch_size=8, max_head_offpolicyness=1024, setup_timeout=300.0, - openai=OpenAIProxyConfig(admin_api_key="test-admin"), + admin_api_key="test-admin", ) local_scheduler = _make_local_scheduler( diff --git a/tests/experimental/inference_service/test_controller_version.py b/tests/experimental/inference_service/test_controller_version.py index a845e120bc..b2d90c3a5d 100644 --- a/tests/experimental/inference_service/test_controller_version.py +++ b/tests/experimental/inference_service/test_controller_version.py @@ -31,6 +31,7 @@ def _make_controller( Does NOT call initialize() — internal fields are set directly. """ cfg = GatewayControllerConfig( + admin_api_key="test-key", scheduling_spec=(SchedulingSpec(),), ) scheduler = MagicMock() diff --git a/tests/experimental/inference_service/test_external_model.py b/tests/experimental/inference_service/test_external_model.py new file mode 100644 index 0000000000..bef88ac7ec --- /dev/null +++ b/tests/experimental/inference_service/test_external_model.py @@ -0,0 +1,795 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +import pytest_asyncio + +from areal.experimental.inference_service.data_proxy.app import ( + create_app as create_data_proxy_app, +) +from areal.experimental.inference_service.data_proxy.config import DataProxyConfig +from areal.experimental.inference_service.data_proxy.session import SessionStore +from areal.experimental.inference_service.gateway.app import ( + create_app as create_gateway_app, +) +from areal.experimental.inference_service.gateway.config import GatewayConfig +from areal.experimental.inference_service.gateway.streaming import ( + RouterKeyRejectedError, +) +from areal.experimental.inference_service.router.app import ( + create_app as create_router_app, +) +from areal.experimental.inference_service.router.config import RouterConfig +from areal.experimental.inference_service.router.state import ModelRegistry + +ADMIN_KEY = "test-admin-key" +SESSION_KEY = "session-key-abc123" +WORKER_ADDR = "http://worker-1:18082" +ROUTER_MODULE = "areal.experimental.inference_service.gateway.app" + + +def admin_headers() -> dict[str, str]: + return {"Authorization": f"Bearer {ADMIN_KEY}"} + + +def session_headers() -> dict[str, str]: + return {"Authorization": f"Bearer {SESSION_KEY}"} + + +@pytest.fixture +def router_config() -> RouterConfig: + return RouterConfig( + host="127.0.0.1", + port=18081, + admin_api_key=ADMIN_KEY, + poll_interval=999, + routing_strategy="round_robin", + ) + + +@pytest_asyncio.fixture +async def router_client(router_config: RouterConfig): + app = create_router_app(router_config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +class TestModelRegistry: + @pytest.mark.asyncio + async def test_model_registry_register_get_list_remove(self): + reg = ModelRegistry() + await reg.register("ext-a", "http://api", None, [WORKER_ADDR]) + + info = await reg.get("ext-a") + assert info is not None + assert info.name == "ext-a" + assert info.url == "http://api" + assert info.data_proxy_addrs == [WORKER_ADDR] + + names = await reg.list_names() + assert names == ["ext-a"] + + removed = await reg.remove("ext-a") + assert removed is True + assert await reg.get("ext-a") is None + + +class TestRouterExternalEndpoints: + @pytest.mark.asyncio + async def test_register_model_success(self, router_client): + await router_client.post( + "/register", + json={"worker_addr": WORKER_ADDR}, + headers=admin_headers(), + ) + + resp = await router_client.post( + "/register_model", + json={ + "model": "ext-1", + "url": "http://ext-api", + "data_proxy_addrs": [WORKER_ADDR], + }, + headers=admin_headers(), + ) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["model"] == "ext-1" + assert data["data_proxy_addrs"] == [WORKER_ADDR] + + @pytest.mark.asyncio + async def test_register_model_no_workers_503(self, router_client): + resp = await router_client.post( + "/register_model", + json={"model": "ext-1", "url": "http://ext-api"}, + headers=admin_headers(), + ) + assert resp.status_code == 503 + + @pytest.mark.asyncio + async def test_register_model_no_auth_401(self, router_client): + resp = await router_client.post( + "/register_model", + json={"model": "ext-1", "url": "http://ext-api"}, + ) + assert resp.status_code == 401 + + @pytest.mark.asyncio + async def test_route_model_success(self, router_client): + await router_client.post( + "/register", + json={"worker_addr": WORKER_ADDR}, + headers=admin_headers(), + ) + await router_client.post( + "/register_model", + json={ + "model": "ext-1", + "url": "http://ext-api", + "data_proxy_addrs": [WORKER_ADDR], + }, + headers=admin_headers(), + ) + + resp = await router_client.post( + "/route", + json={"model": "ext-1"}, + headers=admin_headers(), + ) + assert resp.status_code == 200 + data = resp.json() + assert data["worker_addr"] == WORKER_ADDR + assert data["url"] == "http://ext-api" + + @pytest.mark.asyncio + async def test_route_model_not_found_404(self, router_client): + resp = await router_client.post( + "/route", + json={"model": "nope"}, + headers=admin_headers(), + ) + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_list_models_empty(self, router_client): + resp = await router_client.get("/models", headers=admin_headers()) + assert resp.status_code == 200 + assert resp.json()["models"] == [] + + @pytest.mark.asyncio + async def test_list_models_after_registration(self, router_client): + await router_client.post( + "/register", + json={"worker_addr": WORKER_ADDR}, + headers=admin_headers(), + ) + await router_client.post( + "/register_model", + json={ + "model": "ext-1", + "url": "http://ext-api", + "data_proxy_addrs": [WORKER_ADDR], + }, + headers=admin_headers(), + ) + + resp = await router_client.get("/models", headers=admin_headers()) + assert resp.status_code == 200 + assert resp.json()["models"] == ["ext-1"] + + +@pytest.fixture +def gateway_config() -> GatewayConfig: + return GatewayConfig( + host="127.0.0.1", + port=18080, + admin_api_key=ADMIN_KEY, + router_addr="http://mock-router:8081", + router_timeout=2.0, + forward_timeout=30.0, + ) + + +@pytest_asyncio.fixture +async def gateway_client(gateway_config: GatewayConfig): + app = create_gateway_app(gateway_config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +class TestGatewayExternalEndpoints: + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.forward_request", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.register_model_in_router", new_callable=AsyncMock) + async def test_register_model_gateway_full_flow( + self, + mock_register_model, + mock_forward, + gateway_client, + ): + mock_register_model.return_value = { + "status": "ok", + "model": "ext-1", + "data_proxy_addrs": [WORKER_ADDR], + } + mock_forward.return_value = httpx.Response(200, json={"status": "ok"}) + + resp = await gateway_client.post( + "/register_model", + json={"model": "ext-1", "url": "http://ext-api"}, + headers=admin_headers(), + ) + assert resp.status_code == 200 + assert resp.json()["data_proxy_addrs"] == [WORKER_ADDR] + mock_register_model.assert_called_once() + mock_forward.assert_called_once() + + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.forward_request", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.query_router", new_callable=AsyncMock) + async def test_chat_completions_external_model( + self, + mock_query_router, + mock_forward, + gateway_client, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward.return_value = httpx.Response(200, json={"id": "ext-chat-1"}) + + resp = await gateway_client.post( + "/chat/completions", + json={"model": "ext-1", "messages": [{"role": "user", "content": "hi"}]}, + headers=admin_headers(), + ) + assert resp.status_code == 200 + assert resp.json()["id"] == "ext-chat-1" + assert "/chat/completions" in mock_forward.call_args.args[0] + + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.forward_sse_stream") + @patch(f"{ROUTER_MODULE}.query_router", new_callable=AsyncMock) + async def test_chat_completions_external_model_streaming( + self, + mock_query_router, + mock_forward_sse, + gateway_client, + ): + mock_query_router.return_value = WORKER_ADDR + + async def _stream() -> AsyncGenerator[bytes, None]: + yield b"data: hello\n\n" + yield b"data: [DONE]\n\n" + + mock_forward_sse.return_value = _stream() + + resp = await gateway_client.post( + "/chat/completions", + json={ + "model": "ext-1", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + headers=admin_headers(), + ) + assert resp.status_code == 200 + assert "text/event-stream" in resp.headers["content-type"] + + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.forward_request", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.query_router", new_callable=AsyncMock) + async def test_chat_completions_unregistered_model_falls_back( + self, + mock_query_router, + mock_forward, + gateway_client, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward.return_value = httpx.Response(200, json={"id": "internal-chat"}) + + resp = await gateway_client.post( + "/chat/completions", + json={ + "model": "missing", + "messages": [{"role": "user", "content": "hi"}], + }, + headers=session_headers(), + ) + assert resp.status_code == 200 + assert resp.json()["id"] == "internal-chat" + assert "/chat/completions" in mock_forward.call_args.args[0] + + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.forward_request", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.query_router", new_callable=AsyncMock) + async def test_chat_completions_no_model_internal_path( + self, + mock_query_router, + mock_forward, + gateway_client, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward.return_value = httpx.Response(200, json={"id": "internal-chat"}) + + resp = await gateway_client.post( + "/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}]}, + headers=session_headers(), + ) + assert resp.status_code == 200 + assert resp.json()["id"] == "internal-chat" + + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.list_models_from_router", new_callable=AsyncMock) + async def test_list_models_gateway(self, mock_list_models, gateway_client): + mock_list_models.return_value = ["ext-1", "ext-2"] + + resp = await gateway_client.get("/models", headers=admin_headers()) + assert resp.status_code == 200 + assert resp.json()["models"] == ["ext-1", "ext-2"] + + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.forward_request", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.query_router", new_callable=AsyncMock) + async def test_export_trajectories_routes_external_by_session_id( + self, + mock_query_router, + mock_forward, + gateway_client, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward.return_value = httpx.Response( + 200, + json={ + "interactions": {"id-1": {"messages": [], "reward": 0.0}}, + }, + ) + + resp = await gateway_client.post( + "/export_trajectories", + json={"session_id": "ext-1"}, + headers=admin_headers(), + ) + assert resp.status_code == 200 + assert "/export_trajectories" in mock_forward.call_args.args[0] + + +@pytest.fixture +def data_proxy_config() -> DataProxyConfig: + return DataProxyConfig( + host="127.0.0.1", + port=18082, + backend_addr="http://mock-sglang:30000", + tokenizer_path="mock-tokenizer", + request_timeout=10.0, + ) + + +@pytest.fixture +def mock_tokenizer(): + tok = MagicMock() + tok._tok = MagicMock() + tok._tok.eos_token_id = 2 + tok._tok.pad_token_id = 0 + return tok + + +@pytest.fixture +def mock_areal_client(): + from openai.types.chat import ChatCompletion, ChatCompletionMessage + from openai.types.chat.chat_completion import Choice + from openai.types.completion_usage import CompletionUsage + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=ChatCompletion( + id="chatcmpl-mock", + choices=[ + Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage(content="Hello!", role="assistant"), + ) + ], + created=1234567890, + model="sglang", + object="chat.completion", + usage=CompletionUsage(completion_tokens=3, prompt_tokens=5, total_tokens=8), + ) + ) + return mock_client + + +@pytest_asyncio.fixture +async def data_proxy_client(data_proxy_config, mock_tokenizer, mock_areal_client): + from areal.experimental.inference_service.data_proxy.backend import ( + SGLangBridgeBackend, + ) + from areal.experimental.inference_service.data_proxy.inf_bridge import InfBridge + from areal.experimental.inference_service.data_proxy.pause import PauseState + + app = create_data_proxy_app(data_proxy_config) + pause_state = PauseState() + inf_bridge = InfBridge( + backend=SGLangBridgeBackend(), + backend_addr=data_proxy_config.backend_addr, + pause_state=pause_state, + request_timeout=data_proxy_config.request_timeout, + max_resubmit_retries=5, + resubmit_wait=0.01, + ) + app.state.tokenizer = mock_tokenizer + app.state.inf_bridge = inf_bridge + app.state.areal_client = mock_areal_client + app.state.pause_state = pause_state + app.state.config = data_proxy_config + store = SessionStore() + store.set_admin_key(data_proxy_config.admin_api_key) + app.state.session_store = store + app.state.version = 0 + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +class TestDataProxyExternalEndpoints: + @pytest.mark.asyncio + async def test_register_external_model(self, data_proxy_client): + resp = await data_proxy_client.post( + "/register_model", + json={"name": "ext-1", "url": "http://ext-api", "model": "gpt-4o"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + @pytest.mark.asyncio + async def test_external_chat_completions_non_streaming( + self, + data_proxy_client, + monkeypatch, + ): + await data_proxy_client.post( + "/register_model", + json={"name": "ext-1", "url": "http://ext-api", "model": "gpt-4o"}, + ) + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, json=None, headers=None): + assert url == "http://ext-api/chat/completions" + assert json["model"] == "gpt-4o" + return httpx.Response(200, json={"id": "ext-1-response"}) + + monkeypatch.setattr( + "areal.experimental.inference_service.data_proxy.app.httpx.AsyncClient", + _FakeClient, + ) + + resp = await data_proxy_client.post( + "/chat/completions", + json={"model": "ext-1", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": "Bearer areal-admin-key"}, + ) + assert resp.status_code == 200 + assert resp.json()["id"] == "ext-1-response" + + not_ready = await data_proxy_client.post( + "/export_trajectories", + json={"session_id": "__hitl__"}, + headers={"Authorization": "Bearer areal-admin-key"}, + ) + assert not_ready.status_code == 409 + + set_reward = await data_proxy_client.post( + "/rl/set_reward", + json={"reward": 1.0}, + headers={"Authorization": "Bearer areal-admin-key"}, + ) + assert set_reward.status_code == 200 + assert set_reward.json()["trajectory_ready"] is True + + exported = await data_proxy_client.post( + "/export_trajectories", + json={"session_id": "__hitl__"}, + headers={"Authorization": "Bearer areal-admin-key"}, + ) + assert exported.status_code == 200 + payload = exported.json() + assert len(payload["interactions"]) == 1 + + @pytest.mark.asyncio + async def test_external_chat_completions_streaming( + self, + data_proxy_client, + monkeypatch, + ): + await data_proxy_client.post( + "/register_model", + json={"name": "ext-1", "url": "http://ext-api", "model": "gpt-4o"}, + ) + + class _FakeStreamResponse: + status_code = 200 + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def aread(self): + return b"" + + async def aiter_bytes(self): + yield b"data: chunk-1\n\n" + yield b"data: [DONE]\n\n" + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def stream(self, method, url, json=None, headers=None): + assert method == "POST" + assert url == "http://ext-api/chat/completions" + assert json["model"] == "gpt-4o" + return _FakeStreamResponse() + + monkeypatch.setattr( + "areal.experimental.inference_service.data_proxy.app.httpx.AsyncClient", + _FakeClient, + ) + + resp = await data_proxy_client.post( + "/chat/completions", + json={ + "model": "ext-1", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + headers={"Authorization": "Bearer areal-admin-key"}, + ) + assert resp.status_code == 200 + assert "text/event-stream" in resp.headers["content-type"] + + set_reward = await data_proxy_client.post( + "/rl/set_reward", + json={"reward": 1.0}, + headers={"Authorization": "Bearer areal-admin-key"}, + ) + assert set_reward.status_code == 200 + assert set_reward.json()["trajectory_ready"] is True + + exported = await data_proxy_client.post( + "/export_trajectories", + json={"session_id": "__hitl__"}, + headers={"Authorization": "Bearer areal-admin-key"}, + ) + assert exported.status_code == 200 + payload = exported.json() + assert len(payload["interactions"]) == 1 + + exported_again = await data_proxy_client.post( + "/export_trajectories", + json={"session_id": "__hitl__"}, + headers={"Authorization": "Bearer areal-admin-key"}, + ) + assert exported_again.status_code == 409 + + @pytest.mark.asyncio + async def test_unregistered_model_falls_through_to_internal( + self, data_proxy_client + ): + resp = await data_proxy_client.post( + "/chat/completions", + json={ + "model": "missing", + "messages": [{"role": "user", "content": "hi"}], + }, + ) + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_external_chat_uses_stored_provider_api_key( + self, + data_proxy_client, + monkeypatch, + ): + await data_proxy_client.post( + "/register_model", + json={ + "name": "ext-1", + "url": "http://ext-api", + "model": "gpt-4o", + "api_key": "sk-provider-key-99", + }, + ) + + captured_headers: dict[str, str] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, json=None, headers=None): + if headers: + captured_headers.update(headers) + return httpx.Response(200, json={"id": "ext-1-response"}) + + monkeypatch.setattr( + "areal.experimental.inference_service.data_proxy.app.httpx.AsyncClient", + _FakeClient, + ) + + resp = await data_proxy_client.post( + "/chat/completions", + json={"model": "ext-1", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": "Bearer sk-session-key"}, + ) + assert resp.status_code == 200 + assert captured_headers.get("authorization") == "Bearer sk-provider-key-99" + + +@pytest.mark.asyncio +async def test_external_model_end_to_end_register_then_chat(router_config): + router_app = create_router_app(router_config) + router_transport = httpx.ASGITransport(app=router_app) + async with httpx.AsyncClient( + transport=router_transport, + base_url="http://router", + ) as router_client: + await router_client.post( + "/register", + json={"worker_addr": WORKER_ADDR}, + headers=admin_headers(), + ) + + gateway_config = GatewayConfig( + host="127.0.0.1", + port=18080, + admin_api_key=ADMIN_KEY, + router_addr="http://router", + router_timeout=2.0, + forward_timeout=30.0, + ) + gateway_app = create_gateway_app(gateway_config) + gateway_transport = httpx.ASGITransport(app=gateway_app) + + proxy_state: dict[str, dict[str, str | None]] = {} + + async def _register_model_in_router( + router_addr: str, + model: str, + url: str, + api_key: str | None, + data_proxy_addrs: list[str], + admin_api_key: str, + timeout: float, + ) -> dict: + resp = await router_client.post( + "/register_model", + json={ + "model": model, + "url": url, + "api_key": api_key, + "data_proxy_addrs": data_proxy_addrs, + }, + headers=admin_headers(), + ) + resp.raise_for_status() + return resp.json() + + async def _query_router( + router_addr: str, + api_key: str | None = None, + path: str | None = None, + timeout: float = 2.0, + *, + session_id: str | None = None, + admin_api_key: str | None = None, + model: str | None = None, + ) -> str: + if model is not None: + resp = await router_client.post( + "/route", + json={"model": model}, + headers=admin_headers(), + ) + if resp.status_code == 404: + raise RouterKeyRejectedError("not found", 404) + if resp.status_code == 503: + raise RouterKeyRejectedError("no healthy workers", 503) + resp.raise_for_status() + return resp.json()["worker_addr"] + resp = await router_client.post( + "/route", + json={"api_key": api_key, "session_id": session_id}, + headers=admin_headers(), + ) + resp.raise_for_status() + return resp.json()["worker_addr"] + + async def _forward_request( + upstream_url: str, + body: bytes, + headers: dict[str, str], + timeout: float, + ) -> httpx.Response: + if upstream_url == f"{WORKER_ADDR}/register_model": + data = json.loads(body) + proxy_state[data["name"]] = { + "url": data["url"], + "model": data.get("model"), + } + return httpx.Response(200, json={"status": "ok"}) + if upstream_url == f"{WORKER_ADDR}/chat/completions": + data = json.loads(body) + assert data["model"] in proxy_state + return httpx.Response(200, json={"id": "ext-e2e"}) + return httpx.Response(500, json={"error": "unexpected"}) + + with ( + patch( + f"{ROUTER_MODULE}.register_model_in_router", + new=AsyncMock(side_effect=_register_model_in_router), + ), + patch( + f"{ROUTER_MODULE}.query_router", + new=AsyncMock(side_effect=_query_router), + ), + patch( + f"{ROUTER_MODULE}.forward_request", + new=AsyncMock(side_effect=_forward_request), + ), + ): + async with httpx.AsyncClient( + transport=gateway_transport, + base_url="http://gateway", + ) as gateway_client: + reg = await gateway_client.post( + "/register_model", + json={ + "model": "ext-1", + "url": "http://ext-api", + }, + headers=admin_headers(), + ) + assert reg.status_code == 200 + + chat = await gateway_client.post( + "/chat/completions", + json={ + "model": "ext-1", + "messages": [{"role": "user", "content": "hello"}], + }, + headers=admin_headers(), + ) + assert chat.status_code == 200 + assert chat.json()["id"] == "ext-e2e" diff --git a/tests/experimental/inference_service/test_external_model_integration.py b/tests/experimental/inference_service/test_external_model_integration.py new file mode 100644 index 0000000000..3733bb2bef --- /dev/null +++ b/tests/experimental/inference_service/test_external_model_integration.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, patch + +import httpx +import pytest +import pytest_asyncio +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +from areal.experimental.inference_service.data_proxy.app import ( + create_app as create_data_proxy_app, +) +from areal.experimental.inference_service.data_proxy.config import DataProxyConfig +from areal.experimental.inference_service.data_proxy.session import SessionStore +from areal.experimental.inference_service.gateway.app import ( + create_app as create_gateway_app, +) +from areal.experimental.inference_service.gateway.config import GatewayConfig +from areal.experimental.inference_service.gateway.streaming import ( + RouterKeyRejectedError, +) +from areal.experimental.inference_service.router.app import ( + create_app as create_router_app, +) +from areal.experimental.inference_service.router.config import RouterConfig + +ADMIN_KEY = "test-admin-key" +WORKER_ADDR = "http://worker-1:18082" +ROUTER_MODULE = "areal.experimental.inference_service.gateway.app" + + +def admin_headers() -> dict[str, str]: + return {"Authorization": f"Bearer {ADMIN_KEY}"} + + +@pytest.fixture +def router_config() -> RouterConfig: + return RouterConfig( + host="127.0.0.1", + port=18081, + admin_api_key=ADMIN_KEY, + poll_interval=999, + routing_strategy="round_robin", + ) + + +@pytest.fixture +def gateway_config() -> GatewayConfig: + return GatewayConfig( + host="127.0.0.1", + port=18080, + admin_api_key=ADMIN_KEY, + router_addr="http://mock-router:8081", + router_timeout=2.0, + forward_timeout=30.0, + ) + + +@pytest_asyncio.fixture +async def gateway_client(gateway_config: GatewayConfig): + app = create_gateway_app(gateway_config) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +class TestGatewayUnifiedExportTrajectories: + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.revoke_session_in_router", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.forward_request", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.query_router", new_callable=AsyncMock) + async def test_export_trajectories_with_session_id( + self, + mock_query_router, + mock_forward, + mock_revoke, + gateway_client, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward.return_value = httpx.Response( + 200, + json={ + "interactions": {"id-1": {"messages": [], "reward": 0.0}}, + }, + ) + + resp = await gateway_client.post( + "/export_trajectories", + json={"session_id": "ext-1"}, + headers=admin_headers(), + ) + + assert resp.status_code == 200 + mock_query_router.assert_called_once() + assert "/export_trajectories" in mock_forward.call_args.args[0] + + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.revoke_session_in_router", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.forward_request", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.query_router", new_callable=AsyncMock) + async def test_export_trajectories_internal_session( + self, + mock_query_router, + mock_forward, + mock_revoke, + gateway_client, + ): + mock_query_router.return_value = WORKER_ADDR + mock_forward.return_value = httpx.Response(200, json={"interactions": []}) + + resp = await gateway_client.post( + "/export_trajectories", + json={"session_id": "ses-1", "discount": 1.0, "style": "sft"}, + headers=admin_headers(), + ) + + assert resp.status_code == 200 + mock_query_router.assert_called_once() + mock_revoke.assert_called_once() + assert "/export_trajectories" in mock_forward.call_args.args[0] + + @pytest.mark.asyncio + @patch(f"{ROUTER_MODULE}.revoke_session_in_router", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.forward_request", new_callable=AsyncMock) + @patch(f"{ROUTER_MODULE}.query_router", new_callable=AsyncMock) + async def test_export_trajectories_without_session_id_returns_400( + self, + mock_query_router, + mock_forward, + mock_revoke, + gateway_client, + ): + resp = await gateway_client.post( + "/export_trajectories", + json={"discount": 1.0}, + headers=admin_headers(), + ) + + assert resp.status_code == 400 + assert "session_id is required" in resp.json()["error"] + mock_query_router.assert_not_called() + mock_forward.assert_not_called() + mock_revoke.assert_not_called() + + +@pytest.mark.asyncio +async def test_external_model_flow_end_to_end_gateway_router_data_proxy(router_config): + mock_external_app = FastAPI() + + @mock_external_app.post("/chat/completions") + async def mock_chat(request: Request): + body = await request.json() + return JSONResponse( + { + "id": "chatcmpl-mock", + "object": "chat.completion", + "created": 1234567890, + "model": body.get("model", "mock-model"), + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Mock response"}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + ) + + router_app = create_router_app(router_config) + data_proxy_app = create_data_proxy_app( + DataProxyConfig( + host="127.0.0.1", + port=18082, + backend_addr="http://mock-sglang:30000", + tokenizer_path="mock-tokenizer", + request_timeout=10.0, + admin_api_key=ADMIN_KEY, + ) + ) + data_proxy_app.state.config = DataProxyConfig( + host="127.0.0.1", + port=18082, + backend_addr="http://mock-sglang:30000", + tokenizer_path="mock-tokenizer", + request_timeout=10.0, + admin_api_key=ADMIN_KEY, + ) + store = SessionStore() + store.set_admin_key(ADMIN_KEY) + data_proxy_app.state.session_store = store + data_proxy_app.state.version = 0 + gateway_app = create_gateway_app( + GatewayConfig( + host="127.0.0.1", + port=18080, + admin_api_key=ADMIN_KEY, + router_addr="http://router", + router_timeout=2.0, + forward_timeout=30.0, + ) + ) + + router_transport = httpx.ASGITransport(app=router_app) + proxy_transport = httpx.ASGITransport(app=data_proxy_app) + gateway_transport = httpx.ASGITransport(app=gateway_app) + external_transport = httpx.ASGITransport(app=mock_external_app) + + async with ( + httpx.AsyncClient( + transport=router_transport, base_url="http://router" + ) as router_client, + httpx.AsyncClient( + transport=proxy_transport, base_url="http://worker-1:18082" + ) as data_proxy_client, + httpx.AsyncClient( + transport=gateway_transport, base_url="http://gateway" + ) as gateway_client, + httpx.AsyncClient( + transport=external_transport, base_url="http://mock-external" + ) as external_client, + ): + await router_client.post( + "/register", + json={"worker_addr": WORKER_ADDR}, + headers=admin_headers(), + ) + + async def _register_model_in_router( + router_addr: str, + model: str, + url: str, + api_key: str | None, + data_proxy_addrs: list[str], + admin_api_key: str, + timeout: float, + ) -> dict: + resp = await router_client.post( + "/register_model", + json={ + "model": model, + "url": url, + "api_key": api_key, + "data_proxy_addrs": data_proxy_addrs, + }, + headers=admin_headers(), + ) + resp.raise_for_status() + return resp.json() + + async def _query_router( + router_addr: str, + api_key: str | None = None, + path: str | None = None, + timeout: float = 2.0, + *, + session_id: str | None = None, + admin_api_key: str | None = None, + model: str | None = None, + ) -> str: + payload: dict = {} + if model is not None: + payload["model"] = model + if session_id is not None: + payload["session_id"] = session_id + elif api_key is not None: + payload["api_key"] = api_key + resp = await router_client.post( + "/route", + json=payload, + headers=admin_headers(), + ) + if resp.status_code in (404, 503): + raise RouterKeyRejectedError("routing failed", resp.status_code) + resp.raise_for_status() + return resp.json()["worker_addr"] + + async def _forward_request( + upstream_url: str, + body: bytes, + headers: dict[str, str], + timeout: float, + ) -> httpx.Response: + if upstream_url.startswith(WORKER_ADDR): + path = upstream_url.removeprefix(WORKER_ADDR) + return await data_proxy_client.post(path, content=body, headers=headers) + return httpx.Response(500, json={"error": "unexpected upstream"}) + + class _ExternalClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, json=None, headers=None): + assert url == "http://mock-external/chat/completions" + return await external_client.post( + "/chat/completions", + json=json, + headers=headers, + ) + + with ( + patch( + f"{ROUTER_MODULE}.register_model_in_router", + new=AsyncMock(side_effect=_register_model_in_router), + ), + patch( + f"{ROUTER_MODULE}.query_router", + new=AsyncMock(side_effect=_query_router), + ), + patch( + f"{ROUTER_MODULE}.forward_request", + new=AsyncMock(side_effect=_forward_request), + ), + patch( + "areal.experimental.inference_service.data_proxy.app.httpx.AsyncClient", + _ExternalClient, + ), + ): + reg = await gateway_client.post( + "/register_model", + json={ + "model": "ext-1", + "url": "http://mock-external", + }, + headers=admin_headers(), + ) + assert reg.status_code == 200 + + chat = await gateway_client.post( + "/chat/completions", + json={ + "model": "ext-1", + "messages": [{"role": "user", "content": "hello"}], + }, + headers=admin_headers(), + ) + assert chat.status_code == 200 + assert chat.json()["id"] == "chatcmpl-mock" + + set_reward = await gateway_client.post( + "/rl/set_reward", + json={"reward": 1.0}, + headers=admin_headers(), + ) + assert set_reward.status_code == 200 + assert set_reward.json()["trajectory_ready"] is True + + exported = await gateway_client.post( + "/export_trajectories", + json={"session_id": "__hitl__"}, + headers=admin_headers(), + ) + assert exported.status_code == 200 + payload = exported.json() + assert len(payload["interactions"]) == 1 + + interaction = next(iter(payload["interactions"].values())) + assert interaction["messages"][0]["content"] == "hello" + cached_response = json.loads( + interaction["output_message_list"][0]["content"] + ) + assert ( + cached_response["choices"][0]["message"]["content"] == "Mock response" + ) diff --git a/tests/experimental/inference_service/test_ipv6_entrypoints.py b/tests/experimental/inference_service/test_ipv6_entrypoints.py index ed0c0463ce..d382042c06 100644 --- a/tests/experimental/inference_service/test_ipv6_entrypoints.py +++ b/tests/experimental/inference_service/test_ipv6_entrypoints.py @@ -20,6 +20,10 @@ def test_data_proxy_main_formats_ipv6_serving_addr(): set_reward_finish_timeout=0.0, admin_api_key="admin-key", callback_server_addr="http://[::1]:19000", + tool_call_parser="qwen", + reasoning_parser="qwen3", + engine_max_tokens=None, + chat_template_type="hf", ) with ( diff --git a/tests/experimental/inference_service/test_online_stack.py b/tests/experimental/inference_service/test_online_stack.py index 3935777f07..93fd37f136 100644 --- a/tests/experimental/inference_service/test_online_stack.py +++ b/tests/experimental/inference_service/test_online_stack.py @@ -176,6 +176,7 @@ async def _query_router( *, session_id: str | None = None, admin_api_key: str | None = None, + model: str | None = None, ) -> str: del router_addr, timeout payload: dict[str, str] = {} From d37095aea00942f392a00072df9a4a4e0e4ff9f7 Mon Sep 17 00:00:00 2001 From: WeiHaocheng <20514172+WeiHaocheng@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:29:56 +0800 Subject: [PATCH 2/3] feat: add scaffolding rollout workflow (#1064) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add scaffolding rollout workflow Key design: https://github.com/inclusionAI/AReaL/issues/818 Co-Authored-By: narutolhy Co-Authored-By: Claude Opus 4.6 (1M context) * style(examples): fix mdformat line wrapping in scaffolding README --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: 博惟 --- areal/experimental/openai/cache.py | 16 + areal/reward/__init__.py | 80 +- examples/scaffolding/README.md | 228 ++++ examples/scaffolding/__init__.py | 46 + examples/scaffolding/_compat.py | 60 + examples/scaffolding/controllers.py | 1070 +++++++++++++++++ examples/scaffolding/core/__init__.py | 63 + examples/scaffolding/core/controller.py | 207 ++++ examples/scaffolding/core/math_utils.py | 89 ++ examples/scaffolding/core/result.py | 84 ++ examples/scaffolding/core/scaffolding_llm.py | 230 ++++ examples/scaffolding/core/task.py | 270 +++++ examples/scaffolding/core/task_collection.py | 492 ++++++++ examples/scaffolding/core/worker.py | 171 +++ examples/scaffolding/fake_tools.py | 72 ++ .../scaffolding/gsm8k_rlvr_scaffolding.py | 97 ++ .../scaffolding/gsm8k_rlvr_scaffolding.yaml | 178 +++ .../gsm8k_rlvr_scaffolding_2nodes.yaml | 187 +++ examples/scaffolding/real_tools.py | 161 +++ .../scaffolding/search_agent_controller.py | 241 ++++ examples/scaffolding/search_reward.py | 71 ++ examples/scaffolding/search_scaffolding.py | 321 +++++ examples/scaffolding/search_scaffolding.yaml | 178 +++ examples/scaffolding/task.py | 233 ++++ examples/scaffolding/tests/__init__.py | 0 .../scaffolding/tests/test_controllers.py | 935 ++++++++++++++ .../tests/test_scaffolding_llm_integration.py | 1060 ++++++++++++++++ .../scaffolding/tests/test_self_contained.py | 795 ++++++++++++ examples/scaffolding/tests/test_worker.py | 205 ++++ examples/scaffolding/worker.py | 249 ++++ examples/scaffolding/workflow.py | 217 ++++ 31 files changed, 8286 insertions(+), 20 deletions(-) create mode 100644 examples/scaffolding/README.md create mode 100644 examples/scaffolding/__init__.py create mode 100644 examples/scaffolding/_compat.py create mode 100644 examples/scaffolding/controllers.py create mode 100644 examples/scaffolding/core/__init__.py create mode 100644 examples/scaffolding/core/controller.py create mode 100644 examples/scaffolding/core/math_utils.py create mode 100644 examples/scaffolding/core/result.py create mode 100644 examples/scaffolding/core/scaffolding_llm.py create mode 100644 examples/scaffolding/core/task.py create mode 100644 examples/scaffolding/core/task_collection.py create mode 100644 examples/scaffolding/core/worker.py create mode 100644 examples/scaffolding/fake_tools.py create mode 100644 examples/scaffolding/gsm8k_rlvr_scaffolding.py create mode 100644 examples/scaffolding/gsm8k_rlvr_scaffolding.yaml create mode 100644 examples/scaffolding/gsm8k_rlvr_scaffolding_2nodes.yaml create mode 100644 examples/scaffolding/real_tools.py create mode 100644 examples/scaffolding/search_agent_controller.py create mode 100644 examples/scaffolding/search_reward.py create mode 100644 examples/scaffolding/search_scaffolding.py create mode 100644 examples/scaffolding/search_scaffolding.yaml create mode 100644 examples/scaffolding/task.py create mode 100644 examples/scaffolding/tests/__init__.py create mode 100644 examples/scaffolding/tests/test_controllers.py create mode 100644 examples/scaffolding/tests/test_scaffolding_llm_integration.py create mode 100644 examples/scaffolding/tests/test_self_contained.py create mode 100644 examples/scaffolding/tests/test_worker.py create mode 100644 examples/scaffolding/worker.py create mode 100644 examples/scaffolding/workflow.py diff --git a/areal/experimental/openai/cache.py b/areal/experimental/openai/cache.py index a28fbe7fb9..20fee73ca8 100644 --- a/areal/experimental/openai/cache.py +++ b/areal/experimental/openai/cache.py @@ -17,6 +17,22 @@ def __init__(self, *args, **kwargs): self._total_reward = 0.0 self._lock = threading.Lock() + def __deepcopy__(self, memo): + """Allow deep-copy of the empty cache. + + ``threading.Lock`` cannot be deep-copied. Controllers that hold + an ``InteractionCache`` (e.g. ``ChatTracer``) are cloned via + ``Controller.clone()`` (``copy.deepcopy``). The cache must be + empty at clone time; a non-empty cache indicates a bug in the + caller. + """ + assert len(self) == 0, ( + f"InteractionCache must be empty when deep-copied, but has {len(self)} items" + ) + new = InteractionCache() + memo[id(self)] = new + return new + @property def last_interaction_id(self) -> str: return next(reversed(self)) diff --git a/areal/reward/__init__.py b/areal/reward/__init__.py index 7396261bdd..b38e9fd147 100644 --- a/areal/reward/__init__.py +++ b/areal/reward/__init__.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -from math_verify.metric import math_metric -from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig +import concurrent.futures + +from math_verify.grader import verify as math_verify_verify +from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig, parse from areal.utils import logging @@ -29,38 +31,76 @@ def get_custom_reward_fn(path: str, **kwargs): class MathVerifyWorker: """Thin wrapper over math_verify with configurable extraction/precision. + Uses ``parse()`` + ``verify()`` directly instead of ``math_metric()`` + so that signal-based timeouts can be disabled (``parsing_timeout=None``, + ``timeout_seconds=None``). This avoids ``signal.alarm()`` which only + works in the main thread. A thread-safe timeout is enforced via + ``concurrent.futures`` instead. + Args: try_extract_without_anchor: When False, only answers with explicit anchors (e.g., "answer = 1", "final answer = 1") are matched. When True, any numeric string in the text may be extracted. precision: Number of significant digits that must match. + timeout: Thread-safe timeout in seconds for the entire verify call + (parsing + comparison). ``None`` disables the timeout. Notes: Tune these knobs based on dataset format and model output style. """ - def __init__(self, try_extract_without_anchor=True, precision: int = 6): - self.verify_func = math_metric( - gold_extraction_target=( - ExprExtractionConfig( - try_extract_without_anchor=try_extract_without_anchor - ), - LatexExtractionConfig(), - ), - pred_extraction_target=( - ExprExtractionConfig( - try_extract_without_anchor=try_extract_without_anchor - ), - LatexExtractionConfig(), - ), - precision=precision, + def __init__( + self, + try_extract_without_anchor=True, + precision: int = 6, + timeout: float | None = 5.0, + ): + self.gold_extraction_target = ( + ExprExtractionConfig(try_extract_without_anchor=try_extract_without_anchor), + LatexExtractionConfig(), + ) + self.pred_extraction_target = ( + ExprExtractionConfig(try_extract_without_anchor=try_extract_without_anchor), + LatexExtractionConfig(), + ) + self.precision = precision + self.timeout = timeout + + def _verify_impl(self, response: str, ground_truth: str) -> float: + """Core verification logic without timeout wrapper.""" + gold_parsed = parse( + ground_truth, + extraction_config=self.gold_extraction_target, + parsing_timeout=None, + ) + pred_parsed = parse( + response, + extraction_config=self.pred_extraction_target, + parsing_timeout=None, ) + if not gold_parsed or not pred_parsed: + return 0.0 + result = math_verify_verify( + gold_parsed, + pred_parsed, + float_rounding=self.precision, + timeout_seconds=None, + ) + return 1.0 if result else 0.0 def verify(self, response: str, ground_truth: str) -> float: - # ground_truth_parsable = "\\boxed{" + ground_truth + "}" try: - ret_score, _ = self.verify_func([ground_truth], [response]) - return float(ret_score) + if self.timeout is None: + return self._verify_impl(response, ground_truth) + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(self._verify_impl, response, ground_truth) + return future.result(timeout=self.timeout) + except concurrent.futures.TimeoutError: + logger.warning( + f"Timeout ({self.timeout}s) in MathVerifyWorker.verify for " + f"response={response!r} and ground_truth={ground_truth!r}", + ) + return 0.0 except Exception: logger.warning( f"Exception in MathVerifyWorker.verify for response={response} and ground_truth={ground_truth}", diff --git a/examples/scaffolding/README.md b/examples/scaffolding/README.md new file mode 100644 index 0000000000..eae9e1b9fb --- /dev/null +++ b/examples/scaffolding/README.md @@ -0,0 +1,228 @@ +# Scaffolding Framework Examples for AReaL + +This directory contains examples demonstrating how to use the Scaffolding framework with +AReaL for reinforcement learning training. + +## Overview + +The scaffolding framework provides a modular and extensible way to compose various +methods with RL training. It decouples the inference logic (Controllers) from the +execution backend (Workers), enabling flexible composition of different methods. With +Scaffolding, we can flexibly compose various rollout, reward, and trajectory tracing +methods. + +### Key Components + +1. **Controller**: Defines the inference-time compute logic (e.g., generation, reward + computation) +1. **Worker**: Handles the actual execution of tasks (e.g., TRT-LLM, OpenAI API) +1. **ScaffoldingLlm**: Orchestrates controllers and workers together +1. **ScaffoldingWorkflow**: Wraps ScaffoldingLlm as a RolloutWorkflow for AReaL training + +### AReaL-Specific Components + +The following components are implemented in `examples/scaffolding/`: + +- **`CreateWorkerFromEngine`**: Creates a scaffolding Worker from AReaL's + InferenceEngine (e.g., RemoteSGLangEngine). The returned Worker is similar to + scaffolding's `OpenaiWorker` but integrated with AReaL's engine. + +- **`RLVRRewardController`**: A Controller that computes rewards for generated samples + using verifiable reward functions (e.g., math answer verification). + +- **`PipelineTrajectoryMaker`**: A Controller that composes generation and reward + controllers into a pipeline that produces training trajectories. + +- **`ScaffoldingWorkflow`**: A `RolloutWorkflow` implementation that wraps + ScaffoldingLlm for integration with AReaL's training pipeline. + +## RLVR Example with GSM8K + +### Quick Start + +```bash +python examples/scaffolding/gsm8k_rlvr_scaffolding.py \ + --config examples/scaffolding/gsm8k_rlvr_scaffolding.yaml +``` + +### Architecture + +The scaffolding workflow follows this pattern from the RFC: + +```python +# Step 1: Create Worker from the SGLang engine +rollout_worker = CreateWorkerFromEngine(engine) + +# Step 2: Create controllers +rollout_controller = NativeGenerationController() +reward_controller = RLVRRewardController(gsm8k_reward_fn) + +# Step 3: Create trajectory maker (composes the controllers) +trajectory_maker = PipelineTrajectoryMaker(rollout_controller, reward_controller) + +# Step 4: Create ScaffoldingLlm (orchestrates controllers with workers) +scaffolding_llm = ScaffoldingLlm( + trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: rollout_worker}, +) + +# Step 5: Create ScaffoldingWorkflow (wraps as RolloutWorkflow) +scaffolding_workflow = ScaffoldingWorkflow(scaffolding_llm) +``` + +### Data Flow Diagram + +``` + ┌─────────────────────────────────────────────────┐ + │ ScaffoldingWorkflow │ + │ │ + │ ┌───────────────────────────────────────────┐ │ + │ │ ScaffoldingLlm │ │ + │ │ │ │ + │ │ ┌─────────────────────────────────────┐ │ │ + │ │ │ PipelineTrajectoryMaker │ │ │ + │ │ │ │ │ │ + │ │ │ ┌───────────────────────────────┐ │ │ │ +Data ─────────────────────────┼──┼──┼──► NativeGenerationController │ │ │ │ + │ │ │ │ (from scaffolding.core) │ │ │ │ + │ │ │ └───────────────┬───────────────┘ │ │ │ + │ │ │ │ │ │ │ + │ │ │ ▼ │ │ │ + │ │ │ ┌───────────────────────────────┐ │ │ │ + │ │ │ │ RLVRRewardController │ │ │ │ + │ │ │ │ (from areal.experimental) │ │ │ │ + │ │ │ └───────────────┬───────────────┘ │ │ │ + │ │ │ │ │ │ │ + │ │ └──────────────────┼──────────────────┘ │ │ + │ │ │ │ │ + │ └─────────────────────┼─────────────────────┘ │ + │ │ │ + └────────────────────────┼────────────────────────┘ + │ + ▼ Trajectories + ┌─────────────────────────────┐ + │ PPOTrainer │ + │ (GRPO/PPO Training) │ + └─────────────────────────────┘ + │ + via CreateWorkerFromEngine │ + ▼ + ┌─────────────────────────────────────────┐ + │ RemoteSGLangEngine │ + │ (AReaL Inference Backend) │ + └─────────────────────────────────────────┘ +``` + +### How It Works + +1. **Engine Initialization**: `RemoteSGLangEngine` is initialized with the rollout + configuration and connected to the model server. + +1. **Worker Creation**: `CreateWorkerFromEngine(engine)` wraps the engine into a + scaffolding-compatible Worker. This allows scaffolding controllers to use AReaL's + inference backends. + +1. **Controller Pipeline**: + + - `NativeGenerationController()`: Handles text generation by yielding + `GenerationTask` objects to the Worker. + - `RLVRRewardController(reward_fn)`: Computes rewards for generated samples using the + provided reward function. + - `PipelineTrajectoryMaker(gen_ctrl, reward_ctrl)`: Composes these controllers into a + pipeline that produces training trajectories. + +1. **ScaffoldingLlm**: Orchestrates the trajectory maker with the worker, handling the + async execution of tasks. + +1. **ScaffoldingWorkflow**: Wraps the ScaffoldingLlm as a `RolloutWorkflow` that can be + used directly with AReaL's `PPOTrainer`. + +1. **Training**: The trainer calls the workflow to generate trajectories, which are then + used for GRPO/PPO training. + +### Configuration + +See `gsm8k_rlvr_scaffolding.yaml` for the full configuration. Key options: + +```yaml +# Model configuration +pretrain_path: Qwen/Qwen2.5-3B-Instruct +tokenizer_path: Qwen/Qwen2.5-3B-Instruct + +# Generation hyperparameters +gconfig: + max_new_tokens: 1024 + temperature: 1.0 + top_p: 1.0 + n_samples: 8 + +# Inference engine configuration +engine: + type: sglang + tp: 1 + max_model_len: 4096 +``` + +## Extending the Framework + +### Custom Reward Controllers + +You can create custom reward controllers by subclassing the base Controller: + +```python +from examples.scaffolding._compat import Controller + +class CustomRewardController(Controller): + def __init__(self, reward_fn): + super().__init__() + self.reward_fn = reward_fn + + def process(self, tasks, **kwargs): + # Compute rewards for completed generation tasks + for task in tasks: + reward = self.reward_fn( + prompt=task.input_str, + completion=task.output_str, + **kwargs + ) + task.customized_result_fields["reward"] = reward + yield tasks +``` + +### Custom Trajectory Makers + +For different RL algorithms, you may need different trajectory formats: + +```python +from examples.scaffolding._compat import Controller +import torch + +class CustomTrajectoryMaker(Controller): + def __init__(self, generation_controller, reward_controller): + super().__init__() + self.generation_controller = generation_controller + self.reward_controller = reward_controller + + def process(self, tasks, **kwargs): + # Run generation + yield from self.generation_controller.process(tasks, **kwargs) + + # Run reward computation + yield from self.reward_controller.process(tasks, **kwargs) + + # Build trajectories + trajectories = [] + for task in tasks: + trajectory = { + "input_ids": torch.tensor(task.output_tokens), + "rewards": torch.tensor(task.customized_result_fields["reward"]), + } + trajectories.append(trajectory) + yield trajectories +``` + +## References + +- [TensorRT-LLM Scaffolding README](https://github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/scaffolding) +- [AReaL Workflow Documentation](../../docs/customization/workflow.md) +- [RFC: Scaffolding Integration](https://github.com/inclusionAI/AReaL/issues/818) diff --git a/examples/scaffolding/__init__.py b/examples/scaffolding/__init__.py new file mode 100644 index 0000000000..7856147097 --- /dev/null +++ b/examples/scaffolding/__init__.py @@ -0,0 +1,46 @@ +""" +Scaffolding Framework Example for AReaL. + +This package provides the Scaffolding framework for composing inference-time +compute methods with AReaL's RL training pipeline. Core scaffolding primitives +are vendored from TensorRT-LLM under ``examples.scaffolding.core``. + +Key Components: +- ScaffoldingWorkflow: RolloutWorkflow implementation that wraps ScaffoldingLlm +- RLVRRewardTask: Task for computing verifiable rewards +- RLVRRewardController: Controller for computing verifiable rewards +- PipelineTrajectoryMaker: Controller for composing generation and reward pipelines +- ChatTracer: TaskCollection for tracing multi-turn chat conversations +- TraceTrajectoryMaker: Controller that traces ChatTask objects during rollout +- TraceGenerationTask: Task for tracing multi-turn generation +- ChatRewardTask: Task for computing rewards on traced interactions +- CreateWorkerFromEngine: Creates a scaffolding Worker from AReaL's InferenceEngine +- SGLangWorker: Worker implementation for SGLang engines +""" + +from examples.scaffolding.controllers import ( + ChatTracer, + PipelineTrajectoryMaker, + RLVRRewardController, + TraceTrajectoryMaker, +) +from examples.scaffolding.task import ( + ChatRewardTask, + RLVRRewardTask, + TraceGenerationTask, +) +from examples.scaffolding.worker import CreateWorkerFromEngine, SGLangWorker +from examples.scaffolding.workflow import ScaffoldingWorkflow + +__all__ = [ + "ScaffoldingWorkflow", + "RLVRRewardTask", + "RLVRRewardController", + "PipelineTrajectoryMaker", + "ChatTracer", + "TraceTrajectoryMaker", + "TraceGenerationTask", + "ChatRewardTask", + "CreateWorkerFromEngine", + "SGLangWorker", +] diff --git a/examples/scaffolding/_compat.py b/examples/scaffolding/_compat.py new file mode 100644 index 0000000000..f5ea215714 --- /dev/null +++ b/examples/scaffolding/_compat.py @@ -0,0 +1,60 @@ +"""Scaffolding framework primitives vendored from TensorRT-LLM. + +This module re-exports the core scaffolding classes from the vendored copy +at ``core``, so the rest of the example can import them from a single location. +""" + +from .core.controller import ( + BestOfNController, + Controller, + MajorityVoteController, + NativeChatController, + NativeGenerationController, + NativeRewardController, + ParallelProcess, +) +from .core.result import ScaffoldingOutput +from .core.scaffolding_llm import ScaffoldingLlm +from .core.task import ( + AssistantMessage, + ChatTask, + GenerationTask, + OpenAIToolDescription, + RoleMessage, + StreamGenerationTask, + SystemMessage, + Task, + TaskStatus, + UserMessage, +) +from .core.task_collection import ( + TaskCollection, + with_task_collection, +) +from .core.worker import OpenaiWorker, Worker + +__all__ = [ + "AssistantMessage", + "BestOfNController", + "ChatTask", + "Controller", + "GenerationTask", + "MajorityVoteController", + "NativeChatController", + "NativeGenerationController", + "NativeRewardController", + "OpenAIToolDescription", + "OpenaiWorker", + "ParallelProcess", + "RoleMessage", + "ScaffoldingLlm", + "ScaffoldingOutput", + "StreamGenerationTask", + "SystemMessage", + "Task", + "TaskCollection", + "TaskStatus", + "UserMessage", + "Worker", + "with_task_collection", +] diff --git a/examples/scaffolding/controllers.py b/examples/scaffolding/controllers.py new file mode 100644 index 0000000000..30bb616f06 --- /dev/null +++ b/examples/scaffolding/controllers.py @@ -0,0 +1,1070 @@ +""" +RLVR Controllers for Scaffolding Framework. + +This module provides controllers for RLVR (Reinforcement Learning with Verifiable +Rewards) that integrate with the scaffolding framework. + +Key Components: +- RLVRRewardController: Controller that processes reward computation +- LLMJudgeController: Controller that uses an LLM to judge answer correctness +- PipelineTrajectoryMaker: Controller that composes generation and reward pipelines +- MultiTurnChatController: Controller for multi-turn chat with reflection +- ChatTracer: TaskCollection for tracing multi-turn chat conversations +- TraceTrajectoryMaker: Controller that traces ChatTask objects during rollout +""" + +from __future__ import annotations + +import ast +import json +import re +from collections.abc import Callable +from enum import Enum +from typing import Any + +from areal.api.reward_api import AsyncRewardWrapper +from areal.experimental.openai.cache import InteractionCache +from areal.experimental.openai.types import InteractionWithTokenLogpReward +from areal.utils import logging + +from ._compat import ( + ChatTask, + Controller, + GenerationTask, + NativeGenerationController, + RoleMessage, + Task, + TaskCollection, + UserMessage, + with_task_collection, +) +from .task import ( + ChatRewardTask, + RLVRRewardTask, + TraceGenerationTask, +) + +logger = logging.getLogger("RLVRControllers") + + +class RLVRRewardController(Controller): + """Controller for computing RLVR (verifiable) rewards. + + This controller processes RLVRRewardTask objects and computes rewards + using a provided reward function. The reward function should verify + whether the generated answer is correct. + + The reward computation follows the pattern from RLVRWorkflow._compute_rewards: + 1. Decode output tokens to string (if needed) + 2. Call reward_fn(prompt_str, completion_str, input_tokens, output_tokens, **task_data) + 3. Store the reward in the task and update the interaction object + + Parameters + ---------- + reward_fn : Callable + The reward function that takes (prompt, completions, prompt_ids, completion_ids, **data) + and returns a reward value (typically 0 or 1 for verifiable rewards). + + Example + ------- + ```python + from areal.reward.gsm8k import gsm8k_reward_fn + + reward_controller = RLVRRewardController(gsm8k_reward_fn) + ``` + """ + + class WorkerTag(Enum): + """Worker tag for reward computation.""" + + REWARD = "rlvr_reward" + + def __init__(self, reward_fn: Callable[..., Any]): + """Initialize the RLVR reward controller. + + Parameters + ---------- + reward_fn : Callable + The reward function for verifying answers. + """ + super().__init__() + self.reward_fn = reward_fn + self.async_reward_fn = AsyncRewardWrapper(reward_fn) + self.scores: list[float] | None = None + + def __deepcopy__(self, memo): + """Create a new RLVRRewardController with the same reward function. + + ``AsyncRewardWrapper`` contains ``ProcessPoolExecutor`` and + ``threading.Lock`` which cannot be deep-copied. Instead of + copying those objects, we create a fresh controller that shares + the same underlying executor pool via ``AsyncRewardWrapper``'s + class-level state. + """ + new = RLVRRewardController(self.reward_fn) + memo[id(self)] = new + return new + + def process(self, tasks: list[Task], **kwargs) -> Any: + """Process reward tasks and compute rewards. + + This method computes rewards for each task using the reward function. + The rewards are stored in: + 1. task.reward - the computed reward value + 2. task.interaction.reward - if an interaction object is provided + 3. self.scores - list of all computed rewards + + Parameters + ---------- + tasks : list[Task] + List of RLVRRewardTask objects to process. + **kwargs + Additional keyword arguments. + + Yields + ------ + list[Task] + The processed tasks with rewards computed. + """ + # Mark tasks with worker tag (for potential worker-based execution) + for task in tasks: + task.worker_tag = self.WorkerTag.REWARD + + # Compute rewards synchronously + # Note: For async execution, this would be handled by a worker + self.scores = [] + for task in tasks: + if isinstance(task, RLVRRewardTask): + reward = self._compute_reward(task) + task.reward = reward + self.scores.append(reward) + + # Update the interaction object if provided + if task.interaction is not None: + task.interaction.reward = reward + elif isinstance(task, GenerationTask): + # For generation tasks, compute reward from customized fields + reward = self._compute_reward_from_generation_task(task, **kwargs) + task.customized_result_fields["reward"] = reward + self.scores.append(reward) + + yield tasks + + def _compute_reward(self, task: RLVRRewardTask) -> float: + """Compute reward for an RLVR reward task. + + Parameters + ---------- + task : RLVRRewardTask + The reward task containing prompt, completion, and task data. + + Returns + ------- + float + The computed reward value. + """ + reward = self.reward_fn( + task.prompt_str, + task.completion_str, + task.input_tokens, + task.output_tokens, + **task.task_data, + ) + return float(reward) + + def _compute_reward_from_generation_task( + self, task: GenerationTask, **kwargs + ) -> float: + """Compute reward from a generation task. + + Parameters + ---------- + task : GenerationTask + The completed generation task. + **kwargs + Should contain 'task_data' with ground truth. + + Returns + ------- + float + The computed reward value. + """ + task_data = kwargs.get("task_data", {}) + prompt_str = kwargs.get("prompt_str", task.input_str or "") + + reward = self.reward_fn( + prompt_str, + task.output_str or "", + list(task.input_tokens or []), + list(task.output_tokens or []), + **task_data, + ) + return float(reward) + + async def aprocess(self, tasks: list[Task], **kwargs) -> Any: + """Process reward tasks asynchronously. + + This method computes rewards asynchronously using AsyncRewardWrapper. + + Parameters + ---------- + tasks : list[Task] + List of RLVRRewardTask objects to process. + **kwargs + Additional keyword arguments. + + Returns + ------- + list[Task] + The processed tasks with rewards computed. + """ + # Mark tasks with worker tag + for task in tasks: + task.worker_tag = self.WorkerTag.REWARD + + # Compute rewards asynchronously + self.scores = [] + for task in tasks: + if isinstance(task, RLVRRewardTask): + reward = await self._acompute_reward(task) + task.reward = reward + self.scores.append(reward) + + # Update the interaction object if provided + if task.interaction is not None: + task.interaction.reward = reward + + return tasks + + async def _acompute_reward(self, task: RLVRRewardTask) -> float: + """Compute reward asynchronously for an RLVR reward task. + + Parameters + ---------- + task : RLVRRewardTask + The reward task containing prompt, completion, and task data. + + Returns + ------- + float + The computed reward value. + """ + reward = await self.async_reward_fn( + task.prompt_str, + task.completion_str, + task.input_tokens, + task.output_tokens, + **task.task_data, + ) + return float(reward) + + +class ChatTracer(TaskCollection): + """TaskCollection for tracing multi-turn chat conversations. + + This class traces ChatTask objects during the controller's process execution. + A multi-turn conversation uses the same ChatTask object across multiple yields, + allowing us to track the evolution of the conversation. + + The tracer: + 1. In `before_yield`: Records the state of ChatTask before worker execution + 2. In `after_yield`: Captures the new messages added by the worker and creates + InteractionWithTokenLogpReward objects + + The traced results can be exported via `get_trace_results()`, which returns + a dict[str, InteractionWithTokenLogpReward] similar to client.py's export_interactions. + + Parameters + ---------- + reward_discount : float + Discount factor for backward reward propagation across turns. + export_style : str + Export style for interactions: 'concat' (tree structure) or 'individual'. + + Example + ------- + ```python + tracer = ChatTracer(reward_discount=0.9, export_style="individual") + # Used via with_task_collection decorator or TraceTrajectoryMaker + ``` + """ + + def __init__( + self, + reward_discount: float = 1.0, + export_style: str = "individual", + ): + super().__init__() + self.reward_discount = reward_discount + self.export_style = export_style + + # Cache for storing interactions, similar to InteractionCache in client.py + self._cache = InteractionCache() + + def before_yield(self, tasks: list[Task]): + """Called before tasks are yielded to workers. + + Parameters + ---------- + tasks : list[Task] + List of tasks about to be yielded. + """ + pass + + def after_yield(self, tasks: list[Task]): + """Called after tasks return from workers. + + Creates InteractionWithTokenLogpReward objects for each ChatTask. + Uses task.completion.id as the interaction ID. + + Parameters + ---------- + tasks : list[Task] + List of tasks that have been processed by workers. + """ + for task in tasks: + if not isinstance(task, ChatTask): + continue + + # Skip tasks without a completion (e.g. not yet processed by worker) + if not hasattr(task, "completion") or task.completion is None: + continue + + interaction = self._create_interaction_from_chat_task(task) + # Use completion.id as the interaction key + completion_id = task.completion.id + self._cache[completion_id] = interaction + + def _create_interaction_from_chat_task( + self, + task: ChatTask, + ) -> InteractionWithTokenLogpReward: + """Create an InteractionWithTokenLogpReward from a ChatTask. + + Parameters + ---------- + task : ChatTask + The ChatTask. Must contain a `completion` attribute + with the ChatCompletion object. + + Returns + ------- + InteractionWithTokenLogpReward + The interaction object capturing this turn. + """ + from areal.api.io_struct import ModelResponse + + # Extract all messages + messages = [ + msg.to_dict() if hasattr(msg, "to_dict") else msg for msg in task.messages + ] + + # Create ModelResponse from task data + input_tokens = list(task.input_tokens or []) + output_tokens = list(task.output_tokens or []) + + model_response = ModelResponse( + input_tokens=input_tokens, + output_tokens=output_tokens, + output_logprobs=[0.0] * len(output_tokens), + output_versions=[-1] * len(output_tokens), + ) + + # Get completion from task (ChatTask will contain the ChatCompletion) + completion = task.completion + + interaction = InteractionWithTokenLogpReward( + model_response=model_response, + reward=None, + messages=messages, + output_message_list=[], + completion=completion, + chat_template_type=self.export_style, + ) + + return interaction + + def get_trace_results(self) -> dict[str, InteractionWithTokenLogpReward]: + """Export traced interactions. + + Returns the traced interactions in the specified export style. + Applies reward discount before export if configured. + + Returns + ------- + dict[str, InteractionWithTokenLogpReward] + Dictionary mapping interaction IDs to their data. + + See Also + -------- + client.py : export_interactions method for similar functionality + """ + if len(self._cache) == 0: + return {} + + return self._cache.export_interactions( + style=self.export_style, + reward_discount=self.reward_discount, + ) + + def clear(self) -> None: + """Clear all traced data.""" + self._cache.clear() + + +class PipelineTrajectoryMaker(Controller): + """Controller that composes generation and reward controllers into a pipeline. + + This controller orchestrates the full RLVR pipeline: + 1. Run generation via the generation controller + 2. Compute rewards via the reward controller + 3. Assemble results into InteractionWithTokenLogpReward objects + + Parameters + ---------- + generation_controller : Controller + The controller for text generation (e.g., NativeGenerationController). + reward_controller : RLVRRewardController + The controller for reward computation. + task_data : dict[str, Any] + Task data containing ground truth (e.g., "answer" field) for reward computation. + prompt_str : str + The prompt string used for generation. + + Example + ------- + ```python + from examples.scaffolding._compat import NativeGenerationController + + gen_controller = NativeGenerationController() + reward_controller = RLVRRewardController(gsm8k_reward_fn) + trajectory_maker = PipelineTrajectoryMaker( + gen_controller, + reward_controller, + task_data={"answer": "42"}, + prompt_str="What is the answer?", + ) + ``` + """ + + def __init__( + self, + generation_controller: Controller, + reward_controller: RLVRRewardController, + task_data: dict[str, Any] | None = None, + prompt_str: str = "", + input_tokens: list[int] | None = None, + ): + """Initialize the pipeline trajectory maker. + + Parameters + ---------- + generation_controller : Controller + The generation controller. + reward_controller : RLVRRewardController + The reward controller. + task_data : dict[str, Any], optional + Task data containing ground truth for reward computation. + prompt_str : str, optional + The prompt string used for generation. + input_tokens : list[int], optional + The tokenized input IDs for the prompt. + """ + super().__init__() + self.generation_controller = generation_controller + self.reward_controller = reward_controller + self.task_data = task_data if task_data is not None else {} + self.prompt_str = prompt_str + self.input_tokens = input_tokens if input_tokens is not None else [] + + def process(self, tasks: list[Task], **kwargs) -> Any: + """Process tasks through the generation and reward pipeline. + + Yields task lists only for generation (worker execution). Reward + computation is done locally without yielding to workers. + Interactions are stored in ``task.customized_result_fields["interactions"]`` + for retrieval in ``generate()``. + + Parameters + ---------- + tasks : list[Task] + List of generation tasks to process. + **kwargs + Additional keyword arguments. + + Yields + ------ + list[Task] + Task lists for worker execution (generation only). + """ + # Step 1: Run generation (yields task lists for worker execution) + yield from self.generation_controller.process(tasks, **kwargs) + + # Step 2: Create reward tasks and compute rewards locally + reward_tasks = [] + interactions = {} + + for i, task in enumerate(tasks): + if isinstance(task, GenerationTask): + # Create interaction object + interaction = self._create_interaction_from_task(task) + task_id = f"task_{i}" + interactions[task_id] = interaction + + # Create reward task using constructor-provided task_data and prompt_str + reward_task = RLVRRewardTask.create_from_generation_task( + gen_task=task, + prompt_str=self.prompt_str or task.input_str or "", + task_data=self.task_data, + interaction=interaction, + ) + reward_tasks.append(reward_task) + + # Compute rewards locally (no yield to workers) + for _ in self.reward_controller.process(reward_tasks, **kwargs): + pass + + # Store interactions on tasks for retrieval in generate() + for task in tasks: + task.customized_result_fields["interactions"] = interactions + + def generate(self, prompt: str, **kwargs) -> Any: + """Generate with the full pipeline and return interactions in output. + + Overrides the base ``Controller.generate()`` to: + 1. Set ``input_tokens`` on the task so ``to_tensor_dict()`` works. + 2. Reset ``stop`` to ``None`` so ``NativeGenerationController`` can + set it from ``sampling_params``. + 3. Return a ``ScaffoldingOutput`` with interactions in ``data``. + + Parameters + ---------- + prompt : str + The input prompt string. + **kwargs + Additional keyword arguments. + + Returns + ------- + ScaffoldingOutput + Output with ``data`` containing the interactions dict. + """ + + task = GenerationTask.create_from_prompt(prompt) + if self.input_tokens: + task.input_tokens = self.input_tokens + # Reset stop to None so NativeGenerationController can set from sampling_params + task.stop = None + + yield from self.process([task], **kwargs) + + output = task.create_scaffolding_output() + output.data = task.customized_result_fields.get("interactions") + return output + + def _create_interaction_from_task( + self, task: GenerationTask + ) -> InteractionWithTokenLogpReward: + """Create an InteractionWithTokenLogpReward from a generation task. + + Parameters + ---------- + task : GenerationTask + The completed generation task. + + Returns + ------- + InteractionWithTokenLogpReward + The interaction object with model response data. + """ + from areal.api.io_struct import ModelResponse + + # Build ModelResponse from task data + input_tokens = list(task.input_tokens or []) + output_tokens = list(task.output_tokens or []) + output_logprobs = task.customized_result_fields.get("output_logprobs", []) + output_versions = task.customized_result_fields.get("output_versions", []) + + # Create ModelResponse + model_response = ModelResponse( + input_tokens=input_tokens, + output_tokens=output_tokens, + output_logprobs=list(output_logprobs) + if output_logprobs + else [0.0] * len(output_tokens), + output_versions=list(output_versions) + if output_versions + else [-1] * len(output_tokens), + ) + + # Create interaction + interaction = InteractionWithTokenLogpReward( + model_response=model_response, + reward=None, # Will be set by reward controller + ) + + return interaction + + +class MultiTurnChatController(Controller): + """Controller for multi-turn chat with reflection between turns. + + Handles the chat loop: for each turn, yields a ChatTask to the worker + for generation, then appends a reflection message for non-final turns. + + Per-episode data (``messages``, ``input_tokens``) should be set before + calling ``generate()`` or ``process()``. ``ScaffoldingLlm`` deep-copies + the controller via ``clone()`` so each request gets its own copy. + + Parameters + ---------- + generation_controller : Controller + The controller for text generation (e.g., NativeGenerationController). + max_turns : int + Maximum number of chat turns per episode. + reflection_message : str + Message appended after each non-final turn to prompt retry. + tokenizer : Any + Tokenizer for encoding the final output to token IDs. + messages : list[dict], optional + The original chat messages to create the ChatTask from (set per-episode). + input_tokens : list[int], optional + The tokenized input IDs for the original prompt (set per-episode). + """ + + def __init__( + self, + generation_controller: Controller, + max_turns: int = 2, + reflection_message: str = "", + tokenizer: Any = None, + messages: list[dict] | None = None, + input_tokens: list[int] | None = None, + ): + super().__init__() + self.generation_controller = generation_controller + self.max_turns = max_turns + self.reflection_message = reflection_message + self.tokenizer = tokenizer + self.messages = messages if messages is not None else [] + self.input_tokens = input_tokens if input_tokens is not None else [] + + def process(self, tasks: list[Task], **kwargs) -> Any: + """Run multi-turn chat generation. + + Creates a ChatTask from the stored messages and yields it to the + worker for each turn. Between turns, appends a reflection message. + + Parameters + ---------- + tasks : list[Task] + Ignored; the ChatTask is created from ``self.messages``. + **kwargs + Additional keyword arguments. + + Yields + ------ + list[Task] + Task lists for worker execution (one per turn). + """ + role_messages = [RoleMessage.from_dict(m) for m in self.messages] + chat_task = ChatTask.create_from_messages(role_messages) + # Reset stop so NativeGenerationController can set from sampling_params + chat_task.stop = None + if self.input_tokens: + chat_task.input_tokens = self.input_tokens + + for turn in range(self.max_turns): + yield from self.generation_controller.process([chat_task], **kwargs) + + if turn < self.max_turns - 1: + chat_task.add_message(UserMessage(self.reflection_message)) + + +@with_task_collection("chat_tracer", ChatTracer) +class TraceTrajectoryMaker(Controller): + """Controller that traces ChatTask objects during rollout using ChatTracer. + + This controller uses the @with_task_collection decorator to automatically + apply ChatTracer's before_yield and after_yield hooks around each yield + in the rollout controller's process execution. + + A multi-turn conversation uses the same ChatTask object, which is traced + across all yields. The trace results are stored in the TraceGenerationTask + after processing. + + Parameters + ---------- + rollout_controller : Controller + The controller for rollout (e.g., a chat or agent controller). + reward_controller : Controller + The controller for computing rewards on traced interactions. + + Example + ------- + ```python + from examples.scaffolding._compat import NativeGenerationController + + chat_controller = SomeChatController() + reward_controller = RLVRRewardController(gsm8k_reward_fn) + + trace_maker = TraceTrajectoryMaker( + rollout_controller=chat_controller, + reward_controller=reward_controller, + ) + + # Process tasks + result = trace_maker.generate(prompt) + + # Or use process directly + task = TraceGenerationTask.create_from_prompt(prompt) + for _ in trace_maker.process([task]): + pass + trace_results = task.trace_results + ``` + """ + + def __init__( + self, + rollout_controller: Controller, + reward_controller: Controller, + ): + """Initialize the trace trajectory maker. + + Parameters + ---------- + rollout_controller : Controller + The controller for rollout execution. + reward_controller : Controller + The controller for reward computation. + """ + super().__init__() + self.rollout_controller = rollout_controller + self.reward_controller = reward_controller + + def process(self, tasks: list[Task], **kwargs) -> Any: + """Process tasks through the rollout and reward pipeline with tracing. + + This method: + 1. Extracts the generation_task from the TraceGenerationTask + 2. Runs the rollout_controller.process() with ChatTracer tracing + 3. Gets trace results from the ChatTracer + 4. Creates ChatRewardTask objects for each traced interaction + 5. Runs the reward_controller.process() to compute rewards + 6. Stores the trace results in the original task + + Parameters + ---------- + tasks : list[Task] + List of TraceGenerationTask objects to process. + **kwargs + Additional keyword arguments. + + Yields + ------ + Any + Results from the controllers. + """ + # Get the generation task from the first TraceGenerationTask + task = tasks[0] + if isinstance(task, TraceGenerationTask): + generation_task = task.generation_task + else: + generation_task = task + + # Run rollout with tracing (ChatTracer hooks applied via decorator) + yield from self.rollout_controller.process([generation_task], **kwargs) + + # Get trace results from the ChatTracer (registered via decorator) + chat_tracer = self.task_collections["chat_tracer"] + trace_results = chat_tracer.get_trace_results() + + # Create reward tasks for each traced interaction + reward_tasks = [ + ChatRewardTask.create_from_trace_result(interaction_id, interaction) + for interaction_id, interaction in trace_results.items() + ] + + # Run reward computation — yield from so that controllers like + # LLMJudgeController can send tasks to workers via yield. + if reward_tasks: + yield from self.reward_controller.process(reward_tasks, **kwargs) + + # Update trace_results with computed rewards + for reward_task in reward_tasks: + if ( + reward_task.interaction is not None + and reward_task.reward is not None + ): + reward_task.interaction.reward = reward_task.reward + + # Store trace results in the original task + if isinstance(task, TraceGenerationTask): + task.trace_results = trace_results + + def generate(self, prompt: str, **kwargs) -> Any: + """Generate with tracing from a prompt string. + + Parameters + ---------- + prompt : str + The input prompt. + **kwargs + Additional keyword arguments. + + Returns + ------- + ScaffoldingOutput + Output with trace results in ``data``. + """ + task = TraceGenerationTask.create_from_prompt(prompt) + + yield from self.process([task], **kwargs) + + return task.create_scaffolding_output() + + +def _parse_judge_result(raw_response: str) -> float: + """Parse the LLM judge response and extract a binary reward. + + The judge is expected to return a JSON block with a ``"judgement"`` field + set to ``"correct"`` or ``"incorrect"``. + + Parameters + ---------- + raw_response : str + Raw text response from the judge LLM. + + Returns + ------- + float + 1.0 if the judgement is ``"correct"``, 0.0 otherwise. + """ + mbe = None + for parse_fn in [json.loads, ast.literal_eval]: + try: + mbe = parse_fn(raw_response.split("```json")[-1].split("```")[0].strip()) + break + except Exception: + pass + if mbe is None and '"judgement": "incorrect"' in raw_response: + mbe = {"judgement": "incorrect"} + if mbe is None and '"judgement": "correct"' in raw_response: + mbe = {"judgement": "correct"} + if mbe is None: + logger.warning("Unknown judge result. Raw response: %s", raw_response) + mbe = {"judgement": "unknown"} + return float("judgement" in mbe and mbe["judgement"] == "correct") + + +_JUDGE_PROMPT_TEMPLATE = ( + "You are an evaluation assistant. Please determine if the predicted answer " + "is equivalent to the labeled answer.\n" + "You should first give your rationale for the judgement, and then give your " + "judgement result (i.e., correct or incorrect).\n\n" + "\n" + "question: {question}\n" + "ground truth answers: {gt_answer}\n" + "pred_answer: {pred_answer}\n\n" + "Did the model give an answer **equivalent** to the labeled answer? \n\n" + "The output should in the following json format:\n" + "```json\n" + "{{\n" + ' "rationale": "your rationale for the judgement, as a text",\n' + " \"judgement\": \"your judgement result, can only be 'correct' or 'incorrect'\n" + "}}\n" + "```\n" + "Your output:" +) + + +class LLMJudgeController(Controller): + """Controller that uses an LLM to judge answer correctness. + + Instead of using a deterministic reward function, this controller sends + a judge prompt to the same (or a separate) LLM worker and parses the + response to determine whether the predicted answer is correct. + + This is the scaffolding-framework equivalent of + ``MultiTurnReactAgent.calc_reward_with_llm_judge`` from the + tongyi_deepresearch example. + + Parameters + ---------- + judge_prompt_template : str, optional + The prompt template for the judge. Must contain ``{question}``, + ``{gt_answer}``, and ``{pred_answer}`` placeholders. + max_pred_chars : int + Maximum characters of the predicted answer to include in the + judge prompt (to avoid exceeding context limits). + max_judge_tokens : int + Maximum tokens for the judge LLM response. + """ + + class WorkerTag(Enum): + JUDGE = "llm_judge" + + def __init__( + self, + judge_prompt_template: str | None = None, + max_pred_chars: int = 200, + max_judge_tokens: int = 8192, + ): + super().__init__() + self.judge_prompt_template = judge_prompt_template or _JUDGE_PROMPT_TEMPLATE + self.max_pred_chars = max_pred_chars + self.max_judge_tokens = max_judge_tokens + self.scores: list[float] | None = None + # Per-episode data set before generate(); deep-copied via clone() + self.task_data: dict[str, Any] = {} + + def _build_judge_prompt( + self, + question: str, + ground_truth: str, + prediction: str, + ) -> str: + """Format the judge prompt with the given data. + + Parameters + ---------- + question : str + The original question. + ground_truth : str + The ground-truth answer. + prediction : str + The model's predicted answer (truncated to ``max_pred_chars``). + + Returns + ------- + str + The formatted judge prompt. + """ + return self.judge_prompt_template.format( + question=question, + gt_answer=ground_truth, + pred_answer=prediction[: self.max_pred_chars], + ) + + def _extract_answer_and_data(self, task: Task) -> tuple[str, str, str] | None: + """Extract (question, ground_truth, prediction) from a task. + + Supports ``RLVRRewardTask``, ``ChatRewardTask``, and + ``GenerationTask`` with ``customized_result_fields``. + + For ``ChatRewardTask``, falls back to ``self.task_data`` for + question/answer and extracts the prediction from the last + assistant message in the traced interaction. + + Returns ``None`` if the task type is not recognised. + """ + if isinstance(task, RLVRRewardTask): + question = task.task_data.get("question", "") + gt = task.task_data.get("answer", "") + if isinstance(gt, list): + gt = str(gt[0]) if gt else "" + # Extract from tags if present, else use full completion + match = re.search(r"(.*?)", task.completion_str, re.DOTALL) + pred = match.group(1).strip() if match else task.completion_str + return question, str(gt), pred + + if isinstance(task, ChatRewardTask): + # Use per-episode task_data for question/answer + question = self.task_data.get("question", "") + gt = self.task_data.get("answer", "") + if isinstance(gt, list): + gt = str(gt[0]) if gt else "" + # Extract prediction from the interaction's completion + pred = "" + if task.interaction is not None: + completion = getattr(task.interaction, "completion", None) + if completion is not None: + pred = completion.choices[0].message.content or "" + # Extract from tags if present + match = re.search(r"(.*?)", pred, re.DOTALL) + if match: + pred = match.group(1).strip() + return question, str(gt), pred + + if isinstance(task, GenerationTask): + data = task.customized_result_fields + question = data.get("question", "") + gt = data.get("answer", "") + if isinstance(gt, list): + gt = str(gt[0]) if gt else "" + pred = task.output_str or "" + return question, str(gt), pred + + return None + + def process(self, tasks: list[Task], **kwargs) -> Any: + """Process tasks by sending judge prompts to the LLM worker. + + For each task, builds a judge prompt and yields a ``ChatTask`` to + the worker. After the worker responds, parses the judge result + and stores the reward. + + Parameters + ---------- + tasks : list[Task] + Tasks to compute rewards for. + **kwargs + Additional keyword arguments. + + Yields + ------ + list[Task] + ChatTask lists sent to the worker for judge LLM calls. + """ + self.scores = [] + + # Build judge ChatTasks for each input task + # judge_map: input task index -> index in judge_chat_tasks list + judge_chat_tasks: list[ChatTask] = [] + judge_map: dict[int, int] = {} + + for i, task in enumerate(tasks): + extracted = self._extract_answer_and_data(task) + if extracted is None: + continue + question, gt, pred = extracted + if not question and not gt: + continue + + judge_prompt = self._build_judge_prompt(question, gt, pred) + judge_messages = [ + RoleMessage.from_dict({"role": "user", "content": judge_prompt}) + ] + chat_task = ChatTask.create_from_messages(judge_messages) + chat_task.worker_tag = NativeGenerationController.WorkerTag.GENERATION + chat_task.max_tokens = self.max_judge_tokens + chat_task.temperature = 1.0 + chat_task.stop = None + judge_map[i] = len(judge_chat_tasks) + judge_chat_tasks.append(chat_task) + + # Yield all judge tasks to the worker in one batch + if judge_chat_tasks: + yield judge_chat_tasks + + # Parse responses and assign rewards + for i, task in enumerate(tasks): + reward = 0.0 + if i in judge_map: + jt = judge_chat_tasks[judge_map[i]] + # The worker appends an AssistantMessage after the user message + if jt.messages and len(jt.messages) > 1: + judge_response = jt.messages[-1].content or "" + else: + judge_response = "" + reward = _parse_judge_result(judge_response) + + # Store reward on the original task + if isinstance(task, (RLVRRewardTask, ChatRewardTask)): + task.reward = reward + if task.interaction is not None: + task.interaction.reward = reward + elif isinstance(task, GenerationTask): + task.customized_result_fields["reward"] = reward + + self.scores.append(reward) diff --git a/examples/scaffolding/core/__init__.py b/examples/scaffolding/core/__init__.py new file mode 100644 index 0000000000..e821f00a76 --- /dev/null +++ b/examples/scaffolding/core/__init__.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Vendored from TensorRT-LLM scaffolding framework. +# Core scaffolding primitives adapted for standalone use in AReaL. + +from .controller import ( + BestOfNController, + Controller, + MajorityVoteController, + NativeChatController, + NativeGenerationController, + NativeRewardController, + ParallelProcess, +) +from .math_utils import ( + extract_answer_from_boxed, + extract_answer_with_regex, + get_digit_majority_vote_result, +) +from .result import ScaffoldingOutput +from .scaffolding_llm import ScaffoldingLlm +from .task import ( + AssistantMessage, + ChatTask, + GenerationTask, + OpenAIToolDescription, + RoleMessage, + StreamGenerationTask, + SystemMessage, + Task, + TaskStatus, + UserMessage, +) +from .task_collection import TaskCollection, with_task_collection +from .worker import OpenaiWorker, Worker + +__all__ = [ + "ScaffoldingLlm", + "ParallelProcess", + "Controller", + "NativeChatController", + "NativeGenerationController", + "NativeRewardController", + "MajorityVoteController", + "BestOfNController", + "Task", + "GenerationTask", + "StreamGenerationTask", + "ChatTask", + "OpenAIToolDescription", + "RoleMessage", + "UserMessage", + "SystemMessage", + "AssistantMessage", + "Worker", + "OpenaiWorker", + "TaskStatus", + "extract_answer_from_boxed", + "extract_answer_with_regex", + "get_digit_majority_vote_result", + "TaskCollection", + "with_task_collection", + "ScaffoldingOutput", +] diff --git a/examples/scaffolding/core/controller.py b/examples/scaffolding/core/controller.py new file mode 100644 index 0000000000..77fdc18f99 --- /dev/null +++ b/examples/scaffolding/core/controller.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Vendored from tensorrt_llm.scaffolding.controller + +import copy +import logging +from abc import ABC +from collections.abc import Mapping +from enum import Enum +from typing import Any + +import torch + +from .math_utils import get_digit_majority_vote_result +from .task import ChatTask, GenerationTask, Task + +logger = logging.getLogger(__name__) + + +class Controller(ABC): + task_collections: dict = {} + + def __init__(self): + self.task_collections = {} + + def clone(self): + return copy.deepcopy(self) + + def generate(self, prompt: str, **kwargs): + task = GenerationTask.create_from_prompt(prompt) + + yield from self.process([task], **kwargs) + + return task.create_scaffolding_output() + + def process(self, tasks: list[Task], **kwargs): + raise NotImplementedError + + +class ParallelProcess: + def __init__( + self, + controllers: list[Controller], + tasks_list: list[list[Task]], + kwargs_list: list[Mapping[str, Any]], + ): + self.sub_gens = [] + for controller, tasks, kwargs in zip(controllers, tasks_list, kwargs_list): + gen = controller.process(tasks, **kwargs) + self.sub_gens.append(gen) + + +# Controller runs multiple generation tasks. +class NativeGenerationController(Controller): + class WorkerTag(Enum): + GENERATION = "generation" + + def __init__(self, sampling_params: dict = None, streaming: bool = False): + super().__init__() + if sampling_params is None: + sampling_params = {} + for key, value in list(sampling_params.items()): + if key not in GenerationTask.__annotations__: + logger.warning(f"{key} is not a supported field for GenerationTask") + sampling_params.pop(key) + self.sampling_params = sampling_params + self.streaming = streaming + + # [GenerationTask] -> [GenerationTask] | [ChatTask] -> [ChatTask] + def process(self, tasks: list[Task], **kwargs): + for task in tasks: + task.worker_tag = self.WorkerTag.GENERATION + for key, value in self.sampling_params.items(): + if getattr(task, key) is None: + setattr(task, key, value) + + task.streaming_output_flag = self.streaming + + yield tasks + + +class NativeChatController(NativeGenerationController): + def __init__(self, sampling_params: dict = None, streaming: bool = False): + super().__init__(sampling_params, streaming) + + def process(self, tasks: list[Task], **kwargs): + chat_tasks = [ChatTask.create_from_prompt(task.input_str) for task in tasks] + yield from super().process(chat_tasks, **kwargs) + + +class NativeRewardController(Controller): + def __init__(self): + self.scores = None + + class WorkerTag(Enum): + REWARD = "reward" + + def process(self, tasks: list[Task], **kwargs): + task = GenerationTask() + for task in tasks: + task.worker_tag = self.WorkerTag.REWARD + + yield tasks + + +class MajorityVoteController(Controller): + def __init__(self, generation_controller: Controller, default_sample_num: int = 1): + super().__init__() + self.generation_controller = generation_controller + self.default_sample_num = default_sample_num + + def clone(self): + generation_controller = self.generation_controller.clone() + return MajorityVoteController(generation_controller, self.default_sample_num) + + def process( + self, + tasks: list[Task], + sample_num: int = 1, + generation_kwargs: dict = {}, + majority_vote_kwargs: dict = {}, + ): + sample_num = max(sample_num, self.default_sample_num) + generation_controllers = [ + self.generation_controller.clone() for _ in range(sample_num) + ] + tasks_list = [copy.deepcopy(tasks) for _ in range(sample_num)] + generation_kwargs_list = [ + copy.deepcopy(generation_kwargs) for _ in range(sample_num) + ] + + yield ParallelProcess( + generation_controllers, tasks_list, generation_kwargs_list + ) + + majority_index, majority_answer = self.majority_vote( + tasks_list, **majority_vote_kwargs + ) + + assert isinstance(majority_answer, str), "majority_vote failed" + tasks[0].result = tasks_list[majority_index][0].result + + def majority_vote( + self, candidates_tasks: list[list[Task]], **kwargs + ) -> tuple[int, str]: + candidates = [tasks[0].output_str for tasks in candidates_tasks] + return get_digit_majority_vote_result(candidates) + + +class BestOfNController(Controller): + def __init__( + self, + generation_controller: Controller, + reward_controller: Controller, + default_sample_num: int = 4, + ): + super().__init__() + self.generation_controller = generation_controller + self.reward_controller = reward_controller + self.default_sample_num = default_sample_num + + def clone(self): + generation_controller = self.generation_controller.clone() + reward_controller = self.reward_controller.clone() + return BestOfNController( + generation_controller, reward_controller, self.default_sample_num + ) + + def process( + self, + tasks: list[Task], + sample_num: int = 4, + generation_kwargs: dict = {}, + reward_kwargs: dict = {}, + select_best_kwargs: dict = {}, + ): + assert len(tasks) == 1, "BestOfNController only supports one task" + task = tasks[0] + + sample_num = max(sample_num, self.default_sample_num) + generation_controllers = [self.generation_controller for _ in range(sample_num)] + generation_kwargs_list = [generation_kwargs for _ in range(sample_num)] + generation_tasks = [copy.deepcopy(task) for _ in range(sample_num)] + + yield ParallelProcess( + generation_controllers, + [[t] for t in generation_tasks], + generation_kwargs_list, + ) + + yield from self.reward_controller.process(generation_tasks, **reward_kwargs) + + assert self.reward_controller.scores is not None + reward_values = self.reward_controller.scores + + for i, gen_task, reward_value in zip( + range(sample_num), generation_tasks, reward_values + ): + logger.info(f"[output {i}, score {reward_value}]:\n{gen_task.output_str}") + + best_task, best_idx = self.select_best( + generation_tasks, reward_values, **select_best_kwargs + ) + task.result = best_task.result + + def select_best(self, tasks: list[Task], reward_values, **kwargs) -> Task: + max_index = torch.argmax(torch.tensor(reward_values)).item() + return tasks[max_index], max_index diff --git a/examples/scaffolding/core/math_utils.py b/examples/scaffolding/core/math_utils.py new file mode 100644 index 0000000000..e94319bacb --- /dev/null +++ b/examples/scaffolding/core/math_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Vendored from tensorrt_llm.scaffolding.math_utils + +import re + + +def extract_answer_with_regex( + string: str, extract_regex: str = r"The final answer is (.+)$" +): + match = re.search(extract_regex, string) + if match: + return match.group(1) + return None + + +def extract_answer_from_boxed(string: str): + """Extract Answer String from \\boxed expression or based on regex""" + + if "\\boxed" not in string: + return None + + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + if retval: + left = "\\boxed{" + try: + assert retval[: len(left)] == left + assert retval[-1] == "}" + return retval[len(left) : -1] + except AssertionError: + return None + + return None + + +def get_majority_result( + results: list, + result_extractor=lambda x: x, + result_validator=lambda x: True, +): + extract_answers = [result_extractor(result) for result in results] + valid_answers = [ + result + for result in extract_answers + if result is not None and result_validator(result) is True + ] + if len(valid_answers) == 0: + return None, None + + answer_counts = {} + for answer in valid_answers: + answer_counts[answer] = answer_counts.get(answer, 0) + 1 + majority_answer = max(answer_counts, key=answer_counts.get) + majority_index = next( + filter(lambda x: x[1] == majority_answer, enumerate(extract_answers)) + )[0] + return majority_index, majority_answer + + +def get_digit_majority_vote_result(results: list[str]) -> str: + def is_digit(result: str): + return result.isdigit() + + index, extract_answer = get_majority_result( + results, result_extractor=extract_answer_from_boxed, result_validator=is_digit + ) + return (index, extract_answer) if extract_answer else (0, None) diff --git a/examples/scaffolding/core/result.py b/examples/scaffolding/core/result.py new file mode 100644 index 0000000000..52f786bef6 --- /dev/null +++ b/examples/scaffolding/core/result.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Vendored from tensorrt_llm.scaffolding.result + +import asyncio +import queue +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ScaffoldingOutput: + text: str + token_ids: list[int] + data: Any = None + + +class ScaffoldingResult: + """Result object for scaffolding requests. + + Uses a thread-safe ``queue.Queue`` for cross-thread communication so + that producers (ScaffoldingLlm loop thread) and consumers + (caller's event loop) can safely exchange data. + """ + + def __init__(self): + super().__init__() + self._queue: queue.Queue = queue.Queue() + self.outputs = [] + # only support one output for now, so use an empty obj to init + self.outputs.append(ScaffoldingOutput("", [])) + self._done = False + self.task_collections = None + + def set_output(self, output: ScaffoldingOutput | Any): + if isinstance(output, ScaffoldingOutput): + self.set_output_streaming(output) + # terminate + self.set_output_streaming(None) + + def set_output_streaming(self, output: ScaffoldingOutput | Any): + self._queue.put_nowait(output) + + def set_task_collections(self, task_collections: Mapping[str, Any]): + self.task_collections = task_collections + + async def _aresult_step(self): + """Asynchronously wait for the next item from the thread-safe queue.""" + loop = asyncio.get_running_loop() + obj = await loop.run_in_executor(None, self._queue.get) + if obj is None: + self._done = True + else: # obj is ScaffoldingOutput + self.outputs[0] = obj + + def result(self, timeout: float | None = None) -> "ScaffoldingResult": + while not self._done: + try: + obj = self._queue.get(timeout=timeout) + except queue.Empty: + break + if obj is None: + self._done = True + else: + self.outputs[0] = obj + return self + + async def aresult(self) -> "ScaffoldingResult": + while not self._done: + await self._aresult_step() + return self + + def __await__(self): + return self.aresult().__await__() + + def __aiter__(self): + return self + + async def __anext__(self): + if self._done: + raise StopAsyncIteration + + await self._aresult_step() + return self diff --git a/examples/scaffolding/core/scaffolding_llm.py b/examples/scaffolding/core/scaffolding_llm.py new file mode 100644 index 0000000000..224acd0f7a --- /dev/null +++ b/examples/scaffolding/core/scaffolding_llm.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Vendored from tensorrt_llm.scaffolding.scaffolding_llm + +import asyncio +import threading +import traceback +from collections import deque +from collections.abc import Generator, Mapping +from dataclasses import dataclass +from typing import Any + +from .controller import Controller, ParallelProcess +from .result import ScaffoldingResult +from .task import Task +from .worker import Worker + + +@dataclass(frozen=True) +class ScaffoldingRequest: + prompt: str + kwargs: Mapping[str, Any] + controller: Controller + result: "ScaffoldingResult" + + +class ScaffoldingLlm: + def __init__( + self, + prototype_controller: Controller, + workers: Mapping[str, Worker], # map of role to worker instance, + max_parallel_requests: int = 64, + ): + self.prototype_controller = prototype_controller + self.workers = workers + + # Always create a dedicated event loop in a separate thread. + # This avoids deadlocks when ScaffoldingLlm is used inside another + # event loop (e.g. AsyncTaskRunner's uvloop), where fire-and-forget + # tasks created via create_task() would never get executed. + # + # asyncio primitives (Queue, Event) are created inside the loop + # thread to ensure they are bound to the correct event loop. + self.loop = asyncio.new_event_loop() + self._ready = threading.Event() + self._run_main_loop_thread() + self._ready.wait() + + # For top scheduler + self.running_req_count = 0 + self.max_parallel_requests = max_parallel_requests + self.pending_queue = deque() + + self.output_task_collection = False + + def __enter__(self): + return self + + def __exit__(self): + self.shutdown() + + async def _handle_controller_generator( + self, gen: Generator, request: ScaffoldingRequest = None + ): + """Handle a controller generator, processing tasks and parallel processes.""" + for obj in gen: + if isinstance(obj, ParallelProcess): + await self._handle_parallel_process(obj, request) + else: + await self._handle_task_list(obj, request) + + async def _handle_task_list( + self, tasks: list[Task], request: ScaffoldingRequest = None + ): + """Execute a list of tasks concurrently.""" + async_tasks = [ + asyncio.create_task(self.workers[task.worker_tag].run_task(task)) + for task in tasks + ] + await asyncio.gather(*async_tasks) + for task in tasks: + if task.streaming_output_flag: + for output in task.streaming_output_list: + request.result.set_output_streaming(output) + task.streaming_output_list = [] + + async def _handle_parallel_process( + self, tasks: ParallelProcess, request: ScaffoldingRequest = None + ): + """Handle parallel execution of multiple generators.""" + async_tasks = [ + asyncio.create_task(self._handle_controller_generator(sub_gen, request)) + for sub_gen in tasks.sub_gens + ] + await asyncio.gather(*async_tasks) + + async def _handle_single_request(self, request: ScaffoldingRequest): + """Process a single scaffolding request.""" + try: + gen = self._create_controller_generator(request) + await self._handle_controller_generator(gen, request) + except Exception as e: + print(f"ScaffoldingLLM request exception: {e}") + traceback.print_exc() + request.result.set_output(None) + raise + finally: + self.running_req_count -= 1 + self._maybe_schedule() + + def _create_controller_generator(self, request: ScaffoldingRequest): + """Create a generator wrapper for the controller.""" + scaffolding_output = yield from request.controller.generate( + request.prompt, **request.kwargs + ) + + if self.output_task_collection: + request.result.set_task_collections(request.controller.task_collections) + request.result.set_output(scaffolding_output) + + def _schedule_request(self, request: ScaffoldingRequest): + """Schedule a single request for execution.""" + asyncio.create_task(self._handle_single_request(request)) + self.running_req_count += 1 + + def _maybe_schedule(self, request: ScaffoldingRequest = None): + """Schedule pending requests if capacity allows.""" + if self.shutdown_event.is_set(): + return + + if request is not None: + self.pending_queue.append(request) + + while ( + self.running_req_count < self.max_parallel_requests and self.pending_queue + ): + next_request = self.pending_queue.popleft() + self._schedule_request(next_request) + + async def _handle_event_loop(self): + """Main event handling loop.""" + while True: + item = await self.task_queue.get() + + if item is None: + return + elif isinstance(item, ScaffoldingRequest): + self._maybe_schedule(item) + else: + raise ValueError(f"Unsupported task_queue item type: {type(item)}") + + async def _main_loop_async_func(self): + """Main async loop function.""" + handle_event_task = asyncio.create_task(self._handle_event_loop()) + await handle_event_task + self.main_loop_stop_event.set() + + def _run_main_loop_thread(self): + def main_loop_thread(): + asyncio.set_event_loop(self.loop) + # Create asyncio primitives inside the loop thread. + self.task_queue = asyncio.Queue() + self.main_loop_stop_event = asyncio.Event() + self.shutdown_event = asyncio.Event() + self._ready.set() + self.loop.run_until_complete(self._main_loop_async_func()) + + self.main_loop_thread = threading.Thread(target=main_loop_thread, daemon=True) + self.main_loop_thread.start() + + def generate_async(self, prompt: str) -> ScaffoldingResult: + result = ScaffoldingResult() + # Clone synchronously here (before any async handoff) to avoid race + # conditions where concurrent callers mutate prototype_controller state + # between this call and when put_request actually runs on self.loop. + cloned_controller = self.prototype_controller.clone() + + async def put_request(): + try: + request = ScaffoldingRequest( + prompt=prompt, + kwargs={}, + result=result, + controller=cloned_controller, + ) + except Exception as e: + await self.task_queue.put(None) + print( + f"Error: build ScaffoldingRequest failed: {e} \n {traceback.format_exc()}" + ) + else: + await self.task_queue.put(request) + + asyncio.run_coroutine_threadsafe(put_request(), self.loop) + + return result + + def generate( + self, prompts: str | list[str] + ) -> ScaffoldingResult | list[ScaffoldingResult]: + unbatched = not isinstance(prompts, list) + batched_prompts = [prompts] if unbatched else prompts + + scaffolding_results = [] + for prompt in batched_prompts: + scaffolding_results.append(self.generate_async(prompt)) + + for scaffolding_result in scaffolding_results: + scaffolding_result.result() + + return scaffolding_results[0] if unbatched else scaffolding_results + + def enable_output_task_collection(self): + self.output_task_collection = True + + def shutdown(self, shutdown_workers=False): + def shutdown_workers_func(): + for worker in self.workers.values(): + worker.shutdown() + + async def stop_task_on_loop(): + await self.task_queue.put(None) + await self.main_loop_stop_event.wait() + for worker in self.workers.values(): + await worker.async_shutdown() + + asyncio.run_coroutine_threadsafe(stop_task_on_loop(), self.loop) + self.main_loop_thread.join() + + if shutdown_workers: + shutdown_workers_func() diff --git a/examples/scaffolding/core/task.py b/examples/scaffolding/core/task.py new file mode 100644 index 0000000000..0e2b0c6bb4 --- /dev/null +++ b/examples/scaffolding/core/task.py @@ -0,0 +1,270 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Vendored from tensorrt_llm.scaffolding.task + +import json +from collections.abc import Mapping +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from .result import ScaffoldingOutput + + +@dataclass +class Task: + # Scaffolding delivers the task to the Worker by worker_tag. + worker_tag: str = field(default=None) + + # For streaming output. + streaming_output_flag: bool = field(default=False) + streaming_output_list: list[Any] = field(default_factory=list) + + # Reserve for custom input params. + custom_input_params: dict | None = None + + # Reserve for custom output params. + custom_output_params: dict | None = None + + @staticmethod + def create_from_prompt(prompt: str) -> "Task": + pass + + def create_scaffolding_output(self) -> ScaffoldingOutput: + pass + + def create_scaffolding_output_stream(self) -> list[ScaffoldingOutput]: + pass + + +class TaskStatus(Enum): + SUCCESS = "success" + WORKER_NOT_SUPPORTED = "worker_not_supported" + WORKER_EXECEPTION = "worker_exception" + + +@dataclass +class GenerationTask(Task): + # input field + input_tokens: list[int] | None = None + input_str: str | None = None + skip_tokenizer: bool = False + skip_detokenizer: bool = False + + # sampling params for openai + best_of: int | None = None + echo: bool | None = False + frequency_penalty: float | None = 0.0 + logit_bias: dict[str, float] | None = None + num_logprobs: int | None = None + max_tokens: int | None = None + n: int = 1 + presence_penalty: float | None = 0.0 + seed: int | None = None + stop: str | list[str] | None = field(default_factory=list) + suffix: str | None = None + temperature: float | None = None + top_p: float | None = None + user: str | None = None + ignore_eos: bool = False + + # sampling params + top_k: int | None = None + return_context_logits: bool | None = False + + # suggest to use Controller.WorkerTag + worker_tag: str | None = None + + # result field + output_str: str | None = None + output_tokens: list[int] | None = None + finish_reason: str | None = None + context_logits: Any = None + logprobs: Any = None + customized_result_fields: dict[str, Any] = field(default_factory=dict) + + perf_metrics: dict[str, float] | None = None + + @staticmethod + def create_from_prompt(prompt: str) -> "GenerationTask": + task = GenerationTask() + task.input_str = prompt + task.skip_tokenizer = False + task.skip_detokenizer = False + return task + + def create_scaffolding_output(self) -> ScaffoldingOutput: + return ScaffoldingOutput(self.output_str, self.output_tokens) + + +@dataclass +class StreamGenerationTask(GenerationTask): + # input field + cancel_flag: bool | None = field(default=False) + streaming_step: int | None = field(default=1) + + # result field + request_handle: Any = field(default=None) + end_flag: bool = field(default=False) + + @staticmethod + def create_from_generation_task( + task: GenerationTask, streaming_step + ) -> "StreamGenerationTask": + stream_task = StreamGenerationTask() + for k, v in task.__dict__.items(): + stream_task.__dict__[k] = v + stream_task.streaming_step = streaming_step + return stream_task + + +@dataclass +class RewardTask(Task): + # input field + input_tokens: list[int] | None = field(default=None) + input_str: str | None = field(default=None) + + +@dataclass +class RoleMessage: + role: str | None = field(default=None) + content: str | None = field(default=None) + prefix: str | None = field(default=None) + + def __str__(self) -> str: + return json.dumps( + { + "role": self.role, + "content": self.content, + } + ) + + def __repr__(self) -> str: + return f"{self.role}: {self.content}\n" + + def to_dict(self) -> dict[str, Any]: + return {"role": self.role, "content": self.content} + + @classmethod + def from_dict(cls, data: dict[str, Any]): + return cls(role=data["role"], content=data["content"]) + + +@dataclass +class UserMessage(RoleMessage): + def __init__(self, content: str, prefix: str | None = None): + super().__init__(role="user", content=content, prefix=prefix) + + +@dataclass +class AssistantMessage(RoleMessage): + reasoning: str | None = field(default=None) + reasoning_content: str | None = field(default=None) + tool_calls: list[Any] | None = field(default=None) + + def __init__( + self, + content: str, + reasoning: str | None = None, + reasoning_content: str | None = None, + tool_calls: list[Any] | None = None, + ): + super().__init__(role="assistant", content=content) + self.reasoning = reasoning + self.reasoning_content = reasoning_content + self.tool_calls = tool_calls + + def __str__(self) -> str: + return json.dumps( + { + "role": "assistant", + "content": self.content, + "reasoning": self.reasoning, + "reasoning_content": self.reasoning_content, + "tool_calls": [str(tool) for tool in self.tool_calls] + if self.tool_calls is not None + else None, + } + ) + + +@dataclass +class SystemMessage(RoleMessage): + def __init__(self, content: str, prefix: str | None = None): + super().__init__(role="system", content=content, prefix=prefix) + + +class ToolDescription: + def __init__(self, name: str, description: str, parameters: dict[str, Any]): + self.name = name + self.description = description + self.parameters = parameters + + def to_dict(self) -> dict[str, Any]: + pass + + +class OpenAIToolDescription(ToolDescription): + def to_dict(self) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": self.parameters, + }, + }, + } + + +@dataclass +class ChatTask(StreamGenerationTask): + messages: list[RoleMessage] = field(default_factory=list) + tools: Any = field(default=None) + + # for token counting + enable_token_counting: bool = field(default=False) + prompt_tokens_num: int = field(default=0) + completion_tokens_num: int = field(default=0) + reasoning_tokens_num: int = field(default=0) + + # for sub request marker + sub_request_markers: list[tuple[str, int]] = field(default_factory=list) + unique_id: int | None = field(default=None) + + def messages_to_dict_content(self, start_index: int = 0) -> list[Mapping[str, str]]: + ret = [] + for message in self.messages[start_index:]: + if message.content is not None: + ret.append(message.to_dict()) + return ret + + def add_message(self, message: RoleMessage): + self.messages.append(message) + + def add_messages(self, messages: list[RoleMessage]): + self.messages.extend(messages) + + @staticmethod + def create_from_prompt( + user_prompt: str | None, + system_prompts: list[SystemMessage] | None = None, + tools: Any | None = None, + ) -> "ChatTask": + task = ChatTask() + if system_prompts is not None: + task.messages.extend(system_prompts) + if user_prompt is not None: + task.add_message(UserMessage(user_prompt)) + task.tools = tools + return task + + @staticmethod + def create_from_messages( + messages: list[RoleMessage], tools: Any | None = None + ) -> "ChatTask": + task = ChatTask() + task.messages = messages + task.tools = tools + return task diff --git a/examples/scaffolding/core/task_collection.py b/examples/scaffolding/core/task_collection.py new file mode 100644 index 0000000000..6ba6c303e7 --- /dev/null +++ b/examples/scaffolding/core/task_collection.py @@ -0,0 +1,492 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Vendored from tensorrt_llm.scaffolding.task_collection + +import json +import time +from typing import Any + +from .controller import ParallelProcess +from .task import ChatTask, GenerationTask, Task + + +class TaskCollection: + def __init__(self): + # reserved for future use + pass + + def before_yield(self, tasks: list[Task]): + pass + + def after_yield(self, tasks: list[Task]): + pass + + @staticmethod + def get_global_info() -> Any: + pass + + +def with_task_collection( + name: str, task_collection_cls: type[TaskCollection], **task_collection_kwargs +): + def decorator(controller_cls: type): + original_init = controller_cls.__init__ + original_process = controller_cls.process + + # add task collection to controller + def new_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self.task_collections[name] = task_collection_cls(**task_collection_kwargs) + + def new_process(self, tasks: list[Task], **kwargs): + class TaskCollectionWrapper: + def __init__(self, task_collection, gen): + self.task_collection = task_collection + self.gen = gen + + def __call__(self): + for obj in self.gen: + if isinstance(obj, ParallelProcess): + new_sub_gens = [] + for sub_gen in obj.sub_gens: + new_sub_gen = TaskCollectionWrapper( + self.task_collection, sub_gen + ) + new_sub_gens.append(new_sub_gen) + obj.sub_gens = new_sub_gens + + yield obj + else: # obj is a list of tasks + self.task_collection.before_yield(obj) + yield obj + self.task_collection.after_yield(obj) + + def __iter__(self): + return self.__call__() + + original_gen = original_process(self, tasks, **kwargs) + new_gen = TaskCollectionWrapper(self.task_collections[name], original_gen) + return new_gen() + + controller_cls.__init__ = new_init + controller_cls.process = new_process + + return controller_cls + + return decorator + + +class GenerationTokenCounter(TaskCollection): + def __init__(self): + super().__init__() + self.generation_token_count = 0 + self.pre_worker_token_sum = 0 + + def before_yield(self, tasks: list[Task]): + self.pre_worker_token_sum = 0 + for task in tasks: + if isinstance(task, GenerationTask) or issubclass( + type(task), GenerationTask + ): + if task.output_tokens: + self.pre_worker_token_sum += len(task.output_tokens) + + def after_yield(self, tasks: list[Task]): + post_worker_token_sum = 0 + for task in tasks: + if isinstance(task, GenerationTask) or issubclass( + type(task), GenerationTask + ): + if task.output_tokens: + post_worker_token_sum += len(task.output_tokens) + self.generation_token_count += post_worker_token_sum - self.pre_worker_token_sum + + +class ChatTokenCounter(TaskCollection): + # prompt tokens, completion tokens + statistics: dict[str, list[tuple[int, int]]] = {} + + def __init__(self, statistics_name: str): + super().__init__() + self.statistics_name = statistics_name + if statistics_name not in ChatTokenCounter.statistics: + ChatTokenCounter.statistics[statistics_name] = [] + + def before_yield(self, tasks: list[Task]): + for task in tasks: + if not isinstance(task, ChatTask): + continue + task.enable_token_counting = True + + def after_yield(self, tasks: list[Task]): + for task in tasks: + if not isinstance(task, ChatTask): + continue + ChatTokenCounter.statistics[self.statistics_name].append( + (task.prompt_tokens_num, task.completion_tokens_num) + ) + + def get_global_info() -> Any: + return ChatTokenCounter.statistics + + +class TaskTimer(TaskCollection): + statistics: dict[str, dict[type, list[float]]] = {} + + def __init__(self, statistics_name: str, task_types: list[type[Task]]): + super().__init__() + self.statistics_name = statistics_name + self.task_types = task_types + self.start_time_map = {} + if statistics_name not in TaskTimer.statistics: + TaskTimer.statistics[statistics_name] = {} + for task_type in task_types: + if task_type not in TaskTimer.statistics[statistics_name]: + TaskTimer.statistics[statistics_name][task_type] = [] + + def before_yield(self, tasks: list[Task]): + for task in tasks: + if type(task) not in self.task_types: + continue + + self.start_time_map[id(task)] = time.time() + + def after_yield(self, tasks: list[Task]): + for task in tasks: + if type(task) not in self.task_types: + continue + + end_time = time.time() + TaskTimer.statistics[self.statistics_name][type(task)].append( + end_time - self.start_time_map[id(task)] + ) + del self.start_time_map[id(task)] + + def get_global_info() -> Any: + return TaskTimer.statistics + + +class TaskMetricsCollector(TaskCollection): + """Task profiler that captures tasks at yield points.""" + + # Global statistics: controller_name -> List[task_info_dict] + statistics: dict[str, list[dict[str, Any]]] = {} + + def __init__( + self, + controller_name: str, + task_types: list[type[Task]] = None, + enable_print: bool = True, + capture_messages: bool = False, + ): + super().__init__() + self.controller_name = controller_name + self.task_types = task_types + self.enable_print = enable_print + self.capture_messages = capture_messages + self.start_time_map: dict[int, float] = {} + self.pre_message_count_map: dict[int, int] = {} + + if controller_name not in TaskMetricsCollector.statistics: + TaskMetricsCollector.statistics[controller_name] = [] + + def _should_process_task(self, task: Task) -> bool: + if self.task_types is not None and type(task) not in self.task_types: + return False + return True + + def _is_task_already_profiled(self, task: Task) -> bool: + return getattr(task, "_profiling_in_progress", False) + + def _mark_task_profiling_start(self, task: Task): + task._profiling_in_progress = True + + def _mark_task_profiling_end(self, task: Task): + task._profiling_in_progress = False + + def before_yield(self, tasks: list[Task]): + for task in tasks: + if not self._should_process_task(task): + continue + if self._is_task_already_profiled(task): + continue + + self._mark_task_profiling_start(task) + task_id = id(task) + self.start_time_map[task_id] = time.time() + + if isinstance(task, ChatTask): + task.enable_token_counting = True + if self.capture_messages: + self.pre_message_count_map[task_id] = len(task.messages) + + def after_yield(self, tasks: list[Task]): + for task in tasks: + task_id = id(task) + if task_id not in self.start_time_map: + continue + + end_time = time.time() + duration = end_time - self.start_time_map[task_id] + del self.start_time_map[task_id] + self._mark_task_profiling_end(task) + + task_info = { + "controller": self.controller_name, + "task_type": type(task).__name__, + "duration_ms": duration * 1000, + "timestamp": end_time, + } + + if isinstance(task, ChatTask): + task_info["prompt_tokens"] = getattr(task, "prompt_tokens_num", 0) + task_info["completion_tokens"] = getattr( + task, "completion_tokens_num", 0 + ) + task_info["reasoning_tokens"] = getattr(task, "reasoning_tokens_num", 0) + task_info["total_tokens"] = ( + task_info["prompt_tokens"] + task_info["completion_tokens"] + ) + task_info["finish_reason"] = getattr(task, "finish_reason", None) + task_info["unique_id"] = getattr(task, "unique_id", None) + task_info["sub_request_markers"] = getattr( + task, "sub_request_markers", [] + ) + task_info["perf_metrics"] = getattr(task, "perf_metrics", None) + + if self.capture_messages: + pre_message_count = self.pre_message_count_map.get(task_id, 0) + if task_id in self.pre_message_count_map: + del self.pre_message_count_map[task_id] + + task_info["message_count_before"] = pre_message_count + task_info["message_count_after"] = len(task.messages) + task_info["messages"] = [ + self._serialize_message(msg) for msg in task.messages + ] + if len(task.messages) > pre_message_count: + task_info["new_messages"] = [ + self._serialize_message(msg) + for msg in task.messages[pre_message_count:] + ] + else: + task_info["new_messages"] = [] + + TaskMetricsCollector.statistics[self.controller_name].append(task_info) + + if self.enable_print: + self._print_task_info(task_info) + + def _serialize_message(self, message) -> dict[str, Any]: + """Serialize a RoleMessage to a dictionary.""" + result = { + "role": getattr(message, "role", None), + "content": getattr(message, "content", None), + } + if hasattr(message, "reasoning") and message.reasoning is not None: + result["reasoning"] = message.reasoning + if ( + hasattr(message, "reasoning_content") + and message.reasoning_content is not None + ): + result["reasoning_content"] = message.reasoning_content + if hasattr(message, "tool_calls") and message.tool_calls is not None: + result["tool_calls"] = [str(tc) for tc in message.tool_calls] + return result + + def _print_task_info(self, task_info: dict[str, Any]): + log_parts = [ + f"[{task_info['controller']}]", + f"{task_info['task_type']}", + f"duration={task_info['duration_ms']:.2f}ms", + ] + + if "prompt_tokens" in task_info: + log_parts.append( + f"prompt={task_info['prompt_tokens']} " + f"completion={task_info['completion_tokens']} " + f"reasoning={task_info['reasoning_tokens']} " + f"total={task_info['total_tokens']}" + ) + + if task_info.get("perf_metrics"): + perf_str = ", ".join( + f"{k}={v:.2f}" if isinstance(v, float) else f"{k}={v}" + for k, v in task_info["perf_metrics"].items() + ) + log_parts.append(f"perf: {perf_str}") + + print(" | ".join(log_parts)) + + if "new_messages" in task_info and task_info["new_messages"]: + print( + f" Messages: {task_info['message_count_before']} -> {task_info['message_count_after']}" + ) + print(" New Messages:") + for msg in task_info["new_messages"]: + role = msg.get("role", "unknown") + content = msg.get("content", "") + if content and len(content) > 200: + content = content[:200] + "..." + print(f" [{role}]: {content}") + + @staticmethod + def _compute_stats(values: list[float]) -> dict[str, float]: + """Compute avg, median, min, max, sum for a list of values.""" + if not values: + return {"avg": 0, "median": 0, "min": 0, "max": 0, "sum": 0} + sorted_vals = sorted(values) + n = len(sorted_vals) + median = ( + sorted_vals[n // 2] + if n % 2 == 1 + else (sorted_vals[n // 2 - 1] + sorted_vals[n // 2]) / 2 + ) + return { + "avg": sum(values) / n, + "median": median, + "min": min(values), + "max": max(values), + "sum": sum(values), + } + + @staticmethod + def print_summary(): + """Print summary statistics for all controllers.""" + print("\n" + "=" * 80) + print("TASK METRICS SUMMARY") + print("=" * 80) + + for controller_name, task_list in TaskMetricsCollector.statistics.items(): + if not task_list: + continue + + print(f"\n{controller_name} ({len(task_list)} records)") + print("-" * 70) + + task_type_data: dict[str, dict[str, list[float]]] = {} + perf_metrics_agg: dict[str, dict[str, list[float]]] = {} + + for task_info in task_list: + task_type = task_info["task_type"] + if task_type not in task_type_data: + task_type_data[task_type] = { + "duration_ms": [], + "prompt_tokens": [], + "completion_tokens": [], + "reasoning_tokens": [], + "total_tokens": [], + } + perf_metrics_agg[task_type] = {} + + data = task_type_data[task_type] + data["duration_ms"].append(task_info["duration_ms"]) + data["prompt_tokens"].append(task_info.get("prompt_tokens", 0)) + data["completion_tokens"].append(task_info.get("completion_tokens", 0)) + data["reasoning_tokens"].append(task_info.get("reasoning_tokens", 0)) + data["total_tokens"].append(task_info.get("total_tokens", 0)) + + if task_info.get("perf_metrics"): + for key, value in task_info["perf_metrics"].items(): + if isinstance(value, (int, float)): + if key not in perf_metrics_agg[task_type]: + perf_metrics_agg[task_type][key] = [] + perf_metrics_agg[task_type][key].append(float(value)) + + for task_type, data in task_type_data.items(): + count = len(data["duration_ms"]) + print(f"\n {task_type} (count: {count})") + + duration_stats = TaskMetricsCollector._compute_stats( + data["duration_ms"] + ) + print( + f" Duration (ms): sum={duration_stats['sum']:.2f}, " + f"avg={duration_stats['avg']:.2f}, " + f"median={duration_stats['median']:.2f}, " + f"min={duration_stats['min']:.2f}, max={duration_stats['max']:.2f}" + ) + + if sum(data["total_tokens"]) > 0: + prompt_stats = TaskMetricsCollector._compute_stats( + data["prompt_tokens"] + ) + completion_stats = TaskMetricsCollector._compute_stats( + data["completion_tokens"] + ) + total_stats = TaskMetricsCollector._compute_stats( + data["total_tokens"] + ) + + print( + f" Prompt tokens: sum={prompt_stats['sum']:.0f}, " + f"avg={prompt_stats['avg']:.1f}, " + f"min={prompt_stats['min']:.0f}, max={prompt_stats['max']:.0f}" + ) + print( + f" Completion tokens: sum={completion_stats['sum']:.0f}, " + f"avg={completion_stats['avg']:.1f}, " + f"min={completion_stats['min']:.0f}, max={completion_stats['max']:.0f}" + ) + print( + f" Total tokens: sum={total_stats['sum']:.0f}, " + f"avg={total_stats['avg']:.1f}, " + f"min={total_stats['min']:.0f}, max={total_stats['max']:.0f}" + ) + + if perf_metrics_agg[task_type]: + print("\n Perf Metrics:") + for metric_name, values in sorted( + perf_metrics_agg[task_type].items() + ): + stats = TaskMetricsCollector._compute_stats(values) + print( + f" {metric_name}: sum={stats['sum']:.2f}, " + f"avg={stats['avg']:.2f}, " + f"min={stats['min']:.2f}, " + f"max={stats['max']:.2f}" + ) + + print("\n" + "=" * 80 + "\n") + + @staticmethod + def get_statistics(controller_name: str = None) -> dict[str, list[dict[str, Any]]]: + """Get statistics for a specific controller or all controllers.""" + if controller_name is not None: + return { + controller_name: TaskMetricsCollector.statistics.get( + controller_name, [] + ) + } + return TaskMetricsCollector.statistics + + @staticmethod + def get_all_records() -> list[dict[str, Any]]: + """Get all records across all controllers as a flat list.""" + all_records = [] + for records in TaskMetricsCollector.statistics.values(): + all_records.extend(records) + all_records.sort(key=lambda x: x.get("timestamp", 0)) + return all_records + + @staticmethod + def export_to_json(file_path: str, controller_name: str = None): + """Export metrics to a JSON file.""" + if controller_name is not None: + data = TaskMetricsCollector.statistics.get(controller_name, []) + else: + data = TaskMetricsCollector.statistics + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False, default=str) + + @staticmethod + def reset(controller_name: str = None): + """Reset statistics for a specific controller or all controllers.""" + if controller_name is not None: + if controller_name in TaskMetricsCollector.statistics: + TaskMetricsCollector.statistics[controller_name] = [] + else: + TaskMetricsCollector.statistics.clear() + + @staticmethod + def get_global_info() -> Any: + return TaskMetricsCollector.statistics diff --git a/examples/scaffolding/core/worker.py b/examples/scaffolding/core/worker.py new file mode 100644 index 0000000000..c2f76d06c7 --- /dev/null +++ b/examples/scaffolding/core/worker.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Vendored from tensorrt_llm.scaffolding.worker +# TRTLLMWorker and MCPWorker omitted (require tensorrt_llm runtime). + +import os +from abc import ABC +from collections.abc import Callable + +import openai + +from .task import AssistantMessage, ChatTask, GenerationTask, Task, TaskStatus + + +class Worker(ABC): + def register_task_handler( + self, task_cls: type[Task], handler: Callable[[object, Task], TaskStatus] + ): + worker_cls = type(self) + worker_cls.task_handlers[task_cls] = handler + + async def run_task(self, task: Task) -> TaskStatus: + worker_cls = type(self) + if type(task) not in worker_cls.task_handlers: + return TaskStatus.WORKER_NOT_SUPPORTED + return await worker_cls.task_handlers[type(task)](self, task) + + task_handlers = {} + + def shutdown(self): + pass + + async def async_shutdown(self): + pass + + def __enter__(self): + return self + + def __exit__(self): + self.shutdown() + + +# helper function +def add_param_if_not_none(params, key, candidate_values): + for value in candidate_values: + if value is not None: + params[key] = value + return + + +# helper function +def add_attr_if_not_none(obj, attr, candidate_values): + for value in candidate_values: + if value is not None: + setattr(obj, attr, value) + return + + +def is_deterministic_mode(): + """Check if SCAFFOLDING_DETERMINISTIC environment variable is set.""" + return int(os.environ.get("SCAFFOLDING_DETERMINISTIC", 0)) == 1 + + +class OpenaiWorker(Worker): + def __init__( + self, + async_client: openai.AsyncOpenAI, + model: str, + kv_cache_hint_enabled: bool = False, + ): + self.model = model + self.async_client = async_client + self.kv_cache_hint_enabled = kv_cache_hint_enabled + + def convert_task_params(self, task: GenerationTask | ChatTask): + params = { + "model": self.model, + "extra_body": {}, + } + + if not isinstance(task, ChatTask): + params["prompt"] = task.input_str + add_param_if_not_none(params, "echo", [task.echo]) + + add_param_if_not_none(params, "best_of", [task.best_of]) + add_param_if_not_none(params, "frequency_penalty", [task.frequency_penalty]) + add_param_if_not_none(params, "logit_bias", [task.logit_bias]) + add_param_if_not_none(params, "logprobs", [task.num_logprobs]) + add_param_if_not_none(params, "max_tokens", [task.max_tokens]) + add_param_if_not_none(params, "n", [task.n]) + add_param_if_not_none(params, "presence_penalty", [task.presence_penalty]) + add_param_if_not_none(params, "seed", [task.seed]) + add_param_if_not_none(params, "stop", [task.stop]) + add_param_if_not_none(params, "suffix", [task.suffix]) + add_param_if_not_none(params, "temperature", [task.temperature]) + add_param_if_not_none(params, "top_p", [task.top_p]) + add_param_if_not_none(params, "user", [task.user]) + + # Override parameters for deterministic inference + if is_deterministic_mode(): + params["temperature"] = 0.0 + params["top_p"] = 1.0 + params["n"] = 1 + if "seed" not in params or params["seed"] is None: + params["seed"] = 42 + + if hasattr(task, "sub_request_markers") and len(task.sub_request_markers) > 0: + params["extra_body"]["agent_hierarchy"] = [task.sub_request_markers[-1]] + + return params + + def fill_generation_task_with_response( + self, task: GenerationTask, response: openai.Completion + ): + task.output_str = response.choices[0].text + task.output_tokens = response.choices[0].token_ids + task.finish_reason = response.choices[0].finish_reason + task.logprobs = response.choices[0].logprobs + task.perf_metrics = response.perf_metrics + + async def generation_handler(self, task: GenerationTask) -> TaskStatus: + params = self.convert_task_params(task) + + try: + response = await self.async_client.completions.create(**params) + self.fill_generation_task_with_response(task, response) + + return TaskStatus.SUCCESS + + except Exception as e: + print("Openai client get exception: " + str(e)) + return TaskStatus.WORKER_EXECEPTION + + async def chat_handler(self, task: ChatTask) -> TaskStatus: + params = self.convert_task_params(task) + params["messages"] = task.messages_to_dict_content() + params["model"] = self.model + if task.tools is not None: + params["tools"] = [tool.to_dict() for tool in task.tools] + + try: + response = await self.async_client.chat.completions.create(**params) + task.finish_reason = response.choices[0].finish_reason + task.perf_metrics = response.perf_metrics + content = response.choices[0].message.content + reasoning = response.choices[0].message.reasoning + reasoning_content = response.choices[0].message.reasoning_content + tool_calls = response.choices[0].message.tool_calls + task.messages.append( + AssistantMessage(content, reasoning, reasoning_content, tool_calls) + ) + if task.enable_token_counting: + task.prompt_tokens_num = response.usage.prompt_tokens + task.completion_tokens_num = response.usage.completion_tokens + if ( + hasattr(response.usage, "completion_tokens_details") + and response.usage.completion_tokens_details is not None + ): + task.reasoning_tokens_num = ( + response.usage.completion_tokens_details.reasoning_tokens + ) + + return TaskStatus.SUCCESS + + except Exception as e: + print("Openai chat client get exception: " + str(e)) + return TaskStatus.WORKER_EXECEPTION + + task_handlers = { + GenerationTask: generation_handler, + ChatTask: chat_handler, + } diff --git a/examples/scaffolding/fake_tools.py b/examples/scaffolding/fake_tools.py new file mode 100644 index 0000000000..4aaa14da20 --- /dev/null +++ b/examples/scaffolding/fake_tools.py @@ -0,0 +1,72 @@ +""" +Fake search and visit tools for the search scaffolding example. + +These tools return canned results so the example can run without +external API keys (no SERPER_KEY or JINA_API_KEYS required). +""" + + +async def fake_search(queries: list[str]) -> str: + """Return fake search results for each query. + + Parameters + ---------- + queries : list[str] + List of search query strings. + + Returns + ------- + str + Fake search results formatted like the real search tool output. + """ + results = [] + for query in queries: + snippet = ( + f"A Google search for '{query}' found 3 results:\n\n" + f"## Web Results\n" + f"1. [Wikipedia - {query}](https://en.wikipedia.org/wiki/{query.replace(' ', '_')})\n" + f"Source: Wikipedia\n" + f"This article provides an overview of {query}.\n\n" + f"2. [Britannica - {query}](https://www.britannica.com/topic/{query.replace(' ', '-')})\n" + f"Source: Britannica\n" + f"A comprehensive reference on {query} with detailed analysis.\n\n" + f"3. [Research Paper on {query}](https://arxiv.org/abs/2401.00001)\n" + f"Source: arXiv\n" + f"Recent academic research related to {query}." + ) + results.append(snippet) + return "\n=======\n".join(results) + + +async def fake_visit(urls: list[str], goal: str) -> str: + """Return fake webpage content summaries. + + Parameters + ---------- + urls : list[str] + List of URLs to visit. + goal : str + The information goal for visiting the webpages. + + Returns + ------- + str + Fake webpage content summaries for each URL. + """ + results = [] + for url in urls: + summary = ( + f"The useful information in {url} for user goal {goal} as follows: \n\n" + f"Evidence in page: \n" + f"The webpage discusses topics related to {goal}. " + f"Key findings include several relevant data points and references " + f"that contribute to understanding the subject matter. " + f"The content covers historical context, current developments, " + f"and expert opinions on the topic.\n\n" + f"Summary: \n" + f"This source provides relevant background information about {goal}. " + f"The main conclusions support a factual understanding of the topic " + f"based on available evidence.\n\n" + ) + results.append(summary) + return "\n=======\n".join(results) diff --git a/examples/scaffolding/gsm8k_rlvr_scaffolding.py b/examples/scaffolding/gsm8k_rlvr_scaffolding.py new file mode 100644 index 0000000000..327541e02d --- /dev/null +++ b/examples/scaffolding/gsm8k_rlvr_scaffolding.py @@ -0,0 +1,97 @@ +""" +RLVR (Reinforcement Learning with Verifiable Rewards) Example using Scaffolding Framework. + +This example demonstrates how to use the scaffolding framework for RLVR training +on the GSM8K math dataset. The ScaffoldingWorkflow uses AReaL's engine for +generation and scaffolding controllers for reward computation. + +Usage: + python -m examples.scaffolding.gsm8k_rlvr_scaffolding \ + --config examples/scaffolding/gsm8k_rlvr_scaffolding.yaml \ + +scheduler.type=local experiment_name=areal trial_name=scaffolding +""" + +import sys + +from areal.api.cli_args import GRPOConfig, load_expr_config +from areal.api.engine_api import InferenceEngine +from areal.dataset import get_custom_dataset +from areal.trainer import PPOTrainer +from areal.utils.hf_utils import load_hf_tokenizer + +from ._compat import ( + NativeGenerationController, + ScaffoldingLlm, +) +from .controllers import ( + PipelineTrajectoryMaker, + RLVRRewardController, +) +from .workflow import ScaffoldingWorkflow + + +class GSM8KScaffoldingWorkflow(ScaffoldingWorkflow): + """ScaffoldingWorkflow customized for GSM8K RLVR training. + + Demonstrates overriding ``build_scaffolding_llm`` to control how the + ScaffoldingLlm is constructed (e.g., swap controllers or workers). + """ + + def build_scaffolding_llm(self, engine: InferenceEngine) -> ScaffoldingLlm: + sampling_params = { + "max_tokens": self.gconfig.max_new_tokens, + "temperature": self.gconfig.temperature or 1.0, + } + self.gen_controller = NativeGenerationController( + sampling_params=sampling_params + ) + self.reward_controller = RLVRRewardController(self.reward_fn) + self.trajectory_maker = PipelineTrajectoryMaker( + self.gen_controller, self.reward_controller + ) + return ScaffoldingLlm( + self.trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: self.worker}, + ) + + +def main(args): + """Main entry point for RLVR training with scaffolding.""" + config, _ = load_expr_config(args, GRPOConfig) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + train_dataset = get_custom_dataset( + split="train", + dataset_config=config.train_dataset, + tokenizer=tokenizer, + ) + valid_dataset = get_custom_dataset( + split="test", + dataset_config=config.valid_dataset, + tokenizer=tokenizer, + ) + + workflow_kwargs = dict( + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + enable_thinking=False, + ) + eval_workflow_kwargs = workflow_kwargs.copy() + eval_workflow_kwargs["gconfig"] = config.gconfig.new(temperature=0.6) + + with PPOTrainer( + config, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + ) as trainer: + trainer.train( + workflow="examples.scaffolding.gsm8k_rlvr_scaffolding.GSM8KScaffoldingWorkflow", + workflow_kwargs=workflow_kwargs, + eval_workflow="examples.scaffolding.gsm8k_rlvr_scaffolding.GSM8KScaffoldingWorkflow", + eval_workflow_kwargs=eval_workflow_kwargs, + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/scaffolding/gsm8k_rlvr_scaffolding.yaml b/examples/scaffolding/gsm8k_rlvr_scaffolding.yaml new file mode 100644 index 0000000000..18ce7ac759 --- /dev/null +++ b/examples/scaffolding/gsm8k_rlvr_scaffolding.yaml @@ -0,0 +1,178 @@ +# RLVR Scaffolding Example Configuration for GSM8K +# Compatible with GRPOConfig, single-GPU setup (inference + training share 1 GPU) + +experiment_name: gsm8k-rlvr-scaffolding +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 1 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: sglang:d1+d1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 64 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 8 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-3B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.0e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: disk + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_running_requests: null + context_length: 2048 + mem_fraction_static: 0.3 + attention_backend: flashinfer + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 4096 + gpu_memory_utilization: 0.8 + +# Datasets +train_dataset: + batch_size: 64 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 2048 + +valid_dataset: + batch_size: 64 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/examples/scaffolding/gsm8k_rlvr_scaffolding_2nodes.yaml b/examples/scaffolding/gsm8k_rlvr_scaffolding_2nodes.yaml new file mode 100644 index 0000000000..2394371567 --- /dev/null +++ b/examples/scaffolding/gsm8k_rlvr_scaffolding_2nodes.yaml @@ -0,0 +1,187 @@ +# RLVR Scaffolding Example Configuration for GSM8K +# Multi-node setup: 2 nodes x 8 A100-80GB GPUs (Ray scheduler) +# allocation: 8 GPUs for SGLang inference, 8 GPUs for actor+ref training + +experiment_name: gsm8k-rlvr-scaffolding +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 2 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: ray + ray_actor_name: ray_kv_store + +allocation_mode: sglang:d8+d8 + +scheduler: + type: ray + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 64 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: false + +gconfig: + n_samples: 8 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-3B-Instruct + init_from_scratch: false + attn_impl: sdpa + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.0e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: xccl + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + NCCL_SOCKET_FAMILY: AF_INET6 + NCCL_SOCKET_IFNAME: eth0 + TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: "3600" + NCCL_IB_DISABLE: "1" + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_running_requests: null + context_length: 2048 + mem_fraction_static: 0.5 + attention_backend: flashinfer + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 4096 + gpu_memory_utilization: 0.8 + +# Datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: false + num_workers: 0 + path: openai/gsm8k + type: rl + max_length: 2048 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/examples/scaffolding/real_tools.py b/examples/scaffolding/real_tools.py new file mode 100644 index 0000000000..e21b3f92a9 --- /dev/null +++ b/examples/scaffolding/real_tools.py @@ -0,0 +1,161 @@ +""" +Real search and visit tools for the search scaffolding example. + +Uses the Serper API (SERPER_KEY_ID env var) for web search and +basic aiohttp fetching for page visits. +""" + +import asyncio +import json +import os +import re + +import aiohttp + +SERPER_KEY = os.environ.get("SERPER_KEY_ID", "") + +# Maximum characters to keep from a fetched webpage. +_MAX_PAGE_CHARS = 8000 + + +async def real_search(queries: list[str]) -> str: + """Perform Google searches via the Serper API. + + Parameters + ---------- + queries : list[str] + List of search query strings. + + Returns + ------- + str + Formatted search results for each query, separated by ``=======``. + """ + tasks = [_search_single(q) for q in queries] + results = await asyncio.gather(*tasks) + return "\n=======\n".join(results) + + +async def _search_single(query: str) -> str: + """Search a single query via Serper and return formatted results.""" + + def _contains_chinese(text: str) -> bool: + return any("\u4e00" <= ch <= "\u9fff" for ch in text) + + if _contains_chinese(query): + payload = {"q": query, "location": "China", "gl": "cn", "hl": "zh-cn"} + else: + payload = {"q": query, "location": "United States", "gl": "us", "hl": "en"} + + headers = {"X-API-KEY": SERPER_KEY, "Content-Type": "application/json"} + + last_exc: Exception | None = None + async with aiohttp.ClientSession() as session: + for _attempt in range(5): + try: + async with session.post( + "https://google.serper.dev/search", + json=payload, + headers=headers, + ) as resp: + text = await resp.text() + try: + results = json.loads(text) + except Exception: + return f"[Search] Failed to parse response for '{query}'." + + if "organic" not in results: + return ( + f"No results found for query: '{query}'. " + "Use a less specific query." + ) + + web_snippets = [] + for idx, page in enumerate(results.get("organic", []), start=1): + date_published = ( + f"\nDate published: {page['date']}" + if page.get("date") + else "" + ) + source = ( + f"\nSource: {page['source']}" if page.get("source") else "" + ) + snippet = f"\n{page['snippet']}" if page.get("snippet") else "" + entry = ( + f"{idx}. [{page.get('title', '')}]" + f"({page.get('link', '')})" + f"{date_published}{source}\n{snippet}" + ) + entry = entry.replace("Your browser can't play this video.", "") + web_snippets.append(entry) + + return ( + f"A Google search for '{query}' found " + f"{len(web_snippets)} results:\n\n## Web Results\n" + + "\n\n".join(web_snippets) + ) + except Exception as e: + last_exc = e + await asyncio.sleep(0.5) + continue + + return ( + f"Google search Timeout or error ({last_exc}); " + "return None, Please try again later." + ) + + +def _html_to_text(html: str) -> str: + """Very basic HTML tag stripping.""" + text = re.sub(r"]*>[\s\S]*?", "", html, flags=re.I) + text = re.sub(r"]*>[\s\S]*?", "", text, flags=re.I) + text = re.sub(r"<[^>]+>", " ", text) + text = re.sub(r"\s+", " ", text).strip() + return text + + +async def real_visit(urls: list[str], goal: str) -> str: + """Fetch webpages and return truncated text content. + + Parameters + ---------- + urls : list[str] + List of URLs to visit. + goal : str + The information goal for visiting the webpages. + + Returns + ------- + str + Text content summaries for each URL, separated by ``=======``. + """ + results = [] + for url in urls: + content = await _fetch_page(url) + summary = ( + f"The useful information in {url} for user goal {goal} as follows: \n\n" + f"Evidence in page: \n{content}\n\n" + f"Summary: \nContent fetched from {url} related to {goal}.\n\n" + ) + results.append(summary) + return "\n=======\n".join(results) + + +async def _fetch_page(url: str) -> str: + """Fetch a URL and return plain-text content (truncated).""" + timeout = aiohttp.ClientTimeout(total=30) + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as resp: + if resp.status != 200: + return ( + f"The provided webpage returned HTTP {resp.status}. " + "Please check the URL or try another source." + ) + html = await resp.text() + text = _html_to_text(html) + if len(text) > _MAX_PAGE_CHARS: + text = text[:_MAX_PAGE_CHARS] + "\n... [truncated]" + return text if text else "The webpage returned empty content." + except Exception as e: + return f"Failed to fetch {url}: {e}" diff --git a/examples/scaffolding/search_agent_controller.py b/examples/scaffolding/search_agent_controller.py new file mode 100644 index 0000000000..4ca04624ff --- /dev/null +++ b/examples/scaffolding/search_agent_controller.py @@ -0,0 +1,241 @@ +""" +SearchAgentController — multi-turn tool-calling loop as a scaffolding Controller. + +This is the scaffolding-framework equivalent of ``MultiTurnReactAgent.run_agent`` +from the tongyi_deepresearch example. Instead of calling an ``ArealOpenAI`` +client directly, it yields ``ChatTask`` objects through a +``NativeGenerationController`` so that the ``ScaffoldingLlm`` dispatches them +to the SGLang worker. + +Tool execution (search / visit) happens **locally** inside ``process()`` — +only LLM generation goes through a Worker. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import json5 + +from areal.utils import logging + +from ._compat import ( + ChatTask, + Controller, + NativeGenerationController, + RoleMessage, + Task, + UserMessage, +) +from .real_tools import real_search, real_visit + +logger = logging.getLogger("SearchAgentController") + +OBS_START = "" +OBS_END = "\n" + + +class SearchAgentController(Controller): + """Multi-turn search-agent controller for the scaffolding framework. + + Each call to :meth:`process` runs a loop of up to *max_turns* turns: + + 1. Yield the ``ChatTask`` to the generation controller (LLM call). + 2. Parse the assistant reply for ``...``. + 3. If a tool call is found, execute it locally and append the tool + response as a user message. + 4. If an ```` tag is found, stop. + 5. If the token budget is exhausted, append a "please answer" nudge, + do one final generation, and stop. + + Parameters + ---------- + generation_controller : Controller + Typically a ``NativeGenerationController`` that yields ``ChatTask`` + to the worker. + tokenizer + HuggingFace tokenizer for token counting. + max_turns : int + Maximum number of LLM calls per episode. + max_total_tokens : int + Soft token budget for the conversation. + messages : list[dict] | None + Initial chat messages (set per-episode before ``generate``). + input_tokens : list[int] | None + Tokenised input IDs (set per-episode before ``generate``). + """ + + # Re-use the generation worker tag from NativeGenerationController + WorkerTag = NativeGenerationController.WorkerTag + + def __init__( + self, + generation_controller: Controller, + tokenizer: Any, + max_turns: int = 20, + max_total_tokens: int = 32768, + messages: list[dict] | None = None, + input_tokens: list[int] | None = None, + ): + super().__init__() + self.generation_controller = generation_controller + self.tokenizer = tokenizer + self.max_turns = max_turns + self.max_total_tokens = max_total_tokens + self.max_total_tokens_before_finishing = int(max_total_tokens * 0.8) + # Safety margin to account for _count_tokens approximation errors + # (the simple template-based count can underestimate by ~200 tokens) + # and SGLang's >= check (exactly max_context_length is rejected too). + self._token_safety_margin = 256 + self.messages = messages if messages is not None else [] + self.input_tokens = input_tokens if input_tokens is not None else [] + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _count_tokens(self, messages: list[RoleMessage]) -> int: + """Approximate token count for a list of ``RoleMessage`` objects.""" + parts = [] + for msg in messages: + d = msg.to_dict() if hasattr(msg, "to_dict") else msg + parts.append(f"<|im_start|>{d['role']}\n{d['content']}<|im_end|>\n") + parts.append("<|im_start|>assistant\n") + return len(self.tokenizer.encode("".join(parts))) + + async def _execute_tool(self, tool_name: str, tool_args: dict) -> str: + """Dispatch a tool call to the appropriate tool.""" + if tool_name == "search": + queries = tool_args.get("query", []) + if isinstance(queries, str): + queries = [queries] + return await real_search(queries) + if tool_name == "visit": + urls = tool_args.get("url", []) + if isinstance(urls, str): + urls = [urls] + goal = tool_args.get("goal", "") + return await real_visit(urls, goal) + return f"Error: Tool {tool_name} not found" + + # ------------------------------------------------------------------ + # Controller interface + # ------------------------------------------------------------------ + + def process(self, tasks: list[Task], **kwargs) -> Any: # noqa: C901 + """Run the multi-turn search-agent loop. + + Parameters + ---------- + tasks : list[Task] + Ignored — the ``ChatTask`` is built from ``self.messages``. + **kwargs + Forwarded to the generation controller. + + Yields + ------ + list[Task] + Task lists for worker execution (one per LLM call). + """ + # Build the ChatTask from stored messages + role_messages = [RoleMessage.from_dict(m) for m in self.messages] + chat_task = ChatTask.create_from_messages(role_messages) + chat_task.stop = ["\n", ""] + if self.input_tokens: + chat_task.input_tokens = self.input_tokens + + for turn in range(self.max_turns): + # --- Token budget check (before generation) ----------------------- + token_count = self._count_tokens(chat_task.messages) + # Reserve room for max_new_tokens so the request won't exceed + # the SGLang context window. + max_new = self.generation_controller.sampling_params.get("max_tokens", 2048) + if token_count + max_new > self.max_total_tokens: + logger.info( + "Token budget approaching limit (%d + %d > %d); " + "requesting final answer.", + token_count, + max_new, + self.max_total_tokens, + ) + chat_task.add_message( + UserMessage( + "You have now reached the maximum context length you can handle. " + "You should stop making tool calls and, based on all the information above, " + "think again and provide what you consider the most likely answer " + "in the following format:" + "your final thinking\n" + "your answer" + ) + ) + # Cap max_tokens directly on the task to stay within budget. + # NativeGenerationController.process() only sets values when + # task.max_tokens is None, so we must set it on the task itself. + remaining = max( + 256, + self.max_total_tokens + - self._count_tokens(chat_task.messages) + - self._token_safety_margin, + ) + chat_task.max_tokens = remaining + yield from self.generation_controller.process([chat_task], **kwargs) + break + + # --- LLM generation ------------------------------------------------ + yield from self.generation_controller.process([chat_task], **kwargs) + + # Extract the assistant reply + last_msg = chat_task.messages[-1] + content = last_msg.content or "" + + # --- Tool call handling ------------------------------------------- + if "" in content and "" in content: + tool_call_str = content.split("")[1].split("")[0] + try: + tool_call = json5.loads(tool_call_str) + tool_name = tool_call["name"] + tool_args = tool_call.get("arguments", {}) + # Execute tool (async → sync bridge) + loop = asyncio.new_event_loop() + try: + result = loop.run_until_complete( + self._execute_tool(tool_name, tool_args) + ) + finally: + loop.close() + except Exception as e: + result = ( + f"Error: {e} Tool call must be valid JSON with " + f'"name" and "arguments" fields.' + ) + tool_response = f"{OBS_START}\n{result}{OBS_END}" + chat_task.add_message(UserMessage(tool_response)) + + # --- Check for final answer --------------------------------------- + if "" in content and "" in content: + break + + # --- Turn limit reached without answer -------------------------------- + if turn == self.max_turns - 1: + last_content = chat_task.messages[-1].content or "" + if "" not in last_content: + chat_task.add_message( + UserMessage( + "Sorry, the number of LLM calls exceeds the limit. " + "You should stop making tool calls and, based on all " + "the information above, think again and provide what " + "you consider the most likely answer in the following format:" + "your final thinking\n" + "your answer" + ) + ) + # Cap max_tokens directly on the task to stay within budget. + remaining = max( + 256, + self.max_total_tokens + - self._count_tokens(chat_task.messages) + - self._token_safety_margin, + ) + chat_task.max_tokens = remaining + yield from self.generation_controller.process([chat_task], **kwargs) diff --git a/examples/scaffolding/search_reward.py b/examples/scaffolding/search_reward.py new file mode 100644 index 0000000000..fcab3774fa --- /dev/null +++ b/examples/scaffolding/search_reward.py @@ -0,0 +1,71 @@ +""" +Reward function for the search scaffolding example. + +Checks whether the model produced an ``...`` tag and does +basic string matching against the ground truth. For production training +use an LLM-as-judge (as in tongyi_deepresearch); this simple version is +sufficient for the fake-tool demo. +""" + +import re + + +def search_reward_fn( + prompt_str: str, + completion_str: str, + input_tokens: list[int], + output_tokens: list[int], + **data, +) -> float: + """Compute reward for a search agent trajectory. + + The function extracts the text inside ``...`` from + *completion_str* and compares it against ``data["answer"]``. A reward + of 1.0 is given when the ground-truth answer string appears (case- + insensitive) inside the predicted answer; otherwise the reward is 0.0. + + A small bonus (0.1) is awarded if an ```` tag is present but + the content does not match, to encourage the model to at least produce + a structured answer. + + Parameters + ---------- + prompt_str : str + The prompt string (unused beyond signature compatibility). + completion_str : str + The model's full completion text. + input_tokens : list[int] + Input token IDs (unused beyond signature compatibility). + output_tokens : list[int] + Output token IDs (unused beyond signature compatibility). + **data + Must contain an ``"answer"`` key with the ground-truth answer. + + Returns + ------- + float + Reward value: 1.0 for correct, 0.1 for structured but wrong, 0.0 + for missing answer tag. + """ + ground_truth = data.get("answer", "") + if isinstance(ground_truth, list): + ground_truth = str(ground_truth[0]) if ground_truth else "" + ground_truth = str(ground_truth).strip() + + # Extract predicted answer from ... tags + match = re.search(r"(.*?)", completion_str, re.DOTALL) + if match is None: + return 0.0 + + predicted = match.group(1).strip() + if not predicted: + return 0.0 + + # Case-insensitive containment check (either direction) + if ground_truth.lower() in predicted.lower(): + return 1.0 + if predicted.lower() in ground_truth.lower(): + return 1.0 + + # Structured answer present but incorrect + return 0.1 diff --git a/examples/scaffolding/search_scaffolding.py b/examples/scaffolding/search_scaffolding.py new file mode 100644 index 0000000000..4b3344d9db --- /dev/null +++ b/examples/scaffolding/search_scaffolding.py @@ -0,0 +1,321 @@ +""" +Search Agent Scaffolding Example. + +This example demonstrates multi-turn search-based RL training using +the scaffolding framework. A ``SearchAgentController`` drives a +tool-calling loop (search + visit) expressed as a scaffolding +``Controller``, while ``TraceTrajectoryMaker`` traces each LLM call +for PPO training. + +The example uses real web search (via Serper API, requires SERPER_KEY_ID +env var) and basic HTTP fetching for page visits. An LLM judge is used +for reward computation (the same inference engine is used for both agent +generation and judging). + +Usage: + python -m examples.scaffolding.search_scaffolding \\ + --config examples/scaffolding/search_scaffolding.yaml \\ + +scheduler.type=local experiment_name=areal trial_name=search_scaffolding +""" + +import datetime +import sys +from collections.abc import Callable +from typing import Any + +import torch +from transformers import PreTrainedTokenizerFast + +from areal.api.cli_args import GenerationHyperparameters, GRPOConfig, load_expr_config +from areal.api.engine_api import InferenceEngine +from areal.dataset import get_custom_dataset +from areal.trainer import PPOTrainer +from areal.utils import logging +from areal.utils.hf_utils import load_hf_tokenizer + +from ._compat import ( + NativeGenerationController, + ScaffoldingLlm, +) +from .controllers import ( + LLMJudgeController, + TraceTrajectoryMaker, +) +from .search_agent_controller import SearchAgentController +from .workflow import ScaffoldingWorkflow + +logger = logging.getLogger("SearchScaffoldingWorkflow") + +# Reuse the system prompt from tongyi_deepresearch (search + visit only). +SYSTEM_PROMPT = ( + "You are a deep research assistant. Your core function is to conduct " + "thorough, multi-source investigations into any topic. You must handle " + "both broad, open-domain inquiries and queries within specialized academic " + "fields. For every request, synthesize information from credible, diverse " + "sources to deliver a comprehensive, accurate, and objective response. " + "When you have gathered sufficient information and are ready to provide " + "the definitive response, you must enclose the entire final answer within " + " tags.\n\n" + "# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "search", "description": ' + '"Perform Google web searches then returns a string of the top search ' + 'results. Accepts multiple queries.", "parameters": {"type": "object", ' + '"properties": {"query": {"type": "array", "items": {"type": "string", ' + '"description": "The search query."}, "minItems": 1, "description": ' + '"The list of search queries."}}, "required": ["query"]}}}\n' + '{"type": "function", "function": {"name": "visit", "description": ' + '"Visit webpage(s) and return the summary of the content.", "parameters": ' + '{"type": "object", "properties": {"url": {"type": "array", "items": ' + '{"type": "string"}, "description": "The URL(s) of the webpage(s) to ' + 'visit. Can be a single URL or an array of URLs."}, "goal": {"type": ' + '"string", "description": "The specific information goal for visiting ' + 'webpage(s)."}}, "required": ["url", "goal"]}}}\n' + "\n\n" + "For each function call, return a json object with function name and " + "arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "\n\n" + "Current date: " +) + + +class SearchScaffoldingWorkflow(ScaffoldingWorkflow): + """ScaffoldingWorkflow for multi-turn search-agent RL training. + + The episode loop delegates to ``SearchAgentController`` (multi-turn + tool calling) composed with ``TraceTrajectoryMaker`` (trajectory + tracing) and ``LLMJudgeController`` (LLM-as-judge reward). + + Parameters + ---------- + reward_fn : Callable | str + Fallback reward function or importable path (used by parent class + for non-LLM-judge scenarios; the LLM judge is the primary reward). + gconfig : GenerationHyperparameters + Generation hyperparameters. + tokenizer : PreTrainedTokenizerFast | str + Tokenizer or path. + enable_thinking : bool + Whether to enable thinking tokens. + max_turns : int + Maximum number of LLM calls per episode. + max_total_tokens : int + Soft token budget for the conversation. + max_judge_tokens : int + Maximum tokens for the LLM judge response. + """ + + def __init__( + self, + reward_fn: Callable[..., Any] | str, + gconfig: GenerationHyperparameters, + tokenizer: PreTrainedTokenizerFast | str, + enable_thinking: bool = False, + max_turns: int = 20, + max_total_tokens: int = 32768, + max_judge_tokens: int = 8192, + ): + super().__init__( + reward_fn=reward_fn, + gconfig=gconfig, + tokenizer=tokenizer, + enable_thinking=enable_thinking, + ) + self.max_turns = max_turns + self.max_total_tokens = max_total_tokens + self.max_judge_tokens = max_judge_tokens + + # ------------------------------------------------------------------ + # Scaffolding construction + # ------------------------------------------------------------------ + + def build_scaffolding_llm(self, engine: InferenceEngine) -> ScaffoldingLlm: + """Build ``ScaffoldingLlm`` with ``SearchAgentController`` + ``TraceTrajectoryMaker``. + + Uses ``LLMJudgeController`` as the reward controller so that + answer correctness is determined by the same LLM (via a judge + prompt) rather than a deterministic string-matching function. + + Parameters + ---------- + engine : InferenceEngine + The inference engine (worker already initialised by parent). + + Returns + ------- + ScaffoldingLlm + """ + stop_strings = ["\n", ""] + sampling_params: dict[str, Any] = { + "max_tokens": self.gconfig.max_new_tokens, + "temperature": self.gconfig.temperature or 1.0, + "stop": stop_strings, + } + + self.gen_controller = NativeGenerationController( + sampling_params=sampling_params, + ) + self.reward_controller = LLMJudgeController( + max_judge_tokens=self.max_judge_tokens, + ) + + self.search_controller = SearchAgentController( + generation_controller=self.gen_controller, + tokenizer=self.tokenizer, + max_turns=self.max_turns, + max_total_tokens=self.max_total_tokens, + ) + + self.trajectory_maker = TraceTrajectoryMaker( + rollout_controller=self.search_controller, + reward_controller=self.reward_controller, + ) + + return ScaffoldingLlm( + self.trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: self.worker}, + ) + + # ------------------------------------------------------------------ + # Episode + # ------------------------------------------------------------------ + + async def arun_episode( + self, engine: InferenceEngine, data: dict[str, Any] + ) -> dict[str, torch.Tensor]: + """Run a single search-agent episode. + + Parameters + ---------- + engine : InferenceEngine + The inference engine. + data : dict[str, Any] + Must contain ``"question"`` and ``"answer"`` keys. + + Returns + ------- + dict[str, torch.Tensor] + Trajectory tensors for PPO training. + """ + if self.worker is None: + self._lazy_init_scaffolding(engine) + + # Build messages: system prompt + user question + system_prompt = SYSTEM_PROMPT + datetime.date.today().strftime("%Y-%m-%d") + question = data.get("question", "") + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + + # Tokenize the original prompt + input_ids = list( + self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + ) + prompt_str = self.tokenizer.decode(input_ids) + + # Configure per-episode state on the search controller and reward controller. + # ScaffoldingLlm.clone() deep-copies the controllers for each request, + # so the per-episode data is isolated across concurrent episodes. + self.search_controller.messages = messages + self.search_controller.input_tokens = input_ids + self.reward_controller.task_data = data + + # Run the full pipeline (generation + LLM judge reward) + result = self.scaffolding_llm.generate_async(prompt_str) + await result + + # Extract trace results + scaffolding_output = result.outputs[0] + trace_results = scaffolding_output.data + + # Get the final output text from the last traced interaction + if trace_results: + last_interaction = list(trace_results.values())[-1] + output_str = "" + if last_interaction.completion is not None: + output_str = ( + last_interaction.completion.choices[0].message.content or "" + ) + # Reward is set by LLMJudgeController via TraceTrajectoryMaker + reward = float(last_interaction.reward or 0.0) + else: + output_str = scaffolding_output.text or "" + reward = 0.0 + + output_tokens = self.tokenizer.encode(output_str, add_special_tokens=False) + + # Build tensor dict for PPO training + seq = input_ids + output_tokens + logprobs = [0.0] * len(seq) + loss_mask = [0] * len(input_ids) + [1] * len(output_tokens) + versions = [-1] * len(seq) + + res = { + "input_ids": torch.tensor(seq, dtype=torch.int32), + "loss_mask": torch.tensor(loss_mask, dtype=torch.int32), + "logprobs": torch.tensor(logprobs, dtype=torch.float32), + "versions": torch.tensor(versions, dtype=torch.int32), + "attention_mask": torch.ones(len(seq), dtype=torch.bool), + "rewards": torch.tensor(reward, dtype=torch.float32), + } + return {k: v.unsqueeze(0) for k, v in res.items()} + + +# ---------------------------------------------------------------------- +# Entry point +# ---------------------------------------------------------------------- + + +def main(args): + """Main entry point for search scaffolding training.""" + config, _ = load_expr_config(args, GRPOConfig) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + train_dataset = get_custom_dataset( + split="train", + dataset_config=config.train_dataset, + tokenizer=tokenizer, + ) + valid_dataset = get_custom_dataset( + split="test", + dataset_config=config.valid_dataset, + tokenizer=tokenizer, + ) + + workflow_kwargs = dict( + reward_fn="examples.scaffolding.search_reward.search_reward_fn", + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + enable_thinking=False, + max_turns=10, + max_total_tokens=8192, + max_judge_tokens=2048, + ) + eval_workflow_kwargs = workflow_kwargs.copy() + eval_workflow_kwargs["gconfig"] = config.gconfig.new(temperature=0.6) + + with PPOTrainer( + config, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + ) as trainer: + trainer.train( + workflow="examples.scaffolding.search_scaffolding.SearchScaffoldingWorkflow", + workflow_kwargs=workflow_kwargs, + eval_workflow="examples.scaffolding.search_scaffolding.SearchScaffoldingWorkflow", + eval_workflow_kwargs=eval_workflow_kwargs, + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/scaffolding/search_scaffolding.yaml b/examples/scaffolding/search_scaffolding.yaml new file mode 100644 index 0000000000..b40f9742f1 --- /dev/null +++ b/examples/scaffolding/search_scaffolding.yaml @@ -0,0 +1,178 @@ +# Search Agent Scaffolding Example Configuration +# Uses SearchScaffoldingWorkflow with TraceTrajectoryMaker + SearchAgentController +# Compatible with GRPOConfig, single-GPU setup + +experiment_name: search-scaffolding +trial_name: trial0 + +seed: 1 +enable_offload: true +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 1 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: sglang:d1+d1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 32 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 4 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 8 + min_new_tokens: 0 + max_new_tokens: 2048 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-3B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 8192 + optimizer: + type: adam + lr: 1.0e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: disk + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 4096 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang - larger context for multi-turn search conversations +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_running_requests: null + context_length: 8192 + mem_fraction_static: 0.3 + attention_backend: flashinfer + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 8192 + gpu_memory_utilization: 0.8 + +# Datasets - uses GSM8K with question/answer fields for the demo +train_dataset: + batch_size: 16 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 8192 + +valid_dataset: + batch_size: 16 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/examples/scaffolding/task.py b/examples/scaffolding/task.py new file mode 100644 index 0000000000..7c6456b9ac --- /dev/null +++ b/examples/scaffolding/task.py @@ -0,0 +1,233 @@ +""" +RLVR Tasks for Scaffolding Framework. + +This module provides task definitions for RLVR (Reinforcement Learning with +Verifiable Rewards) that integrate with the scaffolding framework. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from ._compat import ( + ChatTask, + GenerationTask, + ScaffoldingOutput, + Task, +) + +if TYPE_CHECKING: + from areal.experimental.openai.types import InteractionWithTokenLogpReward + + +@dataclass +class RLVRRewardTask(Task): + """Task for computing RLVR (verifiable) rewards. + + This task contains the necessary information to verify whether a generated + response is correct and compute the corresponding reward. + + Attributes + ---------- + prompt_str : str + The prompt string that was used for generation. + completion_str : str + The generated completion string to verify. + input_tokens : list[int] + The input token IDs. + output_tokens : list[int] + The output token IDs. + output_logprobs : list[float] + The log probabilities of output tokens. + output_versions : list[int] + The weight versions for output tokens. + task_data : dict[str, Any] + Additional task data containing ground truth (e.g., "answer" field). + interaction : InteractionWithTokenLogpReward + The interaction object to store the computed reward. + reward : float + The computed reward value (output field, set after processing). + """ + + # Input fields + prompt_str: str = field(default="") + completion_str: str = field(default="") + input_tokens: list[int] = field(default_factory=list) + output_tokens: list[int] = field(default_factory=list) + output_logprobs: list[float] = field(default_factory=list) + output_versions: list[int] = field(default_factory=list) + task_data: dict[str, Any] = field(default_factory=dict) + + # The interaction object to update with the reward + interaction: InteractionWithTokenLogpReward | None = None + + # Output field + reward: float | None = None + + @staticmethod + def create_from_generation_task( + gen_task: GenerationTask, + prompt_str: str, + task_data: dict[str, Any], + interaction: InteractionWithTokenLogpReward | None = None, + ) -> RLVRRewardTask: + """Create a reward task from a completed generation task. + + Parameters + ---------- + gen_task : GenerationTask + The completed generation task with output. + prompt_str : str + The original prompt string. + task_data : dict[str, Any] + Task data containing ground truth answer. + interaction : InteractionWithTokenLogpReward, optional + The interaction object to update with reward. + + Returns + ------- + RLVRRewardTask + The reward task ready for processing. + """ + reward_task = RLVRRewardTask( + prompt_str=prompt_str, + completion_str=gen_task.output_str or "", + input_tokens=list(gen_task.input_tokens or []), + output_tokens=list(gen_task.output_tokens or []), + output_logprobs=list( + gen_task.customized_result_fields.get("output_logprobs", []) + ), + output_versions=list( + gen_task.customized_result_fields.get("output_versions", []) + ), + task_data=task_data, + interaction=interaction, + ) + return reward_task + + +@dataclass +class TraceGenerationTask(Task): + """Task for tracing multi-turn generation with ChatTracer. + + This task wraps a ChatTask (or GenerationTask) for tracing purposes. + The trace results are stored after processing. + + Attributes + ---------- + generation_task : ChatTask | GenerationTask + The underlying task to be processed and traced. + trace_results : dict[str, InteractionWithTokenLogpReward] + The traced interaction results (output field, set after processing). + """ + + # The underlying generation/chat task + generation_task: ChatTask | GenerationTask | None = None + + # Output field - trace results after processing + trace_results: dict[str, InteractionWithTokenLogpReward] | None = None + + @staticmethod + def create_from_prompt(prompt: str) -> TraceGenerationTask: + """Create a TraceGenerationTask from a prompt string. + + Parameters + ---------- + prompt : str + The input prompt string. + + Returns + ------- + TraceGenerationTask + The task ready for processing. + """ + # Create underlying ChatTask + chat_task = ChatTask.create_from_prompt(prompt) + return TraceGenerationTask(generation_task=chat_task) + + @staticmethod + def create_from_chat_task(chat_task: ChatTask) -> TraceGenerationTask: + """Create a TraceGenerationTask from an existing ChatTask. + + Parameters + ---------- + chat_task : ChatTask + The ChatTask to wrap. + + Returns + ------- + TraceGenerationTask + The task ready for processing. + """ + return TraceGenerationTask(generation_task=chat_task) + + def create_scaffolding_output(self) -> ScaffoldingOutput: + """Create a ScaffoldingOutput from the trace results. + + Returns + ------- + ScaffoldingOutput + The output containing traced results. + """ + # Return the trace results as the output + if self.generation_task is not None and hasattr( + self.generation_task, "output_str" + ): + return ScaffoldingOutput( + text=self.generation_task.output_str or "", + token_ids=list(self.generation_task.output_tokens or []), + data=self.trace_results, + ) + return ScaffoldingOutput(text="", token_ids=[], data=self.trace_results) + + +@dataclass +class ChatRewardTask(Task): + """Task for computing rewards on traced chat interactions. + + This task contains a traced InteractionWithTokenLogpReward and is used + by the reward controller to compute and set rewards. + + Attributes + ---------- + interaction : InteractionWithTokenLogpReward + The traced interaction to compute reward for. + interaction_id : str + The ID of the interaction. + reward : float + The computed reward value (output field, set after processing). + """ + + # The traced interaction + interaction: InteractionWithTokenLogpReward | None = None + + # Interaction ID for reference + interaction_id: str = field(default="") + + # Output field + reward: float | None = None + + @staticmethod + def create_from_trace_result( + interaction_id: str, + interaction: InteractionWithTokenLogpReward, + ) -> ChatRewardTask: + """Create a ChatRewardTask from a trace result. + + Parameters + ---------- + interaction_id : str + The ID of the interaction. + interaction : InteractionWithTokenLogpReward + The traced interaction. + + Returns + ------- + ChatRewardTask + The reward task ready for processing. + """ + return ChatRewardTask( + interaction=interaction, + interaction_id=interaction_id, + ) diff --git a/examples/scaffolding/tests/__init__.py b/examples/scaffolding/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/scaffolding/tests/test_controllers.py b/examples/scaffolding/tests/test_controllers.py new file mode 100644 index 0000000000..c1e216f28d --- /dev/null +++ b/examples/scaffolding/tests/test_controllers.py @@ -0,0 +1,935 @@ +"""Unit tests for scaffolding controllers (TraceTrajectoryMaker, PipelineTrajectoryMaker). + +Tests use fake workers/controllers that simulate LLM inference responses +without requiring an SGLang backend or GPU. + +Design Notes +------------ +- ``FakeGenerationController`` fills ``GenerationTask`` fields in-memory + (single-turn generation). +- ``FakeChatRolloutController`` appends assistant messages to ``ChatTask`` + across multiple yields (multi-turn chat). It manually calls + ``ChatTracer.before_yield / after_yield`` because the lightweight + ``with_task_collection`` decorator in ``_compat.py`` only attaches the + ``ChatTracer`` instance to the class; it does NOT wrap ``process`` to + invoke hooks automatically (that is the tensorrt_llm implementation's + responsibility). +- ``FakeChatRewardController`` assigns predetermined rewards to + ``ChatRewardTask`` objects. +- The lightweight ``with_task_collection`` creates the ``ChatTracer`` as a + **class-level** attribute, so all ``TraceTrajectoryMaker`` instances share + one tracer. Each test that uses it must call ``tracer.clear()`` (or use a + fresh instance) to avoid inter-test pollution. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from examples.scaffolding._compat import ( + AssistantMessage, + ChatTask, + Controller, + GenerationTask, + Task, +) +from examples.scaffolding.controllers import ( + ChatTracer, + PipelineTrajectoryMaker, + RLVRRewardController, + TraceTrajectoryMaker, +) +from examples.scaffolding.task import ( + ChatRewardTask, + RLVRRewardTask, + TraceGenerationTask, +) + +from areal.api.io_struct import ModelResponse +from areal.experimental.openai.types import InteractionWithTokenLogpReward + +# --------------------------------------------------------------------------- +# Fake / stub helpers +# --------------------------------------------------------------------------- + +FAKE_INPUT_TOKENS = [101, 102, 103] +FAKE_OUTPUT_TOKENS = [201, 202, 203, 204] +FAKE_OUTPUT_STR = "42" +FAKE_PROMPT_STR = "What is the answer to life?" + + +def _simple_reward_fn( + prompt: str, + completions: str, + prompt_ids: list[int], + completion_ids: list[int], + **kwargs, +) -> float: + """Deterministic reward: 1.0 if completion contains the ground-truth answer, else 0.0.""" + answer = kwargs.get("answer", "") + return 1.0 if answer and answer in completions else 0.0 + + +class FakeGenerationController(Controller): + """Fills ``GenerationTask`` fields without calling any LLM backend.""" + + def __init__( + self, + output_str: str = FAKE_OUTPUT_STR, + output_tokens: list[int] | None = None, + input_tokens: list[int] | None = None, + output_logprobs: list[float] | None = None, + output_versions: list[int] | None = None, + ): + super().__init__() + self.output_str = output_str + self.output_tokens = output_tokens or FAKE_OUTPUT_TOKENS + self.input_tokens = input_tokens or FAKE_INPUT_TOKENS + self.output_logprobs = output_logprobs + self.output_versions = output_versions + + def process(self, tasks: list[Task], **kwargs) -> Any: + for task in tasks: + if isinstance(task, GenerationTask): + task.output_str = self.output_str + task.output_tokens = self.output_tokens + if task.input_tokens is None: + task.input_tokens = self.input_tokens + if self.output_logprobs is not None: + task.customized_result_fields["output_logprobs"] = ( + self.output_logprobs + ) + if self.output_versions is not None: + task.customized_result_fields["output_versions"] = ( + self.output_versions + ) + yield tasks + + +def _make_fake_completion(completion_id: str = "cmpl-001") -> MagicMock: + """Create a minimal fake ``ChatCompletion`` object.""" + completion = MagicMock() + completion.id = completion_id + completion.created = 1000 + completion.choices = [MagicMock()] + completion.choices[0].message.content = FAKE_OUTPUT_STR + completion.choices[0].finish_reason = "stop" + return completion + + +class FakeChatRolloutController(Controller): + """Simulates a multi-turn chat rollout by appending assistant messages. + + Each turn: + 1. Populates the ``ChatTask`` with a fake completion and tokens. + 2. Calls ``ChatTracer.before_yield`` / ``after_yield`` manually. + 3. Yields the tasks. + 4. Appends a follow-up user message before the next turn. + """ + + def __init__( + self, + n_turns: int = 2, + responses: list[str] | None = None, + output_tokens_per_turn: list[list[int]] | None = None, + input_tokens_per_turn: list[list[int]] | None = None, + ): + super().__init__() + self.n_turns = n_turns + self.responses = responses or [f"response_{i}" for i in range(n_turns)] + self.output_tokens_per_turn = output_tokens_per_turn or [ + [300 + i * 10 + j for j in range(4)] for i in range(n_turns) + ] + self.input_tokens_per_turn = input_tokens_per_turn or [ + FAKE_INPUT_TOKENS for _ in range(n_turns) + ] + # Set by the test to allow manual ChatTracer hook invocation. + self._tracer: ChatTracer | None = None + + def process(self, tasks: list[Task], **kwargs) -> Any: + for turn_idx in range(self.n_turns): + for task in tasks: + if isinstance(task, ChatTask): + task.messages.append( + AssistantMessage(content=self.responses[turn_idx]) + ) + task.output_tokens = self.output_tokens_per_turn[turn_idx] + task.input_tokens = self.input_tokens_per_turn[turn_idx] + task.completion = _make_fake_completion( + completion_id=f"cmpl-turn-{turn_idx}" + ) + + if self._tracer is not None: + self._tracer.before_yield(tasks) + + yield tasks + + if self._tracer is not None: + self._tracer.after_yield(tasks) + + if turn_idx < self.n_turns - 1: + for task in tasks: + if isinstance(task, ChatTask): + task.messages.append( + {"role": "user", "content": f"follow-up-{turn_idx}"} + ) + + +class FakeChatRewardController(Controller): + """Assigns predetermined rewards to ``ChatRewardTask`` objects.""" + + def __init__(self, rewards: list[float] | None = None, default_reward: float = 1.0): + super().__init__() + self.rewards = rewards + self.default_reward = default_reward + + def process(self, tasks: list[Task], **kwargs) -> Any: + for i, task in enumerate(tasks): + if isinstance(task, ChatRewardTask): + reward = ( + self.rewards[i] + if self.rewards is not None and i < len(self.rewards) + else self.default_reward + ) + task.reward = reward + if task.interaction is not None: + task.interaction.reward = reward + yield tasks + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_shared_tracer(): + """Reset the class-level ChatTracer before each test. + + The lightweight ``with_task_collection`` stores a single ``ChatTracer`` + instance on ``TraceTrajectoryMaker`` (class-level). We must clear it + between tests so that cached interactions don't leak. + """ + yield + tracer = TraceTrajectoryMaker.task_collections.get("chat_tracer") + if tracer is not None: + tracer.clear() + + +# =========================================================================== +# PipelineTrajectoryMaker tests +# =========================================================================== + + +class TestPipelineTrajectoryMaker: + """Tests for PipelineTrajectoryMaker (single-turn generation + reward).""" + + def test_basic_pipeline_correct_answer(self): + """Pipeline should produce interaction with reward=1.0 for a correct answer.""" + gen_ctrl = FakeGenerationController(output_str="The answer is 42.") + reward_ctrl = RLVRRewardController(_simple_reward_fn) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + ) + + task = GenerationTask(input_str=FAKE_PROMPT_STR, input_tokens=FAKE_INPUT_TOKENS) + results = list(maker.process([task])) + + # Only generation yields (reward computed locally, no dict yield) + assert len(results) == 1 + + # Interactions stored on task + interactions = task.customized_result_fields["interactions"] + assert isinstance(interactions, dict) + assert len(interactions) == 1 + + interaction = list(interactions.values())[0] + assert isinstance(interaction, InteractionWithTokenLogpReward) + assert interaction.reward == 1.0 + + def test_basic_pipeline_wrong_answer(self): + """Pipeline should produce interaction with reward=0.0 for a wrong answer.""" + gen_ctrl = FakeGenerationController(output_str="I don't know") + reward_ctrl = RLVRRewardController(_simple_reward_fn) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + ) + + task = GenerationTask(input_str=FAKE_PROMPT_STR, input_tokens=FAKE_INPUT_TOKENS) + list(maker.process([task])) + + interactions = task.customized_result_fields["interactions"] + interaction = list(interactions.values())[0] + assert interaction.reward == 0.0 + + def test_pipeline_multiple_tasks(self): + """Pipeline should handle multiple GenerationTasks in a single batch.""" + gen_ctrl = FakeGenerationController(output_str="42") + reward_ctrl = RLVRRewardController(_simple_reward_fn) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + ) + + tasks = [ + GenerationTask(input_str=FAKE_PROMPT_STR, input_tokens=FAKE_INPUT_TOKENS), + GenerationTask(input_str="Another prompt", input_tokens=[111, 112]), + ] + list(maker.process(tasks)) + + # Both tasks should have the same interactions dict + interactions = tasks[0].customized_result_fields["interactions"] + assert len(interactions) == 2 + for interaction in interactions.values(): + assert interaction.reward == 1.0 + + def test_pipeline_interaction_has_model_response(self): + """Each interaction should contain a valid ModelResponse with tokens.""" + gen_ctrl = FakeGenerationController( + output_str="42", + output_tokens=[201, 202], + output_logprobs=[0.1, 0.2], + output_versions=[1, 1], + ) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + ) + + task = GenerationTask(input_str=FAKE_PROMPT_STR, input_tokens=FAKE_INPUT_TOKENS) + list(maker.process([task])) + + interaction = list(task.customized_result_fields["interactions"].values())[0] + mr = interaction.model_response + assert isinstance(mr, ModelResponse) + assert mr.input_tokens == FAKE_INPUT_TOKENS + assert mr.output_tokens == [201, 202] + assert mr.output_logprobs == [0.1, 0.2] + assert mr.output_versions == [1, 1] + + def test_pipeline_uses_constructor_prompt_str(self): + """Reward controller should receive the prompt_str provided at construction.""" + received_prompts = [] + + def _capture(prompt, completions, prompt_ids, completion_ids, **kw): + received_prompts.append(prompt) + return 1.0 + + gen_ctrl = FakeGenerationController(output_str="42") + reward_ctrl = RLVRRewardController(_capture) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str="custom prompt", + ) + + task = GenerationTask(input_str="input str", input_tokens=FAKE_INPUT_TOKENS) + list(maker.process([task])) + + assert received_prompts == ["custom prompt"] + + def test_pipeline_falls_back_to_input_str(self): + """When prompt_str is empty, should fall back to task.input_str.""" + received_prompts = [] + + def _capture(prompt, completions, prompt_ids, completion_ids, **kw): + received_prompts.append(prompt) + return 1.0 + + gen_ctrl = FakeGenerationController(output_str="42") + reward_ctrl = RLVRRewardController(_capture) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str="", + ) + + task = GenerationTask( + input_str="fallback prompt", input_tokens=FAKE_INPUT_TOKENS + ) + list(maker.process([task])) + + assert received_prompts[0] == "fallback prompt" + + def test_pipeline_reward_scores_tracked(self): + """RLVRRewardController should track scores in self.scores.""" + gen_ctrl = FakeGenerationController(output_str="42") + reward_ctrl = RLVRRewardController(_simple_reward_fn) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + ) + + tasks = [ + GenerationTask(input_str=FAKE_PROMPT_STR, input_tokens=FAKE_INPUT_TOKENS), + GenerationTask(input_str=FAKE_PROMPT_STR, input_tokens=FAKE_INPUT_TOKENS), + ] + list(maker.process(tasks)) + + assert reward_ctrl.scores == [1.0, 1.0] + + def test_pipeline_default_logprobs_and_versions(self): + """When no logprobs/versions provided, interaction should use placeholders.""" + gen_ctrl = FakeGenerationController( + output_str="42", + output_tokens=[201, 202], + # No logprobs/versions supplied + ) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + ) + + task = GenerationTask(input_str=FAKE_PROMPT_STR, input_tokens=FAKE_INPUT_TOKENS) + list(maker.process([task])) + + interaction = list(task.customized_result_fields["interactions"].values())[0] + mr = interaction.model_response + # Placeholders: [0.0] * output_len + assert mr.output_logprobs == [0.0, 0.0] + assert mr.output_versions == [-1, -1] + + +# =========================================================================== +# ChatTracer tests +# =========================================================================== + + +class TestChatTracer: + """Tests for ChatTracer (TaskCollection for tracing multi-turn chats).""" + + def test_after_yield_creates_interaction(self): + """after_yield should create an interaction for each ChatTask.""" + tracer = ChatTracer(export_style="individual") + task = ChatTask(messages=[{"role": "user", "content": "hello"}]) + task.completion = _make_fake_completion("cmpl-001") + task.input_tokens = [1, 2, 3] + task.output_tokens = [4, 5] + + tracer.after_yield([task]) + + results = tracer.get_trace_results() + assert len(results) == 1 + assert "cmpl-001" in results + + def test_multiple_turns_traced(self): + """Calling after_yield with different completions should trace all.""" + tracer = ChatTracer(export_style="individual") + task = ChatTask(messages=[{"role": "user", "content": "hello"}]) + + # Turn 1 + task.completion = _make_fake_completion("cmpl-turn-0") + task.input_tokens = [1, 2] + task.output_tokens = [3, 4] + tracer.after_yield([task]) + + # Turn 2 (same ChatTask, new completion) + task.messages.append(AssistantMessage(content="first response")) + task.completion = _make_fake_completion("cmpl-turn-1") + task.input_tokens = [1, 2, 3, 4] + task.output_tokens = [5, 6] + tracer.after_yield([task]) + + results = tracer.get_trace_results() + assert len(results) == 2 + assert "cmpl-turn-0" in results + assert "cmpl-turn-1" in results + + def test_tracer_interaction_has_model_response(self): + """Traced interactions should have ModelResponse with correct tokens.""" + tracer = ChatTracer(export_style="individual") + task = ChatTask(messages=[{"role": "user", "content": "hello"}]) + task.completion = _make_fake_completion("cmpl-001") + task.input_tokens = [10, 20] + task.output_tokens = [30, 40, 50] + + tracer.after_yield([task]) + + interaction = tracer.get_trace_results()["cmpl-001"] + assert interaction.model_response is not None + assert interaction.model_response.input_tokens == [10, 20] + assert interaction.model_response.output_tokens == [30, 40, 50] + + def test_tracer_clear(self): + """clear() should remove all traced data.""" + tracer = ChatTracer(export_style="individual") + task = ChatTask(messages=[{"role": "user", "content": "hello"}]) + task.completion = _make_fake_completion("cmpl-001") + task.input_tokens = [1] + task.output_tokens = [2] + + tracer.after_yield([task]) + assert len(tracer.get_trace_results()) == 1 + + tracer.clear() + assert tracer.get_trace_results() == {} + + def test_tracer_ignores_non_chat_tasks(self): + """after_yield should skip non-ChatTask objects.""" + tracer = ChatTracer(export_style="individual") + tracer.after_yield([GenerationTask(input_str="hello")]) + assert tracer.get_trace_results() == {} + + def test_tracer_empty_returns_empty(self): + """get_trace_results on a fresh tracer should return empty dict.""" + tracer = ChatTracer(export_style="individual") + assert tracer.get_trace_results() == {} + + +# =========================================================================== +# TraceTrajectoryMaker tests +# =========================================================================== + + +class TestTraceTrajectoryMaker: + """Tests for TraceTrajectoryMaker (multi-turn tracing + reward pipeline).""" + + @staticmethod + def _make_trace_maker( + n_turns: int = 2, + responses: list[str] | None = None, + rewards: list[float] | None = None, + default_reward: float = 1.0, + ) -> tuple[TraceTrajectoryMaker, FakeChatRolloutController]: + """Build a ``TraceTrajectoryMaker`` with fake sub-controllers. + + Also wires the class-level ``ChatTracer`` into the fake rollout + controller so it can call ``before_yield`` / ``after_yield`` hooks. + """ + rollout_ctrl = FakeChatRolloutController(n_turns=n_turns, responses=responses) + reward_ctrl = FakeChatRewardController( + rewards=rewards, default_reward=default_reward + ) + maker = TraceTrajectoryMaker( + rollout_controller=rollout_ctrl, + reward_controller=reward_ctrl, + ) + chat_tracer = maker.task_collections["chat_tracer"] + rollout_ctrl._tracer = chat_tracer + return maker, rollout_ctrl + + def test_basic_trace_single_turn(self): + """Single-turn trace should produce one traced interaction with reward.""" + maker, _ = self._make_trace_maker( + n_turns=1, + responses=["The answer is 42"], + default_reward=1.0, + ) + + task = TraceGenerationTask.create_from_prompt("What is 6*7?") + list(maker.process([task])) + + assert task.trace_results is not None + assert len(task.trace_results) == 1 + interaction = list(task.trace_results.values())[0] + assert interaction.reward == 1.0 + + def test_multi_turn_trace(self): + """Multi-turn trace should produce one interaction per turn.""" + maker, _ = self._make_trace_maker( + n_turns=3, + responses=["step 1", "step 2", "final answer: 42"], + rewards=[0.0, 0.0, 1.0], + ) + + task = TraceGenerationTask.create_from_prompt("Solve step by step") + list(maker.process([task])) + + assert task.trace_results is not None + assert len(task.trace_results) == 3 + + def test_trace_rewards_assigned_correctly(self): + """Each traced interaction should get its designated reward.""" + maker, _ = self._make_trace_maker( + n_turns=2, + responses=["thinking...", "42"], + rewards=[0.5, 1.0], + ) + + task = TraceGenerationTask.create_from_prompt("What is the answer?") + list(maker.process([task])) + + assert task.trace_results is not None + rewards = [i.reward for i in task.trace_results.values()] + assert rewards == [0.5, 1.0] + + def test_trace_results_stored_in_task(self): + """trace_results should be set on the TraceGenerationTask after processing.""" + maker, _ = self._make_trace_maker(n_turns=1, default_reward=0.0) + + task = TraceGenerationTask.create_from_prompt("hello") + assert task.trace_results is None + + list(maker.process([task])) + + assert task.trace_results is not None + assert isinstance(task.trace_results, dict) + + def test_trace_with_plain_task_fallback(self): + """process should not crash when given a plain ChatTask.""" + maker, _ = self._make_trace_maker(n_turns=1, default_reward=1.0) + + chat_task = ChatTask.create_from_prompt("direct chat task") + list(maker.process([chat_task])) + # No assertion on trace_results — plain ChatTask doesn't store them. + + def test_trace_generation_task_create_from_chat_task(self): + """TraceGenerationTask.create_from_chat_task should wrap correctly.""" + chat_task = ChatTask.create_from_prompt("hello") + trace_task = TraceGenerationTask.create_from_chat_task(chat_task) + + assert trace_task.generation_task is chat_task + assert trace_task.trace_results is None + + def test_trace_generation_task_scaffolding_output(self): + """create_scaffolding_output should reflect generation_task fields and trace_results.""" + gen_task = GenerationTask(output_str="result text", output_tokens=[10, 20, 30]) + trace_task = TraceGenerationTask(generation_task=gen_task) + trace_task.trace_results = {"id-1": "fake_interaction"} + + output = trace_task.create_scaffolding_output() + assert output.text == "result text" + assert output.token_ids == [10, 20, 30] + assert output.data == {"id-1": "fake_interaction"} + + def test_trace_generation_task_scaffolding_output_empty(self): + """create_scaffolding_output with no generation_task should return empty.""" + trace_task = TraceGenerationTask() + output = trace_task.create_scaffolding_output() + assert output.text == "" + assert output.token_ids == [] + assert output.data is None + + def test_no_reward_tasks_when_no_traces(self): + """If rollout produces no traceable outputs, reward step should be skipped.""" + + class EmptyRolloutController(Controller): + def process(self, tasks, **kwargs): + yield tasks + + reward_ctrl = FakeChatRewardController(default_reward=1.0) + maker = TraceTrajectoryMaker( + rollout_controller=EmptyRolloutController(), + reward_controller=reward_ctrl, + ) + assert maker.task_collections.get("chat_tracer") is not None + + task = TraceGenerationTask.create_from_prompt("hello") + list(maker.process([task])) + + assert task.trace_results is not None + assert len(task.trace_results) == 0 + + def test_trace_interaction_model_response_tokens(self): + """Traced interactions should carry correct per-turn tokens.""" + maker, _ = self._make_trace_maker( + n_turns=2, + responses=["r0", "r1"], + default_reward=1.0, + ) + # Use distinct per-turn tokens + rollout_ctrl = maker.rollout_controller + rollout_ctrl.output_tokens_per_turn = [[10, 11], [20, 21, 22]] + rollout_ctrl.input_tokens_per_turn = [[1, 2], [1, 2, 3]] + + task = TraceGenerationTask.create_from_prompt("prompt") + list(maker.process([task])) + + interactions = list(task.trace_results.values()) + assert interactions[0].model_response.output_tokens == [10, 11] + assert interactions[0].model_response.input_tokens == [1, 2] + assert interactions[1].model_response.output_tokens == [20, 21, 22] + assert interactions[1].model_response.input_tokens == [1, 2, 3] + + +# =========================================================================== +# RLVRRewardController tests +# =========================================================================== + + +class TestRLVRRewardController: + """Tests for RLVRRewardController (reward computation).""" + + def test_compute_reward_correct(self): + """Should return 1.0 when completion contains the answer.""" + ctrl = RLVRRewardController(_simple_reward_fn) + task = RLVRRewardTask( + prompt_str="What is 2+2?", + completion_str="The answer is 4", + input_tokens=[1, 2, 3], + output_tokens=[4, 5], + task_data={"answer": "4"}, + ) + + list(ctrl.process([task])) + assert task.reward == 1.0 + assert ctrl.scores == [1.0] + + def test_compute_reward_wrong(self): + """Should return 0.0 when completion does not contain the answer.""" + ctrl = RLVRRewardController(_simple_reward_fn) + task = RLVRRewardTask( + prompt_str="What is 2+2?", + completion_str="I think it's 5", + input_tokens=[1, 2, 3], + output_tokens=[4, 5], + task_data={"answer": "4"}, + ) + + list(ctrl.process([task])) + assert task.reward == 0.0 + + def test_compute_reward_updates_interaction(self): + """Reward should be propagated to the attached interaction object.""" + ctrl = RLVRRewardController(_simple_reward_fn) + interaction = InteractionWithTokenLogpReward( + model_response=ModelResponse( + input_tokens=[1], + output_tokens=[2], + output_logprobs=[0.0], + output_versions=[-1], + ), + messages=[], + ) + task = RLVRRewardTask( + prompt_str="Q", + completion_str="42", + input_tokens=[1], + output_tokens=[2], + task_data={"answer": "42"}, + interaction=interaction, + ) + + list(ctrl.process([task])) + assert interaction.reward == 1.0 + + def test_compute_reward_batch(self): + """Should process multiple tasks and track all scores.""" + ctrl = RLVRRewardController(_simple_reward_fn) + tasks = [ + RLVRRewardTask( + prompt_str="Q1", + completion_str="42", + task_data={"answer": "42"}, + ), + RLVRRewardTask( + prompt_str="Q2", + completion_str="wrong", + task_data={"answer": "42"}, + ), + RLVRRewardTask( + prompt_str="Q3", + completion_str="also 42", + task_data={"answer": "42"}, + ), + ] + + list(ctrl.process(tasks)) + assert ctrl.scores == [1.0, 0.0, 1.0] + + def test_reward_from_generation_task(self): + """Should handle GenerationTask via customized_result_fields path.""" + ctrl = RLVRRewardController(_simple_reward_fn) + gen_task = GenerationTask( + input_str="What is 2+2?", + output_str="4", + input_tokens=[1, 2], + output_tokens=[3], + ) + + list( + ctrl.process( + [gen_task], + task_data={"answer": "4"}, + prompt_str="What is 2+2?", + ) + ) + assert gen_task.customized_result_fields["reward"] == 1.0 + + +# =========================================================================== +# RLVRRewardTask tests +# =========================================================================== + + +class TestRLVRRewardTask: + """Tests for RLVRRewardTask creation and data flow.""" + + def test_create_from_generation_task(self): + """create_from_generation_task should correctly populate all fields.""" + gen_task = GenerationTask( + input_str="prompt", + output_str="completion text", + input_tokens=[1, 2, 3], + output_tokens=[4, 5], + ) + gen_task.customized_result_fields["output_logprobs"] = [0.1, 0.2] + gen_task.customized_result_fields["output_versions"] = [1, 1] + + reward_task = RLVRRewardTask.create_from_generation_task( + gen_task=gen_task, + prompt_str="original prompt", + task_data={"answer": "42"}, + ) + + assert reward_task.prompt_str == "original prompt" + assert reward_task.completion_str == "completion text" + assert reward_task.input_tokens == [1, 2, 3] + assert reward_task.output_tokens == [4, 5] + assert reward_task.output_logprobs == [0.1, 0.2] + assert reward_task.output_versions == [1, 1] + assert reward_task.task_data == {"answer": "42"} + assert reward_task.reward is None + + def test_create_from_generation_task_no_output(self): + """Should handle GenerationTask with None output gracefully.""" + gen_task = GenerationTask() + + reward_task = RLVRRewardTask.create_from_generation_task( + gen_task=gen_task, + prompt_str="prompt", + task_data={}, + ) + + assert reward_task.completion_str == "" + assert reward_task.input_tokens == [] + assert reward_task.output_tokens == [] + + +# =========================================================================== +# ChatRewardTask tests +# =========================================================================== + + +class TestChatRewardTask: + """Tests for ChatRewardTask creation.""" + + def test_create_from_trace_result(self): + """create_from_trace_result should wrap an interaction correctly.""" + interaction = InteractionWithTokenLogpReward( + model_response=ModelResponse( + input_tokens=[1, 2], + output_tokens=[3, 4], + output_logprobs=[0.0, 0.0], + output_versions=[-1, -1], + ), + messages=[{"role": "user", "content": "hello"}], + ) + + task = ChatRewardTask.create_from_trace_result("id-001", interaction) + + assert task.interaction is interaction + assert task.interaction_id == "id-001" + assert task.reward is None + + +# =========================================================================== +# End-to-end integration tests +# =========================================================================== + + +class TestEndToEnd: + """Integration tests that exercise the full scaffolding rollout pipeline.""" + + def test_pipeline_e2e_tensor_dict_compatible(self): + """PipelineTrajectoryMaker interactions should be convertible to tensor dicts.""" + gen_ctrl = FakeGenerationController( + output_str="42", + output_tokens=[201, 202], + ) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + + maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + ) + + task = GenerationTask(input_str=FAKE_PROMPT_STR, input_tokens=FAKE_INPUT_TOKENS) + list(maker.process([task])) + interaction = list(task.customized_result_fields["interactions"].values())[0] + + td = interaction.to_tensor_dict() + assert "input_ids" in td + assert "loss_mask" in td + assert "logprobs" in td + assert "rewards" in td + assert td["rewards"].item() == 1.0 + + def test_trace_e2e_multi_turn_with_rewards(self): + """Full multi-turn TraceTrajectoryMaker E2E with per-turn rewards.""" + rollout_ctrl = FakeChatRolloutController( + n_turns=3, + responses=["Let me think...", "Calculating...", "The answer is 42"], + ) + reward_ctrl = FakeChatRewardController(rewards=[0.0, 0.5, 1.0]) + + maker = TraceTrajectoryMaker( + rollout_controller=rollout_ctrl, + reward_controller=reward_ctrl, + ) + chat_tracer = maker.task_collections["chat_tracer"] + rollout_ctrl._tracer = chat_tracer + + task = TraceGenerationTask.create_from_prompt("Solve step by step: 6*7") + list(maker.process([task])) + + assert task.trace_results is not None + assert len(task.trace_results) == 3 + rewards = [i.reward for i in task.trace_results.values()] + assert rewards == [0.0, 0.5, 1.0] + + def test_trace_e2e_single_turn_generates_output(self): + """Single-turn trace should produce a valid interaction with reward.""" + rollout_ctrl = FakeChatRolloutController(n_turns=1, responses=["42"]) + reward_ctrl = FakeChatRewardController(default_reward=1.0) + + maker = TraceTrajectoryMaker( + rollout_controller=rollout_ctrl, + reward_controller=reward_ctrl, + ) + chat_tracer = maker.task_collections["chat_tracer"] + rollout_ctrl._tracer = chat_tracer + + task = TraceGenerationTask.create_from_prompt("What is 6*7?") + list(maker.process([task])) + + assert task.trace_results is not None + assert len(task.trace_results) == 1 + interaction = list(task.trace_results.values())[0] + assert interaction.reward == 1.0 + assert interaction.model_response is not None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/scaffolding/tests/test_scaffolding_llm_integration.py b/examples/scaffolding/tests/test_scaffolding_llm_integration.py new file mode 100644 index 0000000000..5210e94765 --- /dev/null +++ b/examples/scaffolding/tests/test_scaffolding_llm_integration.py @@ -0,0 +1,1060 @@ +"""Integration tests for ScaffoldingLlm with full controller pipelines. + +Tests verify that ScaffoldingLlm can correctly drive: +1. PipelineTrajectoryMaker (single-turn generation + reward) +2. TraceTrajectoryMaker + MultiTurnChatController (multi-turn chat + tracing + reward) + +These tests use a FakeChatWorker that handles ChatTask and GenerationTask +via the Worker.task_handlers dispatch mechanism, simulating an LLM backend. +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock + +import pytest + +from examples.scaffolding._compat import ( + AssistantMessage, + ChatTask, + GenerationTask, + NativeGenerationController, + ScaffoldingLlm, + TaskStatus, + Worker, +) +from examples.scaffolding.controllers import ( + MultiTurnChatController, + PipelineTrajectoryMaker, + RLVRRewardController, + TraceTrajectoryMaker, +) + +# --------------------------------------------------------------------------- +# Test constants +# --------------------------------------------------------------------------- + +FAKE_INPUT_TOKENS = [101, 102, 103] +FAKE_OUTPUT_TOKENS = [201, 202, 203, 204] +FAKE_OUTPUT_STR = "42" +FAKE_PROMPT_STR = "What is the answer to life?" + + +def _simple_reward_fn( + prompt: str, + completions: str, + prompt_ids: list[int], + completion_ids: list[int], + **kwargs, +) -> float: + """Deterministic reward: 1.0 if completion contains the answer, else 0.0.""" + answer = kwargs.get("answer", "") + return 1.0 if answer and answer in completions else 0.0 + + +# --------------------------------------------------------------------------- +# Fake Worker +# --------------------------------------------------------------------------- + + +def _make_fake_completion(completion_id: str = "cmpl-001") -> MagicMock: + """Create a minimal fake ChatCompletion object.""" + completion = MagicMock() + completion.id = completion_id + completion.created = 1000 + completion.choices = [MagicMock()] + completion.choices[0].message.content = FAKE_OUTPUT_STR + completion.choices[0].finish_reason = "stop" + return completion + + +_chat_handler_call_count = 0 + + +class FakeChatWorker(Worker): + """Worker that handles ChatTask and GenerationTask without any backend. + + For ChatTask: appends an AssistantMessage and sets completion/tokens. + For GenerationTask: fills output_str and output_tokens. + """ + + def __init__( + self, + response_text: str = FAKE_OUTPUT_STR, + output_tokens: list[int] | None = None, + ): + self.response_text = response_text + self.output_tokens = output_tokens or FAKE_OUTPUT_TOKENS + + async def _handle_generation_task(self, task: GenerationTask) -> TaskStatus: + task.output_str = self.response_text + task.output_tokens = list(self.output_tokens) + if task.input_tokens is None: + task.input_tokens = FAKE_INPUT_TOKENS + task.finish_reason = "stop" + return TaskStatus.SUCCESS + + async def _handle_chat_task(self, task: ChatTask) -> TaskStatus: + global _chat_handler_call_count + _chat_handler_call_count += 1 + + completion_id = f"cmpl-{_chat_handler_call_count:03d}" + task.completion = _make_fake_completion(completion_id) + task.completion.choices[0].message.content = self.response_text + task.output_tokens = list(self.output_tokens) + if task.input_tokens is None: + task.input_tokens = FAKE_INPUT_TOKENS + task.finish_reason = "stop" + + # Mimic what OpenaiWorker.chat_handler does: append assistant message + task.messages.append(AssistantMessage(content=self.response_text)) + return TaskStatus.SUCCESS + + task_handlers = { + GenerationTask: _handle_generation_task, + ChatTask: _handle_chat_task, + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_chat_handler_count(): + """Reset the global call counter before each test.""" + global _chat_handler_call_count + _chat_handler_call_count = 0 + yield + + +@pytest.fixture(autouse=True) +def _reset_shared_tracer(): + """Reset the class-level ChatTracer before each test.""" + yield + tracer = TraceTrajectoryMaker.task_collections.get("chat_tracer") + if tracer is not None: + tracer.clear() + + +# =========================================================================== +# PipelineTrajectoryMaker + ScaffoldingLlm +# =========================================================================== + + +class TestPipelineViaScaffoldingLlm: + """Test PipelineTrajectoryMaker running through ScaffoldingLlm.""" + + def test_single_generation_sync(self): + """ScaffoldingLlm.generate() should produce a result with interactions.""" + worker = FakeChatWorker(response_text="The answer is 42.") + reward_ctrl = RLVRRewardController(_simple_reward_fn) + gen_ctrl = NativeGenerationController( + sampling_params={"max_tokens": 100, "temperature": 1.0} + ) + trajectory_maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + input_tokens=FAKE_INPUT_TOKENS, + ) + + llm = ScaffoldingLlm( + trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + result = llm.generate(FAKE_PROMPT_STR) + + assert result is not None + assert result._done is True + output = result.outputs[0] + assert output.data is not None + + interactions = output.data + assert len(interactions) == 1 + interaction = list(interactions.values())[0] + assert interaction.reward == 1.0 + assert interaction.model_response is not None + assert interaction.model_response.input_tokens == FAKE_INPUT_TOKENS + finally: + llm.shutdown() + + def test_single_generation_async(self): + """ScaffoldingLlm.generate_async() + await should work.""" + worker = FakeChatWorker(response_text="42") + reward_ctrl = RLVRRewardController(_simple_reward_fn) + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 100}) + trajectory_maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + input_tokens=FAKE_INPUT_TOKENS, + ) + + llm = ScaffoldingLlm( + trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + result = llm.generate_async(FAKE_PROMPT_STR) + # Use result.result() which blocks via the event loop + result.result(timeout=10.0) + + assert result._done is True + output = result.outputs[0] + interactions = output.data + assert len(interactions) == 1 + interaction = list(interactions.values())[0] + assert interaction.reward == 1.0 + finally: + llm.shutdown() + + def test_batch_generation(self): + """ScaffoldingLlm.generate() with a list of prompts should work.""" + worker = FakeChatWorker(response_text="42") + reward_ctrl = RLVRRewardController(_simple_reward_fn) + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 50}) + trajectory_maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + input_tokens=FAKE_INPUT_TOKENS, + ) + + llm = ScaffoldingLlm( + trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + results = llm.generate(["prompt1", "prompt2", "prompt3"]) + assert len(results) == 3 + for result in results: + assert result._done is True + assert result.outputs[0].data is not None + finally: + llm.shutdown() + + +# =========================================================================== +# MultiTurnChatController + TraceTrajectoryMaker + ScaffoldingLlm +# =========================================================================== + + +class TestMultiTurnViaScaffoldingLlm: + """Test MultiTurnChatController + TraceTrajectoryMaker via ScaffoldingLlm. + + This reproduces the exact architecture used by ChatScaffoldingWorkflow. + """ + + @staticmethod + def _build_llm( + response_text: str = FAKE_OUTPUT_STR, + max_turns: int = 2, + reward_fn=None, + ) -> tuple[ScaffoldingLlm, MultiTurnChatController, TraceTrajectoryMaker]: + """Build the full ChatScaffoldingWorkflow pipeline.""" + worker = FakeChatWorker(response_text=response_text) + gen_ctrl = NativeGenerationController( + sampling_params={"max_tokens": 100, "temperature": 1.0} + ) + reward_ctrl = RLVRRewardController(reward_fn or _simple_reward_fn) + multi_turn_ctrl = MultiTurnChatController( + generation_controller=gen_ctrl, + max_turns=max_turns, + reflection_message="Try again.", + ) + trace_maker = TraceTrajectoryMaker( + rollout_controller=multi_turn_ctrl, + reward_controller=reward_ctrl, + ) + llm = ScaffoldingLlm( + trace_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + return llm, multi_turn_ctrl, trace_maker + + def test_single_turn_sync(self): + """Single-turn chat via ScaffoldingLlm.generate() should work.""" + llm, multi_turn_ctrl, _ = self._build_llm(response_text="42", max_turns=1) + try: + # Set per-episode data (like ChatScaffoldingWorkflow.arun_episode does) + multi_turn_ctrl.messages = [{"role": "user", "content": "What is 6*7?"}] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + result = llm.generate(FAKE_PROMPT_STR) + + assert result is not None + assert result._done is True + output = result.outputs[0] + # TraceTrajectoryMaker.generate() returns ScaffoldingOutput with + # trace_results in data + assert output.data is not None or output.text is not None + finally: + llm.shutdown() + + def test_multi_turn_sync(self): + """Multi-turn (2-turn) chat via ScaffoldingLlm.generate() should work.""" + llm, multi_turn_ctrl, _ = self._build_llm(response_text="42", max_turns=2) + try: + multi_turn_ctrl.messages = [{"role": "user", "content": "Solve: 6*7"}] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + result = llm.generate(FAKE_PROMPT_STR) + + assert result is not None + assert result._done is True + + # Verify the worker was called twice (2 turns) + global _chat_handler_call_count + assert _chat_handler_call_count >= 2 + finally: + llm.shutdown() + + def test_multi_turn_async_result(self): + """generate_async() + result() should complete for multi-turn.""" + llm, multi_turn_ctrl, _ = self._build_llm(response_text="42", max_turns=2) + try: + multi_turn_ctrl.messages = [{"role": "user", "content": "Solve: 6*7"}] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + result = llm.generate_async(FAKE_PROMPT_STR) + result.result(timeout=10.0) + + assert result._done is True + finally: + llm.shutdown() + + def test_trace_results_available(self): + """Trace results should be accessible via ScaffoldingOutput.data.""" + llm, multi_turn_ctrl, _ = self._build_llm(response_text="42", max_turns=1) + try: + multi_turn_ctrl.messages = [{"role": "user", "content": "What is 6*7?"}] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + result = llm.generate(FAKE_PROMPT_STR) + output = result.outputs[0] + + # TraceTrajectoryMaker stores trace_results in + # TraceGenerationTask.trace_results, which is returned via + # create_scaffolding_output().data + trace_data = output.data + if trace_data is not None: + # If tracing worked, we should have at least one interaction + assert len(trace_data) >= 1 + finally: + llm.shutdown() + + def test_multiple_concurrent_requests(self): + """Multiple concurrent requests should all complete.""" + llm, multi_turn_ctrl, _ = self._build_llm(response_text="42", max_turns=1) + try: + multi_turn_ctrl.messages = [{"role": "user", "content": "What is 6*7?"}] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + results = llm.generate([f"prompt_{i}" for i in range(5)]) + + assert len(results) == 5 + for result in results: + assert result._done is True + finally: + llm.shutdown() + + def test_clone_isolation(self): + """Each request should get an independent controller clone.""" + call_messages = [] + + def _capture_reward(prompt, completions, prompt_ids, completion_ids, **kw): + call_messages.append(completions) + return 1.0 + + llm, multi_turn_ctrl, _ = self._build_llm( + response_text="42", max_turns=1, reward_fn=_capture_reward + ) + try: + multi_turn_ctrl.messages = [{"role": "user", "content": "Q1"}] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + # Two sequential requests + r1 = llm.generate("prompt1") + r2 = llm.generate("prompt2") + + assert r1._done is True + assert r2._done is True + finally: + llm.shutdown() + + +# =========================================================================== +# Edge cases +# =========================================================================== + + +class TestEdgeCases: + """Edge cases and error scenarios.""" + + def test_empty_messages(self): + """Controller with empty messages should not crash.""" + worker = FakeChatWorker(response_text="hello") + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 10}) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + multi_turn_ctrl = MultiTurnChatController( + generation_controller=gen_ctrl, + max_turns=1, + reflection_message="retry", + messages=[], + input_tokens=[], + ) + trace_maker = TraceTrajectoryMaker( + rollout_controller=multi_turn_ctrl, + reward_controller=reward_ctrl, + ) + llm = ScaffoldingLlm( + trace_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + result = llm.generate("test prompt") + assert result._done is True + finally: + llm.shutdown() + + def test_shutdown_is_safe(self): + """Calling shutdown() should not raise.""" + worker = FakeChatWorker() + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 10}) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + trajectory_maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + ) + llm = ScaffoldingLlm( + trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + llm.shutdown() + # No exception = success + + +# =========================================================================== +# Async context tests (reproducing the training framework scenario) +# =========================================================================== + + +class TestAsyncContextScaffoldingLlm: + """Tests that run ScaffoldingLlm from within an existing asyncio event loop. + + This reproduces the actual deployment scenario: the rollout framework's + arun_episode is called from within an asyncio context that already has a + running event loop. ScaffoldingLlm._get_loop() will detect the running + loop (own_loop=False) and schedule its main loop on it. + """ + + @pytest.mark.asyncio + async def test_pipeline_in_async_context(self): + """PipelineTrajectoryMaker should work when called from async context.""" + worker = FakeChatWorker(response_text="The answer is 42.") + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 100}) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + trajectory_maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + input_tokens=FAKE_INPUT_TOKENS, + ) + + llm = ScaffoldingLlm( + trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + result = llm.generate_async(FAKE_PROMPT_STR) + await asyncio.wait_for(result, timeout=10.0) + + assert result._done is True + output = result.outputs[0] + assert output.data is not None + interactions = output.data + assert len(interactions) == 1 + interaction = list(interactions.values())[0] + assert interaction.reward == 1.0 + finally: + llm.shutdown() + + @pytest.mark.asyncio + async def test_multi_turn_in_async_context(self): + """MultiTurnChatController + TraceTrajectoryMaker in async context.""" + worker = FakeChatWorker(response_text="42") + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 100}) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + multi_turn_ctrl = MultiTurnChatController( + generation_controller=gen_ctrl, + max_turns=2, + reflection_message="Try again.", + messages=[{"role": "user", "content": "What is 6*7?"}], + input_tokens=FAKE_INPUT_TOKENS, + ) + trace_maker = TraceTrajectoryMaker( + rollout_controller=multi_turn_ctrl, + reward_controller=reward_ctrl, + ) + + llm = ScaffoldingLlm( + trace_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + result = llm.generate_async(FAKE_PROMPT_STR) + await asyncio.wait_for(result, timeout=10.0) + + assert result._done is True + finally: + llm.shutdown() + + @pytest.mark.asyncio + async def test_multiple_async_requests(self): + """Multiple concurrent async requests from within async context.""" + worker = FakeChatWorker(response_text="42") + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 100}) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + trajectory_maker = PipelineTrajectoryMaker( + generation_controller=gen_ctrl, + reward_controller=reward_ctrl, + task_data={"answer": "42"}, + prompt_str=FAKE_PROMPT_STR, + input_tokens=FAKE_INPUT_TOKENS, + ) + + llm = ScaffoldingLlm( + trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + # Launch multiple requests concurrently (like the rollout framework) + results = [] + for i in range(5): + results.append(llm.generate_async(f"prompt_{i}")) + + # Await all concurrently + await asyncio.wait_for( + asyncio.gather(*[r.aresult() for r in results]), + timeout=10.0, + ) + + for result in results: + assert result._done is True + assert result.outputs[0].data is not None + finally: + llm.shutdown() + + +# =========================================================================== +# Uvloop thread tests (reproducing the exact AsyncTaskRunner scenario) +# =========================================================================== + + +class TestUvloopThreadScaffoldingLlm: + """Tests that reproduce the exact production deployment architecture. + + AsyncTaskRunner runs a uvloop in a separate thread. Multiple coroutines + (one per arun_episode) run concurrently on this loop. They share one + ScaffoldingLlm instance. The ScaffoldingLlm is lazily initialized inside + the first coroutine (so it captures the uvloop as its event loop). + + This is the exact scenario where the deadlock was observed. + """ + + @staticmethod + def _build_shared_components( + response_text: str = FAKE_OUTPUT_STR, + max_turns: int = 2, + ): + """Build shared components (simulating ChatScaffoldingWorkflow.__init__).""" + worker = FakeChatWorker(response_text=response_text) + gen_ctrl = NativeGenerationController( + sampling_params={"max_tokens": 100, "temperature": 1.0} + ) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + multi_turn_ctrl = MultiTurnChatController( + generation_controller=gen_ctrl, + max_turns=max_turns, + reflection_message="Try again.", + ) + trace_maker = TraceTrajectoryMaker( + rollout_controller=multi_turn_ctrl, + reward_controller=reward_ctrl, + ) + return worker, multi_turn_ctrl, trace_maker + + def test_uvloop_lazy_init_single_request(self): + """ScaffoldingLlm lazily initialized on uvloop should handle one request.""" + import threading + + import uvloop + + worker, multi_turn_ctrl, trace_maker = self._build_shared_components( + max_turns=1 + ) + multi_turn_ctrl.messages = [{"role": "user", "content": "What is 6*7?"}] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + result_holder = {} + error_holder = {} + + async def run_on_uvloop(): + try: + # Lazy init ScaffoldingLlm inside the uvloop (like _lazy_init_scaffolding) + llm = ScaffoldingLlm( + trace_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + result = llm.generate_async(FAKE_PROMPT_STR) + await asyncio.wait_for(result, timeout=10.0) + result_holder["result"] = result + finally: + llm.shutdown() + except Exception as e: + error_holder["error"] = e + + def thread_fn(): + loop = uvloop.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_on_uvloop()) + loop.close() + + t = threading.Thread(target=thread_fn) + t.start() + t.join(timeout=30) + assert not t.is_alive(), "Thread deadlocked!" + + if "error" in error_holder: + raise error_holder["error"] + assert "result" in result_holder + assert result_holder["result"]._done is True + + def test_uvloop_concurrent_coroutines_shared_llm(self): + """Multiple concurrent coroutines on uvloop sharing one ScaffoldingLlm. + + This is the exact scenario from the training pipeline: + - AsyncTaskRunner creates a uvloop in a background thread + - Multiple _execute_workflow coroutines run concurrently + - They all share the same ChatScaffoldingWorkflow instance + - The ScaffoldingLlm is lazily initialized on first call + """ + import threading + + import uvloop + + worker, multi_turn_ctrl, trace_maker = self._build_shared_components( + max_turns=2 + ) + + num_concurrent = 10 + results = {} + errors = {} + + async def run_concurrent(): + # Lazy init ScaffoldingLlm inside uvloop (like _lazy_init_scaffolding) + llm = ScaffoldingLlm( + trace_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + + async def simulate_arun_episode(idx: int): + """Simulate what arun_episode does.""" + try: + # Each coroutine sets per-episode data and calls generate_async + # Note: in real code, clone() inside ScaffoldingLlm deep-copies + # the prototype controller, so the race on shared state is safe + # as long as the prototype is set before generate_async. + multi_turn_ctrl.messages = [ + {"role": "user", "content": f"Question {idx}"} + ] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + result = llm.generate_async(f"prompt_{idx}") + await asyncio.wait_for(result, timeout=10.0) + results[idx] = result + except Exception as e: + errors[idx] = e + + # Run all concurrently (like AsyncTaskRunner does) + await asyncio.gather( + *[simulate_arun_episode(i) for i in range(num_concurrent)] + ) + finally: + llm.shutdown() + + def thread_fn(): + loop = uvloop.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_concurrent()) + loop.close() + + t = threading.Thread(target=thread_fn) + t.start() + t.join(timeout=60) + assert not t.is_alive(), "Thread deadlocked with concurrent coroutines!" + + if errors: + first_error = next(iter(errors.values())) + raise first_error + + assert len(results) == num_concurrent + for idx, result in results.items(): + assert result._done is True, f"Result {idx} not done" + + def test_uvloop_high_concurrency(self): + """Stress test: 50 concurrent coroutines on uvloop (closer to 256 in prod).""" + import threading + + import uvloop + + worker, multi_turn_ctrl, trace_maker = self._build_shared_components( + max_turns=1 + ) + multi_turn_ctrl.messages = [{"role": "user", "content": "test"}] + multi_turn_ctrl.input_tokens = FAKE_INPUT_TOKENS + + num_concurrent = 50 + done_count = {"value": 0} + errors = [] + + async def run_stress(): + llm = ScaffoldingLlm( + trace_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + try: + + async def single_episode(idx: int): + try: + result = llm.generate_async(f"prompt_{idx}") + await asyncio.wait_for(result, timeout=15.0) + assert result._done is True + done_count["value"] += 1 + except Exception as e: + errors.append((idx, e)) + + await asyncio.gather( + *[single_episode(i) for i in range(num_concurrent)] + ) + finally: + llm.shutdown() + + def thread_fn(): + loop = uvloop.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run_stress()) + loop.close() + + t = threading.Thread(target=thread_fn) + t.start() + t.join(timeout=60) + assert not t.is_alive(), "Thread deadlocked under high concurrency!" + + if errors: + raise errors[0][1] + assert done_count["value"] == num_concurrent + + +# =========================================================================== +# AsyncTaskRunner integration test (exact production scenario) +# =========================================================================== + + +class TestAsyncTaskRunnerScaffoldingLlm: + """Tests using the real AsyncTaskRunner to reproduce the production architecture. + + This is the closest reproduction of the actual training pipeline: + - AsyncTaskRunner runs a uvloop in a background thread + - A shared workflow object is used across all tasks + - ScaffoldingLlm is lazily initialized inside the first coroutine + - Multiple concurrent coroutines share one ScaffoldingLlm instance + """ + + def test_async_task_runner_with_scaffolding_llm(self): + """Full integration test using AsyncTaskRunner + shared ScaffoldingLlm.""" + from areal.infra.async_task_runner import AsyncTaskRunner + + worker = FakeChatWorker(response_text="42") + gen_ctrl = NativeGenerationController( + sampling_params={"max_tokens": 100, "temperature": 1.0} + ) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + multi_turn_ctrl = MultiTurnChatController( + generation_controller=gen_ctrl, + max_turns=2, + reflection_message="Try again.", + ) + trace_maker = TraceTrajectoryMaker( + rollout_controller=multi_turn_ctrl, + reward_controller=reward_ctrl, + ) + + # Shared mutable state, like in the real workflow + shared_state = { + "scaffolding_llm": None, + "worker": worker, + "multi_turn_ctrl": multi_turn_ctrl, + "trace_maker": trace_maker, + } + + async def simulate_arun_episode(episode_idx: int) -> dict: + """Simulate ChatScaffoldingWorkflow.arun_episode.""" + # Lazy init (like _lazy_init_scaffolding) + if shared_state["scaffolding_llm"] is None: + shared_state["scaffolding_llm"] = ScaffoldingLlm( + shared_state["trace_maker"], + { + NativeGenerationController.WorkerTag.GENERATION: shared_state[ + "worker" + ] + }, + ) + + llm = shared_state["scaffolding_llm"] + + # Set per-episode data (race condition in production!) + shared_state["multi_turn_ctrl"].messages = [ + {"role": "user", "content": f"Question {episode_idx}"} + ] + shared_state["multi_turn_ctrl"].input_tokens = FAKE_INPUT_TOKENS + + # Generate + result = llm.generate_async(f"prompt_{episode_idx}") + await result + + return {"episode": episode_idx, "done": result._done} + + # Use AsyncTaskRunner like the real WorkflowExecutor + runner = AsyncTaskRunner(max_queue_size=64) + runner.initialize() + + num_tasks = 10 + try: + for i in range(num_tasks): + runner.submit(simulate_arun_episode, i, task_id=i) + + results = runner.wait(count=num_tasks, timeout=30.0) + + assert len(results) == num_tasks + for result in results: + assert result is not None + assert result["done"] is True + finally: + # Shutdown ScaffoldingLlm if initialized + if shared_state["scaffolding_llm"] is not None: + shared_state["scaffolding_llm"].shutdown() + runner.destroy() + + def test_async_task_runner_high_concurrency(self): + """Stress test: 50 concurrent tasks via AsyncTaskRunner.""" + from areal.infra.async_task_runner import AsyncTaskRunner + + worker = FakeChatWorker(response_text="42") + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 100}) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + multi_turn_ctrl = MultiTurnChatController( + generation_controller=gen_ctrl, + max_turns=1, + reflection_message="Try again.", + ) + trace_maker = TraceTrajectoryMaker( + rollout_controller=multi_turn_ctrl, + reward_controller=reward_ctrl, + ) + + shared_state = { + "scaffolding_llm": None, + "worker": worker, + "multi_turn_ctrl": multi_turn_ctrl, + "trace_maker": trace_maker, + } + + async def simulate_arun_episode(episode_idx: int) -> dict: + if shared_state["scaffolding_llm"] is None: + shared_state["scaffolding_llm"] = ScaffoldingLlm( + shared_state["trace_maker"], + { + NativeGenerationController.WorkerTag.GENERATION: shared_state[ + "worker" + ] + }, + ) + + llm = shared_state["scaffolding_llm"] + + shared_state["multi_turn_ctrl"].messages = [ + {"role": "user", "content": f"Q{episode_idx}"} + ] + shared_state["multi_turn_ctrl"].input_tokens = FAKE_INPUT_TOKENS + + result = llm.generate_async(f"prompt_{episode_idx}") + await result + + return {"episode": episode_idx, "done": result._done} + + runner = AsyncTaskRunner(max_queue_size=128) + runner.initialize() + + num_tasks = 50 + try: + for i in range(num_tasks): + runner.submit(simulate_arun_episode, i, task_id=i) + + results = runner.wait(count=num_tasks, timeout=60.0) + + assert len(results) == num_tasks + for result in results: + assert result is not None + assert result["done"] is True + finally: + if shared_state["scaffolding_llm"] is not None: + shared_state["scaffolding_llm"].shutdown() + runner.destroy() + + +# =========================================================================== +# Per-task ScaffoldingLlm instances (RolloutController path) +# =========================================================================== + + +class TestPerTaskScaffoldingLlm: + """Tests where EACH task creates its own ScaffoldingLlm instance. + + In the RolloutController path, RemoteInfEngine._resolve_workflow() creates + a NEW ChatScaffoldingWorkflow per submit() call. So 256 tasks create 256 + workflow instances, each with its own ScaffoldingLlm. All 256 ScaffoldingLlm + instances share the same uvloop (own_loop=False) and each schedules its own + _main_loop_async_func on it. + + This is the ACTUAL production architecture with scheduler.type=local. + """ + + def test_multiple_llm_instances_on_async_task_runner(self): + """Each task creates its own ScaffoldingLlm — deadlock reproduction.""" + from areal.infra.async_task_runner import AsyncTaskRunner + + worker = FakeChatWorker(response_text="42") + llm_instances = [] + + async def simulate_arun_episode_per_instance(episode_idx: int) -> dict: + """Each task creates its OWN ScaffoldingLlm (like _resolve_workflow).""" + # Create fresh controller hierarchy (like workflow.__init__ + build_scaffolding_llm) + gen_ctrl = NativeGenerationController( + sampling_params={"max_tokens": 100, "temperature": 1.0} + ) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + multi_turn_ctrl = MultiTurnChatController( + generation_controller=gen_ctrl, + max_turns=1, + reflection_message="Try again.", + messages=[{"role": "user", "content": f"Q{episode_idx}"}], + input_tokens=FAKE_INPUT_TOKENS, + ) + trace_maker = TraceTrajectoryMaker( + rollout_controller=multi_turn_ctrl, + reward_controller=reward_ctrl, + ) + + # Each task creates its OWN ScaffoldingLlm + llm = ScaffoldingLlm( + trace_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + llm_instances.append(llm) + + result = llm.generate_async(f"prompt_{episode_idx}") + await result + + return {"episode": episode_idx, "done": result._done} + + runner = AsyncTaskRunner(max_queue_size=64) + runner.initialize() + + num_tasks = 10 + try: + for i in range(num_tasks): + runner.submit(simulate_arun_episode_per_instance, i, task_id=i) + + results = runner.wait(count=num_tasks, timeout=30.0) + + assert len(results) == num_tasks + for result in results: + assert result is not None + assert result["done"] is True + finally: + for llm in llm_instances: + try: + llm.shutdown() + except Exception: + pass + runner.destroy() + + def test_many_llm_instances_on_async_task_runner(self): + """50 per-task ScaffoldingLlm instances — stress test.""" + from areal.infra.async_task_runner import AsyncTaskRunner + + worker = FakeChatWorker(response_text="42") + llm_instances = [] + + async def simulate_arun_episode_per_instance(episode_idx: int) -> dict: + gen_ctrl = NativeGenerationController(sampling_params={"max_tokens": 50}) + reward_ctrl = RLVRRewardController(_simple_reward_fn) + multi_turn_ctrl = MultiTurnChatController( + generation_controller=gen_ctrl, + max_turns=1, + reflection_message="retry", + messages=[{"role": "user", "content": f"Q{episode_idx}"}], + input_tokens=FAKE_INPUT_TOKENS, + ) + trace_maker = TraceTrajectoryMaker( + rollout_controller=multi_turn_ctrl, + reward_controller=reward_ctrl, + ) + + llm = ScaffoldingLlm( + trace_maker, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + llm_instances.append(llm) + + result = llm.generate_async(f"prompt_{episode_idx}") + await result + + return {"episode": episode_idx, "done": result._done} + + runner = AsyncTaskRunner(max_queue_size=128) + runner.initialize() + + num_tasks = 50 + try: + for i in range(num_tasks): + runner.submit(simulate_arun_episode_per_instance, i, task_id=i) + + results = runner.wait(count=num_tasks, timeout=60.0) + + assert len(results) == num_tasks + for result in results: + assert result is not None + assert result["done"] is True + finally: + for llm in llm_instances: + try: + llm.shutdown() + except Exception: + pass + runner.destroy() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/scaffolding/tests/test_self_contained.py b/examples/scaffolding/tests/test_self_contained.py new file mode 100644 index 0000000000..8d40875a1e --- /dev/null +++ b/examples/scaffolding/tests/test_self_contained.py @@ -0,0 +1,795 @@ +"""Tests proving the vendored scaffolding modules are self-contained. + +These tests verify that: +1. All scaffolding core modules import without tensorrt_llm installed. +2. Every public symbol in core/__init__.py and _compat.py is importable. +3. Core primitives (Task, Controller, Worker, ScaffoldingLlm, TaskCollection, + math_utils) function correctly in isolation. +4. AReaL wrapper modules (controllers, task, worker, workflow) import cleanly. +5. No source file under examples/scaffolding/ contains a live + ``import tensorrt_llm`` or ``from tensorrt_llm`` statement. +""" + +from __future__ import annotations + +import ast +import asyncio +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest + +# ============================================================================ +# 1. Import isolation — tensorrt_llm must NOT be importable +# ============================================================================ + + +class TestNoTensorRTLLMDependency: + """Verify that tensorrt_llm is not required at runtime.""" + + def test_tensorrt_llm_not_installed(self): + """tensorrt_llm should not be importable in the test environment.""" + assert "tensorrt_llm" not in sys.modules or sys.modules["tensorrt_llm"] is None + + def test_no_live_tensorrt_llm_imports_in_source(self): + """No .py file under scaffolding/ should have a live import of tensorrt_llm. + + Comments and docstrings are allowed; only top-level or function-level + ``import tensorrt_llm`` / ``from tensorrt_llm import ...`` are flagged. + """ + scaffolding_root = Path(__file__).resolve().parents[1] + violations = [] + for py_file in scaffolding_root.rglob("*.py"): + try: + tree = ast.parse(py_file.read_text(), filename=str(py_file)) + except SyntaxError: + continue + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "tensorrt_llm" or alias.name.startswith( + "tensorrt_llm." + ): + violations.append( + f"{py_file.relative_to(scaffolding_root)}:{node.lineno}" + ) + elif isinstance(node, ast.ImportFrom): + if node.module and ( + node.module == "tensorrt_llm" + or node.module.startswith("tensorrt_llm.") + ): + violations.append( + f"{py_file.relative_to(scaffolding_root)}:{node.lineno}" + ) + + assert violations == [], "Found live tensorrt_llm imports in:\n" + "\n".join( + violations + ) + + +# ============================================================================ +# 2. Core module imports — every public symbol is importable +# ============================================================================ + + +class TestCoreImports: + """Verify every public symbol in core/ is importable.""" + + def test_core_init_imports(self): + """All symbols in core/__init__.py should import successfully.""" + from examples.scaffolding.core import ( # noqa: F401 + AssistantMessage, + BestOfNController, + ChatTask, + Controller, + GenerationTask, + MajorityVoteController, + NativeChatController, + NativeGenerationController, + NativeRewardController, + OpenAIToolDescription, + OpenaiWorker, + ParallelProcess, + RoleMessage, + ScaffoldingLlm, + ScaffoldingOutput, + StreamGenerationTask, + SystemMessage, + Task, + TaskCollection, + TaskStatus, + UserMessage, + Worker, + extract_answer_from_boxed, + extract_answer_with_regex, + get_digit_majority_vote_result, + with_task_collection, + ) + + # Spot-check a few are real classes/functions + assert callable(Controller) + assert callable(extract_answer_from_boxed) + assert callable(with_task_collection) + + def test_compat_reexports_match_core(self): + """_compat.py should re-export the same symbols as core/__init__.py + (minus math_utils which is not in _compat).""" + from examples.scaffolding import _compat, core + + compat_all = set(_compat.__all__) + # _compat intentionally omits math_utils functions + math_utils_names = { + "extract_answer_from_boxed", + "extract_answer_with_regex", + "get_digit_majority_vote_result", + } + core_all_minus_math = set(core.__all__) - math_utils_names + assert compat_all == core_all_minus_math + + def test_core_submodule_imports(self): + """Each core submodule should import independently.""" + import examples.scaffolding.core.controller # noqa: F401 + import examples.scaffolding.core.math_utils # noqa: F401 + import examples.scaffolding.core.result # noqa: F401 + import examples.scaffolding.core.scaffolding_llm # noqa: F401 + import examples.scaffolding.core.task # noqa: F401 + import examples.scaffolding.core.task_collection # noqa: F401 + import examples.scaffolding.core.worker # noqa: F401 + + +# ============================================================================ +# 3. Core primitives — functional tests +# ============================================================================ + + +class TestTask: + """Tests for Task, GenerationTask, ChatTask, and related dataclasses.""" + + def test_task_creation(self): + from examples.scaffolding.core.task import Task + + t = Task() + assert t.worker_tag is None + assert t.streaming_output_flag is False + assert t.streaming_output_list == [] + + def test_generation_task_create_from_prompt(self): + from examples.scaffolding.core.task import GenerationTask + + t = GenerationTask.create_from_prompt("Hello world") + assert t.input_str == "Hello world" + assert t.skip_tokenizer is False + assert t.skip_detokenizer is False + + def test_generation_task_scaffolding_output(self): + from examples.scaffolding.core.task import GenerationTask + + t = GenerationTask(output_str="result", output_tokens=[10, 20]) + output = t.create_scaffolding_output() + assert output.text == "result" + assert output.token_ids == [10, 20] + + def test_stream_generation_task_from_generation(self): + from examples.scaffolding.core.task import ( + GenerationTask, + StreamGenerationTask, + ) + + gen = GenerationTask(input_str="prompt", max_tokens=100) + stream = StreamGenerationTask.create_from_generation_task(gen, streaming_step=5) + assert stream.input_str == "prompt" + assert stream.max_tokens == 100 + assert stream.streaming_step == 5 + + def test_chat_task_create_from_prompt(self): + from examples.scaffolding.core.task import ( + ChatTask, + SystemMessage, + ) + + t = ChatTask.create_from_prompt( + "What is 2+2?", + system_prompts=[SystemMessage("You are a math tutor")], + ) + assert len(t.messages) == 2 + assert t.messages[0].role == "system" + assert t.messages[1].role == "user" + + def test_chat_task_create_from_messages(self): + from examples.scaffolding.core.task import ( + ChatTask, + UserMessage, + ) + + msgs = [UserMessage("hi"), UserMessage("hello")] + t = ChatTask.create_from_messages(msgs) + assert len(t.messages) == 2 + + def test_chat_task_add_message(self): + from examples.scaffolding.core.task import ( + AssistantMessage, + ChatTask, + UserMessage, + ) + + t = ChatTask() + t.add_message(UserMessage("Q")) + t.add_message(AssistantMessage("A")) + assert len(t.messages) == 2 + assert t.messages[0].role == "user" + assert t.messages[1].role == "assistant" + + def test_chat_task_messages_to_dict(self): + from examples.scaffolding.core.task import ( + ChatTask, + UserMessage, + ) + + t = ChatTask() + t.add_message(UserMessage("hello")) + dicts = t.messages_to_dict_content() + assert dicts == [{"role": "user", "content": "hello"}] + + def test_role_message_str_repr(self): + from examples.scaffolding.core.task import UserMessage + + m = UserMessage("hi") + assert '"role": "user"' in str(m) + assert "user" in repr(m) + + def test_role_message_from_dict(self): + from examples.scaffolding.core.task import RoleMessage + + m = RoleMessage.from_dict({"role": "user", "content": "test"}) + assert m.role == "user" + assert m.content == "test" + + def test_assistant_message_with_reasoning(self): + from examples.scaffolding.core.task import AssistantMessage + + m = AssistantMessage("answer", reasoning="chain of thought") + assert m.role == "assistant" + assert m.reasoning == "chain of thought" + assert '"reasoning"' in str(m) + + def test_openai_tool_description(self): + from examples.scaffolding.core.task import OpenAIToolDescription + + tool = OpenAIToolDescription( + "my_func", "Does something", {"x": {"type": "int"}} + ) + d = tool.to_dict() + assert d["type"] == "function" + assert d["function"]["name"] == "my_func" + assert d["function"]["description"] == "Does something" + + def test_task_status_values(self): + from examples.scaffolding.core.task import TaskStatus + + assert TaskStatus.SUCCESS.value == "success" + assert TaskStatus.WORKER_NOT_SUPPORTED.value == "worker_not_supported" + assert TaskStatus.WORKER_EXECEPTION.value == "worker_exception" + + +class TestController: + """Tests for Controller and built-in controller subclasses.""" + + def test_controller_is_abstract(self): + from examples.scaffolding.core.controller import Controller + + ctrl = Controller() + assert hasattr(ctrl, "task_collections") + with pytest.raises(NotImplementedError): + list(ctrl.process([])) + + def test_controller_clone(self): + from examples.scaffolding.core.controller import Controller + + class MyCtrl(Controller): + def __init__(self, val): + super().__init__() + self.val = val + + def process(self, tasks, **kw): + yield tasks + + original = MyCtrl(42) + cloned = original.clone() + assert cloned.val == 42 + assert cloned is not original + + def test_native_generation_controller(self): + from examples.scaffolding.core.controller import ( + NativeGenerationController, + ) + from examples.scaffolding.core.task import GenerationTask + + ctrl = NativeGenerationController( + sampling_params={"temperature": 0.7, "max_tokens": 100} + ) + task = GenerationTask(input_str="test") + results = list(ctrl.process([task])) + + assert len(results) == 1 + assert task.worker_tag == NativeGenerationController.WorkerTag.GENERATION + assert task.temperature == 0.7 + assert task.max_tokens == 100 + + def test_native_generation_controller_ignores_invalid_params(self): + from examples.scaffolding.core.controller import ( + NativeGenerationController, + ) + + ctrl = NativeGenerationController( + sampling_params={"invalid_param_xyz": 999, "temperature": 0.5} + ) + assert "temperature" in ctrl.sampling_params + assert "invalid_param_xyz" not in ctrl.sampling_params + + def test_native_chat_controller(self): + from examples.scaffolding.core.controller import ( + NativeChatController, + ) + from examples.scaffolding.core.task import ( + ChatTask, + GenerationTask, + ) + + ctrl = NativeChatController() + task = GenerationTask(input_str="What is 2+2?") + results = list(ctrl.process([task])) + + assert len(results) == 1 + # NativeChatController wraps in ChatTask + yielded_tasks = results[0] + assert isinstance(yielded_tasks[0], ChatTask) + + def test_native_reward_controller(self): + from examples.scaffolding.core.controller import ( + NativeRewardController, + ) + from examples.scaffolding.core.task import GenerationTask + + ctrl = NativeRewardController() + task = GenerationTask(input_str="test") + results = list(ctrl.process([task])) + + assert len(results) == 1 + assert task.worker_tag == NativeRewardController.WorkerTag.REWARD + + def test_parallel_process_creation(self): + from examples.scaffolding.core.controller import ( + NativeGenerationController, + ParallelProcess, + ) + from examples.scaffolding.core.task import GenerationTask + + ctrl1 = NativeGenerationController() + ctrl2 = NativeGenerationController() + tasks1 = [GenerationTask(input_str="a")] + tasks2 = [GenerationTask(input_str="b")] + + pp = ParallelProcess( + controllers=[ctrl1, ctrl2], + tasks_list=[tasks1, tasks2], + kwargs_list=[{}, {}], + ) + assert len(pp.sub_gens) == 2 + + +class TestWorker: + """Tests for Worker base class.""" + + def test_worker_is_abstract(self): + from examples.scaffolding.core.worker import Worker + + w = Worker() + assert hasattr(w, "task_handlers") + + @pytest.mark.asyncio + async def test_worker_unsupported_task(self): + from examples.scaffolding.core.task import Task, TaskStatus + from examples.scaffolding.core.worker import Worker + + class EmptyWorker(Worker): + task_handlers = {} + + w = EmptyWorker() + status = await w.run_task(Task()) + assert status == TaskStatus.WORKER_NOT_SUPPORTED + + @pytest.mark.asyncio + async def test_worker_register_handler(self): + from examples.scaffolding.core.task import ( + GenerationTask, + TaskStatus, + ) + from examples.scaffolding.core.worker import Worker + + class MyWorker(Worker): + task_handlers = {} + + async def my_handler(self, task): + task.output_str = "handled" + return TaskStatus.SUCCESS + + w = MyWorker() + w.register_task_handler(GenerationTask, my_handler) + task = GenerationTask(input_str="test") + status = await w.run_task(task) + assert status == TaskStatus.SUCCESS + assert task.output_str == "handled" + + def test_worker_context_manager(self): + from examples.scaffolding.core.worker import Worker + + w = Worker() + result = w.__enter__() + assert result is w + + def test_is_deterministic_mode(self): + from examples.scaffolding.core.worker import is_deterministic_mode + + assert is_deterministic_mode() is False + + with patch.dict("os.environ", {"SCAFFOLDING_DETERMINISTIC": "1"}): + assert is_deterministic_mode() is True + + +class TestResult: + """Tests for ScaffoldingOutput and ScaffoldingResult.""" + + def test_scaffolding_output(self): + from examples.scaffolding.core.result import ScaffoldingOutput + + o = ScaffoldingOutput(text="hello", token_ids=[1, 2, 3]) + assert o.text == "hello" + assert o.token_ids == [1, 2, 3] + assert o.data is None + + def test_scaffolding_output_with_data(self): + from examples.scaffolding.core.result import ScaffoldingOutput + + payload = {"key": "value", "nested": [1, 2, 3]} + o = ScaffoldingOutput(text="hello", token_ids=[1, 2, 3], data=payload) + assert o.text == "hello" + assert o.token_ids == [1, 2, 3] + assert o.data is payload + + def test_scaffolding_result_set_output(self): + from examples.scaffolding.core.result import ( + ScaffoldingOutput, + ScaffoldingResult, + ) + + result = ScaffoldingResult() + assert not result._done + + output = ScaffoldingOutput("text", [1, 2]) + result.set_output(output) + + # After set_output, we should be able to get the result + # by draining the queue + loop = asyncio.new_event_loop() + try: + done = loop.run_until_complete(result.aresult()) + assert done._done + assert done.outputs[0].text == "text" + assert done.outputs[0].token_ids == [1, 2] + finally: + loop.close() + + def test_scaffolding_result_set_output_none(self): + from examples.scaffolding.core.result import ScaffoldingResult + + result = ScaffoldingResult() + result.set_output(None) + + loop = asyncio.new_event_loop() + try: + done = loop.run_until_complete(result.aresult()) + assert done._done + finally: + loop.close() + + def test_scaffolding_result_task_collections(self): + from examples.scaffolding.core.result import ScaffoldingResult + + result = ScaffoldingResult() + assert result.task_collections is None + result.set_task_collections({"key": "value"}) + assert result.task_collections == {"key": "value"} + + +class TestTaskCollection: + """Tests for TaskCollection and the with_task_collection decorator.""" + + def test_task_collection_base(self): + from examples.scaffolding.core.task_collection import TaskCollection + + tc = TaskCollection() + tc.before_yield([]) # Should not raise + tc.after_yield([]) # Should not raise + assert TaskCollection.get_global_info() is None + + def test_with_task_collection_decorator(self): + from examples.scaffolding.core.controller import Controller + from examples.scaffolding.core.task import Task + from examples.scaffolding.core.task_collection import ( + TaskCollection, + with_task_collection, + ) + + class MyCollection(TaskCollection): + def __init__(self): + super().__init__() + self.before_count = 0 + self.after_count = 0 + + def before_yield(self, tasks): + self.before_count += 1 + + def after_yield(self, tasks): + self.after_count += 1 + + @with_task_collection("my_tc", MyCollection) + class MyController(Controller): + def process(self, tasks, **kwargs): + yield tasks + + ctrl = MyController() + assert "my_tc" in ctrl.task_collections + tc = ctrl.task_collections["my_tc"] + assert isinstance(tc, MyCollection) + + list(ctrl.process([Task()])) + assert tc.before_count == 1 + assert tc.after_count == 1 + + def test_generation_token_counter(self): + from examples.scaffolding.core.task import GenerationTask + from examples.scaffolding.core.task_collection import ( + GenerationTokenCounter, + ) + + counter = GenerationTokenCounter() + task = GenerationTask(output_tokens=[1, 2, 3]) + + counter.before_yield([task]) + # Simulate worker adding tokens + task.output_tokens = [1, 2, 3, 4, 5] + counter.after_yield([task]) + + assert counter.generation_token_count == 2 # 5 - 3 = 2 new tokens + + def test_task_metrics_collector_reset(self): + from examples.scaffolding.core.task_collection import ( + TaskMetricsCollector, + ) + + TaskMetricsCollector.statistics["test_ctrl"] = [{"data": 1}] + TaskMetricsCollector.reset("test_ctrl") + assert TaskMetricsCollector.statistics["test_ctrl"] == [] + TaskMetricsCollector.reset() + assert TaskMetricsCollector.statistics == {} + + +class TestMathUtils: + """Tests for math_utils functions.""" + + def test_extract_answer_from_boxed_simple(self): + from examples.scaffolding.core.math_utils import ( + extract_answer_from_boxed, + ) + + assert extract_answer_from_boxed("The answer is \\boxed{42}") == "42" + + def test_extract_answer_from_boxed_nested(self): + from examples.scaffolding.core.math_utils import ( + extract_answer_from_boxed, + ) + + assert extract_answer_from_boxed("\\boxed{x^{2}}") == "x^{2}" + + def test_extract_answer_from_boxed_none(self): + from examples.scaffolding.core.math_utils import ( + extract_answer_from_boxed, + ) + + assert extract_answer_from_boxed("No boxed answer here") is None + + def test_extract_answer_with_regex(self): + from examples.scaffolding.core.math_utils import ( + extract_answer_with_regex, + ) + + result = extract_answer_with_regex("The final answer is 42") + assert result == "42" + + def test_extract_answer_with_regex_no_match(self): + from examples.scaffolding.core.math_utils import ( + extract_answer_with_regex, + ) + + assert extract_answer_with_regex("Nothing relevant") is None + + def test_get_digit_majority_vote_result(self): + from examples.scaffolding.core.math_utils import ( + get_digit_majority_vote_result, + ) + + results = [ + "The answer is \\boxed{42}", + "Therefore \\boxed{42}", + "I get \\boxed{99}", + ] + index, answer = get_digit_majority_vote_result(results) + assert answer == "42" + + def test_get_digit_majority_vote_no_valid(self): + from examples.scaffolding.core.math_utils import ( + get_digit_majority_vote_result, + ) + + results = ["no boxed", "nothing here"] + index, answer = get_digit_majority_vote_result(results) + assert answer is None + + +# ============================================================================ +# 4. AReaL wrapper modules — import and basic function +# ============================================================================ + + +class TestWrapperImports: + """Verify AReaL wrapper modules import without tensorrt_llm.""" + + def test_compat_module_imports(self): + from examples.scaffolding._compat import ( # noqa: F401 + AssistantMessage, + BestOfNController, + ChatTask, + Controller, + GenerationTask, + MajorityVoteController, + NativeChatController, + NativeGenerationController, + NativeRewardController, + OpenAIToolDescription, + OpenaiWorker, + ParallelProcess, + RoleMessage, + ScaffoldingLlm, + ScaffoldingOutput, + StreamGenerationTask, + SystemMessage, + Task, + TaskCollection, + TaskStatus, + UserMessage, + Worker, + with_task_collection, + ) + + def test_controllers_module_imports(self): + from examples.scaffolding.controllers import ( # noqa: F401 + ChatTracer, + PipelineTrajectoryMaker, + RLVRRewardController, + TraceTrajectoryMaker, + ) + + def test_task_module_imports(self): + from examples.scaffolding.task import ( # noqa: F401 + ChatRewardTask, + RLVRRewardTask, + TraceGenerationTask, + ) + + def test_worker_module_imports(self): + from examples.scaffolding.worker import ( # noqa: F401 + CreateWorkerFromEngine, + SGLangWorker, + ) + + def test_workflow_module_imports(self): + from examples.scaffolding.workflow import ( + ScaffoldingWorkflow, # noqa: F401 + ) + + def test_top_level_package_imports(self): + from examples.scaffolding import ( # noqa: F401 + ChatRewardTask, + ChatTracer, + CreateWorkerFromEngine, + PipelineTrajectoryMaker, + RLVRRewardController, + RLVRRewardTask, + ScaffoldingWorkflow, + SGLangWorker, + TraceGenerationTask, + TraceTrajectoryMaker, + ) + + +# ============================================================================ +# 5. Cross-module integration — core + wrappers work together +# ============================================================================ + + +class TestCrossModuleIntegration: + """Verify that core primitives and AReaL wrappers interoperate.""" + + def test_rlvr_reward_task_inherits_from_core_task(self): + from examples.scaffolding.core.task import Task + from examples.scaffolding.task import RLVRRewardTask + + t = RLVRRewardTask(prompt_str="Q", completion_str="A") + assert isinstance(t, Task) + + def test_sglang_worker_inherits_from_core_openai_worker(self): + from examples.scaffolding.core.worker import OpenaiWorker + from examples.scaffolding.worker import SGLangWorker + + assert issubclass(SGLangWorker, OpenaiWorker) + + def test_rlvr_reward_controller_inherits_from_core_controller(self): + from examples.scaffolding.controllers import RLVRRewardController + from examples.scaffolding.core.controller import Controller + + assert issubclass(RLVRRewardController, Controller) + + def test_pipeline_trajectory_maker_inherits_from_core_controller(self): + from examples.scaffolding.controllers import PipelineTrajectoryMaker + from examples.scaffolding.core.controller import Controller + + assert issubclass(PipelineTrajectoryMaker, Controller) + + def test_native_gen_controller_process_with_compat_import(self): + """Using _compat imports should produce the same result as core imports.""" + from examples.scaffolding._compat import ( + GenerationTask, + NativeGenerationController, + ) + + ctrl = NativeGenerationController(sampling_params={"temperature": 0.5}) + task = GenerationTask.create_from_prompt("test") + results = list(ctrl.process([task])) + assert len(results) == 1 + assert task.temperature == 0.5 + + def test_scaffolding_workflow_is_rollout_workflow(self): + from examples.scaffolding.workflow import ScaffoldingWorkflow + + from areal.api.workflow_api import RolloutWorkflow + + assert issubclass(ScaffoldingWorkflow, RolloutWorkflow) + + def test_compat_classes_are_same_as_core_classes(self): + """Verify _compat re-exports are the exact same class objects as core.""" + from examples.scaffolding._compat import ( + Controller as CompatController, + ) + from examples.scaffolding._compat import ( + GenerationTask as CompatGenTask, + ) + from examples.scaffolding._compat import ( + ScaffoldingLlm as CompatLlm, + ) + from examples.scaffolding._compat import Task as CompatTask + from examples.scaffolding._compat import Worker as CompatWorker + from examples.scaffolding.core.controller import Controller + from examples.scaffolding.core.scaffolding_llm import ScaffoldingLlm + from examples.scaffolding.core.task import GenerationTask, Task + from examples.scaffolding.core.worker import Worker + + assert CompatTask is Task + assert CompatGenTask is GenerationTask + assert CompatController is Controller + assert CompatWorker is Worker + assert CompatLlm is ScaffoldingLlm + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/examples/scaffolding/tests/test_worker.py b/examples/scaffolding/tests/test_worker.py new file mode 100644 index 0000000000..cfff281a82 --- /dev/null +++ b/examples/scaffolding/tests/test_worker.py @@ -0,0 +1,205 @@ +"""Tests for SGLangWorker with a real SGLang server (requires GPU).""" + +from __future__ import annotations + +import subprocess +import sys +import time +from unittest.mock import MagicMock + +import openai +import pytest +import requests + +from examples.scaffolding._compat import ( + GenerationTask, + TaskStatus, +) +from examples.scaffolding.worker import SGLangWorker + +from areal.api.cli_args import SGLangConfig +from areal.tests.utils import get_model_path +from areal.utils import network, seeding +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.proc import kill_process_tree + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +EXPR_NAME = "test_scaffolding_worker" +MODEL_PATH = get_model_path( + "/storage/openpsi/models/Qwen__Qwen3-0.6B/", "Qwen/Qwen3-0.6B" +) +PORT, DIST_PORT = network.find_free_ports(2) +HOST = network.gethostip() +RUN_SERVER_TIMEOUT = 180 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _check_server_health(base_url: str) -> bool: + try: + response = requests.get(f"{base_url}/health", timeout=30) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + +@pytest.fixture(scope="module") +def sglang_server(): + """Launch a real SGLang server. + + Uses skip_tokenizer_init=True (AReaL default). Chat template is applied + client-side via the tokenizer, matching ScaffoldingWorkflow behavior. + """ + seeding.set_random_seed(1, EXPR_NAME) + cmd = SGLangConfig.build_cmd( + sglang_config=SGLangConfig( + skip_tokenizer_init=False, + model_path=MODEL_PATH, + mem_fraction_static=0.3, + ), + host=HOST, + port=PORT, + tp_size=1, + base_gpu_id=0, + dist_init_addr=f"{HOST}:{DIST_PORT}", + ) + process = subprocess.Popen( + cmd, + stdout=sys.stdout, + stderr=sys.stdout, + ) + base_url = f"http://{HOST}:{PORT}" + tik = time.time() + while time.time() - tik < RUN_SERVER_TIMEOUT: + if _check_server_health(base_url): + break + time.sleep(1) + if time.time() - tik > RUN_SERVER_TIMEOUT: + kill_process_tree(process.pid, graceful=True) + raise RuntimeError("SGLang server launch timed out") + yield base_url + kill_process_tree(process.pid, graceful=True) + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load the tokenizer for client-side chat template application.""" + return load_hf_tokenizer(MODEL_PATH) + + +@pytest.fixture(scope="module") +def sglang_worker(sglang_server): + """Create an SGLangWorker connected to the real SGLang server.""" + base_url = sglang_server + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + async_client = openai.AsyncOpenAI(base_url=base_url, api_key="EMPTY") + mock_engine = MagicMock() + return SGLangWorker( + async_client=async_client, + model="default", + engine=mock_engine, + ) + + +# --------------------------------------------------------------------------- +# Tests — generation_handler (uses /v1/completions, no tokenizer needed) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_generation_handler(sglang_worker): + """generation_handler should generate text from real server.""" + task = GenerationTask( + input_str="What is 1 + 1? Answer briefly:", + input_tokens=[], + ) + status = await sglang_worker.generation_handler(task) + + assert status == TaskStatus.SUCCESS + assert task.output_str is not None + assert len(task.output_str) > 0 + assert task.finish_reason in ("stop", "length") + + +@pytest.mark.asyncio +async def test_generation_max_tokens(sglang_worker): + """generation_handler should respect max_tokens and finish with 'length'.""" + task = GenerationTask( + input_str="Write a very long essay about the history of mathematics.", + input_tokens=[], + ) + original_create = sglang_worker.async_client.completions.create + + async def _create_with_limit(**kwargs): + kwargs["max_tokens"] = 5 + return await original_create(**kwargs) + + sglang_worker.async_client.completions.create = _create_with_limit + try: + status = await sglang_worker.generation_handler(task) + assert status == TaskStatus.SUCCESS + assert task.finish_reason == "length" + assert task.output_str is not None + finally: + sglang_worker.async_client.completions.create = original_create + + +# --------------------------------------------------------------------------- +# Tests — generation with chat template (client-side, like ScaffoldingWorkflow) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_generation_with_chat_template(sglang_worker, tokenizer): + """Client-side chat template + completions API, matching ScaffoldingWorkflow.""" + messages = [{"role": "user", "content": "What is the capital of France?"}] + input_ids = tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True + ) + prompt_str = tokenizer.decode(input_ids) + + task = GenerationTask( + input_str=prompt_str, + input_tokens=input_ids, + ) + status = await sglang_worker.generation_handler(task) + + assert status == TaskStatus.SUCCESS + assert task.output_str is not None + assert len(task.output_str) > 0 + + +@pytest.mark.asyncio +async def test_multi_turn_generation(sglang_worker, tokenizer): + """Multi-turn via client-side chat template + completions API.""" + # Turn 1 + messages = [{"role": "user", "content": "My name is Alice."}] + input_ids = tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True + ) + prompt_str = tokenizer.decode(input_ids) + + task = GenerationTask(input_str=prompt_str, input_tokens=input_ids) + status = await sglang_worker.generation_handler(task) + assert status == TaskStatus.SUCCESS + assert task.output_str is not None + + # Turn 2 — append assistant reply and new user message + messages.append({"role": "assistant", "content": task.output_str}) + messages.append({"role": "user", "content": "What is my name?"}) + input_ids_2 = tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True + ) + prompt_str_2 = tokenizer.decode(input_ids_2) + + task2 = GenerationTask(input_str=prompt_str_2, input_tokens=input_ids_2) + status = await sglang_worker.generation_handler(task2) + assert status == TaskStatus.SUCCESS + assert task2.output_str is not None diff --git a/examples/scaffolding/worker.py b/examples/scaffolding/worker.py new file mode 100644 index 0000000000..9eb46c2cbe --- /dev/null +++ b/examples/scaffolding/worker.py @@ -0,0 +1,249 @@ +""" +Worker implementations for Scaffolding Framework. + +This module provides Worker implementations that wrap AReaL inference engines +for use with the scaffolding framework. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import openai + +from areal.utils import logging + +from ._compat import ( + AssistantMessage, + ChatTask, + GenerationTask, + OpenaiWorker, + TaskStatus, +) + +if TYPE_CHECKING: + from areal.engine.sglang_remote import RemoteSGLangEngine + +worker_logger = logging.getLogger("SGLangWorker") + + +class SGLangWorker(OpenaiWorker): + """Worker that wraps an SGLang engine for scaffolding. + + This worker connects to an SGLang server via its OpenAI-compatible API + and handles generation and chat tasks. + + Parameters + ---------- + async_client : openai.AsyncOpenAI + The OpenAI async client configured to connect to SGLang server. + model : str + The model name to use for requests. + engine : RemoteSGLangEngine + The underlying SGLang engine (kept for reference and potential future use). + """ + + def __init__( + self, + async_client: openai.AsyncOpenAI, + model: str, + engine: RemoteSGLangEngine, + ): + super().__init__(async_client, model, kv_cache_hint_enabled=False) + self.engine = engine + + async def chat_handler(self, task: ChatTask) -> TaskStatus: + """Handle chat completion requests. + + This method extends the base OpenaiWorker's chat handler to also + store the ChatCompletion object in the task for tracing purposes. + + Parameters + ---------- + task : ChatTask + The chat task to process. + + Returns + ------- + TaskStatus + The status of the task execution. + """ + params = self.convert_task_params(task) + params["messages"] = task.messages_to_dict_content() + params["model"] = self.model + if task.tools is not None: + params["tools"] = [tool.to_dict() for tool in task.tools] + + try: + worker_logger.info( + "Sending chat request to %s (messages=%d) ...", + self.async_client.base_url, + len(params.get("messages", [])), + ) + response = await self.async_client.chat.completions.create(**params) + worker_logger.info("Chat response received.") + + # Store the completion in the task for tracing + task.completion = response + + task.finish_reason = response.choices[0].finish_reason + if hasattr(response, "perf_metrics"): + task.perf_metrics = response.perf_metrics + + content = response.choices[0].message.content + reasoning = getattr(response.choices[0].message, "reasoning", None) + reasoning_content = getattr( + response.choices[0].message, "reasoning_content", None + ) + tool_calls = response.choices[0].message.tool_calls + + task.messages.append( + AssistantMessage(content, reasoning, reasoning_content, tool_calls) + ) + + if task.enable_token_counting and response.usage: + task.prompt_tokens_num = response.usage.prompt_tokens + task.completion_tokens_num = response.usage.completion_tokens + if ( + hasattr(response.usage, "completion_tokens_details") + and response.usage.completion_tokens_details is not None + ): + task.reasoning_tokens_num = ( + response.usage.completion_tokens_details.reasoning_tokens + ) + + return TaskStatus.SUCCESS + + except Exception as e: + worker_logger.error("SGLang chat client exception: %s", e) + return TaskStatus.WORKER_EXECEPTION + + async def generation_handler(self, task: GenerationTask) -> TaskStatus: + """Handle text generation requests. + + Parameters + ---------- + task : GenerationTask + The generation task to process. + + Returns + ------- + TaskStatus + The status of the task execution. + """ + params = self.convert_task_params(task) + params["model"] = self.model + if task.input_str is not None: + params["prompt"] = task.input_str + + try: + worker_logger.info( + "Sending generation request to %s ...", + self.async_client.base_url, + ) + response = await self.async_client.completions.create(**params) + worker_logger.info("Generation response received.") + + task.output_str = response.choices[0].text + if hasattr(response.choices[0], "token_ids"): + task.output_tokens = response.choices[0].token_ids + task.finish_reason = response.choices[0].finish_reason + if hasattr(response.choices[0], "logprobs"): + task.logprobs = response.choices[0].logprobs + if hasattr(response, "perf_metrics"): + task.perf_metrics = response.perf_metrics + + return TaskStatus.SUCCESS + + except Exception as e: + worker_logger.error("SGLang completion client exception: %s", e) + return TaskStatus.WORKER_EXECEPTION + + # Register task handlers + task_handlers = { + GenerationTask: generation_handler, + ChatTask: chat_handler, + } + + +def CreateWorkerFromEngine( + engine: RemoteSGLangEngine, + model: str = "default", +) -> SGLangWorker: + """Create a scaffolding Worker from an AReaL SGLang engine. + + This function creates a Worker that wraps the given SGLang engine, + allowing it to be used with the scaffolding framework. + The worker uses the SGLang server's OpenAI-compatible API. + + Parameters + ---------- + engine : RemoteSGLangEngine + The AReaL SGLang inference engine (must be initialized). + model : str, optional + The model name to use for API requests. Defaults to "default". + + Returns + ------- + SGLangWorker + A Worker instance that can be used with ScaffoldingLlm. + + Example + ------- + ```python + from areal.engine.sglang_remote import RemoteSGLangEngine + from examples.scaffolding import CreateWorkerFromEngine + + # Initialize the engine + engine = RemoteSGLangEngine(config) + engine.initialize() + + # Create a worker + worker = CreateWorkerFromEngine(engine) + + # Use with ScaffoldingLlm + from examples.scaffolding._compat import ScaffoldingLlm, NativeGenerationController + llm = ScaffoldingLlm( + controller, + {NativeGenerationController.WorkerTag.GENERATION: worker}, + ) + ``` + + Raises + ------ + RuntimeError + If the engine is not initialized. + """ + if not engine.initialized: + raise RuntimeError( + "Engine must be initialized before creating a worker. " + "Call engine.initialize() first." + ) + + # Get the server address from the engine + # The internal engine stores server info + internal_engine = engine._engine + server_addrs = internal_engine._server_addrs + + if not server_addrs: + raise RuntimeError("No server addresses found in engine.") + + # Use the first server address for the OpenAI client + # SGLang servers support OpenAI-compatible API at /v1/ + base_url = server_addrs[0] + if not base_url.startswith("http"): + base_url = f"http://{base_url}" + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + + # Create an async OpenAI client pointing to the SGLang server + async_client = openai.AsyncOpenAI( + base_url=base_url, + api_key="EMPTY", # SGLang doesn't require API key by default + ) + + return SGLangWorker( + async_client=async_client, + model=model, + engine=engine, + ) diff --git a/examples/scaffolding/workflow.py b/examples/scaffolding/workflow.py new file mode 100644 index 0000000000..0b530ead88 --- /dev/null +++ b/examples/scaffolding/workflow.py @@ -0,0 +1,217 @@ +""" +ScaffoldingWorkflow - RolloutWorkflow with generation and reward via Scaffolding. + +Architecture +------------ +- Generation: via scaffolding Worker (SGLangWorker calls SGLang OpenAI API) +- Reward: via scaffolding RLVRRewardController +- Logprobs: placeholder (0.0) since recompute_logprob=true in training config + causes the actor to recompute exact logprobs during PPO update. +- Worker & ScaffoldingLlm: lazily created from engine server addresses, + exposed for subclasses (e.g., multi-turn workflows). +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import torch +from transformers import PreTrainedTokenizerFast + +from areal.api.cli_args import GenerationHyperparameters +from areal.api.engine_api import InferenceEngine +from areal.api.workflow_api import RolloutWorkflow +from areal.utils import logging +from areal.utils.dynamic_import import import_from_string + +from ._compat import ( + NativeGenerationController, + ScaffoldingLlm, +) +from .controllers import ( + PipelineTrajectoryMaker, + RLVRRewardController, +) +from .worker import SGLangWorker + +logger = logging.getLogger("ScaffoldingWorkflow") + + +class ScaffoldingWorkflow(RolloutWorkflow): + """RolloutWorkflow with generation and reward via scaffolding components. + + Both generation and reward computation go through scaffolding: + - Generation: SGLangWorker calls SGLang's OpenAI-compatible completions API + - Reward: RLVRRewardController computes verifiable rewards + + Since the OpenAI API does not return per-token logprobs in AReaL's format, + placeholder logprobs are used. Set ``recompute_logprob: true`` in the actor + config so the training engine recomputes exact logprobs during PPO update. + + Parameters + ---------- + reward_fn : Callable | str + The reward function, or an importable string path. + gconfig : GenerationHyperparameters + Generation hyperparameters. + tokenizer : PreTrainedTokenizerFast | str + Tokenizer or path to load it. + enable_thinking : bool + Whether to enable thinking tokens. + """ + + def __init__( + self, + reward_fn: Callable[..., Any] | str, + gconfig: GenerationHyperparameters, + tokenizer: PreTrainedTokenizerFast | str, + enable_thinking: bool = False, + ): + if isinstance(reward_fn, str): + reward_fn = import_from_string(reward_fn) + self.reward_fn = reward_fn + + self.tokenizer = tokenizer + if isinstance(self.tokenizer, str): + from areal.utils.hf_utils import load_hf_tokenizer + + self.tokenizer = load_hf_tokenizer(self.tokenizer) + self.gconfig = gconfig.new_with_stop_and_pad_token_ids(self.tokenizer) + self.enable_thinking = enable_thinking + + # Lazily created from engine server addresses via build_scaffolding_llm + self.worker: SGLangWorker | None = None + self.gen_controller: NativeGenerationController | None = None + self.reward_controller: RLVRRewardController | None = None + self.trajectory_maker: PipelineTrajectoryMaker | None = None + self.scaffolding_llm: ScaffoldingLlm | None = None + + def _lazy_init_scaffolding(self, engine: InferenceEngine) -> None: + """Create Worker, PipelineTrajectoryMaker, and ScaffoldingLlm.""" + import openai + + addr = engine.addresses[0] + base_url = f"http://{addr}" + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + + async_client = openai.AsyncOpenAI(base_url=base_url, api_key="EMPTY") + self.worker = SGLangWorker( + async_client=async_client, model="default", engine=engine + ) + + self.scaffolding_llm = self.build_scaffolding_llm(engine) + logger.info("Initialized scaffolding components with server at %s", addr) + + def build_scaffolding_llm(self, engine: InferenceEngine) -> ScaffoldingLlm: + """Build the ScaffoldingLlm instance. + + Override this method in subclasses to use different scaffolding + controllers or worker configurations. Subclasses should set + ``self.gen_controller`` and ``self.reward_controller`` here. + + When this method is called, ``self.worker`` is already initialized. + + Parameters + ---------- + engine : InferenceEngine + The inference engine (available for address lookup if needed). + + Returns + ------- + ScaffoldingLlm + The constructed ScaffoldingLlm instance. + """ + # Convert gconfig to sampling params for NativeGenerationController + stop_strings = [] + if self.gconfig.stop_token_ids: + for tid in self.gconfig.stop_token_ids: + decoded = self.tokenizer.decode([tid]) + if decoded: + stop_strings.append(decoded) + + sampling_params = { + "max_tokens": self.gconfig.max_new_tokens, + "temperature": self.gconfig.temperature or 1.0, + } + if stop_strings: + sampling_params["stop"] = stop_strings + + self.gen_controller = NativeGenerationController( + sampling_params=sampling_params + ) + self.reward_controller = RLVRRewardController(self.reward_fn) + self.trajectory_maker = PipelineTrajectoryMaker( + self.gen_controller, self.reward_controller + ) + return ScaffoldingLlm( + self.trajectory_maker, + {NativeGenerationController.WorkerTag.GENERATION: self.worker}, + ) + + async def arun_episode( + self, engine: InferenceEngine, data: dict[str, Any] + ) -> dict[str, torch.Tensor]: + """Run a single episode via scaffolding pipeline. + + Delegates the full episode (generation + reward) to + ``self.scaffolding_llm``, which wraps a ``PipelineTrajectoryMaker``. + The result is an ``InteractionWithTokenLogpReward`` whose + ``to_tensor_dict()`` produces the training tensors. + + Note: logprobs are placeholders (0.0). Set ``recompute_logprob: true`` + in actor config so the training engine computes exact logprobs. + + Parameters + ---------- + engine : InferenceEngine + The inference engine (used for server addresses on first call). + data : dict[str, Any] + Input data containing messages and ground truth. + + Returns + ------- + dict[str, torch.Tensor] + Trajectory tensors for PPO training. + """ + if self.worker is None: + self._lazy_init_scaffolding(engine) + + # Tokenize prompt + input_ids = list( + self.tokenizer.apply_chat_template( + data["messages"], + tokenize=True, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + ) + prompt_str = self.tokenizer.decode(input_ids) + + # Configure per-episode data on trajectory maker + # (clone() in scaffolding_llm will deep-copy these) + self.trajectory_maker.task_data = data + self.trajectory_maker.prompt_str = prompt_str + self.trajectory_maker.input_tokens = input_ids + + # Run full pipeline via scaffolding_llm + result = self.scaffolding_llm.generate_async(prompt_str) + await result + + # Extract interaction and convert to tensor dict + scaffolding_output = result.outputs[0] + interactions = scaffolding_output.data + interaction = next(iter(interactions.values())) + + # If output_tokens is missing (e.g., SGLang didn't return token_ids), + # tokenize the output_str as a fallback so loss_mask is non-zero. + resp = interaction.model_response + if resp is not None and not resp.output_tokens: + output_text = scaffolding_output.text or "" + output_tokens = self.tokenizer.encode(output_text, add_special_tokens=False) + resp.output_tokens = output_tokens + resp.output_logprobs = [0.0] * len(output_tokens) + resp.output_versions = [-1] * len(output_tokens) + + return interaction.to_tensor_dict() From bc9f00988579bac7d3cefd1967caa2cad4e02151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E9=97=AE=E7=A5=9E=E5=A5=87=E6=B5=B7=E8=9E=BA?= Date: Mon, 20 Apr 2026 14:47:37 +0800 Subject: [PATCH 3/3] feat(api): add unified RejectionSamplingConfig for async training (#1088) Replace behave_imp_weight_cap/behave_imp_weight_mode with unified RejectionSamplingConfig supporting multiple metrics (ratio, kl_k1, kl_k2, kl_k3), levels (token/sequence), and actions (mask/clamp). Key changes: - Add RejectionSamplingConfig dataclass with comprehensive validation - Implement apply_rejection_sampling for 1D packed and 2D padded formats - Fix loss denominator scaling bug in mask mode (save count before filtering) - Use geometric mean for sequence-level ratio aggregation (matching GSPO) - Broadcast sequence-level geometric mean as uniform behave_imp_weight - Warn when use_decoupled_loss=True but rejection_sampling is None - Update ppo_actor_loss_fn and grpo_loss_fn to use new config - Migrate 40 example configs to new rejection_sampling field - Add 43 unit tests covering all modes, metrics, and edge cases Refs: #1052 --- areal/api/cli_args.py | 282 ++++-- areal/trainer/ppo/actor.py | 29 +- areal/utils/functional/__init__.py | 6 +- areal/utils/functional/functional.py | 380 ++++++-- docs/en/best_practices/algo_perf.md | 2 +- docs/en/cli_reference.md | 308 ++++--- docs/zh/cli_reference.md | 308 ++++--- examples/agent_workflow/config.yaml | 4 +- examples/agent_workflow/config_claude.yaml | 4 +- examples/camel/config.yaml | 4 +- examples/countdown/train_config.yaml | 4 +- examples/distillation/gsm8k_grpo_distill.yaml | 4 +- .../prox_approx/gsm8k_grpo_prox_approx.yaml | 4 +- .../gsm8k_grpo_prox_approx_eval.yaml | 4 +- examples/math/boba_grpo.yaml | 4 +- examples/math/gsm8k_dapo_dynamic_bs.yaml | 4 +- examples/math/gsm8k_drgrpo.yaml | 4 +- examples/math/gsm8k_grpo.yaml | 4 +- examples/math/gsm8k_grpo_cpu.yaml | 4 +- examples/math/gsm8k_grpo_lora.yaml | 4 +- examples/math/gsm8k_grpo_megatron.yaml | 4 +- examples/math/gsm8k_grpo_megatron_fp8.yaml | 4 +- examples/math/gsm8k_grpo_npu.yaml | 4 +- examples/math/gsm8k_gspo.yaml | 4 +- examples/math/gsm8k_liteppo.yaml | 4 +- examples/math/gsm8k_m2po.yaml | 4 +- examples/math/gsm8k_ppo.yaml | 4 +- examples/math/gsm8k_ppo_megatron.yaml | 4 +- examples/math/gsm8k_reinforce.yaml | 4 +- examples/math/gsm8k_reinforce_baseline.yaml | 4 +- examples/math/gsm8k_rloo.yaml | 4 +- examples/math/gsm8k_sapo.yaml | 4 +- examples/multi_turn_math/gsm8k_grpo_mt.yaml | 4 +- examples/openai_agents/config.yaml | 4 +- examples/openclaw/config.yaml | 4 +- examples/search_agent/local_1.5b_example.yaml | 4 +- .../tongyi_deepresearch/config.yaml | 4 +- examples/skypilot/gsm8k_grpo_ray.yaml | 4 +- examples/tau2/config_1.7b_airline.yaml | 4 +- examples/tau2/config_235b_moe_airline.yaml | 4 +- examples/tau2/config_30b_moe_airline.yaml | 4 +- examples/tau2/config_8b_airline.yaml | 4 +- examples/tir/tir_math_config.yaml | 4 +- examples/vlm/clevr_count_70k_grpo.yaml | 4 +- examples/vlm/geometry3k_grpo.yaml | 4 +- .../qwen2_5_vl_3b_geometry3k_grpo.yaml | 4 +- .../vlm_npu/qwen3_vl_2b_geometry3k_grpo.yaml | 4 +- tests/experimental/archon/test_grpo.py | 5 - tests/test_functional.py | 140 +-- tests/test_ppo_stats.py | 2 - tests/test_prox_approx.py | 10 - tests/test_rejection_sampling.py | 839 ++++++++++++++++++ 52 files changed, 1911 insertions(+), 560 deletions(-) create mode 100644 tests/test_rejection_sampling.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 91f9d3ba13..23b5bd96eb 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1191,6 +1191,171 @@ def __post_init__(self): ) +@dataclass +class RejectionSamplingConfig: + """Unified configuration for sample filtering based on policy divergence. + + Filters tokens/sequences where the divergence between proximal policy + and behavior policy exceeds a threshold, via two action modes: + - 'mask': zero out loss_mask (rejection, exclude from gradient) + - 'clamp': clamp importance weight to bounds (truncation, bounded gradient) + + Supports direct ratio bounds and KL divergence estimators (K1/K2/K3), + at both token-level and sequence-level granularity. + + Replaces the removed ``behave_imp_weight_cap`` and ``behave_imp_weight_mode``. + + Attributes: + level: Filtering granularity ('token' or 'sequence'). When ``level='sequence'`` + and ``metric='ratio'``, both the filtering decision and the correction + weight (behave_imp_weight) use the sequence-level geometric mean, + matching the old ``sequence_mask``/``sequence_truncate`` semantics. + action: Action mode ('mask' or 'clamp'). + metric: Divergence metric ('ratio', 'kl_k1', 'kl_k2', 'kl_k3'). + agg: Aggregation method for sequence-level ('sum', 'mean', 'max'). + For 'ratio' metric, aggregation is performed in log space (geometric + mean/sum) to avoid the "length trap" and match GSPO semantics. + For KL metrics, aggregation is arithmetic. + upper: Upper bound for filtering. + lower: Lower bound for filtering (optional). + """ + + level: str = field( + default="token", + metadata={ + "help": "Filtering granularity. " + "'token': per-token filtering (each token judged independently). " + "'sequence': per-sequence filtering (all tokens in a sequence share the same fate). " + "When metric='ratio', both the filtering decision and the correction weight " + "(behave_imp_weight) operate at sequence level using the geometric mean.", + "choices": ["token", "sequence"], + }, + ) + action: str = field( + default="mask", + metadata={ + "help": "Action to take when metric exceeds threshold. " + "'mask': zero out loss_mask for filtered tokens/sequences (rejection, " + "completely excludes from gradient computation). " + "'clamp': clamp importance weight to [lower, upper] bounds (truncation, " + "tokens still participate in gradient but with bounded weight).", + "choices": ["mask", "clamp"], + }, + ) + metric: str = field( + default="ratio", + metadata={ + "help": "Divergence metric for filtering. " + "'ratio': direct importance ratio π_proximal/π_behave. " + "'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). " + "'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. " + "'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator.", + "choices": ["ratio", "kl_k1", "kl_k2", "kl_k3"], + }, + ) + agg: str = field( + default="mean", + metadata={ + "help": "Aggregation method for sequence-level filtering. " + "Only used when level='sequence'. " + "For 'ratio' metric, aggregation is in log space: " + "'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean " + "(length-invariant, consistent with GSPO). " + "For KL metrics, aggregation is arithmetic: " + "'sum' = sum(kl_i), 'mean' = mean(kl_i). " + "'max': max of per-token metric values (most conservative).", + "choices": ["sum", "mean", "max"], + }, + ) + upper: float = field( + default=5.0, + metadata={ + "help": "Upper bound for filtering. " + "Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). " + "For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. " + "For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0." + }, + ) + lower: float | None = field( + default=None, + metadata={ + "help": "Lower bound for filtering (optional). " + "None means no lower bound. " + "For 'ratio' metric: typical value is 0.5 (filter out tokens where policy " + "probability dropped significantly). Must be > 0. " + "For 'kl_k1' metric: can be used to filter negative KL estimates." + }, + ) + + def __post_init__(self): + """Validate configuration.""" + import warnings + + _VALID_LEVELS = ("token", "sequence") + _VALID_ACTIONS = ("mask", "clamp") + _VALID_METRICS = ("ratio", "kl_k1", "kl_k2", "kl_k3") + _VALID_AGGS = ("sum", "mean", "max") + + # Validate enum-like fields. + if self.level not in _VALID_LEVELS: + raise ValueError( + f"level must be one of {_VALID_LEVELS}, got '{self.level}'" + ) + if self.action not in _VALID_ACTIONS: + raise ValueError( + f"action must be one of {_VALID_ACTIONS}, got '{self.action}'" + ) + if self.metric not in _VALID_METRICS: + raise ValueError( + f"metric must be one of {_VALID_METRICS}, got '{self.metric}'" + ) + if self.agg not in _VALID_AGGS: + raise ValueError(f"agg must be one of {_VALID_AGGS}, got '{self.agg}'") + + # Validate lower <= upper when both are set. + if self.lower is not None and self.lower > self.upper: + raise ValueError( + f"lower ({self.lower}) cannot be greater than upper ({self.upper})" + ) + + # For ratio metric, upper must be > 1.0 (otherwise all non-identical policy tokens are filtered). + if self.metric == "ratio": + if self.upper <= 1.0: + raise ValueError( + f"upper must be > 1.0 for 'ratio' metric (otherwise all non-identical " + f"policy tokens will be filtered), got {self.upper}" + ) + if self.lower is not None and self.lower <= 0: + raise ValueError( + f"lower must be positive for 'ratio' metric, got {self.lower}" + ) + # For KL metrics, upper must be > 0. + # Note: kl_k1 is excluded because it is a forward KL unbiased estimator that + # can produce negative values, so requiring upper > 0 would be too restrictive. + if self.metric in ("kl_k2", "kl_k3") and self.upper <= 0: + raise ValueError( + f"upper must be positive for '{self.metric}' metric, got {self.upper}" + ) + # Clamp action only supports ratio metric (direct importance weight truncation). + if self.action == "clamp" and self.metric != "ratio": + raise ValueError( + f"action='clamp' only supports metric='ratio' (direct importance weight " + f"truncation). Got metric='{self.metric}'. " + f"Use action='mask' for KL-based filtering." + ) + # Clamp action defaults lower to 0.0 (consistent with old truncate behavior). + if self.action == "clamp" and self.lower is None: + self.lower = 0.0 + # Validate sequence-level aggregation. + if self.level == "token" and self.agg != "mean": + warnings.warn( + f"agg='{self.agg}' is ignored when level='token'. " + "Aggregation is only used for sequence-level filtering.", + UserWarning, + stacklevel=2, + ) + + @dataclass class PPOActorConfig(TrainEngineConfig): """Configuration for PPO actor model, a subclass of a TrainEngine.""" @@ -1294,32 +1459,12 @@ class PPOActorConfig(TrainEngineConfig): "help": "Use the decoupled loss. Implicitly enables recompute_logprob." }, ) - behave_imp_weight_cap: float | None = field( - default=5.0, - metadata={ - "help": "Filter out tokens/sequences where behave_imp_weight exceeds this cap when computing loss. " - "Only effective when use_decoupled_loss=True (decoupled/async training). " - "Must be > 1.0 when mode is not 'disabled'. " - "Mode controlled by behave_imp_weight_mode (mask/truncate/disabled)." - }, - ) - behave_imp_weight_mode: str = field( - default="token_mask", + rejection_sampling: RejectionSamplingConfig | None = field( + default=None, metadata={ - "help": "Mode for importance weight filtering. " - "Only effective when use_decoupled_loss=True (decoupled/async training). " - "'token_truncate': clamp token ratio to [0, cap]. " - "'token_mask': set token ratio to 0 where ratio > cap. " - "'sequence_truncate': clamp sequence ratio to [0, cap]. " - "'sequence_mask': set sequence ratio to 0 where ratio > cap. " - "'disabled': disable importance weight correction.", - "choices": [ - "token_truncate", - "token_mask", - "sequence_truncate", - "sequence_mask", - "disabled", - ], + "help": "Rejection sampling configuration for filtering stale samples. " + "None disables filtering (equivalent to old behave_imp_weight_mode='disabled'). " + "Only effective when use_decoupled_loss=True." }, ) importance_sampling_level: str = field( @@ -1372,34 +1517,28 @@ def should_compute_prox_logp(self) -> bool: def __post_init__(self): """Validate PPO actor configuration.""" - # Validate MIS/TIS configuration - if self.behave_imp_weight_mode == "disabled": - if self.behave_imp_weight_cap is not None: - raise ValueError( - f"behave_imp_weight_cap must be None when behave_imp_weight_mode is 'disabled', " - f"got {self.behave_imp_weight_cap}." - ) - else: - if ( - self.behave_imp_weight_cap is not None - and self.behave_imp_weight_cap <= 1.0 - ): - raise ValueError( - f"behave_imp_weight_cap must be > 1.0 when behave_imp_weight_mode is not 'disabled', " - f"got {self.behave_imp_weight_cap}." - ) - - # Warn if behave_imp_weight settings are configured but use_decoupled_loss is False - if not self.use_decoupled_loss: - if ( - self.behave_imp_weight_cap is not None - or self.behave_imp_weight_mode != "disabled" - ): - logger.warning( - "behave_imp_weight_cap and behave_imp_weight_mode are configured but " - "use_decoupled_loss=False. These settings will be ignored. " - "Set use_decoupled_loss=True to enable decoupled loss with importance weight correction." - ) + # Warn if rejection_sampling is configured but use_decoupled_loss is False + if not self.use_decoupled_loss and self.rejection_sampling is not None: + logger.warning( + "rejection_sampling is configured but use_decoupled_loss=False. " + "Filtering will be ignored. Set use_decoupled_loss=True to enable." + ) + # Warn if decoupled loss is enabled but no rejection sampling configured. + # The old default (behave_imp_weight_cap=5.0, mode=token_mask) enabled + # filtering implicitly; the new default (rejection_sampling=None) disables + # it. This warning helps users who relied on the old defaults. + if self.use_decoupled_loss and self.rejection_sampling is None: + logger.warning( + "use_decoupled_loss=True with rejection_sampling=None: " + "staleness filtering is disabled. If you previously relied on " + "the default behave_imp_weight_cap=5.0 with token_mask mode, " + "restore equivalent behavior with:\n" + " rejection_sampling:\n" + " level: token\n" + " action: mask\n" + " metric: ratio\n" + " upper: 5.0" + ) # Validate SAPO configuration if self.use_sapo_loss: @@ -2562,7 +2701,44 @@ def parse_cli_args(argv: list[str]): return cfg, config_file +_LEGACY_REJECTION_SAMPLING_KEYS = { + "behave_imp_weight_cap", + "behave_imp_weight_mode", +} + +_LEGACY_MIGRATION_MESSAGE = ( + "Config keys 'behave_imp_weight_cap' and 'behave_imp_weight_mode' have been " + "removed. Use 'rejection_sampling' sub-config instead.\n" + "Migration mapping:\n" + " behave_imp_weight_mode='disabled' -> rejection_sampling: null\n" + " behave_imp_weight_mode='token_mask', behave_imp_weight_cap=X\n" + " -> rejection_sampling: {level: token, action: mask, metric: ratio, upper: X}\n" + " behave_imp_weight_mode='token_truncate', behave_imp_weight_cap=X\n" + " -> rejection_sampling: {level: token, action: clamp, metric: ratio, upper: X}\n" +) + + +def _migrate_legacy_rejection_sampling(cfg: DictConfig) -> DictConfig: + """Intercept removed behave_imp_weight_* keys and raise actionable error.""" + # Walk top-level and known nested actor/teacher configs for legacy keys. + sections_to_check = ["actor", "teacher"] + for section in sections_to_check: + if not OmegaConf.is_missing(cfg, section) and section in cfg: + sub = cfg[section] + if sub is None or not isinstance(sub, DictConfig): + continue + found = _LEGACY_REJECTION_SAMPLING_KEYS.intersection(sub.keys()) + if found: + raise ValueError( + f"Found removed config key(s) {found} under '{section}'. " + + _LEGACY_MIGRATION_MESSAGE + ) + return cfg + + def to_structured_cfg(cfg, config_cls): + # Intercept legacy config keys before merge to give actionable error. + _migrate_legacy_rejection_sampling(cfg) # Merge with the default configuration. # The yaml and commandline can omit some default values defined in python dataclasses. default_cfg = OmegaConf.structured(config_cls) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index c9e61dd290..07944a31a5 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -6,7 +6,7 @@ import torch from areal.api import TrainEngine -from areal.api.cli_args import MicroBatchSpec, PPOActorConfig +from areal.api.cli_args import MicroBatchSpec, PPOActorConfig, RejectionSamplingConfig from areal.experimental.training_service.controller.controller import ( GatewayTrainController, ) @@ -102,10 +102,13 @@ def _log_configuration(self): logger.info(" log_p_theta (π_θ): TRAINING FORWARD PASS (current policy)") - if config.behave_imp_weight_cap: + if config.rejection_sampling is not None: + rs = config.rejection_sampling logger.info( - f" Importance weight cap: {config.behave_imp_weight_cap:.1f} " - "(filters out tokens with extreme weights)" + f" Rejection sampling: level={rs.level}, metric={rs.metric}, " + f"action={rs.action}, upper={rs.upper}" + + (f", lower={rs.lower}" if rs.lower is not None else "") + + (f", agg={rs.agg}" if rs.level == "sequence" else "") ) # Log other critical config @@ -310,8 +313,11 @@ def _ppo_update(self, data: dict[str, Any]) -> None: scalars["use_dual_clip"] = 1 else: scalars["use_dual_clip"] = 0 - if self.config.behave_imp_weight_cap is not None: - scalars["behave_imp_weight_cap"] = self.config.behave_imp_weight_cap + if self.config.rejection_sampling is not None: + rs = self.config.rejection_sampling + scalars["rs_upper"] = rs.upper + if rs.lower is not None: + scalars["rs_lower"] = rs.lower stats_tracker.scalar(**scalars) if self.config.log_agent_stats: @@ -344,7 +350,7 @@ def _ppo_update(self, data: dict[str, Any]) -> None: eps_clip=self.config.eps_clip, eps_clip_higher=self.config.eps_clip_higher, c_clip=self.config.c_clip, - behave_imp_weight_cap=self.config.behave_imp_weight_cap, + rejection_sampling=self.config.rejection_sampling, m2_threshold=self.m2_threshold, importance_sampling_level=self.config.importance_sampling_level, current_version=current_version, @@ -353,7 +359,6 @@ def _ppo_update(self, data: dict[str, Any]) -> None: sapo_tau_pos=self.config.sapo_tau_pos, sapo_tau_neg=self.config.sapo_tau_neg, use_decoupled_loss=self.config.use_decoupled_loss, - behave_imp_weight_mode=self.config.behave_imp_weight_mode, ), loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), ) @@ -407,7 +412,7 @@ def grpo_loss_fn( eps_clip: float, eps_clip_higher: float | None, c_clip: float | None, - behave_imp_weight_cap: float | None, + rejection_sampling: RejectionSamplingConfig | None = None, m2_threshold: float | None = None, importance_sampling_level: str = "token", current_version: int | None = None, @@ -416,7 +421,6 @@ def grpo_loss_fn( sapo_tau_pos: float = 1.0, sapo_tau_neg: float = 1.05, use_decoupled_loss: bool = False, - behave_imp_weight_mode: str = "token_mask", vocab_min_logits: torch.Tensor | None = None, vocab_max_logits: torch.Tensor | None = None, ): @@ -470,10 +474,9 @@ def grpo_loss_fn( loss_mask=loss_mask, c_clip=c_clip, proximal_logprobs=prox_logp, - behave_imp_weight_cap=behave_imp_weight_cap, + rejection_sampling=rejection_sampling, importance_sampling_level=importance_sampling_level, cu_seqlens=input_data.get("cu_seqlens"), - behave_imp_weight_mode=behave_imp_weight_mode, ) # Joint Distillation KL Loss @@ -540,6 +543,8 @@ def grpo_loss_fn( behave_approx_kl=stat["behave_approx_kl"], denominator="unclipped_behave_tokens", ) + if "filtered_fraction" in stat: + stats_tracker.scalar(rs_filtered_fraction=stat["filtered_fraction"]) if vocab_min_logits is not None and vocab_max_logits is not None: stats_tracker.stat( diff --git a/areal/utils/functional/__init__.py b/areal/utils/functional/__init__.py index f0563ff3b4..369eaf9149 100644 --- a/areal/utils/functional/__init__.py +++ b/areal/utils/functional/__init__.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from areal.utils.functional.functional import ( - compute_behave_imp_weight, + RejectionSamplingResult, + apply_rejection_sampling, masked_normalization, ppo_actor_loss_fn, ppo_critic_loss_fn, @@ -15,7 +16,8 @@ __all__ = [ # functional.py - "compute_behave_imp_weight", + "RejectionSamplingResult", + "apply_rejection_sampling", "masked_normalization", "ppo_actor_loss_fn", "ppo_critic_loss_fn", diff --git a/areal/utils/functional/functional.py b/areal/utils/functional/functional.py index d2c2db0f5b..a79aab6024 100644 --- a/areal/utils/functional/functional.py +++ b/areal/utils/functional/functional.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import functools +from dataclasses import dataclass from typing import Any import numpy as np import torch import torch.distributed as dist +from areal.api.cli_args import RejectionSamplingConfig +from areal.utils.data import KLEstimator + @torch.no_grad() def masked_normalization( @@ -143,73 +147,281 @@ def _compute_sequence_level_ratio_and_advantages( return ratio, advantages -def compute_behave_imp_weight( +@dataclass +class RejectionSamplingResult: + """Result of rejection sampling, used by ppo_actor_loss_fn. + + Attributes: + loss_mask: Updated loss mask (mask mode) or original loss mask (clamp mode). + behave_imp_weight: Importance weight (clamped in clamp mode, raw in mask mode). + filtered_fraction: Fraction of valid tokens that were filtered/clamped (for logging). + """ + + loss_mask: torch.Tensor + behave_imp_weight: torch.Tensor + filtered_fraction: float + + +def _check_bounds( + metric: torch.Tensor, config: RejectionSamplingConfig +) -> torch.Tensor: + """Check if metric values are within configured bounds. + + Args: + metric: Per-token or per-sequence metric values. + config: Rejection sampling configuration with upper and optional lower bounds. + + Returns: + Boolean tensor, True where metric is within bounds. + """ + if config.lower is not None: + return (metric >= config.lower) & (metric <= config.upper) + else: + return metric <= config.upper + + +def apply_rejection_sampling( proximal_logprobs: torch.Tensor, old_logprobs: torch.Tensor, loss_mask: torch.Tensor, cu_seqlens: torch.Tensor | None, - behave_imp_weight_mode: str, - behave_imp_weight_cap: float | None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute behavioural importance weight for decoupled loss correction. + config: RejectionSamplingConfig, +) -> RejectionSamplingResult: + """Apply rejection sampling based on divergence between proximal and behavior policy. + + Supports two action modes: + - 'mask': zero out loss_mask for tokens/sequences exceeding threshold (rejection) + - 'clamp': clamp importance weight to bounds for tokens/sequences exceeding + threshold (truncation, tokens still participate in gradient) Args: - proximal_logprobs: Recomputed log probabilities from reference model - old_logprobs: Log probabilities from inference engine - loss_mask: Boolean mask indicating valid tokens - cu_seqlens: Cumulative sequence lengths for packed sequences - behave_imp_weight_mode: Mode for importance weight correction - - 'token_truncate': clamp token ratio to [0, cap] - - 'token_mask': set token ratio to 0 where ratio > cap - - 'sequence_truncate': clamp sequence ratio to [0, cap] - - 'sequence_mask': set sequence ratio to 0 where ratio > cap - - 'disabled': skip importance weight correction - behave_imp_weight_cap: Cap value for importance weights + proximal_logprobs: Proximal policy log-probabilities, + shape [batch, seq_len] (2D padded) or [total_tokens] (1D packed). + old_logprobs: Behavior policy log-probabilities from inference engine, + same shape as proximal_logprobs. + loss_mask: Original loss mask (1 for valid tokens), same shape as proximal_logprobs. + cu_seqlens: Cumulative sequence lengths for 1D packed format. Shape: [batch_size + 1]. + Required when inputs are 1D. None for 2D padded inputs. + config: Configuration for rejection sampling. Returns: - Tuple of (behave_imp_weight, behave_approx_kl, behave_mask) + RejectionSamplingResult with updated loss_mask, behave_imp_weight, and filtered_fraction. """ - if behave_imp_weight_mode == "disabled": + # Step 0: Validate input shapes. + if proximal_logprobs.shape != old_logprobs.shape: raise ValueError( - "compute_behave_imp_weight should not be called with mode='disabled'. " - "The caller should guard this call with 'if behave_imp_weight_mode != \"disabled\"'." + f"proximal_logprobs shape {proximal_logprobs.shape} != " + f"old_logprobs shape {old_logprobs.shape}" + ) + if proximal_logprobs.shape != loss_mask.shape: + raise ValueError( + f"proximal_logprobs shape {proximal_logprobs.shape} != " + f"loss_mask shape {loss_mask.shape}" + ) + if proximal_logprobs.ndim not in (1, 2): + raise ValueError( + f"Expected 1D (packed) or 2D (padded) tensors, " + f"got ndim={proximal_logprobs.ndim}" ) - is_sequence_level = "sequence" in behave_imp_weight_mode - behave_approx_kl = proximal_logprobs - old_logprobs - behave_imp_weight_log_ratio = behave_approx_kl - - if is_sequence_level: - # Compute sequence-level geometric mean importance weights - dummy_advantages = torch.zeros_like(behave_imp_weight_log_ratio) - behave_imp_weight_seq, _ = _compute_sequence_level_ratio_and_advantages( - behave_imp_weight_log_ratio, - dummy_advantages, - loss_mask, - cu_seqlens, + # Step 1: Compute log ratio = log(π_proximal / π_behave) + # Upcast operands to fp32 before subtraction to avoid precision loss in bf16/fp16. + log_ratio = proximal_logprobs.detach().float() - old_logprobs.detach().float() + # Sanitize non-finite values (e.g. -inf - (-inf) = NaN) to prevent NaN propagation. + log_ratio = torch.where(torch.isfinite(log_ratio), log_ratio, 0.0) + + # Step 2: Compute metric value (reuse existing KLEstimator sign conventions) + if config.metric == "ratio": + # Direct ratio π_proximal / π_behave + metric = torch.exp(log_ratio) + elif config.metric in ("kl_k1", "kl_k2", "kl_k3"): + # Use existing KLEstimator (note: _compute_approx_kl takes log_probs, log_probs_base) + estimator_name = config.metric.replace("kl_", "") # "k1", "k2", "k3" + metric = KLEstimator._compute_approx_kl( + log_probs=proximal_logprobs.detach(), + log_probs_base=old_logprobs.detach(), + kl_estimator=estimator_name, + apply_clamp=False, # Don't clamp; threshold check handles bounds ) - behave_imp_weight = behave_imp_weight_seq else: - # Token-level importance weights (default) - behave_imp_weight = behave_imp_weight_log_ratio.exp() + raise ValueError(f"Unknown metric: {config.metric}") + + # Step 3: Compute behave_imp_weight (needed for both modes) + behave_imp_weight = torch.exp(log_ratio) + # Save original weight before any clamping, to compute clamped fraction later. + original_weight = behave_imp_weight + + # Step 4: Aggregate and filter + # + # For ratio metric, aggregate in log space (geometric mean) to match GSPO + # semantics and avoid the "length trap" where arithmetic mean inflates + # sequence-level ratios. For KL metrics, aggregate in metric space + # (arithmetic) since KL divergence is additive. + _use_log_agg = config.metric == "ratio" + + if config.level == "sequence": + # Pre-compute sequence indexing (shared by filtering and weight broadcast). + if loss_mask.ndim == 1: + # 1D packed format: use cu_seqlens + if cu_seqlens is None: + raise ValueError( + "cu_seqlens is required for 1D packed tensors " + "in sequence-level filtering." + ) + batch_size = cu_seqlens.shape[0] - 1 + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + sequence_idx = torch.arange( + batch_size, device=metric.device + ).repeat_interleave(seq_lengths) + + # For ratio metric: aggregate log_ratio (geometric); else: aggregate metric (arithmetic). + agg_values = log_ratio if _use_log_agg else metric + masked_agg = torch.where(loss_mask.bool(), agg_values, 0.0) + valid_count_per_seq = ( + torch.zeros(batch_size, device=loss_mask.device, dtype=torch.int32) + .scatter_add_(0, sequence_idx, loss_mask.int()) + .clamp(min=1) + ) - # Apply cap (truncate or mask) based on mode - if behave_imp_weight_cap is not None: - if "truncate" in behave_imp_weight_mode: + # Ratio metric + sequence level: use geometric mean as uniform weight + # for all tokens (matches old sequence_mask/sequence_truncate semantics). + if _use_log_agg: + # masked_agg is already log_ratio masked by loss_mask (computed above). + seq_log_sum = torch.zeros( + batch_size, device=log_ratio.device, dtype=log_ratio.dtype + ).scatter_add_(0, sequence_idx, masked_agg) + seq_log_mean = seq_log_sum / valid_count_per_seq.to(log_ratio.dtype) + behave_imp_weight = torch.exp(seq_log_mean)[sequence_idx] + original_weight = behave_imp_weight + + if config.agg == "sum": + seq_agg = torch.zeros( + batch_size, device=metric.device, dtype=agg_values.dtype + ).scatter_add_(0, sequence_idx, masked_agg) + elif config.agg == "mean": + seq_agg_sum = torch.zeros( + batch_size, device=metric.device, dtype=agg_values.dtype + ).scatter_add_(0, sequence_idx, masked_agg) + seq_agg = seq_agg_sum / valid_count_per_seq.to(agg_values.dtype) + elif config.agg == "max": + agg_for_max = agg_values.masked_fill(~loss_mask.bool(), float("-inf")) + seq_agg = torch.full( + (batch_size,), + float("-inf"), + device=metric.device, + dtype=agg_values.dtype, + ).scatter_reduce_(0, sequence_idx, agg_for_max, reduce="amax") + # All-masked sequences stay -inf; treat them as in-bounds (no valid + # tokens to filter, and their loss_mask is already all-zero). + # Recompute from raw counts to detect true zero. + raw_valid = torch.zeros( + batch_size, device=loss_mask.device, dtype=torch.int32 + ).scatter_add_(0, sequence_idx, loss_mask.int()) + all_masked = raw_valid == 0 + seq_agg = torch.where(all_masked, torch.zeros_like(seq_agg), seq_agg) + else: + raise ValueError(f"Unknown agg method: {config.agg}") + + # Convert back to metric space for threshold comparison. + seq_metric = torch.exp(seq_agg) if _use_log_agg else seq_agg + + # Check each sequence against bounds + in_bounds_per_seq = _check_bounds(seq_metric, config) + + if config.action == "mask": + # Broadcast back to token level, filter entire sequence + in_bounds = in_bounds_per_seq[sequence_idx] + else: + # clamp mode: clamp tokens in out-of-bounds sequences + out_of_bounds = (~in_bounds_per_seq)[sequence_idx] + behave_imp_weight = torch.where( + out_of_bounds, + behave_imp_weight.clamp( + min=config.lower if config.lower is not None else 0.0, + max=config.upper, + ), + behave_imp_weight, + ) + else: + # 2D padded format + agg_values = log_ratio if _use_log_agg else metric + masked_agg = torch.where(loss_mask.bool(), agg_values, 0.0) + valid_count = loss_mask.sum(dim=-1, keepdim=True).clamp(min=1) + + # Ratio metric + sequence level: geometric mean as uniform weight. + if _use_log_agg: + seq_log_mean = masked_agg.sum(dim=-1, keepdim=True) / valid_count + behave_imp_weight = torch.exp(seq_log_mean).expand_as(log_ratio) + original_weight = behave_imp_weight + + if config.agg == "sum": + seq_agg = masked_agg.sum(dim=-1, keepdim=True) + elif config.agg == "mean": + seq_agg = masked_agg.sum(dim=-1, keepdim=True) / valid_count + elif config.agg == "max": + agg_for_max = agg_values.masked_fill(~loss_mask.bool(), float("-inf")) + seq_agg = agg_for_max.max(dim=-1, keepdim=True)[0] + # All-masked sequences stay -inf; treat them as in-bounds (no valid + # tokens to filter, and their loss_mask is already all-zero). + all_masked = loss_mask.sum(dim=-1, keepdim=True) == 0 + seq_agg = torch.where(all_masked, torch.zeros_like(seq_agg), seq_agg) + else: + raise ValueError(f"Unknown agg method: {config.agg}") + + # Convert back to metric space for threshold comparison. + seq_metric = torch.exp(seq_agg) if _use_log_agg else seq_agg + + if config.action == "mask": + in_bounds = _check_bounds(seq_metric, config).expand_as(loss_mask) + else: + # clamp mode: clamp tokens in out-of-bounds sequences + out_of_bounds = (~_check_bounds(seq_metric, config)).expand_as( + loss_mask + ) + behave_imp_weight = torch.where( + out_of_bounds, + behave_imp_weight.clamp( + min=config.lower if config.lower is not None else 0.0, + max=config.upper, + ), + behave_imp_weight, + ) + else: + # Token level + if config.action == "mask": + in_bounds = _check_bounds(metric, config) + else: + # clamp mode: directly clamp importance weight behave_imp_weight = behave_imp_weight.clamp( - min=0.0, max=behave_imp_weight_cap - ) - else: # mask - behave_imp_weight = torch.where( - behave_imp_weight > behave_imp_weight_cap, 0.0, behave_imp_weight + min=config.lower if config.lower is not None else 0.0, + max=config.upper, ) - # Apply loss_mask - behave_imp_weight = torch.where(loss_mask, behave_imp_weight, 0.0) - behave_mask = (behave_imp_weight > 0).logical_and(loss_mask) - behave_approx_kl = torch.where(behave_mask, behave_approx_kl, 0.0) + # Step 5: Update loss_mask or keep it based on action mode + if config.action == "mask": + candidates = loss_mask.bool() + updated_mask = (candidates & in_bounds).to(loss_mask.dtype) + filtered_count = (candidates & ~in_bounds).sum().item() + total_count = candidates.sum().item() + filtered_fraction = filtered_count / max(total_count, 1) + else: + # clamp mode: loss_mask unchanged + updated_mask = loss_mask + # Report fraction of clamped tokens (for logging) + clamped_count = ( + (loss_mask.bool() & (original_weight != behave_imp_weight)).sum().item() + ) + total_count = loss_mask.bool().sum().item() + filtered_fraction = clamped_count / max(total_count, 1) + + # Apply loss_mask to behave_imp_weight + behave_imp_weight = torch.where(updated_mask.bool(), behave_imp_weight, 0.0) - return behave_imp_weight, behave_approx_kl, behave_mask + return RejectionSamplingResult( + loss_mask=updated_mask, + behave_imp_weight=behave_imp_weight, + filtered_fraction=filtered_fraction, + ) def ppo_actor_loss_fn( @@ -221,12 +433,18 @@ def ppo_actor_loss_fn( loss_mask: torch.Tensor, eps_clip_higher: float | None = None, c_clip: float | None = None, - behave_imp_weight_cap: float | None = None, + rejection_sampling: RejectionSamplingConfig | None = None, importance_sampling_level: str = "token", cu_seqlens: torch.Tensor | None = None, - behave_imp_weight_mode: str = "token_mask", ) -> tuple[torch.Tensor, dict]: - """ + """PPO actor loss function with optional rejection sampling. + + The ``rejection_sampling`` parameter replaces the removed + ``behave_imp_weight_cap`` / ``behave_imp_weight_mode``. + + - ``action='mask'``: modifies loss_mask before loss computation (rejection) + - ``action='clamp'``: clamps importance weight to bounds (truncation) + When decoupled loss is disabled: 1. if recompute logp, both old_logprobs and proximal_logprobs are recomputed logp; 2. if no recomputation, both old_logp and proximal_logprobs are produced by the inference backend. @@ -234,23 +452,50 @@ def ppo_actor_loss_fn( When decoupled loss is enabled, proximal_logprobs is the recomputed logp, old_logprobs is produced by the inference engine. + Note: ``importance_sampling_level`` controls PPO ratio (π_θ/π_proximal) + aggregation (GSPO), which is orthogonal to ``rejection_sampling.level`` + that controls staleness filtering (π_proximal/π_behave) granularity. + Args: + logprobs: Current policy log-probabilities (π_θ). + proximal_logprobs: Proximal policy log-probabilities (π_proximal). + old_logprobs: Behavior policy log-probabilities from inference (π_behave). + advantages: Per-token advantage estimates. + eps_clip: PPO clipping factor for policy ratio. + loss_mask: Mask for valid tokens (1 = valid). + eps_clip_higher: Upper clipping factor (decoupled clipping). None = use eps_clip. + c_clip: Dual clipping factor, must be > 1.0. None disables dual clipping. + rejection_sampling: Rejection sampling configuration. None disables filtering. importance_sampling_level: Level at which to compute importance sampling ratios. - - 'token': Per-token ratios + - 'token': Per-token ratios (standard PPO) - 'sequence': Sequence-level geometric mean of per-token ratios (GSPO) cu_seqlens: Cumulative sequence lengths for packed sequences (1D tensors). Required when inputs are 1D and importance_sampling_level='sequence'. Shape: [batch_size + 1], where cu_seqlens[i] marks the start of sequence i. Not needed for 2D padded inputs (sequences identified by batch dimension). - behave_imp_weight_mode: Mode for importance weight correction (mask or truncate). - - 'token_truncate': clamp token ratio to [0, cap] - - 'token_mask': set token ratio to 0 where ratio > cap - - 'sequence_truncate': clamp sequence ratio to [0, cap] - - 'sequence_mask': set sequence ratio to 0 where ratio > cap - - 'disabled': skip importance weight correction """ + # Save original count BEFORE rejection sampling may modify loss_mask. + # This keeps the denominator consistent with loss_weight_fn in actor.py, + # which always uses the original loss_mask from input_data. Without this, + # mask mode would inflate per-token gradients by N_original / N_kept. loss_mask_count = loss_mask.count_nonzero() or 1 + # === Apply rejection sampling (replaces old compute_behave_imp_weight) === + if rejection_sampling is not None: + rs_result = apply_rejection_sampling( + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, + loss_mask=loss_mask, + cu_seqlens=cu_seqlens, + config=rejection_sampling, + ) + # mask mode updates loss_mask; clamp mode keeps it unchanged + loss_mask = rs_result.loss_mask + behave_imp_weight = rs_result.behave_imp_weight + filtered_fraction = rs_result.filtered_fraction + else: + filtered_fraction = 0.0 + if importance_sampling_level == "sequence": # GSPO: Compute sequence-level geometric mean of probability ratios log_ratio = logprobs - proximal_logprobs @@ -284,17 +529,11 @@ def ppo_actor_loss_fn( else: dual_clip_mask = torch.zeros_like(clip_mask) - # Compute behavioural importance weight only when not disabled - # When disabled, pg_loss remains unchanged (no behavioural correction applied) - if behave_imp_weight_mode != "disabled": - behave_imp_weight, behave_approx_kl, behave_mask = compute_behave_imp_weight( - proximal_logprobs=proximal_logprobs, - old_logprobs=old_logprobs, - loss_mask=loss_mask, - cu_seqlens=cu_seqlens, - behave_imp_weight_mode=behave_imp_weight_mode, - behave_imp_weight_cap=behave_imp_weight_cap, - ) + # Apply behavioural importance weight from rejection sampling + if rejection_sampling is not None: + behave_approx_kl = proximal_logprobs.detach() - old_logprobs.detach() + behave_mask = (behave_imp_weight > 0).logical_and(loss_mask.bool()) + behave_approx_kl = torch.where(behave_mask, behave_approx_kl, 0.0) pg_loss = pg_loss * behave_imp_weight logging_loss = pg_loss.detach() @@ -308,11 +547,12 @@ def ppo_actor_loss_fn( clip_mask=clip_mask, dual_clip_mask=dual_clip_mask, ) - if proximal_logprobs is not None and behave_imp_weight_mode != "disabled": + if rejection_sampling is not None: stat.update( behave_approx_kl=behave_approx_kl.detach(), behave_imp_weight=behave_imp_weight.detach(), behave_mask=behave_mask, + filtered_fraction=filtered_fraction, ) return pg_loss, stat diff --git a/docs/en/best_practices/algo_perf.md b/docs/en/best_practices/algo_perf.md index c67f33e022..a6c61989d6 100644 --- a/docs/en/best_practices/algo_perf.md +++ b/docs/en/best_practices/algo_perf.md @@ -98,7 +98,7 @@ A, \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{proximal}}}, 1-\epsilon, **Troubleshooting `behave_imp_weight` deviations:** -- Ensure `behave_imp_weight_cap` is set (recommended value: 5). +- Ensure `rejection_sampling` is configured (e.g., `rejection_sampling: {level: token, action: mask, metric: ratio, upper: 5.0}`). - If deviation persists, reduce `max_head_offpolicyness` to decrease sample staleness. ### Sequence Length Metrics diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 9ccf70e141..03134d717b 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -78,6 +78,7 @@ For detailed examples, see the experiment configurations in the `examples/` dire - [MegatronEngine Configuration](section-megatron-engine) - [OpenAIProxy Configuration](section-open-ai-proxy) - [PerfTracer Configuration](section-perf-tracer) +- [RejectionSampling Configuration](section-rejection-sampling) - [Scheduler Configuration](section-scheduler) - [Scheduling Specification](section-scheduling) - [SchedulingStrategy](section-scheduling-strategy) @@ -332,72 +333,71 @@ Configuration for model optimization during training. Configuration for PPO actor model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| --------------------------- | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | -| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | -| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behave_imp_weight_cap` | float \| None | `5.0` | Filter out tokens/sequences where behave_imp_weight exceeds this cap when computing loss. Only effective when use_decoupled_loss=True (decoupled/async training). Must be > 1.0 when mode is not 'disabled'. Mode controlled by behave_imp_weight_mode (mask/truncate/disabled). | -| `behave_imp_weight_mode` | string | `"token_mask"` | Mode for importance weight filtering. Only effective when use_decoupled_loss=True (decoupled/async training). 'token_truncate': clamp token ratio to \[0, cap\]. 'token_mask': set token ratio to 0 where ratio > cap. 'sequence_truncate': clamp sequence ratio to \[0, cap\]. 'sequence_mask': set sequence ratio to 0 where ratio > cap. 'disabled': disable importance weight correction. **Choices:** `token_truncate`, `token_mask`, `sequence_truncate`, `sequence_mask`, `disabled` | -| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | -| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| Parameter | Type | Default | Description | +| --------------------------- | --------------------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | +| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | +| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `rejection_sampling` | [`RejectionSamplingConfig`](section-rejection-sampling) \| None | `None` | Rejection sampling configuration for filtering stale samples. None disables filtering (equivalent to old behave_imp_weight_mode='disabled'). Only effective when use_decoupled_loss=True. | +| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | +| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | (section-ppo-critic)= @@ -1003,6 +1003,47 @@ Configuration for perf tracer emission. | `profile_steps` | list of integer \| None | `None` | List of step numbers at which to capture detailed profiling traces. If None, no detailed profiling traces are captured. | | `session_tracer` | [`SessionTracerConfig`](section-session-tracer) \| None | `None` | Session tracing configuration. | +(section-rejection-sampling)= + +## RejectionSampling Configuration + +Unified configuration for sample filtering based on policy divergence. + +Filters tokens/sequences where the divergence between proximal policy and behavior +policy exceeds a threshold, via two action modes: - 'mask': zero out loss_mask +(rejection, exclude from gradient) - 'clamp': clamp importance weight to bounds +(truncation, bounded gradient) + +``` +Supports direct ratio bounds and KL divergence estimators (K1/K2/K3), +at both token-level and sequence-level granularity. + +Replaces the removed ``behave_imp_weight_cap`` and ``behave_imp_weight_mode``. + +Attributes: + level: Filtering granularity ('token' or 'sequence'). When ``level='sequence'`` + and ``metric='ratio'``, both the filtering decision and the correction + weight (behave_imp_weight) use the sequence-level geometric mean, + matching the old ``sequence_mask``/``sequence_truncate`` semantics. + action: Action mode ('mask' or 'clamp'). + metric: Divergence metric ('ratio', 'kl_k1', 'kl_k2', 'kl_k3'). + agg: Aggregation method for sequence-level ('sum', 'mean', 'max'). + For 'ratio' metric, aggregation is performed in log space (geometric + mean/sum) to avoid the "length trap" and match GSPO semantics. + For KL metrics, aggregation is arithmetic. + upper: Upper bound for filtering. + lower: Lower bound for filtering (optional). +``` + +| Parameter | Type | Default | Description | +| --------- | ------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | +| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | +| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | +| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | +| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | +| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | + (section-scheduler)= ## Scheduler Configuration @@ -1072,71 +1113,70 @@ Configuration for per-session lifecycle tracing. Configuration class: TeacherConfig -| Parameter | Type | Default | Description | -| --------------------------- | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | -| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | -| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behave_imp_weight_cap` | float \| None | `5.0` | Filter out tokens/sequences where behave_imp_weight exceeds this cap when computing loss. Only effective when use_decoupled_loss=True (decoupled/async training). Must be > 1.0 when mode is not 'disabled'. Mode controlled by behave_imp_weight_mode (mask/truncate/disabled). | -| `behave_imp_weight_mode` | string | `"token_mask"` | Mode for importance weight filtering. Only effective when use_decoupled_loss=True (decoupled/async training). 'token_truncate': clamp token ratio to \[0, cap\]. 'token_mask': set token ratio to 0 where ratio > cap. 'sequence_truncate': clamp sequence ratio to \[0, cap\]. 'sequence_mask': set sequence ratio to 0 where ratio > cap. 'disabled': disable importance weight correction. **Choices:** `token_truncate`, `token_mask`, `sequence_truncate`, `sequence_mask`, `disabled` | -| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | -| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | -| `rl_loss_weight` | float | `1.0` | RL loss weight | -| `distill_loss_weight` | float | `0.005` | Distillation loss weight | +| Parameter | Type | Default | Description | +| --------------------------- | --------------------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | +| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | +| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `rejection_sampling` | [`RejectionSamplingConfig`](section-rejection-sampling) \| None | `None` | Rejection sampling configuration for filtering stale samples. None disables filtering (equivalent to old behave_imp_weight_mode='disabled'). Only effective when use_decoupled_loss=True. | +| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | +| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| `rl_loss_weight` | float | `1.0` | RL loss weight | +| `distill_loss_weight` | float | `0.005` | Distillation loss weight | diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index b382cd26ba..a4ab9238e0 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -76,6 +76,7 @@ python3 train.py --config path/to/config.yaml actor.lr=1e-4 seed=42 - [MegatronEngine Configuration](section-megatron-engine) - [OpenAIProxy Configuration](section-open-ai-proxy) - [PerfTracer Configuration](section-perf-tracer) +- [RejectionSampling Configuration](section-rejection-sampling) - [Scheduler Configuration](section-scheduler) - [Scheduling Specification](section-scheduling) - [SchedulingStrategy](section-scheduling-strategy) @@ -330,72 +331,71 @@ Configuration for model optimization during training. Configuration for PPO actor model, a subclass of a TrainEngine. -| Parameter | Type | Default | Description | -| --------------------------- | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | -| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | -| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behave_imp_weight_cap` | float \| None | `5.0` | Filter out tokens/sequences where behave_imp_weight exceeds this cap when computing loss. Only effective when use_decoupled_loss=True (decoupled/async training). Must be > 1.0 when mode is not 'disabled'. Mode controlled by behave_imp_weight_mode (mask/truncate/disabled). | -| `behave_imp_weight_mode` | string | `"token_mask"` | Mode for importance weight filtering. Only effective when use_decoupled_loss=True (decoupled/async training). 'token_truncate': clamp token ratio to \[0, cap\]. 'token_mask': set token ratio to 0 where ratio > cap. 'sequence_truncate': clamp sequence ratio to \[0, cap\]. 'sequence_mask': set sequence ratio to 0 where ratio > cap. 'disabled': disable importance weight correction. **Choices:** `token_truncate`, `token_mask`, `sequence_truncate`, `sequence_mask`, `disabled` | -| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | -| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| Parameter | Type | Default | Description | +| --------------------------- | --------------------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | +| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | +| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `rejection_sampling` | [`RejectionSamplingConfig`](section-rejection-sampling) \| None | `None` | Rejection sampling configuration for filtering stale samples. None disables filtering (equivalent to old behave_imp_weight_mode='disabled'). Only effective when use_decoupled_loss=True. | +| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | +| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | (section-ppo-critic)= @@ -1001,6 +1001,47 @@ Configuration for perf tracer emission. | `profile_steps` | list of integer \| None | `None` | List of step numbers at which to capture detailed profiling traces. If None, no detailed profiling traces are captured. | | `session_tracer` | [`SessionTracerConfig`](section-session-tracer) \| None | `None` | Session tracing configuration. | +(section-rejection-sampling)= + +## RejectionSampling Configuration + +Unified configuration for sample filtering based on policy divergence. + +Filters tokens/sequences where the divergence between proximal policy and behavior +policy exceeds a threshold, via two action modes: - 'mask': zero out loss_mask +(rejection, exclude from gradient) - 'clamp': clamp importance weight to bounds +(truncation, bounded gradient) + +``` +Supports direct ratio bounds and KL divergence estimators (K1/K2/K3), +at both token-level and sequence-level granularity. + +Replaces the removed ``behave_imp_weight_cap`` and ``behave_imp_weight_mode``. + +Attributes: + level: Filtering granularity ('token' or 'sequence'). When ``level='sequence'`` + and ``metric='ratio'``, both the filtering decision and the correction + weight (behave_imp_weight) use the sequence-level geometric mean, + matching the old ``sequence_mask``/``sequence_truncate`` semantics. + action: Action mode ('mask' or 'clamp'). + metric: Divergence metric ('ratio', 'kl_k1', 'kl_k2', 'kl_k3'). + agg: Aggregation method for sequence-level ('sum', 'mean', 'max'). + For 'ratio' metric, aggregation is performed in log space (geometric + mean/sum) to avoid the "length trap" and match GSPO semantics. + For KL metrics, aggregation is arithmetic. + upper: Upper bound for filtering. + lower: Lower bound for filtering (optional). +``` + +| Parameter | Type | Default | Description | +| --------- | ------------- | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `level` | string | `"token"` | Filtering granularity. 'token': per-token filtering (each token judged independently). 'sequence': per-sequence filtering (all tokens in a sequence share the same fate). When metric='ratio', both the filtering decision and the correction weight (behave_imp_weight) operate at sequence level using the geometric mean. **Choices:** `token`, `sequence` | +| `action` | string | `"mask"` | Action to take when metric exceeds threshold. 'mask': zero out loss_mask for filtered tokens/sequences (rejection, completely excludes from gradient computation). 'clamp': clamp importance weight to \[lower, upper\] bounds (truncation, tokens still participate in gradient but with bounded weight). **Choices:** `mask`, `clamp` | +| `metric` | string | `"ratio"` | Divergence metric for filtering. 'ratio': direct importance ratio π_proximal/π_behave. 'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). 'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. 'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator. **Choices:** `ratio`, `kl_k1`, `kl_k2`, `kl_k3` | +| `agg` | string | `"mean"` | Aggregation method for sequence-level filtering. Only used when level='sequence'. For 'ratio' metric, aggregation is in log space: 'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean (length-invariant, consistent with GSPO). For KL metrics, aggregation is arithmetic: 'sum' = sum(kl_i), 'mean' = mean(kl_i). 'max': max of per-token metric values (most conservative). **Choices:** `sum`, `mean`, `max` | +| `upper` | float | `5.0` | Upper bound for filtering. Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0. | +| `lower` | float \| None | `None` | Lower bound for filtering (optional). None means no lower bound. For 'ratio' metric: typical value is 0.5 (filter out tokens where policy probability dropped significantly). Must be > 0. For 'kl_k1' metric: can be used to filter negative KL estimates. | + (section-scheduler)= ## Scheduler Configuration @@ -1070,71 +1111,70 @@ Configuration for per-session lifecycle tracing. Configuration class: TeacherConfig -| Parameter | Type | Default | Description | -| --------------------------- | --------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Parameter data type. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | -| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | -| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `behave_imp_weight_cap` | float \| None | `5.0` | Filter out tokens/sequences where behave_imp_weight exceeds this cap when computing loss. Only effective when use_decoupled_loss=True (decoupled/async training). Must be > 1.0 when mode is not 'disabled'. Mode controlled by behave_imp_weight_mode (mask/truncate/disabled). | -| `behave_imp_weight_mode` | string | `"token_mask"` | Mode for importance weight filtering. Only effective when use_decoupled_loss=True (decoupled/async training). 'token_truncate': clamp token ratio to \[0, cap\]. 'token_mask': set token ratio to 0 where ratio > cap. 'sequence_truncate': clamp sequence ratio to \[0, cap\]. 'sequence_mask': set sequence ratio to 0 where ratio > cap. 'disabled': disable importance weight correction. **Choices:** `token_truncate`, `token_mask`, `sequence_truncate`, `sequence_mask`, `disabled` | -| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | -| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | -| `rl_loss_weight` | float | `1.0` | RL loss weight | -| `distill_loss_weight` | float | `0.005` | Distillation loss weight | +| Parameter | Type | Default | Description | +| --------------------------- | --------------------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `experiment_name` | string | **Required** | - | +| `trial_name` | string | **Required** | - | +| `path` | string | `""` | Path to HuggingFace checkpoint | +| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | +| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | +| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | +| `is_critic` | boolean | `False` | Whether to use a critic/reward model | +| `temperature` | float | `1.0` | Temperature during generation. | +| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | +| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | +| `disable_dropout` | boolean | `False` | Disable dropout layers during training | +| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | +| `dtype` | string | `"bfloat16"` | Parameter data type. | +| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | +| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | +| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | +| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | +| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | +| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | +| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | +| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | +| `lora_rank` | integer | `32` | lora rank | +| `lora_alpha` | integer | `16` | lora alpha | +| `target_modules` | list of string | **Required** | lora target_modules. | +| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | +| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | +| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | +| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | +| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | +| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | +| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | +| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | +| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | +| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | +| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | +| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | +| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | +| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | +| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | +| `reward_scaling` | float | `1.0` | Reward scaling factor | +| `reward_bias` | float | `0.0` | Reward bias | +| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | +| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | +| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | +| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | +| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | +| `discount` | float | `1.0` | Discount factor for future rewards | +| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | +| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | +| `kl_ctl` | float | `0.1` | KL divergence coefficient | +| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | +| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | +| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | +| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | +| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | +| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | +| `rejection_sampling` | [`RejectionSamplingConfig`](section-rejection-sampling) \| None | `None` | Rejection sampling configuration for filtering stale samples. None disables filtering (equivalent to old behave_imp_weight_mode='disabled'). Only effective when use_decoupled_loss=True. | +| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | +| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | +| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | +| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | +| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | +| `rl_loss_weight` | float | `1.0` | RL loss weight | +| `distill_loss_weight` | float | `0.005` | Distillation loss weight | diff --git a/examples/agent_workflow/config.yaml b/examples/agent_workflow/config.yaml index 4cf35018ac..c09278170e 100644 --- a/examples/agent_workflow/config.yaml +++ b/examples/agent_workflow/config.yaml @@ -78,7 +78,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: null adv_norm: mean_level: batch diff --git a/examples/agent_workflow/config_claude.yaml b/examples/agent_workflow/config_claude.yaml index 7570876be7..1531c37606 100644 --- a/examples/agent_workflow/config_claude.yaml +++ b/examples/agent_workflow/config_claude.yaml @@ -79,7 +79,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: null adv_norm: mean_level: batch diff --git a/examples/camel/config.yaml b/examples/camel/config.yaml index ce5d9f1586..1c720f3fe4 100644 --- a/examples/camel/config.yaml +++ b/examples/camel/config.yaml @@ -71,7 +71,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 adv_norm: mean_level: batch std_level: batch diff --git a/examples/countdown/train_config.yaml b/examples/countdown/train_config.yaml index 598bbe4620..7694d1ec17 100644 --- a/examples/countdown/train_config.yaml +++ b/examples/countdown/train_config.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/distillation/gsm8k_grpo_distill.yaml b/examples/distillation/gsm8k_grpo_distill.yaml index 9c666bae9d..564537de30 100644 --- a/examples/distillation/gsm8k_grpo_distill.yaml +++ b/examples/distillation/gsm8k_grpo_distill.yaml @@ -67,7 +67,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/experimental/prox_approx/gsm8k_grpo_prox_approx.yaml b/examples/experimental/prox_approx/gsm8k_grpo_prox_approx.yaml index 13cdf14d0a..e0be7fc388 100644 --- a/examples/experimental/prox_approx/gsm8k_grpo_prox_approx.yaml +++ b/examples/experimental/prox_approx/gsm8k_grpo_prox_approx.yaml @@ -67,7 +67,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: false use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/experimental/prox_approx/gsm8k_grpo_prox_approx_eval.yaml b/examples/experimental/prox_approx/gsm8k_grpo_prox_approx_eval.yaml index b067512313..85b9f4e419 100644 --- a/examples/experimental/prox_approx/gsm8k_grpo_prox_approx_eval.yaml +++ b/examples/experimental/prox_approx/gsm8k_grpo_prox_approx_eval.yaml @@ -67,7 +67,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/boba_grpo.yaml b/examples/math/boba_grpo.yaml index 840846f5ad..694512fd43 100644 --- a/examples/math/boba_grpo.yaml +++ b/examples/math/boba_grpo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 4 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_dapo_dynamic_bs.yaml b/examples/math/gsm8k_dapo_dynamic_bs.yaml index 67cff469b3..ef5074f118 100644 --- a/examples/math/gsm8k_dapo_dynamic_bs.yaml +++ b/examples/math/gsm8k_dapo_dynamic_bs.yaml @@ -73,7 +73,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_drgrpo.yaml b/examples/math/gsm8k_drgrpo.yaml index b81de4f7d9..2991358056 100644 --- a/examples/math/gsm8k_drgrpo.yaml +++ b/examples/math/gsm8k_drgrpo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: null diff --git a/examples/math/gsm8k_grpo.yaml b/examples/math/gsm8k_grpo.yaml index 2fb7c001ce..23a01ca219 100644 --- a/examples/math/gsm8k_grpo.yaml +++ b/examples/math/gsm8k_grpo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_grpo_cpu.yaml b/examples/math/gsm8k_grpo_cpu.yaml index f484cd5911..add22b8ae6 100644 --- a/examples/math/gsm8k_grpo_cpu.yaml +++ b/examples/math/gsm8k_grpo_cpu.yaml @@ -72,7 +72,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_grpo_lora.yaml b/examples/math/gsm8k_grpo_lora.yaml index 851463e054..4a4473efd4 100644 --- a/examples/math/gsm8k_grpo_lora.yaml +++ b/examples/math/gsm8k_grpo_lora.yaml @@ -70,7 +70,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_grpo_megatron.yaml b/examples/math/gsm8k_grpo_megatron.yaml index 68484ee695..2482b297bc 100644 --- a/examples/math/gsm8k_grpo_megatron.yaml +++ b/examples/math/gsm8k_grpo_megatron.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_grpo_megatron_fp8.yaml b/examples/math/gsm8k_grpo_megatron_fp8.yaml index 55079e6373..376a54ab1d 100644 --- a/examples/math/gsm8k_grpo_megatron_fp8.yaml +++ b/examples/math/gsm8k_grpo_megatron_fp8.yaml @@ -64,7 +64,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_grpo_npu.yaml b/examples/math/gsm8k_grpo_npu.yaml index da1db38b57..112e5fce05 100644 --- a/examples/math/gsm8k_grpo_npu.yaml +++ b/examples/math/gsm8k_grpo_npu.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_gspo.yaml b/examples/math/gsm8k_gspo.yaml index 2da7012ef5..6caf80a5e3 100644 --- a/examples/math/gsm8k_gspo.yaml +++ b/examples/math/gsm8k_gspo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_liteppo.yaml b/examples/math/gsm8k_liteppo.yaml index a26fc869f3..40499d232c 100644 --- a/examples/math/gsm8k_liteppo.yaml +++ b/examples/math/gsm8k_liteppo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: batch diff --git a/examples/math/gsm8k_m2po.yaml b/examples/math/gsm8k_m2po.yaml index 9fe1a0853f..ae8fd03641 100644 --- a/examples/math/gsm8k_m2po.yaml +++ b/examples/math/gsm8k_m2po.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/math/gsm8k_ppo.yaml b/examples/math/gsm8k_ppo.yaml index af65623f61..f3544db2c8 100644 --- a/examples/math/gsm8k_ppo.yaml +++ b/examples/math/gsm8k_ppo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/gsm8k_ppo_megatron.yaml b/examples/math/gsm8k_ppo_megatron.yaml index 8afe47ba33..e75a341906 100644 --- a/examples/math/gsm8k_ppo_megatron.yaml +++ b/examples/math/gsm8k_ppo_megatron.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/gsm8k_reinforce.yaml b/examples/math/gsm8k_reinforce.yaml index 88d5dc91bc..6944cd3bdf 100644 --- a/examples/math/gsm8k_reinforce.yaml +++ b/examples/math/gsm8k_reinforce.yaml @@ -69,7 +69,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 adv_norm: mean_level: batch std_level: batch diff --git a/examples/math/gsm8k_reinforce_baseline.yaml b/examples/math/gsm8k_reinforce_baseline.yaml index 5661925dbe..cfc144b92a 100644 --- a/examples/math/gsm8k_reinforce_baseline.yaml +++ b/examples/math/gsm8k_reinforce_baseline.yaml @@ -69,7 +69,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: null diff --git a/examples/math/gsm8k_rloo.yaml b/examples/math/gsm8k_rloo.yaml index 40cf0dd0ba..867a6552f7 100644 --- a/examples/math/gsm8k_rloo.yaml +++ b/examples/math/gsm8k_rloo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group mean_leave1out: true diff --git a/examples/math/gsm8k_sapo.yaml b/examples/math/gsm8k_sapo.yaml index 343313b0e0..da7825e150 100644 --- a/examples/math/gsm8k_sapo.yaml +++ b/examples/math/gsm8k_sapo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: false use_decoupled_loss: false # SAPO requires decoupled loss to be disabled - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/multi_turn_math/gsm8k_grpo_mt.yaml b/examples/multi_turn_math/gsm8k_grpo_mt.yaml index e22dac1c6d..64050b3fdf 100644 --- a/examples/multi_turn_math/gsm8k_grpo_mt.yaml +++ b/examples/multi_turn_math/gsm8k_grpo_mt.yaml @@ -72,7 +72,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/openai_agents/config.yaml b/examples/openai_agents/config.yaml index 5b618cc0b8..be10f2cecc 100644 --- a/examples/openai_agents/config.yaml +++ b/examples/openai_agents/config.yaml @@ -71,7 +71,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 adv_norm: mean_level: batch std_level: batch diff --git a/examples/openclaw/config.yaml b/examples/openclaw/config.yaml index 0bd034d59d..e84e5cb297 100644 --- a/examples/openclaw/config.yaml +++ b/examples/openclaw/config.yaml @@ -75,7 +75,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/search_agent/local_1.5b_example.yaml b/examples/search_agent/local_1.5b_example.yaml index f9100a92e6..eef671a21c 100644 --- a/examples/search_agent/local_1.5b_example.yaml +++ b/examples/search_agent/local_1.5b_example.yaml @@ -69,7 +69,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/search_agent/tongyi_deepresearch/config.yaml b/examples/search_agent/tongyi_deepresearch/config.yaml index 86d50347f9..a347c3c846 100644 --- a/examples/search_agent/tongyi_deepresearch/config.yaml +++ b/examples/search_agent/tongyi_deepresearch/config.yaml @@ -71,7 +71,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 megatron: use_deterministic_algorithms: true recompute_granularity: full diff --git a/examples/skypilot/gsm8k_grpo_ray.yaml b/examples/skypilot/gsm8k_grpo_ray.yaml index e92c647247..0865a300b7 100644 --- a/examples/skypilot/gsm8k_grpo_ray.yaml +++ b/examples/skypilot/gsm8k_grpo_ray.yaml @@ -65,7 +65,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/tau2/config_1.7b_airline.yaml b/examples/tau2/config_1.7b_airline.yaml index 3e51476a16..2db6facb2e 100644 --- a/examples/tau2/config_1.7b_airline.yaml +++ b/examples/tau2/config_1.7b_airline.yaml @@ -77,7 +77,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: null adv_norm: mean_level: batch diff --git a/examples/tau2/config_235b_moe_airline.yaml b/examples/tau2/config_235b_moe_airline.yaml index c7d7744a39..86f06c695d 100644 --- a/examples/tau2/config_235b_moe_airline.yaml +++ b/examples/tau2/config_235b_moe_airline.yaml @@ -113,7 +113,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 discount: 1.0 gae_lambda: 1.0 adv_norm: diff --git a/examples/tau2/config_30b_moe_airline.yaml b/examples/tau2/config_30b_moe_airline.yaml index 326bd7dc41..618c15ee02 100644 --- a/examples/tau2/config_30b_moe_airline.yaml +++ b/examples/tau2/config_30b_moe_airline.yaml @@ -113,7 +113,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 discount: 1.0 gae_lambda: 1.0 adv_norm: diff --git a/examples/tau2/config_8b_airline.yaml b/examples/tau2/config_8b_airline.yaml index de5b55f4f5..756038c324 100644 --- a/examples/tau2/config_8b_airline.yaml +++ b/examples/tau2/config_8b_airline.yaml @@ -77,7 +77,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: null adv_norm: mean_level: batch diff --git a/examples/tir/tir_math_config.yaml b/examples/tir/tir_math_config.yaml index 4e93e3925c..6d2047f8b2 100644 --- a/examples/tir/tir_math_config.yaml +++ b/examples/tir/tir_math_config.yaml @@ -65,7 +65,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/vlm/clevr_count_70k_grpo.yaml b/examples/vlm/clevr_count_70k_grpo.yaml index be4b29d87c..1b18f2081d 100644 --- a/examples/vlm/clevr_count_70k_grpo.yaml +++ b/examples/vlm/clevr_count_70k_grpo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/vlm/geometry3k_grpo.yaml b/examples/vlm/geometry3k_grpo.yaml index ab56c7b84c..6bb1e8028a 100644 --- a/examples/vlm/geometry3k_grpo.yaml +++ b/examples/vlm/geometry3k_grpo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/vlm_npu/qwen2_5_vl_3b_geometry3k_grpo.yaml b/examples/vlm_npu/qwen2_5_vl_3b_geometry3k_grpo.yaml index ab56c7b84c..6bb1e8028a 100644 --- a/examples/vlm_npu/qwen2_5_vl_3b_geometry3k_grpo.yaml +++ b/examples/vlm_npu/qwen2_5_vl_3b_geometry3k_grpo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/examples/vlm_npu/qwen3_vl_2b_geometry3k_grpo.yaml b/examples/vlm_npu/qwen3_vl_2b_geometry3k_grpo.yaml index 3e7bbbad5c..7529f02994 100644 --- a/examples/vlm_npu/qwen3_vl_2b_geometry3k_grpo.yaml +++ b/examples/vlm_npu/qwen3_vl_2b_geometry3k_grpo.yaml @@ -68,7 +68,9 @@ actor: ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group diff --git a/tests/experimental/archon/test_grpo.py b/tests/experimental/archon/test_grpo.py index c1054c0344..d1c1e58c4e 100644 --- a/tests/experimental/archon/test_grpo.py +++ b/tests/experimental/archon/test_grpo.py @@ -342,7 +342,6 @@ def test_grpo_loss_fn_consistency(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, importance_sampling_level="token", current_version=1, prox_logp_method="recompute", @@ -360,7 +359,6 @@ def test_grpo_loss_fn_consistency(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, importance_sampling_level="token", current_version=1, prox_logp_method="recompute", @@ -431,7 +429,6 @@ def test_ppo_loss_edge_cases(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, importance_sampling_level="token", current_version=1, prox_logp_method="recompute", @@ -601,7 +598,6 @@ def test_reward_signal_propagation(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, importance_sampling_level="token", current_version=1, prox_logp_method="recompute", @@ -625,7 +621,6 @@ def test_reward_signal_propagation(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, importance_sampling_level="token", current_version=1, prox_logp_method="recompute", diff --git a/tests/test_functional.py b/tests/test_functional.py index 09c638a1e7..1fff734b81 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,8 +1,8 @@ import pytest import torch +from areal.api.cli_args import RejectionSamplingConfig from areal.utils.functional import ( - compute_behave_imp_weight, ppo_actor_loss_fn, sapo_loss_fn, ) @@ -422,8 +422,8 @@ def test_sequence_level_with_dual_clip(self): assert not torch.isnan(loss) assert not torch.isinf(loss) - def test_sequence_level_with_behave_imp_weight_cap(self): - """Test sequence-level with behavior importance weight capping.""" + def test_sequence_level_with_rejection_sampling(self): + """Test sequence-level with rejection sampling (replaces behave_imp_weight_cap).""" batch_size = 2 seq_len = 4 @@ -433,6 +433,7 @@ def test_sequence_level_with_behave_imp_weight_cap(self): advantages = torch.randn(batch_size, seq_len) loss_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) + rs_config = RejectionSamplingConfig(metric="ratio", upper=10.0) loss, stat = ppo_actor_loss_fn( logprobs=logprobs, proximal_logprobs=proximal_logprobs, @@ -440,7 +441,7 @@ def test_sequence_level_with_behave_imp_weight_cap(self): advantages=advantages, eps_clip=0.2, loss_mask=loss_mask, - behave_imp_weight_cap=10.0, + rejection_sampling=rs_config, importance_sampling_level="sequence", ) @@ -448,11 +449,49 @@ def test_sequence_level_with_behave_imp_weight_cap(self): assert "behave_imp_weight" in stat assert "behave_approx_kl" in stat assert "behave_mask" in stat + assert "filtered_fraction" in stat # Verify loss is finite assert not torch.isnan(loss) assert not torch.isinf(loss) + def test_no_rejection_sampling_excludes_behave_stats(self): + """When rejection_sampling=None, stat should not contain behave keys.""" + batch_size = 2 + seq_len = 4 + + logprobs = torch.randn(batch_size, seq_len) + proximal_logprobs = torch.randn(batch_size, seq_len) + old_logprobs = torch.randn(batch_size, seq_len) + advantages = torch.randn(batch_size, seq_len) + loss_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) + + loss, stat = ppo_actor_loss_fn( + logprobs=logprobs, + proximal_logprobs=proximal_logprobs, + old_logprobs=old_logprobs, + advantages=advantages, + eps_clip=0.2, + loss_mask=loss_mask, + rejection_sampling=None, + importance_sampling_level="sequence", + ) + + # Core PPO stats should be present + assert "importance_weight" in stat + assert "approx_kl" in stat + assert "clip_mask" in stat + + # Rejection sampling stats should NOT be present + assert "behave_imp_weight" not in stat + assert "behave_approx_kl" not in stat + assert "behave_mask" not in stat + assert "filtered_fraction" not in stat + + # Loss should be finite + assert not torch.isnan(loss) + assert not torch.isinf(loss) + def test_sequence_level_vs_token_level_different(self): """Verify that sequence-level and token-level produce different results.""" batch_size = 2 @@ -1143,96 +1182,3 @@ def test_loss_mask(self): expected_loss = -(gate * advantages[:, :2]).sum() / 2 assert torch.allclose(loss, expected_loss, atol=1e-5) - - -class TestComputeBehaveImpWeight: - """Test cases for compute_behave_imp_weight function.""" - - @pytest.fixture - def basic_2d_data(self): - """Basic 2D tensor test data.""" - batch_size = 4 - seq_len = 8 - - return { - "proximal_logprobs": torch.randn(batch_size, seq_len), - "old_logprobs": torch.randn(batch_size, seq_len), - "loss_mask": torch.ones(batch_size, seq_len, dtype=torch.bool), - "cu_seqlens": None, - } - - def test_disabled_mode_raises_error(self, basic_2d_data): - """Test that disabled mode raises ValueError.""" - with pytest.raises( - ValueError, match="should not be called with mode='disabled'" - ): - compute_behave_imp_weight( - proximal_logprobs=basic_2d_data["proximal_logprobs"], - old_logprobs=basic_2d_data["old_logprobs"], - loss_mask=basic_2d_data["loss_mask"], - cu_seqlens=basic_2d_data["cu_seqlens"], - behave_imp_weight_mode="disabled", - behave_imp_weight_cap=None, - ) - - def test_token_mask_mode(self, basic_2d_data): - """Test token_mask mode computes correct weights.""" - behave_imp_weight, behave_approx_kl, behave_mask = compute_behave_imp_weight( - proximal_logprobs=basic_2d_data["proximal_logprobs"], - old_logprobs=basic_2d_data["old_logprobs"], - loss_mask=basic_2d_data["loss_mask"], - cu_seqlens=basic_2d_data["cu_seqlens"], - behave_imp_weight_mode="token_mask", - behave_imp_weight_cap=5.0, - ) - - assert behave_imp_weight.shape == basic_2d_data["loss_mask"].shape - assert behave_approx_kl.shape == basic_2d_data["proximal_logprobs"].shape - assert behave_mask.shape == basic_2d_data["loss_mask"].shape - assert behave_mask.dtype == torch.bool - - def test_token_truncate_mode(self, basic_2d_data): - """Test token_truncate mode clamps weights correctly.""" - cap = 3.0 - behave_imp_weight, _, _ = compute_behave_imp_weight( - proximal_logprobs=basic_2d_data["proximal_logprobs"], - old_logprobs=basic_2d_data["old_logprobs"], - loss_mask=basic_2d_data["loss_mask"], - cu_seqlens=basic_2d_data["cu_seqlens"], - behave_imp_weight_mode="token_truncate", - behave_imp_weight_cap=cap, - ) - - # All weights should be clamped to [0, cap] - assert (behave_imp_weight >= 0).all() - assert (behave_imp_weight <= cap).all() - - def test_sequence_mask_mode(self, basic_2d_data): - """Test sequence_mask mode computes sequence-level weights.""" - behave_imp_weight, behave_approx_kl, behave_mask = compute_behave_imp_weight( - proximal_logprobs=basic_2d_data["proximal_logprobs"], - old_logprobs=basic_2d_data["old_logprobs"], - loss_mask=basic_2d_data["loss_mask"], - cu_seqlens=basic_2d_data["cu_seqlens"], - behave_imp_weight_mode="sequence_mask", - behave_imp_weight_cap=5.0, - ) - - assert behave_imp_weight.shape == basic_2d_data["loss_mask"].shape - assert behave_approx_kl.shape == basic_2d_data["proximal_logprobs"].shape - assert behave_mask.shape == basic_2d_data["loss_mask"].shape - - def test_without_cap(self, basic_2d_data): - """Test that mode works without cap (cap=None).""" - behave_imp_weight, _, _ = compute_behave_imp_weight( - proximal_logprobs=basic_2d_data["proximal_logprobs"], - old_logprobs=basic_2d_data["old_logprobs"], - loss_mask=basic_2d_data["loss_mask"], - cu_seqlens=basic_2d_data["cu_seqlens"], - behave_imp_weight_mode="token_mask", - behave_imp_weight_cap=None, - ) - - assert behave_imp_weight.shape == basic_2d_data["loss_mask"].shape - assert not torch.isnan(behave_imp_weight).any() - assert not torch.isinf(behave_imp_weight).any() diff --git a/tests/test_ppo_stats.py b/tests/test_ppo_stats.py index b0ad4427ce..f2a05cb2e2 100644 --- a/tests/test_ppo_stats.py +++ b/tests/test_ppo_stats.py @@ -76,7 +76,6 @@ def test_grpo_loss_fn_uses_full_cu_seqlens_for_n_tokens(): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, ) n_tokens = next( @@ -130,7 +129,6 @@ def test_grpo_loss_fn_uses_packed_denominator_for_tree_vocab_stats(): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, vocab_min_logits=torch.zeros(3), vocab_max_logits=torch.zeros(3), ) diff --git a/tests/test_prox_approx.py b/tests/test_prox_approx.py index 4d0a42c5cd..03ed2dcc11 100644 --- a/tests/test_prox_approx.py +++ b/tests/test_prox_approx.py @@ -259,7 +259,6 @@ def test_approximation_metrics_only_with_metrics_method(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="recompute", # Not metrics ) @@ -451,7 +450,6 @@ def test_grpo_loss_fn_detects_none_prox_logp(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="recompute", ) @@ -485,7 +483,6 @@ def test_grpo_loss_fn_requires_versions_when_prox_logp_none(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="loglinear", ) @@ -515,7 +512,6 @@ def test_grpo_loss_fn_computes_approximation_when_prox_logp_none(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="loglinear", ) @@ -549,7 +545,6 @@ def test_grpo_loss_fn_works_with_tensor_prox_logp(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="loglinear", ) @@ -582,7 +577,6 @@ def test_grpo_loss_fn_metrics_disabled_when_prox_logp_none(self): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="loglinear", ) @@ -807,7 +801,6 @@ def mock_stat(**kwargs): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="loglinear", ) @@ -860,7 +853,6 @@ def mock_stat(**kwargs): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="recompute", ) @@ -911,7 +903,6 @@ def mock_stat(**kwargs): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="metrics", ) @@ -976,7 +967,6 @@ def mock_stat(**kwargs): eps_clip=0.2, eps_clip_higher=None, c_clip=None, - behave_imp_weight_cap=None, current_version=5, prox_logp_method="metrics", ) diff --git a/tests/test_rejection_sampling.py b/tests/test_rejection_sampling.py new file mode 100644 index 0000000000..4c1dbb8c40 --- /dev/null +++ b/tests/test_rejection_sampling.py @@ -0,0 +1,839 @@ +import pytest +import torch + +from areal.api.cli_args import RejectionSamplingConfig +from areal.utils.functional import apply_rejection_sampling + + +class TestRejectionSamplingConfig: + """Tests for RejectionSamplingConfig validation.""" + + def test_ratio_upper_must_exceed_one(self): + """ratio metric with upper <= 1.0 should raise ValueError.""" + with pytest.raises(ValueError, match="upper must be > 1.0"): + RejectionSamplingConfig(metric="ratio", upper=1.0) + + def test_ratio_lower_must_be_positive(self): + """ratio metric with lower <= 0 should raise ValueError.""" + with pytest.raises(ValueError, match="lower must be positive"): + RejectionSamplingConfig(metric="ratio", lower=-0.1, upper=5.0) + + def test_kl_upper_must_be_positive(self): + """KL metrics with upper <= 0 should raise ValueError.""" + with pytest.raises(ValueError, match="upper must be positive"): + RejectionSamplingConfig(metric="kl_k2", upper=0.0) + + def test_agg_warning_for_token_level(self): + """agg != 'mean' with level='token' should warn.""" + with pytest.warns(UserWarning, match="agg=.*is ignored"): + RejectionSamplingConfig(level="token", agg="max", metric="ratio", upper=5.0) + + def test_clamp_only_supports_ratio_metric(self): + """action='clamp' with non-ratio metric should raise ValueError.""" + with pytest.raises( + ValueError, match="action='clamp' only supports metric='ratio'" + ): + RejectionSamplingConfig(action="clamp", metric="kl_k2", upper=1.0) + + def test_clamp_sets_default_lower_to_zero(self): + """action='clamp' without explicit lower should default to 0.0.""" + config = RejectionSamplingConfig(action="clamp", metric="ratio", upper=5.0) + assert config.lower == 0.0 + + +class TestRejectionSamplingMask: + """Tests for apply_rejection_sampling with action='mask'.""" + + def test_ratio_upper_bound_filters_high_ratio(self): + """Token with ratio > upper should be filtered.""" + config = RejectionSamplingConfig(level="token", metric="ratio", upper=2.0) + # ratio = exp(1) ~ 2.72 > 2.0 + proximal_logprobs = torch.tensor([[0.0, 1.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert result.loss_mask[0, 0] == 1.0 # ratio = 1, keep + assert result.loss_mask[0, 1] == 0.0 # ratio ~ 2.72 > 2.0, filter + assert result.loss_mask[0, 2] == 1.0 # ratio = 1, keep + assert result.filtered_fraction > 0 + + def test_ratio_lower_bound_filters_low_ratio(self): + """Token with ratio < lower should be filtered.""" + config = RejectionSamplingConfig( + level="token", metric="ratio", lower=0.5, upper=2.0 + ) + # ratio = exp(-1) ~ 0.37 < 0.5 + proximal_logprobs = torch.tensor([[0.0, -1.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert result.loss_mask[0, 0] == 1.0 # ratio = 1, keep + assert result.loss_mask[0, 1] == 0.0 # ratio ~ 0.37 < 0.5, filter + assert result.loss_mask[0, 2] == 1.0 # ratio = 1, keep + + def test_sequence_ratio_mask_uniform_weight(self): + """Sequence-level ratio mask: behave_imp_weight = geometric mean for all tokens.""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="ratio", upper=5.0 + ) + # log_ratios [0.0, 0.5, 1.0], geo_mean = exp(mean([0, 0.5, 1.0])) = exp(0.5) + proximal_logprobs = torch.tensor([[0.0, 0.5, 1.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, old_logprobs, loss_mask, cu_seqlens=None, config=config + ) + + expected_weight = torch.exp(torch.tensor(0.5)) + torch.testing.assert_close( + result.behave_imp_weight[0], + expected_weight.expand(3), + rtol=1e-5, + atol=1e-5, + ) + + def test_sequence_ratio_mask_uniform_weight_1d_packed(self): + """1D packed: sequence-level ratio mask should also use uniform weight.""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="ratio", upper=5.0 + ) + # Seq 0 (3 tokens): log_ratios [0.0, 0.6, 0.3], geo_mean = exp(0.3) + # Seq 1 (2 tokens): log_ratios [0.0, 0.0], geo_mean = exp(0) = 1.0 + proximal_logprobs = torch.tensor([0.0, 0.6, 0.3, 0.0, 0.0]) + old_logprobs = torch.zeros(5) + loss_mask = torch.ones(5) + cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=cu_seqlens, + config=config, + ) + + geo_mean_seq0 = torch.exp(torch.tensor(0.3)) + torch.testing.assert_close( + result.behave_imp_weight[:3], + geo_mean_seq0.expand(3), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + result.behave_imp_weight[3:], + torch.ones(2), + rtol=1e-5, + atol=1e-5, + ) + + def test_sequence_kl_metric_keeps_per_token_weight(self): + """Sequence-level KL metric: behave_imp_weight stays per-token (not uniform).""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="kl_k2", upper=5.0 + ) + # log_ratios [0.0, 1.0, 0.0] -> per-token ratios [1.0, e, 1.0] + proximal_logprobs = torch.tensor([[0.0, 1.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, old_logprobs, loss_mask, cu_seqlens=None, config=config + ) + + expected = torch.exp(torch.tensor([0.0, 1.0, 0.0])) + torch.testing.assert_close( + result.behave_imp_weight[0], expected, rtol=1e-5, atol=1e-5 + ) + + def test_kl_k2_sequence_mean_keeps_clean_sequences(self): + """Sequence-level mean KL K2 should keep sequences below threshold.""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="kl_k2", upper=0.5 + ) + # Sequence 0: one token with log_ratio=1, KL_k2 = 0.5*1^2 = 0.5 + # mean KL = (0 + 0.5 + 0) / 3 = 0.167 < 0.5, keep + # Sequence 1: all zeros, mean KL = 0, keep + proximal_logprobs = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + loss_mask = torch.ones(2, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert torch.all(result.loss_mask[0] == 1.0) + assert torch.all(result.loss_mask[1] == 1.0) + + def test_kl_k2_sequence_mean_filters_stale_sequence(self): + """Sequence with high mean KL should be fully filtered.""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="kl_k2", upper=0.1 + ) + # Sequence 0: token with log_ratio=2, KL_k2 = 0.5*4 = 2.0 + # mean KL = (0 + 2.0 + 0) / 3 = 0.667 > 0.1, filter entire sequence + proximal_logprobs = torch.tensor([[0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + loss_mask = torch.ones(2, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert torch.all(result.loss_mask[0] == 0.0) # entire sequence filtered + assert torch.all(result.loss_mask[1] == 1.0) # clean sequence kept + + def test_packed_1d_format_with_cu_seqlens(self): + """1D packed format should work with cu_seqlens.""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="kl_k2", upper=0.1 + ) + # Two sequences packed: [seq0: 3 tokens, seq1: 2 tokens] + proximal_logprobs = torch.tensor([0.0, 2.0, 0.0, 0.0, 0.0]) + old_logprobs = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]) + loss_mask = torch.ones(5) + cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=cu_seqlens, + config=config, + ) + + # Sequence 0: mean KL > 0.1, entire sequence filtered + assert torch.all(result.loss_mask[:3] == 0.0) + # Sequence 1: mean KL = 0, kept + assert torch.all(result.loss_mask[3:] == 1.0) + + def test_padding_tokens_not_counted(self): + """Padding tokens (loss_mask=0) should not affect filtering.""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="kl_k2", upper=0.5 + ) + # Only first token is valid, with KL_k2 = 0.5 * 2^2 = 2.0 > 0.5 + proximal_logprobs = torch.tensor([[2.0, 0.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.tensor([[1.0, 0.0, 0.0]]) # only first token valid + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # mean KL = 2.0 / 1 = 2.0 > 0.5, filter + assert result.loss_mask[0, 0] == 0.0 + + +class TestRejectionSamplingClamp: + """Tests for apply_rejection_sampling with action='clamp'.""" + + def test_clamp_does_not_modify_loss_mask(self): + """Clamp mode should never modify loss_mask.""" + config = RejectionSamplingConfig( + level="token", metric="ratio", action="clamp", upper=2.0 + ) + proximal_logprobs = torch.tensor([[0.0, 5.0, 0.0]]) # ratio ~ 148.4 + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # loss_mask unchanged -- all tokens still participate + assert torch.all(result.loss_mask == 1.0) + + def test_clamp_truncates_high_ratio(self): + """Token with ratio > upper should have weight clamped to upper.""" + config = RejectionSamplingConfig( + level="token", metric="ratio", action="clamp", upper=5.0 + ) + # ratio = exp(2) ~ 7.39 > 5.0 + proximal_logprobs = torch.tensor([[0.0, 2.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # Token 0: ratio=1.0, not clamped + torch.testing.assert_close( + result.behave_imp_weight[0, 0], torch.tensor(1.0), rtol=1e-5, atol=1e-5 + ) + # Token 1: ratio~7.39, clamped to 5.0 + torch.testing.assert_close( + result.behave_imp_weight[0, 1], torch.tensor(5.0), rtol=1e-5, atol=1e-5 + ) + # Token 2: ratio=1.0, not clamped + torch.testing.assert_close( + result.behave_imp_weight[0, 2], torch.tensor(1.0), rtol=1e-5, atol=1e-5 + ) + + def test_clamp_sequence_level(self): + """Sequence-level clamp with ratio metric uses geometric mean as uniform weight. + + geo_mean = exp(mean(log_ratio)), broadcast to all tokens in the sequence. + When geo_mean > upper, the uniform weight is clamped to upper for all tokens. + """ + config = RejectionSamplingConfig( + level="sequence", + metric="ratio", + action="clamp", + agg="mean", + upper=3.0, + ) + # Sequence 0: log_ratios [0, 4, 0], geo_mean = exp(4/3) ≈ 3.79 > 3.0 -> clamp + # Sequence 1: log_ratios [0, 0, 0], geo_mean = exp(0) = 1.0 <= 3.0 -> no clamp + proximal_logprobs = torch.tensor([[0.0, 4.0, 0.0], [0.0, 0.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + loss_mask = torch.ones(2, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # Sequence 0: geo_mean ≈ 3.79 > 3.0, all tokens get clamped uniform weight = 3.0 + torch.testing.assert_close( + result.behave_imp_weight[0], + torch.tensor([3.0, 3.0, 3.0]), + rtol=1e-5, + atol=1e-5, + ) + # Sequence 1: geo_mean = 1.0 <= 3.0, uniform weight = 1.0 (no clamp) + torch.testing.assert_close( + result.behave_imp_weight[1], + torch.ones(3), + rtol=1e-5, + atol=1e-5, + ) + # loss_mask unchanged for both sequences + assert torch.all(result.loss_mask == 1.0) + + def test_clamp_reports_clamped_fraction(self): + """filtered_fraction should report proportion of clamped tokens.""" + config = RejectionSamplingConfig( + level="token", metric="ratio", action="clamp", upper=2.0 + ) + # 1 of 3 tokens has ratio > 2.0 + proximal_logprobs = torch.tensor([[0.0, 1.0, 0.0]]) # ratios: 1.0, 2.72, 1.0 + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert result.filtered_fraction > 0 # at least 1 token clamped + + +class TestBackwardCompatibility: + """Verify new config reproduces old behave_imp_weight_cap behavior.""" + + def test_equivalent_to_legacy_token_mask(self): + """New ratio/token/mask config should match old token_mask behavior.""" + config = RejectionSamplingConfig(level="token", metric="ratio", upper=5.0) + + proximal_logprobs = torch.tensor([[0.0, 2.0, -0.5]]) # ratios: 1.0, 7.39, 0.61 + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # ratio=7.39 > 5.0 -> filtered, others kept + assert result.loss_mask[0, 0] == 1.0 + assert result.loss_mask[0, 1] == 0.0 + assert result.loss_mask[0, 2] == 1.0 + + def test_equivalent_to_legacy_token_truncate(self): + """New ratio/token/clamp config should match old token_truncate behavior.""" + config = RejectionSamplingConfig( + level="token", metric="ratio", action="clamp", upper=5.0 + ) + + proximal_logprobs = torch.tensor([[0.0, 2.0, -0.5]]) # ratios: 1.0, 7.39, 0.61 + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # All tokens kept in loss_mask + assert torch.all(result.loss_mask == 1.0) + # ratio=7.39 > 5.0 -> clamped to 5.0 + torch.testing.assert_close( + result.behave_imp_weight[0, 1], torch.tensor(5.0), rtol=1e-5, atol=1e-5 + ) + # ratio=0.61 within [0, 5.0] -> not clamped + torch.testing.assert_close( + result.behave_imp_weight[0, 2], + torch.exp(torch.tensor(-0.5)), + rtol=1e-5, + atol=1e-5, + ) + + +class TestKLK1Metric: + """Tests for kl_k1 metric (forward KL unbiased estimator, can be negative).""" + + def test_kl_k1_can_be_negative(self): + """kl_k1 = log(r) can be negative when proximal < old.""" + config = RejectionSamplingConfig( + level="token", metric="kl_k1", upper=1.0, lower=-0.5 + ) + # log_ratio = -1.0, so kl_k1 = -1.0 < lower=-0.5 -> filtered + # log_ratio = -0.3, so kl_k1 = -0.3, within [-0.5, 1.0] -> kept + proximal_logprobs = torch.tensor([[0.0, -1.0, -0.3]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert result.loss_mask[0, 0] == 1.0 # kl_k1=0, within bounds + assert result.loss_mask[0, 1] == 0.0 # kl_k1=-1.0 < -0.5, filtered + assert result.loss_mask[0, 2] == 1.0 # kl_k1=-0.3, within bounds + + def test_kl_k1_filters_high_positive(self): + """kl_k1 with high positive value should be filtered.""" + config = RejectionSamplingConfig(level="token", metric="kl_k1", upper=0.5) + # log_ratio = 1.0, kl_k1 = 1.0 > 0.5 -> filtered + proximal_logprobs = torch.tensor([[0.0, 1.0]]) + old_logprobs = torch.tensor([[0.0, 0.0]]) + loss_mask = torch.ones(1, 2) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert result.loss_mask[0, 0] == 1.0 # kl_k1=0 + assert result.loss_mask[0, 1] == 0.0 # kl_k1=1.0 > 0.5 + + def test_kl_k1_sequence_level(self): + """kl_k1 should work at sequence level with mean aggregation.""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="kl_k1", upper=0.5 + ) + # Seq 0: log_ratios [0, 1.0, -0.5], kl_k1 values [0, 1.0, -0.5] + # mean = (0 + 1.0 + (-0.5)) / 3 = 0.167 < 0.5, keep + # Seq 1: log_ratios [2.0, 2.0, 0], kl_k1 values [2.0, 2.0, 0] + # mean = (2.0 + 2.0 + 0) / 3 = 1.33 > 0.5, filter + proximal_logprobs = torch.tensor([[0.0, 1.0, -0.5], [2.0, 2.0, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + loss_mask = torch.ones(2, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert torch.all(result.loss_mask[0] == 1.0) # seq 0 kept + assert torch.all(result.loss_mask[1] == 0.0) # seq 1 filtered + + +class TestKLK3Metric: + """Tests for kl_k3 metric (exact forward KL estimator, non-negative).""" + + def test_kl_k3_is_non_negative(self): + """kl_k3 = exp(-log_ratio) - 1 - (-log_ratio) should be >= 0.""" + config = RejectionSamplingConfig(level="token", metric="kl_k3", upper=0.5) + # log_ratio = 0.5, kl_k3 = exp(-0.5) - 1 - (-0.5) = 0.6065 - 0.5 = 0.1065 + # log_ratio = -0.5, kl_k3 = exp(0.5) - 1 - 0.5 = 1.6487 - 1.5 = 0.1487 + # Both below upper=0.5 -> kept + proximal_logprobs = torch.tensor([[0.5, -0.5, 0.0]]) + old_logprobs = torch.tensor([[0.0, 0.0, 0.0]]) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert torch.all(result.loss_mask == 1.0) + + def test_kl_k3_filters_stale_tokens(self): + """kl_k3 with large divergence should exceed threshold.""" + config = RejectionSamplingConfig(level="token", metric="kl_k3", upper=0.5) + # log_ratio = 2.0, kl_k3 = exp(-2) - 1 - (-2) = 0.1353 + 1 = 1.1353 + # kl_k3 = 1.1353 > 0.5 -> filtered + proximal_logprobs = torch.tensor([[0.0, 2.0]]) + old_logprobs = torch.tensor([[0.0, 0.0]]) + loss_mask = torch.ones(1, 2) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert result.loss_mask[0, 0] == 1.0 # kl_k3=0 + assert result.loss_mask[0, 1] == 0.0 # kl_k3 > 0.5 + + +class TestAggregationMethods: + """Tests for sum and max aggregation methods (Issue 10).""" + + def test_sequence_sum_is_length_sensitive(self): + """Sum aggregation should be sensitive to sequence length.""" + config = RejectionSamplingConfig( + level="sequence", agg="sum", metric="kl_k2", upper=1.0 + ) + # Seq 0: 4 tokens each with kl_k2 = 0.5*0.5^2 = 0.125 + # sum = 4 * 0.125 = 0.5 < 1.0, keep + # Seq 1: 4 tokens each with kl_k2 = 0.5*1.0^2 = 0.5 + # sum = 4 * 0.5 = 2.0 > 1.0, filter + proximal_logprobs = torch.tensor([[0.5, 0.5, 0.5, 0.5], [1.0, 1.0, 1.0, 1.0]]) + old_logprobs = torch.zeros(2, 4) + loss_mask = torch.ones(2, 4) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert torch.all(result.loss_mask[0] == 1.0) # sum < 1.0, kept + assert torch.all(result.loss_mask[1] == 0.0) # sum > 1.0, filtered + + def test_sequence_max_filters_single_high_token(self): + """Max aggregation should filter based on worst token in sequence.""" + config = RejectionSamplingConfig( + level="sequence", agg="max", metric="ratio", upper=3.0 + ) + # Seq 0: ratios [1, 1, exp(2)~7.39], max=7.39 > 3.0 -> filter entire seq + # Seq 1: ratios [1, exp(0.5)~1.65, 1], max=1.65 < 3.0 -> keep + proximal_logprobs = torch.tensor([[0.0, 0.0, 2.0], [0.0, 0.5, 0.0]]) + old_logprobs = torch.zeros(2, 3) + loss_mask = torch.ones(2, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert torch.all(result.loss_mask[0] == 0.0) # max > 3.0, filtered + assert torch.all(result.loss_mask[1] == 1.0) # max < 3.0, kept + + def test_packed_sum_aggregation(self): + """Sum aggregation should work with 1D packed format.""" + config = RejectionSamplingConfig( + level="sequence", agg="sum", metric="kl_k2", upper=0.5 + ) + # Seq 0 (3 tokens): log_ratios [0, 1.0, 0], kl_k2 = [0, 0.5, 0], sum=0.5 + # Seq 1 (2 tokens): log_ratios [2.0, 0], kl_k2 = [2.0, 0], sum=2.0 > 0.5 + proximal_logprobs = torch.tensor([0.0, 1.0, 0.0, 2.0, 0.0]) + old_logprobs = torch.zeros(5) + loss_mask = torch.ones(5) + cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=cu_seqlens, + config=config, + ) + + assert torch.all(result.loss_mask[:3] == 1.0) # seq 0: sum=0.5, kept + assert torch.all(result.loss_mask[3:] == 0.0) # seq 1: sum=2.0 > 0.5, filtered + + def test_packed_max_aggregation(self): + """Max aggregation should work with 1D packed format.""" + config = RejectionSamplingConfig( + level="sequence", agg="max", metric="ratio", upper=3.0 + ) + # Seq 0 (3 tokens): ratios [1, exp(2)~7.39, 1], max=7.39 > 3.0 -> filter + # Seq 1 (2 tokens): ratios [1, 1], max=1 <= 3.0 -> keep + proximal_logprobs = torch.tensor([0.0, 2.0, 0.0, 0.0, 0.0]) + old_logprobs = torch.zeros(5) + loss_mask = torch.ones(5) + cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=cu_seqlens, + config=config, + ) + + assert torch.all(result.loss_mask[:3] == 0.0) # seq 0: max > 3.0 + assert torch.all(result.loss_mask[3:] == 1.0) # seq 1: max <= 3.0 + + +class TestEdgeCases: + """Tests for edge cases (Issue 15).""" + + def test_empty_loss_mask(self): + """All-zero loss_mask should produce no filtering and zero fraction.""" + config = RejectionSamplingConfig(level="token", metric="ratio", upper=2.0) + proximal_logprobs = torch.tensor([[5.0, 5.0, 5.0]]) # huge ratios + old_logprobs = torch.zeros(1, 3) + loss_mask = torch.zeros(1, 3) # all masked + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # loss_mask stays all-zero (nothing to filter) + assert torch.all(result.loss_mask == 0.0) + # behave_imp_weight should be zeroed where mask is 0 + assert torch.all(result.behave_imp_weight == 0.0) + assert result.filtered_fraction == 0.0 + + def test_single_token_sequence(self): + """Single-token sequences should work correctly.""" + config = RejectionSamplingConfig( + level="sequence", agg="mean", metric="ratio", upper=2.0 + ) + proximal_logprobs = torch.tensor([[1.0]]) # ratio=exp(1)~2.72 > 2.0 + old_logprobs = torch.tensor([[0.0]]) + loss_mask = torch.ones(1, 1) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert result.loss_mask[0, 0] == 0.0 # filtered + + def test_all_tokens_filtered(self): + """When all tokens exceed threshold, everything should be filtered.""" + config = RejectionSamplingConfig(level="token", metric="ratio", upper=1.5) + # All ratios > 1.5 + proximal_logprobs = torch.tensor([[1.0, 2.0, 3.0]]) + old_logprobs = torch.zeros(1, 3) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert torch.all(result.loss_mask == 0.0) + assert result.filtered_fraction == 1.0 + + def test_clamp_with_lower_bound(self): + """Clamp mode with explicit lower bound should clamp from both sides.""" + config = RejectionSamplingConfig( + level="token", metric="ratio", action="clamp", upper=3.0, lower=0.5 + ) + # ratio = exp(-2) ~ 0.135 < 0.5 -> clamped to 0.5 + # ratio = exp(2) ~ 7.39 > 3.0 -> clamped to 3.0 + # ratio = exp(0) = 1.0, within bounds -> unchanged + proximal_logprobs = torch.tensor([[-2.0, 2.0, 0.0]]) + old_logprobs = torch.zeros(1, 3) + loss_mask = torch.ones(1, 3) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + torch.testing.assert_close( + result.behave_imp_weight[0, 0], torch.tensor(0.5), rtol=1e-5, atol=1e-5 + ) + torch.testing.assert_close( + result.behave_imp_weight[0, 1], torch.tensor(3.0), rtol=1e-5, atol=1e-5 + ) + torch.testing.assert_close( + result.behave_imp_weight[0, 2], torch.tensor(1.0), rtol=1e-5, atol=1e-5 + ) + + def test_ratio_exactly_at_upper_bound(self): + """Token with ratio exactly equal to upper should pass (<=).""" + config = RejectionSamplingConfig(level="token", metric="ratio", upper=2.0) + # log_ratio = ln(2.0) ~ 0.6931 + log_2 = torch.tensor(2.0).log() + proximal_logprobs = torch.tensor([[log_2.item()]]) + old_logprobs = torch.tensor([[0.0]]) + loss_mask = torch.ones(1, 1) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + assert result.loss_mask[0, 0] == 1.0 # exactly at boundary, kept + + def test_nan_from_inf_logprobs(self): + """Non-finite log-probs (both -inf) should not produce NaN.""" + config = RejectionSamplingConfig(level="token", metric="ratio", upper=5.0) + proximal_logprobs = torch.tensor([[0.0, float("-inf")]]) + old_logprobs = torch.tensor([[0.0, float("-inf")]]) + loss_mask = torch.ones(1, 2) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # No NaN in outputs + assert not torch.isnan(result.behave_imp_weight).any() + assert not torch.isnan(result.loss_mask).any() + + def test_all_masked_sequence_with_max_agg(self): + """Sequence with all tokens masked should pass bounds check with max agg.""" + config = RejectionSamplingConfig( + level="sequence", agg="max", metric="ratio", upper=2.0 + ) + # Seq 0: all masked, should be treated as in-bounds + # Seq 1: valid tokens, ratio=1.0 within bounds + proximal_logprobs = torch.tensor([[5.0, 5.0], [0.0, 0.0]]) + old_logprobs = torch.zeros(2, 2) + loss_mask = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) + + result = apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + # Seq 0: all-masked, loss_mask stays all-zero (no change) + assert torch.all(result.loss_mask[0] == 0.0) + # Seq 1: within bounds, kept + assert torch.all(result.loss_mask[1] == 1.0) + + def test_shape_mismatch_raises(self): + """Mismatched tensor shapes should raise ValueError.""" + config = RejectionSamplingConfig(level="token", metric="ratio", upper=5.0) + proximal_logprobs = torch.tensor([[0.0, 1.0]]) + old_logprobs = torch.tensor([[0.0, 1.0, 2.0]]) + loss_mask = torch.ones(1, 2) + + with pytest.raises(ValueError, match="shape"): + apply_rejection_sampling( + proximal_logprobs, + old_logprobs, + loss_mask, + cu_seqlens=None, + config=config, + ) + + +class TestConfigValidation: + """Tests for new config validation rules (Issues 1, 9).""" + + def test_lower_greater_than_upper_raises(self): + """lower > upper should raise ValueError.""" + with pytest.raises(ValueError, match="lower.*cannot be greater than upper"): + RejectionSamplingConfig(metric="ratio", lower=3.0, upper=2.0) + + def test_invalid_level_raises(self): + """Invalid level should raise ValueError.""" + with pytest.raises(ValueError, match="level must be one of"): + RejectionSamplingConfig(level="invalid") + + def test_invalid_action_raises(self): + """Invalid action should raise ValueError.""" + with pytest.raises(ValueError, match="action must be one of"): + RejectionSamplingConfig(action="invalid") + + def test_invalid_metric_raises(self): + """Invalid metric should raise ValueError.""" + with pytest.raises(ValueError, match="metric must be one of"): + RejectionSamplingConfig(metric="invalid") + + def test_invalid_agg_raises(self): + """Invalid agg should raise ValueError.""" + with pytest.raises(ValueError, match="agg must be one of"): + RejectionSamplingConfig(agg="invalid")