diff --git a/ax_cli/commands/gateway.py b/ax_cli/commands/gateway.py index 90b82f8..48948c6 100644 --- a/ax_cli/commands/gateway.py +++ b/ax_cli/commands/gateway.py @@ -203,6 +203,18 @@ def _validate_runtime_registration(runtime_type: str, exec_cmd: str | None) -> N raise ValueError("This runtime does not accept --exec.") +def _normalize_timeout_seconds(timeout_seconds: int | None) -> int | None: + if timeout_seconds is None: + return None + try: + normalized = int(timeout_seconds) + except (TypeError, ValueError) as exc: + raise ValueError("Timeout must be a whole number of seconds.") from exc + if normalized < 1: + raise ValueError("Timeout must be at least 1 second.") + return normalized + + def _register_managed_agent( *, name: str, @@ -215,6 +227,7 @@ def _register_managed_agent( audience: str = "both", description: str | None = None, model: str | None = None, + timeout_seconds: int | None = None, start: bool = True, ) -> dict: name = name.strip() @@ -241,6 +254,7 @@ def _register_managed_agent( if template_effective_id == "ollama" and not normalized_ollama_model: normalized_ollama_model = str(ollama_setup_status().get("recommended_model") or "").strip() or None _validate_runtime_registration(runtime_type, exec_cmd) + timeout_effective = _normalize_timeout_seconds(timeout_seconds) session = _load_gateway_session_or_exit() selected_space = space_id or session.get("space_id") @@ -289,6 +303,7 @@ def _register_managed_agent( "exec_command": exec_cmd, "workdir": workdir, "ollama_model": normalized_ollama_model, + "timeout_seconds": timeout_effective, "token_file": str(token_file), "desired_state": "running" if start else "stopped", "effective_state": "stopped", @@ -334,6 +349,7 @@ def _update_managed_agent( ollama_model: str | object = _UNSET, description: str | None = None, model: str | None = None, + timeout_seconds: int | object = _UNSET, desired_state: str | None = None, ) -> dict: name = name.strip() @@ -399,6 +415,8 @@ def _update_managed_agent( if normalized_desired not in {"running", "stopped"}: raise ValueError("Desired state must be running or stopped.") entry["desired_state"] = normalized_desired + if timeout_seconds is not _UNSET: + entry["timeout_seconds"] = _normalize_timeout_seconds(timeout_seconds) # type: ignore[arg-type] session = _load_gateway_session_or_exit() if description or model: @@ -444,6 +462,7 @@ def _update_managed_agent( workdir=workdir_effective, exec_command=exec_effective, desired_state=entry.get("desired_state"), + timeout_seconds=entry.get("timeout_seconds"), ) return annotate_runtime_health(entry, registry=registry) @@ -3043,6 +3062,7 @@ def do_POST(self) -> None: # noqa: N802 audience=str(body.get("audience") or "both"), description=str(body.get("description") or "").strip() or None, model=str(body.get("model") or "").strip() or None, + timeout_seconds=body.get("timeout_seconds", body.get("timeout")), start=bool(body.get("start", True)), ) _write_json_response(self, payload, status=HTTPStatus.CREATED) @@ -3111,6 +3131,9 @@ def do_PUT(self) -> None: # noqa: N802 ollama_model=str(body.get("ollama_model") or "") if "ollama_model" in body else _UNSET, description=str(body.get("description") or "").strip() or None, model=str(body.get("model") or "").strip() or None, + timeout_seconds=body.get("timeout_seconds", body.get("timeout")) + if "timeout_seconds" in body or "timeout" in body + else _UNSET, desired_state=str(body.get("desired_state") or "").strip() or None, ) _write_json_response(self, payload) @@ -3215,9 +3238,15 @@ def _render_agent_detail(entry: dict, *, activity: list[dict]) -> Group: overview.add_row( "Phase", str(entry.get("current_status") or "-"), "Activity", str(entry.get("current_activity") or "-") ) - overview.add_row("Tool", str(entry.get("current_tool") or "-"), "Adapter", str(entry.get("runtime_type") or "-")) overview.add_row( - "Cred Source", str(entry.get("credential_source") or "-"), "Space", str(entry.get("space_id") or "-") + "Tool", + str(entry.get("current_tool") or "-"), + "Timeout", + f"{entry.get('timeout_seconds')}s" if entry.get("timeout_seconds") else "-", + ) + overview.add_row("Adapter", str(entry.get("runtime_type") or "-"), "Space", str(entry.get("space_id") or "-")) + overview.add_row( + "Cred Source", str(entry.get("credential_source") or "-"), "Token", str(entry.get("token_file") or "-") ) overview.add_row( "Agent ID", str(entry.get("agent_id") or "-"), "Last Reply", str(entry.get("last_reply_preview") or "-") @@ -3905,6 +3934,9 @@ def add_agent( audience: str = typer.Option("both", "--audience", help="Minted PAT audience"), description: str = typer.Option(None, "--description", help="Create/update description"), model: str = typer.Option(None, "--model", help="Create/update model"), + timeout_seconds: int = typer.Option( + None, "--timeout", "--timeout-seconds", help="Max seconds a runtime may process one message" + ), start: bool = typer.Option(True, "--start/--no-start", help="Desired running state after registration"), as_json: bool = JSON_OPTION, ): @@ -3922,6 +3954,7 @@ def add_agent( audience=audience, description=description, model=model, + timeout_seconds=timeout_seconds, start=start, ) except (ValueError, LookupError) as exc: @@ -3937,6 +3970,8 @@ def add_agent( if entry.get("asset_type_label"): err_console.print(f" asset = {entry['asset_type_label']}") err_console.print(f" desired_state = {entry['desired_state']}") + if entry.get("timeout_seconds"): + err_console.print(f" timeout = {entry.get('timeout_seconds')}s") err_console.print(f" token_file = {entry['token_file']}") @@ -3954,6 +3989,9 @@ def update_agent( ollama_model: str = typer.Option(None, "--ollama-model", help="Ollama model override for the Ollama template"), description: str = typer.Option(None, "--description", help="Update platform agent description"), model: str = typer.Option(None, "--model", help="Update platform agent model"), + timeout_seconds: int = typer.Option( + None, "--timeout", "--timeout-seconds", help="Max seconds a runtime may process one message" + ), desired_state: str = typer.Option(None, "--desired-state", help="running | stopped"), as_json: bool = JSON_OPTION, ): @@ -3968,6 +4006,7 @@ def update_agent( ollama_model=ollama_model if ollama_model is not None else _UNSET, description=description, model=model, + timeout_seconds=timeout_seconds if timeout_seconds is not None else _UNSET, desired_state=desired_state, ) except (LookupError, ValueError) as exc: @@ -3980,6 +4019,8 @@ def update_agent( err_console.print(f"[green]Managed agent updated:[/green] @{name}") err_console.print(f" type = {entry.get('template_label') or entry.get('runtime_type')}") err_console.print(f" desired_state = {entry.get('desired_state')}") + if entry.get("timeout_seconds"): + err_console.print(f" timeout = {entry.get('timeout_seconds')}s") @agents_app.command("list") diff --git a/ax_cli/gateway.py b/ax_cli/gateway.py index 6b7da5f..cd8bde9 100644 --- a/ax_cli/gateway.py +++ b/ax_cli/gateway.py @@ -46,6 +46,7 @@ DEFAULT_QUEUE_SIZE = 50 DEFAULT_ACTIVITY_LIMIT = 10 DEFAULT_HANDLER_TIMEOUT_SECONDS = 900 +MIN_HANDLER_TIMEOUT_SECONDS = 1 SSE_IDLE_TIMEOUT_SECONDS = 45.0 RUNTIME_STALE_AFTER_SECONDS = 75.0 GATEWAY_EVENT_PREFIX = "AX_GATEWAY_EVENT " @@ -63,6 +64,18 @@ "AX_USER_ENV", "AX_USER_TOKEN", } + + +class GatewayRuntimeTimeoutError(TimeoutError): + """Raised when a managed runtime exceeds its per-message timeout.""" + + def __init__(self, timeout_seconds: int, *, runtime_type: str | None = None) -> None: + self.timeout_seconds = timeout_seconds + self.runtime_type = runtime_type + label = f" {runtime_type}" if runtime_type else "" + super().__init__(f"Gateway{label} runtime timed out after {timeout_seconds}s.") + + _ACTIVITY_LOCK = threading.Lock() _GATEWAY_PROCESS_RE = re.compile( r"(?:uv\s+run\s+ax\s+gateway\s+run|(?:^|\s).+?/ax(?:ctl)?\s+gateway\s+run(?:\s|$)|-m\s+ax_cli\.main\s+gateway\s+run(?:\s|$))" @@ -2749,6 +2762,18 @@ def _hash_tool_arguments(arguments: dict[str, Any] | None) -> str | None: return hashlib.sha256(encoded).hexdigest() +def runtime_timeout_seconds(entry: dict[str, Any]) -> int: + """Resolve a safe per-message runtime timeout for Gateway-managed agents.""" + raw_value = entry.get("timeout_seconds") + if raw_value is None: + raw_value = entry.get("timeout") + try: + timeout = int(raw_value) if raw_value is not None else DEFAULT_HANDLER_TIMEOUT_SECONDS + except (TypeError, ValueError): + timeout = DEFAULT_HANDLER_TIMEOUT_SECONDS + return max(MIN_HANDLER_TIMEOUT_SECONDS, timeout) + + def _run_exec_handler( command: str, prompt: str, @@ -2756,6 +2781,7 @@ def _run_exec_handler( *, message_id: str | None = None, space_id: str | None = None, + timeout_seconds: int | None = None, on_event: Callable[[dict[str, Any]], None] | None = None, ) -> str: argv = [*shlex.split(command), prompt] @@ -2805,9 +2831,10 @@ def _consume_stderr() -> None: stdout_thread.start() stderr_thread.start() + timeout_seconds = max(MIN_HANDLER_TIMEOUT_SECONDS, int(timeout_seconds or runtime_timeout_seconds(entry))) timed_out = False try: - process.wait(timeout=DEFAULT_HANDLER_TIMEOUT_SECONDS) + process.wait(timeout=timeout_seconds) except subprocess.TimeoutExpired: timed_out = True process.kill() @@ -2820,7 +2847,7 @@ def _consume_stderr() -> None: process.stderr.close() if timed_out: - return f"(handler timed out after {DEFAULT_HANDLER_TIMEOUT_SECONDS}s)" + raise GatewayRuntimeTimeoutError(timeout_seconds, runtime_type="exec") output = "".join(stdout_lines).strip() stderr = "".join(stderr_lines).strip() @@ -3782,9 +3809,7 @@ def _handle_sentinel_cli_prompt(self, prompt: str, *, message_id: str, data: dic new_session_id: str | None = None last_activity_time = time.time() exit_reason = "done" - timeout_seconds = int( - self.entry.get("timeout_seconds") or self.entry.get("timeout") or DEFAULT_HANDLER_TIMEOUT_SECONDS - ) + timeout_seconds = runtime_timeout_seconds(self.entry) finished = threading.Event() def _consume_stderr() -> None: @@ -3896,7 +3921,7 @@ def _timeout_watchdog() -> None: final = accumulated_text.strip() stderr = "".join(stderr_lines).strip() if exit_reason == "timeout": - return final or f"Timed out after {timeout_seconds}s with no output." + raise GatewayRuntimeTimeoutError(timeout_seconds, runtime_type=runtime_name) if exit_reason == "crashed": if final: return final @@ -3973,6 +3998,7 @@ def _handle_prompt(self, prompt: str, *, message_id: str, data: dict[str, Any] | self.entry, message_id=message_id or None, space_id=self.space_id, + timeout_seconds=runtime_timeout_seconds(self.entry), on_event=lambda event: self._handle_exec_event(event, message_id=message_id), ) raise ValueError(f"Unsupported runtime_type: {runtime_type}") @@ -4072,6 +4098,33 @@ def _worker_loop(self) -> None: last_work_completed_at=_now_iso(), backlog_depth=self._queue.qsize(), ) + except GatewayRuntimeTimeoutError as exc: + activity = f"Timed out after {exc.timeout_seconds}s" + self._update_state( + current_status="error", + current_activity=activity, + current_tool=None, + current_tool_call_id=None, + last_error=str(exc)[:400], + backlog_depth=self._queue.qsize(), + ) + if message_id: + self._publish_processing_status( + message_id, + "error", + activity=activity, + reason="runtime_timeout", + error_message=str(exc)[:400], + detail={"timeout_seconds": exc.timeout_seconds, "runtime_type": exc.runtime_type}, + ) + record_gateway_activity( + "runtime_timeout", + entry=self.entry, + message_id=message_id or None, + timeout_seconds=exc.timeout_seconds, + runtime_type=exc.runtime_type, + ) + self._log(f"worker timeout: {exc}") except Exception as exc: self._update_state( current_status="error", diff --git a/docs/gateway-agent-runtimes.md b/docs/gateway-agent-runtimes.md index dc419d6..5e06ff6 100644 --- a/docs/gateway-agent-runtimes.md +++ b/docs/gateway-agent-runtimes.md @@ -139,12 +139,17 @@ Use command bridges for simple adapters, demos, and smoke tests. ax gateway agents add echo-bot --type echo ax gateway agents add probe \ --type exec \ - --exec "python3 examples/gateway_probe/probe_bridge.py" + --exec "python3 examples/gateway_probe/probe_bridge.py" \ + --timeout 120 ``` Command bridges are valuable for probes and simple integrations. They are not the preferred shape for coding sentinels because a per-message command loses important in-process state unless the bridge explicitly persists and resumes it. +Use `--timeout` / `--timeout-seconds` to cap per-message runtime work. On +timeout, Gateway publishes a terminal `error` processing signal with +`reason=runtime_timeout` and does not mark the message completed or send a fake +success reply. ## Signal Contract @@ -159,7 +164,7 @@ Minimum signals: when available. - `completed`: the runtime finished and either replied or explicitly queued the work. -- `error`: the runtime failed and the operator should inspect logs. +- `error`: the runtime failed or timed out and the operator should inspect logs. Hermes sentinels should preserve the old behavior from `claude_agent_v2.py`: tool callbacks update the activity bubble with real work, such as reading a diff --git a/tests/test_gateway_commands.py b/tests/test_gateway_commands.py index 511ebe9..dec4ce1 100644 --- a/tests/test_gateway_commands.py +++ b/tests/test_gateway_commands.py @@ -483,17 +483,19 @@ def test_gateway_agents_add_mints_token_and_writes_registry(monkeypatch, tmp_pat monkeypatch.setattr(gateway_cmd, "_polish_metadata", lambda *args, **kwargs: None) monkeypatch.setattr(gateway_cmd, "_mint_agent_pat", lambda *args, **kwargs: ("axp_a_agent.secret", "mgmt")) - result = runner.invoke(app, ["gateway", "agents", "add", "echo-bot", "--type", "echo", "--json"]) + result = runner.invoke(app, ["gateway", "agents", "add", "echo-bot", "--type", "echo", "--timeout", "42", "--json"]) assert result.exit_code == 0, result.output payload = json.loads(result.stdout) assert payload["name"] == "echo-bot" assert payload["runtime_type"] == "echo" + assert payload["timeout_seconds"] == 42 assert payload["desired_state"] == "running" assert payload["credential_source"] == "gateway" assert payload["transport"] == "gateway" registry = gateway_core.load_gateway_registry() assert registry["agents"][0]["name"] == "echo-bot" + assert registry["agents"][0]["timeout_seconds"] == 42 assert registry["bindings"][0]["asset_id"] == "agent-1" assert registry["bindings"][0]["approved_state"] == "approved" assert registry["agents"][0]["install_id"] == registry["bindings"][0]["install_id"] @@ -969,6 +971,65 @@ def test_managed_exec_runtime_parses_gateway_progress_events(tmp_path, monkeypat assert "tool_finished" in events +def test_managed_exec_runtime_marks_message_timed_out(tmp_path, monkeypatch): + config_dir = tmp_path / "config" + config_dir.mkdir() + monkeypatch.setenv("AX_CONFIG_DIR", str(config_dir)) + token_file = tmp_path / "token" + token_file.write_text("axp_a_agent.secret") + script = tmp_path / "slow_bridge.py" + script.write_text( + """ +import time + +time.sleep(5) +print("too late", flush=True) +""".strip() + ) + payload = { + "id": "msg-1", + "content": "@exec-bot run slow job", + "author": {"id": "user-1", "name": "madtank", "type": "user"}, + "mentions": ["exec-bot"], + } + shared = _SharedRuntimeClient(payload) + + runtime = gateway_core.ManagedAgentRuntime( + { + "name": "exec-bot", + "agent_id": "agent-1", + "space_id": "space-1", + "base_url": "https://paxai.app", + "runtime_type": "exec", + "exec_command": f"{sys.executable} {script}", + "timeout_seconds": 1, + "token_file": str(token_file), + }, + client_factory=lambda **kwargs: shared, + ) + + runtime.start() + deadline = time.time() + 4.0 + while time.time() < deadline and not any(row.get("reason") == "runtime_timeout" for row in shared.processing): + time.sleep(0.05) + snapshot = runtime.snapshot() + runtime.stop() + + assert not shared.sent + assert [row["status"] for row in shared.processing] == ["started", "processing", "error"] + timeout_status = shared.processing[-1] + assert timeout_status["activity"] == "Timed out after 1s" + assert timeout_status["reason"] == "runtime_timeout" + assert timeout_status["detail"] == {"timeout_seconds": 1, "runtime_type": "exec"} + assert "timed out after 1s" in timeout_status["error_message"] + assert snapshot["current_status"] == "error" + assert snapshot["current_activity"] == "Timed out after 1s" + recent = gateway_core.load_recent_gateway_activity() + events = [row["event"] for row in recent] + assert "runtime_timeout" in events + assert "reply_sent" not in events + + def test_managed_sentinel_cli_runtime_resumes_agent_session(tmp_path, monkeypatch): config_dir = tmp_path / "config" config_dir.mkdir() @@ -2038,6 +2099,8 @@ def test_gateway_agents_update_changes_template_and_workdir(monkeypatch, tmp_pat str(tmp_path), "--exec", "python3 examples/gateway_ollama/ollama_bridge.py", + "--timeout", + "120", "--json", ], ) @@ -2047,9 +2110,11 @@ def test_gateway_agents_update_changes_template_and_workdir(monkeypatch, tmp_pat assert payload["template_id"] == "ollama" assert payload["runtime_type"] == "exec" assert payload["workdir"] == str(tmp_path) + assert payload["timeout_seconds"] == 120 stored = gateway_core.load_gateway_registry()["agents"][0] assert stored["template_id"] == "ollama" assert stored["workdir"] == str(tmp_path) + assert stored["timeout_seconds"] == 120 registry_after = gateway_core.load_gateway_registry() binding = registry_after["bindings"][0] assert binding["launch_spec"]["runtime_type"] == "exec"