Skip to content
221 changes: 221 additions & 0 deletions audio_subscribers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""Per-session audio subscriber registry (Wave 4.3).

A separate, tiny module so the routing surface stays independent of the
session registry (which is already load-bearing for ADR-082 Phase 1). The
dashboard WebSocket lands here; ``mcp_shim._play_wav_bytes`` queries this
registry before falling back to sounddevice; the kernel queries the
``/v1/sessions/{id}/subscribers`` HTTP endpoint before spawning afplay.

Thread-safety: registration happens in FastAPI's event loop thread (WS
handler), lookup and emit happen from the MCP shim playback thread. A
single RLock around the dict is sufficient — the set-per-session is small
(usually 1 dashboard) and contention is effectively zero.

Delivery semantics: the emit paths are best-effort. A WebSocket that has
already died just drops the frame — the kernel fallback ``afplay``
never happened because the check went through before we spawned it, but
the session will miss this turn's audio. That's acceptable: the dashboard
polls and reconnects; the user sees silence for one turn instead of
nothing at all.
"""

from __future__ import annotations

import asyncio
import logging
import threading
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover — type-check-only import
from fastapi import WebSocket

logger = logging.getLogger("mod3.audio_subscribers")


@dataclass
class _Subscriber:
"""A single WebSocket subscription for a session."""

ws: "WebSocket"
# The event loop the WebSocket was accepted on. Emit calls from other
# threads need to run_coroutine_threadsafe onto this loop.
loop: asyncio.AbstractEventLoop
# Monotonic sequence for logging / frame ordering. Opaque to callers.
seq: int = 0


@dataclass
class _SessionBucket:
"""Subscribers currently attached to a session_id."""

subscribers: list[_Subscriber] = field(default_factory=list)


class AudioSubscriberRegistry:
"""Thread-safe session_id → active WebSocket subscribers mapping.

Callers never reach into the bucket lists directly — they go through
register / unregister / has_subscribers / emit_wav.
"""

def __init__(self) -> None:
self._lock = threading.RLock()
self._buckets: dict[str, _SessionBucket] = {}
self._frame_seq = 0

# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------

def register(self, session_id: str, ws: "WebSocket", loop: asyncio.AbstractEventLoop) -> _Subscriber:
"""Attach ``ws`` to ``session_id``. Caller must have already
``accept()``-ed the WebSocket.
"""
sub = _Subscriber(ws=ws, loop=loop)
with self._lock:
bucket = self._buckets.setdefault(session_id, _SessionBucket())
bucket.subscribers.append(sub)
count = len(bucket.subscribers)
logger.info("audio subscriber attached: session=%s total=%d", session_id, count)
return sub

def unregister(self, session_id: str, sub: _Subscriber) -> None:
"""Detach. Idempotent — double-unregister is a no-op."""
with self._lock:
bucket = self._buckets.get(session_id)
if bucket is None:
return
try:
bucket.subscribers.remove(sub)
except ValueError:
return
remaining = len(bucket.subscribers)
if not bucket.subscribers:
# Drop empty buckets so the subscribed-check stays fast.
self._buckets.pop(session_id, None)
logger.info("audio subscriber detached: session=%s remaining=%d", session_id, remaining)

# ------------------------------------------------------------------
# Inspection
# ------------------------------------------------------------------

def has_subscribers(self, session_id: str) -> bool:
with self._lock:
bucket = self._buckets.get(session_id)
return bool(bucket and bucket.subscribers)

def count(self, session_id: str) -> int:
with self._lock:
bucket = self._buckets.get(session_id)
return len(bucket.subscribers) if bucket else 0

def snapshot(self) -> dict[str, int]:
"""session_id → subscriber count. For diagnostics only."""
with self._lock:
return {sid: len(b.subscribers) for sid, b in self._buckets.items()}

# ------------------------------------------------------------------
# Emit
# ------------------------------------------------------------------

def emit_wav(
self,
session_id: str,
wav_bytes: bytes,
*,
job_id: str | None = None,
duration_sec: float | None = None,
sample_rate: int | None = None,
) -> int:
"""Push a whole WAV blob to every subscriber of ``session_id``.

Returns the number of subscribers the frame was enqueued for. Each
send is fire-and-forget via ``run_coroutine_threadsafe`` — the caller
doesn't block on socket I/O, which matches the existing
``BrowserChannel.broadcast_trace_event`` pattern.

The wire format is a single binary WebSocket frame containing the
raw WAV (RIFF / WAVE) bytes. Browsers can decode this directly via
AudioContext.decodeAudioData. A preceding small JSON control frame
announces the incoming audio with session_id + job_id + duration
so the dashboard can correlate the blob to a synthesize call.
"""
with self._lock:
bucket = self._buckets.get(session_id)
if not bucket or not bucket.subscribers:
return 0
targets = list(bucket.subscribers)
self._frame_seq += 1
seq = self._frame_seq

header = {
"type": "audio_header",
"session_id": session_id,
"job_id": job_id,
"duration_sec": duration_sec,
"sample_rate": sample_rate,
"bytes": len(wav_bytes),
"format": "wav",
"seq": seq,
}

delivered = 0
for sub in targets:
try:
asyncio.run_coroutine_threadsafe(_send_audio_frame(sub.ws, header, wav_bytes), sub.loop)
delivered += 1
except Exception as exc: # noqa: BLE001 — disconnected subscribers are expected
logger.debug("emit_wav: scheduling failed for %s: %s", session_id, exc)
logger.info(
"emit_wav: session=%s bytes=%d delivered_to=%d seq=%d",
session_id,
len(wav_bytes),
delivered,
seq,
)
return delivered


async def _send_audio_frame(ws: "WebSocket", header: dict, wav_bytes: bytes) -> None:
"""Send header JSON + binary WAV over a single WebSocket.

