From 93b22c3c73d56d57c675dc683b8c0a9c6f0f8f9f Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 24 Apr 2026 13:53:58 +0000 Subject: [PATCH 1/3] Add messaging gateway and notify tool Co-authored-by: OpenAI Codex --- agent/config.py | 3 + agent/core/agent_loop.py | 5 + agent/core/session.py | 96 +++++++++- agent/core/tools.py | 7 + agent/main.py | 11 ++ agent/messaging/__init__.py | 15 ++ agent/messaging/base.py | 27 +++ agent/messaging/gateway.py | 159 ++++++++++++++++ agent/messaging/models.py | 114 ++++++++++++ agent/messaging/slack.py | 96 ++++++++++ agent/prompts/system_prompt_v3.yaml | 1 + agent/tools/notify_tool.py | 108 +++++++++++ backend/main.py | 3 + backend/models.py | 9 +- backend/routes/agent.py | 22 ++- backend/session_manager.py | 44 +++++ configs/main_agent_config.json | 5 + pyproject.toml | 1 + tests/unit/test_messaging.py | 277 ++++++++++++++++++++++++++++ uv.lock | 16 ++ 20 files changed, 1012 insertions(+), 7 deletions(-) create mode 100644 agent/messaging/__init__.py create mode 100644 agent/messaging/base.py create mode 100644 agent/messaging/gateway.py create mode 100644 agent/messaging/models.py create mode 100644 agent/messaging/slack.py create mode 100644 agent/tools/notify_tool.py create mode 100644 tests/unit/test_messaging.py diff --git a/agent/config.py b/agent/config.py index b7e698ad..6c2a096d 100644 --- a/agent/config.py +++ b/agent/config.py @@ -6,6 +6,8 @@ from dotenv import load_dotenv +from agent.messaging.models import MessagingConfig + # Project root: two levels up from this file (agent/config.py -> project root) _PROJECT_ROOT = Path(__file__).resolve().parent.parent from fastmcp.mcp_config import ( @@ -42,6 +44,7 @@ class Config(BaseModel): # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off. # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max" reasoning_effort: str | None = "max" + messaging: MessagingConfig = MessagingConfig() def substitute_env_vars(obj: Any) -> Any: diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index c3fd88bc..46ce41a2 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -12,6 +12,7 @@ from litellm.exceptions import ContextWindowExceededError from agent.config import Config +from agent.messaging.gateway import NotificationGateway from agent.core.doom_loop import check_for_doom_loop from agent.core.llm_params import _resolve_llm_params from agent.core.prompt_caching import with_prompt_caching @@ -1204,6 +1205,8 @@ async def submission_loop( hf_token: str | None = None, local_mode: bool = False, stream: bool = True, + notification_gateway: NotificationGateway | None = None, + notification_destinations: list[str] | None = None, ) -> None: """ Main agent loop - processes submissions and dispatches to handlers. @@ -1214,6 +1217,8 @@ async def submission_loop( session = Session( event_queue, config=config, tool_router=tool_router, hf_token=hf_token, local_mode=local_mode, stream=stream, + notification_gateway=notification_gateway, + notification_destinations=notification_destinations, ) if session_holder is not None: session_holder[0] = session diff --git a/agent/core/session.py b/agent/core/session.py index 4b6390d8..1680e8d6 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -12,6 +12,8 @@ from agent.config import Config from agent.context_manager.manager import ContextManager +from agent.messaging.gateway import NotificationGateway +from agent.messaging.models import NotificationRequest logger = logging.getLogger(__name__) @@ -79,13 +81,19 @@ def __init__( hf_token: str | None = None, local_mode: bool = False, stream: bool = True, + notification_gateway: NotificationGateway | None = None, + notification_destinations: list[str] | None = None, + session_id: str | None = None, ): self.hf_token: Optional[str] = hf_token self.tool_router = tool_router self.stream = stream tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else [] + effective_config = config or Config( + model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0", + ) self.context_manager = context_manager or ContextManager( - model_max_tokens=_get_max_tokens_safe(config.model_name), + model_max_tokens=_get_max_tokens_safe(effective_config.model_name), compact_size=0.1, untouched_messages=5, tool_specs=tool_specs, @@ -93,15 +101,15 @@ def __init__( local_mode=local_mode, ) self.event_queue = event_queue - self.session_id = str(uuid.uuid4()) - self.config = config or Config( - model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0", - ) + self.session_id = session_id or str(uuid.uuid4()) + self.config = effective_config self.is_running = True self._cancelled = asyncio.Event() self.pending_approval: Optional[dict[str, Any]] = None self.sandbox = None self._running_job_ids: set[str] = set() # HF job IDs currently executing + self.notification_gateway = notification_gateway + self.notification_destinations = list(notification_destinations or []) # Session trajectory logging self.logged_events: list[dict] = [] @@ -131,6 +139,84 @@ async def send_event(self, event: Event) -> None: "data": event.data, } ) + await self._send_auto_notification(event) + + def set_notification_destinations(self, destinations: list[str]) -> None: + """Replace the session's opted-in auto-notification destinations.""" + deduped: list[str] = [] + seen: set[str] = set() + for destination in destinations: + if destination not in seen: + deduped.append(destination) + seen.add(destination) + self.notification_destinations = deduped + + async def _send_auto_notification(self, event: Event) -> None: + if self.notification_gateway is None: + return + if not self.notification_destinations: + return + if event.event_type not in self.config.messaging.auto_event_types: + return + + requests = self._build_auto_notification_requests(event) + for request in requests: + await self.notification_gateway.enqueue(request) + + def _build_auto_notification_requests( + self, event: Event + ) -> list[NotificationRequest]: + metadata = { + "session_id": self.session_id, + "model": self.config.model_name, + "event_type": event.event_type, + } + + title: str | None = None + message: str | None = None + severity = "info" + data = event.data or {} + if event.event_type == "approval_required": + tools = data.get("tools", []) + tool_names = [] + for tool in tools if isinstance(tools, list) else []: + if isinstance(tool, dict): + tool_name = str(tool.get("tool") or "").strip() + if tool_name and tool_name not in tool_names: + tool_names.append(tool_name) + count = len(tools) if isinstance(tools, list) else 0 + title = "Agent approval required" + message = ( + f"Session {self.session_id} is waiting for approval " + f"for {count} tool call(s)." + ) + if tool_names: + message += " Tools: " + ", ".join(tool_names) + severity = "warning" + elif event.event_type == "error": + title = "Agent error" + error = str(data.get("error") or "Unknown error") + message = f"Session {self.session_id} hit an error.\n{error[:500]}" + severity = "error" + + if message is None: + return [] + + requests: list[NotificationRequest] = [] + for destination in self.notification_destinations: + if not self.config.messaging.can_auto_send(destination): + continue + requests.append( + NotificationRequest( + destination=destination, + title=title, + message=message, + severity=severity, + metadata=metadata, + event_type=event.event_type, + ) + ) + return requests def cancel(self) -> None: """Signal cancellation to the running agent loop.""" diff --git a/agent/core/tools.py b/agent/core/tools.py index 9bbf91d7..f54163cc 100644 --- a/agent/core/tools.py +++ b/agent/core/tools.py @@ -46,6 +46,7 @@ hf_repo_git_handler, ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler +from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler @@ -324,6 +325,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: parameters=PLAN_TOOL_SPEC["parameters"], handler=plan_tool_handler, ), + ToolSpec( + name=NOTIFY_TOOL_SPEC["name"], + description=NOTIFY_TOOL_SPEC["description"], + parameters=NOTIFY_TOOL_SPEC["parameters"], + handler=notify_handler, + ), ToolSpec( name=HF_JOBS_TOOL_SPEC["name"], description=HF_JOBS_TOOL_SPEC["description"], diff --git a/agent/main.py b/agent/main.py index 4ecbefc5..bf347209 100644 --- a/agent/main.py +++ b/agent/main.py @@ -23,6 +23,7 @@ from agent.config import load_config from agent.core.agent_loop import submission_loop from agent.core import model_switcher +from agent.messaging.gateway import NotificationGateway from agent.core.session import OpType from agent.core.tools import ToolRouter from agent.utils.reliability_checks import check_training_script_save_pattern @@ -847,6 +848,8 @@ async def main(): # Start agent loop in background config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json" config = load_config(config_path) + notification_gateway = NotificationGateway(config.messaging) + await notification_gateway.start() # Create tool router with local mode tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) @@ -864,6 +867,7 @@ async def main(): hf_token=hf_token, local_mode=True, stream=True, + notification_gateway=notification_gateway, ) ) @@ -1019,6 +1023,8 @@ def _install_sigint() -> bool: agent_task.cancel() # Agent didn't shut down cleanly — close MCP explicitly await tool_router.__aexit__(None, None, None) + finally: + await notification_gateway.close() # Now safe to cancel the listener (agent is done emitting events) listener_task.cancel() @@ -1047,6 +1053,8 @@ async def headless_main( config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json" config = load_config(config_path) config.yolo_mode = True # Auto-approve everything in headless mode + notification_gateway = NotificationGateway(config.messaging) + await notification_gateway.start() if model: config.model_name = model @@ -1075,6 +1083,7 @@ async def headless_main( hf_token=hf_token, local_mode=True, stream=stream, + notification_gateway=notification_gateway, ) ) @@ -1213,6 +1222,8 @@ async def headless_main( except asyncio.TimeoutError: agent_task.cancel() await tool_router.__aexit__(None, None, None) + finally: + await notification_gateway.close() def cli(): diff --git a/agent/messaging/__init__.py b/agent/messaging/__init__.py new file mode 100644 index 00000000..c399d254 --- /dev/null +++ b/agent/messaging/__init__.py @@ -0,0 +1,15 @@ +from agent.messaging.gateway import NotificationGateway +from agent.messaging.models import ( + MessagingConfig, + NotificationRequest, + NotificationResult, + SUPPORTED_AUTO_EVENT_TYPES, +) + +__all__ = [ + "MessagingConfig", + "NotificationGateway", + "NotificationRequest", + "NotificationResult", + "SUPPORTED_AUTO_EVENT_TYPES", +] diff --git a/agent/messaging/base.py b/agent/messaging/base.py new file mode 100644 index 00000000..bf1d7389 --- /dev/null +++ b/agent/messaging/base.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod + +import httpx + +from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult + + +class NotificationError(Exception): + """Delivery failed and should not be retried.""" + + +class RetryableNotificationError(NotificationError): + """Delivery failed transiently and can be retried.""" + + +class NotificationProvider(ABC): + provider_name: str + + @abstractmethod + async def send( + self, + client: httpx.AsyncClient, + destination_name: str, + destination: DestinationConfig, + request: NotificationRequest, + ) -> NotificationResult: + """Deliver a notification to one destination.""" diff --git a/agent/messaging/gateway.py b/agent/messaging/gateway.py new file mode 100644 index 00000000..bfc9741d --- /dev/null +++ b/agent/messaging/gateway.py @@ -0,0 +1,159 @@ +import asyncio +import logging +from collections.abc import Iterable + +import httpx + +from agent.messaging.base import ( + NotificationError, + NotificationProvider, + RetryableNotificationError, +) +from agent.messaging.models import ( + MessagingConfig, + NotificationRequest, + NotificationResult, +) +from agent.messaging.slack import SlackProvider + +logger = logging.getLogger(__name__) + +_RETRY_DELAYS = (1, 2, 4) + + +class NotificationGateway: + def __init__(self, config: MessagingConfig): + self.config = config + self._providers: dict[str, NotificationProvider] = { + "slack": SlackProvider(), + } + self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue() + self._worker_task: asyncio.Task | None = None + self._client: httpx.AsyncClient | None = None + + @property + def enabled(self) -> bool: + return self.config.enabled + + async def start(self) -> None: + if not self.enabled or self._worker_task is not None: + return + self._client = httpx.AsyncClient(timeout=10.0) + self._worker_task = asyncio.create_task(self._worker(), name="notification-gateway") + + async def flush(self) -> None: + if not self.enabled: + return + await self._queue.join() + + async def close(self) -> None: + if not self.enabled: + return + await self.flush() + if self._worker_task is not None: + self._worker_task.cancel() + try: + await self._worker_task + except asyncio.CancelledError: + pass + self._worker_task = None + if self._client is not None: + await self._client.aclose() + self._client = None + + async def send(self, request: NotificationRequest) -> NotificationResult: + if not self.enabled: + return NotificationResult( + destination=request.destination, + ok=False, + provider="disabled", + error="Messaging is disabled", + ) + + destination = self.config.get_destination(request.destination) + if destination is None: + return NotificationResult( + destination=request.destination, + ok=False, + provider="unknown", + error=f"Unknown destination '{request.destination}'", + ) + + provider = self._providers[destination.provider] + return await self._send_with_retries(provider, request.destination, destination, request) + + async def send_many( + self, requests: Iterable[NotificationRequest] + ) -> list[NotificationResult]: + results: list[NotificationResult] = [] + for request in requests: + results.append(await self.send(request)) + return results + + async def enqueue(self, request: NotificationRequest) -> bool: + if not self.enabled or self._worker_task is None: + return False + await self._queue.put(request) + return True + + async def _worker(self) -> None: + while True: + request = await self._queue.get() + try: + result = await self.send(request) + if not result.ok: + logger.warning( + "Notification delivery failed for %s: %s", + request.destination, + result.error, + ) + except Exception: + logger.exception("Unexpected notification worker failure") + finally: + self._queue.task_done() + + async def _send_with_retries( + self, + provider: NotificationProvider, + destination_name: str, + destination, + request: NotificationRequest, + ) -> NotificationResult: + client = self._client or httpx.AsyncClient(timeout=10.0) + owns_client = self._client is None + try: + for attempt in range(len(_RETRY_DELAYS) + 1): + try: + return await provider.send(client, destination_name, destination, request) + except RetryableNotificationError as exc: + if attempt >= len(_RETRY_DELAYS): + return NotificationResult( + destination=destination_name, + ok=False, + provider=provider.provider_name, + error=str(exc), + ) + delay = _RETRY_DELAYS[attempt] + logger.warning( + "Retrying notification to %s in %ss after transient error: %s", + destination_name, + delay, + exc, + ) + await asyncio.sleep(delay) + except NotificationError as exc: + return NotificationResult( + destination=destination_name, + ok=False, + provider=provider.provider_name, + error=str(exc), + ) + return NotificationResult( + destination=destination_name, + ok=False, + provider=provider.provider_name, + error="Notification delivery exhausted retries", + ) + finally: + if owns_client: + await client.aclose() diff --git a/agent/messaging/models.py b/agent/messaging/models.py new file mode 100644 index 00000000..c4f03c63 --- /dev/null +++ b/agent/messaging/models.py @@ -0,0 +1,114 @@ +from typing import Annotated, Literal + +from pydantic import BaseModel, Field, field_validator, model_validator + +_DESTINATION_NAME_CHARS = set("abcdefghijklmnopqrstuvwxyz0123456789._-") +SUPPORTED_AUTO_EVENT_TYPES = {"approval_required", "error"} + + +class SlackDestinationConfig(BaseModel): + provider: Literal["slack"] = "slack" + token: str + channel: str + allow_agent_tool: bool = False + allow_auto_events: bool = False + username: str | None = None + icon_emoji: str | None = None + + @field_validator("token", "channel") + @classmethod + def _require_non_empty(cls, value: str) -> str: + value = value.strip() + if not value: + raise ValueError("must not be empty") + return value + + +DestinationConfig = Annotated[SlackDestinationConfig, Field(discriminator="provider")] + + +class MessagingConfig(BaseModel): + enabled: bool = False + auto_event_types: list[str] = Field( + default_factory=lambda: ["approval_required", "error"] + ) + destinations: dict[str, DestinationConfig] = Field(default_factory=dict) + + @field_validator("destinations") + @classmethod + def _validate_destination_names( + cls, destinations: dict[str, DestinationConfig] + ) -> dict[str, DestinationConfig]: + for name in destinations: + if not name or any(char not in _DESTINATION_NAME_CHARS for char in name): + raise ValueError( + "destination names must use lowercase letters, digits, '.', '_' or '-'" + ) + return destinations + + @field_validator("auto_event_types") + @classmethod + def _validate_auto_event_types(cls, event_types: list[str]) -> list[str]: + if not event_types: + return [] + normalized: list[str] = [] + seen: set[str] = set() + for event_type in event_types: + if event_type not in SUPPORTED_AUTO_EVENT_TYPES: + raise ValueError( + f"unsupported auto event type '{event_type}'" + ) + if event_type not in seen: + normalized.append(event_type) + seen.add(event_type) + return normalized + + @model_validator(mode="after") + def _require_destinations_when_enabled(self) -> "MessagingConfig": + if self.enabled and not self.destinations: + raise ValueError("messaging.enabled requires at least one destination") + return self + + def get_destination(self, name: str) -> DestinationConfig | None: + return self.destinations.get(name) + + def can_agent_tool_send(self, name: str) -> bool: + destination = self.get_destination(name) + return bool(destination and destination.allow_agent_tool) + + def can_auto_send(self, name: str) -> bool: + destination = self.get_destination(name) + return bool(destination and destination.allow_auto_events) + + +class NotificationRequest(BaseModel): + destination: str + title: str | None = None + message: str + severity: Literal["info", "success", "warning", "error"] = "info" + metadata: dict[str, str] = Field(default_factory=dict) + event_type: str | None = None + + @field_validator("destination", "message") + @classmethod + def _require_text(cls, value: str) -> str: + value = value.strip() + if not value: + raise ValueError("must not be empty") + return value + + @field_validator("title") + @classmethod + def _normalize_title(cls, value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + return value or None + + +class NotificationResult(BaseModel): + destination: str + ok: bool + provider: str + error: str | None = None + external_id: str | None = None diff --git a/agent/messaging/slack.py b/agent/messaging/slack.py new file mode 100644 index 00000000..c21ac98f --- /dev/null +++ b/agent/messaging/slack.py @@ -0,0 +1,96 @@ +import json + +import httpx + +from agent.messaging.base import ( + NotificationError, + NotificationProvider, + RetryableNotificationError, +) +from agent.messaging.models import ( + NotificationRequest, + NotificationResult, + SlackDestinationConfig, +) + +_SEVERITY_PREFIX = { + "info": "[INFO]", + "success": "[SUCCESS]", + "warning": "[WARNING]", + "error": "[ERROR]", +} + + +def _format_text(request: NotificationRequest) -> str: + lines: list[str] = [] + prefix = _SEVERITY_PREFIX[request.severity] + if request.title: + lines.append(f"{prefix} {request.title}") + else: + lines.append(prefix) + lines.append(request.message) + for key, value in request.metadata.items(): + lines.append(f"{key}: {value}") + return "\n".join(lines) + + +class SlackProvider(NotificationProvider): + provider_name = "slack" + + async def send( + self, + client: httpx.AsyncClient, + destination_name: str, + destination: SlackDestinationConfig, + request: NotificationRequest, + ) -> NotificationResult: + payload = { + "channel": destination.channel, + "text": _format_text(request), + "unfurl_links": False, + "unfurl_media": False, + } + if destination.username: + payload["username"] = destination.username + if destination.icon_emoji: + payload["icon_emoji"] = destination.icon_emoji + + try: + response = await client.post( + "https://slack.com/api/chat.postMessage", + headers={ + "Authorization": f"Bearer {destination.token}", + "Content-Type": "application/json; charset=utf-8", + }, + content=json.dumps(payload), + ) + except httpx.TimeoutException as exc: + raise RetryableNotificationError("Slack request timed out") from exc + except httpx.TransportError as exc: + raise RetryableNotificationError("Slack transport error") from exc + + if response.status_code == 429 or response.status_code >= 500: + raise RetryableNotificationError( + f"Slack HTTP {response.status_code}" + ) + if response.status_code >= 400: + raise NotificationError(f"Slack HTTP {response.status_code}") + + try: + data = response.json() + except ValueError as exc: + raise RetryableNotificationError("Slack returned invalid JSON") from exc + + if not data.get("ok"): + error = str(data.get("error") or "unknown_error") + if error == "ratelimited": + raise RetryableNotificationError(error) + raise NotificationError(error) + + return NotificationResult( + destination=destination_name, + ok=True, + provider=self.provider_name, + external_id=str(data.get("ts") or ""), + error=None, + ) diff --git a/agent/prompts/system_prompt_v3.yaml b/agent/prompts/system_prompt_v3.yaml index befa56bf..7c29a392 100644 --- a/agent/prompts/system_prompt_v3.yaml +++ b/agent/prompts/system_prompt_v3.yaml @@ -156,6 +156,7 @@ system_prompt: | - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs. - For errors: state what went wrong, why, and what you're doing to fix it. - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity. + - Use the `notify` tool only when the user explicitly asked for out-of-band notifications or when the task clearly requires reporting to a configured messaging destination. Do not use it for routine chat updates. # Tool usage diff --git a/agent/tools/notify_tool.py b/agent/tools/notify_tool.py new file mode 100644 index 00000000..f926d5a5 --- /dev/null +++ b/agent/tools/notify_tool.py @@ -0,0 +1,108 @@ +from typing import Any + +from agent.messaging.models import NotificationRequest + +NOTIFY_TOOL_SPEC = { + "name": "notify", + "description": ( + "Send an out-of-band notification to configured messaging destinations. " + "Use this only when the user explicitly asked for proactive notifications " + "or when the task requires reporting progress outside the chat. " + "Destinations must be named server-side configs such as 'slack.ops'." + ), + "parameters": { + "type": "object", + "properties": { + "destinations": { + "type": "array", + "description": "Named messaging destinations to notify.", + "items": {"type": "string"}, + "minItems": 1, + }, + "message": { + "type": "string", + "description": "Main notification body.", + }, + "title": { + "type": "string", + "description": "Optional short title line.", + }, + "severity": { + "type": "string", + "enum": ["info", "success", "warning", "error"], + "description": "Notification severity label.", + }, + }, + "required": ["destinations", "message"], + }, +} + + +async def notify_handler( + arguments: dict[str, Any], session=None, **_kwargs +) -> tuple[str, bool]: + if session is None or session.notification_gateway is None: + return "Messaging is not configured for this session.", False + + raw_destinations = arguments.get("destinations", []) + if not isinstance(raw_destinations, list) or not raw_destinations: + return "destinations must be a non-empty array of destination names.", False + + destinations: list[str] = [] + seen: set[str] = set() + for raw_name in raw_destinations: + if not isinstance(raw_name, str): + return "Each destination must be a string.", False + name = raw_name.strip() + if not name: + return "Destination names must not be empty.", False + if name not in seen: + destinations.append(name) + seen.add(name) + + disallowed = [ + name + for name in destinations + if not session.config.messaging.can_agent_tool_send(name) + ] + if disallowed: + return ( + "These destinations are unavailable for the notify tool: " + + ", ".join(disallowed) + ), False + + message = arguments.get("message", "") + if not isinstance(message, str) or not message.strip(): + return "message must be a non-empty string.", False + + title = arguments.get("title") + severity = arguments.get("severity", "info") + if title is not None and not isinstance(title, str): + return "title must be a string when provided.", False + if severity not in {"info", "success", "warning", "error"}: + return "severity must be one of: info, success, warning, error.", False + + requests = [ + NotificationRequest( + destination=name, + title=title, + message=message, + severity=severity, + metadata={ + "session_id": session.session_id, + "model": session.config.model_name, + }, + ) + for name in destinations + ] + results = await session.notification_gateway.send_many(requests) + + lines = [] + all_ok = True + for result in results: + if result.ok: + lines.append(f"{result.destination}: sent") + else: + all_ok = False + lines.append(f"{result.destination}: failed ({result.error})") + return "\n".join(lines), all_ok diff --git a/backend/main.py b/backend/main.py index 888740e5..38004132 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,6 +11,7 @@ from fastapi.staticfiles import StaticFiles from routes.agent import router as agent_router from routes.auth import router as auth_router +from session_manager import session_manager # Load .env from project root (parent directory) load_dotenv(Path(__file__).parent.parent / ".env") @@ -27,8 +28,10 @@ async def lifespan(app: FastAPI): """Application lifespan handler.""" logger.info("Starting HF Agent backend...") + await session_manager.start() yield logger.info("Shutting down HF Agent backend...") + await session_manager.close() app = FastAPI( diff --git a/backend/models.py b/backend/models.py index 954779f6..e5ae3562 100644 --- a/backend/models.py +++ b/backend/models.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field class OpType(str, Enum): @@ -86,6 +86,13 @@ class SessionInfo(BaseModel): user_id: str = "dev" pending_approval: list[PendingApprovalTool] | None = None model: str | None = None + notification_destinations: list[str] = Field(default_factory=list) + + +class SessionNotificationsRequest(BaseModel): + """Replace the session's auto-notification destinations.""" + + destinations: list[str] class HealthResponse(BaseModel): diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 7f577995..50b2923d 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -24,6 +24,7 @@ HealthResponse, LLMHealthResponse, SessionInfo, + SessionNotificationsRequest, SessionResponse, SubmitRequest, TruncateRequest, @@ -428,6 +429,26 @@ async def set_session_model( return {"session_id": session_id, "model": model_id} +@router.post("/session/{session_id}/notifications") +async def set_session_notifications( + session_id: str, + body: SessionNotificationsRequest, + user: dict = Depends(get_current_user), +) -> dict: + """Replace the session's auto-notification destinations.""" + _check_session_access(session_id, user) + try: + destinations = session_manager.set_notification_destinations( + session_id, body.destinations + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return { + "session_id": session_id, + "notification_destinations": destinations, + } + + @router.get("/user/quota") async def get_user_quota(user: dict = Depends(get_current_user)) -> dict: """Return the user's plan tier and today's Claude-session quota state.""" @@ -692,4 +713,3 @@ async def shutdown_session( raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} - diff --git a/backend/session_manager.py b/backend/session_manager.py index 7293f9cf..239e3aa7 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -10,6 +10,7 @@ from agent.config import load_config from agent.core.agent_loop import process_submission +from agent.messaging.gateway import NotificationGateway from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter @@ -119,9 +120,18 @@ class SessionManager: def __init__(self, config_path: str | None = None) -> None: self.config = load_config(config_path or DEFAULT_CONFIG_PATH) + self.messaging_gateway = NotificationGateway(self.config.messaging) self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() + async def start(self) -> None: + """Start shared background resources.""" + await self.messaging_gateway.start() + + async def close(self) -> None: + """Flush and close shared background resources.""" + await self.messaging_gateway.close() + def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" return sum( @@ -193,6 +203,9 @@ def _create_session_sync(): session = Session( event_queue, config=session_config, tool_router=tool_router, hf_token=hf_token, + notification_gateway=self.messaging_gateway, + notification_destinations=[], + session_id=session_id, ) t1 = _time.monotonic() logger.info(f"Session initialized in {t1 - t0:.2f}s") @@ -506,8 +519,39 @@ def get_session_info(self, session_id: str) -> dict[str, Any] | None: "user_id": agent_session.user_id, "pending_approval": pending_approval, "model": agent_session.session.config.model_name, + "notification_destinations": list( + agent_session.session.notification_destinations + ), } + def set_notification_destinations( + self, session_id: str, destinations: list[str] + ) -> list[str]: + """Replace the session's opted-in auto-notification destinations.""" + agent_session = self.sessions.get(session_id) + if not agent_session or not agent_session.is_active: + raise ValueError("Session not found or inactive") + + normalized: list[str] = [] + seen: set[str] = set() + for raw_name in destinations: + name = raw_name.strip() + if not name: + raise ValueError("Destination names must not be empty") + destination = self.config.messaging.get_destination(name) + if destination is None: + raise ValueError(f"Unknown destination '{name}'") + if not destination.allow_auto_events: + raise ValueError( + f"Destination '{name}' is not enabled for auto events" + ) + if name not in seen: + normalized.append(name) + seen.add(name) + + agent_session.session.set_notification_destinations(normalized) + return normalized + def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: """List sessions, optionally filtered by user. diff --git a/configs/main_agent_config.json b/configs/main_agent_config.json index af76608f..1270155c 100644 --- a/configs/main_agent_config.json +++ b/configs/main_agent_config.json @@ -5,6 +5,11 @@ "yolo_mode": false, "confirm_cpu_jobs": true, "auto_file_upload": true, + "messaging": { + "enabled": false, + "auto_event_types": ["approval_required", "error"], + "destinations": {} + }, "mcpServers": { "hf-mcp-server": { "transport": "http", diff --git a/pyproject.toml b/pyproject.toml index 52e147ac..255f6fb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ eval = [ # Development and testing dependencies dev = [ "pytest>=9.0.2", + "pytest-asyncio>=1.2.0", ] # All dependencies (eval + dev) diff --git a/tests/unit/test_messaging.py b/tests/unit/test_messaging.py new file mode 100644 index 00000000..bb9e13d9 --- /dev/null +++ b/tests/unit/test_messaging.py @@ -0,0 +1,277 @@ +import asyncio +from pathlib import Path +from types import SimpleNamespace + +import httpx +import pytest +from pydantic import ValidationError + +from agent.config import Config +from agent.core.session import Event, Session +from agent.messaging.gateway import NotificationGateway +from agent.messaging.models import NotificationRequest, NotificationResult +from agent.messaging.slack import SlackProvider +from agent.tools.notify_tool import notify_handler +from backend.session_manager import AgentSession, SessionManager + + +class DummyToolRouter: + def get_tool_specs_for_llm(self) -> list[dict]: + return [] + + +class RecordingGateway: + def __init__(self): + self.enqueued: list[NotificationRequest] = [] + self.sent: list[NotificationRequest] = [] + + async def enqueue(self, request: NotificationRequest) -> bool: + self.enqueued.append(request) + return True + + async def send_many( + self, requests: list[NotificationRequest] + ) -> list[NotificationResult]: + self.sent.extend(requests) + return [ + NotificationResult( + destination=request.destination, + ok=True, + provider="test", + ) + for request in requests + ] + + +def _config_with_messaging(**destination_overrides) -> Config: + destination = { + "provider": "slack", + "token": "xoxb-test", + "channel": "C123", + **destination_overrides, + } + return Config.model_validate( + { + "model_name": "moonshotai/Kimi-K2.6", + "messaging": { + "enabled": True, + "destinations": { + "slack.ops": destination, + }, + }, + } + ) + + +def _test_session( + config: Config, gateway, session_id: str = "session-test" +) -> Session: + return Session( + asyncio.Queue(), + config=config, + tool_router=DummyToolRouter(), + context_manager=SimpleNamespace(items=[]), + notification_gateway=gateway, + session_id=session_id, + ) + + +def test_messaging_config_validates_destination_names(): + with pytest.raises(ValidationError): + Config.model_validate( + { + "model_name": "moonshotai/Kimi-K2.6", + "messaging": { + "enabled": True, + "destinations": { + "Slack Ops": { + "provider": "slack", + "token": "x", + "channel": "C123", + } + }, + }, + } + ) + + config = _config_with_messaging(allow_agent_tool=True, allow_auto_events=True) + assert config.messaging.can_agent_tool_send("slack.ops") + assert config.messaging.can_auto_send("slack.ops") + + +@pytest.mark.asyncio +async def test_slack_provider_formats_and_sends_payload(): + seen: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + seen["auth"] = request.headers["Authorization"] + seen["content_type"] = request.headers["Content-Type"] + seen["json"] = request.read().decode("utf-8") + return httpx.Response(200, json={"ok": True, "ts": "123.456"}) + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + provider = SlackProvider() + result = await provider.send( + client, + "slack.ops", + _config_with_messaging().messaging.destinations["slack.ops"], + NotificationRequest( + destination="slack.ops", + title="Approval required", + message="A run is waiting.", + severity="warning", + metadata={"session_id": "sess-1"}, + ), + ) + + assert result.ok + assert result.external_id == "123.456" + assert seen["auth"] == "Bearer xoxb-test" + assert seen["content_type"].startswith("application/json") + assert '"channel": "C123"' in seen["json"] + assert "[WARNING] Approval required\\nA run is waiting.\\nsession_id: sess-1" in seen["json"] + + +@pytest.mark.asyncio +async def test_notification_gateway_retries_transient_failures(monkeypatch): + attempts = {"count": 0} + + def handler(_request: httpx.Request) -> httpx.Response: + attempts["count"] += 1 + if attempts["count"] == 1: + return httpx.Response(503, json={"ok": False}) + return httpx.Response(200, json={"ok": True, "ts": "999.1"}) + + async def fake_sleep(_delay: float) -> None: + return None + + monkeypatch.setattr("agent.messaging.gateway.asyncio.sleep", fake_sleep) + + config = _config_with_messaging(allow_agent_tool=True) + gateway = NotificationGateway(config.messaging) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + gateway._client = client + result = await gateway.send( + NotificationRequest( + destination="slack.ops", + message="hello", + ) + ) + gateway._client = None + + assert attempts["count"] == 2 + assert result.ok + + +@pytest.mark.asyncio +async def test_notify_tool_rejects_non_allowlisted_destinations(): + config = _config_with_messaging(allow_agent_tool=False) + gateway = RecordingGateway() + session = _test_session(config, gateway) + + output, ok = await notify_handler( + {"destinations": ["slack.ops"], "message": "done"}, + session=session, + ) + + assert not ok + assert "unavailable for the notify tool" in output + assert gateway.sent == [] + + +@pytest.mark.asyncio +async def test_notify_tool_sends_to_allowlisted_destinations(): + config = _config_with_messaging(allow_agent_tool=True) + gateway = RecordingGateway() + session = _test_session(config, gateway, session_id="sess-42") + + output, ok = await notify_handler( + { + "destinations": ["slack.ops"], + "title": "Training complete", + "message": "The run finished successfully.", + "severity": "success", + }, + session=session, + ) + + assert ok + assert output == "slack.ops: sent" + assert len(gateway.sent) == 1 + sent = gateway.sent[0] + assert sent.metadata["session_id"] == "sess-42" + assert sent.metadata["model"] == "moonshotai/Kimi-K2.6" + + +@pytest.mark.asyncio +async def test_session_auto_notifications_only_send_opted_in_auto_destinations(): + config = Config.model_validate( + { + "model_name": "moonshotai/Kimi-K2.6", + "messaging": { + "enabled": True, + "destinations": { + "slack.ops": { + "provider": "slack", + "token": "xoxb-test", + "channel": "C123", + "allow_auto_events": True, + }, + "slack.tool": { + "provider": "slack", + "token": "xoxb-test", + "channel": "C999", + "allow_agent_tool": True, + }, + }, + }, + } + ) + gateway = RecordingGateway() + session = _test_session(config, gateway, session_id="sess-auto") + session.set_notification_destinations(["slack.ops", "slack.tool"]) + + await session.send_event( + Event( + event_type="approval_required", + data={"tools": [{"tool": "hf_jobs", "tool_call_id": "tc-1"}]}, + ) + ) + await session.send_event( + Event(event_type="assistant_message", data={"content": "normal message"}) + ) + + assert len(gateway.enqueued) == 1 + request = gateway.enqueued[0] + assert request.destination == "slack.ops" + assert request.severity == "warning" + assert request.event_type == "approval_required" + assert "hf_jobs" in request.message + + +def test_session_manager_updates_notification_destinations_in_session_info(): + config = _config_with_messaging(allow_auto_events=True) + manager = SessionManager(str(Path(__file__).resolve().parents[2] / "configs" / "main_agent_config.json")) + manager.config = config + manager.sessions = {} + + session = _test_session(config, RecordingGateway(), session_id="sess-manager") + manager.sessions["sess-manager"] = AgentSession( + session_id="sess-manager", + session=session, + tool_router=DummyToolRouter(), + submission_queue=asyncio.Queue(), + ) + + updated = manager.set_notification_destinations( + "sess-manager", + ["slack.ops", "slack.ops"], + ) + + assert updated == ["slack.ops"] + info = manager.get_session_info("sess-manager") + assert info is not None + assert info["notification_destinations"] == ["slack.ops"] + + with pytest.raises(ValueError): + manager.set_notification_destinations("sess-manager", ["slack.unknown"]) diff --git a/uv.lock b/uv.lock index 7546e793..b4c751f3 100644 --- a/uv.lock +++ b/uv.lock @@ -1023,10 +1023,12 @@ all = [ { name = "inspect-ai" }, { name = "pandas" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "tenacity" }, ] dev = [ { name = "pytest" }, + { name = "pytest-asyncio" }, ] eval = [ { name = "datasets" }, @@ -1053,6 +1055,7 @@ requires-dist = [ { name = "prompt-toolkit", specifier = ">=3.0.0" }, { name = "pydantic", specifier = ">=2.12.3" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" }, + { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" }, { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "requests", specifier = ">=2.33.0" }, { name = "rich", specifier = ">=13.0.0" }, @@ -2775,6 +2778,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" From b74ce59e6d194357b8849e21534ef3ea0c4baa90 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sat, 25 Apr 2026 15:45:44 +0000 Subject: [PATCH 2/3] Handle Bedrock streaming permission denials Co-authored-by: OpenAI Codex --- agent/core/agent_loop.py | 57 +++++++++++++++++++++++ tests/unit/test_agent_loop.py | 87 +++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 tests/unit/test_agent_loop.py diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 46ce41a2..1da737c5 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -150,6 +150,29 @@ def _is_effort_config_error(error: Exception) -> bool: return _is_thinking_unsupported(error) or _is_invalid_effort(error) +def _is_bedrock_streaming_permission_error(error: Exception) -> bool: + """Return True when Bedrock rejects streaming due to missing IAM permission.""" + err_str = str(error).lower() + if "invokemodelwithresponsestream" not in err_str: + return False + + auth_markers = ["not authorized", "accessdenied", "forbidden", "permission"] + bedrock_markers = ["bedrock", "inference-profile", "converse-stream"] + return any(marker in err_str for marker in auth_markers) and any( + marker in err_str for marker in bedrock_markers + ) + + +def _is_bedrock_invoke_permission_error(error: Exception) -> bool: + """Return True when Bedrock rejects model invocation due to IAM policy.""" + err_str = str(error).lower() + if "bedrock" not in err_str or "invokemodel" not in err_str: + return False + + auth_markers = ["not authorized", "accessdenied", "forbidden", "permission"] + return any(marker in err_str for marker in auth_markers) + + async def _heal_effort_and_rebuild_params( session: Session, error: Exception, llm_params: dict, ) -> dict: @@ -194,6 +217,23 @@ def _friendly_error_message(error: Exception) -> str | None: """Return a user-friendly message for known error types, or None to fall back to traceback.""" err_str = str(error).lower() + if _is_bedrock_streaming_permission_error(error): + return ( + "Bedrock denied streaming for this model.\n\n" + "Your AWS role is missing `bedrock:InvokeModelWithResponseStream` for " + "the selected inference profile. Run `ml-intern --no-stream ...` or " + "update the IAM policy to allow that action." + ) + + if _is_bedrock_invoke_permission_error(error): + return ( + "Bedrock access was denied for this model.\n\n" + "Your AWS role needs permission to invoke the selected Bedrock model " + "or inference profile. Ensure the policy allows `bedrock:InvokeModel`, " + "and if you want token streaming, also allow " + "`bedrock:InvokeModelWithResponseStream`." + ) + if "authentication" in err_str or "unauthorized" in err_str or "invalid x-api-key" in err_str: return ( "Authentication failed — your API key is missing or invalid.\n\n" @@ -322,6 +362,23 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, )) continue + if _is_bedrock_streaming_permission_error(e): + session.stream = False + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": ( + "Bedrock rejected streaming for this role — " + "switching this session to non-streaming mode." + ), + }, + ) + ) + return await _call_llm_non_streaming( + session, messages, tools, llm_params + ) if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( diff --git a/tests/unit/test_agent_loop.py b/tests/unit/test_agent_loop.py new file mode 100644 index 00000000..7e70440e --- /dev/null +++ b/tests/unit/test_agent_loop.py @@ -0,0 +1,87 @@ +from types import SimpleNamespace + +import pytest + +from agent.core.agent_loop import ( + _call_llm_streaming, + _friendly_error_message, + _is_bedrock_streaming_permission_error, +) + + +class RecordingSession: + def __init__(self): + self.events = [] + self.stream = True + + async def send_event(self, event): + self.events.append(event) + + +def test_detects_bedrock_streaming_permission_error(): + error = Exception( + 'litellm.APIConnectionError: BedrockException - {"Message":"User is not ' + "authorized to perform: bedrock:InvokeModelWithResponseStream on " + 'resource: arn:aws:bedrock:us-west-2:123456789012:inference-profile/' + 'us.anthropic.claude-opus-4-6-v1"}' + ) + + assert _is_bedrock_streaming_permission_error(error) is True + + +def test_friendly_error_message_for_bedrock_streaming_permission(): + error = Exception( + "BedrockException: not authorized to perform " + "bedrock:InvokeModelWithResponseStream" + ) + + message = _friendly_error_message(error) + + assert message is not None + assert "Bedrock denied streaming" in message + assert "--no-stream" in message + + +@pytest.mark.asyncio +async def test_streaming_call_falls_back_to_non_streaming_for_bedrock_permission_error( + monkeypatch, +): + calls: list[bool] = [] + + async def fake_acompletion(*, stream, **_kwargs): + calls.append(stream) + if stream: + raise Exception( + "BedrockException: User is not authorized to perform " + "bedrock:InvokeModelWithResponseStream on resource " + "arn:aws:bedrock:us-west-2:123456789012:inference-profile/" + "us.anthropic.claude-opus-4-6-v1" + ) + + return SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace(content="fallback response", tool_calls=None), + finish_reason="stop", + ) + ], + usage=SimpleNamespace(total_tokens=17), + ) + + monkeypatch.setattr("agent.core.agent_loop.acompletion", fake_acompletion) + + session = RecordingSession() + result = await _call_llm_streaming( + session, + messages=[], + tools=[], + llm_params={"model": "bedrock/us.anthropic.claude-opus-4-6-v1"}, + ) + + assert calls == [True, False] + assert session.stream is False + assert result.content == "fallback response" + assert [event.event_type for event in session.events] == [ + "tool_log", + "assistant_message", + ] From f9f14e181b5337445ac299411aeb4b378667eb4c Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Sat, 25 Apr 2026 15:48:00 +0000 Subject: [PATCH 3/3] Revert "Handle Bedrock streaming permission denials" Co-authored-by: OpenAI Codex --- agent/core/agent_loop.py | 57 ----------------------- tests/unit/test_agent_loop.py | 87 ----------------------------------- 2 files changed, 144 deletions(-) delete mode 100644 tests/unit/test_agent_loop.py diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 1da737c5..46ce41a2 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -150,29 +150,6 @@ def _is_effort_config_error(error: Exception) -> bool: return _is_thinking_unsupported(error) or _is_invalid_effort(error) -def _is_bedrock_streaming_permission_error(error: Exception) -> bool: - """Return True when Bedrock rejects streaming due to missing IAM permission.""" - err_str = str(error).lower() - if "invokemodelwithresponsestream" not in err_str: - return False - - auth_markers = ["not authorized", "accessdenied", "forbidden", "permission"] - bedrock_markers = ["bedrock", "inference-profile", "converse-stream"] - return any(marker in err_str for marker in auth_markers) and any( - marker in err_str for marker in bedrock_markers - ) - - -def _is_bedrock_invoke_permission_error(error: Exception) -> bool: - """Return True when Bedrock rejects model invocation due to IAM policy.""" - err_str = str(error).lower() - if "bedrock" not in err_str or "invokemodel" not in err_str: - return False - - auth_markers = ["not authorized", "accessdenied", "forbidden", "permission"] - return any(marker in err_str for marker in auth_markers) - - async def _heal_effort_and_rebuild_params( session: Session, error: Exception, llm_params: dict, ) -> dict: @@ -217,23 +194,6 @@ def _friendly_error_message(error: Exception) -> str | None: """Return a user-friendly message for known error types, or None to fall back to traceback.""" err_str = str(error).lower() - if _is_bedrock_streaming_permission_error(error): - return ( - "Bedrock denied streaming for this model.\n\n" - "Your AWS role is missing `bedrock:InvokeModelWithResponseStream` for " - "the selected inference profile. Run `ml-intern --no-stream ...` or " - "update the IAM policy to allow that action." - ) - - if _is_bedrock_invoke_permission_error(error): - return ( - "Bedrock access was denied for this model.\n\n" - "Your AWS role needs permission to invoke the selected Bedrock model " - "or inference profile. Ensure the policy allows `bedrock:InvokeModel`, " - "and if you want token streaming, also allow " - "`bedrock:InvokeModelWithResponseStream`." - ) - if "authentication" in err_str or "unauthorized" in err_str or "invalid x-api-key" in err_str: return ( "Authentication failed — your API key is missing or invalid.\n\n" @@ -362,23 +322,6 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, )) continue - if _is_bedrock_streaming_permission_error(e): - session.stream = False - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": ( - "Bedrock rejected streaming for this role — " - "switching this session to non-streaming mode." - ), - }, - ) - ) - return await _call_llm_non_streaming( - session, messages, tools, llm_params - ) if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( diff --git a/tests/unit/test_agent_loop.py b/tests/unit/test_agent_loop.py deleted file mode 100644 index 7e70440e..00000000 --- a/tests/unit/test_agent_loop.py +++ /dev/null @@ -1,87 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from agent.core.agent_loop import ( - _call_llm_streaming, - _friendly_error_message, - _is_bedrock_streaming_permission_error, -) - - -class RecordingSession: - def __init__(self): - self.events = [] - self.stream = True - - async def send_event(self, event): - self.events.append(event) - - -def test_detects_bedrock_streaming_permission_error(): - error = Exception( - 'litellm.APIConnectionError: BedrockException - {"Message":"User is not ' - "authorized to perform: bedrock:InvokeModelWithResponseStream on " - 'resource: arn:aws:bedrock:us-west-2:123456789012:inference-profile/' - 'us.anthropic.claude-opus-4-6-v1"}' - ) - - assert _is_bedrock_streaming_permission_error(error) is True - - -def test_friendly_error_message_for_bedrock_streaming_permission(): - error = Exception( - "BedrockException: not authorized to perform " - "bedrock:InvokeModelWithResponseStream" - ) - - message = _friendly_error_message(error) - - assert message is not None - assert "Bedrock denied streaming" in message - assert "--no-stream" in message - - -@pytest.mark.asyncio -async def test_streaming_call_falls_back_to_non_streaming_for_bedrock_permission_error( - monkeypatch, -): - calls: list[bool] = [] - - async def fake_acompletion(*, stream, **_kwargs): - calls.append(stream) - if stream: - raise Exception( - "BedrockException: User is not authorized to perform " - "bedrock:InvokeModelWithResponseStream on resource " - "arn:aws:bedrock:us-west-2:123456789012:inference-profile/" - "us.anthropic.claude-opus-4-6-v1" - ) - - return SimpleNamespace( - choices=[ - SimpleNamespace( - message=SimpleNamespace(content="fallback response", tool_calls=None), - finish_reason="stop", - ) - ], - usage=SimpleNamespace(total_tokens=17), - ) - - monkeypatch.setattr("agent.core.agent_loop.acompletion", fake_acompletion) - - session = RecordingSession() - result = await _call_llm_streaming( - session, - messages=[], - tools=[], - llm_params={"model": "bedrock/us.anthropic.claude-opus-4-6-v1"}, - ) - - assert calls == [True, False] - assert session.stream is False - assert result.content == "fallback response" - assert [event.event_type for event in session.events] == [ - "tool_log", - "assistant_message", - ]