diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 6716091d6f..91f9d3ba13 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1507,6 +1507,8 @@ def build_args( host: str | None = None, port: int | None = None, dist_init_addr: str | None = None, + n_nodes: int = 1, + node_rank: int = 0, ): args: dict = conf_as_dict(vllm_config) args = dict( @@ -1522,6 +1524,18 @@ def build_args( args["port"] = port if host is not None: args["host"] = host + # Multi-node support + if n_nodes > 1: + args["nnodes"] = n_nodes + args["node_rank"] = node_rank + if dist_init_addr is not None: + from areal.utils.network import split_hostport + + master_host, master_port = split_hostport(dist_init_addr) + args["master_addr"] = master_host + args["master_port"] = str(master_port) + if node_rank > 0: + args["headless"] = True return args @staticmethod @@ -1536,6 +1550,8 @@ def build_cmd( host: str | None = None, port: int | None = None, dist_init_addr: str | None = None, + n_nodes: int = 1, + node_rank: int = 0, ): args = vLLMConfig.build_args( vllm_config=vllm_config, @@ -1544,6 +1560,8 @@ def build_cmd( host=host, port=port, dist_init_addr=dist_init_addr, + n_nodes=n_nodes, + node_rank=node_rank, ) return vLLMConfig.build_cmd_from_args(args) diff --git a/areal/experimental/agent_service/README.md b/areal/experimental/agent_service/README.md index 46a8ae876f..f3dc0f839f 100644 --- a/areal/experimental/agent_service/README.md +++ b/areal/experimental/agent_service/README.md @@ -153,29 +153,45 @@ Turn 2: ``` areal/experimental/agent_service/ -├── __init__.py # Public exports +├── __init__.py # Public exports (AgentRequest, AgentResponse, etc.) ├── README.md # This document +├── auth.py # Admin key auth helpers (hmac-safe comparison) ├── protocol.py # Gateway protocol frame types ├── types.py # AgentRequest, AgentResponse, EventEmitter, AgentRunnable +├── controller/ +│ ├── __init__.py # AgentServiceController, AgentServiceControllerConfig +│ ├── config.py # AgentServiceControllerConfig dataclass +│ └── controller.py # AgentServiceController orchestrator +├── guard/ +│ ├── __init__.py # Module docstring +│ ├── __main__.py # python -m areal.experimental.agent_service.guard +│ └── app.py # Guard Flask app (pass-through to areal.infra.rpc.guard) ├── gateway/ +│ ├── __init__.py # Public exports │ ├── __main__.py # python -m areal.experimental.agent_service.gateway │ ├── app.py # create_gateway_app() -│ └── bridge.py # OpenResponsesBridge, mount_bridge() +│ ├── bridge.py # OpenResponsesBridge, mount_bridge() +│ └── config.py # GatewayConfig dataclass ├── router/ +│ ├── __init__.py # Public exports │ ├── __main__.py # python -m areal.experimental.agent_service.router │ ├── app.py # create_router_app() -│ └── client.py # RouterClient +│ ├── client.py # RouterClient +│ └── config.py # RouterConfig dataclass ├── data_proxy/ +│ ├── __init__.py # Public exports │ ├── __main__.py # python -m areal.experimental.agent_service.data_proxy │ ├── app.py # create_data_proxy_app() -│ └── client.py # DataProxyClient +│ ├── client.py # DataProxyClient +│ └── config.py # DataProxyConfig dataclass └── worker/ + ├── __init__.py # Public exports ├── __main__.py # python -m areal.experimental.agent_service.worker - └── app.py # create_worker_app() + ├── app.py # create_worker_app() + └── config.py # WorkerConfig dataclass examples/agent_service/ -├── agent.py # Tau2Agent (PydanticAI) -├── config.yaml # Demo configuration -├── run_demo.py # One-click demo -└── README.md # Example documentation +├── agent.py # ClaudeAgent (Claude Agent SDK) +├── run_agent_service.py # Controller-based launcher + interactive demo +└── README.md # Example documentation ``` diff --git a/areal/experimental/agent_service/__init__.py b/areal/experimental/agent_service/__init__.py index 28059732d0..3858d5c133 100644 --- a/areal/experimental/agent_service/__init__.py +++ b/areal/experimental/agent_service/__init__.py @@ -5,83 +5,22 @@ Exposes complete agent sessions (autonomous multi-step reasoning, tool use, memory) via independent HTTP microservices: Gateway, Router, DataProxy, and Worker. -""" - -from __future__ import annotations -import importlib -from typing import TYPE_CHECKING +Submodules +---------- +- ``controller`` — :class:`AgentServiceController` orchestrator +- ``gateway`` — public HTTP/WebSocket entry point +- ``router`` — session-affine routing +- ``data_proxy`` — stateful session proxy +- ``worker`` — stateless agent execution +- ``protocol`` — WebSocket frame types and helpers +""" -from .protocol import ( - EventFrame, - Frame, - FrameType, - QueueMode, - RequestFrame, - RequestMethod, - ResponseFrame, - RunStatus, - generate_run_id, - make_complete_response, - make_delta_event, - make_failed_response, - make_tool_call_event, - parse_frame, - serialize_frame, -) from .types import AgentRequest, AgentResponse, AgentRunnable, EventEmitter -if TYPE_CHECKING: - from .data_proxy import DataProxyClient, create_data_proxy_app - from .gateway import OpenResponsesBridge, create_gateway_app, mount_bridge - from .router import RouterClient, create_router_app - from .worker import create_worker_app - -_LAZY_IMPORTS: dict[str, str] = { - "DataProxyClient": ".data_proxy", - "OpenResponsesBridge": ".gateway", - "RouterClient": ".router", - "create_data_proxy_app": ".data_proxy", - "create_gateway_app": ".gateway", - "create_router_app": ".router", - "create_worker_app": ".worker", - "mount_bridge": ".gateway", -} - - -def __getattr__(name: str): - if name in _LAZY_IMPORTS: - module = importlib.import_module(_LAZY_IMPORTS[name], __package__) - return getattr(module, name) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - __all__ = [ "AgentRequest", "AgentResponse", "AgentRunnable", - "DataProxyClient", "EventEmitter", - "EventFrame", - "Frame", - "FrameType", - "OpenResponsesBridge", - "QueueMode", - "RequestFrame", - "RequestMethod", - "ResponseFrame", - "RouterClient", - "RunStatus", - "create_data_proxy_app", - "create_gateway_app", - "create_router_app", - "create_worker_app", - "generate_run_id", - "make_complete_response", - "make_delta_event", - "make_failed_response", - "make_tool_call_event", - "mount_bridge", - "parse_frame", - "serialize_frame", ] diff --git a/areal/experimental/agent_service/auth.py b/areal/experimental/agent_service/auth.py index 0f01da4bd2..b3893f5bf2 100644 --- a/areal/experimental/agent_service/auth.py +++ b/areal/experimental/agent_service/auth.py @@ -4,9 +4,11 @@ from __future__ import annotations +import hmac + from fastapi import Header, HTTPException -DEFAULT_ADMIN_KEY = "areal-agent-admin" +DEFAULT_ADMIN_API_KEY = "areal-agent-admin" async def verify_admin_key( @@ -15,16 +17,16 @@ async def verify_admin_key( expected_key: str, ) -> None: expected = f"Bearer {expected_key}" - if authorization != expected: + if not hmac.compare_digest(authorization, expected): raise HTTPException(status_code=401, detail="Invalid admin key") -def make_admin_dependency(admin_key: str): +def make_admin_dependency(admin_api_key: str): async def _dep(authorization: str = Header(alias="Authorization")) -> None: - await verify_admin_key(authorization, expected_key=admin_key) + await verify_admin_key(authorization, expected_key=admin_api_key) return _dep -def admin_headers(admin_key: str) -> dict[str, str]: - return {"Authorization": f"Bearer {admin_key}"} +def admin_headers(admin_api_key: str) -> dict[str, str]: + return {"Authorization": f"Bearer {admin_api_key}"} diff --git a/areal/experimental/agent_service/controller/__init__.py b/areal/experimental/agent_service/controller/__init__.py new file mode 100644 index 0000000000..3150205885 --- /dev/null +++ b/areal/experimental/agent_service/controller/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Agent Service Controller — orchestrator for agent micro-services.""" + +from .config import AgentServiceControllerConfig +from .controller import AgentServiceController + +__all__ = [ + "AgentServiceController", + "AgentServiceControllerConfig", +] diff --git a/areal/experimental/agent_service/controller/config.py b/areal/experimental/agent_service/controller/config.py new file mode 100644 index 0000000000..c316d58227 --- /dev/null +++ b/areal/experimental/agent_service/controller/config.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Configuration for the AgentServiceController.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from ..auth import DEFAULT_ADMIN_API_KEY + + +@dataclass +class AgentServiceControllerConfig: + """Unified configuration for AgentServiceController. + + Consolidates settings for the guard, router, gateway, worker, and + data proxy micro-services launched by the controller. + """ + + # -- Agent class ------------------------------------------------------- + agent_cls_path: str = "" + """Fully-qualified import path for the ``AgentRunnable`` implementation + (e.g. ``examples.agent_service.agent.Tau2Agent``).""" + + # -- Authentication ---------------------------------------------------- + admin_api_key: str = DEFAULT_ADMIN_API_KEY + """Shared admin API key for inter-service Bearer auth.""" + + # -- Scaling ----------------------------------------------------------- + num_pairs: int = 1 + """Number of Worker+DataProxy pairs to launch on initialize.""" + + # -- Timeouts ---------------------------------------------------------- + setup_timeout: float = 120.0 + """Timeout (seconds) waiting for each service to become healthy.""" + + health_poll_interval: float = 5.0 + """Seconds between health polls for crash detection (0 = disabled).""" + + drain_timeout: float = 30.0 + """Seconds to wait for active sessions to drain before force-killing a pair.""" + + # -- Log level --------------------------------------------------------- + log_level: str = "info" + """Log level for spawned micro-services.""" + + # -- Environment ------------------------------------------------------- + env: dict[str, str] = field(default_factory=dict) + """Extra environment variables to pass to all forked child processes.""" + + def __post_init__(self) -> None: + if not self.agent_cls_path: + raise ValueError("agent_cls_path must be a non-empty import path") + if self.num_pairs < 0: + raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") + if self.setup_timeout <= 0: + raise ValueError( + f"setup_timeout must be positive, got {self.setup_timeout}" + ) + if self.drain_timeout < 0: + raise ValueError( + f"drain_timeout must be non-negative, got {self.drain_timeout}" + ) diff --git a/areal/experimental/agent_service/controller/controller.py b/areal/experimental/agent_service/controller/controller.py new file mode 100644 index 0000000000..21b12851bb --- /dev/null +++ b/areal/experimental/agent_service/controller/controller.py @@ -0,0 +1,577 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""AgentServiceController — orchestrates agent service micro-services via Guards. + +Mirrors the architecture of +:class:`~areal.experimental.inference_service.controller.controller.GatewayInferenceController`: +Guard workers are created via the Scheduler, then the controller forks +Router, Worker+DataProxy pairs, and Gateway onto them via HTTP API. + +Lifecycle:: + + from areal.infra.scheduler.local import LocalScheduler + + scheduler = LocalScheduler(...) + controller = AgentServiceController(config, scheduler) + controller.initialize() + # ... run traffic ... + controller.scale_up(2) # add 2 Worker+DataProxy pairs + controller.scale_down(1) # drain + remove 1 pair + controller.destroy() +""" + +from __future__ import annotations + +import sys +import threading +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import requests + +from areal.experimental.agent_service.controller.config import ( + AgentServiceControllerConfig, +) +from areal.utils import logging +from areal.utils.network import format_hostport + +if TYPE_CHECKING: + from areal.api.scheduler_api import Scheduler, Worker + +logger = logging.getLogger("AgentServiceController") + +_GUARD_ROLE = "agent-guard" +_UNREGISTER_RETRIES = 3 +_HEALTH_CHECK_WORKERS = 4 + + +@dataclass +class _WorkerPair: + pair_index: int + guard_addr: str + worker_host: str + worker_port: int + proxy_host: str + proxy_port: int + proxy_addr: str + worker_addr: str + + +class AgentServiceController: + """Orchestrator for the Agent Service micro-service stack. + + Parameters + ---------- + config: + Controller configuration. + scheduler: + Scheduler instance used to create and manage Guard workers. + """ + + def __init__( + self, + config: AgentServiceControllerConfig, + scheduler: Scheduler, + ) -> None: + self.config = config + self.scheduler = scheduler + + self._guard_addrs: list[str] = [] + self._workers: list[Worker] = [] + self._service_roles: list[str] = [] + + self._router_addr: str = "" + self._gateway_addr: str = "" + + self._pairs: dict[int, _WorkerPair] = {} + self._pairs_lock = threading.Lock() + self._next_pair_index: int = 0 + + self._forked_services: list[tuple[str, str, int]] = [] + + self._health_stop = threading.Event() + self._health_thread: threading.Thread | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def initialize(self) -> None: + """Launch the full micro-service stack. + + Order: Guards (via scheduler) → Router → Worker+DataProxy pairs → + register → Gateway → health monitor. + On failure, already-forked services are cleaned up via destroy(). + """ + try: + self._do_initialize() + except Exception: + logger.error("initialize() failed, rolling back...") + self.destroy() + raise + + def _do_initialize(self) -> None: + from areal.api.cli_args import SchedulingSpec, SchedulingStrategy + from areal.api.scheduler_api import Job + + cfg = self.config + + # Step 1: Create Guard workers via scheduler + guard_spec = SchedulingSpec( + gpu=0, cmd=f"{sys.executable} -m areal.experimental.agent_service.guard" + ) + num_guards = max(cfg.num_pairs, 1) + guard_job = Job( + role=_GUARD_ROLE, + replicas=num_guards, + tasks=[guard_spec for _ in range(num_guards)], + scheduling_strategy=SchedulingStrategy(), + ) + self.scheduler.create_workers(job=guard_job) + self._service_roles.append(_GUARD_ROLE) + + self._workers = self.scheduler.get_workers(role=_GUARD_ROLE) + self._guard_addrs = [ + f"http://{format_hostport(w.ip, int(w.worker_ports[0]))}" + for w in self._workers + ] + logger.info("Guards ready: %s", self._guard_addrs) + + # Step 2: Fork Router on guard[0] + guard_0 = self._guard_addrs[0] + router_cmd = [ + sys.executable, + "-m", + "areal.experimental.agent_service.router", + "--admin-api-key", + cfg.admin_api_key, + ] + router_host, router_port = self._fork_on_guard( + guard_addr=guard_0, + role="agent-router", + worker_index=0, + raw_cmd=router_cmd, + ) + self._router_addr = f"http://{format_hostport(router_host, router_port)}" + logger.info("Router: %s", self._router_addr) + + # Step 3: Fork Worker+DataProxy pairs + self.scale_up(cfg.num_pairs) + + # Step 4: Fork Gateway on guard[0] + gw_cmd = [ + sys.executable, + "-m", + "areal.experimental.agent_service.gateway", + "--router-addr", + self._router_addr, + "--admin-api-key", + cfg.admin_api_key, + ] + gw_host, gw_port = self._fork_on_guard( + guard_addr=guard_0, + role="agent-gateway", + worker_index=0, + raw_cmd=gw_cmd, + ) + self._gateway_addr = f"http://{format_hostport(gw_host, gw_port)}" + logger.info("Gateway: %s", self._gateway_addr) + + # Step 5: Start health monitor + if cfg.health_poll_interval > 0: + self._health_stop.clear() + self._health_thread = threading.Thread( + target=self._health_monitor_loop, daemon=True + ) + self._health_thread.start() + + def destroy(self) -> None: + """Tear down all services in reverse order.""" + self._stop_health_monitor() + + for guard_addr, role, worker_index in reversed(self._forked_services): + try: + self._kill_forked_service(guard_addr, role, worker_index) + except requests.RequestException: + logger.error( + "Error killing forked service %s/%d: %s", + role, + worker_index, + traceback.format_exc(), + ) + self._forked_services.clear() + + for role in reversed(self._service_roles): + try: + self.scheduler.delete_workers(role=role) + logger.info("Workers deleted for role: %s", role) + except Exception: + logger.error( + "Error deleting workers for role %s: %s", + role, + traceback.format_exc(), + ) + self._service_roles.clear() + self._workers.clear() + self._guard_addrs.clear() + with self._pairs_lock: + self._pairs.clear() + self._router_addr = "" + self._gateway_addr = "" + + def scale_up(self, count: int) -> list[int]: + """Add *count* Worker+DataProxy pairs. + + Pairs are distributed across guards round-robin. + Returns the pair indices that were created. + """ + cfg = self.config + created: list[int] = [] + + for _ in range(count): + pair_index = self._next_pair_index + self._next_pair_index += 1 + + guard_addr = self._guard_addrs[pair_index % len(self._guard_addrs)] + + worker_cmd = [ + sys.executable, + "-m", + "areal.experimental.agent_service.worker", + "--agent", + cfg.agent_cls_path, + "--log-level", + cfg.log_level, + ] + worker_host, worker_port = self._fork_on_guard( + guard_addr=guard_addr, + role=f"agent-worker-{pair_index}", + worker_index=pair_index, + raw_cmd=worker_cmd, + ) + worker_addr = f"http://{format_hostport(worker_host, worker_port)}" + + proxy_cmd = [ + sys.executable, + "-m", + "areal.experimental.agent_service.data_proxy", + "--worker-addr", + worker_addr, + ] + proxy_host, proxy_port = self._fork_on_guard( + guard_addr=guard_addr, + role=f"agent-proxy-{pair_index}", + worker_index=pair_index, + raw_cmd=proxy_cmd, + ) + proxy_addr = f"http://{format_hostport(proxy_host, proxy_port)}" + + pair = _WorkerPair( + pair_index=pair_index, + guard_addr=guard_addr, + worker_host=worker_host, + worker_port=worker_port, + proxy_host=proxy_host, + proxy_port=proxy_port, + proxy_addr=proxy_addr, + worker_addr=worker_addr, + ) + + try: + self._register_proxy(proxy_addr) + except Exception: + logger.error( + "Failed to register pair %d, cleaning up forked processes", + pair_index, + ) + self._cleanup_pair_forks(pair_index, guard_addr) + raise + + with self._pairs_lock: + self._pairs[pair_index] = pair + created.append(pair_index) + + logger.info( + "Pair %d: worker=%s proxy=%s", pair_index, worker_addr, proxy_addr + ) + + return created + + def scale_down(self, count: int) -> list[int]: + """Remove *count* pairs (LIFO order). + + For each pair: unregister from Router (with retry) → drain active + sessions → kill DataProxy → kill Worker. + Returns the pair indices that were removed. + """ + removed: list[int] = [] + + with self._pairs_lock: + indices = sorted(self._pairs.keys(), reverse=True)[:count] + + for pair_index in indices: + with self._pairs_lock: + pair = self._pairs.get(pair_index) + if pair is None: + continue + + try: + self._unregister_proxy(pair.proxy_addr) + except requests.RequestException: + logger.error( + "Unregister failed for pair %d after retries, skipping", + pair_index, + ) + continue + + self._drain_proxy(pair.proxy_addr) + + with self._pairs_lock: + self._pairs.pop(pair_index, None) + + proxy_key = (pair.guard_addr, f"agent-proxy-{pair_index}", pair_index) + worker_key = (pair.guard_addr, f"agent-worker-{pair_index}", pair_index) + + for guard_addr, role, wi in [proxy_key, worker_key]: + try: + self._kill_forked_service(guard_addr, role, wi) + entry = (guard_addr, role, wi) + if entry in self._forked_services: + self._forked_services.remove(entry) + except requests.RequestException: + logger.warning( + "Failed to kill %s/%d: %s", + role, + wi, + traceback.format_exc(), + ) + + removed.append(pair_index) + logger.info("Removed pair %d", pair_index) + + return removed + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def router_addr(self) -> str: + return self._router_addr + + @property + def gateway_addr(self) -> str: + return self._gateway_addr + + @property + def pairs(self) -> dict[int, _WorkerPair]: + with self._pairs_lock: + return dict(self._pairs) + + # ------------------------------------------------------------------ + # Guard interaction helpers + # ------------------------------------------------------------------ + + def _fork_on_guard( + self, + guard_addr: str, + role: str, + worker_index: int, + raw_cmd: list[str], + health_path: str = "/health", + env: dict[str, str] | None = None, + ) -> tuple[str, int]: + resp = requests.post( + f"{guard_addr}/alloc_ports", + json={"count": 1}, + timeout=30, + ) + resp.raise_for_status() + port_data = resp.json() + host = port_data["host"] + port = port_data["ports"][0] + + cmd = list(raw_cmd) + ["--host", host, "--port", str(port)] + + merged_env = {**self.config.env, **(env or {})} + + fork_payload: dict[str, Any] = { + "role": role, + "worker_index": worker_index, + "raw_cmd": cmd, + } + if merged_env: + fork_payload["env"] = merged_env + + resp = requests.post( + f"{guard_addr}/fork", + json=fork_payload, + timeout=30, + ) + resp.raise_for_status() + + self._forked_services.append((guard_addr, role, worker_index)) + + addr = f"http://{format_hostport(host, port)}" + self._wait_for_service(f"{addr}{health_path}", role) + + return host, port + + def _cleanup_pair_forks(self, pair_index: int, guard_addr: str) -> None: + for role_prefix in ("agent-proxy-", "agent-worker-"): + role = f"{role_prefix}{pair_index}" + entry = (guard_addr, role, pair_index) + if entry in self._forked_services: + try: + self._kill_forked_service(guard_addr, role, pair_index) + except requests.RequestException: + pass + self._forked_services.remove(entry) + + def _kill_forked_service( + self, guard_addr: str, role: str, worker_index: int + ) -> None: + try: + resp = requests.post( + f"{guard_addr}/kill_forked_worker", + json={"role": role, "worker_index": worker_index}, + timeout=10, + ) + if resp.status_code == 200: + logger.info("Killed forked service %s/%d", role, worker_index) + else: + logger.warning( + "Failed to kill forked service %s/%d: %s", + role, + worker_index, + resp.text, + ) + except requests.RequestException as exc: + logger.error( + "Error killing forked service %s/%d: %s", role, worker_index, exc + ) + + def _wait_for_service( + self, url: str, name: str, timeout: float | None = None + ) -> None: + timeout = timeout or self.config.setup_timeout + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = requests.get(url, timeout=2) + if resp.status_code == 200: + logger.info("%s healthy at %s", name, url) + return + except requests.RequestException: + pass + time.sleep(0.5) + raise TimeoutError(f"{name} did not become healthy at {url} within {timeout}s") + + def _register_proxy(self, proxy_addr: str) -> None: + """Raises on failure so that ``scale_up`` callers know the pair is + non-functional. + """ + if not self._router_addr: + return + resp = requests.post( + f"{self._router_addr}/register", + json={"addr": proxy_addr}, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, + timeout=10, + ) + resp.raise_for_status() + logger.info("Registered proxy %s with Router", proxy_addr) + + def _drain_proxy(self, proxy_addr: str) -> None: + timeout = self.config.drain_timeout + if timeout <= 0: + return + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = requests.get(f"{proxy_addr}/health", timeout=2) + if resp.status_code == 200: + active = resp.json().get("active_sessions", 0) + if active == 0: + logger.info("Proxy %s drained", proxy_addr) + return + logger.debug( + "Proxy %s draining: %d active sessions", proxy_addr, active + ) + except requests.RequestException: + break + time.sleep(1.0) + logger.warning( + "Proxy %s drain timed out after %.0fs, force-killing", proxy_addr, timeout + ) + + def _check_pair_health(self, pair_index: int, proxy_addr: str) -> None: + try: + resp = requests.get(f"{proxy_addr}/health", timeout=2) + if resp.status_code != 200: + logger.warning( + "Pair %d proxy %s returned %d", + pair_index, + proxy_addr, + resp.status_code, + ) + except requests.RequestException: + logger.warning("Pair %d proxy %s unreachable", pair_index, proxy_addr) + + def _health_monitor_loop(self) -> None: + interval = self.config.health_poll_interval + while not self._health_stop.wait(timeout=interval): + with self._pairs_lock: + snapshot = list(self._pairs.items()) + if not snapshot: + continue + with ThreadPoolExecutor( + max_workers=min(_HEALTH_CHECK_WORKERS, len(snapshot)) + ) as pool: + futures = { + pool.submit(self._check_pair_health, idx, pair.proxy_addr): idx + for idx, pair in snapshot + } + for future in as_completed(futures, timeout=10): + try: + future.result() + except Exception: + pass + + def _stop_health_monitor(self) -> None: + self._health_stop.set() + if self._health_thread is not None: + self._health_thread.join(timeout=5) + self._health_thread = None + + def _unregister_proxy(self, proxy_addr: str) -> None: + """Unregister with retry. Raises after all retries exhausted.""" + if not self._router_addr: + return + last_exc: Exception | None = None + for attempt in range(_UNREGISTER_RETRIES): + try: + resp = requests.post( + f"{self._router_addr}/unregister", + json={"addr": proxy_addr}, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, + timeout=5, + ) + resp.raise_for_status() + logger.info("Unregistered proxy %s", proxy_addr) + return + except requests.RequestException as exc: + last_exc = exc + logger.warning( + "Unregister proxy %s attempt %d/%d failed: %s", + proxy_addr, + attempt + 1, + _UNREGISTER_RETRIES, + exc, + ) + if attempt < _UNREGISTER_RETRIES - 1: + time.sleep(1.0) + raise last_exc # type: ignore[misc] diff --git a/areal/experimental/agent_service/data_proxy/__main__.py b/areal/experimental/agent_service/data_proxy/__main__.py index bda8e6164e..c856bac91b 100644 --- a/areal/experimental/agent_service/data_proxy/__main__.py +++ b/areal/experimental/agent_service/data_proxy/__main__.py @@ -2,21 +2,41 @@ """``python -m areal.experimental.agent_service.data_proxy``""" -from .app import create_data_proxy_app +import argparse -if __name__ == "__main__": - import argparse +import uvicorn + +from .app import create_data_proxy_app +from .config import DataProxyConfig - import uvicorn +def main() -> None: parser = argparse.ArgumentParser(description="Agent DataProxy") parser.add_argument("--worker-addr", required=True, help="Worker HTTP address") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=9100) + parser.add_argument("--request-timeout", type=float, default=600.0) + parser.add_argument("--session-timeout", type=int, default=3600) + parser.add_argument( + "--log-level", choices=["debug", "info", "warning", "error"], default="info" + ) args = parser.parse_args() - uvicorn.run( - create_data_proxy_app(worker_addr=args.worker_addr), + config = DataProxyConfig( host=args.host, port=args.port, + worker_addr=args.worker_addr, + request_timeout=args.request_timeout, + session_timeout=args.session_timeout, + log_level=args.log_level, ) + uvicorn.run( + create_data_proxy_app(config), + host=config.host, + port=config.port, + log_level=config.log_level, + ) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/agent_service/data_proxy/app.py b/areal/experimental/agent_service/data_proxy/app.py index 0e05c4b392..fbf51646ba 100644 --- a/areal/experimental/agent_service/data_proxy/app.py +++ b/areal/experimental/agent_service/data_proxy/app.py @@ -14,6 +14,8 @@ from areal.utils import logging +from .config import DataProxyConfig + logger = logging.getLogger("AgentDataProxy") @@ -24,23 +26,31 @@ class _SessionData: last_active: float = field(default_factory=time.monotonic) -def create_data_proxy_app( - worker_addr: str, - session_timeout: int = 3600, -) -> FastAPI: +def create_data_proxy_app(config: DataProxyConfig) -> FastAPI: app = FastAPI(title="AReaL Data Proxy") sessions: dict[str, _SessionData] = {} - http_client = httpx.AsyncClient(timeout=600.0) + http_client = httpx.AsyncClient(timeout=config.request_timeout) + + async def _close_worker_session(session_key: str) -> None: + try: + await http_client.post( + f"{config.worker_addr}/session/{session_key}/close", timeout=5 + ) + except Exception: + logger.debug("Failed to close worker session %s", session_key) async def _reap_idle_sessions() -> None: while True: await asyncio.sleep(60) now = time.monotonic() stale = [ - k for k, s in sessions.items() if now - s.last_active > session_timeout + k + for k, s in sessions.items() + if now - s.last_active > config.session_timeout ] for k in stale: del sessions[k] + await _close_worker_session(k) if stale: logger.info("Reaped %d idle sessions", len(stale)) @@ -57,17 +67,11 @@ async def health(): return { "status": "ok", "active_sessions": len(sessions), - "worker_addr": worker_addr, + "worker_addr": config.worker_addr, } @app.post("/session/{session_key}/turn") async def turn(session_key: str, body: dict[str, Any]): - """Process one turn. session_key must be unique per agent session. - - When used with the rollout service, uniqueness is ensured by - ``/rl/start_session``. When used standalone, callers must - generate unique keys (e.g. ``f"{model}:{user_id}"``). - """ session = sessions.get(session_key) if session is None: session = _SessionData() @@ -87,7 +91,7 @@ async def turn(session_key: str, body: dict[str, Any]): "metadata": metadata, } - resp = await http_client.post(f"{worker_addr}/run", json=worker_request) + resp = await http_client.post(f"{config.worker_addr}/run", json=worker_request) resp.raise_for_status() result = resp.json() @@ -138,6 +142,7 @@ async def turn(session_key: str, body: dict[str, Any]): @app.post("/session/{session_key}/close") async def close_session(session_key: str): sessions.pop(session_key, None) + await _close_worker_session(session_key) return {"status": "ok"} @app.get("/session/{session_key}/history") diff --git a/areal/experimental/agent_service/data_proxy/config.py b/areal/experimental/agent_service/data_proxy/config.py new file mode 100644 index 0000000000..45e5a994ee --- /dev/null +++ b/areal/experimental/agent_service/data_proxy/config.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class DataProxyConfig: + host: str = "0.0.0.0" + port: int = 9100 + worker_addr: str = "http://localhost:9000" + request_timeout: float = 600.0 + session_timeout: int = 3600 + log_level: str = "info" diff --git a/areal/experimental/agent_service/gateway/__main__.py b/areal/experimental/agent_service/gateway/__main__.py index a8923d42ef..bc02f532f2 100644 --- a/areal/experimental/agent_service/gateway/__main__.py +++ b/areal/experimental/agent_service/gateway/__main__.py @@ -6,8 +6,10 @@ import uvicorn +from ..auth import DEFAULT_ADMIN_API_KEY from .app import create_gateway_app from .bridge import OpenResponsesBridge, mount_bridge +from .config import GatewayConfig def main() -> None: @@ -15,16 +17,32 @@ def main() -> None: parser.add_argument("--router-addr", required=True, help="Router HTTP address") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8080) - parser.add_argument("--admin-key", default="areal-agent-admin") + parser.add_argument("--admin-api-key", default=DEFAULT_ADMIN_API_KEY) + parser.add_argument("--router-timeout", type=float, default=2.0) + parser.add_argument("--forward-timeout", type=float, default=120.0) + parser.add_argument( + "--log-level", choices=["debug", "info", "warning", "error"], default="info" + ) args = parser.parse_args() - app = create_gateway_app(router_addr=args.router_addr, admin_key=args.admin_key) + config = GatewayConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + router_addr=args.router_addr, + router_timeout=args.router_timeout, + forward_timeout=args.forward_timeout, + log_level=args.log_level, + ) + app = create_gateway_app(config) mount_bridge( app, - OpenResponsesBridge(router_addr=args.router_addr, admin_key=args.admin_key), - admin_key=args.admin_key, + OpenResponsesBridge( + router_addr=config.router_addr, admin_api_key=config.admin_api_key + ), + admin_api_key=config.admin_api_key, ) - uvicorn.run(app, host=args.host, port=args.port) + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) if __name__ == "__main__": diff --git a/areal/experimental/agent_service/gateway/app.py b/areal/experimental/agent_service/gateway/app.py index 0f1c3d1dc0..b3043ad57a 100644 --- a/areal/experimental/agent_service/gateway/app.py +++ b/areal/experimental/agent_service/gateway/app.py @@ -4,6 +4,7 @@ from __future__ import annotations +import hmac import json import traceback @@ -12,7 +13,7 @@ from areal.utils import logging -from ..auth import DEFAULT_ADMIN_KEY, admin_headers +from ..auth import admin_headers from ..protocol import ( FrameType, RequestFrame, @@ -26,6 +27,7 @@ parse_frame, serialize_frame, ) +from .config import GatewayConfig logger = logging.getLogger("AgentGateway") @@ -41,16 +43,17 @@ def _make_accepted_json(request_id: str, run_id: str) -> str: ) -def create_gateway_app(router_addr: str, admin_key: str = DEFAULT_ADMIN_KEY) -> FastAPI: +def create_gateway_app(config: GatewayConfig) -> FastAPI: app = FastAPI(title="AReaL Agent Gateway") - http_client = httpx.AsyncClient(timeout=600.0) - _auth_headers = admin_headers(admin_key) + http_client = httpx.AsyncClient(timeout=config.forward_timeout) + _auth_headers = admin_headers(config.admin_api_key) async def _route(session_key: str) -> str: resp = await http_client.post( - f"{router_addr}/route", + f"{config.router_addr}/route", json={"session_key": session_key}, headers=_auth_headers, + timeout=config.router_timeout, ) resp.raise_for_status() return resp.json()["data_proxy_addr"] @@ -81,7 +84,7 @@ async def health(): @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, token: str = Query(default="")): - if token != admin_key: + if not hmac.compare_digest(token, config.admin_api_key): await websocket.close(code=4001, reason="Invalid admin key") return await websocket.accept() diff --git a/areal/experimental/agent_service/gateway/bridge.py b/areal/experimental/agent_service/gateway/bridge.py index a29c342f29..e47c8d8f35 100644 --- a/areal/experimental/agent_service/gateway/bridge.py +++ b/areal/experimental/agent_service/gateway/bridge.py @@ -14,7 +14,7 @@ from areal.utils import logging -from ..auth import DEFAULT_ADMIN_KEY, admin_headers, make_admin_dependency +from ..auth import DEFAULT_ADMIN_API_KEY, admin_headers, make_admin_dependency from ..protocol import generate_run_id logger = logging.getLogger("AgentBridge") @@ -26,9 +26,11 @@ async def handle_request(self, request: Request) -> Any: ... class OpenResponsesBridge(AgentBridge): - def __init__(self, router_addr: str, admin_key: str = DEFAULT_ADMIN_KEY) -> None: + def __init__( + self, router_addr: str, admin_api_key: str = DEFAULT_ADMIN_API_KEY + ) -> None: self._router_addr = router_addr - self._auth_headers = admin_headers(admin_key) + self._auth_headers = admin_headers(admin_api_key) self._http = httpx.AsyncClient(timeout=600.0) async def close(self) -> None: @@ -157,9 +159,9 @@ def _derive_session_key(user: str, model: str) -> str: def mount_bridge( app: FastAPI, bridge: OpenResponsesBridge, - admin_key: str = DEFAULT_ADMIN_KEY, + admin_api_key: str = DEFAULT_ADMIN_API_KEY, ) -> None: - auth = make_admin_dependency(admin_key) + auth = make_admin_dependency(admin_api_key) @app.post("/v1/responses", dependencies=[Depends(auth)]) async def responses_endpoint(request: Request): diff --git a/areal/experimental/agent_service/gateway/config.py b/areal/experimental/agent_service/gateway/config.py new file mode 100644 index 0000000000..f7ec950fc2 --- /dev/null +++ b/areal/experimental/agent_service/gateway/config.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from ..auth import DEFAULT_ADMIN_API_KEY + + +@dataclass +class GatewayConfig: + host: str = "0.0.0.0" + port: int = 8080 + admin_api_key: str = DEFAULT_ADMIN_API_KEY + router_addr: str = "http://localhost:8081" + router_timeout: float = 2.0 + forward_timeout: float = 120.0 + log_level: str = "info" diff --git a/areal/experimental/agent_service/guard/__init__.py b/areal/experimental/agent_service/guard/__init__.py new file mode 100644 index 0000000000..57f50162ec --- /dev/null +++ b/areal/experimental/agent_service/guard/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Agent Service Guard — process supervisor backed by the shared guard. + +Pure pass-through to ``areal.infra.rpc.guard``. All orchestration logic +(launching Router, Gateway, Worker+DataProxy pairs) lives in the +:mod:`~areal.experimental.agent_service.controller` module. + +Quick start:: + + python -m areal.experimental.agent_service.guard \\ + --experiment-name demo --trial-name run0 \\ + --role agent-guard --worker-index 0 +""" diff --git a/areal/experimental/agent_service/guard/__main__.py b/areal/experimental/agent_service/guard/__main__.py new file mode 100644 index 0000000000..d311f6023b --- /dev/null +++ b/areal/experimental/agent_service/guard/__main__.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""CLI entrypoint: ``python -m areal.experimental.agent_service.guard``""" + +from __future__ import annotations + +from areal.experimental.agent_service.guard.app import ( + _state, + app, +) +from areal.infra.rpc.guard.app import ( + configure_state_from_args, + make_base_parser, + run_server, +) + + +def main(): + parser = make_base_parser( + description="AReaL Agent Service Guard — process supervisor for agent workers" + ) + args, _ = parser.parse_known_args() + + bind_host = configure_state_from_args(_state, args) + + run_server(_state, app, bind_host, args.port) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/agent_service/guard/app.py b/areal/experimental/agent_service/guard/app.py new file mode 100644 index 0000000000..b137feef91 --- /dev/null +++ b/areal/experimental/agent_service/guard/app.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Agent Service Guard backed by the shared guard infrastructure. + +All core guard functionality (port allocation, process forking, health +checks, cleanup) is provided by ``areal.infra.rpc.guard``. This module +creates and exposes the Flask app and shared state instance, following +the same pattern as ``areal.experimental.inference_service.guard``. +""" + +from __future__ import annotations + +from areal.infra.rpc.guard.app import ( + GuardState, + create_app, +) +from areal.infra.rpc.guard.app import ( + cleanup_forked_children as _cleanup_impl, +) +from areal.utils import logging + +logger = logging.getLogger("AgentGuard") + +_state = GuardState() + +app = create_app(_state) + + +def cleanup_forked_children() -> None: + _cleanup_impl(_state) diff --git a/areal/experimental/agent_service/router/__main__.py b/areal/experimental/agent_service/router/__main__.py index d52f77392d..f0203a2f37 100644 --- a/areal/experimental/agent_service/router/__main__.py +++ b/areal/experimental/agent_service/router/__main__.py @@ -6,18 +6,36 @@ import uvicorn +from ..auth import DEFAULT_ADMIN_API_KEY from .app import create_router_app +from .config import RouterConfig def main() -> None: parser = argparse.ArgumentParser(description="Agent Router") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8081) - parser.add_argument("--admin-key", default="areal-agent-admin") + parser.add_argument("--admin-api-key", default=DEFAULT_ADMIN_API_KEY) + parser.add_argument("--poll-interval", type=float, default=5.0) + parser.add_argument("--worker-health-timeout", type=float, default=2.0) + parser.add_argument( + "--log-level", choices=["debug", "info", "warning", "error"], default="info" + ) args = parser.parse_args() + config = RouterConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + poll_interval=args.poll_interval, + worker_health_timeout=args.worker_health_timeout, + log_level=args.log_level, + ) uvicorn.run( - create_router_app(admin_key=args.admin_key), host=args.host, port=args.port + create_router_app(config), + host=config.host, + port=config.port, + log_level=config.log_level, ) diff --git a/areal/experimental/agent_service/router/app.py b/areal/experimental/agent_service/router/app.py index 980ee3b0ff..222c8a6c17 100644 --- a/areal/experimental/agent_service/router/app.py +++ b/areal/experimental/agent_service/router/app.py @@ -12,14 +12,15 @@ from areal.utils import logging -from ..auth import DEFAULT_ADMIN_KEY, make_admin_dependency +from ..auth import make_admin_dependency +from .config import RouterConfig logger = logging.getLogger("AgentRouter") -def create_router_app(admin_key: str = DEFAULT_ADMIN_KEY) -> FastAPI: +def create_router_app(config: RouterConfig) -> FastAPI: app = FastAPI(title="AReaL Agent Router") - auth = make_admin_dependency(admin_key) + auth = make_admin_dependency(config.admin_api_key) registered_proxies: list[str] = [] session_map: dict[str, str] = {} diff --git a/areal/experimental/agent_service/router/client.py b/areal/experimental/agent_service/router/client.py index 7d96a646be..4c5e129be9 100644 --- a/areal/experimental/agent_service/router/client.py +++ b/areal/experimental/agent_service/router/client.py @@ -4,13 +4,15 @@ import httpx -from ..auth import DEFAULT_ADMIN_KEY, admin_headers +from ..auth import DEFAULT_ADMIN_API_KEY, admin_headers class RouterClient: - def __init__(self, router_addr: str, admin_key: str = DEFAULT_ADMIN_KEY) -> None: + def __init__( + self, router_addr: str, admin_api_key: str = DEFAULT_ADMIN_API_KEY + ) -> None: self._addr = router_addr - self._headers = admin_headers(admin_key) + self._headers = admin_headers(admin_api_key) self._http = httpx.AsyncClient(timeout=30.0) async def register(self, addr: str) -> None: diff --git a/areal/experimental/agent_service/router/config.py b/areal/experimental/agent_service/router/config.py new file mode 100644 index 0000000000..ed36e50244 --- /dev/null +++ b/areal/experimental/agent_service/router/config.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from ..auth import DEFAULT_ADMIN_API_KEY + + +@dataclass +class RouterConfig: + host: str = "0.0.0.0" + port: int = 8081 + admin_api_key: str = DEFAULT_ADMIN_API_KEY + poll_interval: float = 5.0 + worker_health_timeout: float = 2.0 + log_level: str = "info" diff --git a/areal/experimental/agent_service/worker/__main__.py b/areal/experimental/agent_service/worker/__main__.py index 6077390589..c14d52eba7 100644 --- a/areal/experimental/agent_service/worker/__main__.py +++ b/areal/experimental/agent_service/worker/__main__.py @@ -1,54 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 -"""``python -m areal.experimental.agent_service.worker``""" +"""``python -m areal.experimental.agent_service.worker`` + +Start a standalone Agent Worker process. The Controller forks this +via Guard to create Worker+DataProxy pairs. + + python -m areal.experimental.agent_service.worker \ + --agent examples.agent_service.agent.ClaudeAgent \ + --host 127.0.0.1 --port 9000 +""" import argparse -import asyncio -import threading -import httpx import uvicorn -from areal.utils.network import format_hostport - from .app import create_worker_app def main() -> None: - from ..data_proxy import create_data_proxy_app - - parser = argparse.ArgumentParser(description="Agent Worker + DataProxy") + parser = argparse.ArgumentParser(description="Agent Worker") parser.add_argument("--agent", required=True, help="Agent import path") - parser.add_argument("--router-addr", required=True, help="Router HTTP address") - parser.add_argument("--worker-port", type=int, default=9000) - parser.add_argument("--proxy-port", type=int, default=9100) parser.add_argument("--host", default="0.0.0.0") - parser.add_argument("--admin-key", default="areal-agent-admin") + parser.add_argument("--port", type=int, default=9000) + parser.add_argument( + "--log-level", choices=["debug", "info", "warning", "error"], default="info" + ) args = parser.parse_args() - worker_addr = f"http://{format_hostport(args.host, args.worker_port)}" - proxy_addr = f"http://{format_hostport(args.host, args.proxy_port)}" - - worker_app = create_worker_app(args.agent) - proxy_app = create_data_proxy_app(worker_addr=worker_addr) - - def run_worker(): - uvicorn.run(worker_app, host=args.host, port=args.worker_port, log_level="info") - - threading.Thread(target=run_worker, daemon=True).start() - - from ..auth import admin_headers - - async def register(): - async with httpx.AsyncClient() as client: - await client.post( - f"{args.router_addr}/register", - json={"addr": proxy_addr}, - headers=admin_headers(args.admin_key), - ) - - asyncio.run(register()) - uvicorn.run(proxy_app, host=args.host, port=args.proxy_port, log_level="info") + app = create_worker_app(args.agent) + uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level) if __name__ == "__main__": diff --git a/areal/experimental/agent_service/worker/app.py b/areal/experimental/agent_service/worker/app.py index 13086653e9..55507bc558 100644 --- a/areal/experimental/agent_service/worker/app.py +++ b/areal/experimental/agent_service/worker/app.py @@ -52,6 +52,19 @@ def create_worker_app( async def health(): return {"status": "ok"} + @app.post("/session/{session_key}/close") + async def close_session(session_key: str): + close_fn = getattr(agent, "close_session", None) + if close_fn is not None: + await close_fn(session_key) + return {"status": "ok"} + + @app.on_event("shutdown") + async def shutdown(): + close_all_fn = getattr(agent, "close_all_sessions", None) + if close_all_fn is not None: + await close_all_fn() + @app.post("/run") async def run(body: dict[str, Any]): request = AgentRequest( diff --git a/areal/experimental/agent_service/worker/config.py b/areal/experimental/agent_service/worker/config.py new file mode 100644 index 0000000000..f3704f0420 --- /dev/null +++ b/areal/experimental/agent_service/worker/config.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class WorkerConfig: + host: str = "0.0.0.0" + port: int = 9000 + agent_cls_path: str = "" + log_level: str = "info" diff --git a/areal/experimental/inference_service/controller/config.py b/areal/experimental/inference_service/controller/config.py index 4f0574e212..2b8832a722 100644 --- a/areal/experimental/inference_service/controller/config.py +++ b/areal/experimental/inference_service/controller/config.py @@ -51,6 +51,7 @@ class GatewayControllerConfig: backend: str = "sglang:d1" scheduling_spec: tuple = field(default_factory=tuple) pause_grace_period: float = 0.5 + n_gpus_per_node: int | None = None # GPUs per physical node; None = single-node # -- OpenAI proxy configuration (for agent-like workflows) --------------- openai: OpenAIProxyConfig = field(default_factory=lambda: OpenAIProxyConfig()) diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index ecb2d67ab8..70bdfba148 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -90,6 +90,24 @@ def __init__( # Parse allocation from config.backend self.rollout_alloc = ModelAllocation.from_str(config.backend) + # Multi-node: derive nnodes_per_instance from n_gpus_per_node + total_gpus = ( + self.rollout_alloc.parallel.tp_size * self.rollout_alloc.parallel.pp_size + ) + n_gpus_per_node = config.n_gpus_per_node + if n_gpus_per_node is None: + nnodes_per_instance = 1 + else: + if n_gpus_per_node < 1: + raise ValueError(f"n_gpus_per_node must be >= 1, got {n_gpus_per_node}") + if total_gpus % n_gpus_per_node != 0: + raise ValueError( + f"tp_size * pp_size ({total_gpus}) must be divisible " + f"by n_gpus_per_node ({n_gpus_per_node})" + ) + nnodes_per_instance = total_gpus // n_gpus_per_node + self._nnodes_per_instance = nnodes_per_instance + # Worker management self.workers: list[Worker] = [] self.server_infos: list[LocalInfServerInfo] = [] @@ -224,27 +242,32 @@ async def _async_initialize( inf_backend = alloc.backend # ================================================================== - # Step 0: Always create dp_size RPCGuard workers + # Step 0: Create RPCGuard workers (dp_size × nnodes_per_instance) # ================================================================== inf_spec = SchedulingSpec(**asdict(cfg.scheduling_spec[0])) instance_size = alloc.parallel.tp_size * alloc.parallel.pp_size + nnodes_per_instance = self._nnodes_per_instance + gpus_per_worker = instance_size // nnodes_per_instance + if server_infos is not None: - # Pre-existing inference servers — RPCGuard workers only host - # CPU services (data proxy, router, gateway), no GPUs needed. + # Pre-existing inference servers — only need dp_size workers + # for CPU services (data proxy, router, gateway), no GPUs. + total_workers = dp_size inf_spec.gpu = 0 else: - inf_spec.cpu *= instance_size - inf_spec.mem *= instance_size + total_workers = dp_size * nnodes_per_instance + inf_spec.cpu *= gpus_per_worker + inf_spec.mem *= gpus_per_worker if inf_spec.gpu > 0: - inf_spec.gpu = instance_size + inf_spec.gpu = gpus_per_worker # Override cmd to launch RPCGuard instead of RPC server inf_spec.cmd = "python -m areal.experimental.inference_service.guard" inf_role = f"{self._worker_role}{self._INF_SUFFIX}" inf_job = Job( - replicas=dp_size, - tasks=[inf_spec for _ in range(dp_size)], + replicas=total_workers, + tasks=[inf_spec for _ in range(total_workers)], scheduling_strategy=SchedulingStrategy(), role=inf_role, ) @@ -252,6 +275,11 @@ async def _async_initialize( self.scheduler.create_workers(job=inf_job) self._service_roles.append(inf_role) inf_workers = self.scheduler.get_workers(role=inf_role) + if len(inf_workers) != total_workers: + raise RuntimeError( + f"Expected {total_workers} workers for role {inf_role!r}, " + f"got {len(inf_workers)}" + ) self.workers = inf_workers logger.info("RPCGuard workers ready: %s", [w.id for w in inf_workers]) @@ -291,13 +319,22 @@ async def _async_initialize( v, ) - def _build_launch_cmd(host: str, port: int) -> list[str]: + def _build_launch_cmd( + host: str | None, + port: int | None, + n_nodes: int = 1, + node_rank: int = 0, + dist_init_addr: str | None = None, + ) -> list[str]: return SGLangConfig.build_cmd( sglang_config=sglang_config, tp_size=tp_size, base_gpu_id=0, host=host, port=port, + dist_init_addr=dist_init_addr, + n_nodes=n_nodes, + node_rank=node_rank, ) elif inf_backend == "vllm": @@ -315,79 +352,138 @@ def _build_launch_cmd(host: str, port: int) -> list[str]: v, ) - def _build_launch_cmd(host: str, port: int) -> list[str]: + def _build_launch_cmd( + host: str | None, + port: int | None, + n_nodes: int = 1, + node_rank: int = 0, + dist_init_addr: str | None = None, + ) -> list[str]: return vLLMConfig.build_cmd( vllm_config=vllm_config, tp_size=tp_size, pp_size=alloc.parallel.pp_size, host=host, port=port, + dist_init_addr=dist_init_addr, + n_nodes=n_nodes, + node_rank=node_rank, ) else: raise ValueError(f"Unsupported inference backend: {inf_backend!r}") - # For each RPCGuard worker: alloc port, build cmd, fork server - for rank, worker in enumerate(inf_workers): - guard_addr = ( - f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" - ) - - resp = requests.post( - f"{guard_addr}/alloc_ports", - json={"count": 1}, - timeout=30, - ) - resp.raise_for_status() - port_data = resp.json() - inf_host = port_data["host"] - inf_port = port_data["ports"][0] - - cmd = _build_launch_cmd(inf_host, inf_port) - - fork_payload: dict[str, Any] = { - "role": "inf-server", - "worker_index": rank, - "raw_cmd": cmd, - } - if inf_backend == "vllm": - from areal.infra.utils.launcher import ( - TRITON_CACHE_PATH as _TRITON_CACHE, + # For each inference instance group: alloc ports, build cmd, fork servers + for group_idx in range(dp_size): + group_workers = inf_workers[ + group_idx * nnodes_per_instance : (group_idx + 1) + * nnodes_per_instance + ] + head_worker = group_workers[0] + head_guard_addr = f"http://{format_hostport(head_worker.ip, int(head_worker.worker_ports[0]))}" + + # Allocate rendezvous port on head node for distributed init + dist_init_addr = None + if nnodes_per_instance > 1: + resp = requests.post( + f"{head_guard_addr}/alloc_ports", + json={"count": 1}, + timeout=30, ) - from areal.infra.utils.launcher import ( - VLLM_CACHE_ROOT as _VLLM_CACHE, + resp.raise_for_status() + rendezvous_data = resp.json() + rendezvous_host = rendezvous_data["host"] + rendezvous_port = rendezvous_data["ports"][0] + dist_init_addr = format_hostport(rendezvous_host, rendezvous_port) + + head_inf_host = None + head_inf_port = None + + for node_rank, worker in enumerate(group_workers): + guard_addr = f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" + + # Allocate port for inference server on this node + resp = requests.post( + f"{guard_addr}/alloc_ports", + json={"count": 1}, + timeout=30, + ) + resp.raise_for_status() + port_data = resp.json() + inf_host = port_data["host"] + inf_port = port_data["ports"][0] + + # Worker nodes (rank > 0) don't need to serve HTTP, + # but we still pass host/port for the server to bind + cmd = _build_launch_cmd( + host=inf_host, + port=inf_port, + n_nodes=nnodes_per_instance, + node_rank=node_rank, + dist_init_addr=dist_init_addr, ) - fork_payload["env"] = { - "TRITON_CACHE_PATH": os.path.join( - os.environ.get("TRITON_CACHE_PATH", _TRITON_CACHE), - str(uuid.uuid4()), - ), - "VLLM_CACHE_ROOT": os.path.join( - os.environ.get("VLLM_CACHE_ROOT", _VLLM_CACHE), - str(uuid.uuid4()), - ), - "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True", + fork_payload: dict[str, Any] = { + "role": "inf-server", + "worker_index": group_idx * nnodes_per_instance + node_rank, + "raw_cmd": cmd, } + if inf_backend == "vllm": + from areal.infra.utils.launcher import ( + TRITON_CACHE_PATH as _TRITON_CACHE, + ) + from areal.infra.utils.launcher import ( + VLLM_CACHE_ROOT as _VLLM_CACHE, + ) - resp = requests.post( - f"{guard_addr}/fork", - json=fork_payload, - timeout=30, - ) - resp.raise_for_status() + fork_payload["env"] = { + "TRITON_CACHE_PATH": os.path.join( + os.environ.get("TRITON_CACHE_PATH", _TRITON_CACHE), + str(uuid.uuid4()), + ), + "VLLM_CACHE_ROOT": os.path.join( + os.environ.get("VLLM_CACHE_ROOT", _VLLM_CACHE), + str(uuid.uuid4()), + ), + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True", + } + + resp = requests.post( + f"{guard_addr}/fork", + json=fork_payload, + timeout=30, + ) + resp.raise_for_status() + self._forked_services.append( + ( + guard_addr, + "inf-server", + group_idx * nnodes_per_instance + node_rank, + ) + ) + + if node_rank == 0: + head_inf_host = inf_host + head_inf_port = inf_port - addr = f"http://{format_hostport(inf_host, inf_port)}" + if head_inf_host is None or head_inf_port is None: + raise RuntimeError( + f"No head worker resolved for group {group_idx}; " + f"expected {nnodes_per_instance} workers per group" + ) + + # Only record the head node's address as the inference endpoint + addr = f"http://{format_hostport(head_inf_host, head_inf_port)}" self._inf_addrs.append(addr) self.server_infos.append( LocalInfServerInfo( - host=inf_host, - port=inf_port, + host=head_inf_host, + port=head_inf_port, process=None, # type: ignore[arg-type] # RPCGuard manages process ) ) - # Wait for inference servers to be healthy + # Wait for inference servers to be healthy (only head nodes) for i, addr in enumerate(self._inf_addrs): self._wait_for_service( f"{addr}/health", f"InfServer-{i}", timeout=cfg.setup_timeout @@ -442,21 +538,24 @@ def _build_launch_cmd(host: str, port: int) -> list[str]: f"http://{self.callback_addr}", ] - for rank, worker in enumerate(inf_workers): - guard_addr = ( - f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" - ) - # Each data proxy connects to its corresponding inference server + for group_idx in range(dp_size): + head_worker = inf_workers[ + group_idx + if server_infos is not None + else group_idx * nnodes_per_instance + ] + guard_addr = f"http://{format_hostport(head_worker.ip, int(head_worker.worker_ports[0]))}" + # Each data proxy connects to its group's head inference server data_proxy_cmd = data_proxy_base_cmd + [ "--backend-addr", - self._inf_addrs[rank], + self._inf_addrs[group_idx], "--backend-type", inf_backend or "sglang", ] data_proxy_host, data_proxy_port = self._fork_on_guard( guard_addr=guard_addr, role="data-proxy", - worker_index=rank, + worker_index=group_idx, raw_cmd=data_proxy_cmd, ) self._data_proxy_addrs.append( diff --git a/areal/infra/utils/proc.py b/areal/infra/utils/proc.py index 3f93a08ec1..22576e3263 100644 --- a/areal/infra/utils/proc.py +++ b/areal/infra/utils/proc.py @@ -4,6 +4,7 @@ import os import shlex +import shutil import signal import subprocess import sys @@ -61,23 +62,30 @@ def build_streaming_log_cmd( else: cmd_str = cmd + # Check if stdbuf is available (not present on macOS by default) + _has_stdbuf = shutil.which("stdbuf") is not None + # Build prefix with env vars if provided prefix_parts = [] if env_vars: prefix_parts.append( " ".join(f"{k}={shlex.quote(str(v))}" for k, v in env_vars.items()) ) - prefix_parts.append(f"stdbuf -oL {cmd_str}") + if _has_stdbuf: + prefix_parts.append(f"stdbuf -oL {cmd_str}") + else: + prefix_parts.append(cmd_str) full_cmd = " ".join(prefix_parts) # Build log prefix for merged log log_prefix = f"[{role}]".ljust(LOG_PREFIX_WIDTH) # Construct tee/sed pipeline - shell_cmd = ( - f"{full_cmd} 2>&1 " - f"| tee -a {log_file} >(stdbuf -oL sed 's/^/{log_prefix}/' >> {merged_log})" - ) + if _has_stdbuf: + sed_prefix = f"stdbuf -oL sed 's/^/{log_prefix}/'" + else: + sed_prefix = f"sed 's/^/{log_prefix}/'" + shell_cmd = f"{full_cmd} 2>&1 | tee -a {log_file} >({sed_prefix} >> {merged_log})" return shell_cmd diff --git a/examples/agent_service/README.md b/examples/agent_service/README.md index 04fe12975a..563064b4e5 100644 --- a/examples/agent_service/README.md +++ b/examples/agent_service/README.md @@ -1,177 +1,114 @@ -# Agent Service Demo — Tau2 with PydanticAI +# Agent Service — Claude Agent SDK ## Overview -This example demonstrates AReaL's Agent Service running a **tau2 customer-service -agent** powered by **PydanticAI**. The agent handles multi-turn conversations, calls -tau2 environment tools (e.g. flight lookup, reservation booking), and maintains -conversation history across turns. +This example demonstrates AReaL's Agent Service running the **Claude Agent SDK** +(`claude-agent-sdk`) as a scalable HTTP micro-service. It turns Claude's autonomous +agent capabilities — multi-turn conversations, tool use, file editing, web search — into +a production-deployable service with session management, load balancing, and dynamic +scaling. -The Agent Service consists of four independent HTTP services: +**Why this matters**: Projects like +[claude-agent-acp](https://github.com/agentclientprotocol/claude-agent-acp) expose +Claude Agent SDK via custom protocols (ACP) for editor integration. AReaL takes a +different approach — it wraps Claude Agent SDK into standard HTTP micro-services with +session-affine routing, so you can **scale, orchestrate, and train** Claude agents using +AReaL's RL infrastructure. ``` -Client → Gateway (8080) → Router (8081) → DataProxy (9100) → Worker (9000) +Client → Gateway (HTTP) → Router → DataProxy (session state) → Worker (ClaudeSDKClient) ``` -- **Gateway**: public entry point (WebSocket + OpenResponses HTTP bridge) -- **Router**: session-affine routing (DataProxy registration, round-robin) -- **DataProxy**: stateful session proxy (conversation history, forwards to Worker) -- **Worker**: stateless agent execution (loads AgentRunnable, runs one turn) - -## Architecture - -``` -Client (HTTP/WS) - │ - ▼ -┌──────────┐ POST /route ┌──────────┐ -│ Gateway │ ──────────────▶ │ Router │ -│ :8080 │ ◀────────────── │ :8081 │ -└──────────┘ DataProxy addr └──────────┘ - │ - │ POST /session/{key}/turn - ▼ -┌──────────┐ -│ DataProxy│ -│ :9100 │ POST /run ┌──────────┐ -│ (history)│ ────────────▶│ Worker │ -└──────────┘ │ :9000 │ - │ (agent) │ - └──────────┘ -``` - -## Files - -| File | Description | -| ------------- | ----------------------------------------------------- | -| `agent.py` | `Tau2Agent` — PydanticAI agent with tau2 domain tools | -| `config.yaml` | Configuration: LLM endpoints, tau2 domain, data path | -| `run_demo.py` | One-click: starts all services, runs tau2 demo | - ## Prerequisites ```bash -pip install pydantic-ai -pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion +uv pip install claude-agent-sdk +export ANTHROPIC_API_KEY=sk-... ``` -## Configuration - -Edit `config.yaml` to set your LLM endpoints and tau2 settings: - -```yaml -tau2: - domain: airline - data_dir: /path/to/tau2-bench/data - -agent_llm: - model: openai:your-model-name - base_url: http://localhost:8000/v1 - api_key: unused - -user_llm: - model: null # set for user simulator, null for scripted messages - base_url: null - api_key: unused -``` - -Alternatively, set `TAU2_DATA_DIR` as an environment variable. - ## Quick Start -### One-click demo - ```bash -python examples/agent_service/run_demo.py # single task, airline -python examples/agent_service/run_demo.py --domain telecom # different domain -python examples/agent_service/run_demo.py --full # all tasks -python examples/agent_service/run_demo.py --config my.yaml # custom config +python examples/agent_service/run_agent_service.py ``` -This starts all four services in background threads and runs a multi-turn conversation -showing tool calls and history accumulation. +The script creates a `LocalScheduler`, launches Guard workers, then forks Router → +Worker+DataProxy → Gateway. An interactive prompt lets you chat with the Claude agent. -### Manual startup (separate terminals) +### Options ```bash -# Terminal 1: Router -python -m areal.experimental.agent_service.router --port 8081 - -# Terminal 2: Worker + DataProxy -python -m areal.experimental.agent_service.worker \ - --agent examples.agent_service.agent.Tau2Agent \ - --router-addr http://localhost:8081 \ - --worker-port 9000 \ - --proxy-port 9100 - -# Terminal 3: Gateway -python -m areal.experimental.agent_service.gateway \ - --router-addr http://localhost:8081 \ - --port 8080 +python examples/agent_service/run_agent_service.py --num-pairs 4 ``` -### Send a request +### Send requests directly ```bash -curl -X POST http://localhost:8080/v1/responses \ +curl -X POST http://localhost:/v1/responses \ -H "Content-Type: application/json" \ + -H "Authorization: Bearer areal-agent-admin" \ -d '{ - "input": [{"type": "message", "content": "I need to change my flight AA123"}], - "model": "tau2-agent", + "input": [{"type": "message", "content": "Explain RLHF in simple terms"}], + "model": "claude-agent", "user": "my-session" }' ``` -## Implementing Your Own Agent +## Configuration -Create a class that satisfies the `AgentRunnable` protocol: +Claude Agent SDK settings are controlled via environment variables: -```python -from areal.experimental.agent_service.agent_worker import ( - AgentRequest, AgentResponse, EventEmitter, -) +| Variable | Default | Description | +| ---------------------- | ------------------- | --------------------------- | +| `ANTHROPIC_API_KEY` | (required) | Anthropic API key | +| `CLAUDE_MODEL` | `claude-sonnet-4-6` | Model to use | +| `CLAUDE_SYSTEM_PROMPT` | (none) | Optional system prompt | +| `CLAUDE_MAX_TURNS` | `20` | Max agentic turns per query | -class MyAgent: - def __init__(self, **kwargs): - # Configure LLM client, tools, etc. - pass - - async def run( - self, - request: AgentRequest, - *, - emitter: EventEmitter, - ) -> AgentResponse: - # request.message — current user message - # request.history — prior conversation turns - # emitter — stream events back to client - await emitter.emit_delta("Hello!") - return AgentResponse(summary="Hello!") -``` +## Architecture -Then start a worker with your agent: +The Worker maintains a **session-persistent `ClaudeSDKClient`** per session key. Unlike +stateless wrappers, the SDK's internal session retains the full conversation transcript +— no need to re-send history on each turn. -```bash -python -m areal.experimental.agent_service.worker \ - --agent mypackage.myagent.MyAgent \ - --router-addr http://localhost:8081 +``` +Turn 1: Client → Gateway → Router → DataProxy → Worker + Worker: creates ClaudeSDKClient for session "abc" + Claude Agent SDK runs autonomously (tool calls, file ops, etc.) + Response streams back through the chain + +Turn 2: Client → Gateway → Router (same DataProxy) → DataProxy → Worker + Worker: reuses ClaudeSDKClient for session "abc" + SDK remembers full context from Turn 1 ``` -## Multi-turn Conversations - -The DataProxy automatically manages conversation history. Each turn: - -1. DataProxy reads history for the session -1. Builds `AgentRequest` with `history` field populated -1. Forwards to Worker → Agent sees full conversation context -1. Appends user message + agent response to history -1. Tool calls and results are also recorded in history - -The agent accesses history via `request.history`: +## Programmatic Usage ```python -async def run(self, request, *, emitter): - for msg in request.history: - print(f"{msg['role']}: {msg['content']}") - # ... generate response using full context +from areal.experimental.agent_service.controller import ( + AgentServiceController, + AgentServiceControllerConfig, +) +from areal.infra.scheduler.local import LocalScheduler + +scheduler = LocalScheduler(experiment_name="demo", trial_name="run0", gpu_devices=[]) +ctrl = AgentServiceController( + config=AgentServiceControllerConfig( + agent_cls_path="examples.agent_service.agent.ClaudeAgent", + num_pairs=2, + ), + scheduler=scheduler, +) +ctrl.initialize() +# ctrl.gateway_addr → "http://10.0.0.1:9005" +# ctrl.scale_up(2) → add 2 more pairs +# ctrl.scale_down(1) → remove 1 pair (with graceful drain) +ctrl.destroy() ``` + +## Files + +| File | Description | +| ---------------------- | ----------------------------------------------------------- | +| `agent.py` | `ClaudeAgent` — session-persistent Claude Agent SDK wrapper | +| `run_agent_service.py` | Controller-based launcher + interactive conversation | diff --git a/examples/agent_service/agent.py b/examples/agent_service/agent.py index 2c2beb5df8..c05f3bebe5 100644 --- a/examples/agent_service/agent.py +++ b/examples/agent_service/agent.py @@ -1,26 +1,34 @@ -"""Tau2 Agent for AReaL Agent Service (PydanticAI). +"""Claude Agent for AReaL Agent Service. -Implements :class:`AgentRunnable` using PydanticAI. Each call to ``run()`` -handles a **single turn** of a tau2 customer-service dialogue. The agent -uses tau2 environment tools (registered as PydanticAI function tools) and -maintains conversation context via ``request.history``. +Implements :class:`AgentRunnable` using the Claude Agent SDK +(``claude-agent-sdk``). Each Worker instance holds a pool of +:class:`ClaudeSDKClient` sessions keyed by ``session_key``, so multi-turn +conversations preserve full context without re-sending history. -Requires: ``pip install pydantic-ai tau2-bench`` +Requires:: + + pip install claude-agent-sdk + +Environment variables: + ANTHROPIC_API_KEY — Anthropic API key (required) + CLAUDE_MODEL — model name (default: claude-sonnet-4-6) + CLAUDE_SYSTEM_PROMPT — optional system prompt override + CLAUDE_MAX_TURNS — max agentic turns per query (default: 20) """ from __future__ import annotations -import inspect -import json import os -from typing import Any - -from pydantic_ai import Agent -from pydantic_ai.models.openai import OpenAIChatModel -from pydantic_ai.providers.openai import OpenAIProvider -from tau2.environment.environment import Environment -from tau2.environment.tool import Tool as Tau2Tool -from tau2.registry import registry +from typing import Any, Literal + +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ClaudeSDKClient, + ResultMessage, + TextBlock, + ToolUseBlock, +) from areal.experimental.agent_service.types import ( AgentRequest, @@ -29,101 +37,66 @@ ) from areal.utils import logging -logger = logging.getLogger("Tau2Agent") +logger = logging.getLogger("ClaudeAgent") +PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"] -def _make_pydantic_tool(tau2_tool: Tau2Tool): - """Create a plain async function from a tau2 Tool for PydanticAI.""" - fn = tau2_tool._func # noqa: SLF001 - name = tau2_tool.name - doc = tau2_tool.openai_schema["function"].get("description", name) +_DEFAULT_PERMISSION_MODE: PermissionMode = "bypassPermissions" - async def _wrapper(**kwargs: Any) -> str: - result = fn(**kwargs) - if not isinstance(result, str): - result = json.dumps(result, default=str) - return result - - _wrapper.__name__ = name - _wrapper.__qualname__ = name - _wrapper.__doc__ = doc - - sig = inspect.signature(fn) - params = [ - inspect.Parameter( - pname, - kind=inspect.Parameter.KEYWORD_ONLY, - default=param.default, - annotation=param.annotation, - ) - for pname, param in sig.parameters.items() - ] - _wrapper.__signature__ = inspect.Signature(params) # type: ignore[attr-defined] - if hasattr(fn, "__annotations__"): - _wrapper.__annotations__ = { - k: v for k, v in fn.__annotations__.items() if k != "return" - } - return _wrapper +class ClaudeAgent: + """AgentRunnable backed by the Claude Agent SDK. -def _think_tool_fn(thoughts: str) -> str: - """Use this tool to think. Only use when necessary.""" - return "Your thoughts are recorded. Please continue your work." - - -class Tau2Agent: - """AgentRunnable that wraps a PydanticAI Agent with tau2 tools. - - Accepts a ``config`` dict (loaded from config.yaml by run_demo.py). - Falls back to environment variables if config is not provided. + Maintains a ``ClaudeSDKClient`` per session for true multi-turn + continuity — the SDK's internal session keeps the full transcript, + so ``request.history`` is only used for the very first turn of a + new session (to seed context if provided by the caller). """ - def __init__(self, config: dict | None = None, **kwargs: Any) -> None: - config = config or {} - tau2_cfg = config.get("tau2", {}) - agent_llm_cfg = config.get("agent_llm", {}) - - self._domain = tau2_cfg.get("domain") or os.environ.get( - "TAU2_DOMAIN", "airline" - ) - add_thinking = tau2_cfg.get("add_thinking_tool", False) - - data_dir = tau2_cfg.get("data_dir") or os.environ.get("TAU2_DATA_DIR") - if data_dir: - os.environ["TAU2_DATA_DIR"] = data_dir + def __init__(self, **kwargs: Any) -> None: + self._model = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6") + self._system_prompt = os.environ.get("CLAUDE_SYSTEM_PROMPT", "") + self._max_turns = int(os.environ.get("CLAUDE_MAX_TURNS", "20")) + self._permission_mode: PermissionMode = _DEFAULT_PERMISSION_MODE - env = self._build_environment() - tau2_tools: list[Tau2Tool] = env.get_tools() - if add_thinking: - tau2_tools.append(Tau2Tool(_think_tool_fn)) - - tools = [_make_pydantic_tool(t) for t in tau2_tools] - system_prompt = env.get_policy() - - model_name = agent_llm_cfg.get("model", "openai:default") - base_url = agent_llm_cfg.get("base_url") - api_key = agent_llm_cfg.get("api_key", "unused") - - if base_url: - model: Any = OpenAIChatModel( - model_name.replace("openai:", ""), - provider=OpenAIProvider(base_url=base_url, api_key=api_key), - ) - else: - model = model_name - - self._agent = Agent(model, system_prompt=system_prompt, tools=tools) + self._sessions: dict[str, ClaudeSDKClient] = {} logger.info( - "Tau2Agent initialized (domain=%s, tools=%d, model=%s)", - self._domain, - len(tools), - model_name, + "ClaudeAgent initialized (model=%s, max_turns=%d)", + self._model, + self._max_turns, ) - def _build_environment(self) -> Environment: - constructor = registry.get_env_constructor(self._domain) - return constructor(solo_mode=False) + def _make_options(self) -> ClaudeAgentOptions: + opts = ClaudeAgentOptions( + model=self._model, + max_turns=self._max_turns, + permission_mode=self._permission_mode, + ) + if self._system_prompt: + opts.system_prompt = self._system_prompt + return opts + + async def _get_or_create_client(self, session_key: str) -> ClaudeSDKClient: + if session_key not in self._sessions: + client = ClaudeSDKClient(options=self._make_options()) + await client.__aenter__() + self._sessions[session_key] = client + logger.info("New session: %s", session_key) + return self._sessions[session_key] + + async def close_session(self, session_key: str) -> None: + client = self._sessions.pop(session_key, None) + if client is not None: + try: + await client.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing session %s", session_key, exc_info=True) + + async def close_all_sessions(self) -> None: + keys = list(self._sessions.keys()) + for key in keys: + await self.close_session(key) async def run( self, @@ -131,87 +104,36 @@ async def run( *, emitter: EventEmitter, ) -> AgentResponse: - from pydantic_ai.messages import ( - ModelRequest, - TextPart, - ToolCallPart, - ToolReturnPart, - UserPromptPart, - ) - from pydantic_ai.messages import ( - ModelResponse as PAModelResponse, - ) - - message_history: list[ModelRequest | PAModelResponse] = [] - for msg in request.history: - role = msg.get("role", "user") - content = msg.get("content", "") - - if role == "user": - message_history.append( - ModelRequest(parts=[UserPromptPart(content=content or "")]) - ) - elif role == "assistant": - tool_calls = msg.get("tool_calls") - if tool_calls: - parts = [] - for tc in tool_calls: - fn = tc.get("function", tc) - parts.append( - ToolCallPart( - tool_name=fn.get("name", ""), - args=fn.get("arguments", ""), - tool_call_id=tc.get("id", ""), + client = await self._get_or_create_client(request.session_key) + + try: + await client.query(request.message) + + text_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + await emitter.emit_delta(block.text) + text_parts.append(block.text) + elif isinstance(block, ToolUseBlock): + await emitter.emit_tool_call( + name=block.name, + args=str(block.input), ) - ) - message_history.append(PAModelResponse(parts=parts)) - elif content: - message_history.append( - PAModelResponse(parts=[TextPart(content=content)]) - ) - elif role == "tool": - tool_call_id = msg.get("tool_call_id", "") - message_history.append( - ModelRequest( - parts=[ - ToolReturnPart( - tool_name=tool_call_id, - content=content or "", - tool_call_id=tool_call_id, + tool_calls.append( + {"name": block.name, "input": block.input} ) - ] - ) - ) - - result = await self._agent.run( - request.message, - message_history=message_history, - ) + elif isinstance(msg, ResultMessage): + break - final_text = str(result.output) if result.output else "" - - tool_calls: list[dict[str, Any]] = [] - for msg in result.new_messages(): - if not hasattr(msg, "parts"): - continue - for part in msg.parts: - kind = getattr(part, "part_kind", "") - if kind == "tool-call": - name = getattr(part, "tool_name", "") - args = getattr(part, "args", "") - if isinstance(args, dict): - args = json.dumps(args) - await emitter.emit_tool_call(name=name, args=str(args)) - tool_calls.append({"name": name, "arguments": args}) - elif kind == "tool-return": - name = getattr(part, "tool_name", "") - content = str(getattr(part, "content", "")) - await emitter.emit_tool_result(name=name, result=content) - - if final_text: - await emitter.emit_delta(final_text) - - return AgentResponse( - summary=final_text[:200], - metadata={"tool_calls": tool_calls}, - ) + summary = "".join(text_parts) + return AgentResponse( + summary=summary[:200], + metadata={"tool_calls": tool_calls}, + ) + except Exception: + await self.close_session(request.session_key) + raise diff --git a/examples/agent_service/config.yaml b/examples/agent_service/config.yaml deleted file mode 100644 index 23cebf9f40..0000000000 --- a/examples/agent_service/config.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# Agent Service Demo Configuration - -# Admin key for inter-service authentication (Router, Gateway, Worker). -# Change from default for non-local deployments. -admin_key: areal-agent-admin - -# tau2 environment settings -tau2: - domain: airline # airline | retail | telecom - data_dir: None # path to tau2 data dir (or set TAU2_DATA_DIR env var) - add_thinking_tool: false - -# Agent LLM — the model the agent uses for reasoning + tool calls. -# For the demo this points to a local/self-hosted model. -agent_llm: - model: openai:Ling-2.6-1T # PydanticAI model string - base_url: None # e.g. http://localhost:8000/v1 - api_key: None - -# User simulator LLM — drives the simulated customer. -# Set to null to use scripted user messages instead. -user_llm: - model: openai:GLM-5 # e.g. openai:Qwen2.5-72B - base_url: None # e.g. http://localhost:8001/v1 - api_key: None diff --git a/examples/agent_service/run_agent_service.py b/examples/agent_service/run_agent_service.py new file mode 100644 index 0000000000..e96f83f501 --- /dev/null +++ b/examples/agent_service/run_agent_service.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Launch the Agent Service with Claude Agent SDK. + +Usage:: + + python examples/agent_service/run_agent_service.py + python examples/agent_service/run_agent_service.py --num-pairs 2 + +Requires:: + + uv pip install claude-agent-sdk + export ANTHROPIC_API_KEY=sk-... +""" + +from __future__ import annotations + +import argparse +import asyncio +import time + +import httpx + +from areal.experimental.agent_service.controller import ( + AgentServiceController, + AgentServiceControllerConfig, +) + + +async def _wait_healthy(url: str, timeout: float = 60.0) -> None: + async with httpx.AsyncClient() as client: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = await client.get(url) + if resp.status_code == 200: + return + except httpx.ConnectError: + pass + await asyncio.sleep(0.5) + raise TimeoutError(f"Service at {url} did not become healthy") + + +async def interactive_loop(gateway_addr: str, admin_key: str) -> None: + session_key = f"session-{int(time.time())}" + print("Type your message (or 'quit' to exit):\n") + + async with httpx.AsyncClient(timeout=120.0) as client: + while True: + try: + user_input = input("You: ") + except (EOFError, KeyboardInterrupt): + break + if user_input.strip().lower() in ("quit", "exit", "q"): + break + if not user_input.strip(): + continue + + resp = await client.post( + f"{gateway_addr}/v1/responses", + json={ + "input": [{"type": "message", "content": user_input}], + "model": "claude-agent", + "user": session_key, + }, + headers={"Authorization": f"Bearer {admin_key}"}, + ) + data = resp.json() + + if data.get("status") == "completed": + for item in data.get("output", []): + if item.get("type") == "message": + for block in item.get("content", []): + if block.get("type") == "output_text": + print(f"Agent: {block['text']}") + elif item.get("type") == "function_call": + print(f"[tool] {item.get('name', '')}") + print() + elif data.get("error"): + print(f"Error: {data['error'].get('message', '')[:200]}\n") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Agent Service — Claude Agent SDK") + parser.add_argument( + "--num-pairs", + type=int, + default=1, + help="Number of Worker+DataProxy pairs (default: 1)", + ) + parser.add_argument( + "--admin-api-key", + default="areal-agent-admin", + help="Admin API key for inter-service auth", + ) + args = parser.parse_args() + + from areal.infra.scheduler.local import LocalScheduler + + scheduler = LocalScheduler( + experiment_name="agent-service-demo", + trial_name="run0", + gpu_devices=[], + ) + + ctrl_config = AgentServiceControllerConfig( + agent_cls_path="examples.agent_service.agent.ClaudeAgent", + admin_api_key=args.admin_api_key, + num_pairs=args.num_pairs, + ) + ctrl = AgentServiceController(config=ctrl_config, scheduler=scheduler) + + try: + print(f"Initializing with {args.num_pairs} pair(s) ...") + ctrl.initialize() + print(f" Router: {ctrl.router_addr}") + print(f" Gateway: {ctrl.gateway_addr}") + print(f" Pairs: {len(ctrl.pairs)}") + + asyncio.run(_wait_healthy(f"{ctrl.gateway_addr}/health")) + print("All services ready.\n") + + asyncio.run(interactive_loop(ctrl.gateway_addr, admin_key=args.admin_api_key)) + finally: + print("\nShutting down ...") + ctrl.destroy() + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/agent_service/run_demo.py b/examples/agent_service/run_demo.py deleted file mode 100644 index 707812461c..0000000000 --- a/examples/agent_service/run_demo.py +++ /dev/null @@ -1,298 +0,0 @@ -"""One-click demo: Agent Service + Tau2 (PydanticAI). - -Usage:: - - python examples/agent_service/run_demo.py # single task - python examples/agent_service/run_demo.py --domain telecom # different domain - python examples/agent_service/run_demo.py --full # all tasks - python examples/agent_service/run_demo.py --config my.yaml # custom config - -Requires:: - - pip install pydantic-ai - pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion -""" - -from __future__ import annotations - -import argparse -import asyncio -import os -import threading -import time -from pathlib import Path -from typing import Any -from unittest.mock import patch - -import httpx -import uvicorn -import yaml - -from areal.experimental.agent_service import ( - OpenResponsesBridge, - create_data_proxy_app, - create_gateway_app, - create_router_app, - create_worker_app, - mount_bridge, -) - -ROUTER_PORT = 18081 -WORKER_PORT = 19000 -PROXY_PORT = 19100 -GATEWAY_PORT = 18080 - -DEFAULT_CONFIG = Path(__file__).parent / "config.yaml" - - -def _load_config(path: str | Path) -> dict[str, Any]: - with open(path) as f: - return yaml.safe_load(f) or {} - - -def _start_in_thread(app, port: int, name: str) -> threading.Thread: - def run(): - uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning") - - t = threading.Thread(target=run, daemon=True, name=name) - t.start() - return t - - -async def _wait_healthy(url: str, timeout: float = 10.0) -> None: - async with httpx.AsyncClient() as client: - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - try: - resp = await client.get(url) - if resp.status_code == 200: - return - except httpx.ConnectError: - pass - await asyncio.sleep(0.2) - raise TimeoutError(f"Service at {url} did not become healthy") - - -async def run_task(gateway_addr: str, task, domain: str, admin_key: str) -> float: - """Run a single tau2 task. Returns the reward.""" - from tau2.data_model.message import AssistantMessage, UserMessage - from tau2.data_model.simulation import SimulationRun, TerminationReason - from tau2.evaluator.evaluator import EvaluationType, evaluate_simulation - - session_key = f"tau2-{domain}-{task.id}" - print(f"\n Task: {task.id}") - print(f" Scenario: {str(task.user_scenario)[:120]}...") - - scripted_messages = [ - str(task.user_scenario), - "Yes, please go ahead and help me with that.", - "Can you check the status of my request?", - "Thank you, that's all I need.", - ] - - tau2_messages = [] - error_occurred = False - - async with httpx.AsyncClient(timeout=120.0) as client: - for i, msg in enumerate(scripted_messages, 1): - resp = await client.post( - f"{gateway_addr}/v1/responses", - json={ - "input": [{"type": "message", "content": msg}], - "model": "tau2-agent", - "user": session_key, - }, - headers={"Authorization": f"Bearer {admin_key}"}, - ) - data = resp.json() - - tau2_messages.append( - UserMessage(role="user", content=msg, turn_idx=len(tau2_messages)) - ) - - if data.get("status") == "completed": - agent_text = "" - for item in data.get("output", []): - if item.get("type") == "message": - for block in item.get("content", []): - if block.get("type") == "output_text": - agent_text += block["text"] - print(f" [Turn {i}] Agent: {block['text'][:150]}") - elif item.get("type") == "function_call": - print(f" [Turn {i}] [tool] {item.get('name', '')}") - - tau2_messages.append( - AssistantMessage( - role="assistant", - content=agent_text or "(no response)", - turn_idx=len(tau2_messages), - ) - ) - elif data.get("error"): - err = data["error"].get("message", "")[:100] - print(f" [Turn {i}] Error: {err}") - tau2_messages.append( - AssistantMessage( - role="assistant", - content=f"Error: {err}", - turn_idx=len(tau2_messages), - ) - ) - error_occurred = True - break - - reward = 0.0 - if not error_occurred: - try: - simulation = SimulationRun( - id=f"demo-{task.id}", - task_id=task.id, - messages=tau2_messages, - start_time="", - end_time="", - duration=0.0, - termination_reason=TerminationReason.USER_STOP, - ) - reward_info = evaluate_simulation( - simulation=simulation, - task=task, - evaluation_type=EvaluationType.ALL, - solo_mode=False, - domain=domain, - ) - reward = reward_info.reward - except Exception as e: - print(f" Eval error: {e}") - - print(f" Reward: {reward:.3f}") - return reward - - -async def run_demo(gateway_addr: str, domain: str, full: bool, admin_key: str) -> None: - from tau2.registry import registry - - print(f"\n{'=' * 60}") - print(f" Tau2 Agent Service Demo — domain: {domain}") - print(f"{'=' * 60}") - - tasks = registry.get_tasks_loader(domain)(None) - total = len(tasks) - - if not full: - tasks = tasks[:1] - print(f" Running 1 task (use --full for all {total} tasks)") - else: - print(f" Running all {total} tasks") - - rewards = [] - for task in tasks: - reward = await run_task(gateway_addr, task, domain, admin_key=admin_key) - rewards.append((task.id, reward)) - - print(f"\n{'=' * 60}") - print(f" Results — {len(rewards)} task(s)") - print(f"{'=' * 60}") - for task_id, reward in rewards: - print(f" Task {task_id}: reward = {reward:.3f}") - if rewards: - avg = sum(r for _, r in rewards) / len(rewards) - print(f"\n Average reward: {avg:.3f}") - print(f"{'=' * 60}") - - -def main() -> None: - parser = argparse.ArgumentParser(description="Tau2 Agent Service Demo") - parser.add_argument( - "--config", - default=str(DEFAULT_CONFIG), - help=f"Config YAML path (default: {DEFAULT_CONFIG})", - ) - parser.add_argument( - "--domain", - choices=["airline", "retail", "telecom"], - help="Override tau2.domain from config", - ) - parser.add_argument( - "--full", - action="store_true", - help="Run all tasks (default: single task)", - ) - args = parser.parse_args() - - config = _load_config(args.config) - tau2_cfg = config.setdefault("tau2", {}) - - domain = args.domain or tau2_cfg.get("domain", "airline") - tau2_cfg["domain"] = domain - - data_dir = tau2_cfg.get("data_dir") or os.environ.get("TAU2_DATA_DIR") - if data_dir: - os.environ["TAU2_DATA_DIR"] = data_dir - - admin_key = config.get("admin_key", "areal-agent-admin") - - router_addr = f"http://127.0.0.1:{ROUTER_PORT}" - worker_addr = f"http://127.0.0.1:{WORKER_PORT}" - proxy_addr = f"http://127.0.0.1:{PROXY_PORT}" - gateway_addr = f"http://127.0.0.1:{GATEWAY_PORT}" - - # 1. Router - _start_in_thread(create_router_app(admin_key=admin_key), ROUTER_PORT, "router") - - # 2. Worker (Tau2Agent with PydanticAI + tau2 tools) - def _make_agent_cls(): - from examples.agent_service.agent import Tau2Agent - - class _Configured(Tau2Agent): - def __init__(self, **kw: Any): - super().__init__(config=config, **kw) - - return _Configured - - with patch( - "areal.experimental.agent_service.worker.app.import_from_string", - return_value=_make_agent_cls(), - ): - worker_app = create_worker_app("examples.agent_service.agent.Tau2Agent") - _start_in_thread(worker_app, WORKER_PORT, "worker") - - # 3. DataProxy - _start_in_thread( - create_data_proxy_app(worker_addr=worker_addr), PROXY_PORT, "proxy" - ) - - # 4. Gateway + Bridge - gw_app = create_gateway_app(router_addr=router_addr, admin_key=admin_key) - mount_bridge( - gw_app, - OpenResponsesBridge(router_addr=router_addr, admin_key=admin_key), - admin_key=admin_key, - ) - _start_in_thread(gw_app, GATEWAY_PORT, "gateway") - - # 5. Wait + register - async def setup(): - await _wait_healthy(f"{router_addr}/health") - await _wait_healthy(f"{worker_addr}/health") - await _wait_healthy(f"{proxy_addr}/health") - await _wait_healthy(f"{gateway_addr}/health") - from areal.experimental.agent_service.auth import admin_headers - - async with httpx.AsyncClient() as client: - await client.post( - f"{router_addr}/register", - json={"addr": proxy_addr}, - headers=admin_headers(admin_key), - ) - - asyncio.run(setup()) - print("All services started.") - - # 6. Run demo - asyncio.run( - run_demo(gateway_addr, domain=domain, full=args.full, admin_key=admin_key) - ) - - -if __name__ == "__main__": - main() diff --git a/tests/experimental/agent_service/test_agent_router.py b/tests/experimental/agent_service/test_agent_router.py index 214d7e9d0c..13e682de23 100644 --- a/tests/experimental/agent_service/test_agent_router.py +++ b/tests/experimental/agent_service/test_agent_router.py @@ -4,16 +4,18 @@ import pytest -from areal.experimental.agent_service.auth import DEFAULT_ADMIN_KEY, admin_headers +from areal.experimental.agent_service.auth import DEFAULT_ADMIN_API_KEY, admin_headers from areal.experimental.agent_service.router.app import create_router_app +from areal.experimental.agent_service.router.config import RouterConfig httpx = pytest.importorskip("httpx") -_AUTH = admin_headers(DEFAULT_ADMIN_KEY) +_AUTH = admin_headers(DEFAULT_ADMIN_API_KEY) def _make_client(): - app = create_router_app(admin_key=DEFAULT_ADMIN_KEY) + config = RouterConfig(admin_api_key=DEFAULT_ADMIN_API_KEY) + app = create_router_app(config) transport = httpx.ASGITransport(app=app) return httpx.AsyncClient(transport=transport, base_url="http://router") diff --git a/tests/experimental/agent_service/test_controller.py b/tests/experimental/agent_service/test_controller.py new file mode 100644 index 0000000000..376bed71c6 --- /dev/null +++ b/tests/experimental/agent_service/test_controller.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for AgentServiceController. + +All Guard HTTP interactions are mocked — no real processes or servers. +Tests cover: initialize, destroy, scale_up, scale_down, and error handling. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +import pytest + +from areal.experimental.agent_service.controller.config import ( + AgentServiceControllerConfig, +) +from areal.experimental.agent_service.controller.controller import ( + AgentServiceController, +) + +CTRL = "areal.experimental.agent_service.controller.controller" + + +@dataclass +class _FakeWorker: + id: str + ip: str + worker_ports: list[str] + engine_ports: list[str] + + +def _make_scheduler(*guard_specs: tuple[str, str]) -> MagicMock: + """Return a mock Scheduler whose get_workers returns _FakeWorkers.""" + workers = [ + _FakeWorker(id=f"agent-guard/{i}", ip=ip, worker_ports=[port], engine_ports=[]) + for i, (ip, port) in enumerate(guard_specs) + ] + scheduler = MagicMock() + scheduler.get_workers.return_value = workers + return scheduler + + +def _mock_alloc_ports_response(host: str, ports: list[int]) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"status": "success", "host": host, "ports": ports} + resp.raise_for_status = MagicMock() + return resp + + +def _mock_fork_response(host: str, pid: int) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"status": "success", "host": host, "pid": pid} + resp.raise_for_status = MagicMock() + return resp + + +def _mock_kill_response() -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"status": "success"} + resp.text = '{"status": "success"}' + return resp + + +def _mock_register_response() -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.raise_for_status = MagicMock() + return resp + + +def _mock_health_response(active_sessions: int = 0) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"status": "ok", "active_sessions": active_sessions} + return resp + + +@pytest.fixture() +def config(): + return AgentServiceControllerConfig( + agent_cls_path="my.Agent", + admin_api_key="test-key", + num_pairs=2, + setup_timeout=1.0, + health_poll_interval=0, + ) + + +def _setup_mock_requests(mock_requests, port_start=9001): + port_counter = iter(range(port_start, port_start + 100)) + + def mock_post(url, **kwargs): + if "/alloc_ports" in url: + return _mock_alloc_ports_response("10.0.0.1", [next(port_counter)]) + if "/fork" in url: + return _mock_fork_response("10.0.0.1", 100) + if "/register" in url: + return _mock_register_response() + if "/kill_forked_worker" in url: + return _mock_kill_response() + if "/unregister" in url: + return _mock_register_response() + return MagicMock(status_code=404) + + mock_requests.post = mock_post + mock_requests.get = lambda url, **kw: _mock_health_response() + mock_requests.RequestException = Exception + + +class TestConstruction: + def test_construction(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + assert ctrl.router_addr == "" + assert ctrl.gateway_addr == "" + assert ctrl.pairs == {} + + +class TestInitialize: + @patch(f"{CTRL}.requests") + def test_initialize_forks_router_pairs_gateway(self, mock_requests, config): + """Initialize should create guards via scheduler, then fork services.""" + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090"), ("10.0.0.2", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + + scheduler.create_workers.assert_called_once() + scheduler.get_workers.assert_called_once() + + assert "http://" in ctrl.router_addr + assert "http://" in ctrl.gateway_addr + assert len(ctrl.pairs) == 2 + assert len(ctrl._forked_services) == 6 + + +class TestScaleUp: + @patch(f"{CTRL}.requests") + def test_scale_up_adds_pairs(self, mock_requests, config): + config.num_pairs = 0 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert len(ctrl.pairs) == 0 + + created = ctrl.scale_up(3) + assert created == [0, 1, 2] + assert len(ctrl.pairs) == 3 + + @patch(f"{CTRL}.requests") + def test_scale_up_round_robins_guards(self, mock_requests, config): + config.num_pairs = 0 + guards_called: list[str] = [] + + def mock_post(url, **kwargs): + if "/alloc_ports" in url: + guards_called.append(url.split("/alloc_ports")[0]) + return _mock_alloc_ports_response("10.0.0.1", [9001]) + if "/fork" in url: + return _mock_fork_response("10.0.0.1", 100) + if "/register" in url: + return _mock_register_response() + if "/kill_forked_worker" in url: + return _mock_kill_response() + return MagicMock(status_code=404) + + mock_requests.post = mock_post + mock_requests.get = lambda url, **kw: _mock_health_response() + mock_requests.RequestException = Exception + + scheduler = _make_scheduler(("g0", "8090"), ("g1", "8091")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + guards_called.clear() + + ctrl.scale_up(4) + + g0_calls = [g for g in guards_called if "g0" in g] + g1_calls = [g for g in guards_called if "g1" in g] + assert len(g0_calls) == 4 + assert len(g1_calls) == 4 + + +class TestScaleDown: + @patch(f"{CTRL}.requests") + def test_scale_down_removes_newest_first(self, mock_requests, config): + config.num_pairs = 3 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert len(ctrl.pairs) == 3 + + removed = ctrl.scale_down(2) + assert set(removed) == {2, 1} + assert len(ctrl.pairs) == 1 + assert 0 in ctrl.pairs + + +class TestDestroy: + @patch(f"{CTRL}.requests") + def test_destroy_clears_everything(self, mock_requests, config): + config.num_pairs = 1 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert len(ctrl._forked_services) > 0 + + ctrl.destroy() + assert ctrl.router_addr == "" + assert ctrl.gateway_addr == "" + assert ctrl.pairs == {} + assert ctrl._forked_services == [] + scheduler.delete_workers.assert_called() + + @patch(f"{CTRL}.requests") + def test_destroy_tolerates_kill_errors(self, mock_requests, config): + config.num_pairs = 0 + kill_count = 0 + + def mock_post(url, **kwargs): + nonlocal kill_count + if "/alloc_ports" in url: + return _mock_alloc_ports_response("10.0.0.1", [9001]) + if "/fork" in url: + return _mock_fork_response("10.0.0.1", 100) + if "/kill_forked_worker" in url: + kill_count += 1 + raise ConnectionError("Guard down") + return MagicMock(status_code=404) + + mock_requests.post = mock_post + mock_requests.get = lambda url, **kw: _mock_health_response() + mock_requests.RequestException = Exception + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + + ctrl.destroy() + assert kill_count == 2 + assert ctrl._forked_services == [] + + +class TestDrain: + @patch(f"{CTRL}.requests") + def test_scale_down_waits_for_drain(self, mock_requests, config): + """scale_down should poll DataProxy health until active_sessions reaches 0.""" + config.num_pairs = 1 + config.drain_timeout = 5.0 + + _setup_mock_requests(mock_requests) + health_call_count = 0 + + def mock_get(url, **kwargs): + nonlocal health_call_count + health_call_count += 1 + if "/health" in url and health_call_count <= 5: + return _mock_health_response(active_sessions=2) + return _mock_health_response(active_sessions=0) + + mock_requests.get = mock_get + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + + health_call_count = 0 + with patch(f"{CTRL}.time") as mock_time: + mock_time.monotonic = time.monotonic + mock_time.sleep = MagicMock() + ctrl.scale_down(1) + + assert len(ctrl.pairs) == 0 + assert health_call_count > 1 + + @patch(f"{CTRL}.requests") + def test_drain_skipped_when_timeout_zero(self, mock_requests, config): + config.num_pairs = 1 + config.drain_timeout = 0 + _setup_mock_requests(mock_requests) + get_count = 0 + + def counting_get(url, **kwargs): + nonlocal get_count + get_count += 1 + return _mock_health_response(active_sessions=5) + + mock_requests.get = counting_get + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + + pre_get_count = get_count + ctrl.scale_down(1) + drain_gets = get_count - pre_get_count + assert drain_gets == 0 + + +class TestHealthMonitor: + @patch(f"{CTRL}.requests") + def test_health_monitor_starts_and_stops(self, mock_requests, config): + config.num_pairs = 0 + config.health_poll_interval = 0.1 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert ctrl._health_thread is not None + assert ctrl._health_thread.is_alive() + + ctrl.destroy() + assert ctrl._health_thread is None + + @patch(f"{CTRL}.requests") + def test_health_monitor_disabled_when_interval_zero(self, mock_requests, config): + config.num_pairs = 0 + config.health_poll_interval = 0 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert ctrl._health_thread is None + + ctrl.destroy() diff --git a/tests/experimental/agent_service/test_guard.py b/tests/experimental/agent_service/test_guard.py new file mode 100644 index 0000000000..e5a82deac3 --- /dev/null +++ b/tests/experimental/agent_service/test_guard.py @@ -0,0 +1,73 @@ +"""Unit tests for Agent Service Guard (pure pass-through). + +Tests that the base guard routes are available on the agent guard app. +The agent_blueprint has been removed in v2 — all orchestration logic +now lives in AgentServiceController. + +Test structure mirrors ``tests/experimental/inference_service/test_guard.py``. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from areal.experimental.agent_service.guard import app as guard_module +from areal.experimental.agent_service.guard.app import app + +GUARD_APP = "areal.infra.rpc.guard.app" + + +@pytest.fixture(autouse=True) +def _reset_guard_globals(): + """Reset all guard state between tests.""" + guard_module._state.allocated_ports = set() + guard_module._state.forked_children = [] + guard_module._state.forked_children_map = {} + guard_module._state.server_host = "10.0.0.1" + guard_module._state.experiment_name = "test-exp" + guard_module._state.trial_name = "test-trial" + guard_module._state.fileroot = None + yield + guard_module._state.allocated_ports = set() + guard_module._state.forked_children = [] + guard_module._state.forked_children_map = {} + + +@pytest.fixture() +def client(): + app.config["TESTING"] = True + with app.test_client() as c: + yield c + + +class TestHealth: + def test_health_returns_200(self, client): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.get_json() + assert data["status"] == "healthy" + assert data["forked_children"] == 0 + + def test_health_counts_forked_children(self, client): + guard_module._state.forked_children = [MagicMock(), MagicMock()] + resp = client.get("/health") + data = resp.get_json() + assert data["forked_children"] == 2 + + +class TestAllocPorts: + @patch(f"{GUARD_APP}.find_free_ports") + def test_alloc_ports_success(self, mock_find, client): + mock_find.return_value = [9001, 9002] + resp = client.post("/alloc_ports", json={"count": 2}) + assert resp.status_code == 200 + data = resp.get_json() + assert data["status"] == "success" + assert data["ports"] == [9001, 9002] + assert guard_module._state.allocated_ports == {9001, 9002} + + def test_alloc_ports_missing_count(self, client): + resp = client.post("/alloc_ports", json={}) + assert resp.status_code == 400 diff --git a/tests/experimental/agent_service/test_integration.py b/tests/experimental/agent_service/test_integration.py index 7c3dc9ba19..f815c5ca25 100644 --- a/tests/experimental/agent_service/test_integration.py +++ b/tests/experimental/agent_service/test_integration.py @@ -10,11 +10,14 @@ import pytest -from areal.experimental.agent_service.auth import DEFAULT_ADMIN_KEY, admin_headers +from areal.experimental.agent_service.auth import DEFAULT_ADMIN_API_KEY, admin_headers from areal.experimental.agent_service.data_proxy.app import create_data_proxy_app +from areal.experimental.agent_service.data_proxy.config import DataProxyConfig from areal.experimental.agent_service.gateway.app import create_gateway_app from areal.experimental.agent_service.gateway.bridge import OpenResponsesBridge +from areal.experimental.agent_service.gateway.config import GatewayConfig from areal.experimental.agent_service.router.app import create_router_app +from areal.experimental.agent_service.router.config import RouterConfig from areal.experimental.agent_service.types import ( AgentRequest, AgentResponse, @@ -24,7 +27,7 @@ httpx = pytest.importorskip("httpx") -_AUTH = admin_headers(DEFAULT_ADMIN_KEY) +_AUTH = admin_headers(DEFAULT_ADMIN_API_KEY) class _EchoAgent: @@ -88,7 +91,7 @@ async def test_data_proxy_manages_history(self): worker_transport = httpx.ASGITransport(app=worker_app) # Create DataProxy pointing to worker - proxy_app = create_data_proxy_app(worker_addr="http://worker") + proxy_app = create_data_proxy_app(DataProxyConfig(worker_addr="http://worker")) # Patch DataProxy's httpx client to use worker's ASGITransport original_post = httpx.AsyncClient.post @@ -132,7 +135,7 @@ async def patched_post(self, url, **kwargs): async def test_close_session_clears_history(self): worker_app = _make_worker_app(_EchoAgent) worker_transport = httpx.ASGITransport(app=worker_app) - proxy_app = create_data_proxy_app(worker_addr="http://worker") + proxy_app = create_data_proxy_app(DataProxyConfig(worker_addr="http://worker")) original_post = httpx.AsyncClient.post @@ -163,7 +166,9 @@ async def patched_post(self, url, **kwargs): class TestRouterIntegration: @pytest.mark.asyncio async def test_register_and_route(self): - router_app = create_router_app(admin_key=DEFAULT_ADMIN_KEY) + router_app = create_router_app( + RouterConfig(admin_api_key=DEFAULT_ADMIN_API_KEY) + ) transport = httpx.ASGITransport(app=router_app) async with httpx.AsyncClient( @@ -190,7 +195,7 @@ class TestToolCallFlow: async def test_tool_events_through_proxy(self): worker_app = _make_worker_app(_ToolAgent) worker_transport = httpx.ASGITransport(app=worker_app) - proxy_app = create_data_proxy_app(worker_addr="http://worker") + proxy_app = create_data_proxy_app(DataProxyConfig(worker_addr="http://worker")) original_post = httpx.AsyncClient.post @@ -231,7 +236,7 @@ async def patched_post(self, url, **kwargs): class TestGatewayHealth: @pytest.mark.asyncio async def test_health(self): - app = create_gateway_app(router_addr="http://fake-router") + app = create_gateway_app(GatewayConfig(router_addr="http://fake-router")) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient( transport=transport, base_url="http://gw" diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 99a64a71ae..13b8f43e95 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -615,3 +615,199 @@ async def run(self, data, **kwargs): workflow._export_interactions.assert_awaited_once_with( mock_http_session, "sess-1", trajectory_id=None ) + + +# ============================================================================= +# Multi-node inference configuration +# ============================================================================= + + +class TestMultiNodeConfig: + def test_n_gpus_per_node_default_is_none(self): + cfg = GatewayControllerConfig() + assert cfg.n_gpus_per_node is None + + def test_n_gpus_per_node_custom(self): + cfg = GatewayControllerConfig(n_gpus_per_node=4) + assert cfg.n_gpus_per_node == 4 + + def test_n_gpus_per_node_zero_raises(self): + cfg = GatewayControllerConfig(n_gpus_per_node=0, backend="sglang:d1t8") + with pytest.raises(ValueError, match="n_gpus_per_node must be >= 1"): + GatewayInferenceController(config=cfg, scheduler=MagicMock()) + + def test_gpus_not_divisible_raises(self): + cfg = GatewayControllerConfig(n_gpus_per_node=3, backend="sglang:d1t8") + with pytest.raises(ValueError, match="must be divisible by n_gpus_per_node"): + GatewayInferenceController(config=cfg, scheduler=MagicMock()) + + def test_single_node_backward_compat(self): + cfg = GatewayControllerConfig(backend="sglang:d2t4") + controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + assert controller._nnodes_per_instance == 1 + + def test_multi_node_valid_config(self): + # tp=16, n_gpus_per_node=8 → nnodes_per_instance=2 + cfg = GatewayControllerConfig(n_gpus_per_node=8, backend="sglang:d1t16") + controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + assert controller._nnodes_per_instance == 2 + + @pytest.mark.asyncio + async def test_async_initialize_multinode_worker_count(self): + """With multi-node and pre-existing server_infos, should create dp_size workers.""" + from areal.api.cli_args import SchedulingSpec + from areal.api.io_struct import LocalInfServerInfo + + worker0 = MagicMock() + worker0.ip = "10.0.0.1" + worker0.worker_ports = [18000] + worker0.id = "w0" + + worker1 = MagicMock() + worker1.ip = "10.0.0.2" + worker1.worker_ports = [18000] + worker1.id = "w1" + + scheduler = MagicMock() + scheduler.get_workers.return_value = [worker0] + + # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 + cfg = GatewayControllerConfig( + tokenizer_path="mock-tokenizer", + backend="sglang:d1t8", + n_gpus_per_node=4, + scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), + openai=OpenAIProxyConfig(admin_api_key="test-key"), + ) + controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller._callback_host = "127.0.0.1" + controller._callback_port = 19000 + + with patch.object(controller, "_fork_on_guard") as mock_fork: + mock_fork.side_effect = [ + ("127.0.0.1", 18081), # router + ("127.0.0.1", 18082), # data proxy (only 1, on head) + ("127.0.0.1", 18080), # gateway + ] + + await controller._async_initialize( + server_args=None, + server_infos=[ + LocalInfServerInfo( + host="10.0.0.1", port=30000, process=MagicMock() + ), + ], + ) + + # With server_infos, total_workers = dp_size = 1 (not dp_size * nnodes_per_instance) + create_call = scheduler.create_workers.call_args + job = create_call.kwargs.get("job") or create_call.args[0] + assert job.replicas == 1 + + # 3 forks: router + data-proxy + gateway (all on head worker) + assert mock_fork.call_count == 3 + data_proxy_calls = [ + c for c in mock_fork.call_args_list if c.kwargs.get("role") == "data-proxy" + ] + assert len(data_proxy_calls) == 1 + + @pytest.mark.asyncio + async def test_async_initialize_multinode_fork_path(self): + """Exercise the full multi-node fork path (server_infos=None).""" + from areal.api.cli_args import SchedulingSpec + + worker0 = MagicMock() + worker0.ip = "10.0.0.1" + worker0.worker_ports = [18000] + worker0.id = "w0" + + worker1 = MagicMock() + worker1.ip = "10.0.0.2" + worker1.worker_ports = [18000] + worker1.id = "w1" + + scheduler = MagicMock() + scheduler.get_workers.return_value = [worker0, worker1] + + # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 + cfg = GatewayControllerConfig( + tokenizer_path="mock-tokenizer", + backend="sglang:d1t8", + n_gpus_per_node=4, + scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), + openai=OpenAIProxyConfig(admin_api_key="test-key"), + ) + controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller._callback_host = "127.0.0.1" + controller._callback_port = 19000 + + # Track requests.post calls to /alloc_ports and /fork + alloc_port_counter = 0 + fork_calls = [] + + def mock_requests_post(url, json=None, timeout=None): + nonlocal alloc_port_counter + resp = MagicMock() + resp.status_code = 200 + if "/alloc_ports" in url: + alloc_port_counter += 1 + resp.json.return_value = { + "status": "success", + "host": url.split("//")[1].split(":")[0], + "ports": [30000 + alloc_port_counter], + } + elif "/fork" in url: + fork_calls.append(json) + resp.json.return_value = {"status": "success"} + return resp + + with ( + patch("requests.post", side_effect=mock_requests_post) as mock_post, + patch.object(controller, "_fork_on_guard") as mock_fork, + patch.object(controller, "_wait_for_service"), + patch( + "areal.api.cli_args.pkg_version.is_version_greater_or_equal", + return_value=True, + ), + patch("areal.api.cli_args.is_version_less", return_value=False), + ): + mock_fork.side_effect = [ + ("10.0.0.1", 18081), # router + ("10.0.0.1", 18082), # data proxy + ("10.0.0.1", 18080), # gateway + ] + + await controller._async_initialize( + server_args=None, + server_infos=None, + ) + + # dp_size=1, nnodes_per_instance=2: total_workers = 2 + create_call = scheduler.create_workers.call_args + job = create_call.kwargs.get("job") or create_call.args[0] + assert job.replicas == 2 + + # requests.post calls: + # 1 rendezvous alloc (nnodes_per_instance > 1) + 2 node allocs + 2 forks = 5 + post_calls = mock_post.call_args_list + alloc_calls = [c for c in post_calls if "/alloc_ports" in str(c)] + fork_post_calls = [c for c in post_calls if "/fork" in str(c)] + assert len(alloc_calls) == 3 # 1 rendezvous + 2 per-node + assert len(fork_post_calls) == 2 # 1 per node in the group + + # Verify fork payloads have correct worker_index and role + assert fork_calls[0]["role"] == "inf-server" + assert fork_calls[0]["worker_index"] == 0 + assert fork_calls[1]["role"] == "inf-server" + assert fork_calls[1]["worker_index"] == 1 + + # Verify dist_init_addr propagated to fork commands + for fc in fork_calls: + cmd_str = " ".join(fc["raw_cmd"]) + assert "--dist-init-addr" in cmd_str or "--dist_init_addr" in cmd_str + + # Only 1 data proxy (dp_size=1, on head worker only) + data_proxy_calls = [ + c for c in mock_fork.call_args_list if c.kwargs.get("role") == "data-proxy" + ] + assert len(data_proxy_calls) == 1 diff --git a/tests/experimental/inference_service/test_sglang_multinode.py b/tests/experimental/inference_service/test_sglang_multinode.py new file mode 100644 index 0000000000..dfc2d2e3ce --- /dev/null +++ b/tests/experimental/inference_service/test_sglang_multinode.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for SGLang multi-node CLI generation.""" + +from __future__ import annotations + +from unittest.mock import patch + +from areal.api.cli_args import SGLangConfig + + +class TestSGLangMultiNode: + """Mirror of TestVLLMMultiNode for the SGLang backend.""" + + def _build_args(self, **kwargs): + """Helper that patches sglang version checks away.""" + defaults = dict( + sglang_config=SGLangConfig(model_path="test-model"), + tp_size=8, + base_gpu_id=0, + ) + defaults.update(kwargs) + with ( + patch( + "areal.api.cli_args.pkg_version.is_version_greater_or_equal", + return_value=True, + ), + patch("areal.api.cli_args.is_version_less", return_value=False), + ): + return SGLangConfig.build_args(**defaults) + + def _build_cmd(self, **kwargs): + """Helper that patches sglang version checks away.""" + defaults = dict( + sglang_config=SGLangConfig(model_path="test-model"), + tp_size=8, + base_gpu_id=0, + ) + defaults.update(kwargs) + with ( + patch( + "areal.api.cli_args.pkg_version.is_version_greater_or_equal", + return_value=True, + ), + patch("areal.api.cli_args.is_version_less", return_value=False), + ): + return SGLangConfig.build_cmd(**defaults) + + def test_build_args_single_node_defaults(self): + """Single-node (default) should have nnodes=1, node_rank=0.""" + args = self._build_args() + assert args["nnodes"] == 1 + assert args["node_rank"] == 0 + assert args.get("dist_init_addr") is None + + def test_build_args_multi_node_head(self): + """Head node (rank 0) with n_nodes > 1 should set nnodes and dist_init_addr.""" + args = self._build_args( + tp_size=16, + n_nodes=2, + node_rank=0, + dist_init_addr="10.0.0.1:29500", + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 0 + assert args["dist_init_addr"] == "10.0.0.1:29500" + + def test_build_args_multi_node_worker(self): + """Worker node (rank > 0) should set nnodes and node_rank.""" + args = self._build_args( + tp_size=16, + n_nodes=2, + node_rank=1, + dist_init_addr="10.0.0.1:29500", + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 1 + assert args["dist_init_addr"] == "10.0.0.1:29500" + + def test_build_args_multi_node_no_dist_init_addr(self): + """Multi-node without dist_init_addr should have dist_init_addr=None.""" + args = self._build_args( + tp_size=16, + n_nodes=2, + node_rank=0, + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 0 + assert args.get("dist_init_addr") is None + + def test_build_cmd_multi_node_produces_flags(self): + """build_cmd with multi-node should produce CLI flags for nnodes and node-rank.""" + cmd = self._build_cmd( + tp_size=16, + n_nodes=2, + node_rank=1, + dist_init_addr="10.0.0.1:29500", + ) + cmd_str = " ".join(cmd) + assert "--nnodes" in cmd_str + assert "--node-rank" in cmd_str + assert "--dist-init-addr" in cmd_str diff --git a/tests/experimental/inference_service/test_vllm_multinode.py b/tests/experimental/inference_service/test_vllm_multinode.py new file mode 100644 index 0000000000..969a390247 --- /dev/null +++ b/tests/experimental/inference_service/test_vllm_multinode.py @@ -0,0 +1,84 @@ +"""Tests for vLLM multi-node CLI generation.""" + +from __future__ import annotations + +from areal.api.cli_args import vLLMConfig + + +class TestVLLMMultiNode: + def test_build_args_single_node_no_extra_flags(self): + """Single-node (default) should not add nnodes/node_rank/headless.""" + cfg = vLLMConfig(model="test-model") + args = vLLMConfig.build_args(cfg, tp_size=8, pp_size=1) + assert "nnodes" not in args + assert "node_rank" not in args + assert "headless" not in args + assert "master_addr" not in args + assert "master_port" not in args + + def test_build_args_multi_node_head(self): + """Head node (rank 0) with n_nodes > 1 should add nnodes/node_rank but NOT headless.""" + cfg = vLLMConfig(model="test-model") + args = vLLMConfig.build_args( + cfg, + tp_size=16, + pp_size=1, + n_nodes=2, + node_rank=0, + dist_init_addr="10.0.0.1:29500", + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 0 + assert "headless" not in args + assert args["master_addr"] == "10.0.0.1" + assert args["master_port"] == "29500" + + def test_build_args_multi_node_worker(self): + """Worker node (rank > 0) should add headless=True.""" + cfg = vLLMConfig(model="test-model") + args = vLLMConfig.build_args( + cfg, + tp_size=16, + pp_size=1, + n_nodes=2, + node_rank=1, + dist_init_addr="10.0.0.1:29500", + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 1 + assert args["headless"] is True + assert args["master_addr"] == "10.0.0.1" + assert args["master_port"] == "29500" + + def test_build_args_multi_node_no_dist_init_addr(self): + """Multi-node without dist_init_addr should not add master_addr/master_port.""" + cfg = vLLMConfig(model="test-model") + args = vLLMConfig.build_args( + cfg, + tp_size=16, + pp_size=1, + n_nodes=2, + node_rank=0, + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 0 + assert "master_addr" not in args + assert "master_port" not in args + + def test_build_cmd_multi_node_produces_flags(self): + """build_cmd with multi-node should produce CLI flags for nnodes and node-rank.""" + cfg = vLLMConfig(model="test-model") + cmd = vLLMConfig.build_cmd( + cfg, + tp_size=16, + pp_size=1, + n_nodes=2, + node_rank=1, + dist_init_addr="10.0.0.1:29500", + ) + cmd_str = " ".join(cmd) + assert "--nnodes" in cmd_str + assert "--node-rank" in cmd_str + assert "--headless" in cmd_str + assert "--master-addr" in cmd_str + assert "--master-port" in cmd_str