diff --git a/agent_loop.py b/agent_loop.py index dd8f548..63b9ea3 100644 --- a/agent_loop.py +++ b/agent_loop.py @@ -21,6 +21,7 @@ from modality import CognitiveEvent, CognitiveIntent, ModalityType from pipeline_state import PipelineState from providers import AGENT_TOOLS, InferenceProvider +from schemas.bargein import BargeinContext if TYPE_CHECKING: from channels import BrowserChannel @@ -149,6 +150,9 @@ def __init__( self.draft_queue = DraftQueue() self._speculative_context: list[dict[str, str]] = [] # Context for speculative inference self._human_speaking = False # Whether human is currently speaking + # A2: typed barge-in context prepared before the next turn, consumed by A3 + # for prompt injection. Set by _prepare_bargein_context() on the WS path. + self._pending_bargein: BargeinContext | None = None async def handle_event(self, event: CognitiveEvent) -> None: """Called when a CognitiveEvent arrives from the channel.""" @@ -175,12 +179,46 @@ async def handle_event(self, event: CognitiveEvent) -> None: async def _process(self, event: CognitiveEvent) -> None: """Core: event → provider → tool dispatch.""" - # Context stitching: inject interrupt context from dashboard path - # This closes the barge-in loop — the agent knows what was spoken, - # what was unsaid, and what the user interrupted with. - interrupt_context = self._build_interrupt_context(event.content) - if interrupt_context: - self.conversation.append({"role": "system", "content": interrupt_context}) + # A2: build typed BargeinContext from pipeline_state.last_interrupt (if any) + # and stash on self._pending_bargein. A3 will consume it for prompt injection. + self._prepare_bargein_context(user_text=event.content) + + # MOD3_USE_COGOS_AGENT fork: forward user turn to kernel bus instead of + # calling local provider. Response arrives asynchronously via the + # cogos_agent_bridge → BrowserChannel.broadcast_response_text path. + from cogos_agent_bridge import is_enabled as _cogos_agent_enabled + from cogos_agent_bridge import post_user_message as _post_user_message + + if _cogos_agent_enabled(): + session_id = f"mod3:{self.channel_id or 'unknown'}" + # Fold any pending barge-in context into the forwarded text so the + # kernel cycle sees it. A full structured payload will come in a + # later iteration; for v1 we prepend the terse prompt renderer. + forwarded_text = event.content + pending = self._pending_bargein + if pending is not None: + self._pending_bargein = None + forwarded_text = ( + "[interrupted earlier] " + + pending.format_for_prompt() + + "\n" + + forwarded_text + ) + ok = await _post_user_message(forwarded_text, session_id=session_id) + if not ok and self._channel_ref: + try: + await self._channel_ref.send_response_text( + "[cogos-agent unreachable — check kernel]" + ) + await self._channel_ref.send_response_complete( + metrics={"provider": "cogos-agent", "error": "unreachable"} + ) + except Exception: + pass + # Track the user turn in history so subsequent turns carry it. + self.conversation.append({"role": "user", "content": event.content}) + self._trim_history() + return self.conversation.append({"role": "user", "content": event.content}) self._trim_history() @@ -190,6 +228,7 @@ async def _process(self, event: CognitiveEvent) -> None: # Assemble system prompt with kernel context (afferent path) kernel_ctx = _fetch_kernel_context() system_prompt = _BASE_SYSTEM_PROMPT + kernel_ctx + system_prompt = self._inject_pending_bargein(system_prompt) response = await self.provider.chat( messages=self.conversation, @@ -510,46 +549,51 @@ async def background_validate_drafts(self, latest_user_text: str) -> None: await self._push_draft_queue_state() - def _build_interrupt_context(self, user_text: str) -> str | None: - """Build context stitch from pipeline_state.last_interrupt. - - When the user barged in during TTS playback, captures what was - spoken vs unspoken and injects it as structured context for the - next inference call. Consumes the interrupt (clears it). + def _prepare_bargein_context(self, user_text: str | None) -> None: + """Read pipeline_state.last_interrupt and stash a typed BargeinContext. - Returns a context string, or None if no interrupt occurred. + Called at the top of each WS turn. If the previous assistant reply was + interrupted (and the interrupt is still fresh, < 30s), build a + BargeinContext via the A1 schema and store it on ``self._pending_bargein`` + for A3 to pick up during prompt construction. Clears last_interrupt so + the next turn does not re-consume a stale record. """ info = self.pipeline_state.last_interrupt if info is None: - return None + self._pending_bargein = None + return # Only use recent interrupts (within last 30 seconds) if time.time() - info.timestamp > 30: - return None + # Stale — clear and skip. + with self.pipeline_state._lock: + self.pipeline_state._last_interrupt = None + self._pending_bargein = None + return - # Clear the interrupt so we don't re-inject it + # Consume the interrupt so we don't re-inject it on subsequent turns. + # pipeline_state has no public consume helper yet; clear the private + # slot under its lock (matches the pre-existing pattern on this path). with self.pipeline_state._lock: self.pipeline_state._last_interrupt = None - # Compute unspoken remainder - unspoken = "" - if info.full_text and info.delivered_text: - if info.full_text.startswith(info.delivered_text): - unspoken = info.full_text[len(info.delivered_text) :].strip() - else: - # Fallback: everything after the delivered percentage - unspoken = info.full_text[len(info.delivered_text) :].strip() + self._pending_bargein = BargeinContext.from_interrupt_info( + info, + source="browser_vad", + user_said=user_text or None, + ) - parts = [] - parts.append("[Barge-in context — your previous response was interrupted]") - parts.append(f'spoken (user heard this): "{info.delivered_text}"') - if unspoken: - parts.append(f'unspoken (user did NOT hear this): "{unspoken}"') - parts.append(f"interrupted_at: {info.spoken_pct * 100:.0f}%") - parts.append(f'user_said: "{user_text}"') - parts.append("Acknowledge what was interrupted and respond to the user's new input.") - - return "\n".join(parts) + def _inject_pending_bargein(self, system_prompt: str) -> str: + """Append the pending BargeinContext (if any) to the system prompt. + + Consumes ``self._pending_bargein`` so it does not leak into subsequent + turns. Returns the prompt unchanged if no barge-in is pending. + """ + pending = self._pending_bargein + if pending is None: + return system_prompt + self._pending_bargein = None + return system_prompt + "\n\n" + pending.format_for_prompt() def _trim_history(self) -> None: """Keep conversation within MAX_HISTORY messages.""" diff --git a/bargein/__init__.py b/bargein/__init__.py new file mode 100644 index 0000000..068e18f --- /dev/null +++ b/bargein/__init__.py @@ -0,0 +1,296 @@ +"""Barge-in subsystem. + +This package owns the first-class barge-in primitive inside mod3. Sources +(SuperWhisper, browser VAD, MCP signals, etc.) register as +``BargeinProvider`` instances; each one emits ``BargeinEvent``s through a +callback. The registry below wires those callbacks into the shared consumer +helper ``handle_bargein_event``, which does the same work the legacy +``/tmp/mod3-barge-in.json`` file watcher in ``server.py`` does today: +interrupt in-progress playback via ``pipeline_state.interrupt()`` and log. + +Env-driven config: + MOD3_BARGEIN_PROVIDERS — comma-separated provider names (default: empty). + Example: ``MOD3_BARGEIN_PROVIDERS=superwhisper`` + +Default is empty so users without SuperWhisper installed see no behavior +change from the current setup — they can still run the standalone +``integrations/bargein-producer.py`` script and the legacy file watcher +in ``server.py`` keeps picking up its signals. +""" + +from __future__ import annotations + +import json +import logging +import os +import threading +from typing import Callable + +from pipeline_state import InterruptInfo, PipelineState +from schemas.bargein import BargeinSource + +from .providers.base import BargeinCallback, BargeinEvent, BargeinEventType, BargeinProvider + +log = logging.getLogger("bargein") + +# --------------------------------------------------------------------------- +# Shared consumer helper +# --------------------------------------------------------------------------- +# +# Both the legacy file watcher in server.py and the new provider registry +# call this when a "user is speaking" signal arrives. It is the single +# authoritative "barge-in start" handler. +# +# Returning the InterruptInfo (or None) lets the file watcher continue its +# extra work of writing the interrupt detail back into the signal file — +# cross-process coordination that only matters for the file-based IPC. +# In-process providers ignore the return. + + +def handle_bargein_start( + pipeline_state: PipelineState, + source: str, + metadata: dict | None = None, +) -> InterruptInfo | None: + """Attempt to interrupt in-progress TTS playback because the user began speaking. + + Returns the ``InterruptInfo`` if playback was actually halted, or ``None`` + if nothing was speaking (or another process owns the speech — only the + file watcher can handle that via the cross-process lock). + """ + if not pipeline_state.is_speaking: + return None + info = pipeline_state.interrupt(reason="barge_in") + if info is not None: + log.info( + "Barge-in from %s: paused local playback (%.0f%% delivered)%s", + source, + info.spoken_pct * 100, + f" meta={metadata}" if metadata else "", + ) + return info + + +# --------------------------------------------------------------------------- +# Provider registry +# --------------------------------------------------------------------------- + + +PROVIDER_NAMES = ["superwhisper"] + + +def _build_provider(name: str, on_event: BargeinCallback) -> BargeinProvider | None: + """Instantiate a provider by name. Returns None if unknown or import fails.""" + name = name.strip().lower() + if not name: + return None + if name == "superwhisper": + from .providers.superwhisper import SuperWhisperProvider + + return SuperWhisperProvider(on_event=on_event) + log.warning("Unknown barge-in provider: %r (known: %s)", name, PROVIDER_NAMES) + return None + + +class BargeinRegistry: + """Owns the set of active barge-in providers and routes their events. + + Use: + registry = BargeinRegistry(pipeline_state) + registry.start_from_env() # or registry.register(SomeProvider(...)) + # ... later, on shutdown: + registry.stop_all() + + Tests can install their own dispatch by passing ``on_event`` to + ``register``; registry-level dispatch goes through ``_dispatch`` which + calls both ``handle_bargein_start`` and any extra subscribers. + """ + + def __init__(self, pipeline_state: PipelineState): + self._pipeline_state = pipeline_state + self._providers: list[BargeinProvider] = [] + self._subscribers: list[Callable[[BargeinEvent], None]] = [] + self._lock = threading.Lock() + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + def register(self, provider: BargeinProvider) -> None: + """Register a pre-built provider. Does NOT start it (see ``start_all``).""" + with self._lock: + self._providers.append(provider) + + def subscribe(self, callback: Callable[[BargeinEvent], None]) -> None: + """Register an additional event subscriber (fires after the consumer helper). + + Useful for tests and for future observers (metrics, bus emits, etc.). + """ + with self._lock: + self._subscribers.append(callback) + + def unsubscribe(self, callback: Callable[[BargeinEvent], None]) -> None: + """Remove a previously-registered subscriber. Idempotent.""" + with self._lock: + try: + self._subscribers.remove(callback) + except ValueError: + pass + + # ------------------------------------------------------------------ + # Synchronous wait primitive + # ------------------------------------------------------------------ + + def wait_for_event( + self, + event_type: BargeinEventType, + source: BargeinSource | None = None, + timeout: float | None = None, + ) -> BargeinEvent | None: + """Block until a matching event is dispatched, or until ``timeout``. + + Returns the matching ``BargeinEvent`` on success, or ``None`` on timeout. + Thread-safe; multiple waiters may run concurrently — each receives the + first matching event emitted after its wait began. + + Example:: + + event = registry.wait_for_event("user_speaking_end", timeout=180) + if event is None: + ... # timed out + """ + signal = threading.Event() + captured: list[BargeinEvent] = [] + + def _waiter(event: BargeinEvent) -> None: + if event.event_type != event_type: + return + if source is not None and event.source != source: + return + if signal.is_set(): + return + captured.append(event) + signal.set() + + self.subscribe(_waiter) + try: + if signal.wait(timeout): + return captured[0] + return None + finally: + self.unsubscribe(_waiter) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def start_all(self) -> None: + """Start every registered provider.""" + with self._lock: + providers = list(self._providers) + for p in providers: + p.start() + + def stop_all(self, timeout: float = 2.0) -> None: + """Signal shutdown and (best-effort) join every provider thread.""" + with self._lock: + providers = list(self._providers) + for p in providers: + p.stop(timeout=timeout) + + def start_from_env(self, env_var: str = "MOD3_BARGEIN_PROVIDERS") -> list[str]: + """Instantiate and start providers listed in the env var. Returns started names. + + Providers already present on the registry are kept; we append whatever + the env var asks for that isn't already there. + """ + raw = os.environ.get(env_var, "").strip() + if not raw: + log.info("No barge-in providers configured (set %s=superwhisper to enable)", env_var) + return [] + + requested = [n.strip().lower() for n in raw.split(",") if n.strip()] + already = {type(p).__name__.lower() for p in self._providers} + started: list[str] = [] + for name in requested: + # Match by normalized class name (SuperWhisperProvider -> "superwhisperprovider") + # or the logical name the factory accepts. + if f"{name}provider" in already: + continue + provider = _build_provider(name, self._dispatch) + if provider is None: + continue + self.register(provider) + provider.start() + started.append(name) + log.info("Barge-in providers started: %s", started) + return started + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _dispatch(self, event: BargeinEvent) -> None: + """Route a provider event through the shared consumer + any subscribers.""" + try: + if event.event_type == "user_speaking_start": + handle_bargein_start( + self._pipeline_state, + source=event.source, + metadata=event.metadata, + ) + # user_speaking_end has no in-process consumer today (the legacy + # file watcher also only reacts to "start"). Subscribers still + # see it so future code can use it. + except Exception: + log.exception("consumer helper raised while handling %s", event) + + with self._lock: + subs = list(self._subscribers) + for cb in subs: + try: + cb(event) + except Exception: + log.exception("barge-in subscriber raised") + + +def make_file_mirror_subscriber(signal_path: str) -> Callable[[BargeinEvent], None]: + """Build a registry subscriber that mirrors events into the legacy signal file. + + The legacy ``/tmp/mod3-barge-in.json`` file is consumed by + out-of-process clients (e.g. ``mcp_shim.py``'s ``await_voice_input``) + that cannot subscribe to the in-process registry. Installing this + subscriber lets in-process providers reach those pollers. + + Writes are atomic (tmp + rename). ``OSError`` is swallowed and logged + at debug level — the file mirror is best-effort and must never break + in-process delivery. + """ + + def _mirror(event: BargeinEvent) -> None: + try: + payload = { + "event": event.event_type, + "source": event.source, + "timestamp": event.timestamp.isoformat(), + "via": "bargein_registry", + **event.metadata, + } + tmp = signal_path + ".tmp" + with open(tmp, "w") as f: + json.dump(payload, f) + os.replace(tmp, signal_path) + except OSError: + log.debug("file mirror write failed", exc_info=True) + + return _mirror + + +__all__ = [ + "BargeinEvent", + "BargeinProvider", + "BargeinRegistry", + "handle_bargein_start", + "make_file_mirror_subscriber", + "PROVIDER_NAMES", +] diff --git a/bargein/providers/__init__.py b/bargein/providers/__init__.py new file mode 100644 index 0000000..664a82b --- /dev/null +++ b/bargein/providers/__init__.py @@ -0,0 +1,14 @@ +"""Barge-in providers. + +Each provider watches a different signal source (SuperWhisper, browser VAD, +hotkey, mic-level VAD, …) and emits ``BargeinEvent`` through a callback. +""" + +from .base import BargeinCallback, BargeinEvent, BargeinEventType, BargeinProvider + +__all__ = [ + "BargeinCallback", + "BargeinEvent", + "BargeinEventType", + "BargeinProvider", +] diff --git a/bargein/providers/base.py b/bargein/providers/base.py new file mode 100644 index 0000000..7c7cd1d --- /dev/null +++ b/bargein/providers/base.py @@ -0,0 +1,132 @@ +"""Barge-in provider base class + event shape. + +A provider watches some external signal source (SuperWhisper recordings, +browser VAD, a push-to-talk hotkey, a mic-level silero VAD, …) and emits +``BargeinEvent``s through an ``on_event`` callback supplied at construction. +The mod3 provider registry wires that callback to the shared consumer helper +(``bargein._handle_bargein_start``), which takes the same action the legacy +``/tmp/mod3-barge-in.json`` file watcher takes today. + +Concurrency: threads. Providers run their own polling loop on a daemon +thread started by ``start()`` and stopped by ``stop()``. This matches the +existing ``_bargein_watcher`` in server.py. The SuperWhisper provider's +inner loop does blocking filesystem + sqlite3 reads, so a thread is the +natural fit; an async shape would force every provider to wrap blocking +calls in ``asyncio.to_thread``. +""" + +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Callable, Literal + +from schemas.bargein import BargeinSource + +BargeinEventType = Literal["user_speaking_start", "user_speaking_end"] + + +@dataclass +class BargeinEvent: + """A single emission from a ``BargeinProvider``. + + ``metadata`` carries provider-specific detail (folder names, confidence + scores, etc.) that the consumer may log but must not depend on for + correctness. + """ + + source: BargeinSource + event_type: BargeinEventType + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + metadata: dict = field(default_factory=dict) + + +BargeinCallback = Callable[[BargeinEvent], None] + + +class BargeinProvider(ABC): + """Abstract barge-in provider. + + Subclasses implement ``_run`` as a blocking poll loop. ``start()`` spawns + it on a daemon thread; ``stop()`` sets the stop-event and (best-effort) + joins the thread. + """ + + source: BargeinSource # class-level — subclasses set this + + def __init__(self, on_event: BargeinCallback): + self._on_event = on_event + self._stop = threading.Event() + self._thread: threading.Thread | None = None + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def start(self) -> None: + """Start the provider's background thread. Idempotent.""" + if self._thread is not None and self._thread.is_alive(): + return + self._stop.clear() + self._thread = threading.Thread( + target=self._run_guarded, + name=f"bargein-{self.source}", + daemon=True, + ) + self._thread.start() + + def stop(self, timeout: float = 2.0) -> None: + """Signal shutdown and best-effort join the thread.""" + self._stop.set() + thread = self._thread + if thread is not None and thread.is_alive(): + thread.join(timeout=timeout) + self._thread = None + + @property + def is_running(self) -> bool: + return self._thread is not None and self._thread.is_alive() + + # ------------------------------------------------------------------ + # Subclass contract + # ------------------------------------------------------------------ + + @abstractmethod + def _run(self) -> None: + """Provider-specific poll loop. Must return when ``self._stop`` is set.""" + + def _emit( + self, + event_type: BargeinEventType, + metadata: dict | None = None, + ) -> None: + """Emit an event to the registered callback. Swallows callback errors.""" + try: + self._on_event( + BargeinEvent( + source=self.source, + event_type=event_type, + metadata=metadata or {}, + ) + ) + except Exception: + # Provider must not die because the consumer threw. + import logging + + logging.getLogger(f"bargein.{self.source}").exception("barge-in callback raised; continuing") + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + def _run_guarded(self) -> None: + """Wrap ``_run`` so an unexpected raise logs instead of vanishing silently.""" + import logging + + log = logging.getLogger(f"bargein.{self.source}") + try: + self._run() + except Exception: + log.exception("provider loop crashed") diff --git a/bargein/providers/superwhisper.py b/bargein/providers/superwhisper.py new file mode 100644 index 0000000..c95c2e3 --- /dev/null +++ b/bargein/providers/superwhisper.py @@ -0,0 +1,265 @@ +"""SuperWhisper barge-in provider. + +Watches the SuperWhisper recordings directory and its SQLite DB for +recording start/end, emitting ``BargeinEvent``s through the registered +callback. This is the in-process replacement for the standalone +``integrations/bargein-producer.py`` script: same detection logic, but +instead of writing ``/tmp/mod3-barge-in.json`` it calls directly into +mod3's barge-in consumer. + +Detection: + * Start: a new empty timestamped folder appears under the recordings dir. + * End (any of): + - ``output.wav`` or ``meta.json`` appears in that folder, OR + - a matching row appears in ``superwhisper.sqlite`` (structural ground + truth — written only after transcription completes), OR + - the folder disappears (cancellation), OR + - the staleness timeout elapses without the above (crash / sleep). + +Environment variables: + SW_RECORDINGS_DIR — override recordings path + BARGEIN_POLL_MS — poll interval in ms (default: 150) +""" + +from __future__ import annotations + +import logging +import os +import time +from pathlib import Path + +from .base import BargeinProvider + +log = logging.getLogger("bargein.superwhisper") + + +class SuperWhisperProvider(BargeinProvider): + """Barge-in provider backed by SuperWhisper's recordings folder + DB.""" + + source = "superwhisper" + + # Default ~/Documents/superwhisper/recordings, overridable via env. + _DEFAULT_REC_DIR = os.path.expanduser("~/Documents/superwhisper/recordings") + # SuperWhisper SQLite DB — secondary "recording finished" signal. + _SW_DB = os.path.expanduser("~/Library/Application Support/SuperWhisper/database/superwhisper.sqlite") + # 2.5 minutes — recordings can legitimately run 60s+; be generous + # before declaring a stuck folder stale. + _STALE_TIMEOUT = 150 + _STARTUP_FRESH_SECS = 30 + + def __init__(self, on_event, recordings_dir: str | None = None, poll_ms: int | None = None): + super().__init__(on_event) + self.recordings_dir = Path(recordings_dir or os.environ.get("SW_RECORDINGS_DIR", self._DEFAULT_REC_DIR)) + poll_ms = poll_ms if poll_ms is not None else int(os.environ.get("BARGEIN_POLL_MS", "150")) + self._poll_interval = poll_ms / 1000.0 + + # Mutable state (touched only from the provider thread) + self._recording = False + self._active_folder: str | None = None + self._known_folders: set[str] = set() + self._last_dir_mtime: float = 0.0 + + # ------------------------------------------------------------------ + # State transitions (emit events through the callback) + # ------------------------------------------------------------------ + + def _start_recording(self, folder: str) -> None: + if self._recording and self._active_folder == folder: + return + self._recording = True + self._active_folder = folder + log.info("Recording started (folder=%s)", folder) + self._emit("user_speaking_start", {"folder": folder}) + + def _end_recording(self, reason: str) -> None: + if not self._recording: + return + folder = self._active_folder + self._recording = False + if folder: + self._known_folders.add(folder) + self._active_folder = None + log.info("Recording finished (folder=%s, reason=%s)", folder, reason) + self._emit("user_speaking_end", {"folder": folder, "reason": reason}) + + # ------------------------------------------------------------------ + # Detection helpers + # ------------------------------------------------------------------ + + @staticmethod + def _is_empty_dir(path: Path) -> bool: + try: + return path.is_dir() and not any(path.iterdir()) + except OSError: + return False + + @staticmethod + def _has_output(path: Path) -> bool: + return (path / "output.wav").exists() or (path / "meta.json").exists() + + @classmethod + def _is_in_db(cls, folder_name: str) -> bool: + """True if SuperWhisper's DB has a ``recording`` row for this folder. + + SuperWhisper writes the row only after transcription completes, so a + hit here is a definitive "recording is done" signal regardless of + filesystem state. + """ + try: + import sqlite3 + + conn = sqlite3.connect(f"file:{cls._SW_DB}?mode=ro", uri=True, timeout=1.0) + cursor = conn.execute( + "SELECT 1 FROM recording WHERE folderName = ? LIMIT 1", + (folder_name,), + ) + found = cursor.fetchone() is not None + conn.close() + return found + except Exception: + return False + + def _scan(self) -> None: + """One poll cycle: detect state changes in the recordings dir.""" + rec_dir = self.recordings_dir + + # Fast path: if we're tracking an active recording, check completion signals + if self._recording and self._active_folder: + active_path = rec_dir / self._active_folder + if self._has_output(active_path): + self._end_recording(reason="output_files") + return + if self._is_in_db(self._active_folder): + log.info("DB confirms recording complete (filesystem missed it)") + self._end_recording(reason="db") + return + if not active_path.exists(): + log.warning("Active recording folder disappeared, clearing state") + self._end_recording(reason="folder_gone") + return + # Fall through so we can detect a newer recording superseding this one + + # Stat-then-iterdir: skip the expensive scan if mtime is unchanged + try: + dir_mtime = os.stat(rec_dir).st_mtime + except OSError: + return + if dir_mtime == self._last_dir_mtime: + return + self._last_dir_mtime = dir_mtime + + try: + candidates: list[Path] = [] + for entry in rec_dir.iterdir(): + if entry.is_dir() and entry.name.isdigit() and entry.name not in self._known_folders: + candidates.append(entry) + except OSError: + return + + candidates.sort(key=lambda p: p.name, reverse=True) + for entry in candidates[:5]: + name = entry.name + if self._is_empty_dir(entry): + self._start_recording(name) + return + # Non-empty, previously unseen — completed recording we missed + self._known_folders.add(name) + + def _check_stale(self) -> None: + """Clear stuck recording state if the active folder has been empty too long. + + Before clearing, double-check the DB so legitimately long recordings + aren't thrown away when they finally land. + """ + if not self._recording or not self._active_folder: + return + folder = self.recordings_dir / self._active_folder + try: + ctime = folder.stat().st_birthtime + except (OSError, AttributeError): + return + if time.time() - ctime <= self._STALE_TIMEOUT: + return + + if self._is_in_db(self._active_folder): + log.info("Stale timeout hit but DB confirms completion — ending normally") + self._end_recording(reason="db_after_stale") + elif self._has_output(folder): + log.info("Stale timeout hit but output files present — ending normally") + self._end_recording(reason="output_after_stale") + else: + log.warning( + "Stale recording (>%ds), no DB entry, no output files — clearing as cancelled/crashed", + self._STALE_TIMEOUT, + ) + self._end_recording(reason="stale") + + # ------------------------------------------------------------------ + # Startup scan: handle recordings that existed before we started + # ------------------------------------------------------------------ + + def _startup_scan(self) -> None: + now = time.time() + newest_empty: tuple[str, float] | None = None + try: + for entry in self.recordings_dir.iterdir(): + if not (entry.is_dir() and entry.name.isdigit()): + continue + if self._has_output(entry): + self._known_folders.add(entry.name) + elif self._is_empty_dir(entry): + try: + age = now - entry.stat().st_birthtime + except (OSError, AttributeError): + age = float("inf") + if age < self._STARTUP_FRESH_SECS: + if newest_empty is None or entry.name > newest_empty[0]: + newest_empty = (entry.name, age) + else: + self._known_folders.add(entry.name) + except OSError as e: + log.warning("Startup scan error: %s", e) + + if newest_empty: + log.info("Detected in-progress recording on startup (age=%.1fs)", newest_empty[1]) + self._start_recording(newest_empty[0]) + + # ------------------------------------------------------------------ + # Provider contract + # ------------------------------------------------------------------ + + def _run(self) -> None: + rec_dir = self.recordings_dir + if not rec_dir.is_dir(): + log.warning( + "SuperWhisper recordings directory not found: %s (provider inactive)", + rec_dir, + ) + return + + self._startup_scan() + log.info( + "SuperWhisper provider running (poll=%dms, recordings=%s, known=%d)", + self._poll_interval * 1000, + rec_dir, + len(self._known_folders), + ) + + stale_every = max(1, int(2.0 / self._poll_interval)) + stale_counter = 0 + while not self._stop.is_set(): + try: + self._scan() + stale_counter += 1 + if stale_counter >= stale_every: + self._check_stale() + stale_counter = 0 + except Exception: + log.exception("SuperWhisper poll cycle raised; continuing") + # Use Event.wait for responsive shutdown + if self._stop.wait(self._poll_interval): + break + + if self._recording: + # Emit a synthetic end so consumers don't stay in "speaking" forever + self._end_recording(reason="shutdown") diff --git a/bus_bridge.py b/bus_bridge.py new file mode 100644 index 0000000..ad2f014 --- /dev/null +++ b/bus_bridge.py @@ -0,0 +1,305 @@ +"""Kernel-bus SSE subscriber. + +Consumes http://localhost:6931/v1/events/stream and yields parsed bus events. +Reconnects on disconnect with exponential backoff. Tolerates unknown event kinds +per ADR-083 (cycle-trace event contract). + +C3 will consume this to broadcast CycleEvents to dashboard WebSocket clients. + +The kernel (see apps/cogos/bus_stream.go) emits SSE frames of the form: + + data: {"id":"live_*_42","type":"bus.event","timestamp":"...","data":{}}\\n\\n + +Heartbeats arrive as SSE comment lines: + + : keep-alive\\n\\n + +An initial frame of {"type":"connected","bus_id":"*","timestamp":"..."} is +sent on subscribe — we surface that as a BusEnvelope with kind="connected". +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +from dataclasses import dataclass, field +from typing import Any, AsyncIterator, Optional + +import httpx + +logger = logging.getLogger("mod3.bus_bridge") + +# Path appended to ``COGOS_ENDPOINT`` (or the default below) to form the +# kernel SSE stream URL. +KERNEL_BUS_STREAM_PATH = "/v1/events/stream" + +_DEFAULT_KERNEL_BASE = "http://localhost:6931" + + +def default_stream_url() -> str: + """Build the kernel bus stream URL from ``COGOS_ENDPOINT`` (or the default). + + Resolved at call time, not at import time, so tests and runtime config + can override the env var before the bridge is constructed. + """ + base = os.environ.get("COGOS_ENDPOINT", _DEFAULT_KERNEL_BASE).rstrip("/") + return f"{base}{KERNEL_BUS_STREAM_PATH}" + + +# Back-compat module attribute. New code should call ``default_stream_url()`` +# so that ``COGOS_ENDPOINT`` overrides take effect at runtime. +KERNEL_BUS_STREAM_URL = default_stream_url() + + +@dataclass +class BusEnvelope: + """Raw bus-envelope record as received from the kernel SSE stream. + + `raw` is the full outer JSON (the bus.event envelope). `payload` is the + inner CogBlock dict (envelope["data"]) — may be {} for non-bus.event + frames (e.g. the initial "connected" frame). `kind` is the best-effort + event-kind string: preferring payload["kind"] (ADR-083 CycleEvent), then + payload["type"], then envelope["type"]. Consumers MUST tolerate unknown + kinds. + """ + + raw: dict + kind: str + payload: dict = field(default_factory=dict) + ts: Optional[str] = None + event_id: Optional[str] = None + + +def _extract_kind(envelope: dict, payload: dict) -> str: + for src in (payload, envelope): + for key in ("kind", "type"): + val = src.get(key) if isinstance(src, dict) else None + if isinstance(val, str) and val: + return val + return "unknown" + + +class KernelBusSubscriber: + """Async SSE subscriber for the cogos kernel bus stream. + + Usage:: + + sub = KernelBusSubscriber() + async for env in sub.stream(): + handle(env) + + `stream()` yields indefinitely; on any transport error it reconnects + with exponential backoff clamped to [reconnect_min_s, reconnect_max_s]. + Call `close()` (or cancel the consuming task) to stop. + """ + + def __init__( + self, + url: Optional[str] = None, + *, + bus_filter: str = "*", + consumer_id: Optional[str] = None, + reconnect_min_s: float = 1.0, + reconnect_max_s: float = 30.0, + request_timeout_s: float = 10.0, + ) -> None: + # ``COGOS_ENDPOINT`` is honored at construction time when ``url`` is + # not explicitly provided, so the subscriber tracks whatever endpoint + # the rest of the cogos client code is using. + self._url = url or default_stream_url() + self._bus_filter = bus_filter + self._consumer_id = consumer_id + self._min_backoff = reconnect_min_s + self._max_backoff = reconnect_max_s + self._request_timeout = request_timeout_s + self._last_event_id: Optional[str] = None + self._closed = asyncio.Event() + self._client: Optional[httpx.AsyncClient] = None + + async def close(self) -> None: + self._closed.set() + if self._client is not None: + try: + await self._client.aclose() + except Exception: # pragma: no cover - best-effort + pass + self._client = None + + def _build_params(self) -> dict[str, str]: + params: dict[str, str] = {} + if self._bus_filter and self._bus_filter != "*": + params["bus_id"] = self._bus_filter + if self._consumer_id: + params["consumer"] = self._consumer_id + return params + + def _build_headers(self) -> dict[str, str]: + headers = {"Accept": "text/event-stream", "Cache-Control": "no-cache"} + if self._last_event_id: + # Harmless if the kernel doesn't honor it today; future protocol + # bump may use it for resume. + headers["Last-Event-ID"] = self._last_event_id + return headers + + async def stream(self) -> AsyncIterator[BusEnvelope]: + backoff = self._min_backoff + # Generous read timeout — SSE is long-lived with 30s heartbeats. + timeout = httpx.Timeout(self._request_timeout, read=None) + while not self._closed.is_set(): + self._client = httpx.AsyncClient(timeout=timeout) + try: + async with self._client.stream( + "GET", + self._url, + params=self._build_params(), + headers=self._build_headers(), + ) as resp: + if resp.status_code != 200: + logger.info( + "bus-bridge: non-200 from %s: %s — backing off %.1fs", + self._url, + resp.status_code, + backoff, + ) + await self._sleep_or_close(backoff) + backoff = min(self._max_backoff, max(self._min_backoff, backoff * 2)) + continue + + logger.info("bus-bridge: connected to %s", self._url) + backoff = self._min_backoff # reset on successful connect + + async for envelope in self._iter_sse(resp): + yield envelope + except (httpx.HTTPError, asyncio.TimeoutError, ConnectionError) as e: + logger.info( + "bus-bridge: transport error (%s); reconnecting in %.1fs", + e.__class__.__name__, + backoff, + ) + await self._sleep_or_close(backoff) + backoff = min(self._max_backoff, max(self._min_backoff, backoff * 2)) + except asyncio.CancelledError: + await self.close() + raise + finally: + if self._client is not None: + try: + await self._client.aclose() + except Exception: # pragma: no cover + pass + self._client = None + + async def _sleep_or_close(self, seconds: float) -> None: + try: + await asyncio.wait_for(self._closed.wait(), timeout=seconds) + except asyncio.TimeoutError: + return + + async def _iter_sse(self, resp: httpx.Response) -> AsyncIterator[BusEnvelope]: + """Parse the SSE byte stream into BusEnvelope records. + + Minimal SSE parser: we accumulate field lines into the current event, + dispatch on blank-line boundaries, silently skip comment lines + (`: heartbeat`), and honor `data:`, `event:`, `id:` fields. + """ + event_name: Optional[str] = None + data_lines: list[str] = [] + event_id: Optional[str] = None + + async for raw_line in resp.aiter_lines(): + if self._closed.is_set(): + return + # httpx strips the trailing \n but preserves empty lines. + if raw_line == "": + # Dispatch boundary. + if data_lines: + env = self._parse_event(event_name, "\n".join(data_lines), event_id) + if env is not None: + yield env + event_name = None + data_lines = [] + event_id = None + continue + if raw_line.startswith(":"): + # Comment line / heartbeat. + continue + field, _, value = raw_line.partition(":") + if value.startswith(" "): + value = value[1:] + if field == "data": + data_lines.append(value) + elif field == "event": + event_name = value + elif field == "id": + event_id = value + self._last_event_id = value + # retry / unknown fields: ignore + + def _parse_event(self, event_name: Optional[str], data: str, event_id: Optional[str]) -> Optional[BusEnvelope]: + try: + envelope: Any = json.loads(data) + except json.JSONDecodeError: + logger.debug("bus-bridge: non-JSON data frame dropped: %r", data[:200]) + return None + if not isinstance(envelope, dict): + logger.debug("bus-bridge: non-object data frame dropped: %r", envelope) + return None + + inner = envelope.get("data") + payload: dict = inner if isinstance(inner, dict) else {} + kind = _extract_kind(envelope, payload) + ts = envelope.get("timestamp") or payload.get("ts") or payload.get("timestamp") + eid = event_id or envelope.get("id") + if eid and not self._last_event_id: + self._last_event_id = eid + + if kind not in ("state_transition", "tool_dispatch", "assessment", "bus.event", "connected"): + # Tolerate unknowns — just log and forward. + logger.debug("bus-bridge: forwarding unknown event kind=%r", kind) + + return BusEnvelope( + raw=envelope, + kind=kind, + payload=payload, + ts=ts if isinstance(ts, str) else None, + event_id=eid if isinstance(eid, str) else None, + ) + + +# --------------------------------------------------------------------------- +# Manual validation entry point +# --------------------------------------------------------------------------- + + +async def _main() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + sub = KernelBusSubscriber() + print(f"bus-bridge: subscribing to {sub._url} (Ctrl-C to stop)") + try: + async for env in sub.stream(): + print( + json.dumps( + { + "kind": env.kind, + "ts": env.ts, + "id": env.event_id, + "payload_keys": sorted(env.payload.keys())[:12], + } + ) + ) + except (KeyboardInterrupt, asyncio.CancelledError): + pass + finally: + await sub.close() + + +if __name__ == "__main__": + try: + asyncio.run(_main()) + except KeyboardInterrupt: + pass diff --git a/bus_bridge_runner.py b/bus_bridge_runner.py new file mode 100644 index 0000000..aa788c4 --- /dev/null +++ b/bus_bridge_runner.py @@ -0,0 +1,154 @@ +"""Kernel-bus → dashboard bridge runner. + +Consumes `KernelBusSubscriber.stream()` (see `bus_bridge.py`) and fans the +ADR-083 cycle-trace events out to every connected dashboard WebSocket via +`BrowserChannel.broadcast_trace_event()` (see `channels.py`). + +Wiring: + + kernel (bus_cycle_trace) + └─► SSE /v1/events/stream?bus_id=bus_cycle_trace + └─► KernelBusSubscriber.stream() [C1] + └─► run_bridge() filter + forward + └─► BrowserChannel.broadcast_trace_event() [C2] + +The subscriber does its own reconnect with exponential backoff, so a kernel +that is temporarily unreachable does not affect server startup. Disable the +bridge entirely at process boot by setting env `MOD3_BUS_BRIDGE_DISABLED=1`. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Optional + +from bus_bridge import KernelBusSubscriber, default_stream_url +from channels import BrowserChannel + +logger = logging.getLogger("mod3.bus_bridge") + +# ADR-083 kinds the dashboard trace panel cares about. Kept as a module-level +# constant so tests and the lifespan wiring share one definition. +ADR083_KINDS: frozenset[str] = frozenset({"state_transition", "tool_dispatch", "assessment"}) + +# Kernel-side bus name (see apps/cogos/trace_emit.go:const traceBusID). +TRACE_BUS_ID = "bus_cycle_trace" + +# Env flag consulted at startup. +DISABLE_ENV = "MOD3_BUS_BRIDGE_DISABLED" + + +def is_disabled() -> bool: + """True when MOD3_BUS_BRIDGE_DISABLED is set to a truthy value.""" + v = os.environ.get(DISABLE_ENV, "").strip().lower() + return v in ("1", "true", "yes", "on") + + +async def run_bridge( + subscriber: KernelBusSubscriber, + *, + filter_kinds: Optional[set[str]] = None, +) -> None: + """Consume `subscriber` and broadcast cycle-trace events to dashboard clients. + + `filter_kinds`: + - `None`: forward everything (dev mode — useful when inspecting the raw + stream through a dashboard). + - set of kind strings: only forward envelopes whose `BusEnvelope.kind` + is in the set. Unknown kinds are tolerated per ADR-083 — they simply + won't pass this filter. + + `BrowserChannel.broadcast_trace_event()` is thread-safe and non-blocking: + it dispatches each WS send via `run_coroutine_threadsafe`. We call it + directly (no await). + """ + first_event_logged = False + forwarded = 0 + async for env in subscriber.stream(): + if filter_kinds is not None and env.kind not in filter_kinds: + continue + # The "connected" bootstrap frame has an empty payload; skip silently. + if env.kind == "connected": + continue + if not first_event_logged: + logger.info( + "bridge: first event forwarded kind=%s event_id=%s", + env.kind, + env.event_id, + ) + first_event_logged = True + try: + BrowserChannel.broadcast_trace_event(env.payload) + forwarded += 1 + logger.debug( + "bridge: forwarded kind=%s event_id=%s (total=%d)", + env.kind, + env.event_id, + forwarded, + ) + except Exception as exc: # noqa: BLE001 — broadcaster is best-effort + logger.debug("bridge: broadcast failed: %s", exc) + + +async def start_bridge( + app_state: object, + *, + url: Optional[str] = None, + bus_filter: str = TRACE_BUS_ID, + filter_kinds: Optional[set[str]] = frozenset(ADR083_KINDS), +) -> None: + """Construct the subscriber + bridge task and store them on `app_state`. + + Startup is non-blocking: we don't await the task or probe the kernel. + The subscriber's own backoff loop handles reconnects. Logs a disabled + notice and returns cleanly when `MOD3_BUS_BRIDGE_DISABLED` is set. + + ``url`` defaults to ``COGOS_ENDPOINT`` (resolved at call time) so the + subscriber tracks whatever endpoint the rest of the cogos client code is + using. + """ + if is_disabled(): + logger.info("bridge: disabled via %s=1", DISABLE_ENV) + setattr(app_state, "bus_bridge_subscriber", None) + setattr(app_state, "bus_bridge_task", None) + return + + resolved_url = url or default_stream_url() + subscriber = KernelBusSubscriber(url=resolved_url, bus_filter=bus_filter, consumer_id="mod3-dashboard") + task = asyncio.create_task( + run_bridge(subscriber, filter_kinds=set(filter_kinds) if filter_kinds else None), + name="mod3-bus-bridge", + ) + setattr(app_state, "bus_bridge_subscriber", subscriber) + setattr(app_state, "bus_bridge_task", task) + logger.info( + "bridge: started, target=%s bus_id=%s filter=%s", + resolved_url, + bus_filter, + sorted(filter_kinds) if filter_kinds else "*", + ) + + +async def stop_bridge(app_state: object, *, timeout_s: float = 2.0) -> None: + """Gracefully stop the bridge: close subscriber, await task, cancel on timeout.""" + subscriber: Optional[KernelBusSubscriber] = getattr(app_state, "bus_bridge_subscriber", None) + task: Optional[asyncio.Task] = getattr(app_state, "bus_bridge_task", None) + if subscriber is None and task is None: + return + if subscriber is not None: + try: + await subscriber.close() + except Exception: # pragma: no cover - best-effort + pass + if task is not None: + try: + await asyncio.wait_for(task, timeout=timeout_s) + except (asyncio.TimeoutError, asyncio.CancelledError): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): # pragma: no cover + pass + logger.info("bridge: stopped") diff --git a/channels.py b/channels.py index c026e64..702dbc8 100644 --- a/channels.py +++ b/channels.py @@ -8,6 +8,12 @@ T1 (Whisper Base, ~31ms): per-chunk during speech T2 (Whisper Large, ~470ms): on natural pause T3 (Whisper Large, ~470ms): on end-of-utterance (final) + +Server→client WebSocket message types: + audio, response_text, response_complete, interrupted, + partial_transcript, transcript, + trace_event — kernel cycle-trace events (ADR-083), fanned out via + BrowserChannel.broadcast_trace_event(). """ from __future__ import annotations @@ -33,6 +39,12 @@ class BrowserChannel: """WebSocket-backed channel for the browser dashboard.""" + # Registry of currently-active dashboard channels. Used by + # broadcast_trace_event() to fan kernel cycle-trace events out to every + # connected dashboard client (see ADR-083). Populated in __init__, + # pruned in _cleanup. + _active_channels: "set[BrowserChannel]" = set() + def __init__( self, ws: WebSocket, @@ -71,6 +83,7 @@ def __init__( modalities=[ModalityType.VOICE, ModalityType.TEXT], deliver=self._deliver_sync, ) + BrowserChannel._active_channels.add(self) logger.info("BrowserChannel registered: %s", self.channel_id) # ------------------------------------------------------------------ @@ -424,6 +437,62 @@ async def send_response_complete(self, metrics: dict | None = None) -> None: except Exception: self._active = False + # ------------------------------------------------------------------ + # Trace event broadcast (kernel cycle-trace → dashboards) + # ------------------------------------------------------------------ + + @classmethod + def broadcast_trace_event(cls, event: dict) -> None: + """Fan a kernel cycle-trace event out to every connected dashboard. + + Per ADR-083, `event` is a pre-parsed CycleEvent dict + (id, ts, source, cycle_id, kind, payload). Wrapped in the + `{"type": "trace_event", "event": ...}` envelope and sent to each + active BrowserChannel's WebSocket. Clients whose send fails are + skipped silently (they will be pruned by their own disconnect path). + """ + frame = {"type": "trace_event", "event": event} + for ch in list(cls._active_channels): + if not ch._active: + continue + try: + asyncio.run_coroutine_threadsafe(ch.ws.send_json(frame), ch._loop) + except Exception as exc: # noqa: BLE001 — disconnected clients are expected + logger.debug("trace_event send failed for %s: %s", ch.channel_id, exc) + + @classmethod + def broadcast_response_text(cls, text: str, session_id: str | None = None) -> None: + """Push an agent-reply text frame to dashboard WebSocket clients. + + Used by the MOD3_USE_COGOS_AGENT response bridge (see + `cogos_agent_bridge.run_response_bridge`). The frame matches the + existing text-response shape emitted by `_deliver_async` and + `send_response_text`: `{"type": "response_text", "text": }`. + + If `session_id` is None (default) the frame is broadcast to every + active dashboard channel. When provided, only channels whose + `channel_id` matches the `mod3:` convention from + `cogos_agent_bridge.post_user_message` receive the frame — this is + how future multi-user routing will land, but for v1 a None + broadcast is the common case (only one dashboard attached). + + Thread-safe: dispatches each WS send via `run_coroutine_threadsafe` + on the channel's own loop, matching `broadcast_trace_event`. + """ + frame = {"type": "response_text", "text": text} + expected_channel = None + if session_id and session_id.startswith("mod3:"): + expected_channel = session_id[len("mod3:"):] + for ch in list(cls._active_channels): + if not ch._active: + continue + if expected_channel and ch.channel_id != expected_channel: + continue + try: + asyncio.run_coroutine_threadsafe(ch.ws.send_json(frame), ch._loop) + except Exception as exc: # noqa: BLE001 — disconnected clients are expected + logger.debug("response_text send failed for %s: %s", ch.channel_id, exc) + # ------------------------------------------------------------------ # Cleanup # ------------------------------------------------------------------ @@ -431,6 +500,7 @@ async def send_response_complete(self, metrics: dict | None = None) -> None: def _cleanup(self) -> None: """Deactivate channel and cancel pending TTS jobs on disconnect.""" self._active = False + BrowserChannel._active_channels.discard(self) ch = self.bus._channels.get(self.channel_id) if ch: ch.active = False diff --git a/cogos_agent_bridge.py b/cogos_agent_bridge.py new file mode 100644 index 0000000..c8616f2 --- /dev/null +++ b/cogos_agent_bridge.py @@ -0,0 +1,292 @@ +"""CogOS kernel agent bridge (MOD3_USE_COGOS_AGENT=1). + +When the env flag is set, Mod³'s agent loop forwards user turns to the +cogos kernel's metabolic cycle instead of the local inference provider: + + browser → WS turn → post_user_message() ─POST /v1/bus/send─► kernel + │ + ▼ + bus_dashboard_chat + │ + ▼ + kernel cycle → `respond` tool + │ + ▼ + bus_dashboard_response + │ + SSE /v1/events/stream + │ + ▼ + KernelBusSubscriber.stream() + │ + ▼ + run_response_bridge() + │ + ▼ + BrowserChannel.broadcast_response_text() + +The subscriber does its own reconnect with exponential backoff (see +`bus_bridge.py`). Disable the whole fork by leaving `MOD3_USE_COGOS_AGENT` +unset (default). + +Note: the kernel's `POST /v1/bus/send` takes a flat `{bus_id, from, to, +message, type}` body — the inner JSON event is serialised into `message` +(matches the pattern used by other cogos producers). +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +from datetime import datetime, timezone +from typing import Optional + +import httpx + +from bus_bridge import KernelBusSubscriber, default_stream_url +from channels import BrowserChannel + +logger = logging.getLogger("mod3.cogos_agent") + +# Bus names — contract with the kernel side (see ADR / c-agent subagent). +CHAT_BUS_ID = "bus_dashboard_chat" +RESPONSE_BUS_ID = "bus_dashboard_response" + + +def _kernel_base() -> str: + """Resolve the kernel base URL from ``COGOS_ENDPOINT`` at call time.""" + return os.environ.get("COGOS_ENDPOINT", "http://localhost:6931").rstrip("/") + + +def _bus_send_url() -> str: + """Build the kernel bus-send URL from the current ``COGOS_ENDPOINT``.""" + return f"{_kernel_base()}/v1/bus/send" + + +# Back-compat module attribute. Use ``_bus_send_url()`` for runtime resolution. +BUS_SEND_URL = _bus_send_url() + +# Env gate. +ENABLE_ENV = "MOD3_USE_COGOS_AGENT" + +_POST_TIMEOUT_S = 5.0 + + +def is_enabled() -> bool: + """True when MOD3_USE_COGOS_AGENT is set to a truthy value.""" + v = os.environ.get(ENABLE_ENV, "").strip().lower() + return v in ("1", "true", "yes", "on") + + +def _now_rfc3339() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +async def post_user_message(text: str, session_id: str) -> bool: + """POST a user turn to the kernel's `bus_dashboard_chat` bus. + + Returns True if the send succeeded (kernel replied 2xx), False otherwise. + Logs at warning-level on failure but never raises — callers use graceful + degradation (e.g. show an error response frame to the dashboard). + + The kernel's handleBusSend (see apps/cogos/bus_api.go) accepts + `{bus_id, from, to, message, type}` — we JSON-encode the full event dict + into `message` so the kernel's cycle receives the structured payload. + """ + event = { + "type": "user_message", + "text": text, + "session_id": session_id, + "ts": _now_rfc3339(), + } + body = { + "bus_id": CHAT_BUS_ID, + "from": "mod3-dashboard", + "type": "user_message", + "message": json.dumps(event, separators=(",", ":")), + } + url = _bus_send_url() + try: + async with httpx.AsyncClient(timeout=_POST_TIMEOUT_S) as client: + resp = await client.post(url, json=body) + except httpx.HTTPError as exc: + logger.warning("cogos-agent: post to %s failed: %s", url, exc) + return False + if resp.status_code // 100 != 2: + logger.warning( + "cogos-agent: post non-2xx: %d body=%r", + resp.status_code, + resp.text[:200], + ) + return False + logger.info( + "cogos-agent: forwarded user turn to kernel bus (session=%s)", + session_id, + ) + return True + + +def _extract_session_id(payload: dict) -> Optional[str]: + """Extract the ``session_id`` from a kernel reply payload, if present. + + Mirrors :func:`_extract_response_text`: checks the top-level shape and + the JSON-encoded ``content`` wrapper that ``handleBusSend`` produces. + Returns ``None`` for older kernels that don't include a session id, or + for non-session-scoped events. + + The downstream :meth:`BrowserChannel.broadcast_response_text` falls + back to broadcasting when ``session_id`` is ``None``, preserving the + backward-compat behavior. + """ + if not isinstance(payload, dict): + return None + top = payload.get("session_id") + if isinstance(top, str) and top: + return top + content = payload.get("content") + if isinstance(content, str) and content: + try: + inner = json.loads(content) + except (TypeError, ValueError): + return None + if isinstance(inner, dict): + sid = inner.get("session_id") + if isinstance(sid, str) and sid: + return sid + return None + + +def _extract_response_text(payload: dict) -> Optional[str]: + """Dig the assistant reply out of the bus event payload. + + Kernel's `handleBusSend` wraps the sent `message` string inside a + `{"content": ""}` map. On SSE delivery, the envelope's `data` + field is that map. We look first for structured keys (`text`, direct + agent_response shape), then fall through to parsing `content` as JSON. + """ + if not isinstance(payload, dict): + return None + # Direct shape (if an upstream producer wrote the event dict at the top level). + for key in ("text", "reply", "response"): + val = payload.get(key) + if isinstance(val, str) and val: + return val + # Standard bus envelope: payload = {"content": ""} + content = payload.get("content") + if isinstance(content, str) and content: + try: + inner = json.loads(content) + except (TypeError, ValueError): + # Free-form string — treat the whole thing as the reply. + return content + if isinstance(inner, dict): + for key in ("text", "reply", "response"): + val = inner.get(key) + if isinstance(val, str) and val: + return val + elif isinstance(inner, str) and inner: + return inner + return None + + +async def run_response_bridge(subscriber: KernelBusSubscriber) -> None: + """Consume `subscriber` and broadcast agent replies to dashboard clients. + + `BrowserChannel.broadcast_response_text()` is thread-safe via + `run_coroutine_threadsafe`, matching the existing trace-event pattern. + Malformed events (no recoverable text) are logged at debug and skipped. + """ + first_event_logged = False + forwarded = 0 + async for env in subscriber.stream(): + if env.kind == "connected": + continue + text = _extract_response_text(env.payload) + if not text: + logger.debug( + "cogos-agent: skip event with no text kind=%s id=%s", + env.kind, + env.event_id, + ) + continue + if not first_event_logged: + logger.info( + "cogos-agent: first response forwarded kind=%s event_id=%s", + env.kind, + env.event_id, + ) + first_event_logged = True + session_id = _extract_session_id(env.payload) + try: + BrowserChannel.broadcast_response_text(text, session_id=session_id) + forwarded += 1 + logger.debug( + "cogos-agent: forwarded response event_id=%s session=%s (total=%d)", + env.event_id, + session_id, + forwarded, + ) + except Exception as exc: # noqa: BLE001 — best-effort fan-out + logger.debug("cogos-agent: broadcast failed: %s", exc) + + +async def start_response_bridge( + app_state: object, + *, + url: Optional[str] = None, +) -> None: + """Construct the response subscriber + bridge task and store on `app_state`. + + No-op (logs once) when `MOD3_USE_COGOS_AGENT` is unset. + + ``url`` defaults to ``COGOS_ENDPOINT`` (resolved at call time) so the + subscriber tracks the same kernel endpoint as ``post_user_message``. + """ + if not is_enabled(): + logger.debug("cogos-agent: response bridge disabled (%s unset)", ENABLE_ENV) + setattr(app_state, "cogos_agent_subscriber", None) + setattr(app_state, "cogos_agent_task", None) + return + + resolved_url = url or default_stream_url() + subscriber = KernelBusSubscriber( + url=resolved_url, + bus_filter=RESPONSE_BUS_ID, + consumer_id="mod3-dashboard-agent", + ) + task = asyncio.create_task( + run_response_bridge(subscriber), + name="mod3-cogos-agent-bridge", + ) + setattr(app_state, "cogos_agent_subscriber", subscriber) + setattr(app_state, "cogos_agent_task", task) + logger.info( + "cogos-agent: response bridge started, target=%s bus_id=%s", + resolved_url, + RESPONSE_BUS_ID, + ) + + +async def stop_response_bridge(app_state: object, *, timeout_s: float = 2.0) -> None: + """Gracefully stop the response bridge: close subscriber, await task, cancel on timeout.""" + subscriber: Optional[KernelBusSubscriber] = getattr(app_state, "cogos_agent_subscriber", None) + task: Optional[asyncio.Task] = getattr(app_state, "cogos_agent_task", None) + if subscriber is None and task is None: + return + if subscriber is not None: + try: + await subscriber.close() + except Exception: # pragma: no cover - best-effort + pass + if task is not None: + try: + await asyncio.wait_for(task, timeout=timeout_s) + except (asyncio.TimeoutError, asyncio.CancelledError): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): # pragma: no cover + pass + logger.info("cogos-agent: response bridge stopped") diff --git a/dashboard/index.html b/dashboard/index.html index 914a2b0..ae166c5 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -171,6 +171,49 @@ 50% { opacity: 0.3; } } + /* Cycle Trace panel (bottom drawer) */ + #trace-panel { + position: fixed; left: 0; right: 0; bottom: 0; + background: var(--surface); border-top: 1px solid var(--border); + font-family: ui-monospace, SFMono-Regular, Menlo, monospace; + z-index: 20; max-height: 40vh; display: flex; flex-direction: column; + transition: max-height 0.2s ease; + } + #trace-panel.collapsed { max-height: 32px; } + #trace-panel .trace-header { + display: flex; align-items: center; gap: 8px; + padding: 6px 16px; border-bottom: 1px solid var(--border); + background: var(--bg); cursor: pointer; user-select: none; + font-size: 0.75rem; color: var(--muted); text-transform: uppercase; + letter-spacing: 0.5px; flex-shrink: 0; height: 32px; + } + #trace-panel .trace-header .trace-title { font-weight: 600; } + #trace-panel .trace-header .trace-toggle { + margin-left: auto; font-size: 0.7rem; color: var(--muted); + } + #trace-panel.collapsed #trace-entries { display: none; } + #trace-entries { + overflow-y: auto; padding: 4px 0; flex: 1; min-height: 0; + font-size: 0.75rem; line-height: 1.4; + } + .trace-entry { + display: flex; gap: 8px; align-items: baseline; + padding: 2px 16px; border-bottom: 1px solid rgba(48,54,61,0.3); + white-space: nowrap; overflow: hidden; + } + .trace-entry:hover { background: rgba(88,166,255,0.05); } + .trace-time { color: var(--muted); flex-shrink: 0; font-variant-numeric: tabular-nums; } + .trace-source { color: var(--muted); flex-shrink: 0; font-size: 0.7rem; } + .trace-kind { + flex-shrink: 0; font-size: 0.65rem; padding: 1px 6px; + border: 1px solid var(--border); border-radius: 3px; + text-transform: uppercase; letter-spacing: 0.5px; + } + .trace-summary { color: var(--text); overflow: hidden; text-overflow: ellipsis; } + + /* Leave room at the bottom so the drawer doesn't cover the input */ + body { padding-bottom: 32px; } + /* Responsive */ @media (max-width: 700px) { .main { padding: 12px 16px; } @@ -300,6 +343,15 @@

Mod³

For voice, use headphones. Speak naturally — the system detects when you start and stop. Speak during playback to interrupt.
+ + + +