Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions ax_cli/commands/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ def _validate_runtime_registration(runtime_type: str, exec_cmd: str | None) -> N
raise ValueError("This runtime does not accept --exec.")


def _normalize_timeout_seconds(timeout_seconds: int | None) -> int | None:
if timeout_seconds is None:
return None
try:
normalized = int(timeout_seconds)
except (TypeError, ValueError) as exc:
raise ValueError("Timeout must be a whole number of seconds.") from exc
if normalized < 1:
raise ValueError("Timeout must be at least 1 second.")
return normalized


def _register_managed_agent(
*,
name: str,
Expand All @@ -215,6 +227,7 @@ def _register_managed_agent(
audience: str = "both",
description: str | None = None,
model: str | None = None,
timeout_seconds: int | None = None,
start: bool = True,
) -> dict:
name = name.strip()
Expand All @@ -241,6 +254,7 @@ def _register_managed_agent(
if template_effective_id == "ollama" and not normalized_ollama_model:
normalized_ollama_model = str(ollama_setup_status().get("recommended_model") or "").strip() or None
_validate_runtime_registration(runtime_type, exec_cmd)
timeout_effective = _normalize_timeout_seconds(timeout_seconds)

session = _load_gateway_session_or_exit()
selected_space = space_id or session.get("space_id")
Expand Down Expand Up @@ -289,6 +303,7 @@ def _register_managed_agent(
"exec_command": exec_cmd,
"workdir": workdir,
"ollama_model": normalized_ollama_model,
"timeout_seconds": timeout_effective,
"token_file": str(token_file),
"desired_state": "running" if start else "stopped",
"effective_state": "stopped",
Expand Down Expand Up @@ -334,6 +349,7 @@ def _update_managed_agent(
ollama_model: str | object = _UNSET,
description: str | None = None,
model: str | None = None,
timeout_seconds: int | object = _UNSET,
desired_state: str | None = None,
) -> dict:
name = name.strip()
Expand Down Expand Up @@ -399,6 +415,8 @@ def _update_managed_agent(
if normalized_desired not in {"running", "stopped"}:
raise ValueError("Desired state must be running or stopped.")
entry["desired_state"] = normalized_desired
if timeout_seconds is not _UNSET:
entry["timeout_seconds"] = _normalize_timeout_seconds(timeout_seconds) # type: ignore[arg-type]

session = _load_gateway_session_or_exit()
if description or model:
Expand Down Expand Up @@ -444,6 +462,7 @@ def _update_managed_agent(
workdir=workdir_effective,
exec_command=exec_effective,
desired_state=entry.get("desired_state"),
timeout_seconds=entry.get("timeout_seconds"),
)
return annotate_runtime_health(entry, registry=registry)

Expand Down Expand Up @@ -3043,6 +3062,7 @@ def do_POST(self) -> None: # noqa: N802
audience=str(body.get("audience") or "both"),
description=str(body.get("description") or "").strip() or None,
model=str(body.get("model") or "").strip() or None,
timeout_seconds=body.get("timeout_seconds", body.get("timeout")),
start=bool(body.get("start", True)),
)
_write_json_response(self, payload, status=HTTPStatus.CREATED)
Expand Down Expand Up @@ -3111,6 +3131,9 @@ def do_PUT(self) -> None: # noqa: N802
ollama_model=str(body.get("ollama_model") or "") if "ollama_model" in body else _UNSET,
description=str(body.get("description") or "").strip() or None,
model=str(body.get("model") or "").strip() or None,
timeout_seconds=body.get("timeout_seconds", body.get("timeout"))
if "timeout_seconds" in body or "timeout" in body
else _UNSET,
desired_state=str(body.get("desired_state") or "").strip() or None,
)
_write_json_response(self, payload)
Expand Down Expand Up @@ -3215,9 +3238,15 @@ def _render_agent_detail(entry: dict, *, activity: list[dict]) -> Group:
overview.add_row(
"Phase", str(entry.get("current_status") or "-"), "Activity", str(entry.get("current_activity") or "-")
)
overview.add_row("Tool", str(entry.get("current_tool") or "-"), "Adapter", str(entry.get("runtime_type") or "-"))
overview.add_row(
"Cred Source", str(entry.get("credential_source") or "-"), "Space", str(entry.get("space_id") or "-")
"Tool",
str(entry.get("current_tool") or "-"),
"Timeout",
f"{entry.get('timeout_seconds')}s" if entry.get("timeout_seconds") else "-",
)
overview.add_row("Adapter", str(entry.get("runtime_type") or "-"), "Space", str(entry.get("space_id") or "-"))
overview.add_row(
"Cred Source", str(entry.get("credential_source") or "-"), "Token", str(entry.get("token_file") or "-")
)
overview.add_row(
"Agent ID", str(entry.get("agent_id") or "-"), "Last Reply", str(entry.get("last_reply_preview") or "-")
Expand Down Expand Up @@ -3905,6 +3934,9 @@ def add_agent(
audience: str = typer.Option("both", "--audience", help="Minted PAT audience"),
description: str = typer.Option(None, "--description", help="Create/update description"),
model: str = typer.Option(None, "--model", help="Create/update model"),
timeout_seconds: int = typer.Option(
None, "--timeout", "--timeout-seconds", help="Max seconds a runtime may process one message"
),
start: bool = typer.Option(True, "--start/--no-start", help="Desired running state after registration"),
as_json: bool = JSON_OPTION,
):
Expand All @@ -3922,6 +3954,7 @@ def add_agent(
audience=audience,
description=description,
model=model,
timeout_seconds=timeout_seconds,
start=start,
)
except (ValueError, LookupError) as exc:
Expand All @@ -3937,6 +3970,8 @@ def add_agent(
if entry.get("asset_type_label"):
err_console.print(f" asset = {entry['asset_type_label']}")
err_console.print(f" desired_state = {entry['desired_state']}")
if entry.get("timeout_seconds"):
err_console.print(f" timeout = {entry.get('timeout_seconds')}s")
err_console.print(f" token_file = {entry['token_file']}")


Expand All @@ -3954,6 +3989,9 @@ def update_agent(
ollama_model: str = typer.Option(None, "--ollama-model", help="Ollama model override for the Ollama template"),
description: str = typer.Option(None, "--description", help="Update platform agent description"),
model: str = typer.Option(None, "--model", help="Update platform agent model"),
timeout_seconds: int = typer.Option(
None, "--timeout", "--timeout-seconds", help="Max seconds a runtime may process one message"
),
desired_state: str = typer.Option(None, "--desired-state", help="running | stopped"),
as_json: bool = JSON_OPTION,
):
Expand All @@ -3968,6 +4006,7 @@ def update_agent(
ollama_model=ollama_model if ollama_model is not None else _UNSET,
description=description,
model=model,
timeout_seconds=timeout_seconds if timeout_seconds is not None else _UNSET,
desired_state=desired_state,
)
except (LookupError, ValueError) as exc:
Expand All @@ -3980,6 +4019,8 @@ def update_agent(
err_console.print(f"[green]Managed agent updated:[/green] @{name}")
err_console.print(f" type = {entry.get('template_label') or entry.get('runtime_type')}")
err_console.print(f" desired_state = {entry.get('desired_state')}")
if entry.get("timeout_seconds"):
err_console.print(f" timeout = {entry.get('timeout_seconds')}s")


@agents_app.command("list")
Expand Down
65 changes: 59 additions & 6 deletions ax_cli/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
DEFAULT_QUEUE_SIZE = 50
DEFAULT_ACTIVITY_LIMIT = 10
DEFAULT_HANDLER_TIMEOUT_SECONDS = 900
MIN_HANDLER_TIMEOUT_SECONDS = 1
SSE_IDLE_TIMEOUT_SECONDS = 45.0
RUNTIME_STALE_AFTER_SECONDS = 75.0
GATEWAY_EVENT_PREFIX = "AX_GATEWAY_EVENT "
Expand All @@ -63,6 +64,18 @@
"AX_USER_ENV",
"AX_USER_TOKEN",
}


class GatewayRuntimeTimeoutError(TimeoutError):
"""Raised when a managed runtime exceeds its per-message timeout."""

def __init__(self, timeout_seconds: int, *, runtime_type: str | None = None) -> None:
self.timeout_seconds = timeout_seconds
self.runtime_type = runtime_type
label = f" {runtime_type}" if runtime_type else ""
super().__init__(f"Gateway{label} runtime timed out after {timeout_seconds}s.")


_ACTIVITY_LOCK = threading.Lock()
_GATEWAY_PROCESS_RE = re.compile(
r"(?:uv\s+run\s+ax\s+gateway\s+run|(?:^|\s).+?/ax(?:ctl)?\s+gateway\s+run(?:\s|$)|-m\s+ax_cli\.main\s+gateway\s+run(?:\s|$))"
Expand Down Expand Up @@ -2749,13 +2762,26 @@ def _hash_tool_arguments(arguments: dict[str, Any] | None) -> str | None:
return hashlib.sha256(encoded).hexdigest()


def runtime_timeout_seconds(entry: dict[str, Any]) -> int:
"""Resolve a safe per-message runtime timeout for Gateway-managed agents."""
raw_value = entry.get("timeout_seconds")
if raw_value is None:
raw_value = entry.get("timeout")
try:
timeout = int(raw_value) if raw_value is not None else DEFAULT_HANDLER_TIMEOUT_SECONDS
except (TypeError, ValueError):
timeout = DEFAULT_HANDLER_TIMEOUT_SECONDS
return max(MIN_HANDLER_TIMEOUT_SECONDS, timeout)


def _run_exec_handler(
command: str,
prompt: str,
entry: dict[str, Any],
*,
message_id: str | None = None,
space_id: str | None = None,
timeout_seconds: int | None = None,
on_event: Callable[[dict[str, Any]], None] | None = None,
) -> str:
argv = [*shlex.split(command), prompt]
Expand Down Expand Up @@ -2805,9 +2831,10 @@ def _consume_stderr() -> None:
stdout_thread.start()
stderr_thread.start()

timeout_seconds = max(MIN_HANDLER_TIMEOUT_SECONDS, int(timeout_seconds or runtime_timeout_seconds(entry)))
timed_out = False
try:
process.wait(timeout=DEFAULT_HANDLER_TIMEOUT_SECONDS)
process.wait(timeout=timeout_seconds)
except subprocess.TimeoutExpired:
timed_out = True
process.kill()
Expand All @@ -2820,7 +2847,7 @@ def _consume_stderr() -> None:
process.stderr.close()

if timed_out:
return f"(handler timed out after {DEFAULT_HANDLER_TIMEOUT_SECONDS}s)"
raise GatewayRuntimeTimeoutError(timeout_seconds, runtime_type="exec")

output = "".join(stdout_lines).strip()
stderr = "".join(stderr_lines).strip()
Expand Down Expand Up @@ -3782,9 +3809,7 @@ def _handle_sentinel_cli_prompt(self, prompt: str, *, message_id: str, data: dic
new_session_id: str | None = None
last_activity_time = time.time()
exit_reason = "done"
timeout_seconds = int(
self.entry.get("timeout_seconds") or self.entry.get("timeout") or DEFAULT_HANDLER_TIMEOUT_SECONDS
)
timeout_seconds = runtime_timeout_seconds(self.entry)
finished = threading.Event()

def _consume_stderr() -> None:
Expand Down Expand Up @@ -3896,7 +3921,7 @@ def _timeout_watchdog() -> None:
final = accumulated_text.strip()
stderr = "".join(stderr_lines).strip()
if exit_reason == "timeout":
return final or f"Timed out after {timeout_seconds}s with no output."
raise GatewayRuntimeTimeoutError(timeout_seconds, runtime_type=runtime_name)
if exit_reason == "crashed":
if final:
return final
Expand Down Expand Up @@ -3973,6 +3998,7 @@ def _handle_prompt(self, prompt: str, *, message_id: str, data: dict[str, Any] |
self.entry,
message_id=message_id or None,
space_id=self.space_id,
timeout_seconds=runtime_timeout_seconds(self.entry),
on_event=lambda event: self._handle_exec_event(event, message_id=message_id),
)
raise ValueError(f"Unsupported runtime_type: {runtime_type}")
Expand Down Expand Up @@ -4072,6 +4098,33 @@ def _worker_loop(self) -> None:
last_work_completed_at=_now_iso(),
backlog_depth=self._queue.qsize(),
)
except GatewayRuntimeTimeoutError as exc:
activity = f"Timed out after {exc.timeout_seconds}s"
self._update_state(
current_status="error",
current_activity=activity,
current_tool=None,
current_tool_call_id=None,
last_error=str(exc)[:400],
backlog_depth=self._queue.qsize(),
)
if message_id:
self._publish_processing_status(
message_id,
"error",
activity=activity,
reason="runtime_timeout",
error_message=str(exc)[:400],
detail={"timeout_seconds": exc.timeout_seconds, "runtime_type": exc.runtime_type},
)
record_gateway_activity(
"runtime_timeout",
entry=self.entry,
message_id=message_id or None,
timeout_seconds=exc.timeout_seconds,
runtime_type=exc.runtime_type,
)
self._log(f"worker timeout: {exc}")
except Exception as exc:
self._update_state(
current_status="error",
Expand Down
9 changes: 7 additions & 2 deletions docs/gateway-agent-runtimes.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,17 @@ Use command bridges for simple adapters, demos, and smoke tests.
ax gateway agents add echo-bot --type echo
ax gateway agents add probe \
--type exec \
--exec "python3 examples/gateway_probe/probe_bridge.py"
--exec "python3 examples/gateway_probe/probe_bridge.py" \
--timeout 120
```

Command bridges are valuable for probes and simple integrations. They are not
the preferred shape for coding sentinels because a per-message command loses
important in-process state unless the bridge explicitly persists and resumes it.
Use `--timeout` / `--timeout-seconds` to cap per-message runtime work. On
timeout, Gateway publishes a terminal `error` processing signal with
`reason=runtime_timeout` and does not mark the message completed or send a fake
success reply.

## Signal Contract

Expand All @@ -159,7 +164,7 @@ Minimum signals:
when available.
- `completed`: the runtime finished and either replied or explicitly queued the
work.
- `error`: the runtime failed and the operator should inspect logs.
- `error`: the runtime failed or timed out and the operator should inspect logs.

Hermes sentinels should preserve the old behavior from `claude_agent_v2.py`:
tool callbacks update the activity bubble with real work, such as reading a
Expand Down
Loading