Split into a module-level coroutine so ``run_coroutine_threadsafe``
returns a Future the caller can ignore — the pair is sent in order on
the socket's own coroutine context.
"""
try:
await ws.send_json(header)
await ws.send_bytes(wav_bytes)
except Exception as exc: # noqa: BLE001 — disconnect mid-send is expected
logger.debug("audio frame send failed: %s", exc)


# ---------------------------------------------------------------------------
# Process-global default registry — shared between http_api and mcp_shim.
# ---------------------------------------------------------------------------

_default_registry: AudioSubscriberRegistry | None = None
_default_registry_lock = threading.Lock()


def get_default_audio_subscribers() -> AudioSubscriberRegistry:
global _default_registry
with _default_registry_lock:
if _default_registry is None:
_default_registry = AudioSubscriberRegistry()
return _default_registry


def reset_default_audio_subscribers() -> None:
"""For tests — drop the module-level registry."""
global _default_registry
with _default_registry_lock:
_default_registry = None


__all__ = [
"AudioSubscriberRegistry",
"get_default_audio_subscribers",
"reset_default_audio_subscribers",
]
39 changes: 39 additions & 0 deletions channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
partial_transcript, transcript,
trace_event — kernel cycle-trace events (ADR-083), fanned out via
BrowserChannel.broadcast_trace_event().

The MOD3_USE_COGOS_AGENT kernel-bridged path emits response_text AND
response_complete via BrowserChannel.broadcast_response_{text,complete}
so the dashboard UI's turn-done signal fires on every turn, matching the
local-inference path's behavior.
"""

from __future__ import annotations
Expand Down Expand Up @@ -493,6 +498,40 @@ def broadcast_response_text(cls, text: str, session_id: str | None = None) -> No
except Exception as exc: # noqa: BLE001 — disconnected clients are expected
logger.debug("response_text send failed for %s: %s", ch.channel_id, exc)

@classmethod
def broadcast_response_complete(
cls,
metrics: dict | None = None,
session_id: str | None = None,
) -> None:
"""Push a `response_complete` frame to dashboard WebSocket clients.

Companion to :meth:`broadcast_response_text`: the MOD3_USE_COGOS_AGENT
response bridge emits exactly one complete-frame per kernel
`agent_response` event so the dashboard UI's per-turn `isResponding`
state gets cleared (otherwise the chat panel spinner hangs forever).

Routing and threading match `broadcast_response_text` 1:1 — pass the
same `session_id` so the completion frame lands on the same channel
that received the text frames for this turn. `metrics` follows the
local-path convention from `agent_loop._process` (`{"llm_ms": ...,
"provider": ...}`); the kernel path populates it with
`{"provider": "cogos-agent", ...}`.
"""
frame = {"type": "response_complete", "metrics": metrics or {}}
expected_channel = None
if session_id and session_id.startswith("mod3:"):
expected_channel = session_id[len("mod3:") :]
for ch in list(cls._active_channels):
if not ch._active:
continue
if expected_channel and ch.channel_id != expected_channel:
continue
try:
asyncio.run_coroutine_threadsafe(ch.ws.send_json(frame), ch._loop)
except Exception as exc: # noqa: BLE001 — disconnected clients are expected
logger.debug("response_complete send failed for %s: %s", ch.channel_id, exc)

# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
Expand Down
31 changes: 28 additions & 3 deletions cogos_agent_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,24 @@ def _extract_response_text(payload: dict) -> Optional[str]:
async def run_response_bridge(subscriber: KernelBusSubscriber) -> None:
"""Consume `subscriber` and broadcast agent replies to dashboard clients.

`BrowserChannel.broadcast_response_text()` is thread-safe via
`run_coroutine_threadsafe`, matching the existing trace-event pattern.
Malformed events (no recoverable text) are logged at debug and skipped.
Each kernel `agent_response` event on `bus_dashboard_response` is a
complete per-turn reply (see `apps/cogos/agent_tools_respond.go` — the
`respond` tool is documented as "call at most once per user turn" and
the auto-fallback publishes once if the model skipped the tool call).
We therefore emit two dashboard frames per kernel event:

* ``broadcast_response_text`` — the reply body (chat panel render)
* ``broadcast_response_complete`` — the turn-done signal so the UI's
per-turn spinner clears. Without this, the dashboard hangs
awaiting completion because the kernel path never reaches the
``send_response_complete`` call that the local-inference branch
emits at ``agent_loop._process`` ~L300.

``BrowserChannel.broadcast_response_{text,complete}()`` are thread-safe
via ``run_coroutine_threadsafe``, matching the existing trace-event
pattern. Malformed events (no recoverable text) are logged at debug
and skipped — we do NOT emit a completion frame for skipped events
(keeps the 1:1 pairing with what the UI actually rendered).
"""
first_event_logged = False
forwarded = 0
Expand All @@ -221,6 +236,16 @@ async def run_response_bridge(subscriber: KernelBusSubscriber) -> None:
session_id = _extract_session_id(env.payload)
try:
BrowserChannel.broadcast_response_text(text, session_id=session_id)
# Pair the text frame with a completion frame so the dashboard's
# per-turn "awaiting response" state clears. Kernel emits exactly
# one agent_response per user turn, so one complete per event is
# the correct cardinality.
metrics: dict = {"provider": "cogos-agent"}
if env.event_id:
metrics["event_id"] = env.event_id
if env.ts:
metrics["kernel_ts"] = env.ts
BrowserChannel.broadcast_response_complete(metrics, session_id=session_id)
forwarded += 1
logger.debug(
"cogos-agent: forwarded response event_id=%s session=%s (total=%d)",
Expand Down
Loading
Loading