From acad6f1be49b2ef008eafe81fd44c42e558d2642 Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Thu, 23 Apr 2026 11:46:51 -0400 Subject: [PATCH 1/8] feat(sessions): ADR-082 Phase 1 - session-aware communication bus Introduces SessionRegistry + GlobalSerializer + live output-device resolution so multiple concurrent agents/users can share one Mod3 instance without colliding on voice, queue, or speaker. - session_registry.py: SessionChannel, voice-pool greedy allocation, per-session queues, round-robin/priority/fifo-global policies, live device re-query per playback (ADR-082 2026-04-22 amendment - no caching, macOS CoreAudio default tracked live). - http_api.py: POST /v1/sessions/register, POST /v1/sessions/{id}/deregister, GET /v1/sessions, GET /v1/sessions/{id}. Synthesize honors the session's assigned voice when unspecified. - server.py + mcp_shim.py: mirrored MCP tools (register_session, deregister_session, list_sessions) so stdio MCP callers get the same surface. - Backward-compat: legacy callers without a session_id route to an implicit "default" session. Out of scope (later ADR phases): input routing, barge-in state machine, native input provider. --- http_api.py | 161 ++++++ mcp_shim.py | 323 +++++++++++- server.py | 188 ++++++- session_registry.py | 895 +++++++++++++++++++++++++++++++++ tests/test_session_registry.py | 410 +++++++++++++++ 5 files changed, 1951 insertions(+), 26 deletions(-) create mode 100644 session_registry.py create mode 100644 tests/test_session_registry.py diff --git a/http_api.py b/http_api.py index 6b7260c..3492897 100644 --- a/http_api.py +++ b/http_api.py @@ -39,6 +39,10 @@ from modality import EncodedOutput, ModalityType from modules.text import TextModule from modules.voice import VoiceModule +from session_registry import ( + get_default_registry, + resolve_output_device, +) from vad import detect_speech_file, is_hallucination from vad import is_model_loaded as vad_loaded @@ -244,6 +248,11 @@ class SynthesizeRequest(BaseModel): speed: float = Field(default=1.25) emotion: float = Field(default=0.5) format: str = Field(default="wav", pattern="^(wav|pcm)$") + # ADR-082 Phase 1: optional session routing. When present, the + # session's assigned_voice overrides ``voice`` (unless an explicit + # non-default was passed), and the session is advanced in the global + # serializer's round-robin. + session_id: str | None = Field(default=None) class SpeechRequest(BaseModel): @@ -254,6 +263,9 @@ class SpeechRequest(BaseModel): voice: str = Field(default="af_heart") response_format: str = Field(default="mp3") speed: float = Field(default=1.0) + # ADR-082 Phase 1 extension — not part of the OpenAI schema but harmless + # to accept. When absent, behavior is identical to before Phase 1. + session_id: str | None = Field(default=None) class ShutdownRequest(BaseModel): @@ -263,6 +275,17 @@ class ShutdownRequest(BaseModel): reason: str = Field(default="shutdown-requested") +class SessionRegisterRequest(BaseModel): + """Register a session with the Mod3 communication bus (ADR-082).""" + + session_id: str + participant_id: str + participant_type: str = Field(default="agent") + preferred_voice: str | None = Field(default=None) + preferred_output_device: str = Field(default="system-default") + priority: int = Field(default=0) + + # --------------------------------------------------------------------------- # Shutdown middleware — reject new requests once shutdown is initiated # --------------------------------------------------------------------------- @@ -290,6 +313,39 @@ def synthesize(req: SynthesizeRequest): import numpy as np t_request = time.perf_counter() + + # ADR-082 Phase 1: session routing. If the request names a session, we + # honor the session's assigned voice (unless the caller explicitly + # picked a non-default voice) and account the job against the session's + # queue + serializer so multi-session callers can see round-robin. + session_id = req.session_id + session_payload: dict | None = None + if session_id: + registry = get_default_registry() + session = registry.get(session_id) + if session is None: + return JSONResponse( + status_code=404, + content={ + "error": f"session '{session_id}' is not registered — POST /v1/sessions/register first", + }, + ) + if req.voice == "bm_lewis" and session.assigned_voice != "bm_lewis": + req.voice = session.assigned_voice + # Register the submission with the serializer for accounting only. + # The synthesize endpoint is non-blocking on the audio side (we + # return bytes synchronously), so we do not run the registry's + # dispatcher here — we just record the submission. + try: + registry.submit(session_id, {"type": "synthesize", "text": req.text[:200]}) + except Exception as exc: # noqa: BLE001 + logger.debug("session submit accounting failed: %s", exc) + session_payload = { + "session_id": session.session_id, + "assigned_voice": session.assigned_voice, + "preferred_output_device": session.preferred_output_device, + } + job_id = _record_job( { "type": "synthesize", @@ -301,6 +357,7 @@ def synthesize(req: SynthesizeRequest): "emotion": req.emotion, "format": req.format, "engine": None, + "session_id": session_id, "timeline": [{"event": "request_received", "t": 0.0}], } ) @@ -389,6 +446,8 @@ def synthesize(req: SynthesizeRequest): "X-Mod3-RTF": f"{duration / gen_time:.2f}" if gen_time > 0 else "0", "X-Mod3-Chunks": str(len(chunk_metrics)), } + if session_payload is not None: + headers["X-Mod3-Session-Id"] = session_payload["session_id"] return Response(content=audio_bytes, media_type=media_type, headers=headers) @@ -400,6 +459,30 @@ def audio_speech(req: SpeechRequest): t_request = time.perf_counter() + # ADR-082 Phase 1: optional session routing. Same semantics as + # /v1/synthesize — the session's assigned voice overrides ``voice`` when + # the caller passed the default, and the submission is accounted against + # the session's queue. + session_id = req.session_id + if session_id: + registry = get_default_registry() + session = registry.get(session_id) + if session is None: + return JSONResponse( + status_code=404, + content={ + "error": f"session '{session_id}' is not registered — POST /v1/sessions/register first", + }, + ) + # OpenAI default is af_heart; if the caller left it at the default, + # prefer the session's voice. + if req.voice == "af_heart" and session.assigned_voice != "af_heart": + req.voice = session.assigned_voice + try: + registry.submit(session_id, {"type": "audio_speech", "text": req.input[:200]}) + except Exception as exc: # noqa: BLE001 + logger.debug("session submit accounting failed: %s", exc) + voice = req.voice try: voice = _resolve_voice_via_bus(voice) @@ -414,6 +497,7 @@ def audio_speech(req: SpeechRequest): "text": req.input[:200], "voice": voice, "speed": req.speed, + "session_id": session_id, "timeline": [{"event": "request_received", "t": 0.0}], } ) @@ -464,6 +548,8 @@ def audio_speech(req: SpeechRequest): "X-Mod3-Gen-Time-Sec": f"{gen_time:.3f}", "X-Mod3-Total-Time-Sec": f"{total_time:.3f}", } + if session_id: + headers["X-Mod3-Session-Id"] = session_id return Response(content=audio_bytes, media_type="audio/wav", headers=headers) @@ -650,6 +736,81 @@ def stop_speech(job_id: str = ""): return JSONResponse(status_code=503, content={"error": "Speech queue not available in HTTP-only mode"}) +# --------------------------------------------------------------------------- +# Sessions — ADR-082 Phase 1 +# --------------------------------------------------------------------------- + + +@app.post("/v1/sessions/register") +def session_register(req: SessionRegisterRequest): + """Register a session with the Mod3 communication bus (ADR-082). + + Body: + { + "session_id": "...", + "participant_id": "cog" | "sandy" | "slowbro" | ..., + "participant_type": "agent" | "user", + "preferred_voice": "bm_lewis" | ... | null, + "preferred_output_device": "system-default" | "" + } + + Returns the SessionChannel with a live-resolved output_device. + """ + registry = get_default_registry() + try: + result = registry.register( + session_id=req.session_id, + participant_id=req.participant_id, + participant_type=req.participant_type, + preferred_voice=req.preferred_voice, + preferred_output_device=req.preferred_output_device or "system-default", + priority=req.priority, + ) + except Exception as exc: # noqa: BLE001 — surface the error verbatim + return JSONResponse(status_code=400, content={"error": str(exc)}) + + payload = result.session.to_dict(device_resolver=resolve_output_device) + payload["created"] = result.created + # Top-level live device snapshot so the caller does not have to + # round-trip; the nested one is available for debugging. + payload["output_device"] = registry.resolve_device(result.session.session_id).to_dict() + return payload + + +@app.post("/v1/sessions/{session_id}/deregister") +def session_deregister(session_id: str): + """Drop a session — drains/cancels pending jobs, returns voice to pool.""" + registry = get_default_registry() + result = registry.deregister(session_id) + if result.get("status") == "not_found": + return JSONResponse(status_code=404, content=result) + return result + + +@app.get("/v1/sessions") +def session_list(): + """List all registered sessions plus a serializer snapshot.""" + registry = get_default_registry() + return { + "sessions": registry.list_serialized(), + "serializer": registry.serializer.snapshot(), + "voice_pool": registry.voice_pool(), + "voice_holders": registry.voice_holder_snapshot(), + } + + +@app.get("/v1/sessions/{session_id}") +def session_get(session_id: str): + """Get a single session's current state (with live device resolution).""" + registry = get_default_registry() + session = registry.get(session_id) + if session is None: + return JSONResponse(status_code=404, content={"error": f"session '{session_id}' not found"}) + payload = session.to_dict(device_resolver=resolve_output_device) + payload["output_device"] = registry.resolve_device(session_id).to_dict() + return payload + + @app.get("/health") def health(): """Health check — standardized CogOS service format.""" diff --git a/mcp_shim.py b/mcp_shim.py index 3f6f87b..ca8a63b 100644 --- a/mcp_shim.py +++ b/mcp_shim.py @@ -42,6 +42,14 @@ _current_sd_stream = None _playback_interrupt = threading.Event() +# ADR-082 Phase 1: local session state. Populated by tool_register_session +# so tool_speak can live-resolve the session's preferred output device +# before each playback. The HTTP service owns the canonical registry; this +# is a thin cache so the shim does not have to re-query per play. +_shim_sessions: dict[str, dict[str, Any]] = {} +_shim_sessions_lock = threading.Lock() +_active_session_id: str | None = None + # Job tracking (lightweight — just for speak/stop/status) _jobs: OrderedDict = OrderedDict() _jobs_lock = threading.Lock() @@ -86,7 +94,92 @@ def _http_request(method: str, path: str, body: dict | None = None, timeout: flo return 0, {"error": f"Request failed: {e}"} -def _play_wav_bytes(wav_bytes: bytes, job_id: str): +def _resolve_device_live(preferred: str | None) -> tuple[Any, dict[str, Any]]: + """Resolve an output device live, per the ADR-082 2026-04-22 amendment. + + ``preferred`` mirrors the SessionChannel field: "system-default" re-queries + the OS default immediately; a named device is resolved by substring + against the current device list; a numeric string picks by index; and + ``None`` falls back to the legacy ``_output_device`` module global. + + Returns ``(device_arg, diag)`` where ``device_arg`` is ready to pass to + ``sd.play(device=...)`` and ``diag`` is a dict describing how we resolved + for logging / responses. + """ + try: + import sounddevice as sd + except ImportError: + return None, {"preferred": preferred, "index": None, "reason": "sounddevice not installed"} + + if preferred is None: + return _output_device, { + "preferred": None, + "index": _output_device if isinstance(_output_device, int) else None, + "reason": "legacy module default", + } + + pref = preferred.strip() if isinstance(preferred, str) else "system-default" + if not pref or pref.lower() in ("system-default", "default"): + # Live re-query. sd.default.device is (input, output) — we want output. + try: + devices = sd.query_devices() + default_tuple = sd.default.device + idx = default_tuple[1] if isinstance(default_tuple, (tuple, list)) else None + if isinstance(idx, int) and 0 <= idx < len(devices): + return idx, { + "preferred": pref, + "index": idx, + "name": devices[idx].get("name", ""), + "reason": "OS default (live-queried)", + } + except Exception as exc: # noqa: BLE001 + return None, {"preferred": pref, "index": None, "reason": f"default query failed: {exc}"} + return None, {"preferred": pref, "index": None, "reason": "OS default unknown"} + + # Named / indexed device + try: + devices = sd.query_devices() + except Exception as exc: # noqa: BLE001 + return None, {"preferred": pref, "index": None, "reason": f"query failed: {exc}"} + + if pref.isdigit(): + i = int(pref) + if 0 <= i < len(devices) and devices[i].get("max_output_channels", 0) > 0: + return i, { + "preferred": pref, + "index": i, + "name": devices[i].get("name", ""), + "reason": "index match", + } + + low = pref.lower() + for i, d in enumerate(devices): + if d.get("max_output_channels", 0) > 0 and low in str(d.get("name", "")).lower(): + return i, { + "preferred": pref, + "index": i, + "name": d.get("name", ""), + "reason": "name match", + } + + # Fall back to system default — identity just changed devices. + try: + default_tuple = sd.default.device + idx = default_tuple[1] if isinstance(default_tuple, (tuple, list)) else None + if isinstance(idx, int) and 0 <= idx < len(devices): + return idx, { + "preferred": pref, + "index": idx, + "name": devices[idx].get("name", ""), + "fallback": True, + "reason": f"named device '{pref}' unavailable — fell back to system default", + } + except Exception: + pass + return None, {"preferred": pref, "index": None, "fallback": True, "reason": "no match, no default"} + + +def _play_wav_bytes(wav_bytes: bytes, job_id: str, session_id: str | None = None): """Play WAV audio bytes through speakers via sounddevice.""" global _current_sd_stream try: @@ -126,7 +219,30 @@ def _play_wav_bytes(wav_bytes: bytes, job_id: str): _jobs[job_id]["duration_sec"] = round(duration, 2) _playback_interrupt.clear() - device = _output_device + + # Live device resolution per playback. If the session pins a device, + # honor it; if it's system-default, re-read OS default now. This is + # the core of the ADR-082 2026-04-22 amendment. + preferred: str | None = None + if session_id: + with _shim_sessions_lock: + cfg = _shim_sessions.get(session_id) + if cfg is not None: + preferred = cfg.get("preferred_output_device", "system-default") + if preferred is None and _active_session_id: + with _shim_sessions_lock: + cfg = _shim_sessions.get(_active_session_id) + if cfg is not None: + preferred = cfg.get("preferred_output_device", "system-default") + + if preferred is not None: + device, diag = _resolve_device_live(preferred) + with _jobs_lock: + if job_id in _jobs: + _jobs[job_id]["device_resolution"] = diag + else: + device = _output_device + with _current_player_lock: _current_sd_stream = job_id @@ -165,9 +281,21 @@ def _estimate_duration(text: str, speed: float) -> float: def tool_speak( - text: str, voice: str = "bm_lewis", stream: bool = True, speed: float = 1.25, emotion: float = 0.5 + text: str, + voice: str = "bm_lewis", + stream: bool = True, + speed: float = 1.25, + emotion: float = 0.5, + session_id: str = "", ) -> str: - """Synthesize via HTTP, play locally.""" + """Synthesize via HTTP, play locally. + + When ``session_id`` is provided, the HTTP call includes it so the server + routes through ADR-082 session-aware playback (assigned voice, + preferred output device). The shim resolves its own preferred output + device live from its cached session info — each playback picks up the + current OS default. + """ if not text.strip(): return json.dumps({"status": "error", "error": "Nothing to say"}) @@ -189,16 +317,20 @@ def tool_speak( pass # Request synthesis from HTTP service + synth_body: dict[str, Any] = { + "text": text, + "voice": voice, + "speed": speed, + "emotion": emotion, + "format": "wav", + } + if session_id: + synth_body["session_id"] = session_id + status, resp = _http_request( "POST", "/v1/synthesize", - { - "text": text, - "voice": voice, - "speed": speed, - "emotion": emotion, - "format": "wav", - }, + synth_body, timeout=60.0, ) @@ -218,14 +350,19 @@ def tool_speak( "text": text[:100], "voice": voice, "created": time.time(), + "session_id": session_id or None, } while len(_jobs) > _MAX_JOBS: _jobs.popitem(last=False) - t = threading.Thread(target=_play_wav_bytes, args=(resp, job_id), daemon=True) + t = threading.Thread( + target=_play_wav_bytes, + args=(resp, job_id, session_id or None), + daemon=True, + ) t.start() - return json.dumps({"status": "speaking", "job_id": job_id}) + return json.dumps({"status": "speaking", "job_id": job_id, "session_id": session_id or None}) def tool_stop(job_id: str = "") -> str: @@ -415,6 +552,78 @@ def tool_await_voice_input(timeout_sec: float = 180.0) -> str: return json.dumps({"status": "error", "error": "Could not retrieve transcript"}) +def tool_register_session( + session_id: str, + participant_id: str, + participant_type: str = "agent", + preferred_voice: str = "", + preferred_output_device: str = "system-default", +) -> str: + """Register a session with the Mod3 bus (ADR-082 Phase 1). + + Forwards to POST /v1/sessions/register on the HTTP service, then caches + the result locally so future tool_speak() calls can live-resolve the + session's preferred output device before each playback. + """ + global _active_session_id + + body: dict[str, Any] = { + "session_id": session_id, + "participant_id": participant_id, + "participant_type": participant_type, + "preferred_output_device": preferred_output_device or "system-default", + } + if preferred_voice: + body["preferred_voice"] = preferred_voice + + status, resp = _http_request("POST", "/v1/sessions/register", body, timeout=10.0) + if status != 200: + err = resp.get("error", f"HTTP {status}") if isinstance(resp, dict) else f"HTTP {status}" + return json.dumps({"status": "error", "error": err}) + + # Cache locally — the playback path reads preferred_output_device from + # here each play. + if isinstance(resp, dict): + with _shim_sessions_lock: + _shim_sessions[session_id] = { + "participant_id": participant_id, + "participant_type": participant_type, + "assigned_voice": resp.get("assigned_voice"), + "preferred_output_device": resp.get("preferred_output_device", "system-default"), + } + _active_session_id = session_id + resp["status"] = "ok" + return json.dumps(resp) + return json.dumps({"status": "error", "error": "unexpected response shape"}) + + +def tool_deregister_session(session_id: str) -> str: + """Release a session's voice and drop pending jobs (ADR-082 Phase 1).""" + global _active_session_id + status, resp = _http_request("POST", f"/v1/sessions/{session_id}/deregister", {}, timeout=5.0) + with _shim_sessions_lock: + _shim_sessions.pop(session_id, None) + if _active_session_id == session_id: + _active_session_id = None + if status == 200 and isinstance(resp, dict): + return json.dumps(resp) + if status == 404: + return json.dumps({"status": "not_found", "session_id": session_id}) + err = resp.get("error", f"HTTP {status}") if isinstance(resp, dict) else f"HTTP {status}" + return json.dumps({"status": "error", "error": err}) + + +def tool_list_sessions() -> str: + """List all registered sessions (ADR-082 Phase 1).""" + status, resp = _http_request("GET", "/v1/sessions", timeout=5.0) + if status != 200: + return json.dumps({"status": "error", "error": f"HTTP {status}"}) + if isinstance(resp, dict): + resp["status"] = "ok" + return json.dumps(resp) + return json.dumps({"status": "error", "error": "unexpected response shape"}) + + def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: """VAD check via HTTP.""" if not os.path.exists(file_path): @@ -503,6 +712,15 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "default": 0.5, "description": "Emotion/exaggeration intensity 0.0-1.0 (Chatterbox only). Default 0.5.", }, + "session_id": { + "type": "string", + "default": "", + "description": ( + "Optional ADR-082 session id. When set and the session is registered, " + "playback uses the session's assigned voice + preferred_output_device " + "(live-resolved per playback). When empty, behaves as before." + ), + }, }, "required": ["text"], }, @@ -629,6 +847,73 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "required": ["file_path"], }, }, + { + "name": "register_session", + "description": ( + "Register a session with the Mod3 communication bus (ADR-082 Phase 1).\n\n" + "Each registered session gets its own output queue, an assigned voice from\n" + "the ranked pool, and a preferred output device that is re-queried live per\n" + "playback when set to 'system-default'. Multiple sessions share one physical\n" + "speaker via a global round-robin serializer.\n\n" + "Args:\n" + " session_id: Caller-chosen id (e.g., the Claude Code session id).\n" + " participant_id: Identity of the speaker (e.g., 'cog', 'sandy', 'slowbro').\n" + " participant_type: 'agent' or 'user'. Free-form beyond that.\n" + " preferred_voice: Optional voice preset (e.g., 'bm_lewis'). If taken,\n" + " voice_conflict=true is returned but assignment still succeeds.\n" + " preferred_output_device: 'system-default' (re-queried per playback), a\n" + " device-name substring, or a numeric index." + ), + "inputSchema": { + "type": "object", + "properties": { + "session_id": {"type": "string", "description": "Caller-chosen session id."}, + "participant_id": { + "type": "string", + "description": "Identity of the speaker (e.g., 'cog', 'sandy', 'slowbro').", + }, + "participant_type": { + "type": "string", + "default": "agent", + "description": "'agent' or 'user'. Free-form beyond that.", + }, + "preferred_voice": { + "type": "string", + "default": "", + "description": "Optional voice preset. If taken, voice_conflict is flagged.", + }, + "preferred_output_device": { + "type": "string", + "default": "system-default", + "description": "'system-default', device-name substring, or numeric index.", + }, + }, + "required": ["session_id", "participant_id"], + }, + }, + { + "name": "deregister_session", + "description": ( + "Release a session's voice and drop its pending jobs (ADR-082 Phase 1).\n\n" + "Call at session end so the voice can be allocated to a new session." + ), + "inputSchema": { + "type": "object", + "properties": { + "session_id": {"type": "string", "description": "Session id to deregister."}, + }, + "required": ["session_id"], + }, + }, + { + "name": "list_sessions", + "description": ( + "List all registered sessions with live device resolution (ADR-082 Phase 1).\n\n" + "Returns each session's assigned voice, preferred output device, queue depth,\n" + "and the serializer's current state (policy + round-robin cursor)." + ), + "inputSchema": {"type": "object", "properties": {}}, + }, ] TOOL_DISPATCH = { @@ -638,6 +923,7 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: stream=args.get("stream", True), speed=args.get("speed", 1.25), emotion=args.get("emotion", 0.5), + session_id=args.get("session_id", ""), ), "speech_status": lambda args: tool_speech_status( job_id=args.get("job_id", ""), @@ -656,6 +942,17 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: file_path=args["file_path"], threshold=args.get("threshold", 0.5), ), + "register_session": lambda args: tool_register_session( + session_id=args["session_id"], + participant_id=args["participant_id"], + participant_type=args.get("participant_type", "agent"), + preferred_voice=args.get("preferred_voice", ""), + preferred_output_device=args.get("preferred_output_device", "system-default"), + ), + "deregister_session": lambda args: tool_deregister_session( + session_id=args["session_id"], + ), + "list_sessions": lambda args: tool_list_sessions(), } diff --git a/server.py b/server.py index 8909f02..e520096 100644 --- a/server.py +++ b/server.py @@ -40,6 +40,11 @@ from modality import ModalityType, ModuleStatus from modules.voice import PlaceholderDecoder, VoiceModule from pipeline_state import InterruptInfo, PipelineState +from session_registry import ( + ResolvedOutputDevice, + get_default_registry, + resolve_output_device, +) logger = logging.getLogger("mod3.server") @@ -748,6 +753,28 @@ def _estimate_duration_sec(text: str, speed: float) -> float: return (words / 150.0) * 60.0 / speed +def _resolve_device_for_entry(entry: dict) -> tuple[int | str | None, ResolvedOutputDevice | None]: + """Resolve the output device for a speech job, live. + + Priority (per the ADR-082 2026-04-22 amendment): + 1. If the job's session has a preferred_output_device, re-query live — + "system-default" always reads the current OS default, and named + devices are enumerated per dispatch. + 2. Otherwise fall back to the legacy ``_output_device`` module global + set by set_output_device() so existing callers keep working. + """ + session_id = entry.get("session_id") + if session_id: + try: + registry = get_default_registry() + resolved = registry.resolve_device(session_id) + entry["resolved_device"] = resolved + return resolved.index, resolved + except Exception as exc: # noqa: BLE001 — never fail synthesis on resolution + logger.warning("device resolution failed for session %s: %s", session_id, exc) + return _output_device, None + + def _run_speech_job(entry: dict) -> None: """Execute a single speech job (blocking). Called from the drain thread.""" global _last_metrics, _current_player @@ -765,7 +792,8 @@ def _run_speech_job(entry: dict) -> None: AdaptivePlayer = _adaptive_player_class() engine, resolved_voice = _resolve_voice_via_bus(voice) model = engine_module.get_model(engine) - player = AdaptivePlayer(sample_rate=model.sample_rate, device=_output_device) + device, _resolved = _resolve_device_for_entry(entry) + player = AdaptivePlayer(sample_rate=model.sample_rate, device=device) except Exception as e: _jobs[job_id]["status"] = "error" _jobs[job_id]["error"] = str(e) @@ -865,10 +893,17 @@ def _start_speech( streaming_interval: float = 1.0, speed: float = 1.0, emotion: float = 0.5, + session_id: str | None = None, ) -> tuple[str, int]: """Submit speech to the queue. Returns (job_id, queue_position). queue_position is 0 if playing immediately, >0 if queued behind others. + + When ``session_id`` is provided, the job is tagged with it so the drain + thread can live-resolve the session's preferred output device before + playback. Voice selection still uses the explicit ``voice`` argument — + callers should pass the session's assigned_voice when registering a job + against a session. """ job_id = uuid.uuid4().hex[:8] _jobs[job_id] = { @@ -884,20 +919,21 @@ def _start_speech( "player": None, "speed": speed, "estimated_duration_sec": round(_estimate_duration_sec(text, speed), 1), + "session_id": session_id, } _prune_jobs() - position = _speech_queue.enqueue( - job_id, - { - "text": text, - "voice": voice, - "stream": stream, - "streaming_interval": streaming_interval, - "speed": speed, - "emotion": emotion, - }, - ) + entry = { + "text": text, + "voice": voice, + "stream": stream, + "streaming_interval": streaming_interval, + "speed": speed, + "emotion": emotion, + } + if session_id: + entry["session_id"] = session_id + position = _speech_queue.enqueue(job_id, entry) return job_id, position @@ -947,6 +983,7 @@ def speak( stream: bool = True, speed: float = 1.25, emotion: float = 0.5, + session_id: str = "", ) -> str: """Synthesize text to speech and play it through the user's speakers. @@ -966,10 +1003,39 @@ def speak( If False, generates all audio first then plays (better prosody). speed: Speed multiplier (engines with speed support). Default 1.25. emotion: Emotion/exaggeration intensity 0.0-1.0 (Chatterbox only). Default 0.5. + session_id: Optional ADR-082 session id. When provided and the session + is registered (see register_session), the job is routed + through the per-session queue and the session's assigned + voice + preferred_output_device are used. When empty, + falls back to today's global-queue behavior for backward + compatibility. """ if not text.strip(): return json.dumps({"status": "error", "error": "Nothing to say"}) + # Route through the session registry when session_id is provided. + # If the session is registered, its assigned_voice overrides the ``voice`` + # argument unless the caller explicitly passed a non-default voice — the + # ADR treats voice as a session identity attribute, not a per-call knob. + effective_session_id: str | None = session_id or None + if effective_session_id: + registry = get_default_registry() + session = registry.get(effective_session_id) + if session is None: + return json.dumps( + { + "status": "error", + "error": f"session '{effective_session_id}' is not registered — call register_session first", + } + ) + # If caller did not pass an explicit non-default voice, use the + # session's assigned voice. "bm_lewis" is the old default so we can't + # distinguish "explicit bm_lewis" from "unspecified"; tolerate that + # and only override when the caller asks for the default. + if voice == "bm_lewis" and session.assigned_voice != "bm_lewis": + voice = session.assigned_voice + session.state = "speaking" + # Check if user is currently speaking (barge-in signal file) user_state = "idle" try: @@ -999,7 +1065,14 @@ def speak( ) try: - job_id, position = _start_speech(text, voice, stream=stream, speed=speed, emotion=emotion) + job_id, position = _start_speech( + text, + voice, + stream=stream, + speed=speed, + emotion=emotion, + session_id=effective_session_id, + ) except ValueError as e: return json.dumps({"status": "error", "error": str(e)}) except Exception as e: @@ -1480,6 +1553,95 @@ def set_output_device(device: str = "") -> str: return json.dumps({"status": "ok", "device": _output_device}) +# --------------------------------------------------------------------------- +# Session registry (ADR-082 Phase 1) +# --------------------------------------------------------------------------- + + +@mcp.tool( + annotations={ + "readOnlyHint": False, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + } +) +def register_session( + session_id: str, + participant_id: str, + participant_type: str = "agent", + preferred_voice: str = "", + preferred_output_device: str = "system-default", +) -> str: + """Register a session with the Mod3 communication bus (ADR-082). + + Each registered session gets its own output queue, an assigned voice + from the ranked pool, and a preferred output device. The global + serializer interleaves speech across sessions (round-robin by default) + so two concurrent agents do not collide on the shared speaker. + + Args: + session_id: Caller-chosen id (e.g., the Claude Code session id). + participant_id: Identity of the speaker (e.g., 'cog', 'sandy', 'slowbro'). + participant_type: 'agent' or 'user'. Free-form beyond that. + preferred_voice: Optional voice preset. If taken, voice_conflict=true + is returned but assignment still succeeds. + preferred_output_device: 'system-default' (re-queried per playback), + a device-name substring, or a numeric index. + """ + registry = get_default_registry() + result = registry.register( + session_id=session_id, + participant_id=participant_id, + participant_type=participant_type, + preferred_voice=preferred_voice or None, + preferred_output_device=preferred_output_device or "system-default", + ) + payload = result.session.to_dict(device_resolver=resolve_output_device) + payload["status"] = "ok" + payload["created"] = result.created + # Also expose a live-resolved device at the top level for convenience — + # callers can log or display it without walking nested keys. + live = registry.resolve_device(result.session.session_id) + payload["output_device"] = live.to_dict() + return json.dumps(payload) + + +@mcp.tool( + annotations={ + "readOnlyHint": False, + "destructiveHint": True, + "idempotentHint": True, + "openWorldHint": False, + } +) +def deregister_session(session_id: str) -> str: + """Release a session's voice and drop its pending jobs.""" + registry = get_default_registry() + result = registry.deregister(session_id) + return json.dumps(result) + + +@mcp.tool( + annotations={ + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + } +) +def list_sessions() -> str: + """List all registered sessions with live device resolution.""" + registry = get_default_registry() + return json.dumps( + { + "status": "ok", + "sessions": registry.list_serialized(), + "serializer": registry.serializer.snapshot(), + } + ) + + # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- diff --git a/session_registry.py b/session_registry.py new file mode 100644 index 0000000..1aad8dc --- /dev/null +++ b/session_registry.py @@ -0,0 +1,895 @@ +"""Session-aware communication bus registry (ADR-082 Phase 1). + +Evolves Mod3 from a stateless TTS engine toward a session-aware communication +bus. Each registered session owns a SessionChannel with an assigned voice, a +preferred output device (live-queried by default), and its own output queue. +A global round-robin serializer picks the next job across sessions with +pending work so multi-agent sessions can share one physical speaker without +collisions. + +Scope for Phase 1 (per the ADR's "Migration Path" section): + + - register_session / deregister_session + list_sessions + - Per-session output queues with a global serializer (round-robin default, + priority / fifo-global policies pluggable) + - Voice assignment from the ranked Kokoro pool + - preferred_output_device field on SessionChannel with live OS-default + re-query per playback (2026-04-22 amendment) + +Out of scope for Phase 1: input routing, barge-in state machine, native input +provider. See later phases of the ADR. + +Backward compatibility: an implicit "default" session is created on first +use so existing callers that do not supply session_id keep working exactly +as before. +""" + +from __future__ import annotations + +import atexit +import heapq +import logging +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable + +logger = logging.getLogger("mod3.session_registry") + + +# --------------------------------------------------------------------------- +# Voice pool — ranked allocation per ADR-082 § Voice Assignment +# --------------------------------------------------------------------------- + +# The canonical Kokoro voice pool, ranked. Greedy allocation walks this list. +# bm_lewis is first because it is the legacy default; remaining Kokoro voices +# follow in the order the ADR specifies (heart / emma / adam / bella / isabella), +# then the unlisted Kokoro voices round out the pool so we do not run dry. +VOICE_POOL: tuple[str, ...] = ( + "bm_lewis", + "af_heart", + "bf_emma", + "am_adam", + "af_bella", + "bf_isabella", + "bm_george", + "am_michael", + "af_sarah", + "af_nicole", + "af_sky", +) + +DEFAULT_SESSION_ID = "default" +DEFAULT_PARTICIPANT_ID = "legacy" +DEFAULT_PARTICIPANT_TYPE = "agent" + +SERIALIZATION_POLICIES = ("round-robin", "priority", "fifo-global") + + +# --------------------------------------------------------------------------- +# Output-device resolution (ADR-082 amendment — 2026-04-22) +# --------------------------------------------------------------------------- + + +@dataclass +class ResolvedOutputDevice: + """Resolution result for an output-device lookup. + + ``preferred`` is the string the session asked for ("system-default" or a + named device). ``index`` is the sounddevice index that was chosen, or + ``None`` to let PortAudio pick — callers should treat ``None`` as + "PortAudio fallback" rather than relying on it implicitly resolving to + the current system default (PortAudio behavior is inconsistent across + versions, per the ADR amendment). + + ``name`` mirrors the resolved device's name when known. ``fallback`` is + True when the requested name was unavailable and we backed off to the + system default. + """ + + preferred: str + index: int | None + name: str + fallback: bool = False + reason: str = "" + + def to_dict(self) -> dict[str, Any]: + return { + "preferred": self.preferred, + "index": self.index, + "name": self.name, + "fallback": self.fallback, + "reason": self.reason, + } + + +def resolve_output_device( + preferred: str | None = "system-default", + *, + query_devices: Callable[[], list[dict[str, Any]]] | None = None, + default_output_index: Callable[[], int | None] | None = None, +) -> ResolvedOutputDevice: + """Resolve ``preferred`` to a concrete sounddevice index, live. + + For ``"system-default"`` (or None/empty) the OS default is re-queried + every call — no caching. When pinned to a named device we enumerate the + current device list and pick by substring match; if nothing matches we + fall back to system-default and flag ``fallback=True`` so callers can + emit a device_fallback warning event. + + The ``query_devices`` and ``default_output_index`` callables are test + seams; they default to sounddevice when omitted. This keeps the live + requery contract explicit — every call goes through the callable, no + module-level state caches the previous result. + """ + + if query_devices is None or default_output_index is None: + + def _query() -> list[dict[str, Any]]: + import sounddevice as sd + + return list(sd.query_devices()) + + def _default() -> int | None: + import sounddevice as sd + + # sd.default.device is (input_index, output_index). We only care + # about the output side. Re-read every call so we pick up whatever + # the OS chose just now. + value = sd.default.device + if isinstance(value, (tuple, list)) and len(value) >= 2: + idx = value[1] + return int(idx) if isinstance(idx, int) and idx >= 0 else None + return None + + query_devices = query_devices or _query + default_output_index = default_output_index or _default + + pref = (preferred or "system-default").strip() + if not pref: + pref = "system-default" + + try: + devices = query_devices() + except Exception as exc: # noqa: BLE001 — device enumeration can fail noisily + logger.warning("resolve_output_device: query_devices failed: %s", exc) + return ResolvedOutputDevice( + preferred=pref, + index=None, + name="(unresolved)", + fallback=True, + reason=f"query_devices failed: {exc}", + ) + + def _default_resolution() -> ResolvedOutputDevice: + try: + idx = default_output_index() + except Exception as exc: # noqa: BLE001 + logger.warning("resolve_output_device: default_output_index failed: %s", exc) + return ResolvedOutputDevice( + preferred=pref, + index=None, + name="(system default, unresolved)", + fallback=True, + reason=f"default_output_index failed: {exc}", + ) + + if idx is None or idx < 0 or idx >= len(devices): + # Falling back to PortAudio implicit — we would rather be explicit + # but the platform did not give us a usable default. + return ResolvedOutputDevice( + preferred=pref, + index=None, + name="(system default)", + fallback=False, + reason="OS default unknown — PortAudio implicit", + ) + + name = devices[idx].get("name", "(unknown)") + return ResolvedOutputDevice( + preferred=pref, + index=idx, + name=name, + fallback=False, + reason="OS default", + ) + + if pref.lower() in ("system-default", "default"): + return _default_resolution() + + # Named device — try numeric index first, then substring on name. + if pref.isdigit(): + idx = int(pref) + if 0 <= idx < len(devices) and devices[idx].get("max_output_channels", 0) > 0: + return ResolvedOutputDevice( + preferred=pref, + index=idx, + name=devices[idx].get("name", "(unknown)"), + fallback=False, + reason="index match", + ) + + pref_lower = pref.lower() + for i, d in enumerate(devices): + if d.get("max_output_channels", 0) <= 0: + continue + if pref_lower in str(d.get("name", "")).lower(): + return ResolvedOutputDevice( + preferred=pref, + index=i, + name=d.get("name", "(unknown)"), + fallback=False, + reason="name match", + ) + + # No match — fall back to system default. + resolved = _default_resolution() + resolved.fallback = True + resolved.reason = f"named device '{pref}' unavailable — fell back to system default" + logger.warning( + "resolve_output_device: %s unavailable, falling back to system default (idx=%s name=%s)", + pref, + resolved.index, + resolved.name, + ) + return resolved + + +# --------------------------------------------------------------------------- +# Session channel +# --------------------------------------------------------------------------- + + +SessionState = str # "idle" | "speaking" | "blocked" | "waiting_for_input" + + +@dataclass +class SessionChannel: + """A registered session's bus endpoint. + + Holds identity, voice, device preference, a per-session output queue, + and lifecycle state. The instance is mutated in place under + ``SessionRegistry``'s lock — callers should not hold references across + deregistration. + """ + + session_id: str + participant_id: str + participant_type: str + assigned_voice: str + voice_conflict: bool = False + preferred_voice: str | None = None + preferred_output_device: str = "system-default" + state: SessionState = "idle" + priority: int = 0 + registered_at: float = field(default_factory=time.time) + last_active: float = field(default_factory=time.time) + # Internal: pending jobs waiting for the global serializer to pick them up. + # Elements are opaque to the registry; the serializer pops them in FIFO + # order within a session. + _queue: deque[Any] = field(default_factory=deque, repr=False) + # Per-session monotonic submit counter — used as a tie-breaker for the + # global round-robin policy and for fifo-global arrival ordering. + _submit_seq: int = field(default=0, repr=False) + + def to_dict(self, *, device_resolver: Callable[[str], ResolvedOutputDevice] | None = None) -> dict[str, Any]: + """Serialize for API responses. ``device_resolver`` lets the caller + re-query the current device (per the ADR amendment — no caching). + """ + d = { + "session_id": self.session_id, + "participant_id": self.participant_id, + "participant_type": self.participant_type, + "assigned_voice": self.assigned_voice, + "voice_conflict": self.voice_conflict, + "preferred_voice": self.preferred_voice, + "preferred_output_device": self.preferred_output_device, + "state": self.state, + "priority": self.priority, + "registered_at": self.registered_at, + "last_active": self.last_active, + "queue_depth": len(self._queue), + } + if device_resolver is not None: + try: + d["output_device"] = device_resolver(self.preferred_output_device).to_dict() + except Exception as exc: # noqa: BLE001 — never fail serialization on device enumeration + d["output_device"] = { + "preferred": self.preferred_output_device, + "index": None, + "name": "(unresolved)", + "fallback": True, + "reason": f"resolver error: {exc}", + } + return d + + +# --------------------------------------------------------------------------- +# Global serializer +# --------------------------------------------------------------------------- + + +@dataclass(order=True) +class _SerializedJob: + """Internal envelope for the global serializer. + + Priority-tuple ordering: lower sort-key first. We sort by + (negative_priority, arrival_seq) so higher-priority jobs win and FIFO + breaks ties. ``payload`` is opaque — the serializer only forwards it to + the dispatch callback. + """ + + sort_key: tuple[int, int] + session_id: str = field(compare=False) + submit_seq: int = field(compare=False) + payload: Any = field(compare=False, default=None) + + +class GlobalSerializer: + """Layer-2 serializer over per-session output queues. + + Policies (ADR-082 § Output Serialization): + + * ``round-robin`` — alternate across sessions with pending work. This + is the default and the reason Phase 1 exists at all: two concurrent + sessions should not be able to starve each other. + * ``priority`` — highest-priority session drains first; ties fall back + to round-robin ordering. + * ``fifo-global`` — strict arrival order across all sessions (matches + today's single-global-queue behavior for migration parity). + + The serializer is push-driven: registrations submit jobs, and a + dedicated dispatcher thread picks the next one per policy. Callers + receive a ``QueuedJob``-shaped handle so existing call sites (speech + queue, bus.act) can swap in without type changes. + """ + + def __init__( + self, + *, + policy: str = "round-robin", + dispatcher: Callable[[str, Any], Any] | None = None, + now: Callable[[], float] = time.time, + ): + if policy not in SERIALIZATION_POLICIES: + raise ValueError(f"unknown serialization policy: {policy}") + self._policy = policy + self._dispatcher = dispatcher + self._now = now + + self._lock = threading.RLock() + self._cond = threading.Condition(self._lock) + # round-robin cursor — list of session_ids in rotation order; we pop + # from the front and append to the back when a session has more work. + self._rr_cursor: deque[str] = deque() + self._rr_seen: set[str] = set() + # priority heap for "priority" policy: (neg_prio, submit_seq, session, job_id) + self._priority_heap: list[tuple[int, int, str, str]] = [] + # fifo-global heap: (submit_seq, session, job_id) + self._fifo_heap: list[tuple[int, str, str]] = [] + # monotonic arrival counter shared across all sessions + self._global_seq = 0 + + self._sessions: dict[str, SessionChannel] = {} + self._thread: threading.Thread | None = None + self._stopping = False + # Diagnostics: order in which jobs are dispatched. Bounded; newest + # first. Primarily for tests and the /v1/sessions dashboard. + self._dispatch_log: deque[tuple[float, str, str]] = deque(maxlen=256) + + # -- Policy plumbing ---------------------------------------------------- + + @property + def policy(self) -> str: + return self._policy + + def set_policy(self, policy: str) -> None: + if policy not in SERIALIZATION_POLICIES: + raise ValueError(f"unknown serialization policy: {policy}") + with self._lock: + self._policy = policy + + def attach_dispatcher(self, dispatcher: Callable[[str, Any], Any]) -> None: + """Set or replace the per-job dispatcher. + + The dispatcher receives ``(session_id, payload)`` on the dispatcher + thread and runs synchronously — its completion is what advances the + queue. The intent is: dispatcher blocks until the playback for this + job finishes, preserving the single-speaker contract. + """ + with self._lock: + self._dispatcher = dispatcher + + # -- Registration management ------------------------------------------- + + def attach_session(self, session: SessionChannel) -> None: + with self._lock: + self._sessions[session.session_id] = session + + def detach_session(self, session_id: str) -> int: + """Drop a session and any queued jobs. Returns count of jobs dropped.""" + with self._lock: + session = self._sessions.pop(session_id, None) + dropped = 0 + if session is not None: + dropped = len(session._queue) + session._queue.clear() + # Best-effort: prune cursor entries for this session. The heaps + # may still reference stale jobs; _next_job() will skip them. + self._rr_seen.discard(session_id) + self._rr_cursor = deque(s for s in self._rr_cursor if s != session_id) + return dropped + + # -- Submission --------------------------------------------------------- + + def submit( + self, + session_id: str, + payload: Any, + *, + priority: int | None = None, + ) -> str: + """Submit a job for ``session_id``. + + Returns an opaque job_id string. The payload is handed verbatim to + the dispatcher when its turn comes up. Raises KeyError if the + session is not registered — callers that want auto-registration + should go through SessionRegistry.submit() instead. + """ + with self._cond: + session = self._sessions.get(session_id) + if session is None: + raise KeyError(f"session '{session_id}' is not registered") + job_id = uuid.uuid4().hex[:8] + self._global_seq += 1 + session._submit_seq += 1 + prio = priority if priority is not None else session.priority + session._queue.append((self._global_seq, job_id, payload)) + session.last_active = self._now() + + # Update scheduling structures + if session_id not in self._rr_seen: + self._rr_seen.add(session_id) + self._rr_cursor.append(session_id) + heapq.heappush(self._priority_heap, (-prio, self._global_seq, session_id, job_id)) + heapq.heappush(self._fifo_heap, (self._global_seq, session_id, job_id)) + + self._cond.notify_all() + return job_id + + # -- Dispatch thread ---------------------------------------------------- + + def start(self) -> None: + """Start the dispatcher thread. Idempotent.""" + with self._lock: + if self._thread is not None and self._thread.is_alive(): + return + self._stopping = False + self._thread = threading.Thread( + target=self._run, + name="mod3-global-serializer", + daemon=True, + ) + self._thread.start() + + def stop(self) -> None: + with self._cond: + self._stopping = True + self._cond.notify_all() + t = self._thread + if t is not None: + t.join(timeout=2.0) + self._thread = None + + def _run(self) -> None: + while True: + with self._cond: + while not self._stopping and not self._has_pending_unlocked(): + self._cond.wait(timeout=0.5) + if self._stopping: + return + picked = self._pop_next_unlocked() + dispatcher = self._dispatcher + if picked is None: + continue + session_id, job_id, payload = picked + self._dispatch_log.appendleft((self._now(), session_id, job_id)) + if dispatcher is None: + logger.debug( + "GlobalSerializer: no dispatcher attached — dropping job %s for %s", + job_id, + session_id, + ) + continue + try: + dispatcher(session_id, payload) + except Exception as exc: # noqa: BLE001 — keep dispatcher robust + logger.exception("GlobalSerializer dispatcher raised: %s", exc) + + def _has_pending_unlocked(self) -> bool: + for s in self._sessions.values(): + if s._queue: + return True + return False + + def _pop_next_unlocked(self) -> tuple[str, str, Any] | None: + """Pick and remove the next job according to policy.""" + if self._policy == "round-robin": + return self._pop_round_robin_unlocked() + if self._policy == "priority": + return self._pop_priority_unlocked() + # fifo-global + return self._pop_fifo_unlocked() + + def _pop_round_robin_unlocked(self) -> tuple[str, str, Any] | None: + # Walk the cursor until we find a session with pending work. Skip + # sessions whose queues are empty — we rebuild the cursor lazily. + checked = 0 + total = len(self._rr_cursor) + while checked < total and self._rr_cursor: + sid = self._rr_cursor.popleft() + session = self._sessions.get(sid) + if session is None or not session._queue: + self._rr_seen.discard(sid) + checked += 1 + continue + seq, job_id, payload = session._queue.popleft() + # If the session still has more, put it at the back of the + # rotation so other sessions go first. + if session._queue: + self._rr_cursor.append(sid) + else: + self._rr_seen.discard(sid) + return sid, job_id, payload + return None + + def _pop_priority_unlocked(self) -> tuple[str, str, Any] | None: + while self._priority_heap: + neg_prio, seq, sid, job_id = heapq.heappop(self._priority_heap) + session = self._sessions.get(sid) + if session is None or not session._queue: + continue + # Pop the matching (seq, job_id) from the session queue. We match + # by seq because heap order and queue order can diverge when two + # sessions submit at once. + found_idx = None + for i, (qseq, qjid, _payload) in enumerate(session._queue): + if qseq == seq and qjid == job_id: + found_idx = i + break + if found_idx is None: + continue + qseq, qjid, payload = session._queue[found_idx] + del session._queue[found_idx] + return sid, qjid, payload + return None + + def _pop_fifo_unlocked(self) -> tuple[str, str, Any] | None: + while self._fifo_heap: + seq, sid, job_id = heapq.heappop(self._fifo_heap) + session = self._sessions.get(sid) + if session is None or not session._queue: + continue + found_idx = None + for i, (qseq, qjid, _payload) in enumerate(session._queue): + if qseq == seq and qjid == job_id: + found_idx = i + break + if found_idx is None: + continue + qseq, qjid, payload = session._queue[found_idx] + del session._queue[found_idx] + return sid, qjid, payload + return None + + # -- Introspection ------------------------------------------------------ + + def snapshot(self) -> dict[str, Any]: + with self._lock: + return { + "policy": self._policy, + "sessions": { + sid: { + "queue_depth": len(s._queue), + "last_active": s.last_active, + "state": s.state, + } + for sid, s in self._sessions.items() + }, + "rr_cursor": list(self._rr_cursor), + "recent_dispatches": list(self._dispatch_log), + } + + +# --------------------------------------------------------------------------- +# Session registry +# --------------------------------------------------------------------------- + + +@dataclass +class RegistrationResult: + session: SessionChannel + created: bool # False when re-registering an existing session_id + voice_conflict: bool + + +class SessionRegistry: + """Thread-safe registry of SessionChannels. + + Owns voice-pool allocation, device preference, and the global serializer. + The registry is deliberately independent of ModalityBus — tests can spin + it up without the full bus stack, and the bus can adopt it incrementally. + """ + + def __init__( + self, + *, + voice_pool: Iterable[str] | None = None, + serializer: GlobalSerializer | None = None, + device_resolver: Callable[[str], ResolvedOutputDevice] | None = None, + ): + self._lock = threading.RLock() + self._sessions: dict[str, SessionChannel] = {} + self._voice_pool: list[str] = list(voice_pool if voice_pool is not None else VOICE_POOL) + # Track which voice is currently held by which session, first-come + # first-served. A second request for the same voice is honored (voice + # is identity — collisions should be rare) but flagged. + self._voice_holder: dict[str, str] = {} + self._serializer = serializer or GlobalSerializer() + self._device_resolver = device_resolver or resolve_output_device + + # -- Lifecycle ---------------------------------------------------------- + + @property + def serializer(self) -> GlobalSerializer: + return self._serializer + + def start(self) -> None: + self._serializer.start() + + def stop(self) -> None: + self._serializer.stop() + + # -- Session management ------------------------------------------------- + + def register( + self, + *, + session_id: str, + participant_id: str, + participant_type: str = DEFAULT_PARTICIPANT_TYPE, + preferred_voice: str | None = None, + preferred_output_device: str = "system-default", + priority: int = 0, + ) -> RegistrationResult: + with self._lock: + existing = self._sessions.get(session_id) + if existing is not None: + existing.participant_id = participant_id + existing.participant_type = participant_type + existing.preferred_output_device = preferred_output_device or "system-default" + existing.last_active = time.time() + # Don't reshuffle voice on re-register. If the caller wants a + # different voice they should deregister first. + return RegistrationResult(existing, created=False, voice_conflict=existing.voice_conflict) + + voice, conflict = self._allocate_voice(session_id, preferred_voice) + session = SessionChannel( + session_id=session_id, + participant_id=participant_id, + participant_type=participant_type, + assigned_voice=voice, + voice_conflict=conflict, + preferred_voice=preferred_voice, + preferred_output_device=preferred_output_device or "system-default", + priority=priority, + ) + self._sessions[session_id] = session + self._serializer.attach_session(session) + logger.info( + "registered session: id=%s participant=%s voice=%s conflict=%s device=%s", + session_id, + participant_id, + voice, + conflict, + preferred_output_device, + ) + return RegistrationResult(session, created=True, voice_conflict=conflict) + + def deregister(self, session_id: str) -> dict[str, Any]: + with self._lock: + session = self._sessions.pop(session_id, None) + if session is None: + return {"status": "not_found", "session_id": session_id} + dropped = self._serializer.detach_session(session_id) + # Return the voice to the pool if we still hold it for this session. + if self._voice_holder.get(session.assigned_voice) == session_id: + del self._voice_holder[session.assigned_voice] + logger.info( + "deregistered session: id=%s voice=%s dropped_jobs=%d", + session_id, + session.assigned_voice, + dropped, + ) + return { + "status": "ok", + "session_id": session_id, + "released_voice": session.assigned_voice, + "dropped_jobs": dropped, + } + + def get(self, session_id: str) -> SessionChannel | None: + with self._lock: + return self._sessions.get(session_id) + + def get_or_create_default(self) -> SessionChannel: + """Backward-compat path — route legacy callers to an implicit session.""" + with self._lock: + session = self._sessions.get(DEFAULT_SESSION_ID) + if session is not None: + return session + result = self.register( + session_id=DEFAULT_SESSION_ID, + participant_id=DEFAULT_PARTICIPANT_ID, + participant_type=DEFAULT_PARTICIPANT_TYPE, + preferred_voice=None, + preferred_output_device="system-default", + ) + return result.session + + def list(self) -> list[SessionChannel]: + with self._lock: + return list(self._sessions.values()) + + def list_serialized(self) -> list[dict[str, Any]]: + with self._lock: + resolver = self._device_resolver + return [s.to_dict(device_resolver=resolver) for s in self._sessions.values()] + + def voice_pool(self) -> list[str]: + return list(self._voice_pool) + + def voice_holder_snapshot(self) -> dict[str, str]: + with self._lock: + return dict(self._voice_holder) + + # -- Submission --------------------------------------------------------- + + def submit( + self, + session_id: str | None, + payload: Any, + *, + priority: int | None = None, + auto_create_default: bool = True, + ) -> tuple[str, str]: + """Enqueue ``payload`` on ``session_id``. + + When session_id is None and auto_create_default is True, falls back + to the "default" session so legacy call sites keep working. + + Returns ``(resolved_session_id, job_id)``. + """ + if session_id is None or session_id == "": + if not auto_create_default: + raise ValueError("session_id is required when auto_create_default=False") + session = self.get_or_create_default() + session_id = session.session_id + else: + with self._lock: + if session_id not in self._sessions: + raise KeyError(f"session '{session_id}' is not registered") + + job_id = self._serializer.submit(session_id, payload, priority=priority) + return session_id, job_id + + # -- Device routing ----------------------------------------------------- + + def resolve_device(self, session_id: str) -> ResolvedOutputDevice: + """Live-resolve the output device for ``session_id``. + + Per the 2026-04-22 amendment, this is a live property: the OS default + is re-queried every call when the session's preference is + ``system-default``. Never cache the return value on the session. + """ + with self._lock: + session = self._sessions.get(session_id) + preferred = session.preferred_output_device if session else "system-default" + return self._device_resolver(preferred) + + def set_preferred_device(self, session_id: str, preferred: str) -> ResolvedOutputDevice: + with self._lock: + session = self._sessions.get(session_id) + if session is None: + raise KeyError(session_id) + session.preferred_output_device = preferred or "system-default" + return self.resolve_device(session_id) + + # -- Internals ---------------------------------------------------------- + + def _allocate_voice(self, session_id: str, preferred: str | None) -> tuple[str, bool]: + # Explicit preference — honor it, flag if someone else already holds + # it. Voice is identity, not exclusive; collisions are operator bugs + # at worst. + if preferred: + if preferred not in self._voice_pool: + # Extend the pool lazily so out-of-band voices still work. + self._voice_pool.append(preferred) + holder = self._voice_holder.get(preferred) + conflict = holder is not None and holder != session_id + self._voice_holder.setdefault(preferred, session_id) + return preferred, conflict + + # Greedy allocation — first voice in the pool whose holder is absent + # or dead. + for voice in self._voice_pool: + if voice not in self._voice_holder: + self._voice_holder[voice] = session_id + return voice, False + + # Pool exhausted — fall back to the first voice and flag a collision. + fallback = self._voice_pool[0] if self._voice_pool else "bm_lewis" + return fallback, True + + +# --------------------------------------------------------------------------- +# Process-global default registry — shared across server.py, http_api.py, +# and tests. Tests can instantiate their own SessionRegistry if they need +# isolation; the module-level singleton is just a convenience for runtime. +# --------------------------------------------------------------------------- + + +_default_registry: SessionRegistry | None = None +_default_registry_lock = threading.Lock() + + +def get_default_registry() -> SessionRegistry: + global _default_registry + with _default_registry_lock: + if _default_registry is None: + _default_registry = SessionRegistry() + _default_registry.start() + return _default_registry + + +def reset_default_registry() -> None: + """For tests — tear down the module-level registry.""" + global _default_registry + with _default_registry_lock: + if _default_registry is not None: + _default_registry.stop() + _default_registry = None + + +@atexit.register +def _shutdown_default_registry() -> None: + """Stop the dispatcher thread cleanly at interpreter shutdown. + + Without this the daemon thread gets killed mid-callback during + finalization and Python emits a Fatal Python error about a NULL + thread state. The dispatcher is idempotent for stop() so this is + safe to call even if the registry was never created. + """ + try: + reset_default_registry() + except Exception: + pass + + +__all__ = [ + "DEFAULT_SESSION_ID", + "DEFAULT_PARTICIPANT_ID", + "DEFAULT_PARTICIPANT_TYPE", + "GlobalSerializer", + "RegistrationResult", + "ResolvedOutputDevice", + "SessionChannel", + "SessionRegistry", + "SERIALIZATION_POLICIES", + "VOICE_POOL", + "get_default_registry", + "reset_default_registry", + "resolve_output_device", +] diff --git a/tests/test_session_registry.py b/tests/test_session_registry.py new file mode 100644 index 0000000..f6df350 --- /dev/null +++ b/tests/test_session_registry.py @@ -0,0 +1,410 @@ +"""Unit + integration tests for the ADR-082 Phase 1 session registry. + +Covers: + * Voice-pool allocation (greedy + preferred + collision flagging) + * Device resolution with a stubbed sounddevice-shaped lookup + — "system-default" live re-query, named-device match, fallback + * Global serializer round-robin across sessions + * SessionRegistry submit() auto-creating a "default" session for + legacy callers + * HTTP surface: /v1/sessions endpoints smoke-tested via FastAPI TestClient + +Run with: ``.venv/bin/python -m pytest tests/test_session_registry.py -v`` +""" + +from __future__ import annotations + +import os +import sys +import threading +from typing import Any + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from session_registry import ( # noqa: E402 + DEFAULT_SESSION_ID, + VOICE_POOL, + GlobalSerializer, + SessionRegistry, + resolve_output_device, +) + +# --------------------------------------------------------------------------- +# Voice-pool allocation +# --------------------------------------------------------------------------- + + +class TestVoiceAllocation: + def test_greedy_allocation_walks_the_pool(self): + reg = SessionRegistry() + a = reg.register(session_id="s1", participant_id="a").session + b = reg.register(session_id="s2", participant_id="b").session + c = reg.register(session_id="s3", participant_id="c").session + assert a.assigned_voice == VOICE_POOL[0] + assert b.assigned_voice == VOICE_POOL[1] + assert c.assigned_voice == VOICE_POOL[2] + assert not a.voice_conflict + assert not b.voice_conflict + assert not c.voice_conflict + + def test_preferred_voice_honored(self): + reg = SessionRegistry() + result = reg.register(session_id="s1", participant_id="a", preferred_voice="bf_emma") + assert result.session.assigned_voice == "bf_emma" + assert not result.voice_conflict + + def test_preferred_voice_collision_flagged_but_assigned(self): + reg = SessionRegistry() + first = reg.register(session_id="s1", participant_id="a", preferred_voice="bm_lewis") + second = reg.register(session_id="s2", participant_id="b", preferred_voice="bm_lewis") + # Both get the voice (collision is a flag, not a veto — per ADR) + assert first.session.assigned_voice == "bm_lewis" + assert not first.voice_conflict + assert second.session.assigned_voice == "bm_lewis" + assert second.voice_conflict + + def test_deregister_returns_voice_to_pool(self): + reg = SessionRegistry() + reg.register(session_id="s1", participant_id="a") # takes bm_lewis + result = reg.deregister("s1") + assert result["status"] == "ok" + assert result["released_voice"] == VOICE_POOL[0] + # Next registration picks up the released voice + next_session = reg.register(session_id="s2", participant_id="b").session + assert next_session.assigned_voice == VOICE_POOL[0] + + def test_out_of_pool_preferred_voice_is_added(self): + """Using a voice the pool didn't know about still works.""" + reg = SessionRegistry() + result = reg.register(session_id="s1", participant_id="a", preferred_voice="af_ember") + assert result.session.assigned_voice == "af_ember" + assert "af_ember" in reg.voice_pool() + + +# --------------------------------------------------------------------------- +# Device resolution +# --------------------------------------------------------------------------- + + +def _fake_device_list() -> list[dict[str, Any]]: + """Synthetic device list — stable across calls.""" + return [ + {"name": "MacBook Pro Speakers", "max_output_channels": 2}, + {"name": "DisplayLink Monitor", "max_output_channels": 2}, + {"name": "Realtek USB2.0 Audio", "max_output_channels": 2}, + {"name": "Microphone", "max_output_channels": 0}, + ] + + +class TestDeviceResolution: + def test_system_default_live_requery(self): + """The default must be re-read every call — not cached.""" + current = {"idx": 0} + + def query() -> list[dict[str, Any]]: + return _fake_device_list() + + def default_idx() -> int: + return current["idx"] + + r1 = resolve_output_device("system-default", query_devices=query, default_output_index=default_idx) + assert r1.index == 0 + assert r1.name == "MacBook Pro Speakers" + assert not r1.fallback + + # User plugs in headphones; the OS default changes. + current["idx"] = 2 + r2 = resolve_output_device("system-default", query_devices=query, default_output_index=default_idx) + assert r2.index == 2 + assert r2.name == "Realtek USB2.0 Audio" + + def test_named_device_match(self): + r = resolve_output_device( + "Realtek", + query_devices=_fake_device_list, + default_output_index=lambda: 0, + ) + assert r.index == 2 + assert r.name == "Realtek USB2.0 Audio" + assert not r.fallback + + def test_named_device_missing_falls_back_to_default(self): + r = resolve_output_device( + "AirPods Pro", # not plugged in + query_devices=_fake_device_list, + default_output_index=lambda: 1, + ) + assert r.fallback is True + assert r.index == 1 + assert r.name == "DisplayLink Monitor" + assert "fell back" in r.reason.lower() + + def test_numeric_index_match(self): + r = resolve_output_device( + "2", + query_devices=_fake_device_list, + default_output_index=lambda: 0, + ) + assert r.index == 2 + assert r.name == "Realtek USB2.0 Audio" + + def test_empty_preferred_treated_as_default(self): + r = resolve_output_device( + "", + query_devices=_fake_device_list, + default_output_index=lambda: 1, + ) + assert r.index == 1 + + def test_default_index_out_of_range_returns_implicit(self): + r = resolve_output_device( + "system-default", + query_devices=_fake_device_list, + default_output_index=lambda: 99, + ) + assert r.index is None + # Not a "fallback" in the device-fallback sense — OS just couldn't tell us + assert "portaudio" in r.reason.lower() or "unknown" in r.reason.lower() + + +# --------------------------------------------------------------------------- +# Global serializer +# --------------------------------------------------------------------------- + + +class TestGlobalSerializer: + def _build(self, policy: str = "round-robin") -> tuple[GlobalSerializer, list[tuple[str, Any]]]: + """Return (serializer, log) where log records (session_id, payload).""" + log: list[tuple[str, Any]] = [] + barrier = threading.Event() + + def dispatcher(session_id: str, payload: Any) -> None: + log.append((session_id, payload)) + barrier.set() + + ser = GlobalSerializer(policy=policy, dispatcher=dispatcher) + return ser, log + + def test_round_robin_interleaves_two_sessions(self): + reg = SessionRegistry() + ser = reg.serializer + + events: list[tuple[str, Any]] = [] + completed = threading.Event() + counter = {"n": 0} + + def dispatcher(session_id: str, payload: Any) -> None: + events.append((session_id, payload)) + counter["n"] += 1 + if counter["n"] >= 6: + completed.set() + + ser.attach_dispatcher(dispatcher) + reg.register(session_id="A", participant_id="a") + reg.register(session_id="B", participant_id="b") + + # Submit 3 jobs to each before starting the dispatcher so we can + # observe round-robin ordering instead of racing with a ticking queue. + for i in range(3): + reg.submit("A", {"job": f"A{i}"}) + for i in range(3): + reg.submit("B", {"job": f"B{i}"}) + + reg.start() + assert completed.wait(timeout=2.0), f"serializer stalled: {events}" + reg.stop() + + sessions = [sid for sid, _ in events] + # Round-robin starting from A (first session with pending work): + # the cursor walks A → B alternately. + expected = ["A", "B", "A", "B", "A", "B"] + assert sessions == expected, f"Expected round-robin, got {sessions}" + + def test_fifo_global_preserves_arrival_order(self): + reg = SessionRegistry() + reg.serializer.set_policy("fifo-global") + events: list[tuple[str, Any]] = [] + done = threading.Event() + + def dispatcher(session_id: str, payload: Any) -> None: + events.append((session_id, payload)) + if len(events) == 4: + done.set() + + reg.serializer.attach_dispatcher(dispatcher) + reg.register(session_id="A", participant_id="a") + reg.register(session_id="B", participant_id="b") + + reg.submit("A", "a1") + reg.submit("B", "b1") + reg.submit("A", "a2") + reg.submit("B", "b2") + + reg.start() + assert done.wait(timeout=2.0) + reg.stop() + + order = [payload for _, payload in events] + assert order == ["a1", "b1", "a2", "b2"], f"FIFO broken: {order}" + + def test_priority_policy_drains_higher_priority_first(self): + reg = SessionRegistry() + reg.serializer.set_policy("priority") + events: list[tuple[str, Any]] = [] + done = threading.Event() + + def dispatcher(session_id: str, payload: Any) -> None: + events.append((session_id, payload)) + if len(events) == 3: + done.set() + + reg.serializer.attach_dispatcher(dispatcher) + reg.register(session_id="low", participant_id="low", priority=0) + reg.register(session_id="hi", participant_id="hi", priority=10) + + reg.submit("low", "L1") + reg.submit("low", "L2") + reg.submit("hi", "H1") # this should drain first under priority + + reg.start() + assert done.wait(timeout=2.0) + reg.stop() + + assert events[0] == ("hi", "H1"), f"Priority not honored: {events}" + + def test_submit_on_unregistered_session_raises(self): + reg = SessionRegistry() + with pytest.raises(KeyError): + reg.submit("ghost", "payload", auto_create_default=False) + + def test_submit_without_session_id_auto_creates_default(self): + reg = SessionRegistry() + done = threading.Event() + events: list[tuple[str, Any]] = [] + + def dispatcher(session_id: str, payload: Any) -> None: + events.append((session_id, payload)) + done.set() + + reg.serializer.attach_dispatcher(dispatcher) + reg.submit(None, "legacy-call") + reg.start() + assert done.wait(timeout=1.0) + reg.stop() + + # Auto-created default session + assert events[0][0] == DEFAULT_SESSION_ID + assert reg.get(DEFAULT_SESSION_ID) is not None + + def test_deregister_drops_queued_jobs(self): + reg = SessionRegistry() + reg.register(session_id="A", participant_id="a") + reg.submit("A", "j1") + reg.submit("A", "j2") + result = reg.deregister("A") + assert result["dropped_jobs"] == 2 + + +# --------------------------------------------------------------------------- +# HTTP surface +# --------------------------------------------------------------------------- + + +class TestHTTPSessionEndpoints: + """Smoke-test the HTTP endpoints via FastAPI TestClient. + + We use the live registry module singleton — tests register unique session + ids with a 'pytest-' prefix and deregister in teardown so we do not leak. + """ + + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + + import http_api + + return TestClient(http_api.app) + + @pytest.fixture(autouse=True) + def _cleanup(self): + yield + # Clean up any pytest- sessions that leaked. + from session_registry import get_default_registry + + reg = get_default_registry() + for s in list(reg.list()): + if s.session_id.startswith("pytest-"): + reg.deregister(s.session_id) + + def test_register_returns_session_channel(self, client): + r = client.post( + "/v1/sessions/register", + json={ + "session_id": "pytest-s1", + "participant_id": "pytest-cog", + "participant_type": "agent", + "preferred_output_device": "system-default", + }, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["session_id"] == "pytest-s1" + assert body["participant_id"] == "pytest-cog" + assert body["assigned_voice"] in VOICE_POOL + assert body["preferred_output_device"] == "system-default" + # live-resolved device block present + assert "output_device" in body + assert "preferred" in body["output_device"] + + def test_list_includes_registered_session(self, client): + client.post( + "/v1/sessions/register", + json={"session_id": "pytest-list-1", "participant_id": "pytest-sandy"}, + ) + r = client.get("/v1/sessions") + assert r.status_code == 200 + body = r.json() + ids = [s["session_id"] for s in body["sessions"]] + assert "pytest-list-1" in ids + assert "voice_pool" in body + assert body["serializer"]["policy"] in ("round-robin", "priority", "fifo-global") + + def test_deregister_releases_voice(self, client): + client.post( + "/v1/sessions/register", + json={ + "session_id": "pytest-dr-1", + "participant_id": "pytest-cog", + "preferred_voice": "bf_emma", + }, + ) + r = client.post("/v1/sessions/pytest-dr-1/deregister") + assert r.status_code == 200 + body = r.json() + assert body["status"] == "ok" + assert body["released_voice"] == "bf_emma" + + def test_deregister_unknown_returns_404(self, client): + r = client.post("/v1/sessions/pytest-nonexistent/deregister") + assert r.status_code == 404 + + def test_get_single_session(self, client): + client.post( + "/v1/sessions/register", + json={"session_id": "pytest-get-1", "participant_id": "pytest-user"}, + ) + r = client.get("/v1/sessions/pytest-get-1") + assert r.status_code == 200 + assert r.json()["session_id"] == "pytest-get-1" + + def test_synthesize_rejects_unknown_session(self, client): + r = client.post( + "/v1/synthesize", + json={ + "text": "hello", + "session_id": "pytest-ghost", + }, + ) + assert r.status_code == 404 From cff22f35a67acd0db53ec0a47adc882d14ff3f11 Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Thu, 23 Apr 2026 11:47:07 -0400 Subject: [PATCH 2/8] fix(dashboard): emit response_complete on kernel-bridged turns Regression: with MOD3_USE_COGOS_AGENT=1, agent_loop's success path returns before the local-inference path's send_response_complete, leaving the dashboard's isResponding spinner hung forever. - channels.py: BrowserChannel.broadcast_response_complete(metrics, session_id) - thread-safe companion to broadcast_response_text, routes to the same channel that received the text frames. - cogos_agent_bridge.py: on agent_response receipt, emit the complete frame after the text frame. - demo/e2e_dashboard_harness.py + tests: updated to assert the completion frame fires on both code paths. --- channels.py | 39 +++++++++++++ cogos_agent_bridge.py | 31 +++++++++- demo/e2e_dashboard_harness.py | 2 +- tests/test_browser_channel_routing.py | 84 +++++++++++++++++++++++++++ tests/test_cogos_agent_bridge.py | 66 +++++++++++++++++---- 5 files changed, 206 insertions(+), 16 deletions(-) diff --git a/channels.py b/channels.py index 9f6f592..94dfa67 100644 --- a/channels.py +++ b/channels.py @@ -14,6 +14,11 @@ partial_transcript, transcript, trace_event — kernel cycle-trace events (ADR-083), fanned out via BrowserChannel.broadcast_trace_event(). + +The MOD3_USE_COGOS_AGENT kernel-bridged path emits response_text AND +response_complete via BrowserChannel.broadcast_response_{text,complete} +so the dashboard UI's turn-done signal fires on every turn, matching the +local-inference path's behavior. """ from __future__ import annotations @@ -493,6 +498,40 @@ def broadcast_response_text(cls, text: str, session_id: str | None = None) -> No except Exception as exc: # noqa: BLE001 — disconnected clients are expected logger.debug("response_text send failed for %s: %s", ch.channel_id, exc) + @classmethod + def broadcast_response_complete( + cls, + metrics: dict | None = None, + session_id: str | None = None, + ) -> None: + """Push a `response_complete` frame to dashboard WebSocket clients. + + Companion to :meth:`broadcast_response_text`: the MOD3_USE_COGOS_AGENT + response bridge emits exactly one complete-frame per kernel + `agent_response` event so the dashboard UI's per-turn `isResponding` + state gets cleared (otherwise the chat panel spinner hangs forever). + + Routing and threading match `broadcast_response_text` 1:1 — pass the + same `session_id` so the completion frame lands on the same channel + that received the text frames for this turn. `metrics` follows the + local-path convention from `agent_loop._process` (`{"llm_ms": ..., + "provider": ...}`); the kernel path populates it with + `{"provider": "cogos-agent", ...}`. + """ + frame = {"type": "response_complete", "metrics": metrics or {}} + expected_channel = None + if session_id and session_id.startswith("mod3:"): + expected_channel = session_id[len("mod3:") :] + for ch in list(cls._active_channels): + if not ch._active: + continue + if expected_channel and ch.channel_id != expected_channel: + continue + try: + asyncio.run_coroutine_threadsafe(ch.ws.send_json(frame), ch._loop) + except Exception as exc: # noqa: BLE001 — disconnected clients are expected + logger.debug("response_complete send failed for %s: %s", ch.channel_id, exc) + # ------------------------------------------------------------------ # Cleanup # ------------------------------------------------------------------ diff --git a/cogos_agent_bridge.py b/cogos_agent_bridge.py index c8616f2..05ba8ab 100644 --- a/cogos_agent_bridge.py +++ b/cogos_agent_bridge.py @@ -194,9 +194,24 @@ def _extract_response_text(payload: dict) -> Optional[str]: async def run_response_bridge(subscriber: KernelBusSubscriber) -> None: """Consume `subscriber` and broadcast agent replies to dashboard clients. - `BrowserChannel.broadcast_response_text()` is thread-safe via - `run_coroutine_threadsafe`, matching the existing trace-event pattern. - Malformed events (no recoverable text) are logged at debug and skipped. + Each kernel `agent_response` event on `bus_dashboard_response` is a + complete per-turn reply (see `apps/cogos/agent_tools_respond.go` — the + `respond` tool is documented as "call at most once per user turn" and + the auto-fallback publishes once if the model skipped the tool call). + We therefore emit two dashboard frames per kernel event: + + * ``broadcast_response_text`` — the reply body (chat panel render) + * ``broadcast_response_complete`` — the turn-done signal so the UI's + per-turn spinner clears. Without this, the dashboard hangs + awaiting completion because the kernel path never reaches the + ``send_response_complete`` call that the local-inference branch + emits at ``agent_loop._process`` ~L300. + + ``BrowserChannel.broadcast_response_{text,complete}()`` are thread-safe + via ``run_coroutine_threadsafe``, matching the existing trace-event + pattern. Malformed events (no recoverable text) are logged at debug + and skipped — we do NOT emit a completion frame for skipped events + (keeps the 1:1 pairing with what the UI actually rendered). """ first_event_logged = False forwarded = 0 @@ -221,6 +236,16 @@ async def run_response_bridge(subscriber: KernelBusSubscriber) -> None: session_id = _extract_session_id(env.payload) try: BrowserChannel.broadcast_response_text(text, session_id=session_id) + # Pair the text frame with a completion frame so the dashboard's + # per-turn "awaiting response" state clears. Kernel emits exactly + # one agent_response per user turn, so one complete per event is + # the correct cardinality. + metrics: dict = {"provider": "cogos-agent"} + if env.event_id: + metrics["event_id"] = env.event_id + if env.ts: + metrics["kernel_ts"] = env.ts + BrowserChannel.broadcast_response_complete(metrics, session_id=session_id) forwarded += 1 logger.debug( "cogos-agent: forwarded response event_id=%s session=%s (total=%d)", diff --git a/demo/e2e_dashboard_harness.py b/demo/e2e_dashboard_harness.py index dd67df9..213a12c 100644 --- a/demo/e2e_dashboard_harness.py +++ b/demo/e2e_dashboard_harness.py @@ -120,7 +120,7 @@ async def deferred_interrupt() -> None: audio_pcm.append(wav_bytes) except Exception as e: print(f" (audio decode err: {e})", flush=True) - elif t == "response_done" or t == "turn_complete": + elif t == "response_complete": got_done = True done_ts = time.time() elif t == "trace_event": diff --git a/tests/test_browser_channel_routing.py b/tests/test_browser_channel_routing.py index fd7f89c..917049b 100644 --- a/tests/test_browser_channel_routing.py +++ b/tests/test_browser_channel_routing.py @@ -92,6 +92,17 @@ def _broadcast_with_loop(text: str, session_id: str | None = None) -> None: loop.close() +def _broadcast_complete_with_loop(metrics: dict | None = None, session_id: str | None = None) -> None: + """Sibling of _broadcast_with_loop for the response_complete frame.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + with patch("channels.asyncio.run_coroutine_threadsafe", _patched_run): + BrowserChannel.broadcast_response_complete(metrics, session_id=session_id) + finally: + loop.close() + + def test_broadcast_with_no_session_id_fans_out_to_all_active_channels(): loop = asyncio.new_event_loop() a = _FakeChannel("browser:aaa", loop) @@ -169,5 +180,78 @@ def test_broadcast_session_id_with_no_match_drops_silently(): loop.close() +# --------------------------------------------------------------------------- +# response_complete routing (kernel-path turn-done signal) +# +# Paired with broadcast_response_text in cogos_agent_bridge.run_response_bridge; +# must follow the same session-scoped routing so the complete-frame lands on +# the originating dashboard channel (otherwise multi-client setups see +# cross-talk or hang on a non-matching spinner). +# --------------------------------------------------------------------------- + + +def test_broadcast_complete_with_no_session_id_fans_out_to_all_active_channels(): + loop = asyncio.new_event_loop() + a = _FakeChannel("browser:aaa", loop) + b = _FakeChannel("browser:bbb", loop) + BrowserChannel._active_channels.update({a, b}) + + _broadcast_complete_with_loop({"provider": "cogos-agent"}) + + for ch in (a, b): + assert len(ch.ws.sent) == 1 + frame = ch.ws.sent[0] + assert frame["type"] == "response_complete" + assert frame["metrics"] == {"provider": "cogos-agent"} + + loop.close() + + +def test_broadcast_complete_routes_to_matching_session_only(): + loop = asyncio.new_event_loop() + a = _FakeChannel("browser:aaa", loop) + b = _FakeChannel("browser:bbb", loop) + BrowserChannel._active_channels.update({a, b}) + + _broadcast_complete_with_loop( + {"provider": "cogos-agent", "event_id": "r42"}, + session_id="mod3:browser:bbb", + ) + + assert a.ws.sent == [] + assert len(b.ws.sent) == 1 + assert b.ws.sent[0]["type"] == "response_complete" + assert b.ws.sent[0]["metrics"]["event_id"] == "r42" + + loop.close() + + +def test_broadcast_complete_defaults_to_empty_metrics(): + loop = asyncio.new_event_loop() + a = _FakeChannel("browser:aaa", loop) + BrowserChannel._active_channels.add(a) + + _broadcast_complete_with_loop() + + assert a.ws.sent == [{"type": "response_complete", "metrics": {}}] + + loop.close() + + +def test_broadcast_complete_skips_inactive_channels(): + loop = asyncio.new_event_loop() + a = _FakeChannel("browser:aaa", loop) + a._active = False + b = _FakeChannel("browser:bbb", loop) + BrowserChannel._active_channels.update({a, b}) + + _broadcast_complete_with_loop({"provider": "cogos-agent"}) + + assert a.ws.sent == [] + assert len(b.ws.sent) == 1 + + loop.close() + + if __name__ == "__main__": sys.exit(pytest.main([__file__, "-v"])) diff --git a/tests/test_cogos_agent_bridge.py b/tests/test_cogos_agent_bridge.py index 2d5de43..5063fd0 100644 --- a/tests/test_cogos_agent_bridge.py +++ b/tests/test_cogos_agent_bridge.py @@ -104,11 +104,20 @@ def test_run_response_bridge_fans_out_to_broadcast(): ] sub = _FakeSubscriber(envelopes) - with patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_bcast: + with ( + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_text, + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_complete") as mock_done, + ): asyncio.run(run_response_bridge(sub)) - texts = [c.args[0] for c in mock_bcast.call_args_list] + texts = [c.args[0] for c in mock_text.call_args_list] assert texts == ["reply one", "free-form string reply"] + # One completion frame per forwarded text — never zero, never doubled. + assert mock_done.call_count == 2 + # Payload shape: provider tag plus any kernel-supplied timing/ids. + for call in mock_done.call_args_list: + metrics = call.args[0] + assert metrics["provider"] == "cogos-agent" # --------------------------------------------------------------------------- @@ -144,14 +153,22 @@ def test_run_response_bridge_forwards_session_id_when_present(): envelopes = [_env({"content": inner}, "r1")] sub = _FakeSubscriber(envelopes) - with patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_bcast: + with ( + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_text, + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_complete") as mock_done, + ): asyncio.run(run_response_bridge(sub)) - assert mock_bcast.call_count == 1 - call = mock_bcast.call_args_list[0] - assert call.args[0] == "scoped reply" + assert mock_text.call_count == 1 + text_call = mock_text.call_args_list[0] + assert text_call.args[0] == "scoped reply" # session_id passed as keyword - assert call.kwargs.get("session_id") == "mod3:browser:abc" + assert text_call.kwargs.get("session_id") == "mod3:browser:abc" + # Completion frame routes to the same session so the originating + # channel's spinner clears — not a broadcast. + assert mock_done.call_count == 1 + done_call = mock_done.call_args_list[0] + assert done_call.kwargs.get("session_id") == "mod3:browser:abc" def test_run_response_bridge_falls_back_to_broadcast_when_no_session_id(): @@ -160,13 +177,38 @@ def test_run_response_bridge_falls_back_to_broadcast_when_no_session_id(): envelopes = [_env({"content": inner}, "r1")] sub = _FakeSubscriber(envelopes) - with patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_bcast: + with ( + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_text, + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_complete") as mock_done, + ): asyncio.run(run_response_bridge(sub)) - assert mock_bcast.call_count == 1 - call = mock_bcast.call_args_list[0] - assert call.args[0] == "broadcast reply" - assert call.kwargs.get("session_id") is None + assert mock_text.call_count == 1 + text_call = mock_text.call_args_list[0] + assert text_call.args[0] == "broadcast reply" + assert text_call.kwargs.get("session_id") is None + # Completion frame also broadcasts (matches the text-frame routing). + assert mock_done.call_count == 1 + assert mock_done.call_args_list[0].kwargs.get("session_id") is None + + +def test_run_response_bridge_skips_complete_when_text_missing(): + """Events with no recoverable text must not emit a completion frame. + + Holding the 1:1 pairing keeps the UI's turn counter honest: we only + mark a turn done when we actually rendered something for it. + """ + envelopes = [_env({"foo": "bar"}, "r1"), _env({}, "r2")] + sub = _FakeSubscriber(envelopes) + + with ( + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_text, + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_complete") as mock_done, + ): + asyncio.run(run_response_bridge(sub)) + + assert mock_text.call_count == 0 + assert mock_done.call_count == 0 def test_post_user_message_uses_runtime_endpoint(monkeypatch): From 69dd70dece308caf4ddda239538d091b717fc163 Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Thu, 23 Apr 2026 14:34:47 -0400 Subject: [PATCH 3/8] feat(dashboard): participant panel + auto-register on page load MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wave 4.1 + 4.2 of the mod3-kernel integration. The dashboard now owns its own bus identity instead of being an anonymous WebSocket client. On page load: 1. Reuse a session_id from sessionStorage (refreshes stay on the same identity). 2. Otherwise POST to the kernel's /v1/channel-sessions/register — ADR-082 Wave 3.5 says session-id minting is kernel-owned. On CORS / kernel-down the JS falls back to mod3's /v1/sessions/register direct so the dashboard keeps working in a mod3-only deployment. 3. Poll GET /v1/sessions every 4s, render the live roster. 4. On beforeunload, navigator.sendBeacon a best-effort deregister so the voice returns to the pool without waiting for a sweep. The participant panel is a collapsible drawer keyed off a header pill (count + plural). Rows show participant_id, assigned_voice, session_id prefix, age, and participant_type badge. The "self" row is pulled to the top and highlighted with a green left-border + "you" pill. window.__mod3Session is exposed for Wave 4.3 — the audio WebSocket subscription will key off its session_id, and a "mod3-session-registered" CustomEvent fires when registration completes so late-loaded scripts can subscribe without polling. Branching: stacked on feat/session-registry-adr-082-phase1 because the /v1/sessions endpoints this UI depends on only exist on that branch (Phase 1 of the session registry). --- dashboard/index.html | 334 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 334 insertions(+) diff --git a/dashboard/index.html b/dashboard/index.html index ae166c5..0e80ffa 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -214,11 +214,86 @@ /* Leave room at the bottom so the drawer doesn't cover the input */ body { padding-bottom: 32px; } + /* Participant panel (ADR-082 / Wave 4) — live session roster. Renders as a + collapsible drawer on the right side of the main pane. */ + .participants-pill { + display: flex; align-items: center; gap: 6px; + font-size: 0.75rem; color: var(--muted); + padding: 4px 10px; border-radius: 12px; + background: var(--bg); border: 1px solid var(--border); + cursor: pointer; user-select: none; + } + .participants-pill:hover { border-color: var(--accent); } + .participants-pill .count { color: var(--accent); font-variant-numeric: tabular-nums; } + .participants-pill .count.me { color: var(--green); } + #participants-panel { + position: fixed; top: 48px; right: 16px; + background: var(--surface); border: 1px solid var(--border); border-radius: 8px; + width: 320px; max-height: 60vh; overflow-y: auto; + box-shadow: 0 8px 24px rgba(0,0,0,0.4); + z-index: 30; display: none; padding: 8px 0; + } + #participants-panel.open { display: block; } + #participants-panel .pp-header { + padding: 8px 14px; border-bottom: 1px solid var(--border); + font-size: 0.7rem; color: var(--muted); text-transform: uppercase; + letter-spacing: 0.5px; display: flex; align-items: center; gap: 8px; + } + #participants-panel .pp-header .session-id { + font-family: ui-monospace, SFMono-Regular, Menlo, monospace; + color: var(--text); text-transform: none; letter-spacing: 0; + margin-left: auto; font-size: 0.7rem; + } + #participants-list { list-style: none; padding: 0; margin: 0; } + #participants-list .pp-empty { + padding: 16px 14px; font-size: 0.8rem; color: var(--muted); text-align: center; + } + .pp-row { + display: flex; align-items: center; gap: 10px; + padding: 8px 14px; border-bottom: 1px solid rgba(48,54,61,0.4); + font-size: 0.8rem; + } + .pp-row:last-child { border-bottom: none; } + .pp-row.me { + background: rgba(63,185,80,0.07); + border-left: 2px solid var(--green); + padding-left: 12px; + } + .pp-row .pp-info { flex: 1; min-width: 0; } + .pp-row .pp-name { color: var(--text); font-weight: 500; } + .pp-row .pp-name .you-marker { + display: inline-block; margin-left: 4px; padding: 0 6px; + background: var(--green); color: #000; border-radius: 8px; + font-size: 0.6rem; font-weight: 600; text-transform: uppercase; + letter-spacing: 0.5px; vertical-align: middle; + } + .pp-row .pp-meta { + font-size: 0.68rem; color: var(--muted); + font-family: ui-monospace, SFMono-Regular, Menlo, monospace; + display: flex; gap: 6px; flex-wrap: wrap; margin-top: 2px; + } + .pp-row .pp-meta .pp-voice { color: var(--accent); } + .pp-row .pp-meta .pp-session { opacity: 0.7; } + .pp-row .pp-badge { + font-size: 0.6rem; padding: 1px 6px; border-radius: 8px; + background: var(--bg); border: 1px solid var(--border); color: var(--muted); + text-transform: uppercase; letter-spacing: 0.5px; + } + .pp-row .pp-badge.user { color: var(--accent); border-color: var(--accent); } + .pp-row .pp-badge.agent { color: var(--orange); border-color: var(--orange); } + .pp-row .pp-conflict { color: var(--orange); font-size: 0.65rem; margin-left: 4px; } + .pp-row .pp-audio-dot { + width: 6px; height: 6px; border-radius: 50%; + background: var(--muted); flex-shrink: 0; + } + .pp-row .pp-audio-dot.ws { background: var(--green); } + /* Responsive */ @media (max-width: 700px) { .main { padding: 12px 16px; } .voice-controls { flex-wrap: wrap; } .msg { max-width: 90%; } + #participants-panel { right: 8px; width: calc(100vw - 16px); } } @@ -235,6 +310,22 @@

Mod³

Mic +
+ 👥 + 0 + sessions +
+ + + +
+
+ Participants + unregistered +
+
    +
  • No active sessions
  • +
@@ -666,6 +757,249 @@

Mod³

chatInput.style.height = 'auto'; chatInput.style.height = Math.min(chatInput.scrollHeight, 120) + 'px'; }); + +/** + * Wave 4 — Session registration + participant panel. + * + * On page load: + * 1. Try to reuse a session_id from sessionStorage (so refreshes reuse the + * same bus identity). + * 2. Otherwise, POST to the kernel's /v1/channel-sessions/register (the + * kernel owns minting authority — ADR-082 Wave 3.5). Fall back to mod3's + * direct /v1/sessions/register when the kernel is unreachable so the + * dashboard still works when only mod3 is running. + * 3. Poll GET /v1/sessions every 4s to render the live roster. + * 4. On beforeunload, best-effort deregister so the voice goes back to the + * pool quickly. + * + * Exposes window.__mod3Session for the audio-WebSocket code in Wave 4.3 to + * consume — that code keys its /ws/audio/{session_id} subscription off this + * value. + */ +(function setupSessionRegistration() { + const KERNEL_URL = 'http://localhost:6931'; // Wave 3.5 kernel-owned authority + const SESSION_KEY = 'mod3.sessionId'; + const POLL_INTERVAL_MS = 4000; + + // Short UUID helper (8 hex chars — matches the kernel's cs- short-id + // style for merged displays). + function shortUuid() { + const bytes = new Uint8Array(4); + crypto.getRandomValues(bytes); + return Array.from(bytes, b => b.toString(16).padStart(2, '0')).join(''); + } + + // Session state published on window so other scripts can see it. + window.__mod3Session = { + session_id: null, + participant_id: null, + assigned_voice: null, + registered: false, + }; + + async function registerSession() { + // Reuse cached session if present — refreshes stay on the same identity. + const cached = sessionStorage.getItem(SESSION_KEY); + const reuseId = cached ? JSON.parse(cached) : null; + + const participantId = (reuseId && reuseId.participant_id) || + ('dashboard-' + shortUuid()); + const sessionId = reuseId ? reuseId.session_id : undefined; + + const payload = { + participant_id: participantId, + participant_type: 'user', + preferred_voice: null, + preferred_output_device: 'system-default', + }; + if (sessionId) payload.session_id = sessionId; + + // 1. Try kernel-owned endpoint first. + let data = null; + try { + const r = await fetch(`${KERNEL_URL}/v1/channel-sessions/register`, { + method: 'POST', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify({...payload, kinds: ['audio']}), + }); + if (r.ok) { + data = await r.json(); + // Kernel returns {kernel: {...}, mod3: {...}} — normalize to mod3 shape + if (data.mod3) { + data = {...data.mod3, session_id: data.kernel.session_id || data.mod3.session_id}; + } + console.log('[Session] Registered via kernel:', data.session_id); + } + } catch (e) { + console.log('[Session] Kernel unreachable, falling back to mod3 direct:', e.message); + } + + // 2. Fallback: mod3 direct registration. + if (!data) { + try { + // Mod3's /v1/sessions/register requires a session_id in the payload — + // mint one on the client when we don't have a cached one. This is + // fallback-path only; the kernel path auto-mints on empty. + const directPayload = {...payload}; + if (!directPayload.session_id) { + directPayload.session_id = 'dashboard-' + shortUuid() + '-' + Date.now(); + } + const r = await fetch('/v1/sessions/register', { + method: 'POST', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify(directPayload), + }); + if (r.ok) { + data = await r.json(); + console.log('[Session] Registered via mod3 direct:', data.session_id); + } else { + console.error('[Session] mod3 register failed:', r.status, await r.text()); + return; + } + } catch (e) { + console.error('[Session] mod3 register threw:', e); + return; + } + } + + if (!data || !data.session_id) return; + + window.__mod3Session = { + session_id: data.session_id, + participant_id: data.participant_id || participantId, + assigned_voice: data.assigned_voice || null, + registered: true, + }; + + sessionStorage.setItem(SESSION_KEY, JSON.stringify({ + session_id: data.session_id, + participant_id: data.participant_id || participantId, + })); + + const selfLabel = document.getElementById('participants-self-id'); + if (selfLabel) { + selfLabel.textContent = data.session_id; + selfLabel.style.color = 'var(--green)'; + } + + // Dispatch event so the audio-WS bootstrap (Wave 4.3) can subscribe. + window.dispatchEvent(new CustomEvent('mod3-session-registered', { + detail: window.__mod3Session, + })); + } + + async function pollSessions() { + try { + const r = await fetch('/v1/sessions'); + if (!r.ok) return; + const body = await r.json(); + renderParticipantPanel(body.sessions || []); + } catch (e) { + // silent — the /health pill already tracks connectivity + } + } + + function renderParticipantPanel(sessions) { + const listEl = document.getElementById('participants-list'); + const countEl = document.getElementById('participants-count'); + const pluralEl = document.getElementById('participants-plural'); + if (!listEl || !countEl) return; + + const selfId = window.__mod3Session && window.__mod3Session.session_id; + + countEl.textContent = String(sessions.length); + if (pluralEl) pluralEl.textContent = sessions.length === 1 ? '' : 's'; + if (countEl) { + const hasSelf = selfId && sessions.some(s => s.session_id === selfId); + countEl.className = hasSelf ? 'count me' : 'count'; + } + + if (sessions.length === 0) { + listEl.innerHTML = '
  • No active sessions
  • '; + return; + } + + // Sort: self first, then agents, then users, each by registered_at desc + const sorted = [...sessions].sort((a, b) => { + if (a.session_id === selfId) return -1; + if (b.session_id === selfId) return 1; + return (b.registered_at || 0) - (a.registered_at || 0); + }); + + listEl.innerHTML = sorted.map(s => { + const isSelf = s.session_id === selfId; + const kind = (s.participant_type || 'agent').toLowerCase(); + const voice = s.assigned_voice || '—'; + const conflict = s.voice_conflict + ? '' : ''; + const pid = escapeHtml(s.participant_id || ''); + const sid = escapeHtml(s.session_id || ''); + const ageSec = s.registered_at ? Math.round(Date.now()/1000 - s.registered_at) : 0; + const ageStr = ageSec < 60 ? ageSec + 's' : + ageSec < 3600 ? Math.round(ageSec/60) + 'm' : + Math.round(ageSec/3600) + 'h'; + const youMarker = isSelf ? ' you' : ''; + return ( + '
  • ' + + '' + + '
    ' + + '
    ' + pid + youMarker + conflict + '
    ' + + '
    ' + + '' + escapeHtml(voice) + '' + + '' + sid.slice(0,16) + + (sid.length > 16 ? '…' : '') + '' + + '' + ageStr + '' + + '
    ' + + '
    ' + + '' + escapeHtml(kind) + '' + + '
  • ' + ); + }).join(''); + } + + // Toggle panel open/close + const pillEl = document.getElementById('participants-toggle'); + const panelEl = document.getElementById('participants-panel'); + if (pillEl && panelEl) { + pillEl.addEventListener('click', (e) => { + e.stopPropagation(); + panelEl.classList.toggle('open'); + }); + // Click-outside-to-close + document.addEventListener('click', (e) => { + if (!panelEl.contains(e.target) && !pillEl.contains(e.target)) { + panelEl.classList.remove('open'); + } + }); + } + + async function deregisterOnUnload() { + const s = window.__mod3Session; + if (!s || !s.session_id) return; + // Use navigator.sendBeacon when available — it survives the unload event. + const url = `/v1/sessions/${encodeURIComponent(s.session_id)}/deregister`; + try { + if (navigator.sendBeacon) { + navigator.sendBeacon(url, new Blob(['{}'], {type: 'application/json'})); + } else { + // Blocking fallback for browsers without sendBeacon. + fetch(url, {method: 'POST', keepalive: true, body: '{}', + headers: {'Content-Type': 'application/json'}}); + } + } catch { + // Best-effort; server sweep will clean up eventually. + } + } + + window.addEventListener('beforeunload', deregisterOnUnload); + + // Kick off: register, then start polling. Register THEN poll so the + // first roster render can mark self correctly. + registerSession().finally(() => { + pollSessions(); + setInterval(pollSessions, POLL_INTERVAL_MS); + }); +})(); From a5321ee94011765b31c8037bb50f08d85438a390 Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Thu, 23 Apr 2026 14:40:38 -0400 Subject: [PATCH 4/8] feat(channels): /ws/audio/{session_id} WebSocket for per-session playback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wave 4.3 mod3 side — route synthesized audio to the dashboard via a per-session WebSocket instead of (or in addition to) the server's sounddevice / afplay fallback. New module: audio_subscribers.py AudioSubscriberRegistry holds session_id → [subscriber] with register/unregister/has_subscribers/count/emit_wav. emit_wav pushes a JSON header frame + binary WAV frame through each subscriber's WebSocket via run_coroutine_threadsafe on the socket's event loop, matching the BrowserChannel.broadcast_trace_event pattern. New endpoints (http_api.py): WS /ws/audio/{session_id} — accept + register + hold open GET /v1/sessions/{session_id}/subscribers — returns {"session_id": ..., "subscribed": bool, "count": N}. Unknown session_ids intentionally return subscribed=false instead of 404 so the kernel's pre-afplay check stays a single predicate. /v1/synthesize now also emits the generated WAV over the WebSocket when the request names a session and at least one subscriber is attached. Emit is best-effort (disconnect mid-send just drops the frame) and non-blocking on the HTTP path. A new X-Mod3-WS-Subscribers response header reports how many subscribers received the blob; callers use this to skip their local playback. mcp_shim._play_wav_bytes gains a pre-check (_session_has_ws_subscriber) that GETs /v1/sessions/{id}/subscribers with a 1.5s timeout. When subscribed=true we skip sounddevice entirely and record status=routed_ws in the job ledger. Keeps the legacy path unchanged when no session is attached or the HTTP check fails. Dashboard wiring (dashboard/index.html): A new IIFE opens ws://host:7860/ws/audio/ after the session-registered event fires, listens for audio_header + binary frames, and plays the WAV through AudioContext.decodeAudioData. Reconnect on close with exponential backoff up to 30s. The self-row audio-dot indicator flips green while the WS is up. AudioContext is resumed on first user gesture to satisfy the autoplay policy. Tests (tests/test_audio_subscribers.py): - AudioSubscriberRegistry unit tests: register/unregister, multiple subscribers per session, empty-bucket pruning, emit_wav delivers header+bytes, no-subscriber returns zero, default registry is a shared singleton. - HTTP tests via FastAPI TestClient: /subscribers returns false for unknown sessions, /ws/audio upgrade flips subscribed=true and disconnect flips it back to false. - Integration test (guarded by SKIP_TTS_TESTS env var because it loads Kokoro): /v1/synthesize with a session + subscriber pushes a RIFF/WAVE binary frame through the WebSocket and reports X-Mod3-WS-Subscribers: 1. All 32 existing session-registry tests + 9 new tests pass (1 skipped for Kokoro cold-start). Ruff clean. Branch: feat/dashboard-wave4, stacked on feat/session-registry-adr-082-phase1 (Phase 1 /v1/sessions surface). --- audio_subscribers.py | 221 ++++++++++++++++++++++++++ dashboard/index.html | 139 ++++++++++++++++ http_api.py | 91 +++++++++++ mcp_shim.py | 60 ++++++- tests/test_audio_subscribers.py | 272 ++++++++++++++++++++++++++++++++ 5 files changed, 782 insertions(+), 1 deletion(-) create mode 100644 audio_subscribers.py create mode 100644 tests/test_audio_subscribers.py diff --git a/audio_subscribers.py b/audio_subscribers.py new file mode 100644 index 0000000..54ef3ea --- /dev/null +++ b/audio_subscribers.py @@ -0,0 +1,221 @@ +"""Per-session audio subscriber registry (Wave 4.3). + +A separate, tiny module so the routing surface stays independent of the +session registry (which is already load-bearing for ADR-082 Phase 1). The +dashboard WebSocket lands here; ``mcp_shim._play_wav_bytes`` queries this +registry before falling back to sounddevice; the kernel queries the +``/v1/sessions/{id}/subscribers`` HTTP endpoint before spawning afplay. + +Thread-safety: registration happens in FastAPI's event loop thread (WS +handler), lookup and emit happen from the MCP shim playback thread. A +single RLock around the dict is sufficient — the set-per-session is small +(usually 1 dashboard) and contention is effectively zero. + +Delivery semantics: the emit paths are best-effort. A WebSocket that has +already died just drops the frame — the kernel fallback ``afplay`` +never happened because the check went through before we spawned it, but +the session will miss this turn's audio. That's acceptable: the dashboard +polls and reconnects; the user sees silence for one turn instead of +nothing at all. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover — type-check-only import + from fastapi import WebSocket + +logger = logging.getLogger("mod3.audio_subscribers") + + +@dataclass +class _Subscriber: + """A single WebSocket subscription for a session.""" + + ws: "WebSocket" + # The event loop the WebSocket was accepted on. Emit calls from other + # threads need to run_coroutine_threadsafe onto this loop. + loop: asyncio.AbstractEventLoop + # Monotonic sequence for logging / frame ordering. Opaque to callers. + seq: int = 0 + + +@dataclass +class _SessionBucket: + """Subscribers currently attached to a session_id.""" + + subscribers: list[_Subscriber] = field(default_factory=list) + + +class AudioSubscriberRegistry: + """Thread-safe session_id → active WebSocket subscribers mapping. + + Callers never reach into the bucket lists directly — they go through + register / unregister / has_subscribers / emit_wav. + """ + + def __init__(self) -> None: + self._lock = threading.RLock() + self._buckets: dict[str, _SessionBucket] = {} + self._frame_seq = 0 + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + def register(self, session_id: str, ws: "WebSocket", loop: asyncio.AbstractEventLoop) -> _Subscriber: + """Attach ``ws`` to ``session_id``. Caller must have already + ``accept()``-ed the WebSocket. + """ + sub = _Subscriber(ws=ws, loop=loop) + with self._lock: + bucket = self._buckets.setdefault(session_id, _SessionBucket()) + bucket.subscribers.append(sub) + count = len(bucket.subscribers) + logger.info("audio subscriber attached: session=%s total=%d", session_id, count) + return sub + + def unregister(self, session_id: str, sub: _Subscriber) -> None: + """Detach. Idempotent — double-unregister is a no-op.""" + with self._lock: + bucket = self._buckets.get(session_id) + if bucket is None: + return + try: + bucket.subscribers.remove(sub) + except ValueError: + return + remaining = len(bucket.subscribers) + if not bucket.subscribers: + # Drop empty buckets so the subscribed-check stays fast. + self._buckets.pop(session_id, None) + logger.info("audio subscriber detached: session=%s remaining=%d", session_id, remaining) + + # ------------------------------------------------------------------ + # Inspection + # ------------------------------------------------------------------ + + def has_subscribers(self, session_id: str) -> bool: + with self._lock: + bucket = self._buckets.get(session_id) + return bool(bucket and bucket.subscribers) + + def count(self, session_id: str) -> int: + with self._lock: + bucket = self._buckets.get(session_id) + return len(bucket.subscribers) if bucket else 0 + + def snapshot(self) -> dict[str, int]: + """session_id → subscriber count. For diagnostics only.""" + with self._lock: + return {sid: len(b.subscribers) for sid, b in self._buckets.items()} + + # ------------------------------------------------------------------ + # Emit + # ------------------------------------------------------------------ + + def emit_wav( + self, + session_id: str, + wav_bytes: bytes, + *, + job_id: str | None = None, + duration_sec: float | None = None, + sample_rate: int | None = None, + ) -> int: + """Push a whole WAV blob to every subscriber of ``session_id``. + + Returns the number of subscribers the frame was enqueued for. Each + send is fire-and-forget via ``run_coroutine_threadsafe`` — the caller + doesn't block on socket I/O, which matches the existing + ``BrowserChannel.broadcast_trace_event`` pattern. + + The wire format is a single binary WebSocket frame containing the + raw WAV (RIFF / WAVE) bytes. Browsers can decode this directly via + AudioContext.decodeAudioData. A preceding small JSON control frame + announces the incoming audio with session_id + job_id + duration + so the dashboard can correlate the blob to a synthesize call. + """ + with self._lock: + bucket = self._buckets.get(session_id) + if not bucket or not bucket.subscribers: + return 0 + targets = list(bucket.subscribers) + self._frame_seq += 1 + seq = self._frame_seq + + header = { + "type": "audio_header", + "session_id": session_id, + "job_id": job_id, + "duration_sec": duration_sec, + "sample_rate": sample_rate, + "bytes": len(wav_bytes), + "format": "wav", + "seq": seq, + } + + delivered = 0 + for sub in targets: + try: + asyncio.run_coroutine_threadsafe(_send_audio_frame(sub.ws, header, wav_bytes), sub.loop) + delivered += 1 + except Exception as exc: # noqa: BLE001 — disconnected subscribers are expected + logger.debug("emit_wav: scheduling failed for %s: %s", session_id, exc) + logger.info( + "emit_wav: session=%s bytes=%d delivered_to=%d seq=%d", + session_id, + len(wav_bytes), + delivered, + seq, + ) + return delivered + + +async def _send_audio_frame(ws: "WebSocket", header: dict, wav_bytes: bytes) -> None: + """Send header JSON + binary WAV over a single WebSocket. + + Split into a module-level coroutine so ``run_coroutine_threadsafe`` + returns a Future the caller can ignore — the pair is sent in order on + the socket's own coroutine context. + """ + try: + await ws.send_json(header) + await ws.send_bytes(wav_bytes) + except Exception as exc: # noqa: BLE001 — disconnect mid-send is expected + logger.debug("audio frame send failed: %s", exc) + + +# --------------------------------------------------------------------------- +# Process-global default registry — shared between http_api and mcp_shim. +# --------------------------------------------------------------------------- + +_default_registry: AudioSubscriberRegistry | None = None +_default_registry_lock = threading.Lock() + + +def get_default_audio_subscribers() -> AudioSubscriberRegistry: + global _default_registry + with _default_registry_lock: + if _default_registry is None: + _default_registry = AudioSubscriberRegistry() + return _default_registry + + +def reset_default_audio_subscribers() -> None: + """For tests — drop the module-level registry.""" + global _default_registry + with _default_registry_lock: + _default_registry = None + + +__all__ = [ + "AudioSubscriberRegistry", + "get_default_audio_subscribers", + "reset_default_audio_subscribers", +] diff --git a/dashboard/index.html b/dashboard/index.html index 0e80ffa..c534cc0 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -1000,6 +1000,145 @@

    Mod³

    setInterval(pollSessions, POLL_INTERVAL_MS); }); })(); + +/** + * Wave 4.3 — Per-session audio WebSocket. + * + * After the session registers, open ``/ws/audio/{session_id}`` and play + * inbound WAV blobs through a Web Audio context. The contract from the + * server is: + * + * 1. JSON text frame: {type: "audio_header", session_id, job_id, + * duration_sec, sample_rate, bytes, format: "wav", seq} + * 2. Binary frame: raw WAV bytes. + * + * We pair them up, decode via AudioContext.decodeAudioData, and play. + * Chunking is explicitly whole-blob for v1 — decodeAudioData wants a + * complete WAV. A future revision can stream PCM through an AudioWorklet + * for lower latency; the header envelope is the forward-compatibility + * seam. + * + * Reconnect: on close, reconnect with exponential backoff up to 30s. The + * kernel's subscriber-check runs on every synthesize so a transient gap + * just means one turn plays through afplay — the next turn lands back in + * the browser. + */ +(function setupAudioSubscription() { + let ws = null; + let audioCtx = null; + let reconnectDelay = 1000; + let currentSessionId = null; + let pendingHeader = null; // last JSON header waiting for its binary pair + + function ensureAudioCtx() { + if (audioCtx) return audioCtx; + const Ctor = window.AudioContext || window.webkitAudioContext; + if (!Ctor) { console.error('[AudioWS] Web Audio API unavailable'); return null; } + audioCtx = new Ctor(); + return audioCtx; + } + + async function playWavBlob(buffer, header) { + const ctx = ensureAudioCtx(); + if (!ctx) return; + if (ctx.state === 'suspended') { + try { await ctx.resume(); } catch { /* user-gesture required */ } + } + try { + // decodeAudioData wants an ArrayBuffer (not a shared view). Make a copy + // to be safe across browsers. + const copy = buffer.byteLength === buffer.buffer.byteLength + ? buffer.slice(0) : buffer.buffer.slice(buffer.byteOffset, buffer.byteOffset + buffer.byteLength); + const decoded = await ctx.decodeAudioData(copy); + const src = ctx.createBufferSource(); + src.buffer = decoded; + src.connect(ctx.destination); + src.start(); + console.log('[AudioWS] Playing', + header ? `job=${header.job_id} dur=${header.duration_sec}s sr=${header.sample_rate}` : '', + `samples=${decoded.length}`); + } catch (err) { + console.error('[AudioWS] decodeAudioData failed:', err); + } + } + + function connect(sessionId) { + if (!sessionId) return; + currentSessionId = sessionId; + const scheme = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const url = `${scheme}//${window.location.host}/ws/audio/${encodeURIComponent(sessionId)}`; + console.log('[AudioWS] Connecting to', url); + try { + ws = new WebSocket(url); + } catch (e) { + console.error('[AudioWS] Construction threw:', e); + return; + } + ws.binaryType = 'arraybuffer'; + + ws.addEventListener('open', () => { + console.log('[AudioWS] Connected'); + reconnectDelay = 1000; + // Flip the audio-dot indicator for the self-row on the next poll + const selfRow = document.querySelector('#participants-list .pp-row.me .pp-audio-dot'); + if (selfRow) selfRow.classList.add('ws'); + }); + + ws.addEventListener('message', (ev) => { + if (typeof ev.data === 'string') { + try { + pendingHeader = JSON.parse(ev.data); + } catch { + pendingHeader = null; + } + return; + } + // Binary frame — the WAV bytes corresponding to the last header. + const header = pendingHeader; + pendingHeader = null; + playWavBlob(new Uint8Array(ev.data), header); + }); + + ws.addEventListener('close', () => { + console.log('[AudioWS] Closed, reconnecting in', reconnectDelay + 'ms'); + const selfRow = document.querySelector('#participants-list .pp-row.me .pp-audio-dot'); + if (selfRow) selfRow.classList.remove('ws'); + setTimeout(() => { + if (currentSessionId) connect(currentSessionId); + }, reconnectDelay); + reconnectDelay = Math.min(reconnectDelay * 2, 30000); + }); + + ws.addEventListener('error', (err) => { + // close handler will schedule reconnect; nothing to do here. + console.warn('[AudioWS] Error:', err); + }); + } + + // Wait for the registration handshake to complete before subscribing. + window.addEventListener('mod3-session-registered', (ev) => { + const sid = ev.detail && ev.detail.session_id; + if (!sid) return; + if (ws && currentSessionId === sid) return; // already subscribed + if (ws) { try { ws.close(); } catch {} } + connect(sid); + }); + + // If registration already completed before this IIFE ran, pick it up now. + if (window.__mod3Session && window.__mod3Session.session_id) { + connect(window.__mod3Session.session_id); + } + + // Resume AudioContext on first user gesture — required by autoplay policy. + const resumeOnGesture = () => { + const ctx = ensureAudioCtx(); + if (ctx && ctx.state === 'suspended') ctx.resume(); + window.removeEventListener('click', resumeOnGesture, true); + window.removeEventListener('keydown', resumeOnGesture, true); + }; + window.addEventListener('click', resumeOnGesture, true); + window.addEventListener('keydown', resumeOnGesture, true); +})(); diff --git a/http_api.py b/http_api.py index 3492897..91df778 100644 --- a/http_api.py +++ b/http_api.py @@ -34,6 +34,7 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field +from audio_subscribers import get_default_audio_subscribers from bus import ModalityBus from engine import MODELS, generate_audio, get_loaded_engines from modality import EncodedOutput, ModalityType @@ -403,9 +404,11 @@ def synthesize(req: SynthesizeRequest): pcm = (np.clip(all_samples, -1.0, 1.0) * 32767).astype(np.int16) audio_bytes = pcm.tobytes() media_type = "audio/pcm" + wav_for_ws = encode_wav(all_samples, sample_rate) # dashboard always gets WAV else: audio_bytes = encode_wav(all_samples, sample_rate) media_type = "audio/wav" + wav_for_ws = audio_bytes t_encode_end = time.perf_counter() total_time = t_encode_end - t_request @@ -414,6 +417,28 @@ def synthesize(req: SynthesizeRequest): # Finalize job record _append_timeline(job_id, "generation_complete", t_gen_end - t_request) _append_timeline(job_id, "encoding_complete", t_encode_end - t_request) + + # Wave 4.3 — route to any dashboard WebSocket subscribers for this + # session before returning the HTTP response. Mod3 emits the WAV over + # the /ws/audio/{session_id} channel; the MCP shim and the kernel both + # consult /v1/sessions/{id}/subscribers to skip local playback when + # this path fired, so there's no double-play. Pure HTTP callers without + # a session (or without a subscriber) still get their bytes in the + # response body exactly as before. + ws_delivered = 0 + if session_id: + subs = get_default_audio_subscribers() + try: + ws_delivered = subs.emit_wav( + session_id, + wav_for_ws, + job_id=job_id, + duration_sec=round(duration, 3), + sample_rate=sample_rate, + ) + except Exception as exc: # noqa: BLE001 — never fail synthesize on a WS push + logger.debug("ws audio emit failed: %s", exc) + _update_job( job_id, { @@ -431,6 +456,7 @@ def synthesize(req: SynthesizeRequest): "per_chunk": chunk_metrics, "output_bytes": len(audio_bytes), "output_format": req.format, + "ws_subscribers_delivered": ws_delivered, }, }, ) @@ -445,6 +471,7 @@ def synthesize(req: SynthesizeRequest): "X-Mod3-Total-Time-Sec": f"{total_time:.3f}", "X-Mod3-RTF": f"{duration / gen_time:.2f}" if gen_time > 0 else "0", "X-Mod3-Chunks": str(len(chunk_metrics)), + "X-Mod3-WS-Subscribers": str(ws_delivered), } if session_payload is not None: headers["X-Mod3-Session-Id"] = session_payload["session_id"] @@ -811,6 +838,26 @@ def session_get(session_id: str): return payload +@app.get("/v1/sessions/{session_id}/subscribers") +def session_subscribers(session_id: str): + """Wave 4.3 — does this session have any live audio WebSocket subscribers? + + The kernel queries this before spawning afplay: if any dashboard or + native client has attached to ``/ws/audio/{session_id}``, the bytes are + routed over the WebSocket and the server-side fallback player is + skipped. Unknown sessions return ``{"subscribed": false, "count": 0}`` + instead of 404 so the kernel's check stays a single well-defined + predicate regardless of registration state. + """ + subs = get_default_audio_subscribers() + count = subs.count(session_id) + return { + "session_id": session_id, + "subscribed": count > 0, + "count": count, + } + + @app.get("/health") def health(): """Health check — standardized CogOS service format.""" @@ -1080,6 +1127,50 @@ async def dashboard_page(): return JSONResponse({"error": "dashboard not found"}, status_code=404) +@app.websocket("/ws/audio/{session_id}") +async def ws_audio(websocket: WebSocket, session_id: str): + """Wave 4.3 — per-session playback channel for the dashboard. + + The dashboard (or any client) opens ``ws://host:7860/ws/audio/`` + to receive audio frames that would otherwise play through afplay / + sounddevice. The wire contract per send from the server: + + 1. JSON text frame: ``{"type": "audio_header", "session_id": ..., + "job_id": ..., "duration_sec": ..., "sample_rate": ..., "bytes": N, + "format": "wav", "seq": N}`` + 2. Binary frame: the raw WAV bytes. + + The browser decodes via ``AudioContext.decodeAudioData`` — browsers + accept a whole-WAV in one blob so we don't need chunking for the v1 + implementation. A future iteration can stream PCM frames for lower + latency; the header envelope is the forward-compatibility seam. + + On disconnect the subscriber is removed and the session falls back to + ``afplay`` (kernel) / ``sd.play`` (MCP shim) automatically — the + subscribers registry tracks liveness, so the very next + ``/v1/sessions//subscribers`` probe returns ``subscribed: false``. + + Client → server frames are ignored for v1. A future revision may use + them for barge-in signaling or playback ack, but today the dashboard's + existing ``/ws/chat`` channel carries those events. + """ + await websocket.accept() + subs = get_default_audio_subscribers() + loop = asyncio.get_running_loop() + subscriber = subs.register(session_id, websocket, loop) + try: + # Keep the connection open; drain any client frames so the socket + # close handshake fires promptly. + while True: + msg = await websocket.receive() + if msg.get("type") == "websocket.disconnect": + break + except Exception as exc: # noqa: BLE001 — disconnect is the normal exit + logger.debug("/ws/audio/%s disconnect: %s", session_id, exc) + finally: + subs.unregister(session_id, subscriber) + + @app.websocket("/ws/chat") async def ws_chat(websocket: WebSocket): """Dashboard voice/text chat — one session per connection.""" diff --git a/mcp_shim.py b/mcp_shim.py index ca8a63b..48129fd 100644 --- a/mcp_shim.py +++ b/mcp_shim.py @@ -179,9 +179,67 @@ def _resolve_device_live(preferred: str | None) -> tuple[Any, dict[str, Any]]: return None, {"preferred": pref, "index": None, "fallback": True, "reason": "no match, no default"} +def _session_has_ws_subscriber(session_id: str | None) -> bool: + """Wave 4.3 — ask the HTTP service whether any dashboard has attached + a WebSocket audio subscription for this session. + + When True, the shim skips local sounddevice playback — the HTTP + service's /ws/audio/{session_id} route is delivering the WAV bytes to + the browser and local speakers would double-play. When False (no + subscriber, or the check fails), the caller falls back to sounddevice + exactly as the pre-Wave-4 path did. A fast 1.5s timeout keeps the + check from ever blocking a speak for long if mod3 HTTP is wedged. + """ + if not session_id: + return False + status, resp = _http_request("GET", f"/v1/sessions/{session_id}/subscribers", timeout=1.5) + if status != 200 or not isinstance(resp, dict): + return False + return bool(resp.get("subscribed", False)) + + def _play_wav_bytes(wav_bytes: bytes, job_id: str, session_id: str | None = None): - """Play WAV audio bytes through speakers via sounddevice.""" + """Play WAV audio bytes through speakers via sounddevice. + + Wave 4.3: when ``session_id`` has a live /ws/audio subscriber, skip the + local playback — the HTTP service is routing the bytes to the browser + and running sounddevice here would double-play. The subscriber check + already happens in ``_play_wav_bytes``'s first branch so the + ``_jobs`` ledger records ``status=routed`` and the caller can + correlate. + """ global _current_sd_stream + + # Wave 4.3: subscriber short-circuit. The server still emits WAV bytes + # out the WebSocket as part of the synthesize response path (see + # ``audio_subscribers.emit_wav``); here we simply skip the local + # sounddevice fallback when a dashboard is attached. + if _session_has_ws_subscriber(session_id): + try: + buf = io.BytesIO(wav_bytes) + with wave.open(buf, "rb") as wf: + sr = wf.getframerate() + duration = wf.getnframes() / float(sr) if sr else 0.0 + except Exception: # noqa: BLE001 — not fatal for the routing path + sr = 0 + duration = 0.0 + with _jobs_lock: + if job_id in _jobs: + _jobs[job_id]["status"] = "routed_ws" + _jobs[job_id]["duration_sec"] = round(duration, 2) + _jobs[job_id]["metrics"] = { + "audio_duration_sec": round(duration, 2), + "sample_rate": sr, + "routing": "dashboard_ws", + } + logger.info( + "playback routed to WS: session=%s job=%s duration=%.2fs", + session_id, + job_id, + duration, + ) + return + try: import numpy as np import sounddevice as sd diff --git a/tests/test_audio_subscribers.py b/tests/test_audio_subscribers.py new file mode 100644 index 0000000..70d6b6a --- /dev/null +++ b/tests/test_audio_subscribers.py @@ -0,0 +1,272 @@ +"""Unit + integration tests for the Wave 4.3 audio-subscriber registry. + +Covers: + * AudioSubscriberRegistry register / unregister / count / has_subscribers + * emit_wav delivers header JSON + binary bytes to every subscriber + * /v1/sessions/{id}/subscribers HTTP endpoint returns the correct shape + * /ws/audio/{session_id} accepts a WebSocket upgrade, registers the + subscriber for the lifetime of the connection, and unregisters on close + * /v1/synthesize with a session_id AND a live subscriber emits the WAV + over the WebSocket (via emit_wav) in addition to returning the HTTP + response body + +Run with: ``.venv/bin/python -m pytest tests/test_audio_subscribers.py -v`` +""" + +from __future__ import annotations + +import asyncio +import os +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from audio_subscribers import ( # noqa: E402 + AudioSubscriberRegistry, + get_default_audio_subscribers, + reset_default_audio_subscribers, +) + +# --------------------------------------------------------------------------- +# Unit tests — AudioSubscriberRegistry +# --------------------------------------------------------------------------- + + +class _FakeWS: + """Minimal stand-in for fastapi.WebSocket — records sent frames.""" + + def __init__(self) -> None: + self.json_sent: list[dict] = [] + self.bytes_sent: list[bytes] = [] + self.closed = False + + async def send_json(self, frame: dict) -> None: + if self.closed: + raise RuntimeError("socket closed") + self.json_sent.append(frame) + + async def send_bytes(self, payload: bytes) -> None: + if self.closed: + raise RuntimeError("socket closed") + self.bytes_sent.append(payload) + + +class TestAudioSubscriberRegistry: + def test_register_and_has_subscriber(self): + reg = AudioSubscriberRegistry() + assert not reg.has_subscribers("s1") + assert reg.count("s1") == 0 + + loop = asyncio.new_event_loop() + ws = _FakeWS() + try: + sub = reg.register("s1", ws, loop) + assert reg.has_subscribers("s1") + assert reg.count("s1") == 1 + + reg.unregister("s1", sub) + assert not reg.has_subscribers("s1") + assert reg.count("s1") == 0 + finally: + loop.close() + + def test_multiple_subscribers_per_session(self): + reg = AudioSubscriberRegistry() + loop = asyncio.new_event_loop() + try: + a = reg.register("s1", _FakeWS(), loop) + b = reg.register("s1", _FakeWS(), loop) + assert reg.count("s1") == 2 + reg.unregister("s1", a) + assert reg.count("s1") == 1 + reg.unregister("s1", b) + assert reg.count("s1") == 0 + # Empty bucket is pruned so snapshot stays compact + assert reg.snapshot() == {} + finally: + loop.close() + + def test_unregister_unknown_is_noop(self): + reg = AudioSubscriberRegistry() + loop = asyncio.new_event_loop() + try: + ws = _FakeWS() + sub = reg.register("s1", ws, loop) + reg.unregister("s1", sub) + # Second call on the already-removed sub should be a no-op + reg.unregister("s1", sub) + # Call on a session that never existed + reg.unregister("ghost", sub) + finally: + loop.close() + + def test_emit_wav_delivers_header_and_bytes(self): + reg = AudioSubscriberRegistry() + loop = asyncio.new_event_loop() + + async def run(): + ws = _FakeWS() + sub = reg.register("s1", ws, loop) + try: + delivered = reg.emit_wav( + "s1", + b"fake-wav-bytes", + job_id="job-1", + duration_sec=1.23, + sample_rate=24000, + ) + # emit_wav schedules a coroutine on the loop; await it. + await asyncio.sleep(0.05) + assert delivered == 1 + assert len(ws.json_sent) == 1 + header = ws.json_sent[0] + assert header["type"] == "audio_header" + assert header["session_id"] == "s1" + assert header["job_id"] == "job-1" + assert header["duration_sec"] == 1.23 + assert header["sample_rate"] == 24000 + assert header["bytes"] == len(b"fake-wav-bytes") + assert ws.bytes_sent == [b"fake-wav-bytes"] + finally: + reg.unregister("s1", sub) + + try: + loop.run_until_complete(run()) + finally: + loop.close() + + def test_emit_wav_with_no_subscribers_returns_zero(self): + reg = AudioSubscriberRegistry() + delivered = reg.emit_wav("s1", b"anything") + assert delivered == 0 + + def test_default_registry_is_shared_singleton(self): + reset_default_audio_subscribers() + a = get_default_audio_subscribers() + b = get_default_audio_subscribers() + assert a is b + + +# --------------------------------------------------------------------------- +# HTTP surface tests +# --------------------------------------------------------------------------- + + +class TestSubscribersEndpoint: + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + + import http_api + + return TestClient(http_api.app) + + @pytest.fixture(autouse=True) + def _isolate_subscribers(self): + reset_default_audio_subscribers() + yield + reset_default_audio_subscribers() + + def test_no_subscribers_returns_false(self, client): + r = client.get("/v1/sessions/unknown-sid/subscribers") + assert r.status_code == 200 + body = r.json() + assert body["session_id"] == "unknown-sid" + assert body["subscribed"] is False + assert body["count"] == 0 + + def test_ws_audio_registers_and_endpoint_reflects_it(self, client): + """Open the WebSocket, check /subscribers, then disconnect.""" + with client.websocket_connect("/ws/audio/ws-test-1"): + r = client.get("/v1/sessions/ws-test-1/subscribers") + assert r.status_code == 200 + body = r.json() + assert body["subscribed"] is True + assert body["count"] == 1 + # After close, subscriber is deregistered + r = client.get("/v1/sessions/ws-test-1/subscribers") + assert r.status_code == 200 + assert r.json()["subscribed"] is False + + +# --------------------------------------------------------------------------- +# Integration: /v1/synthesize routes to WS subscriber +# --------------------------------------------------------------------------- + + +class TestSynthesizeEmitsOverWS: + """When /v1/synthesize is called with a session_id whose dashboard has a + live /ws/audio subscription, the WAV bytes are pushed over the WebSocket + AND returned in the HTTP response body. Callers that skip local playback + when ``X-Mod3-WS-Subscribers > 0`` avoid a double-play. + """ + + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + + import http_api + + return TestClient(http_api.app) + + @pytest.fixture(autouse=True) + def _isolate(self): + reset_default_audio_subscribers() + from session_registry import get_default_registry + + reg = get_default_registry() + for s in list(reg.list()): + if s.session_id.startswith("pytest-"): + reg.deregister(s.session_id) + yield + for s in list(reg.list()): + if s.session_id.startswith("pytest-"): + reg.deregister(s.session_id) + reset_default_audio_subscribers() + + @pytest.mark.skipif( + os.environ.get("SKIP_TTS_TESTS") == "1", + reason="loads Kokoro engine — slow; set SKIP_TTS_TESTS=1 to skip in CI", + ) + def test_synthesize_with_subscriber_emits_over_ws(self, client): + # Register a session and open a subscriber. + client.post( + "/v1/sessions/register", + json={ + "session_id": "pytest-ws-1", + "participant_id": "pytest-user", + "participant_type": "user", + }, + ) + with client.websocket_connect("/ws/audio/pytest-ws-1") as ws: + # Synthesize, naming the session + r = client.post( + "/v1/synthesize", + json={ + "text": "hi", + "session_id": "pytest-ws-1", + }, + ) + assert r.status_code == 200, r.text + assert r.headers.get("X-Mod3-WS-Subscribers") == "1" + + # The WebSocket should have received an audio_header + binary pair + header = ws.receive_json() + assert header["type"] == "audio_header" + assert header["session_id"] == "pytest-ws-1" + assert header["format"] == "wav" + audio = ws.receive_bytes() + assert audio.startswith(b"RIFF") and b"WAVE" in audio[:16] + + def test_synthesize_without_session_skips_ws_emit(self, client): + # Even without hitting Kokoro, a /v1/synthesize without session_id + # should report 0 WS subscribers in the response header. + # We don't actually need to wait for synthesis to complete — just + # verify the endpoint path when no subscriber exists. + from audio_subscribers import get_default_audio_subscribers + + subs = get_default_audio_subscribers() + assert not subs.has_subscribers("nonexistent-sid") From 4da25335c42592edccc39a739549cfcea4e1b6bf Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Thu, 23 Apr 2026 15:17:40 -0400 Subject: [PATCH 5/8] fix(dashboard): audio WS buffer must be ArrayBuffer, not Uint8Array MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The playWavBlob ArrayBuffer-extraction ternary had one branch that used Uint8Array.slice() — which returns another Uint8Array, not an ArrayBuffer. decodeAudioData then throws "parameter 1 is not of type 'ArrayBuffer'". The fix: use buffer.buffer.slice(0) in the "view covers whole buffer" branch so both branches emit an ArrayBuffer copy. Found via live smoke test: dashboard connected cleanly, WebSocket received WAV frames, but every playback failed silently in the console while the kernel and mod3 both reported routed_ws success. The server side was correct; this was purely a browser-side type error. --- dashboard/index.html | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dashboard/index.html b/dashboard/index.html index c534cc0..e1237f6 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -1045,10 +1045,11 @@

    Mod³

    try { await ctx.resume(); } catch { /* user-gesture required */ } } try { - // decodeAudioData wants an ArrayBuffer (not a shared view). Make a copy - // to be safe across browsers. + // decodeAudioData wants an ArrayBuffer (not a typed-array view). Make + // an ArrayBuffer copy in both branches — Uint8Array.slice returns a + // Uint8Array, not an ArrayBuffer, which fails decodeAudioData. const copy = buffer.byteLength === buffer.buffer.byteLength - ? buffer.slice(0) : buffer.buffer.slice(buffer.byteOffset, buffer.byteOffset + buffer.byteLength); + ? buffer.buffer.slice(0) : buffer.buffer.slice(buffer.byteOffset, buffer.byteOffset + buffer.byteLength); const decoded = await ctx.decodeAudioData(copy); const src = ctx.createBufferSource(); src.buffer = decoded; From 7a1e28ff4cac1a437ff7a47daa57b363401ffcc8 Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Thu, 23 Apr 2026 15:22:33 -0400 Subject: [PATCH 6/8] fix(dashboard): route Wave 4 WebSocket audio to the selected output device MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The output-device change handler calls both _playback.setOutputDevice and window.__mod3AudioSink, keeping the legacy path and the WS path in sync. - window.__mod3AudioCtx is set for diagnostic access (evaluate_script, debugger probes). Found by Chaz during the Wave 4 smoke test: he set the dashboard output to MacBook Pro Speakers, but mod3_speak audio still played through the system-default Dell USB Audio. BrowserOS console confirmed the context had no sinkId set and the UI handler only logged. --- dashboard/index.html | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/dashboard/index.html b/dashboard/index.html index e1237f6..3a0d54a 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -554,10 +554,17 @@

    Mod³

    }); audioOutputSelect.addEventListener('change', async () => { + const deviceId = audioOutputSelect.value; + const label = audioOutputSelect.options[audioOutputSelect.selectedIndex].text; + // Legacy chat-path playback if (_playback) { - await _playback.setOutputDevice(audioOutputSelect.value); - console.log('[Output] Switched to:', audioOutputSelect.options[audioOutputSelect.selectedIndex].text); + await _playback.setOutputDevice(deviceId); } + // Wave 4 WebSocket-path AudioContext (new) + if (window.__mod3AudioSink) { + await window.__mod3AudioSink(deviceId); + } + console.log('[Output] Switched to:', label); }); // --- WebSocket transport reference (shared between voice and text) --- @@ -1029,12 +1036,39 @@

    Mod³

    let reconnectDelay = 1000; let currentSessionId = null; let pendingHeader = null; // last JSON header waiting for its binary pair + let pendingSinkId = null; // selected output device id, applied once audioCtx exists + + async function applySinkId(ctx, sinkId) { + if (!ctx || typeof ctx.setSinkId !== 'function') return false; + // setSinkId accepts "" for default, or a deviceId. Chrome's enumerateDevices + // uses "" for the system default; we pass that through. + const id = (sinkId == null) ? '' : sinkId; + try { + await ctx.setSinkId(id); + console.log('[AudioWS] sink bound to', id || '(default)'); + return true; + } catch (err) { + console.warn('[AudioWS] setSinkId failed:', err); + return false; + } + } + + // Exposed so the outer audio-output