Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dotenv import load_dotenv

from agent.messaging.models import MessagingConfig

# Project root: two levels up from this file (agent/config.py -> project root)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
from fastmcp.mcp_config import (
Expand Down Expand Up @@ -47,6 +49,7 @@ class Config(BaseModel):
# ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
# Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
reasoning_effort: str | None = "max"
messaging: MessagingConfig = MessagingConfig()


def substitute_env_vars(obj: Any) -> Any:
Expand Down
5 changes: 5 additions & 0 deletions agent/core/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from litellm.exceptions import ContextWindowExceededError

from agent.config import Config
from agent.messaging.gateway import NotificationGateway
from agent.core import telemetry
from agent.core.doom_loop import check_for_doom_loop
from agent.core.llm_params import _resolve_llm_params
Expand Down Expand Up @@ -1230,6 +1231,8 @@ async def submission_loop(
hf_token: str | None = None,
local_mode: bool = False,
stream: bool = True,
notification_gateway: NotificationGateway | None = None,
notification_destinations: list[str] | None = None,
) -> None:
"""
Main agent loop - processes submissions and dispatches to handlers.
Expand All @@ -1240,6 +1243,8 @@ async def submission_loop(
session = Session(
event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
local_mode=local_mode, stream=stream,
notification_gateway=notification_gateway,
notification_destinations=notification_destinations,
)
if session_holder is not None:
session_holder[0] = session
Expand Down
97 changes: 92 additions & 5 deletions agent/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from agent.config import Config
from agent.context_manager.manager import ContextManager
from agent.messaging.gateway import NotificationGateway
from agent.messaging.models import NotificationRequest

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,29 +81,35 @@ def __init__(
hf_token: str | None = None,
local_mode: bool = False,
stream: bool = True,
notification_gateway: NotificationGateway | None = None,
notification_destinations: list[str] | None = None,
session_id: str | None = None,
):
self.hf_token: Optional[str] = hf_token
self.tool_router = tool_router
self.stream = stream
tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
effective_config = config or Config(
model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0",
)
self.context_manager = context_manager or ContextManager(
model_max_tokens=_get_max_tokens_safe(config.model_name),
model_max_tokens=_get_max_tokens_safe(effective_config.model_name),
compact_size=0.1,
untouched_messages=5,
tool_specs=tool_specs,
hf_token=hf_token,
local_mode=local_mode,
)
self.event_queue = event_queue
self.session_id = str(uuid.uuid4())
self.config = config or Config(
model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0",
)
self.session_id = session_id or str(uuid.uuid4())
self.config = effective_config
self.is_running = True
self._cancelled = asyncio.Event()
self.pending_approval: Optional[dict[str, Any]] = None
self.sandbox = None
self._running_job_ids: set[str] = set() # HF job IDs currently executing
self.notification_gateway = notification_gateway
self.notification_destinations = list(notification_destinations or [])

# Session trajectory logging
self.logged_events: list[dict] = []
Expand Down Expand Up @@ -136,11 +144,90 @@ async def send_event(self, event: Event) -> None:
"data": event.data,
}
)
await self._send_auto_notification(event)

# Mid-turn heartbeat flush (owned by telemetry module).
from agent.core.telemetry import HeartbeatSaver

HeartbeatSaver.maybe_fire(self)

def set_notification_destinations(self, destinations: list[str]) -> None:
"""Replace the session's opted-in auto-notification destinations."""
deduped: list[str] = []
seen: set[str] = set()
for destination in destinations:
if destination not in seen:
deduped.append(destination)
seen.add(destination)
self.notification_destinations = deduped

async def _send_auto_notification(self, event: Event) -> None:
if self.notification_gateway is None:
return
if not self.notification_destinations:
return
if event.event_type not in self.config.messaging.auto_event_types:
return

requests = self._build_auto_notification_requests(event)
for request in requests:
await self.notification_gateway.enqueue(request)

def _build_auto_notification_requests(
self, event: Event
) -> list[NotificationRequest]:
metadata = {
"session_id": self.session_id,
"model": self.config.model_name,
"event_type": event.event_type,
}

title: str | None = None
message: str | None = None
severity = "info"
data = event.data or {}
if event.event_type == "approval_required":
tools = data.get("tools", [])
tool_names = []
for tool in tools if isinstance(tools, list) else []:
if isinstance(tool, dict):
tool_name = str(tool.get("tool") or "").strip()
if tool_name and tool_name not in tool_names:
tool_names.append(tool_name)
count = len(tools) if isinstance(tools, list) else 0
title = "Agent approval required"
message = (
f"Session {self.session_id} is waiting for approval "
f"for {count} tool call(s)."
)
if tool_names:
message += " Tools: " + ", ".join(tool_names)
severity = "warning"
elif event.event_type == "error":
title = "Agent error"
error = str(data.get("error") or "Unknown error")
message = f"Session {self.session_id} hit an error.\n{error[:500]}"
severity = "error"

if message is None:
return []

requests: list[NotificationRequest] = []
for destination in self.notification_destinations:
if not self.config.messaging.can_auto_send(destination):
continue
requests.append(
NotificationRequest(
destination=destination,
title=title,
message=message,
severity=severity,
metadata=metadata,
event_type=event.event_type,
)
)
return requests

def cancel(self) -> None:
"""Signal cancellation to the running agent loop."""
self._cancelled.set()
Expand Down
7 changes: 7 additions & 0 deletions agent/core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
hf_repo_git_handler,
)
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler
from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
Expand Down Expand Up @@ -324,6 +325,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
parameters=PLAN_TOOL_SPEC["parameters"],
handler=plan_tool_handler,
),
ToolSpec(
name=NOTIFY_TOOL_SPEC["name"],
description=NOTIFY_TOOL_SPEC["description"],
parameters=NOTIFY_TOOL_SPEC["parameters"],
handler=notify_handler,
),
ToolSpec(
name=HF_JOBS_TOOL_SPEC["name"],
description=HF_JOBS_TOOL_SPEC["description"],
Expand Down
11 changes: 11 additions & 0 deletions agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from agent.config import load_config
from agent.core.agent_loop import submission_loop
from agent.core import model_switcher
from agent.messaging.gateway import NotificationGateway
from agent.core.session import OpType
from agent.core.tools import ToolRouter
from agent.utils.reliability_checks import check_training_script_save_pattern
Expand Down Expand Up @@ -848,6 +849,8 @@ async def main():
# Start agent loop in background
config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json"
config = load_config(config_path)
notification_gateway = NotificationGateway(config.messaging)
await notification_gateway.start()

# Create tool router with local mode
tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
Expand All @@ -865,6 +868,7 @@ async def main():
hf_token=hf_token,
local_mode=True,
stream=True,
notification_gateway=notification_gateway,
)
)

