From b60407987feb7c87bf9720c85ab1ce682869a11c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=97=AE=E9=97=AE=E7=A5=9E=E5=A5=87=E6=B5=B7=E8=9E=BA?= Date: Wed, 15 Apr 2026 10:55:43 +0800 Subject: [PATCH 1/2] feat(service): support multi-node inference in gateway controller (#1178) * feat(service): support multi-node inference in gateway controller Enable inference instances to span multiple physical nodes for large models (e.g. Llama-3.1 405B with tp_size=16 across 2x 8-GPU nodes). When nnodes > 1, the controller groups every nnodes workers into one inference instance, allocates rendezvous ports for distributed init, forks inference servers on all nodes per group, and only records the head node as the HTTP endpoint. When nnodes=1 (default), behavior is identical to the existing code. Key changes: - Add nnodes field to GatewayControllerConfig (default 1) - Extend vLLMConfig.build_args()/build_cmd() with n_nodes/node_rank - Restructure controller fork loop with grouped multi-node support - Register forked inference processes in _forked_services for cleanup - Fix server_infos indexing to avoid IndexError when nnodes > 1 - Data proxies only fork on head nodes (one per DP group) - Defensive checks: worker count validation, RuntimeError over assert Refs: #1149 * refactor(inference): replace nnodes with n_gpus_per_node for multi-node configuration --- areal/api/cli_args.py | 18 ++ .../inference_service/controller/config.py | 1 + .../controller/controller.py | 229 +++++++++++++----- .../inference_service/test_controller.py | 196 +++++++++++++++ .../test_sglang_multinode.py | 102 ++++++++ .../inference_service/test_vllm_multinode.py | 84 +++++++ 6 files changed, 565 insertions(+), 65 deletions(-) create mode 100644 tests/experimental/inference_service/test_sglang_multinode.py create mode 100644 tests/experimental/inference_service/test_vllm_multinode.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 6716091d6f..91f9d3ba13 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1507,6 +1507,8 @@ def build_args( host: str | None = None, port: int | None = None, dist_init_addr: str | None = None, + n_nodes: int = 1, + node_rank: int = 0, ): args: dict = conf_as_dict(vllm_config) args = dict( @@ -1522,6 +1524,18 @@ def build_args( args["port"] = port if host is not None: args["host"] = host + # Multi-node support + if n_nodes > 1: + args["nnodes"] = n_nodes + args["node_rank"] = node_rank + if dist_init_addr is not None: + from areal.utils.network import split_hostport + + master_host, master_port = split_hostport(dist_init_addr) + args["master_addr"] = master_host + args["master_port"] = str(master_port) + if node_rank > 0: + args["headless"] = True return args @staticmethod @@ -1536,6 +1550,8 @@ def build_cmd( host: str | None = None, port: int | None = None, dist_init_addr: str | None = None, + n_nodes: int = 1, + node_rank: int = 0, ): args = vLLMConfig.build_args( vllm_config=vllm_config, @@ -1544,6 +1560,8 @@ def build_cmd( host=host, port=port, dist_init_addr=dist_init_addr, + n_nodes=n_nodes, + node_rank=node_rank, ) return vLLMConfig.build_cmd_from_args(args) diff --git a/areal/experimental/inference_service/controller/config.py b/areal/experimental/inference_service/controller/config.py index 4f0574e212..2b8832a722 100644 --- a/areal/experimental/inference_service/controller/config.py +++ b/areal/experimental/inference_service/controller/config.py @@ -51,6 +51,7 @@ class GatewayControllerConfig: backend: str = "sglang:d1" scheduling_spec: tuple = field(default_factory=tuple) pause_grace_period: float = 0.5 + n_gpus_per_node: int | None = None # GPUs per physical node; None = single-node # -- OpenAI proxy configuration (for agent-like workflows) --------------- openai: OpenAIProxyConfig = field(default_factory=lambda: OpenAIProxyConfig()) diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index ecb2d67ab8..70bdfba148 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -90,6 +90,24 @@ def __init__( # Parse allocation from config.backend self.rollout_alloc = ModelAllocation.from_str(config.backend) + # Multi-node: derive nnodes_per_instance from n_gpus_per_node + total_gpus = ( + self.rollout_alloc.parallel.tp_size * self.rollout_alloc.parallel.pp_size + ) + n_gpus_per_node = config.n_gpus_per_node + if n_gpus_per_node is None: + nnodes_per_instance = 1 + else: + if n_gpus_per_node < 1: + raise ValueError(f"n_gpus_per_node must be >= 1, got {n_gpus_per_node}") + if total_gpus % n_gpus_per_node != 0: + raise ValueError( + f"tp_size * pp_size ({total_gpus}) must be divisible " + f"by n_gpus_per_node ({n_gpus_per_node})" + ) + nnodes_per_instance = total_gpus // n_gpus_per_node + self._nnodes_per_instance = nnodes_per_instance + # Worker management self.workers: list[Worker] = [] self.server_infos: list[LocalInfServerInfo] = [] @@ -224,27 +242,32 @@ async def _async_initialize( inf_backend = alloc.backend # ================================================================== - # Step 0: Always create dp_size RPCGuard workers + # Step 0: Create RPCGuard workers (dp_size × nnodes_per_instance) # ================================================================== inf_spec = SchedulingSpec(**asdict(cfg.scheduling_spec[0])) instance_size = alloc.parallel.tp_size * alloc.parallel.pp_size + nnodes_per_instance = self._nnodes_per_instance + gpus_per_worker = instance_size // nnodes_per_instance + if server_infos is not None: - # Pre-existing inference servers — RPCGuard workers only host - # CPU services (data proxy, router, gateway), no GPUs needed. + # Pre-existing inference servers — only need dp_size workers + # for CPU services (data proxy, router, gateway), no GPUs. + total_workers = dp_size inf_spec.gpu = 0 else: - inf_spec.cpu *= instance_size - inf_spec.mem *= instance_size + total_workers = dp_size * nnodes_per_instance + inf_spec.cpu *= gpus_per_worker + inf_spec.mem *= gpus_per_worker if inf_spec.gpu > 0: - inf_spec.gpu = instance_size + inf_spec.gpu = gpus_per_worker # Override cmd to launch RPCGuard instead of RPC server inf_spec.cmd = "python -m areal.experimental.inference_service.guard" inf_role = f"{self._worker_role}{self._INF_SUFFIX}" inf_job = Job( - replicas=dp_size, - tasks=[inf_spec for _ in range(dp_size)], + replicas=total_workers, + tasks=[inf_spec for _ in range(total_workers)], scheduling_strategy=SchedulingStrategy(), role=inf_role, ) @@ -252,6 +275,11 @@ async def _async_initialize( self.scheduler.create_workers(job=inf_job) self._service_roles.append(inf_role) inf_workers = self.scheduler.get_workers(role=inf_role) + if len(inf_workers) != total_workers: + raise RuntimeError( + f"Expected {total_workers} workers for role {inf_role!r}, " + f"got {len(inf_workers)}" + ) self.workers = inf_workers logger.info("RPCGuard workers ready: %s", [w.id for w in inf_workers]) @@ -291,13 +319,22 @@ async def _async_initialize( v, ) - def _build_launch_cmd(host: str, port: int) -> list[str]: + def _build_launch_cmd( + host: str | None, + port: int | None, + n_nodes: int = 1, + node_rank: int = 0, + dist_init_addr: str | None = None, + ) -> list[str]: return SGLangConfig.build_cmd( sglang_config=sglang_config, tp_size=tp_size, base_gpu_id=0, host=host, port=port, + dist_init_addr=dist_init_addr, + n_nodes=n_nodes, + node_rank=node_rank, ) elif inf_backend == "vllm": @@ -315,79 +352,138 @@ def _build_launch_cmd(host: str, port: int) -> list[str]: v, ) - def _build_launch_cmd(host: str, port: int) -> list[str]: + def _build_launch_cmd( + host: str | None, + port: int | None, + n_nodes: int = 1, + node_rank: int = 0, + dist_init_addr: str | None = None, + ) -> list[str]: return vLLMConfig.build_cmd( vllm_config=vllm_config, tp_size=tp_size, pp_size=alloc.parallel.pp_size, host=host, port=port, + dist_init_addr=dist_init_addr, + n_nodes=n_nodes, + node_rank=node_rank, ) else: raise ValueError(f"Unsupported inference backend: {inf_backend!r}") - # For each RPCGuard worker: alloc port, build cmd, fork server - for rank, worker in enumerate(inf_workers): - guard_addr = ( - f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" - ) - - resp = requests.post( - f"{guard_addr}/alloc_ports", - json={"count": 1}, - timeout=30, - ) - resp.raise_for_status() - port_data = resp.json() - inf_host = port_data["host"] - inf_port = port_data["ports"][0] - - cmd = _build_launch_cmd(inf_host, inf_port) - - fork_payload: dict[str, Any] = { - "role": "inf-server", - "worker_index": rank, - "raw_cmd": cmd, - } - if inf_backend == "vllm": - from areal.infra.utils.launcher import ( - TRITON_CACHE_PATH as _TRITON_CACHE, + # For each inference instance group: alloc ports, build cmd, fork servers + for group_idx in range(dp_size): + group_workers = inf_workers[ + group_idx * nnodes_per_instance : (group_idx + 1) + * nnodes_per_instance + ] + head_worker = group_workers[0] + head_guard_addr = f"http://{format_hostport(head_worker.ip, int(head_worker.worker_ports[0]))}" + + # Allocate rendezvous port on head node for distributed init + dist_init_addr = None + if nnodes_per_instance > 1: + resp = requests.post( + f"{head_guard_addr}/alloc_ports", + json={"count": 1}, + timeout=30, ) - from areal.infra.utils.launcher import ( - VLLM_CACHE_ROOT as _VLLM_CACHE, + resp.raise_for_status() + rendezvous_data = resp.json() + rendezvous_host = rendezvous_data["host"] + rendezvous_port = rendezvous_data["ports"][0] + dist_init_addr = format_hostport(rendezvous_host, rendezvous_port) + + head_inf_host = None + head_inf_port = None + + for node_rank, worker in enumerate(group_workers): + guard_addr = f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" + + # Allocate port for inference server on this node + resp = requests.post( + f"{guard_addr}/alloc_ports", + json={"count": 1}, + timeout=30, + ) + resp.raise_for_status() + port_data = resp.json() + inf_host = port_data["host"] + inf_port = port_data["ports"][0] + + # Worker nodes (rank > 0) don't need to serve HTTP, + # but we still pass host/port for the server to bind + cmd = _build_launch_cmd( + host=inf_host, + port=inf_port, + n_nodes=nnodes_per_instance, + node_rank=node_rank, + dist_init_addr=dist_init_addr, ) - fork_payload["env"] = { - "TRITON_CACHE_PATH": os.path.join( - os.environ.get("TRITON_CACHE_PATH", _TRITON_CACHE), - str(uuid.uuid4()), - ), - "VLLM_CACHE_ROOT": os.path.join( - os.environ.get("VLLM_CACHE_ROOT", _VLLM_CACHE), - str(uuid.uuid4()), - ), - "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True", + fork_payload: dict[str, Any] = { + "role": "inf-server", + "worker_index": group_idx * nnodes_per_instance + node_rank, + "raw_cmd": cmd, } + if inf_backend == "vllm": + from areal.infra.utils.launcher import ( + TRITON_CACHE_PATH as _TRITON_CACHE, + ) + from areal.infra.utils.launcher import ( + VLLM_CACHE_ROOT as _VLLM_CACHE, + ) - resp = requests.post( - f"{guard_addr}/fork", - json=fork_payload, - timeout=30, - ) - resp.raise_for_status() + fork_payload["env"] = { + "TRITON_CACHE_PATH": os.path.join( + os.environ.get("TRITON_CACHE_PATH", _TRITON_CACHE), + str(uuid.uuid4()), + ), + "VLLM_CACHE_ROOT": os.path.join( + os.environ.get("VLLM_CACHE_ROOT", _VLLM_CACHE), + str(uuid.uuid4()), + ), + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True", + } + + resp = requests.post( + f"{guard_addr}/fork", + json=fork_payload, + timeout=30, + ) + resp.raise_for_status() + self._forked_services.append( + ( + guard_addr, + "inf-server", + group_idx * nnodes_per_instance + node_rank, + ) + ) + + if node_rank == 0: + head_inf_host = inf_host + head_inf_port = inf_port - addr = f"http://{format_hostport(inf_host, inf_port)}" + if head_inf_host is None or head_inf_port is None: + raise RuntimeError( + f"No head worker resolved for group {group_idx}; " + f"expected {nnodes_per_instance} workers per group" + ) + + # Only record the head node's address as the inference endpoint + addr = f"http://{format_hostport(head_inf_host, head_inf_port)}" self._inf_addrs.append(addr) self.server_infos.append( LocalInfServerInfo( - host=inf_host, - port=inf_port, + host=head_inf_host, + port=head_inf_port, process=None, # type: ignore[arg-type] # RPCGuard manages process ) ) - # Wait for inference servers to be healthy + # Wait for inference servers to be healthy (only head nodes) for i, addr in enumerate(self._inf_addrs): self._wait_for_service( f"{addr}/health", f"InfServer-{i}", timeout=cfg.setup_timeout @@ -442,21 +538,24 @@ def _build_launch_cmd(host: str, port: int) -> list[str]: f"http://{self.callback_addr}", ] - for rank, worker in enumerate(inf_workers): - guard_addr = ( - f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" - ) - # Each data proxy connects to its corresponding inference server + for group_idx in range(dp_size): + head_worker = inf_workers[ + group_idx + if server_infos is not None + else group_idx * nnodes_per_instance + ] + guard_addr = f"http://{format_hostport(head_worker.ip, int(head_worker.worker_ports[0]))}" + # Each data proxy connects to its group's head inference server data_proxy_cmd = data_proxy_base_cmd + [ "--backend-addr", - self._inf_addrs[rank], + self._inf_addrs[group_idx], "--backend-type", inf_backend or "sglang", ] data_proxy_host, data_proxy_port = self._fork_on_guard( guard_addr=guard_addr, role="data-proxy", - worker_index=rank, + worker_index=group_idx, raw_cmd=data_proxy_cmd, ) self._data_proxy_addrs.append( diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 99a64a71ae..13b8f43e95 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -615,3 +615,199 @@ async def run(self, data, **kwargs): workflow._export_interactions.assert_awaited_once_with( mock_http_session, "sess-1", trajectory_id=None ) + + +# ============================================================================= +# Multi-node inference configuration +# ============================================================================= + + +class TestMultiNodeConfig: + def test_n_gpus_per_node_default_is_none(self): + cfg = GatewayControllerConfig() + assert cfg.n_gpus_per_node is None + + def test_n_gpus_per_node_custom(self): + cfg = GatewayControllerConfig(n_gpus_per_node=4) + assert cfg.n_gpus_per_node == 4 + + def test_n_gpus_per_node_zero_raises(self): + cfg = GatewayControllerConfig(n_gpus_per_node=0, backend="sglang:d1t8") + with pytest.raises(ValueError, match="n_gpus_per_node must be >= 1"): + GatewayInferenceController(config=cfg, scheduler=MagicMock()) + + def test_gpus_not_divisible_raises(self): + cfg = GatewayControllerConfig(n_gpus_per_node=3, backend="sglang:d1t8") + with pytest.raises(ValueError, match="must be divisible by n_gpus_per_node"): + GatewayInferenceController(config=cfg, scheduler=MagicMock()) + + def test_single_node_backward_compat(self): + cfg = GatewayControllerConfig(backend="sglang:d2t4") + controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + assert controller._nnodes_per_instance == 1 + + def test_multi_node_valid_config(self): + # tp=16, n_gpus_per_node=8 → nnodes_per_instance=2 + cfg = GatewayControllerConfig(n_gpus_per_node=8, backend="sglang:d1t16") + controller = GatewayInferenceController(config=cfg, scheduler=MagicMock()) + assert controller._nnodes_per_instance == 2 + + @pytest.mark.asyncio + async def test_async_initialize_multinode_worker_count(self): + """With multi-node and pre-existing server_infos, should create dp_size workers.""" + from areal.api.cli_args import SchedulingSpec + from areal.api.io_struct import LocalInfServerInfo + + worker0 = MagicMock() + worker0.ip = "10.0.0.1" + worker0.worker_ports = [18000] + worker0.id = "w0" + + worker1 = MagicMock() + worker1.ip = "10.0.0.2" + worker1.worker_ports = [18000] + worker1.id = "w1" + + scheduler = MagicMock() + scheduler.get_workers.return_value = [worker0] + + # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 + cfg = GatewayControllerConfig( + tokenizer_path="mock-tokenizer", + backend="sglang:d1t8", + n_gpus_per_node=4, + scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), + openai=OpenAIProxyConfig(admin_api_key="test-key"), + ) + controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller._callback_host = "127.0.0.1" + controller._callback_port = 19000 + + with patch.object(controller, "_fork_on_guard") as mock_fork: + mock_fork.side_effect = [ + ("127.0.0.1", 18081), # router + ("127.0.0.1", 18082), # data proxy (only 1, on head) + ("127.0.0.1", 18080), # gateway + ] + + await controller._async_initialize( + server_args=None, + server_infos=[ + LocalInfServerInfo( + host="10.0.0.1", port=30000, process=MagicMock() + ), + ], + ) + + # With server_infos, total_workers = dp_size = 1 (not dp_size * nnodes_per_instance) + create_call = scheduler.create_workers.call_args + job = create_call.kwargs.get("job") or create_call.args[0] + assert job.replicas == 1 + + # 3 forks: router + data-proxy + gateway (all on head worker) + assert mock_fork.call_count == 3 + data_proxy_calls = [ + c for c in mock_fork.call_args_list if c.kwargs.get("role") == "data-proxy" + ] + assert len(data_proxy_calls) == 1 + + @pytest.mark.asyncio + async def test_async_initialize_multinode_fork_path(self): + """Exercise the full multi-node fork path (server_infos=None).""" + from areal.api.cli_args import SchedulingSpec + + worker0 = MagicMock() + worker0.ip = "10.0.0.1" + worker0.worker_ports = [18000] + worker0.id = "w0" + + worker1 = MagicMock() + worker1.ip = "10.0.0.2" + worker1.worker_ports = [18000] + worker1.id = "w1" + + scheduler = MagicMock() + scheduler.get_workers.return_value = [worker0, worker1] + + # tp=8, n_gpus_per_node=4 → nnodes_per_instance=2 + cfg = GatewayControllerConfig( + tokenizer_path="mock-tokenizer", + backend="sglang:d1t8", + n_gpus_per_node=4, + scheduling_spec=(SchedulingSpec(gpu=1, cpu=1, mem=1, cmd="mock"),), + openai=OpenAIProxyConfig(admin_api_key="test-key"), + ) + controller = GatewayInferenceController(config=cfg, scheduler=scheduler) + controller._callback_host = "127.0.0.1" + controller._callback_port = 19000 + + # Track requests.post calls to /alloc_ports and /fork + alloc_port_counter = 0 + fork_calls = [] + + def mock_requests_post(url, json=None, timeout=None): + nonlocal alloc_port_counter + resp = MagicMock() + resp.status_code = 200 + if "/alloc_ports" in url: + alloc_port_counter += 1 + resp.json.return_value = { + "status": "success", + "host": url.split("//")[1].split(":")[0], + "ports": [30000 + alloc_port_counter], + } + elif "/fork" in url: + fork_calls.append(json) + resp.json.return_value = {"status": "success"} + return resp + + with ( + patch("requests.post", side_effect=mock_requests_post) as mock_post, + patch.object(controller, "_fork_on_guard") as mock_fork, + patch.object(controller, "_wait_for_service"), + patch( + "areal.api.cli_args.pkg_version.is_version_greater_or_equal", + return_value=True, + ), + patch("areal.api.cli_args.is_version_less", return_value=False), + ): + mock_fork.side_effect = [ + ("10.0.0.1", 18081), # router + ("10.0.0.1", 18082), # data proxy + ("10.0.0.1", 18080), # gateway + ] + + await controller._async_initialize( + server_args=None, + server_infos=None, + ) + + # dp_size=1, nnodes_per_instance=2: total_workers = 2 + create_call = scheduler.create_workers.call_args + job = create_call.kwargs.get("job") or create_call.args[0] + assert job.replicas == 2 + + # requests.post calls: + # 1 rendezvous alloc (nnodes_per_instance > 1) + 2 node allocs + 2 forks = 5 + post_calls = mock_post.call_args_list + alloc_calls = [c for c in post_calls if "/alloc_ports" in str(c)] + fork_post_calls = [c for c in post_calls if "/fork" in str(c)] + assert len(alloc_calls) == 3 # 1 rendezvous + 2 per-node + assert len(fork_post_calls) == 2 # 1 per node in the group + + # Verify fork payloads have correct worker_index and role + assert fork_calls[0]["role"] == "inf-server" + assert fork_calls[0]["worker_index"] == 0 + assert fork_calls[1]["role"] == "inf-server" + assert fork_calls[1]["worker_index"] == 1 + + # Verify dist_init_addr propagated to fork commands + for fc in fork_calls: + cmd_str = " ".join(fc["raw_cmd"]) + assert "--dist-init-addr" in cmd_str or "--dist_init_addr" in cmd_str + + # Only 1 data proxy (dp_size=1, on head worker only) + data_proxy_calls = [ + c for c in mock_fork.call_args_list if c.kwargs.get("role") == "data-proxy" + ] + assert len(data_proxy_calls) == 1 diff --git a/tests/experimental/inference_service/test_sglang_multinode.py b/tests/experimental/inference_service/test_sglang_multinode.py new file mode 100644 index 0000000000..dfc2d2e3ce --- /dev/null +++ b/tests/experimental/inference_service/test_sglang_multinode.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for SGLang multi-node CLI generation.""" + +from __future__ import annotations + +from unittest.mock import patch + +from areal.api.cli_args import SGLangConfig + + +class TestSGLangMultiNode: + """Mirror of TestVLLMMultiNode for the SGLang backend.""" + + def _build_args(self, **kwargs): + """Helper that patches sglang version checks away.""" + defaults = dict( + sglang_config=SGLangConfig(model_path="test-model"), + tp_size=8, + base_gpu_id=0, + ) + defaults.update(kwargs) + with ( + patch( + "areal.api.cli_args.pkg_version.is_version_greater_or_equal", + return_value=True, + ), + patch("areal.api.cli_args.is_version_less", return_value=False), + ): + return SGLangConfig.build_args(**defaults) + + def _build_cmd(self, **kwargs): + """Helper that patches sglang version checks away.""" + defaults = dict( + sglang_config=SGLangConfig(model_path="test-model"), + tp_size=8, + base_gpu_id=0, + ) + defaults.update(kwargs) + with ( + patch( + "areal.api.cli_args.pkg_version.is_version_greater_or_equal", + return_value=True, + ), + patch("areal.api.cli_args.is_version_less", return_value=False), + ): + return SGLangConfig.build_cmd(**defaults) + + def test_build_args_single_node_defaults(self): + """Single-node (default) should have nnodes=1, node_rank=0.""" + args = self._build_args() + assert args["nnodes"] == 1 + assert args["node_rank"] == 0 + assert args.get("dist_init_addr") is None + + def test_build_args_multi_node_head(self): + """Head node (rank 0) with n_nodes > 1 should set nnodes and dist_init_addr.""" + args = self._build_args( + tp_size=16, + n_nodes=2, + node_rank=0, + dist_init_addr="10.0.0.1:29500", + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 0 + assert args["dist_init_addr"] == "10.0.0.1:29500" + + def test_build_args_multi_node_worker(self): + """Worker node (rank > 0) should set nnodes and node_rank.""" + args = self._build_args( + tp_size=16, + n_nodes=2, + node_rank=1, + dist_init_addr="10.0.0.1:29500", + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 1 + assert args["dist_init_addr"] == "10.0.0.1:29500" + + def test_build_args_multi_node_no_dist_init_addr(self): + """Multi-node without dist_init_addr should have dist_init_addr=None.""" + args = self._build_args( + tp_size=16, + n_nodes=2, + node_rank=0, + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 0 + assert args.get("dist_init_addr") is None + + def test_build_cmd_multi_node_produces_flags(self): + """build_cmd with multi-node should produce CLI flags for nnodes and node-rank.""" + cmd = self._build_cmd( + tp_size=16, + n_nodes=2, + node_rank=1, + dist_init_addr="10.0.0.1:29500", + ) + cmd_str = " ".join(cmd) + assert "--nnodes" in cmd_str + assert "--node-rank" in cmd_str + assert "--dist-init-addr" in cmd_str diff --git a/tests/experimental/inference_service/test_vllm_multinode.py b/tests/experimental/inference_service/test_vllm_multinode.py new file mode 100644 index 0000000000..969a390247 --- /dev/null +++ b/tests/experimental/inference_service/test_vllm_multinode.py @@ -0,0 +1,84 @@ +"""Tests for vLLM multi-node CLI generation.""" + +from __future__ import annotations + +from areal.api.cli_args import vLLMConfig + + +class TestVLLMMultiNode: + def test_build_args_single_node_no_extra_flags(self): + """Single-node (default) should not add nnodes/node_rank/headless.""" + cfg = vLLMConfig(model="test-model") + args = vLLMConfig.build_args(cfg, tp_size=8, pp_size=1) + assert "nnodes" not in args + assert "node_rank" not in args + assert "headless" not in args + assert "master_addr" not in args + assert "master_port" not in args + + def test_build_args_multi_node_head(self): + """Head node (rank 0) with n_nodes > 1 should add nnodes/node_rank but NOT headless.""" + cfg = vLLMConfig(model="test-model") + args = vLLMConfig.build_args( + cfg, + tp_size=16, + pp_size=1, + n_nodes=2, + node_rank=0, + dist_init_addr="10.0.0.1:29500", + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 0 + assert "headless" not in args + assert args["master_addr"] == "10.0.0.1" + assert args["master_port"] == "29500" + + def test_build_args_multi_node_worker(self): + """Worker node (rank > 0) should add headless=True.""" + cfg = vLLMConfig(model="test-model") + args = vLLMConfig.build_args( + cfg, + tp_size=16, + pp_size=1, + n_nodes=2, + node_rank=1, + dist_init_addr="10.0.0.1:29500", + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 1 + assert args["headless"] is True + assert args["master_addr"] == "10.0.0.1" + assert args["master_port"] == "29500" + + def test_build_args_multi_node_no_dist_init_addr(self): + """Multi-node without dist_init_addr should not add master_addr/master_port.""" + cfg = vLLMConfig(model="test-model") + args = vLLMConfig.build_args( + cfg, + tp_size=16, + pp_size=1, + n_nodes=2, + node_rank=0, + ) + assert args["nnodes"] == 2 + assert args["node_rank"] == 0 + assert "master_addr" not in args + assert "master_port" not in args + + def test_build_cmd_multi_node_produces_flags(self): + """build_cmd with multi-node should produce CLI flags for nnodes and node-rank.""" + cfg = vLLMConfig(model="test-model") + cmd = vLLMConfig.build_cmd( + cfg, + tp_size=16, + pp_size=1, + n_nodes=2, + node_rank=1, + dist_init_addr="10.0.0.1:29500", + ) + cmd_str = " ".join(cmd) + assert "--nnodes" in cmd_str + assert "--node-rank" in cmd_str + assert "--headless" in cmd_str + assert "--master-addr" in cmd_str + assert "--master-port" in cmd_str From f7e690a45c7c03b8ee609cc5a05e7116e59bf513 Mon Sep 17 00:00:00 2001 From: KennyMcCormick Date: Wed, 15 Apr 2026 14:07:58 +0800 Subject: [PATCH 2/2] feat(service): add agent service Controller, Guard, and Claude Agent SDK example (#1177) Add AgentServiceController and Guard for production-style orchestration, replace tau2/PydanticAI demo with Claude Agent SDK integration. Key changes: - Add controller/ with scheduler-based Guard creation (mirrors GatewayInferenceController) - Add guard/ module (pass-through to areal.infra.rpc.guard) - Add config dataclasses with __post_init__ validation - Replace Tau2Agent with ClaudeAgent (session-persistent ClaudeSDKClient) - Session lifecycle: close_session, Worker endpoint, DataProxy propagation - Initialize rollback on failure, register-before-commit in scale_up - Unregister with retry in scale_down, skip pair on failure - Lock-protected _pairs, ThreadPoolExecutor health monitor - Timing-safe WebSocket auth via hmac.compare_digest BREAKING CHANGE: areal.experimental.agent_service.__init__.py no longer re-exports symbols. Import from submodules directly. --- areal/experimental/agent_service/README.md | 34 +- areal/experimental/agent_service/__init__.py | 79 +-- areal/experimental/agent_service/auth.py | 14 +- .../agent_service/controller/__init__.py | 11 + .../agent_service/controller/config.py | 63 ++ .../agent_service/controller/controller.py | 577 ++++++++++++++++++ .../agent_service/data_proxy/__main__.py | 32 +- .../agent_service/data_proxy/app.py | 33 +- .../agent_service/data_proxy/config.py | 15 + .../agent_service/gateway/__main__.py | 28 +- .../experimental/agent_service/gateway/app.py | 15 +- .../agent_service/gateway/bridge.py | 12 +- .../agent_service/gateway/config.py | 18 + .../agent_service/guard/__init__.py | 14 + .../agent_service/guard/__main__.py | 30 + areal/experimental/agent_service/guard/app.py | 30 + .../agent_service/router/__main__.py | 22 +- .../experimental/agent_service/router/app.py | 7 +- .../agent_service/router/client.py | 8 +- .../agent_service/router/config.py | 17 + .../agent_service/worker/__main__.py | 52 +- .../experimental/agent_service/worker/app.py | 13 + .../agent_service/worker/config.py | 13 + areal/infra/utils/proc.py | 18 +- examples/agent_service/README.md | 209 +++---- examples/agent_service/agent.py | 282 ++++----- examples/agent_service/config.yaml | 25 - examples/agent_service/run_agent_service.py | 131 ++++ examples/agent_service/run_demo.py | 298 --------- .../agent_service/test_agent_router.py | 8 +- .../agent_service/test_controller.py | 340 +++++++++++ .../experimental/agent_service/test_guard.py | 73 +++ .../agent_service/test_integration.py | 19 +- 33 files changed, 1721 insertions(+), 819 deletions(-) create mode 100644 areal/experimental/agent_service/controller/__init__.py create mode 100644 areal/experimental/agent_service/controller/config.py create mode 100644 areal/experimental/agent_service/controller/controller.py create mode 100644 areal/experimental/agent_service/data_proxy/config.py create mode 100644 areal/experimental/agent_service/gateway/config.py create mode 100644 areal/experimental/agent_service/guard/__init__.py create mode 100644 areal/experimental/agent_service/guard/__main__.py create mode 100644 areal/experimental/agent_service/guard/app.py create mode 100644 areal/experimental/agent_service/router/config.py create mode 100644 areal/experimental/agent_service/worker/config.py delete mode 100644 examples/agent_service/config.yaml create mode 100644 examples/agent_service/run_agent_service.py delete mode 100644 examples/agent_service/run_demo.py create mode 100644 tests/experimental/agent_service/test_controller.py create mode 100644 tests/experimental/agent_service/test_guard.py diff --git a/areal/experimental/agent_service/README.md b/areal/experimental/agent_service/README.md index 46a8ae876f..f3dc0f839f 100644 --- a/areal/experimental/agent_service/README.md +++ b/areal/experimental/agent_service/README.md @@ -153,29 +153,45 @@ Turn 2: ``` areal/experimental/agent_service/ -├── __init__.py # Public exports +├── __init__.py # Public exports (AgentRequest, AgentResponse, etc.) ├── README.md # This document +├── auth.py # Admin key auth helpers (hmac-safe comparison) ├── protocol.py # Gateway protocol frame types ├── types.py # AgentRequest, AgentResponse, EventEmitter, AgentRunnable +├── controller/ +│ ├── __init__.py # AgentServiceController, AgentServiceControllerConfig +│ ├── config.py # AgentServiceControllerConfig dataclass +│ └── controller.py # AgentServiceController orchestrator +├── guard/ +│ ├── __init__.py # Module docstring +│ ├── __main__.py # python -m areal.experimental.agent_service.guard +│ └── app.py # Guard Flask app (pass-through to areal.infra.rpc.guard) ├── gateway/ +│ ├── __init__.py # Public exports │ ├── __main__.py # python -m areal.experimental.agent_service.gateway │ ├── app.py # create_gateway_app() -│ └── bridge.py # OpenResponsesBridge, mount_bridge() +│ ├── bridge.py # OpenResponsesBridge, mount_bridge() +│ └── config.py # GatewayConfig dataclass ├── router/ +│ ├── __init__.py # Public exports │ ├── __main__.py # python -m areal.experimental.agent_service.router │ ├── app.py # create_router_app() -│ └── client.py # RouterClient +│ ├── client.py # RouterClient +│ └── config.py # RouterConfig dataclass ├── data_proxy/ +│ ├── __init__.py # Public exports │ ├── __main__.py # python -m areal.experimental.agent_service.data_proxy │ ├── app.py # create_data_proxy_app() -│ └── client.py # DataProxyClient +│ ├── client.py # DataProxyClient +│ └── config.py # DataProxyConfig dataclass └── worker/ + ├── __init__.py # Public exports ├── __main__.py # python -m areal.experimental.agent_service.worker - └── app.py # create_worker_app() + ├── app.py # create_worker_app() + └── config.py # WorkerConfig dataclass examples/agent_service/ -├── agent.py # Tau2Agent (PydanticAI) -├── config.yaml # Demo configuration -├── run_demo.py # One-click demo -└── README.md # Example documentation +├── agent.py # ClaudeAgent (Claude Agent SDK) +├── run_agent_service.py # Controller-based launcher + interactive demo +└── README.md # Example documentation ``` diff --git a/areal/experimental/agent_service/__init__.py b/areal/experimental/agent_service/__init__.py index 28059732d0..3858d5c133 100644 --- a/areal/experimental/agent_service/__init__.py +++ b/areal/experimental/agent_service/__init__.py @@ -5,83 +5,22 @@ Exposes complete agent sessions (autonomous multi-step reasoning, tool use, memory) via independent HTTP microservices: Gateway, Router, DataProxy, and Worker. -""" - -from __future__ import annotations -import importlib -from typing import TYPE_CHECKING +Submodules +---------- +- ``controller`` — :class:`AgentServiceController` orchestrator +- ``gateway`` — public HTTP/WebSocket entry point +- ``router`` — session-affine routing +- ``data_proxy`` — stateful session proxy +- ``worker`` — stateless agent execution +- ``protocol`` — WebSocket frame types and helpers +""" -from .protocol import ( - EventFrame, - Frame, - FrameType, - QueueMode, - RequestFrame, - RequestMethod, - ResponseFrame, - RunStatus, - generate_run_id, - make_complete_response, - make_delta_event, - make_failed_response, - make_tool_call_event, - parse_frame, - serialize_frame, -) from .types import AgentRequest, AgentResponse, AgentRunnable, EventEmitter -if TYPE_CHECKING: - from .data_proxy import DataProxyClient, create_data_proxy_app - from .gateway import OpenResponsesBridge, create_gateway_app, mount_bridge - from .router import RouterClient, create_router_app - from .worker import create_worker_app - -_LAZY_IMPORTS: dict[str, str] = { - "DataProxyClient": ".data_proxy", - "OpenResponsesBridge": ".gateway", - "RouterClient": ".router", - "create_data_proxy_app": ".data_proxy", - "create_gateway_app": ".gateway", - "create_router_app": ".router", - "create_worker_app": ".worker", - "mount_bridge": ".gateway", -} - - -def __getattr__(name: str): - if name in _LAZY_IMPORTS: - module = importlib.import_module(_LAZY_IMPORTS[name], __package__) - return getattr(module, name) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - __all__ = [ "AgentRequest", "AgentResponse", "AgentRunnable", - "DataProxyClient", "EventEmitter", - "EventFrame", - "Frame", - "FrameType", - "OpenResponsesBridge", - "QueueMode", - "RequestFrame", - "RequestMethod", - "ResponseFrame", - "RouterClient", - "RunStatus", - "create_data_proxy_app", - "create_gateway_app", - "create_router_app", - "create_worker_app", - "generate_run_id", - "make_complete_response", - "make_delta_event", - "make_failed_response", - "make_tool_call_event", - "mount_bridge", - "parse_frame", - "serialize_frame", ] diff --git a/areal/experimental/agent_service/auth.py b/areal/experimental/agent_service/auth.py index 0f01da4bd2..b3893f5bf2 100644 --- a/areal/experimental/agent_service/auth.py +++ b/areal/experimental/agent_service/auth.py @@ -4,9 +4,11 @@ from __future__ import annotations +import hmac + from fastapi import Header, HTTPException -DEFAULT_ADMIN_KEY = "areal-agent-admin" +DEFAULT_ADMIN_API_KEY = "areal-agent-admin" async def verify_admin_key( @@ -15,16 +17,16 @@ async def verify_admin_key( expected_key: str, ) -> None: expected = f"Bearer {expected_key}" - if authorization != expected: + if not hmac.compare_digest(authorization, expected): raise HTTPException(status_code=401, detail="Invalid admin key") -def make_admin_dependency(admin_key: str): +def make_admin_dependency(admin_api_key: str): async def _dep(authorization: str = Header(alias="Authorization")) -> None: - await verify_admin_key(authorization, expected_key=admin_key) + await verify_admin_key(authorization, expected_key=admin_api_key) return _dep -def admin_headers(admin_key: str) -> dict[str, str]: - return {"Authorization": f"Bearer {admin_key}"} +def admin_headers(admin_api_key: str) -> dict[str, str]: + return {"Authorization": f"Bearer {admin_api_key}"} diff --git a/areal/experimental/agent_service/controller/__init__.py b/areal/experimental/agent_service/controller/__init__.py new file mode 100644 index 0000000000..3150205885 --- /dev/null +++ b/areal/experimental/agent_service/controller/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Agent Service Controller — orchestrator for agent micro-services.""" + +from .config import AgentServiceControllerConfig +from .controller import AgentServiceController + +__all__ = [ + "AgentServiceController", + "AgentServiceControllerConfig", +] diff --git a/areal/experimental/agent_service/controller/config.py b/areal/experimental/agent_service/controller/config.py new file mode 100644 index 0000000000..c316d58227 --- /dev/null +++ b/areal/experimental/agent_service/controller/config.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Configuration for the AgentServiceController.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from ..auth import DEFAULT_ADMIN_API_KEY + + +@dataclass +class AgentServiceControllerConfig: + """Unified configuration for AgentServiceController. + + Consolidates settings for the guard, router, gateway, worker, and + data proxy micro-services launched by the controller. + """ + + # -- Agent class ------------------------------------------------------- + agent_cls_path: str = "" + """Fully-qualified import path for the ``AgentRunnable`` implementation + (e.g. ``examples.agent_service.agent.Tau2Agent``).""" + + # -- Authentication ---------------------------------------------------- + admin_api_key: str = DEFAULT_ADMIN_API_KEY + """Shared admin API key for inter-service Bearer auth.""" + + # -- Scaling ----------------------------------------------------------- + num_pairs: int = 1 + """Number of Worker+DataProxy pairs to launch on initialize.""" + + # -- Timeouts ---------------------------------------------------------- + setup_timeout: float = 120.0 + """Timeout (seconds) waiting for each service to become healthy.""" + + health_poll_interval: float = 5.0 + """Seconds between health polls for crash detection (0 = disabled).""" + + drain_timeout: float = 30.0 + """Seconds to wait for active sessions to drain before force-killing a pair.""" + + # -- Log level --------------------------------------------------------- + log_level: str = "info" + """Log level for spawned micro-services.""" + + # -- Environment ------------------------------------------------------- + env: dict[str, str] = field(default_factory=dict) + """Extra environment variables to pass to all forked child processes.""" + + def __post_init__(self) -> None: + if not self.agent_cls_path: + raise ValueError("agent_cls_path must be a non-empty import path") + if self.num_pairs < 0: + raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") + if self.setup_timeout <= 0: + raise ValueError( + f"setup_timeout must be positive, got {self.setup_timeout}" + ) + if self.drain_timeout < 0: + raise ValueError( + f"drain_timeout must be non-negative, got {self.drain_timeout}" + ) diff --git a/areal/experimental/agent_service/controller/controller.py b/areal/experimental/agent_service/controller/controller.py new file mode 100644 index 0000000000..21b12851bb --- /dev/null +++ b/areal/experimental/agent_service/controller/controller.py @@ -0,0 +1,577 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""AgentServiceController — orchestrates agent service micro-services via Guards. + +Mirrors the architecture of +:class:`~areal.experimental.inference_service.controller.controller.GatewayInferenceController`: +Guard workers are created via the Scheduler, then the controller forks +Router, Worker+DataProxy pairs, and Gateway onto them via HTTP API. + +Lifecycle:: + + from areal.infra.scheduler.local import LocalScheduler + + scheduler = LocalScheduler(...) + controller = AgentServiceController(config, scheduler) + controller.initialize() + # ... run traffic ... + controller.scale_up(2) # add 2 Worker+DataProxy pairs + controller.scale_down(1) # drain + remove 1 pair + controller.destroy() +""" + +from __future__ import annotations + +import sys +import threading +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import requests + +from areal.experimental.agent_service.controller.config import ( + AgentServiceControllerConfig, +) +from areal.utils import logging +from areal.utils.network import format_hostport + +if TYPE_CHECKING: + from areal.api.scheduler_api import Scheduler, Worker + +logger = logging.getLogger("AgentServiceController") + +_GUARD_ROLE = "agent-guard" +_UNREGISTER_RETRIES = 3 +_HEALTH_CHECK_WORKERS = 4 + + +@dataclass +class _WorkerPair: + pair_index: int + guard_addr: str + worker_host: str + worker_port: int + proxy_host: str + proxy_port: int + proxy_addr: str + worker_addr: str + + +class AgentServiceController: + """Orchestrator for the Agent Service micro-service stack. + + Parameters + ---------- + config: + Controller configuration. + scheduler: + Scheduler instance used to create and manage Guard workers. + """ + + def __init__( + self, + config: AgentServiceControllerConfig, + scheduler: Scheduler, + ) -> None: + self.config = config + self.scheduler = scheduler + + self._guard_addrs: list[str] = [] + self._workers: list[Worker] = [] + self._service_roles: list[str] = [] + + self._router_addr: str = "" + self._gateway_addr: str = "" + + self._pairs: dict[int, _WorkerPair] = {} + self._pairs_lock = threading.Lock() + self._next_pair_index: int = 0 + + self._forked_services: list[tuple[str, str, int]] = [] + + self._health_stop = threading.Event() + self._health_thread: threading.Thread | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def initialize(self) -> None: + """Launch the full micro-service stack. + + Order: Guards (via scheduler) → Router → Worker+DataProxy pairs → + register → Gateway → health monitor. + On failure, already-forked services are cleaned up via destroy(). + """ + try: + self._do_initialize() + except Exception: + logger.error("initialize() failed, rolling back...") + self.destroy() + raise + + def _do_initialize(self) -> None: + from areal.api.cli_args import SchedulingSpec, SchedulingStrategy + from areal.api.scheduler_api import Job + + cfg = self.config + + # Step 1: Create Guard workers via scheduler + guard_spec = SchedulingSpec( + gpu=0, cmd=f"{sys.executable} -m areal.experimental.agent_service.guard" + ) + num_guards = max(cfg.num_pairs, 1) + guard_job = Job( + role=_GUARD_ROLE, + replicas=num_guards, + tasks=[guard_spec for _ in range(num_guards)], + scheduling_strategy=SchedulingStrategy(), + ) + self.scheduler.create_workers(job=guard_job) + self._service_roles.append(_GUARD_ROLE) + + self._workers = self.scheduler.get_workers(role=_GUARD_ROLE) + self._guard_addrs = [ + f"http://{format_hostport(w.ip, int(w.worker_ports[0]))}" + for w in self._workers + ] + logger.info("Guards ready: %s", self._guard_addrs) + + # Step 2: Fork Router on guard[0] + guard_0 = self._guard_addrs[0] + router_cmd = [ + sys.executable, + "-m", + "areal.experimental.agent_service.router", + "--admin-api-key", + cfg.admin_api_key, + ] + router_host, router_port = self._fork_on_guard( + guard_addr=guard_0, + role="agent-router", + worker_index=0, + raw_cmd=router_cmd, + ) + self._router_addr = f"http://{format_hostport(router_host, router_port)}" + logger.info("Router: %s", self._router_addr) + + # Step 3: Fork Worker+DataProxy pairs + self.scale_up(cfg.num_pairs) + + # Step 4: Fork Gateway on guard[0] + gw_cmd = [ + sys.executable, + "-m", + "areal.experimental.agent_service.gateway", + "--router-addr", + self._router_addr, + "--admin-api-key", + cfg.admin_api_key, + ] + gw_host, gw_port = self._fork_on_guard( + guard_addr=guard_0, + role="agent-gateway", + worker_index=0, + raw_cmd=gw_cmd, + ) + self._gateway_addr = f"http://{format_hostport(gw_host, gw_port)}" + logger.info("Gateway: %s", self._gateway_addr) + + # Step 5: Start health monitor + if cfg.health_poll_interval > 0: + self._health_stop.clear() + self._health_thread = threading.Thread( + target=self._health_monitor_loop, daemon=True + ) + self._health_thread.start() + + def destroy(self) -> None: + """Tear down all services in reverse order.""" + self._stop_health_monitor() + + for guard_addr, role, worker_index in reversed(self._forked_services): + try: + self._kill_forked_service(guard_addr, role, worker_index) + except requests.RequestException: + logger.error( + "Error killing forked service %s/%d: %s", + role, + worker_index, + traceback.format_exc(), + ) + self._forked_services.clear() + + for role in reversed(self._service_roles): + try: + self.scheduler.delete_workers(role=role) + logger.info("Workers deleted for role: %s", role) + except Exception: + logger.error( + "Error deleting workers for role %s: %s", + role, + traceback.format_exc(), + ) + self._service_roles.clear() + self._workers.clear() + self._guard_addrs.clear() + with self._pairs_lock: + self._pairs.clear() + self._router_addr = "" + self._gateway_addr = "" + + def scale_up(self, count: int) -> list[int]: + """Add *count* Worker+DataProxy pairs. + + Pairs are distributed across guards round-robin. + Returns the pair indices that were created. + """ + cfg = self.config + created: list[int] = [] + + for _ in range(count): + pair_index = self._next_pair_index + self._next_pair_index += 1 + + guard_addr = self._guard_addrs[pair_index % len(self._guard_addrs)] + + worker_cmd = [ + sys.executable, + "-m", + "areal.experimental.agent_service.worker", + "--agent", + cfg.agent_cls_path, + "--log-level", + cfg.log_level, + ] + worker_host, worker_port = self._fork_on_guard( + guard_addr=guard_addr, + role=f"agent-worker-{pair_index}", + worker_index=pair_index, + raw_cmd=worker_cmd, + ) + worker_addr = f"http://{format_hostport(worker_host, worker_port)}" + + proxy_cmd = [ + sys.executable, + "-m", + "areal.experimental.agent_service.data_proxy", + "--worker-addr", + worker_addr, + ] + proxy_host, proxy_port = self._fork_on_guard( + guard_addr=guard_addr, + role=f"agent-proxy-{pair_index}", + worker_index=pair_index, + raw_cmd=proxy_cmd, + ) + proxy_addr = f"http://{format_hostport(proxy_host, proxy_port)}" + + pair = _WorkerPair( + pair_index=pair_index, + guard_addr=guard_addr, + worker_host=worker_host, + worker_port=worker_port, + proxy_host=proxy_host, + proxy_port=proxy_port, + proxy_addr=proxy_addr, + worker_addr=worker_addr, + ) + + try: + self._register_proxy(proxy_addr) + except Exception: + logger.error( + "Failed to register pair %d, cleaning up forked processes", + pair_index, + ) + self._cleanup_pair_forks(pair_index, guard_addr) + raise + + with self._pairs_lock: + self._pairs[pair_index] = pair + created.append(pair_index) + + logger.info( + "Pair %d: worker=%s proxy=%s", pair_index, worker_addr, proxy_addr + ) + + return created + + def scale_down(self, count: int) -> list[int]: + """Remove *count* pairs (LIFO order). + + For each pair: unregister from Router (with retry) → drain active + sessions → kill DataProxy → kill Worker. + Returns the pair indices that were removed. + """ + removed: list[int] = [] + + with self._pairs_lock: + indices = sorted(self._pairs.keys(), reverse=True)[:count] + + for pair_index in indices: + with self._pairs_lock: + pair = self._pairs.get(pair_index) + if pair is None: + continue + + try: + self._unregister_proxy(pair.proxy_addr) + except requests.RequestException: + logger.error( + "Unregister failed for pair %d after retries, skipping", + pair_index, + ) + continue + + self._drain_proxy(pair.proxy_addr) + + with self._pairs_lock: + self._pairs.pop(pair_index, None) + + proxy_key = (pair.guard_addr, f"agent-proxy-{pair_index}", pair_index) + worker_key = (pair.guard_addr, f"agent-worker-{pair_index}", pair_index) + + for guard_addr, role, wi in [proxy_key, worker_key]: + try: + self._kill_forked_service(guard_addr, role, wi) + entry = (guard_addr, role, wi) + if entry in self._forked_services: + self._forked_services.remove(entry) + except requests.RequestException: + logger.warning( + "Failed to kill %s/%d: %s", + role, + wi, + traceback.format_exc(), + ) + + removed.append(pair_index) + logger.info("Removed pair %d", pair_index) + + return removed + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def router_addr(self) -> str: + return self._router_addr + + @property + def gateway_addr(self) -> str: + return self._gateway_addr + + @property + def pairs(self) -> dict[int, _WorkerPair]: + with self._pairs_lock: + return dict(self._pairs) + + # ------------------------------------------------------------------ + # Guard interaction helpers + # ------------------------------------------------------------------ + + def _fork_on_guard( + self, + guard_addr: str, + role: str, + worker_index: int, + raw_cmd: list[str], + health_path: str = "/health", + env: dict[str, str] | None = None, + ) -> tuple[str, int]: + resp = requests.post( + f"{guard_addr}/alloc_ports", + json={"count": 1}, + timeout=30, + ) + resp.raise_for_status() + port_data = resp.json() + host = port_data["host"] + port = port_data["ports"][0] + + cmd = list(raw_cmd) + ["--host", host, "--port", str(port)] + + merged_env = {**self.config.env, **(env or {})} + + fork_payload: dict[str, Any] = { + "role": role, + "worker_index": worker_index, + "raw_cmd": cmd, + } + if merged_env: + fork_payload["env"] = merged_env + + resp = requests.post( + f"{guard_addr}/fork", + json=fork_payload, + timeout=30, + ) + resp.raise_for_status() + + self._forked_services.append((guard_addr, role, worker_index)) + + addr = f"http://{format_hostport(host, port)}" + self._wait_for_service(f"{addr}{health_path}", role) + + return host, port + + def _cleanup_pair_forks(self, pair_index: int, guard_addr: str) -> None: + for role_prefix in ("agent-proxy-", "agent-worker-"): + role = f"{role_prefix}{pair_index}" + entry = (guard_addr, role, pair_index) + if entry in self._forked_services: + try: + self._kill_forked_service(guard_addr, role, pair_index) + except requests.RequestException: + pass + self._forked_services.remove(entry) + + def _kill_forked_service( + self, guard_addr: str, role: str, worker_index: int + ) -> None: + try: + resp = requests.post( + f"{guard_addr}/kill_forked_worker", + json={"role": role, "worker_index": worker_index}, + timeout=10, + ) + if resp.status_code == 200: + logger.info("Killed forked service %s/%d", role, worker_index) + else: + logger.warning( + "Failed to kill forked service %s/%d: %s", + role, + worker_index, + resp.text, + ) + except requests.RequestException as exc: + logger.error( + "Error killing forked service %s/%d: %s", role, worker_index, exc + ) + + def _wait_for_service( + self, url: str, name: str, timeout: float | None = None + ) -> None: + timeout = timeout or self.config.setup_timeout + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = requests.get(url, timeout=2) + if resp.status_code == 200: + logger.info("%s healthy at %s", name, url) + return + except requests.RequestException: + pass + time.sleep(0.5) + raise TimeoutError(f"{name} did not become healthy at {url} within {timeout}s") + + def _register_proxy(self, proxy_addr: str) -> None: + """Raises on failure so that ``scale_up`` callers know the pair is + non-functional. + """ + if not self._router_addr: + return + resp = requests.post( + f"{self._router_addr}/register", + json={"addr": proxy_addr}, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, + timeout=10, + ) + resp.raise_for_status() + logger.info("Registered proxy %s with Router", proxy_addr) + + def _drain_proxy(self, proxy_addr: str) -> None: + timeout = self.config.drain_timeout + if timeout <= 0: + return + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = requests.get(f"{proxy_addr}/health", timeout=2) + if resp.status_code == 200: + active = resp.json().get("active_sessions", 0) + if active == 0: + logger.info("Proxy %s drained", proxy_addr) + return + logger.debug( + "Proxy %s draining: %d active sessions", proxy_addr, active + ) + except requests.RequestException: + break + time.sleep(1.0) + logger.warning( + "Proxy %s drain timed out after %.0fs, force-killing", proxy_addr, timeout + ) + + def _check_pair_health(self, pair_index: int, proxy_addr: str) -> None: + try: + resp = requests.get(f"{proxy_addr}/health", timeout=2) + if resp.status_code != 200: + logger.warning( + "Pair %d proxy %s returned %d", + pair_index, + proxy_addr, + resp.status_code, + ) + except requests.RequestException: + logger.warning("Pair %d proxy %s unreachable", pair_index, proxy_addr) + + def _health_monitor_loop(self) -> None: + interval = self.config.health_poll_interval + while not self._health_stop.wait(timeout=interval): + with self._pairs_lock: + snapshot = list(self._pairs.items()) + if not snapshot: + continue + with ThreadPoolExecutor( + max_workers=min(_HEALTH_CHECK_WORKERS, len(snapshot)) + ) as pool: + futures = { + pool.submit(self._check_pair_health, idx, pair.proxy_addr): idx + for idx, pair in snapshot + } + for future in as_completed(futures, timeout=10): + try: + future.result() + except Exception: + pass + + def _stop_health_monitor(self) -> None: + self._health_stop.set() + if self._health_thread is not None: + self._health_thread.join(timeout=5) + self._health_thread = None + + def _unregister_proxy(self, proxy_addr: str) -> None: + """Unregister with retry. Raises after all retries exhausted.""" + if not self._router_addr: + return + last_exc: Exception | None = None + for attempt in range(_UNREGISTER_RETRIES): + try: + resp = requests.post( + f"{self._router_addr}/unregister", + json={"addr": proxy_addr}, + headers={"Authorization": f"Bearer {self.config.admin_api_key}"}, + timeout=5, + ) + resp.raise_for_status() + logger.info("Unregistered proxy %s", proxy_addr) + return + except requests.RequestException as exc: + last_exc = exc + logger.warning( + "Unregister proxy %s attempt %d/%d failed: %s", + proxy_addr, + attempt + 1, + _UNREGISTER_RETRIES, + exc, + ) + if attempt < _UNREGISTER_RETRIES - 1: + time.sleep(1.0) + raise last_exc # type: ignore[misc] diff --git a/areal/experimental/agent_service/data_proxy/__main__.py b/areal/experimental/agent_service/data_proxy/__main__.py index bda8e6164e..c856bac91b 100644 --- a/areal/experimental/agent_service/data_proxy/__main__.py +++ b/areal/experimental/agent_service/data_proxy/__main__.py @@ -2,21 +2,41 @@ """``python -m areal.experimental.agent_service.data_proxy``""" -from .app import create_data_proxy_app +import argparse -if __name__ == "__main__": - import argparse +import uvicorn + +from .app import create_data_proxy_app +from .config import DataProxyConfig - import uvicorn +def main() -> None: parser = argparse.ArgumentParser(description="Agent DataProxy") parser.add_argument("--worker-addr", required=True, help="Worker HTTP address") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=9100) + parser.add_argument("--request-timeout", type=float, default=600.0) + parser.add_argument("--session-timeout", type=int, default=3600) + parser.add_argument( + "--log-level", choices=["debug", "info", "warning", "error"], default="info" + ) args = parser.parse_args() - uvicorn.run( - create_data_proxy_app(worker_addr=args.worker_addr), + config = DataProxyConfig( host=args.host, port=args.port, + worker_addr=args.worker_addr, + request_timeout=args.request_timeout, + session_timeout=args.session_timeout, + log_level=args.log_level, ) + uvicorn.run( + create_data_proxy_app(config), + host=config.host, + port=config.port, + log_level=config.log_level, + ) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/agent_service/data_proxy/app.py b/areal/experimental/agent_service/data_proxy/app.py index 0e05c4b392..fbf51646ba 100644 --- a/areal/experimental/agent_service/data_proxy/app.py +++ b/areal/experimental/agent_service/data_proxy/app.py @@ -14,6 +14,8 @@ from areal.utils import logging +from .config import DataProxyConfig + logger = logging.getLogger("AgentDataProxy") @@ -24,23 +26,31 @@ class _SessionData: last_active: float = field(default_factory=time.monotonic) -def create_data_proxy_app( - worker_addr: str, - session_timeout: int = 3600, -) -> FastAPI: +def create_data_proxy_app(config: DataProxyConfig) -> FastAPI: app = FastAPI(title="AReaL Data Proxy") sessions: dict[str, _SessionData] = {} - http_client = httpx.AsyncClient(timeout=600.0) + http_client = httpx.AsyncClient(timeout=config.request_timeout) + + async def _close_worker_session(session_key: str) -> None: + try: + await http_client.post( + f"{config.worker_addr}/session/{session_key}/close", timeout=5 + ) + except Exception: + logger.debug("Failed to close worker session %s", session_key) async def _reap_idle_sessions() -> None: while True: await asyncio.sleep(60) now = time.monotonic() stale = [ - k for k, s in sessions.items() if now - s.last_active > session_timeout + k + for k, s in sessions.items() + if now - s.last_active > config.session_timeout ] for k in stale: del sessions[k] + await _close_worker_session(k) if stale: logger.info("Reaped %d idle sessions", len(stale)) @@ -57,17 +67,11 @@ async def health(): return { "status": "ok", "active_sessions": len(sessions), - "worker_addr": worker_addr, + "worker_addr": config.worker_addr, } @app.post("/session/{session_key}/turn") async def turn(session_key: str, body: dict[str, Any]): - """Process one turn. session_key must be unique per agent session. - - When used with the rollout service, uniqueness is ensured by - ``/rl/start_session``. When used standalone, callers must - generate unique keys (e.g. ``f"{model}:{user_id}"``). - """ session = sessions.get(session_key) if session is None: session = _SessionData() @@ -87,7 +91,7 @@ async def turn(session_key: str, body: dict[str, Any]): "metadata": metadata, } - resp = await http_client.post(f"{worker_addr}/run", json=worker_request) + resp = await http_client.post(f"{config.worker_addr}/run", json=worker_request) resp.raise_for_status() result = resp.json() @@ -138,6 +142,7 @@ async def turn(session_key: str, body: dict[str, Any]): @app.post("/session/{session_key}/close") async def close_session(session_key: str): sessions.pop(session_key, None) + await _close_worker_session(session_key) return {"status": "ok"} @app.get("/session/{session_key}/history") diff --git a/areal/experimental/agent_service/data_proxy/config.py b/areal/experimental/agent_service/data_proxy/config.py new file mode 100644 index 0000000000..45e5a994ee --- /dev/null +++ b/areal/experimental/agent_service/data_proxy/config.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class DataProxyConfig: + host: str = "0.0.0.0" + port: int = 9100 + worker_addr: str = "http://localhost:9000" + request_timeout: float = 600.0 + session_timeout: int = 3600 + log_level: str = "info" diff --git a/areal/experimental/agent_service/gateway/__main__.py b/areal/experimental/agent_service/gateway/__main__.py index a8923d42ef..bc02f532f2 100644 --- a/areal/experimental/agent_service/gateway/__main__.py +++ b/areal/experimental/agent_service/gateway/__main__.py @@ -6,8 +6,10 @@ import uvicorn +from ..auth import DEFAULT_ADMIN_API_KEY from .app import create_gateway_app from .bridge import OpenResponsesBridge, mount_bridge +from .config import GatewayConfig def main() -> None: @@ -15,16 +17,32 @@ def main() -> None: parser.add_argument("--router-addr", required=True, help="Router HTTP address") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8080) - parser.add_argument("--admin-key", default="areal-agent-admin") + parser.add_argument("--admin-api-key", default=DEFAULT_ADMIN_API_KEY) + parser.add_argument("--router-timeout", type=float, default=2.0) + parser.add_argument("--forward-timeout", type=float, default=120.0) + parser.add_argument( + "--log-level", choices=["debug", "info", "warning", "error"], default="info" + ) args = parser.parse_args() - app = create_gateway_app(router_addr=args.router_addr, admin_key=args.admin_key) + config = GatewayConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + router_addr=args.router_addr, + router_timeout=args.router_timeout, + forward_timeout=args.forward_timeout, + log_level=args.log_level, + ) + app = create_gateway_app(config) mount_bridge( app, - OpenResponsesBridge(router_addr=args.router_addr, admin_key=args.admin_key), - admin_key=args.admin_key, + OpenResponsesBridge( + router_addr=config.router_addr, admin_api_key=config.admin_api_key + ), + admin_api_key=config.admin_api_key, ) - uvicorn.run(app, host=args.host, port=args.port) + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level) if __name__ == "__main__": diff --git a/areal/experimental/agent_service/gateway/app.py b/areal/experimental/agent_service/gateway/app.py index 0f1c3d1dc0..b3043ad57a 100644 --- a/areal/experimental/agent_service/gateway/app.py +++ b/areal/experimental/agent_service/gateway/app.py @@ -4,6 +4,7 @@ from __future__ import annotations +import hmac import json import traceback @@ -12,7 +13,7 @@ from areal.utils import logging -from ..auth import DEFAULT_ADMIN_KEY, admin_headers +from ..auth import admin_headers from ..protocol import ( FrameType, RequestFrame, @@ -26,6 +27,7 @@ parse_frame, serialize_frame, ) +from .config import GatewayConfig logger = logging.getLogger("AgentGateway") @@ -41,16 +43,17 @@ def _make_accepted_json(request_id: str, run_id: str) -> str: ) -def create_gateway_app(router_addr: str, admin_key: str = DEFAULT_ADMIN_KEY) -> FastAPI: +def create_gateway_app(config: GatewayConfig) -> FastAPI: app = FastAPI(title="AReaL Agent Gateway") - http_client = httpx.AsyncClient(timeout=600.0) - _auth_headers = admin_headers(admin_key) + http_client = httpx.AsyncClient(timeout=config.forward_timeout) + _auth_headers = admin_headers(config.admin_api_key) async def _route(session_key: str) -> str: resp = await http_client.post( - f"{router_addr}/route", + f"{config.router_addr}/route", json={"session_key": session_key}, headers=_auth_headers, + timeout=config.router_timeout, ) resp.raise_for_status() return resp.json()["data_proxy_addr"] @@ -81,7 +84,7 @@ async def health(): @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, token: str = Query(default="")): - if token != admin_key: + if not hmac.compare_digest(token, config.admin_api_key): await websocket.close(code=4001, reason="Invalid admin key") return await websocket.accept() diff --git a/areal/experimental/agent_service/gateway/bridge.py b/areal/experimental/agent_service/gateway/bridge.py index a29c342f29..e47c8d8f35 100644 --- a/areal/experimental/agent_service/gateway/bridge.py +++ b/areal/experimental/agent_service/gateway/bridge.py @@ -14,7 +14,7 @@ from areal.utils import logging -from ..auth import DEFAULT_ADMIN_KEY, admin_headers, make_admin_dependency +from ..auth import DEFAULT_ADMIN_API_KEY, admin_headers, make_admin_dependency from ..protocol import generate_run_id logger = logging.getLogger("AgentBridge") @@ -26,9 +26,11 @@ async def handle_request(self, request: Request) -> Any: ... class OpenResponsesBridge(AgentBridge): - def __init__(self, router_addr: str, admin_key: str = DEFAULT_ADMIN_KEY) -> None: + def __init__( + self, router_addr: str, admin_api_key: str = DEFAULT_ADMIN_API_KEY + ) -> None: self._router_addr = router_addr - self._auth_headers = admin_headers(admin_key) + self._auth_headers = admin_headers(admin_api_key) self._http = httpx.AsyncClient(timeout=600.0) async def close(self) -> None: @@ -157,9 +159,9 @@ def _derive_session_key(user: str, model: str) -> str: def mount_bridge( app: FastAPI, bridge: OpenResponsesBridge, - admin_key: str = DEFAULT_ADMIN_KEY, + admin_api_key: str = DEFAULT_ADMIN_API_KEY, ) -> None: - auth = make_admin_dependency(admin_key) + auth = make_admin_dependency(admin_api_key) @app.post("/v1/responses", dependencies=[Depends(auth)]) async def responses_endpoint(request: Request): diff --git a/areal/experimental/agent_service/gateway/config.py b/areal/experimental/agent_service/gateway/config.py new file mode 100644 index 0000000000..f7ec950fc2 --- /dev/null +++ b/areal/experimental/agent_service/gateway/config.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from ..auth import DEFAULT_ADMIN_API_KEY + + +@dataclass +class GatewayConfig: + host: str = "0.0.0.0" + port: int = 8080 + admin_api_key: str = DEFAULT_ADMIN_API_KEY + router_addr: str = "http://localhost:8081" + router_timeout: float = 2.0 + forward_timeout: float = 120.0 + log_level: str = "info" diff --git a/areal/experimental/agent_service/guard/__init__.py b/areal/experimental/agent_service/guard/__init__.py new file mode 100644 index 0000000000..57f50162ec --- /dev/null +++ b/areal/experimental/agent_service/guard/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Agent Service Guard — process supervisor backed by the shared guard. + +Pure pass-through to ``areal.infra.rpc.guard``. All orchestration logic +(launching Router, Gateway, Worker+DataProxy pairs) lives in the +:mod:`~areal.experimental.agent_service.controller` module. + +Quick start:: + + python -m areal.experimental.agent_service.guard \\ + --experiment-name demo --trial-name run0 \\ + --role agent-guard --worker-index 0 +""" diff --git a/areal/experimental/agent_service/guard/__main__.py b/areal/experimental/agent_service/guard/__main__.py new file mode 100644 index 0000000000..d311f6023b --- /dev/null +++ b/areal/experimental/agent_service/guard/__main__.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""CLI entrypoint: ``python -m areal.experimental.agent_service.guard``""" + +from __future__ import annotations + +from areal.experimental.agent_service.guard.app import ( + _state, + app, +) +from areal.infra.rpc.guard.app import ( + configure_state_from_args, + make_base_parser, + run_server, +) + + +def main(): + parser = make_base_parser( + description="AReaL Agent Service Guard — process supervisor for agent workers" + ) + args, _ = parser.parse_known_args() + + bind_host = configure_state_from_args(_state, args) + + run_server(_state, app, bind_host, args.port) + + +if __name__ == "__main__": + main() diff --git a/areal/experimental/agent_service/guard/app.py b/areal/experimental/agent_service/guard/app.py new file mode 100644 index 0000000000..b137feef91 --- /dev/null +++ b/areal/experimental/agent_service/guard/app.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Agent Service Guard backed by the shared guard infrastructure. + +All core guard functionality (port allocation, process forking, health +checks, cleanup) is provided by ``areal.infra.rpc.guard``. This module +creates and exposes the Flask app and shared state instance, following +the same pattern as ``areal.experimental.inference_service.guard``. +""" + +from __future__ import annotations + +from areal.infra.rpc.guard.app import ( + GuardState, + create_app, +) +from areal.infra.rpc.guard.app import ( + cleanup_forked_children as _cleanup_impl, +) +from areal.utils import logging + +logger = logging.getLogger("AgentGuard") + +_state = GuardState() + +app = create_app(_state) + + +def cleanup_forked_children() -> None: + _cleanup_impl(_state) diff --git a/areal/experimental/agent_service/router/__main__.py b/areal/experimental/agent_service/router/__main__.py index d52f77392d..f0203a2f37 100644 --- a/areal/experimental/agent_service/router/__main__.py +++ b/areal/experimental/agent_service/router/__main__.py @@ -6,18 +6,36 @@ import uvicorn +from ..auth import DEFAULT_ADMIN_API_KEY from .app import create_router_app +from .config import RouterConfig def main() -> None: parser = argparse.ArgumentParser(description="Agent Router") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8081) - parser.add_argument("--admin-key", default="areal-agent-admin") + parser.add_argument("--admin-api-key", default=DEFAULT_ADMIN_API_KEY) + parser.add_argument("--poll-interval", type=float, default=5.0) + parser.add_argument("--worker-health-timeout", type=float, default=2.0) + parser.add_argument( + "--log-level", choices=["debug", "info", "warning", "error"], default="info" + ) args = parser.parse_args() + config = RouterConfig( + host=args.host, + port=args.port, + admin_api_key=args.admin_api_key, + poll_interval=args.poll_interval, + worker_health_timeout=args.worker_health_timeout, + log_level=args.log_level, + ) uvicorn.run( - create_router_app(admin_key=args.admin_key), host=args.host, port=args.port + create_router_app(config), + host=config.host, + port=config.port, + log_level=config.log_level, ) diff --git a/areal/experimental/agent_service/router/app.py b/areal/experimental/agent_service/router/app.py index 980ee3b0ff..222c8a6c17 100644 --- a/areal/experimental/agent_service/router/app.py +++ b/areal/experimental/agent_service/router/app.py @@ -12,14 +12,15 @@ from areal.utils import logging -from ..auth import DEFAULT_ADMIN_KEY, make_admin_dependency +from ..auth import make_admin_dependency +from .config import RouterConfig logger = logging.getLogger("AgentRouter") -def create_router_app(admin_key: str = DEFAULT_ADMIN_KEY) -> FastAPI: +def create_router_app(config: RouterConfig) -> FastAPI: app = FastAPI(title="AReaL Agent Router") - auth = make_admin_dependency(admin_key) + auth = make_admin_dependency(config.admin_api_key) registered_proxies: list[str] = [] session_map: dict[str, str] = {} diff --git a/areal/experimental/agent_service/router/client.py b/areal/experimental/agent_service/router/client.py index 7d96a646be..4c5e129be9 100644 --- a/areal/experimental/agent_service/router/client.py +++ b/areal/experimental/agent_service/router/client.py @@ -4,13 +4,15 @@ import httpx -from ..auth import DEFAULT_ADMIN_KEY, admin_headers +from ..auth import DEFAULT_ADMIN_API_KEY, admin_headers class RouterClient: - def __init__(self, router_addr: str, admin_key: str = DEFAULT_ADMIN_KEY) -> None: + def __init__( + self, router_addr: str, admin_api_key: str = DEFAULT_ADMIN_API_KEY + ) -> None: self._addr = router_addr - self._headers = admin_headers(admin_key) + self._headers = admin_headers(admin_api_key) self._http = httpx.AsyncClient(timeout=30.0) async def register(self, addr: str) -> None: diff --git a/areal/experimental/agent_service/router/config.py b/areal/experimental/agent_service/router/config.py new file mode 100644 index 0000000000..ed36e50244 --- /dev/null +++ b/areal/experimental/agent_service/router/config.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from ..auth import DEFAULT_ADMIN_API_KEY + + +@dataclass +class RouterConfig: + host: str = "0.0.0.0" + port: int = 8081 + admin_api_key: str = DEFAULT_ADMIN_API_KEY + poll_interval: float = 5.0 + worker_health_timeout: float = 2.0 + log_level: str = "info" diff --git a/areal/experimental/agent_service/worker/__main__.py b/areal/experimental/agent_service/worker/__main__.py index 6077390589..c14d52eba7 100644 --- a/areal/experimental/agent_service/worker/__main__.py +++ b/areal/experimental/agent_service/worker/__main__.py @@ -1,54 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 -"""``python -m areal.experimental.agent_service.worker``""" +"""``python -m areal.experimental.agent_service.worker`` + +Start a standalone Agent Worker process. The Controller forks this +via Guard to create Worker+DataProxy pairs. + + python -m areal.experimental.agent_service.worker \ + --agent examples.agent_service.agent.ClaudeAgent \ + --host 127.0.0.1 --port 9000 +""" import argparse -import asyncio -import threading -import httpx import uvicorn -from areal.utils.network import format_hostport - from .app import create_worker_app def main() -> None: - from ..data_proxy import create_data_proxy_app - - parser = argparse.ArgumentParser(description="Agent Worker + DataProxy") + parser = argparse.ArgumentParser(description="Agent Worker") parser.add_argument("--agent", required=True, help="Agent import path") - parser.add_argument("--router-addr", required=True, help="Router HTTP address") - parser.add_argument("--worker-port", type=int, default=9000) - parser.add_argument("--proxy-port", type=int, default=9100) parser.add_argument("--host", default="0.0.0.0") - parser.add_argument("--admin-key", default="areal-agent-admin") + parser.add_argument("--port", type=int, default=9000) + parser.add_argument( + "--log-level", choices=["debug", "info", "warning", "error"], default="info" + ) args = parser.parse_args() - worker_addr = f"http://{format_hostport(args.host, args.worker_port)}" - proxy_addr = f"http://{format_hostport(args.host, args.proxy_port)}" - - worker_app = create_worker_app(args.agent) - proxy_app = create_data_proxy_app(worker_addr=worker_addr) - - def run_worker(): - uvicorn.run(worker_app, host=args.host, port=args.worker_port, log_level="info") - - threading.Thread(target=run_worker, daemon=True).start() - - from ..auth import admin_headers - - async def register(): - async with httpx.AsyncClient() as client: - await client.post( - f"{args.router_addr}/register", - json={"addr": proxy_addr}, - headers=admin_headers(args.admin_key), - ) - - asyncio.run(register()) - uvicorn.run(proxy_app, host=args.host, port=args.proxy_port, log_level="info") + app = create_worker_app(args.agent) + uvicorn.run(app, host=args.host, port=args.port, log_level=args.log_level) if __name__ == "__main__": diff --git a/areal/experimental/agent_service/worker/app.py b/areal/experimental/agent_service/worker/app.py index 13086653e9..55507bc558 100644 --- a/areal/experimental/agent_service/worker/app.py +++ b/areal/experimental/agent_service/worker/app.py @@ -52,6 +52,19 @@ def create_worker_app( async def health(): return {"status": "ok"} + @app.post("/session/{session_key}/close") + async def close_session(session_key: str): + close_fn = getattr(agent, "close_session", None) + if close_fn is not None: + await close_fn(session_key) + return {"status": "ok"} + + @app.on_event("shutdown") + async def shutdown(): + close_all_fn = getattr(agent, "close_all_sessions", None) + if close_all_fn is not None: + await close_all_fn() + @app.post("/run") async def run(body: dict[str, Any]): request = AgentRequest( diff --git a/areal/experimental/agent_service/worker/config.py b/areal/experimental/agent_service/worker/config.py new file mode 100644 index 0000000000..f3704f0420 --- /dev/null +++ b/areal/experimental/agent_service/worker/config.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class WorkerConfig: + host: str = "0.0.0.0" + port: int = 9000 + agent_cls_path: str = "" + log_level: str = "info" diff --git a/areal/infra/utils/proc.py b/areal/infra/utils/proc.py index 3f93a08ec1..22576e3263 100644 --- a/areal/infra/utils/proc.py +++ b/areal/infra/utils/proc.py @@ -4,6 +4,7 @@ import os import shlex +import shutil import signal import subprocess import sys @@ -61,23 +62,30 @@ def build_streaming_log_cmd( else: cmd_str = cmd + # Check if stdbuf is available (not present on macOS by default) + _has_stdbuf = shutil.which("stdbuf") is not None + # Build prefix with env vars if provided prefix_parts = [] if env_vars: prefix_parts.append( " ".join(f"{k}={shlex.quote(str(v))}" for k, v in env_vars.items()) ) - prefix_parts.append(f"stdbuf -oL {cmd_str}") + if _has_stdbuf: + prefix_parts.append(f"stdbuf -oL {cmd_str}") + else: + prefix_parts.append(cmd_str) full_cmd = " ".join(prefix_parts) # Build log prefix for merged log log_prefix = f"[{role}]".ljust(LOG_PREFIX_WIDTH) # Construct tee/sed pipeline - shell_cmd = ( - f"{full_cmd} 2>&1 " - f"| tee -a {log_file} >(stdbuf -oL sed 's/^/{log_prefix}/' >> {merged_log})" - ) + if _has_stdbuf: + sed_prefix = f"stdbuf -oL sed 's/^/{log_prefix}/'" + else: + sed_prefix = f"sed 's/^/{log_prefix}/'" + shell_cmd = f"{full_cmd} 2>&1 | tee -a {log_file} >({sed_prefix} >> {merged_log})" return shell_cmd diff --git a/examples/agent_service/README.md b/examples/agent_service/README.md index 04fe12975a..563064b4e5 100644 --- a/examples/agent_service/README.md +++ b/examples/agent_service/README.md @@ -1,177 +1,114 @@ -# Agent Service Demo — Tau2 with PydanticAI +# Agent Service — Claude Agent SDK ## Overview -This example demonstrates AReaL's Agent Service running a **tau2 customer-service -agent** powered by **PydanticAI**. The agent handles multi-turn conversations, calls -tau2 environment tools (e.g. flight lookup, reservation booking), and maintains -conversation history across turns. +This example demonstrates AReaL's Agent Service running the **Claude Agent SDK** +(`claude-agent-sdk`) as a scalable HTTP micro-service. It turns Claude's autonomous +agent capabilities — multi-turn conversations, tool use, file editing, web search — into +a production-deployable service with session management, load balancing, and dynamic +scaling. -The Agent Service consists of four independent HTTP services: +**Why this matters**: Projects like +[claude-agent-acp](https://github.com/agentclientprotocol/claude-agent-acp) expose +Claude Agent SDK via custom protocols (ACP) for editor integration. AReaL takes a +different approach — it wraps Claude Agent SDK into standard HTTP micro-services with +session-affine routing, so you can **scale, orchestrate, and train** Claude agents using +AReaL's RL infrastructure. ``` -Client → Gateway (8080) → Router (8081) → DataProxy (9100) → Worker (9000) +Client → Gateway (HTTP) → Router → DataProxy (session state) → Worker (ClaudeSDKClient) ``` -- **Gateway**: public entry point (WebSocket + OpenResponses HTTP bridge) -- **Router**: session-affine routing (DataProxy registration, round-robin) -- **DataProxy**: stateful session proxy (conversation history, forwards to Worker) -- **Worker**: stateless agent execution (loads AgentRunnable, runs one turn) - -## Architecture - -``` -Client (HTTP/WS) - │ - ▼ -┌──────────┐ POST /route ┌──────────┐ -│ Gateway │ ──────────────▶ │ Router │ -│ :8080 │ ◀────────────── │ :8081 │ -└──────────┘ DataProxy addr └──────────┘ - │ - │ POST /session/{key}/turn - ▼ -┌──────────┐ -│ DataProxy│ -│ :9100 │ POST /run ┌──────────┐ -│ (history)│ ────────────▶│ Worker │ -└──────────┘ │ :9000 │ - │ (agent) │ - └──────────┘ -``` - -## Files - -| File | Description | -| ------------- | ----------------------------------------------------- | -| `agent.py` | `Tau2Agent` — PydanticAI agent with tau2 domain tools | -| `config.yaml` | Configuration: LLM endpoints, tau2 domain, data path | -| `run_demo.py` | One-click: starts all services, runs tau2 demo | - ## Prerequisites ```bash -pip install pydantic-ai -pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion +uv pip install claude-agent-sdk +export ANTHROPIC_API_KEY=sk-... ``` -## Configuration - -Edit `config.yaml` to set your LLM endpoints and tau2 settings: - -```yaml -tau2: - domain: airline - data_dir: /path/to/tau2-bench/data - -agent_llm: - model: openai:your-model-name - base_url: http://localhost:8000/v1 - api_key: unused - -user_llm: - model: null # set for user simulator, null for scripted messages - base_url: null - api_key: unused -``` - -Alternatively, set `TAU2_DATA_DIR` as an environment variable. - ## Quick Start -### One-click demo - ```bash -python examples/agent_service/run_demo.py # single task, airline -python examples/agent_service/run_demo.py --domain telecom # different domain -python examples/agent_service/run_demo.py --full # all tasks -python examples/agent_service/run_demo.py --config my.yaml # custom config +python examples/agent_service/run_agent_service.py ``` -This starts all four services in background threads and runs a multi-turn conversation -showing tool calls and history accumulation. +The script creates a `LocalScheduler`, launches Guard workers, then forks Router → +Worker+DataProxy → Gateway. An interactive prompt lets you chat with the Claude agent. -### Manual startup (separate terminals) +### Options ```bash -# Terminal 1: Router -python -m areal.experimental.agent_service.router --port 8081 - -# Terminal 2: Worker + DataProxy -python -m areal.experimental.agent_service.worker \ - --agent examples.agent_service.agent.Tau2Agent \ - --router-addr http://localhost:8081 \ - --worker-port 9000 \ - --proxy-port 9100 - -# Terminal 3: Gateway -python -m areal.experimental.agent_service.gateway \ - --router-addr http://localhost:8081 \ - --port 8080 +python examples/agent_service/run_agent_service.py --num-pairs 4 ``` -### Send a request +### Send requests directly ```bash -curl -X POST http://localhost:8080/v1/responses \ +curl -X POST http://localhost:/v1/responses \ -H "Content-Type: application/json" \ + -H "Authorization: Bearer areal-agent-admin" \ -d '{ - "input": [{"type": "message", "content": "I need to change my flight AA123"}], - "model": "tau2-agent", + "input": [{"type": "message", "content": "Explain RLHF in simple terms"}], + "model": "claude-agent", "user": "my-session" }' ``` -## Implementing Your Own Agent +## Configuration -Create a class that satisfies the `AgentRunnable` protocol: +Claude Agent SDK settings are controlled via environment variables: -```python -from areal.experimental.agent_service.agent_worker import ( - AgentRequest, AgentResponse, EventEmitter, -) +| Variable | Default | Description | +| ---------------------- | ------------------- | --------------------------- | +| `ANTHROPIC_API_KEY` | (required) | Anthropic API key | +| `CLAUDE_MODEL` | `claude-sonnet-4-6` | Model to use | +| `CLAUDE_SYSTEM_PROMPT` | (none) | Optional system prompt | +| `CLAUDE_MAX_TURNS` | `20` | Max agentic turns per query | -class MyAgent: - def __init__(self, **kwargs): - # Configure LLM client, tools, etc. - pass - - async def run( - self, - request: AgentRequest, - *, - emitter: EventEmitter, - ) -> AgentResponse: - # request.message — current user message - # request.history — prior conversation turns - # emitter — stream events back to client - await emitter.emit_delta("Hello!") - return AgentResponse(summary="Hello!") -``` +## Architecture -Then start a worker with your agent: +The Worker maintains a **session-persistent `ClaudeSDKClient`** per session key. Unlike +stateless wrappers, the SDK's internal session retains the full conversation transcript +— no need to re-send history on each turn. -```bash -python -m areal.experimental.agent_service.worker \ - --agent mypackage.myagent.MyAgent \ - --router-addr http://localhost:8081 +``` +Turn 1: Client → Gateway → Router → DataProxy → Worker + Worker: creates ClaudeSDKClient for session "abc" + Claude Agent SDK runs autonomously (tool calls, file ops, etc.) + Response streams back through the chain + +Turn 2: Client → Gateway → Router (same DataProxy) → DataProxy → Worker + Worker: reuses ClaudeSDKClient for session "abc" + SDK remembers full context from Turn 1 ``` -## Multi-turn Conversations - -The DataProxy automatically manages conversation history. Each turn: - -1. DataProxy reads history for the session -1. Builds `AgentRequest` with `history` field populated -1. Forwards to Worker → Agent sees full conversation context -1. Appends user message + agent response to history -1. Tool calls and results are also recorded in history - -The agent accesses history via `request.history`: +## Programmatic Usage ```python -async def run(self, request, *, emitter): - for msg in request.history: - print(f"{msg['role']}: {msg['content']}") - # ... generate response using full context +from areal.experimental.agent_service.controller import ( + AgentServiceController, + AgentServiceControllerConfig, +) +from areal.infra.scheduler.local import LocalScheduler + +scheduler = LocalScheduler(experiment_name="demo", trial_name="run0", gpu_devices=[]) +ctrl = AgentServiceController( + config=AgentServiceControllerConfig( + agent_cls_path="examples.agent_service.agent.ClaudeAgent", + num_pairs=2, + ), + scheduler=scheduler, +) +ctrl.initialize() +# ctrl.gateway_addr → "http://10.0.0.1:9005" +# ctrl.scale_up(2) → add 2 more pairs +# ctrl.scale_down(1) → remove 1 pair (with graceful drain) +ctrl.destroy() ``` + +## Files + +| File | Description | +| ---------------------- | ----------------------------------------------------------- | +| `agent.py` | `ClaudeAgent` — session-persistent Claude Agent SDK wrapper | +| `run_agent_service.py` | Controller-based launcher + interactive conversation | diff --git a/examples/agent_service/agent.py b/examples/agent_service/agent.py index 2c2beb5df8..c05f3bebe5 100644 --- a/examples/agent_service/agent.py +++ b/examples/agent_service/agent.py @@ -1,26 +1,34 @@ -"""Tau2 Agent for AReaL Agent Service (PydanticAI). +"""Claude Agent for AReaL Agent Service. -Implements :class:`AgentRunnable` using PydanticAI. Each call to ``run()`` -handles a **single turn** of a tau2 customer-service dialogue. The agent -uses tau2 environment tools (registered as PydanticAI function tools) and -maintains conversation context via ``request.history``. +Implements :class:`AgentRunnable` using the Claude Agent SDK +(``claude-agent-sdk``). Each Worker instance holds a pool of +:class:`ClaudeSDKClient` sessions keyed by ``session_key``, so multi-turn +conversations preserve full context without re-sending history. -Requires: ``pip install pydantic-ai tau2-bench`` +Requires:: + + pip install claude-agent-sdk + +Environment variables: + ANTHROPIC_API_KEY — Anthropic API key (required) + CLAUDE_MODEL — model name (default: claude-sonnet-4-6) + CLAUDE_SYSTEM_PROMPT — optional system prompt override + CLAUDE_MAX_TURNS — max agentic turns per query (default: 20) """ from __future__ import annotations -import inspect -import json import os -from typing import Any - -from pydantic_ai import Agent -from pydantic_ai.models.openai import OpenAIChatModel -from pydantic_ai.providers.openai import OpenAIProvider -from tau2.environment.environment import Environment -from tau2.environment.tool import Tool as Tau2Tool -from tau2.registry import registry +from typing import Any, Literal + +from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ClaudeSDKClient, + ResultMessage, + TextBlock, + ToolUseBlock, +) from areal.experimental.agent_service.types import ( AgentRequest, @@ -29,101 +37,66 @@ ) from areal.utils import logging -logger = logging.getLogger("Tau2Agent") +logger = logging.getLogger("ClaudeAgent") +PermissionMode = Literal["default", "acceptEdits", "plan", "bypassPermissions"] -def _make_pydantic_tool(tau2_tool: Tau2Tool): - """Create a plain async function from a tau2 Tool for PydanticAI.""" - fn = tau2_tool._func # noqa: SLF001 - name = tau2_tool.name - doc = tau2_tool.openai_schema["function"].get("description", name) +_DEFAULT_PERMISSION_MODE: PermissionMode = "bypassPermissions" - async def _wrapper(**kwargs: Any) -> str: - result = fn(**kwargs) - if not isinstance(result, str): - result = json.dumps(result, default=str) - return result - - _wrapper.__name__ = name - _wrapper.__qualname__ = name - _wrapper.__doc__ = doc - - sig = inspect.signature(fn) - params = [ - inspect.Parameter( - pname, - kind=inspect.Parameter.KEYWORD_ONLY, - default=param.default, - annotation=param.annotation, - ) - for pname, param in sig.parameters.items() - ] - _wrapper.__signature__ = inspect.Signature(params) # type: ignore[attr-defined] - if hasattr(fn, "__annotations__"): - _wrapper.__annotations__ = { - k: v for k, v in fn.__annotations__.items() if k != "return" - } - return _wrapper +class ClaudeAgent: + """AgentRunnable backed by the Claude Agent SDK. -def _think_tool_fn(thoughts: str) -> str: - """Use this tool to think. Only use when necessary.""" - return "Your thoughts are recorded. Please continue your work." - - -class Tau2Agent: - """AgentRunnable that wraps a PydanticAI Agent with tau2 tools. - - Accepts a ``config`` dict (loaded from config.yaml by run_demo.py). - Falls back to environment variables if config is not provided. + Maintains a ``ClaudeSDKClient`` per session for true multi-turn + continuity — the SDK's internal session keeps the full transcript, + so ``request.history`` is only used for the very first turn of a + new session (to seed context if provided by the caller). """ - def __init__(self, config: dict | None = None, **kwargs: Any) -> None: - config = config or {} - tau2_cfg = config.get("tau2", {}) - agent_llm_cfg = config.get("agent_llm", {}) - - self._domain = tau2_cfg.get("domain") or os.environ.get( - "TAU2_DOMAIN", "airline" - ) - add_thinking = tau2_cfg.get("add_thinking_tool", False) - - data_dir = tau2_cfg.get("data_dir") or os.environ.get("TAU2_DATA_DIR") - if data_dir: - os.environ["TAU2_DATA_DIR"] = data_dir + def __init__(self, **kwargs: Any) -> None: + self._model = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6") + self._system_prompt = os.environ.get("CLAUDE_SYSTEM_PROMPT", "") + self._max_turns = int(os.environ.get("CLAUDE_MAX_TURNS", "20")) + self._permission_mode: PermissionMode = _DEFAULT_PERMISSION_MODE - env = self._build_environment() - tau2_tools: list[Tau2Tool] = env.get_tools() - if add_thinking: - tau2_tools.append(Tau2Tool(_think_tool_fn)) - - tools = [_make_pydantic_tool(t) for t in tau2_tools] - system_prompt = env.get_policy() - - model_name = agent_llm_cfg.get("model", "openai:default") - base_url = agent_llm_cfg.get("base_url") - api_key = agent_llm_cfg.get("api_key", "unused") - - if base_url: - model: Any = OpenAIChatModel( - model_name.replace("openai:", ""), - provider=OpenAIProvider(base_url=base_url, api_key=api_key), - ) - else: - model = model_name - - self._agent = Agent(model, system_prompt=system_prompt, tools=tools) + self._sessions: dict[str, ClaudeSDKClient] = {} logger.info( - "Tau2Agent initialized (domain=%s, tools=%d, model=%s)", - self._domain, - len(tools), - model_name, + "ClaudeAgent initialized (model=%s, max_turns=%d)", + self._model, + self._max_turns, ) - def _build_environment(self) -> Environment: - constructor = registry.get_env_constructor(self._domain) - return constructor(solo_mode=False) + def _make_options(self) -> ClaudeAgentOptions: + opts = ClaudeAgentOptions( + model=self._model, + max_turns=self._max_turns, + permission_mode=self._permission_mode, + ) + if self._system_prompt: + opts.system_prompt = self._system_prompt + return opts + + async def _get_or_create_client(self, session_key: str) -> ClaudeSDKClient: + if session_key not in self._sessions: + client = ClaudeSDKClient(options=self._make_options()) + await client.__aenter__() + self._sessions[session_key] = client + logger.info("New session: %s", session_key) + return self._sessions[session_key] + + async def close_session(self, session_key: str) -> None: + client = self._sessions.pop(session_key, None) + if client is not None: + try: + await client.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing session %s", session_key, exc_info=True) + + async def close_all_sessions(self) -> None: + keys = list(self._sessions.keys()) + for key in keys: + await self.close_session(key) async def run( self, @@ -131,87 +104,36 @@ async def run( *, emitter: EventEmitter, ) -> AgentResponse: - from pydantic_ai.messages import ( - ModelRequest, - TextPart, - ToolCallPart, - ToolReturnPart, - UserPromptPart, - ) - from pydantic_ai.messages import ( - ModelResponse as PAModelResponse, - ) - - message_history: list[ModelRequest | PAModelResponse] = [] - for msg in request.history: - role = msg.get("role", "user") - content = msg.get("content", "") - - if role == "user": - message_history.append( - ModelRequest(parts=[UserPromptPart(content=content or "")]) - ) - elif role == "assistant": - tool_calls = msg.get("tool_calls") - if tool_calls: - parts = [] - for tc in tool_calls: - fn = tc.get("function", tc) - parts.append( - ToolCallPart( - tool_name=fn.get("name", ""), - args=fn.get("arguments", ""), - tool_call_id=tc.get("id", ""), + client = await self._get_or_create_client(request.session_key) + + try: + await client.query(request.message) + + text_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + async for msg in client.receive_response(): + if isinstance(msg, AssistantMessage): + for block in msg.content: + if isinstance(block, TextBlock): + await emitter.emit_delta(block.text) + text_parts.append(block.text) + elif isinstance(block, ToolUseBlock): + await emitter.emit_tool_call( + name=block.name, + args=str(block.input), ) - ) - message_history.append(PAModelResponse(parts=parts)) - elif content: - message_history.append( - PAModelResponse(parts=[TextPart(content=content)]) - ) - elif role == "tool": - tool_call_id = msg.get("tool_call_id", "") - message_history.append( - ModelRequest( - parts=[ - ToolReturnPart( - tool_name=tool_call_id, - content=content or "", - tool_call_id=tool_call_id, + tool_calls.append( + {"name": block.name, "input": block.input} ) - ] - ) - ) - - result = await self._agent.run( - request.message, - message_history=message_history, - ) + elif isinstance(msg, ResultMessage): + break - final_text = str(result.output) if result.output else "" - - tool_calls: list[dict[str, Any]] = [] - for msg in result.new_messages(): - if not hasattr(msg, "parts"): - continue - for part in msg.parts: - kind = getattr(part, "part_kind", "") - if kind == "tool-call": - name = getattr(part, "tool_name", "") - args = getattr(part, "args", "") - if isinstance(args, dict): - args = json.dumps(args) - await emitter.emit_tool_call(name=name, args=str(args)) - tool_calls.append({"name": name, "arguments": args}) - elif kind == "tool-return": - name = getattr(part, "tool_name", "") - content = str(getattr(part, "content", "")) - await emitter.emit_tool_result(name=name, result=content) - - if final_text: - await emitter.emit_delta(final_text) - - return AgentResponse( - summary=final_text[:200], - metadata={"tool_calls": tool_calls}, - ) + summary = "".join(text_parts) + return AgentResponse( + summary=summary[:200], + metadata={"tool_calls": tool_calls}, + ) + except Exception: + await self.close_session(request.session_key) + raise diff --git a/examples/agent_service/config.yaml b/examples/agent_service/config.yaml deleted file mode 100644 index 23cebf9f40..0000000000 --- a/examples/agent_service/config.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# Agent Service Demo Configuration - -# Admin key for inter-service authentication (Router, Gateway, Worker). -# Change from default for non-local deployments. -admin_key: areal-agent-admin - -# tau2 environment settings -tau2: - domain: airline # airline | retail | telecom - data_dir: None # path to tau2 data dir (or set TAU2_DATA_DIR env var) - add_thinking_tool: false - -# Agent LLM — the model the agent uses for reasoning + tool calls. -# For the demo this points to a local/self-hosted model. -agent_llm: - model: openai:Ling-2.6-1T # PydanticAI model string - base_url: None # e.g. http://localhost:8000/v1 - api_key: None - -# User simulator LLM — drives the simulated customer. -# Set to null to use scripted user messages instead. -user_llm: - model: openai:GLM-5 # e.g. openai:Qwen2.5-72B - base_url: None # e.g. http://localhost:8001/v1 - api_key: None diff --git a/examples/agent_service/run_agent_service.py b/examples/agent_service/run_agent_service.py new file mode 100644 index 0000000000..e96f83f501 --- /dev/null +++ b/examples/agent_service/run_agent_service.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Launch the Agent Service with Claude Agent SDK. + +Usage:: + + python examples/agent_service/run_agent_service.py + python examples/agent_service/run_agent_service.py --num-pairs 2 + +Requires:: + + uv pip install claude-agent-sdk + export ANTHROPIC_API_KEY=sk-... +""" + +from __future__ import annotations + +import argparse +import asyncio +import time + +import httpx + +from areal.experimental.agent_service.controller import ( + AgentServiceController, + AgentServiceControllerConfig, +) + + +async def _wait_healthy(url: str, timeout: float = 60.0) -> None: + async with httpx.AsyncClient() as client: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = await client.get(url) + if resp.status_code == 200: + return + except httpx.ConnectError: + pass + await asyncio.sleep(0.5) + raise TimeoutError(f"Service at {url} did not become healthy") + + +async def interactive_loop(gateway_addr: str, admin_key: str) -> None: + session_key = f"session-{int(time.time())}" + print("Type your message (or 'quit' to exit):\n") + + async with httpx.AsyncClient(timeout=120.0) as client: + while True: + try: + user_input = input("You: ") + except (EOFError, KeyboardInterrupt): + break + if user_input.strip().lower() in ("quit", "exit", "q"): + break + if not user_input.strip(): + continue + + resp = await client.post( + f"{gateway_addr}/v1/responses", + json={ + "input": [{"type": "message", "content": user_input}], + "model": "claude-agent", + "user": session_key, + }, + headers={"Authorization": f"Bearer {admin_key}"}, + ) + data = resp.json() + + if data.get("status") == "completed": + for item in data.get("output", []): + if item.get("type") == "message": + for block in item.get("content", []): + if block.get("type") == "output_text": + print(f"Agent: {block['text']}") + elif item.get("type") == "function_call": + print(f"[tool] {item.get('name', '')}") + print() + elif data.get("error"): + print(f"Error: {data['error'].get('message', '')[:200]}\n") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Agent Service — Claude Agent SDK") + parser.add_argument( + "--num-pairs", + type=int, + default=1, + help="Number of Worker+DataProxy pairs (default: 1)", + ) + parser.add_argument( + "--admin-api-key", + default="areal-agent-admin", + help="Admin API key for inter-service auth", + ) + args = parser.parse_args() + + from areal.infra.scheduler.local import LocalScheduler + + scheduler = LocalScheduler( + experiment_name="agent-service-demo", + trial_name="run0", + gpu_devices=[], + ) + + ctrl_config = AgentServiceControllerConfig( + agent_cls_path="examples.agent_service.agent.ClaudeAgent", + admin_api_key=args.admin_api_key, + num_pairs=args.num_pairs, + ) + ctrl = AgentServiceController(config=ctrl_config, scheduler=scheduler) + + try: + print(f"Initializing with {args.num_pairs} pair(s) ...") + ctrl.initialize() + print(f" Router: {ctrl.router_addr}") + print(f" Gateway: {ctrl.gateway_addr}") + print(f" Pairs: {len(ctrl.pairs)}") + + asyncio.run(_wait_healthy(f"{ctrl.gateway_addr}/health")) + print("All services ready.\n") + + asyncio.run(interactive_loop(ctrl.gateway_addr, admin_key=args.admin_api_key)) + finally: + print("\nShutting down ...") + ctrl.destroy() + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/agent_service/run_demo.py b/examples/agent_service/run_demo.py deleted file mode 100644 index 707812461c..0000000000 --- a/examples/agent_service/run_demo.py +++ /dev/null @@ -1,298 +0,0 @@ -"""One-click demo: Agent Service + Tau2 (PydanticAI). - -Usage:: - - python examples/agent_service/run_demo.py # single task - python examples/agent_service/run_demo.py --domain telecom # different domain - python examples/agent_service/run_demo.py --full # all tasks - python examples/agent_service/run_demo.py --config my.yaml # custom config - -Requires:: - - pip install pydantic-ai - pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion -""" - -from __future__ import annotations - -import argparse -import asyncio -import os -import threading -import time -from pathlib import Path -from typing import Any -from unittest.mock import patch - -import httpx -import uvicorn -import yaml - -from areal.experimental.agent_service import ( - OpenResponsesBridge, - create_data_proxy_app, - create_gateway_app, - create_router_app, - create_worker_app, - mount_bridge, -) - -ROUTER_PORT = 18081 -WORKER_PORT = 19000 -PROXY_PORT = 19100 -GATEWAY_PORT = 18080 - -DEFAULT_CONFIG = Path(__file__).parent / "config.yaml" - - -def _load_config(path: str | Path) -> dict[str, Any]: - with open(path) as f: - return yaml.safe_load(f) or {} - - -def _start_in_thread(app, port: int, name: str) -> threading.Thread: - def run(): - uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning") - - t = threading.Thread(target=run, daemon=True, name=name) - t.start() - return t - - -async def _wait_healthy(url: str, timeout: float = 10.0) -> None: - async with httpx.AsyncClient() as client: - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - try: - resp = await client.get(url) - if resp.status_code == 200: - return - except httpx.ConnectError: - pass - await asyncio.sleep(0.2) - raise TimeoutError(f"Service at {url} did not become healthy") - - -async def run_task(gateway_addr: str, task, domain: str, admin_key: str) -> float: - """Run a single tau2 task. Returns the reward.""" - from tau2.data_model.message import AssistantMessage, UserMessage - from tau2.data_model.simulation import SimulationRun, TerminationReason - from tau2.evaluator.evaluator import EvaluationType, evaluate_simulation - - session_key = f"tau2-{domain}-{task.id}" - print(f"\n Task: {task.id}") - print(f" Scenario: {str(task.user_scenario)[:120]}...") - - scripted_messages = [ - str(task.user_scenario), - "Yes, please go ahead and help me with that.", - "Can you check the status of my request?", - "Thank you, that's all I need.", - ] - - tau2_messages = [] - error_occurred = False - - async with httpx.AsyncClient(timeout=120.0) as client: - for i, msg in enumerate(scripted_messages, 1): - resp = await client.post( - f"{gateway_addr}/v1/responses", - json={ - "input": [{"type": "message", "content": msg}], - "model": "tau2-agent", - "user": session_key, - }, - headers={"Authorization": f"Bearer {admin_key}"}, - ) - data = resp.json() - - tau2_messages.append( - UserMessage(role="user", content=msg, turn_idx=len(tau2_messages)) - ) - - if data.get("status") == "completed": - agent_text = "" - for item in data.get("output", []): - if item.get("type") == "message": - for block in item.get("content", []): - if block.get("type") == "output_text": - agent_text += block["text"] - print(f" [Turn {i}] Agent: {block['text'][:150]}") - elif item.get("type") == "function_call": - print(f" [Turn {i}] [tool] {item.get('name', '')}") - - tau2_messages.append( - AssistantMessage( - role="assistant", - content=agent_text or "(no response)", - turn_idx=len(tau2_messages), - ) - ) - elif data.get("error"): - err = data["error"].get("message", "")[:100] - print(f" [Turn {i}] Error: {err}") - tau2_messages.append( - AssistantMessage( - role="assistant", - content=f"Error: {err}", - turn_idx=len(tau2_messages), - ) - ) - error_occurred = True - break - - reward = 0.0 - if not error_occurred: - try: - simulation = SimulationRun( - id=f"demo-{task.id}", - task_id=task.id, - messages=tau2_messages, - start_time="", - end_time="", - duration=0.0, - termination_reason=TerminationReason.USER_STOP, - ) - reward_info = evaluate_simulation( - simulation=simulation, - task=task, - evaluation_type=EvaluationType.ALL, - solo_mode=False, - domain=domain, - ) - reward = reward_info.reward - except Exception as e: - print(f" Eval error: {e}") - - print(f" Reward: {reward:.3f}") - return reward - - -async def run_demo(gateway_addr: str, domain: str, full: bool, admin_key: str) -> None: - from tau2.registry import registry - - print(f"\n{'=' * 60}") - print(f" Tau2 Agent Service Demo — domain: {domain}") - print(f"{'=' * 60}") - - tasks = registry.get_tasks_loader(domain)(None) - total = len(tasks) - - if not full: - tasks = tasks[:1] - print(f" Running 1 task (use --full for all {total} tasks)") - else: - print(f" Running all {total} tasks") - - rewards = [] - for task in tasks: - reward = await run_task(gateway_addr, task, domain, admin_key=admin_key) - rewards.append((task.id, reward)) - - print(f"\n{'=' * 60}") - print(f" Results — {len(rewards)} task(s)") - print(f"{'=' * 60}") - for task_id, reward in rewards: - print(f" Task {task_id}: reward = {reward:.3f}") - if rewards: - avg = sum(r for _, r in rewards) / len(rewards) - print(f"\n Average reward: {avg:.3f}") - print(f"{'=' * 60}") - - -def main() -> None: - parser = argparse.ArgumentParser(description="Tau2 Agent Service Demo") - parser.add_argument( - "--config", - default=str(DEFAULT_CONFIG), - help=f"Config YAML path (default: {DEFAULT_CONFIG})", - ) - parser.add_argument( - "--domain", - choices=["airline", "retail", "telecom"], - help="Override tau2.domain from config", - ) - parser.add_argument( - "--full", - action="store_true", - help="Run all tasks (default: single task)", - ) - args = parser.parse_args() - - config = _load_config(args.config) - tau2_cfg = config.setdefault("tau2", {}) - - domain = args.domain or tau2_cfg.get("domain", "airline") - tau2_cfg["domain"] = domain - - data_dir = tau2_cfg.get("data_dir") or os.environ.get("TAU2_DATA_DIR") - if data_dir: - os.environ["TAU2_DATA_DIR"] = data_dir - - admin_key = config.get("admin_key", "areal-agent-admin") - - router_addr = f"http://127.0.0.1:{ROUTER_PORT}" - worker_addr = f"http://127.0.0.1:{WORKER_PORT}" - proxy_addr = f"http://127.0.0.1:{PROXY_PORT}" - gateway_addr = f"http://127.0.0.1:{GATEWAY_PORT}" - - # 1. Router - _start_in_thread(create_router_app(admin_key=admin_key), ROUTER_PORT, "router") - - # 2. Worker (Tau2Agent with PydanticAI + tau2 tools) - def _make_agent_cls(): - from examples.agent_service.agent import Tau2Agent - - class _Configured(Tau2Agent): - def __init__(self, **kw: Any): - super().__init__(config=config, **kw) - - return _Configured - - with patch( - "areal.experimental.agent_service.worker.app.import_from_string", - return_value=_make_agent_cls(), - ): - worker_app = create_worker_app("examples.agent_service.agent.Tau2Agent") - _start_in_thread(worker_app, WORKER_PORT, "worker") - - # 3. DataProxy - _start_in_thread( - create_data_proxy_app(worker_addr=worker_addr), PROXY_PORT, "proxy" - ) - - # 4. Gateway + Bridge - gw_app = create_gateway_app(router_addr=router_addr, admin_key=admin_key) - mount_bridge( - gw_app, - OpenResponsesBridge(router_addr=router_addr, admin_key=admin_key), - admin_key=admin_key, - ) - _start_in_thread(gw_app, GATEWAY_PORT, "gateway") - - # 5. Wait + register - async def setup(): - await _wait_healthy(f"{router_addr}/health") - await _wait_healthy(f"{worker_addr}/health") - await _wait_healthy(f"{proxy_addr}/health") - await _wait_healthy(f"{gateway_addr}/health") - from areal.experimental.agent_service.auth import admin_headers - - async with httpx.AsyncClient() as client: - await client.post( - f"{router_addr}/register", - json={"addr": proxy_addr}, - headers=admin_headers(admin_key), - ) - - asyncio.run(setup()) - print("All services started.") - - # 6. Run demo - asyncio.run( - run_demo(gateway_addr, domain=domain, full=args.full, admin_key=admin_key) - ) - - -if __name__ == "__main__": - main() diff --git a/tests/experimental/agent_service/test_agent_router.py b/tests/experimental/agent_service/test_agent_router.py index 214d7e9d0c..13e682de23 100644 --- a/tests/experimental/agent_service/test_agent_router.py +++ b/tests/experimental/agent_service/test_agent_router.py @@ -4,16 +4,18 @@ import pytest -from areal.experimental.agent_service.auth import DEFAULT_ADMIN_KEY, admin_headers +from areal.experimental.agent_service.auth import DEFAULT_ADMIN_API_KEY, admin_headers from areal.experimental.agent_service.router.app import create_router_app +from areal.experimental.agent_service.router.config import RouterConfig httpx = pytest.importorskip("httpx") -_AUTH = admin_headers(DEFAULT_ADMIN_KEY) +_AUTH = admin_headers(DEFAULT_ADMIN_API_KEY) def _make_client(): - app = create_router_app(admin_key=DEFAULT_ADMIN_KEY) + config = RouterConfig(admin_api_key=DEFAULT_ADMIN_API_KEY) + app = create_router_app(config) transport = httpx.ASGITransport(app=app) return httpx.AsyncClient(transport=transport, base_url="http://router") diff --git a/tests/experimental/agent_service/test_controller.py b/tests/experimental/agent_service/test_controller.py new file mode 100644 index 0000000000..376bed71c6 --- /dev/null +++ b/tests/experimental/agent_service/test_controller.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for AgentServiceController. + +All Guard HTTP interactions are mocked — no real processes or servers. +Tests cover: initialize, destroy, scale_up, scale_down, and error handling. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +import pytest + +from areal.experimental.agent_service.controller.config import ( + AgentServiceControllerConfig, +) +from areal.experimental.agent_service.controller.controller import ( + AgentServiceController, +) + +CTRL = "areal.experimental.agent_service.controller.controller" + + +@dataclass +class _FakeWorker: + id: str + ip: str + worker_ports: list[str] + engine_ports: list[str] + + +def _make_scheduler(*guard_specs: tuple[str, str]) -> MagicMock: + """Return a mock Scheduler whose get_workers returns _FakeWorkers.""" + workers = [ + _FakeWorker(id=f"agent-guard/{i}", ip=ip, worker_ports=[port], engine_ports=[]) + for i, (ip, port) in enumerate(guard_specs) + ] + scheduler = MagicMock() + scheduler.get_workers.return_value = workers + return scheduler + + +def _mock_alloc_ports_response(host: str, ports: list[int]) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"status": "success", "host": host, "ports": ports} + resp.raise_for_status = MagicMock() + return resp + + +def _mock_fork_response(host: str, pid: int) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"status": "success", "host": host, "pid": pid} + resp.raise_for_status = MagicMock() + return resp + + +def _mock_kill_response() -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"status": "success"} + resp.text = '{"status": "success"}' + return resp + + +def _mock_register_response() -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.raise_for_status = MagicMock() + return resp + + +def _mock_health_response(active_sessions: int = 0) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"status": "ok", "active_sessions": active_sessions} + return resp + + +@pytest.fixture() +def config(): + return AgentServiceControllerConfig( + agent_cls_path="my.Agent", + admin_api_key="test-key", + num_pairs=2, + setup_timeout=1.0, + health_poll_interval=0, + ) + + +def _setup_mock_requests(mock_requests, port_start=9001): + port_counter = iter(range(port_start, port_start + 100)) + + def mock_post(url, **kwargs): + if "/alloc_ports" in url: + return _mock_alloc_ports_response("10.0.0.1", [next(port_counter)]) + if "/fork" in url: + return _mock_fork_response("10.0.0.1", 100) + if "/register" in url: + return _mock_register_response() + if "/kill_forked_worker" in url: + return _mock_kill_response() + if "/unregister" in url: + return _mock_register_response() + return MagicMock(status_code=404) + + mock_requests.post = mock_post + mock_requests.get = lambda url, **kw: _mock_health_response() + mock_requests.RequestException = Exception + + +class TestConstruction: + def test_construction(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + assert ctrl.router_addr == "" + assert ctrl.gateway_addr == "" + assert ctrl.pairs == {} + + +class TestInitialize: + @patch(f"{CTRL}.requests") + def test_initialize_forks_router_pairs_gateway(self, mock_requests, config): + """Initialize should create guards via scheduler, then fork services.""" + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090"), ("10.0.0.2", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + + scheduler.create_workers.assert_called_once() + scheduler.get_workers.assert_called_once() + + assert "http://" in ctrl.router_addr + assert "http://" in ctrl.gateway_addr + assert len(ctrl.pairs) == 2 + assert len(ctrl._forked_services) == 6 + + +class TestScaleUp: + @patch(f"{CTRL}.requests") + def test_scale_up_adds_pairs(self, mock_requests, config): + config.num_pairs = 0 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert len(ctrl.pairs) == 0 + + created = ctrl.scale_up(3) + assert created == [0, 1, 2] + assert len(ctrl.pairs) == 3 + + @patch(f"{CTRL}.requests") + def test_scale_up_round_robins_guards(self, mock_requests, config): + config.num_pairs = 0 + guards_called: list[str] = [] + + def mock_post(url, **kwargs): + if "/alloc_ports" in url: + guards_called.append(url.split("/alloc_ports")[0]) + return _mock_alloc_ports_response("10.0.0.1", [9001]) + if "/fork" in url: + return _mock_fork_response("10.0.0.1", 100) + if "/register" in url: + return _mock_register_response() + if "/kill_forked_worker" in url: + return _mock_kill_response() + return MagicMock(status_code=404) + + mock_requests.post = mock_post + mock_requests.get = lambda url, **kw: _mock_health_response() + mock_requests.RequestException = Exception + + scheduler = _make_scheduler(("g0", "8090"), ("g1", "8091")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + guards_called.clear() + + ctrl.scale_up(4) + + g0_calls = [g for g in guards_called if "g0" in g] + g1_calls = [g for g in guards_called if "g1" in g] + assert len(g0_calls) == 4 + assert len(g1_calls) == 4 + + +class TestScaleDown: + @patch(f"{CTRL}.requests") + def test_scale_down_removes_newest_first(self, mock_requests, config): + config.num_pairs = 3 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert len(ctrl.pairs) == 3 + + removed = ctrl.scale_down(2) + assert set(removed) == {2, 1} + assert len(ctrl.pairs) == 1 + assert 0 in ctrl.pairs + + +class TestDestroy: + @patch(f"{CTRL}.requests") + def test_destroy_clears_everything(self, mock_requests, config): + config.num_pairs = 1 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert len(ctrl._forked_services) > 0 + + ctrl.destroy() + assert ctrl.router_addr == "" + assert ctrl.gateway_addr == "" + assert ctrl.pairs == {} + assert ctrl._forked_services == [] + scheduler.delete_workers.assert_called() + + @patch(f"{CTRL}.requests") + def test_destroy_tolerates_kill_errors(self, mock_requests, config): + config.num_pairs = 0 + kill_count = 0 + + def mock_post(url, **kwargs): + nonlocal kill_count + if "/alloc_ports" in url: + return _mock_alloc_ports_response("10.0.0.1", [9001]) + if "/fork" in url: + return _mock_fork_response("10.0.0.1", 100) + if "/kill_forked_worker" in url: + kill_count += 1 + raise ConnectionError("Guard down") + return MagicMock(status_code=404) + + mock_requests.post = mock_post + mock_requests.get = lambda url, **kw: _mock_health_response() + mock_requests.RequestException = Exception + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + + ctrl.destroy() + assert kill_count == 2 + assert ctrl._forked_services == [] + + +class TestDrain: + @patch(f"{CTRL}.requests") + def test_scale_down_waits_for_drain(self, mock_requests, config): + """scale_down should poll DataProxy health until active_sessions reaches 0.""" + config.num_pairs = 1 + config.drain_timeout = 5.0 + + _setup_mock_requests(mock_requests) + health_call_count = 0 + + def mock_get(url, **kwargs): + nonlocal health_call_count + health_call_count += 1 + if "/health" in url and health_call_count <= 5: + return _mock_health_response(active_sessions=2) + return _mock_health_response(active_sessions=0) + + mock_requests.get = mock_get + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + + health_call_count = 0 + with patch(f"{CTRL}.time") as mock_time: + mock_time.monotonic = time.monotonic + mock_time.sleep = MagicMock() + ctrl.scale_down(1) + + assert len(ctrl.pairs) == 0 + assert health_call_count > 1 + + @patch(f"{CTRL}.requests") + def test_drain_skipped_when_timeout_zero(self, mock_requests, config): + config.num_pairs = 1 + config.drain_timeout = 0 + _setup_mock_requests(mock_requests) + get_count = 0 + + def counting_get(url, **kwargs): + nonlocal get_count + get_count += 1 + return _mock_health_response(active_sessions=5) + + mock_requests.get = counting_get + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + + pre_get_count = get_count + ctrl.scale_down(1) + drain_gets = get_count - pre_get_count + assert drain_gets == 0 + + +class TestHealthMonitor: + @patch(f"{CTRL}.requests") + def test_health_monitor_starts_and_stops(self, mock_requests, config): + config.num_pairs = 0 + config.health_poll_interval = 0.1 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert ctrl._health_thread is not None + assert ctrl._health_thread.is_alive() + + ctrl.destroy() + assert ctrl._health_thread is None + + @patch(f"{CTRL}.requests") + def test_health_monitor_disabled_when_interval_zero(self, mock_requests, config): + config.num_pairs = 0 + config.health_poll_interval = 0 + _setup_mock_requests(mock_requests) + + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + ctrl.initialize() + assert ctrl._health_thread is None + + ctrl.destroy() diff --git a/tests/experimental/agent_service/test_guard.py b/tests/experimental/agent_service/test_guard.py new file mode 100644 index 0000000000..e5a82deac3 --- /dev/null +++ b/tests/experimental/agent_service/test_guard.py @@ -0,0 +1,73 @@ +"""Unit tests for Agent Service Guard (pure pass-through). + +Tests that the base guard routes are available on the agent guard app. +The agent_blueprint has been removed in v2 — all orchestration logic +now lives in AgentServiceController. + +Test structure mirrors ``tests/experimental/inference_service/test_guard.py``. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from areal.experimental.agent_service.guard import app as guard_module +from areal.experimental.agent_service.guard.app import app + +GUARD_APP = "areal.infra.rpc.guard.app" + + +@pytest.fixture(autouse=True) +def _reset_guard_globals(): + """Reset all guard state between tests.""" + guard_module._state.allocated_ports = set() + guard_module._state.forked_children = [] + guard_module._state.forked_children_map = {} + guard_module._state.server_host = "10.0.0.1" + guard_module._state.experiment_name = "test-exp" + guard_module._state.trial_name = "test-trial" + guard_module._state.fileroot = None + yield + guard_module._state.allocated_ports = set() + guard_module._state.forked_children = [] + guard_module._state.forked_children_map = {} + + +@pytest.fixture() +def client(): + app.config["TESTING"] = True + with app.test_client() as c: + yield c + + +class TestHealth: + def test_health_returns_200(self, client): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.get_json() + assert data["status"] == "healthy" + assert data["forked_children"] == 0 + + def test_health_counts_forked_children(self, client): + guard_module._state.forked_children = [MagicMock(), MagicMock()] + resp = client.get("/health") + data = resp.get_json() + assert data["forked_children"] == 2 + + +class TestAllocPorts: + @patch(f"{GUARD_APP}.find_free_ports") + def test_alloc_ports_success(self, mock_find, client): + mock_find.return_value = [9001, 9002] + resp = client.post("/alloc_ports", json={"count": 2}) + assert resp.status_code == 200 + data = resp.get_json() + assert data["status"] == "success" + assert data["ports"] == [9001, 9002] + assert guard_module._state.allocated_ports == {9001, 9002} + + def test_alloc_ports_missing_count(self, client): + resp = client.post("/alloc_ports", json={}) + assert resp.status_code == 400 diff --git a/tests/experimental/agent_service/test_integration.py b/tests/experimental/agent_service/test_integration.py index 7c3dc9ba19..f815c5ca25 100644 --- a/tests/experimental/agent_service/test_integration.py +++ b/tests/experimental/agent_service/test_integration.py @@ -10,11 +10,14 @@ import pytest -from areal.experimental.agent_service.auth import DEFAULT_ADMIN_KEY, admin_headers +from areal.experimental.agent_service.auth import DEFAULT_ADMIN_API_KEY, admin_headers from areal.experimental.agent_service.data_proxy.app import create_data_proxy_app +from areal.experimental.agent_service.data_proxy.config import DataProxyConfig from areal.experimental.agent_service.gateway.app import create_gateway_app from areal.experimental.agent_service.gateway.bridge import OpenResponsesBridge +from areal.experimental.agent_service.gateway.config import GatewayConfig from areal.experimental.agent_service.router.app import create_router_app +from areal.experimental.agent_service.router.config import RouterConfig from areal.experimental.agent_service.types import ( AgentRequest, AgentResponse, @@ -24,7 +27,7 @@ httpx = pytest.importorskip("httpx") -_AUTH = admin_headers(DEFAULT_ADMIN_KEY) +_AUTH = admin_headers(DEFAULT_ADMIN_API_KEY) class _EchoAgent: @@ -88,7 +91,7 @@ async def test_data_proxy_manages_history(self): worker_transport = httpx.ASGITransport(app=worker_app) # Create DataProxy pointing to worker - proxy_app = create_data_proxy_app(worker_addr="http://worker") + proxy_app = create_data_proxy_app(DataProxyConfig(worker_addr="http://worker")) # Patch DataProxy's httpx client to use worker's ASGITransport original_post = httpx.AsyncClient.post @@ -132,7 +135,7 @@ async def patched_post(self, url, **kwargs): async def test_close_session_clears_history(self): worker_app = _make_worker_app(_EchoAgent) worker_transport = httpx.ASGITransport(app=worker_app) - proxy_app = create_data_proxy_app(worker_addr="http://worker") + proxy_app = create_data_proxy_app(DataProxyConfig(worker_addr="http://worker")) original_post = httpx.AsyncClient.post @@ -163,7 +166,9 @@ async def patched_post(self, url, **kwargs): class TestRouterIntegration: @pytest.mark.asyncio async def test_register_and_route(self): - router_app = create_router_app(admin_key=DEFAULT_ADMIN_KEY) + router_app = create_router_app( + RouterConfig(admin_api_key=DEFAULT_ADMIN_API_KEY) + ) transport = httpx.ASGITransport(app=router_app) async with httpx.AsyncClient( @@ -190,7 +195,7 @@ class TestToolCallFlow: async def test_tool_events_through_proxy(self): worker_app = _make_worker_app(_ToolAgent) worker_transport = httpx.ASGITransport(app=worker_app) - proxy_app = create_data_proxy_app(worker_addr="http://worker") + proxy_app = create_data_proxy_app(DataProxyConfig(worker_addr="http://worker")) original_post = httpx.AsyncClient.post @@ -231,7 +236,7 @@ async def patched_post(self, url, **kwargs): class TestGatewayHealth: @pytest.mark.asyncio async def test_health(self): - app = create_gateway_app(router_addr="http://fake-router") + app = create_gateway_app(GatewayConfig(router_addr="http://fake-router")) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient( transport=transport, base_url="http://gw"