Expand Down Expand Up @@ -1020,6 +1024,8 @@ def _install_sigint() -> bool:
agent_task.cancel()
# Agent didn't shut down cleanly — close MCP explicitly
await tool_router.__aexit__(None, None, None)
finally:
await notification_gateway.close()

# Now safe to cancel the listener (agent is done emitting events)
listener_task.cancel()
Expand Down Expand Up @@ -1048,6 +1054,8 @@ async def headless_main(
config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json"
config = load_config(config_path)
config.yolo_mode = True # Auto-approve everything in headless mode
notification_gateway = NotificationGateway(config.messaging)
await notification_gateway.start()

if model:
config.model_name = model
Expand Down Expand Up @@ -1076,6 +1084,7 @@ async def headless_main(
hf_token=hf_token,
local_mode=True,
stream=stream,
notification_gateway=notification_gateway,
)
)

Expand Down Expand Up @@ -1214,6 +1223,8 @@ async def headless_main(
except asyncio.TimeoutError:
agent_task.cancel()
await tool_router.__aexit__(None, None, None)
finally:
await notification_gateway.close()


def cli():
Expand Down
15 changes: 15 additions & 0 deletions agent/messaging/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from agent.messaging.gateway import NotificationGateway
from agent.messaging.models import (
MessagingConfig,
NotificationRequest,
NotificationResult,
SUPPORTED_AUTO_EVENT_TYPES,
)

__all__ = [
"MessagingConfig",
"NotificationGateway",
"NotificationRequest",
"NotificationResult",
"SUPPORTED_AUTO_EVENT_TYPES",
]
27 changes: 27 additions & 0 deletions agent/messaging/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod

import httpx

from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult


class NotificationError(Exception):
"""Delivery failed and should not be retried."""


class RetryableNotificationError(NotificationError):
"""Delivery failed transiently and can be retried."""


class NotificationProvider(ABC):
provider_name: str

@abstractmethod
async def send(
self,
client: httpx.AsyncClient,
destination_name: str,
destination: DestinationConfig,
request: NotificationRequest,
) -> NotificationResult:
"""Deliver a notification to one destination."""
Loading