diff --git a/channels.py b/channels.py index 9f6f592..94dfa67 100644 --- a/channels.py +++ b/channels.py @@ -14,6 +14,11 @@ partial_transcript, transcript, trace_event — kernel cycle-trace events (ADR-083), fanned out via BrowserChannel.broadcast_trace_event(). + +The MOD3_USE_COGOS_AGENT kernel-bridged path emits response_text AND +response_complete via BrowserChannel.broadcast_response_{text,complete} +so the dashboard UI's turn-done signal fires on every turn, matching the +local-inference path's behavior. """ from __future__ import annotations @@ -493,6 +498,40 @@ def broadcast_response_text(cls, text: str, session_id: str | None = None) -> No except Exception as exc: # noqa: BLE001 — disconnected clients are expected logger.debug("response_text send failed for %s: %s", ch.channel_id, exc) + @classmethod + def broadcast_response_complete( + cls, + metrics: dict | None = None, + session_id: str | None = None, + ) -> None: + """Push a `response_complete` frame to dashboard WebSocket clients. + + Companion to :meth:`broadcast_response_text`: the MOD3_USE_COGOS_AGENT + response bridge emits exactly one complete-frame per kernel + `agent_response` event so the dashboard UI's per-turn `isResponding` + state gets cleared (otherwise the chat panel spinner hangs forever). + + Routing and threading match `broadcast_response_text` 1:1 — pass the + same `session_id` so the completion frame lands on the same channel + that received the text frames for this turn. `metrics` follows the + local-path convention from `agent_loop._process` (`{"llm_ms": ..., + "provider": ...}`); the kernel path populates it with + `{"provider": "cogos-agent", ...}`. + """ + frame = {"type": "response_complete", "metrics": metrics or {}} + expected_channel = None + if session_id and session_id.startswith("mod3:"): + expected_channel = session_id[len("mod3:") :] + for ch in list(cls._active_channels): + if not ch._active: + continue + if expected_channel and ch.channel_id != expected_channel: + continue + try: + asyncio.run_coroutine_threadsafe(ch.ws.send_json(frame), ch._loop) + except Exception as exc: # noqa: BLE001 — disconnected clients are expected + logger.debug("response_complete send failed for %s: %s", ch.channel_id, exc) + # ------------------------------------------------------------------ # Cleanup # ------------------------------------------------------------------ diff --git a/cogos_agent_bridge.py b/cogos_agent_bridge.py index c8616f2..05ba8ab 100644 --- a/cogos_agent_bridge.py +++ b/cogos_agent_bridge.py @@ -194,9 +194,24 @@ def _extract_response_text(payload: dict) -> Optional[str]: async def run_response_bridge(subscriber: KernelBusSubscriber) -> None: """Consume `subscriber` and broadcast agent replies to dashboard clients. - `BrowserChannel.broadcast_response_text()` is thread-safe via - `run_coroutine_threadsafe`, matching the existing trace-event pattern. - Malformed events (no recoverable text) are logged at debug and skipped. + Each kernel `agent_response` event on `bus_dashboard_response` is a + complete per-turn reply (see `apps/cogos/agent_tools_respond.go` — the + `respond` tool is documented as "call at most once per user turn" and + the auto-fallback publishes once if the model skipped the tool call). + We therefore emit two dashboard frames per kernel event: + + * ``broadcast_response_text`` — the reply body (chat panel render) + * ``broadcast_response_complete`` — the turn-done signal so the UI's + per-turn spinner clears. Without this, the dashboard hangs + awaiting completion because the kernel path never reaches the + ``send_response_complete`` call that the local-inference branch + emits at ``agent_loop._process`` ~L300. + + ``BrowserChannel.broadcast_response_{text,complete}()`` are thread-safe + via ``run_coroutine_threadsafe``, matching the existing trace-event + pattern. Malformed events (no recoverable text) are logged at debug + and skipped — we do NOT emit a completion frame for skipped events + (keeps the 1:1 pairing with what the UI actually rendered). """ first_event_logged = False forwarded = 0 @@ -221,6 +236,16 @@ async def run_response_bridge(subscriber: KernelBusSubscriber) -> None: session_id = _extract_session_id(env.payload) try: BrowserChannel.broadcast_response_text(text, session_id=session_id) + # Pair the text frame with a completion frame so the dashboard's + # per-turn "awaiting response" state clears. Kernel emits exactly + # one agent_response per user turn, so one complete per event is + # the correct cardinality. + metrics: dict = {"provider": "cogos-agent"} + if env.event_id: + metrics["event_id"] = env.event_id + if env.ts: + metrics["kernel_ts"] = env.ts + BrowserChannel.broadcast_response_complete(metrics, session_id=session_id) forwarded += 1 logger.debug( "cogos-agent: forwarded response event_id=%s session=%s (total=%d)", diff --git a/demo/e2e_dashboard_harness.py b/demo/e2e_dashboard_harness.py index dd67df9..213a12c 100644 --- a/demo/e2e_dashboard_harness.py +++ b/demo/e2e_dashboard_harness.py @@ -120,7 +120,7 @@ async def deferred_interrupt() -> None: audio_pcm.append(wav_bytes) except Exception as e: print(f" (audio decode err: {e})", flush=True) - elif t == "response_done" or t == "turn_complete": + elif t == "response_complete": got_done = True done_ts = time.time() elif t == "trace_event": diff --git a/http_api.py b/http_api.py index 6b7260c..3492897 100644 --- a/http_api.py +++ b/http_api.py @@ -39,6 +39,10 @@ from modality import EncodedOutput, ModalityType from modules.text import TextModule from modules.voice import VoiceModule +from session_registry import ( + get_default_registry, + resolve_output_device, +) from vad import detect_speech_file, is_hallucination from vad import is_model_loaded as vad_loaded @@ -244,6 +248,11 @@ class SynthesizeRequest(BaseModel): speed: float = Field(default=1.25) emotion: float = Field(default=0.5) format: str = Field(default="wav", pattern="^(wav|pcm)$") + # ADR-082 Phase 1: optional session routing. When present, the + # session's assigned_voice overrides ``voice`` (unless an explicit + # non-default was passed), and the session is advanced in the global + # serializer's round-robin. + session_id: str | None = Field(default=None) class SpeechRequest(BaseModel): @@ -254,6 +263,9 @@ class SpeechRequest(BaseModel): voice: str = Field(default="af_heart") response_format: str = Field(default="mp3") speed: float = Field(default=1.0) + # ADR-082 Phase 1 extension — not part of the OpenAI schema but harmless + # to accept. When absent, behavior is identical to before Phase 1. + session_id: str | None = Field(default=None) class ShutdownRequest(BaseModel): @@ -263,6 +275,17 @@ class ShutdownRequest(BaseModel): reason: str = Field(default="shutdown-requested") +class SessionRegisterRequest(BaseModel): + """Register a session with the Mod3 communication bus (ADR-082).""" + + session_id: str + participant_id: str + participant_type: str = Field(default="agent") + preferred_voice: str | None = Field(default=None) + preferred_output_device: str = Field(default="system-default") + priority: int = Field(default=0) + + # --------------------------------------------------------------------------- # Shutdown middleware — reject new requests once shutdown is initiated # --------------------------------------------------------------------------- @@ -290,6 +313,39 @@ def synthesize(req: SynthesizeRequest): import numpy as np t_request = time.perf_counter() + + # ADR-082 Phase 1: session routing. If the request names a session, we + # honor the session's assigned voice (unless the caller explicitly + # picked a non-default voice) and account the job against the session's + # queue + serializer so multi-session callers can see round-robin. + session_id = req.session_id + session_payload: dict | None = None + if session_id: + registry = get_default_registry() + session = registry.get(session_id) + if session is None: + return JSONResponse( + status_code=404, + content={ + "error": f"session '{session_id}' is not registered — POST /v1/sessions/register first", + }, + ) + if req.voice == "bm_lewis" and session.assigned_voice != "bm_lewis": + req.voice = session.assigned_voice + # Register the submission with the serializer for accounting only. + # The synthesize endpoint is non-blocking on the audio side (we + # return bytes synchronously), so we do not run the registry's + # dispatcher here — we just record the submission. + try: + registry.submit(session_id, {"type": "synthesize", "text": req.text[:200]}) + except Exception as exc: # noqa: BLE001 + logger.debug("session submit accounting failed: %s", exc) + session_payload = { + "session_id": session.session_id, + "assigned_voice": session.assigned_voice, + "preferred_output_device": session.preferred_output_device, + } + job_id = _record_job( { "type": "synthesize", @@ -301,6 +357,7 @@ def synthesize(req: SynthesizeRequest): "emotion": req.emotion, "format": req.format, "engine": None, + "session_id": session_id, "timeline": [{"event": "request_received", "t": 0.0}], } ) @@ -389,6 +446,8 @@ def synthesize(req: SynthesizeRequest): "X-Mod3-RTF": f"{duration / gen_time:.2f}" if gen_time > 0 else "0", "X-Mod3-Chunks": str(len(chunk_metrics)), } + if session_payload is not None: + headers["X-Mod3-Session-Id"] = session_payload["session_id"] return Response(content=audio_bytes, media_type=media_type, headers=headers) @@ -400,6 +459,30 @@ def audio_speech(req: SpeechRequest): t_request = time.perf_counter() + # ADR-082 Phase 1: optional session routing. Same semantics as + # /v1/synthesize — the session's assigned voice overrides ``voice`` when + # the caller passed the default, and the submission is accounted against + # the session's queue. + session_id = req.session_id + if session_id: + registry = get_default_registry() + session = registry.get(session_id) + if session is None: + return JSONResponse( + status_code=404, + content={ + "error": f"session '{session_id}' is not registered — POST /v1/sessions/register first", + }, + ) + # OpenAI default is af_heart; if the caller left it at the default, + # prefer the session's voice. + if req.voice == "af_heart" and session.assigned_voice != "af_heart": + req.voice = session.assigned_voice + try: + registry.submit(session_id, {"type": "audio_speech", "text": req.input[:200]}) + except Exception as exc: # noqa: BLE001 + logger.debug("session submit accounting failed: %s", exc) + voice = req.voice try: voice = _resolve_voice_via_bus(voice) @@ -414,6 +497,7 @@ def audio_speech(req: SpeechRequest): "text": req.input[:200], "voice": voice, "speed": req.speed, + "session_id": session_id, "timeline": [{"event": "request_received", "t": 0.0}], } ) @@ -464,6 +548,8 @@ def audio_speech(req: SpeechRequest): "X-Mod3-Gen-Time-Sec": f"{gen_time:.3f}", "X-Mod3-Total-Time-Sec": f"{total_time:.3f}", } + if session_id: + headers["X-Mod3-Session-Id"] = session_id return Response(content=audio_bytes, media_type="audio/wav", headers=headers) @@ -650,6 +736,81 @@ def stop_speech(job_id: str = ""): return JSONResponse(status_code=503, content={"error": "Speech queue not available in HTTP-only mode"}) +# --------------------------------------------------------------------------- +# Sessions — ADR-082 Phase 1 +# --------------------------------------------------------------------------- + + +@app.post("/v1/sessions/register") +def session_register(req: SessionRegisterRequest): + """Register a session with the Mod3 communication bus (ADR-082). + + Body: + { + "session_id": "...", + "participant_id": "cog" | "sandy" | "slowbro" | ..., + "participant_type": "agent" | "user", + "preferred_voice": "bm_lewis" | ... | null, + "preferred_output_device": "system-default" | "" + } + + Returns the SessionChannel with a live-resolved output_device. + """ + registry = get_default_registry() + try: + result = registry.register( + session_id=req.session_id, + participant_id=req.participant_id, + participant_type=req.participant_type, + preferred_voice=req.preferred_voice, + preferred_output_device=req.preferred_output_device or "system-default", + priority=req.priority, + ) + except Exception as exc: # noqa: BLE001 — surface the error verbatim + return JSONResponse(status_code=400, content={"error": str(exc)}) + + payload = result.session.to_dict(device_resolver=resolve_output_device) + payload["created"] = result.created + # Top-level live device snapshot so the caller does not have to + # round-trip; the nested one is available for debugging. + payload["output_device"] = registry.resolve_device(result.session.session_id).to_dict() + return payload + + +@app.post("/v1/sessions/{session_id}/deregister") +def session_deregister(session_id: str): + """Drop a session — drains/cancels pending jobs, returns voice to pool.""" + registry = get_default_registry() + result = registry.deregister(session_id) + if result.get("status") == "not_found": + return JSONResponse(status_code=404, content=result) + return result + + +@app.get("/v1/sessions") +def session_list(): + """List all registered sessions plus a serializer snapshot.""" + registry = get_default_registry() + return { + "sessions": registry.list_serialized(), + "serializer": registry.serializer.snapshot(), + "voice_pool": registry.voice_pool(), + "voice_holders": registry.voice_holder_snapshot(), + } + + +@app.get("/v1/sessions/{session_id}") +def session_get(session_id: str): + """Get a single session's current state (with live device resolution).""" + registry = get_default_registry() + session = registry.get(session_id) + if session is None: + return JSONResponse(status_code=404, content={"error": f"session '{session_id}' not found"}) + payload = session.to_dict(device_resolver=resolve_output_device) + payload["output_device"] = registry.resolve_device(session_id).to_dict() + return payload + + @app.get("/health") def health(): """Health check — standardized CogOS service format.""" diff --git a/mcp_shim.py b/mcp_shim.py index 3f6f87b..ca8a63b 100644 --- a/mcp_shim.py +++ b/mcp_shim.py @@ -42,6 +42,14 @@ _current_sd_stream = None _playback_interrupt = threading.Event() +# ADR-082 Phase 1: local session state. Populated by tool_register_session +# so tool_speak can live-resolve the session's preferred output device +# before each playback. The HTTP service owns the canonical registry; this +# is a thin cache so the shim does not have to re-query per play. +_shim_sessions: dict[str, dict[str, Any]] = {} +_shim_sessions_lock = threading.Lock() +_active_session_id: str | None = None + # Job tracking (lightweight — just for speak/stop/status) _jobs: OrderedDict = OrderedDict() _jobs_lock = threading.Lock() @@ -86,7 +94,92 @@ def _http_request(method: str, path: str, body: dict | None = None, timeout: flo return 0, {"error": f"Request failed: {e}"} -def _play_wav_bytes(wav_bytes: bytes, job_id: str): +def _resolve_device_live(preferred: str | None) -> tuple[Any, dict[str, Any]]: + """Resolve an output device live, per the ADR-082 2026-04-22 amendment. + + ``preferred`` mirrors the SessionChannel field: "system-default" re-queries + the OS default immediately; a named device is resolved by substring + against the current device list; a numeric string picks by index; and + ``None`` falls back to the legacy ``_output_device`` module global. + + Returns ``(device_arg, diag)`` where ``device_arg`` is ready to pass to + ``sd.play(device=...)`` and ``diag`` is a dict describing how we resolved + for logging / responses. + """ + try: + import sounddevice as sd + except ImportError: + return None, {"preferred": preferred, "index": None, "reason": "sounddevice not installed"} + + if preferred is None: + return _output_device, { + "preferred": None, + "index": _output_device if isinstance(_output_device, int) else None, + "reason": "legacy module default", + } + + pref = preferred.strip() if isinstance(preferred, str) else "system-default" + if not pref or pref.lower() in ("system-default", "default"): + # Live re-query. sd.default.device is (input, output) — we want output. + try: + devices = sd.query_devices() + default_tuple = sd.default.device + idx = default_tuple[1] if isinstance(default_tuple, (tuple, list)) else None + if isinstance(idx, int) and 0 <= idx < len(devices): + return idx, { + "preferred": pref, + "index": idx, + "name": devices[idx].get("name", ""), + "reason": "OS default (live-queried)", + } + except Exception as exc: # noqa: BLE001 + return None, {"preferred": pref, "index": None, "reason": f"default query failed: {exc}"} + return None, {"preferred": pref, "index": None, "reason": "OS default unknown"} + + # Named / indexed device + try: + devices = sd.query_devices() + except Exception as exc: # noqa: BLE001 + return None, {"preferred": pref, "index": None, "reason": f"query failed: {exc}"} + + if pref.isdigit(): + i = int(pref) + if 0 <= i < len(devices) and devices[i].get("max_output_channels", 0) > 0: + return i, { + "preferred": pref, + "index": i, + "name": devices[i].get("name", ""), + "reason": "index match", + } + + low = pref.lower() + for i, d in enumerate(devices): + if d.get("max_output_channels", 0) > 0 and low in str(d.get("name", "")).lower(): + return i, { + "preferred": pref, + "index": i, + "name": d.get("name", ""), + "reason": "name match", + } + + # Fall back to system default — identity just changed devices. + try: + default_tuple = sd.default.device + idx = default_tuple[1] if isinstance(default_tuple, (tuple, list)) else None + if isinstance(idx, int) and 0 <= idx < len(devices): + return idx, { + "preferred": pref, + "index": idx, + "name": devices[idx].get("name", ""), + "fallback": True, + "reason": f"named device '{pref}' unavailable — fell back to system default", + } + except Exception: + pass + return None, {"preferred": pref, "index": None, "fallback": True, "reason": "no match, no default"} + + +def _play_wav_bytes(wav_bytes: bytes, job_id: str, session_id: str | None = None): """Play WAV audio bytes through speakers via sounddevice.""" global _current_sd_stream try: @@ -126,7 +219,30 @@ def _play_wav_bytes(wav_bytes: bytes, job_id: str): _jobs[job_id]["duration_sec"] = round(duration, 2) _playback_interrupt.clear() - device = _output_device + + # Live device resolution per playback. If the session pins a device, + # honor it; if it's system-default, re-read OS default now. This is + # the core of the ADR-082 2026-04-22 amendment. + preferred: str | None = None + if session_id: + with _shim_sessions_lock: + cfg = _shim_sessions.get(session_id) + if cfg is not None: + preferred = cfg.get("preferred_output_device", "system-default") + if preferred is None and _active_session_id: + with _shim_sessions_lock: + cfg = _shim_sessions.get(_active_session_id) + if cfg is not None: + preferred = cfg.get("preferred_output_device", "system-default") + + if preferred is not None: + device, diag = _resolve_device_live(preferred) + with _jobs_lock: + if job_id in _jobs: + _jobs[job_id]["device_resolution"] = diag + else: + device = _output_device + with _current_player_lock: _current_sd_stream = job_id @@ -165,9 +281,21 @@ def _estimate_duration(text: str, speed: float) -> float: def tool_speak( - text: str, voice: str = "bm_lewis", stream: bool = True, speed: float = 1.25, emotion: float = 0.5 + text: str, + voice: str = "bm_lewis", + stream: bool = True, + speed: float = 1.25, + emotion: float = 0.5, + session_id: str = "", ) -> str: - """Synthesize via HTTP, play locally.""" + """Synthesize via HTTP, play locally. + + When ``session_id`` is provided, the HTTP call includes it so the server + routes through ADR-082 session-aware playback (assigned voice, + preferred output device). The shim resolves its own preferred output + device live from its cached session info — each playback picks up the + current OS default. + """ if not text.strip(): return json.dumps({"status": "error", "error": "Nothing to say"}) @@ -189,16 +317,20 @@ def tool_speak( pass # Request synthesis from HTTP service + synth_body: dict[str, Any] = { + "text": text, + "voice": voice, + "speed": speed, + "emotion": emotion, + "format": "wav", + } + if session_id: + synth_body["session_id"] = session_id + status, resp = _http_request( "POST", "/v1/synthesize", - { - "text": text, - "voice": voice, - "speed": speed, - "emotion": emotion, - "format": "wav", - }, + synth_body, timeout=60.0, ) @@ -218,14 +350,19 @@ def tool_speak( "text": text[:100], "voice": voice, "created": time.time(), + "session_id": session_id or None, } while len(_jobs) > _MAX_JOBS: _jobs.popitem(last=False) - t = threading.Thread(target=_play_wav_bytes, args=(resp, job_id), daemon=True) + t = threading.Thread( + target=_play_wav_bytes, + args=(resp, job_id, session_id or None), + daemon=True, + ) t.start() - return json.dumps({"status": "speaking", "job_id": job_id}) + return json.dumps({"status": "speaking", "job_id": job_id, "session_id": session_id or None}) def tool_stop(job_id: str = "") -> str: @@ -415,6 +552,78 @@ def tool_await_voice_input(timeout_sec: float = 180.0) -> str: return json.dumps({"status": "error", "error": "Could not retrieve transcript"}) +def tool_register_session( + session_id: str, + participant_id: str, + participant_type: str = "agent", + preferred_voice: str = "", + preferred_output_device: str = "system-default", +) -> str: + """Register a session with the Mod3 bus (ADR-082 Phase 1). + + Forwards to POST /v1/sessions/register on the HTTP service, then caches + the result locally so future tool_speak() calls can live-resolve the + session's preferred output device before each playback. + """ + global _active_session_id + + body: dict[str, Any] = { + "session_id": session_id, + "participant_id": participant_id, + "participant_type": participant_type, + "preferred_output_device": preferred_output_device or "system-default", + } + if preferred_voice: + body["preferred_voice"] = preferred_voice + + status, resp = _http_request("POST", "/v1/sessions/register", body, timeout=10.0) + if status != 200: + err = resp.get("error", f"HTTP {status}") if isinstance(resp, dict) else f"HTTP {status}" + return json.dumps({"status": "error", "error": err}) + + # Cache locally — the playback path reads preferred_output_device from + # here each play. + if isinstance(resp, dict): + with _shim_sessions_lock: + _shim_sessions[session_id] = { + "participant_id": participant_id, + "participant_type": participant_type, + "assigned_voice": resp.get("assigned_voice"), + "preferred_output_device": resp.get("preferred_output_device", "system-default"), + } + _active_session_id = session_id + resp["status"] = "ok" + return json.dumps(resp) + return json.dumps({"status": "error", "error": "unexpected response shape"}) + + +def tool_deregister_session(session_id: str) -> str: + """Release a session's voice and drop pending jobs (ADR-082 Phase 1).""" + global _active_session_id + status, resp = _http_request("POST", f"/v1/sessions/{session_id}/deregister", {}, timeout=5.0) + with _shim_sessions_lock: + _shim_sessions.pop(session_id, None) + if _active_session_id == session_id: + _active_session_id = None + if status == 200 and isinstance(resp, dict): + return json.dumps(resp) + if status == 404: + return json.dumps({"status": "not_found", "session_id": session_id}) + err = resp.get("error", f"HTTP {status}") if isinstance(resp, dict) else f"HTTP {status}" + return json.dumps({"status": "error", "error": err}) + + +def tool_list_sessions() -> str: + """List all registered sessions (ADR-082 Phase 1).""" + status, resp = _http_request("GET", "/v1/sessions", timeout=5.0) + if status != 200: + return json.dumps({"status": "error", "error": f"HTTP {status}"}) + if isinstance(resp, dict): + resp["status"] = "ok" + return json.dumps(resp) + return json.dumps({"status": "error", "error": "unexpected response shape"}) + + def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: """VAD check via HTTP.""" if not os.path.exists(file_path): @@ -503,6 +712,15 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "default": 0.5, "description": "Emotion/exaggeration intensity 0.0-1.0 (Chatterbox only). Default 0.5.", }, + "session_id": { + "type": "string", + "default": "", + "description": ( + "Optional ADR-082 session id. When set and the session is registered, " + "playback uses the session's assigned voice + preferred_output_device " + "(live-resolved per playback). When empty, behaves as before." + ), + }, }, "required": ["text"], }, @@ -629,6 +847,73 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "required": ["file_path"], }, }, + { + "name": "register_session", + "description": ( + "Register a session with the Mod3 communication bus (ADR-082 Phase 1).\n\n" + "Each registered session gets its own output queue, an assigned voice from\n" + "the ranked pool, and a preferred output device that is re-queried live per\n" + "playback when set to 'system-default'. Multiple sessions share one physical\n" + "speaker via a global round-robin serializer.\n\n" + "Args:\n" + " session_id: Caller-chosen id (e.g., the Claude Code session id).\n" + " participant_id: Identity of the speaker (e.g., 'cog', 'sandy', 'slowbro').\n" + " participant_type: 'agent' or 'user'. Free-form beyond that.\n" + " preferred_voice: Optional voice preset (e.g., 'bm_lewis'). If taken,\n" + " voice_conflict=true is returned but assignment still succeeds.\n" + " preferred_output_device: 'system-default' (re-queried per playback), a\n" + " device-name substring, or a numeric index." + ), + "inputSchema": { + "type": "object", + "properties": { + "session_id": {"type": "string", "description": "Caller-chosen session id."}, + "participant_id": { + "type": "string", + "description": "Identity of the speaker (e.g., 'cog', 'sandy', 'slowbro').", + }, + "participant_type": { + "type": "string", + "default": "agent", + "description": "'agent' or 'user'. Free-form beyond that.", + }, + "preferred_voice": { + "type": "string", + "default": "", + "description": "Optional voice preset. If taken, voice_conflict is flagged.", + }, + "preferred_output_device": { + "type": "string", + "default": "system-default", + "description": "'system-default', device-name substring, or numeric index.", + }, + }, + "required": ["session_id", "participant_id"], + }, + }, + { + "name": "deregister_session", + "description": ( + "Release a session's voice and drop its pending jobs (ADR-082 Phase 1).\n\n" + "Call at session end so the voice can be allocated to a new session." + ), + "inputSchema": { + "type": "object", + "properties": { + "session_id": {"type": "string", "description": "Session id to deregister."}, + }, + "required": ["session_id"], + }, + }, + { + "name": "list_sessions", + "description": ( + "List all registered sessions with live device resolution (ADR-082 Phase 1).\n\n" + "Returns each session's assigned voice, preferred output device, queue depth,\n" + "and the serializer's current state (policy + round-robin cursor)." + ), + "inputSchema": {"type": "object", "properties": {}}, + }, ] TOOL_DISPATCH = { @@ -638,6 +923,7 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: stream=args.get("stream", True), speed=args.get("speed", 1.25), emotion=args.get("emotion", 0.5), + session_id=args.get("session_id", ""), ), "speech_status": lambda args: tool_speech_status( job_id=args.get("job_id", ""), @@ -656,6 +942,17 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: file_path=args["file_path"], threshold=args.get("threshold", 0.5), ), + "register_session": lambda args: tool_register_session( + session_id=args["session_id"], + participant_id=args["participant_id"], + participant_type=args.get("participant_type", "agent"), + preferred_voice=args.get("preferred_voice", ""), + preferred_output_device=args.get("preferred_output_device", "system-default"), + ), + "deregister_session": lambda args: tool_deregister_session( + session_id=args["session_id"], + ), + "list_sessions": lambda args: tool_list_sessions(), } diff --git a/server.py b/server.py index 8909f02..e520096 100644 --- a/server.py +++ b/server.py @@ -40,6 +40,11 @@ from modality import ModalityType, ModuleStatus from modules.voice import PlaceholderDecoder, VoiceModule from pipeline_state import InterruptInfo, PipelineState +from session_registry import ( + ResolvedOutputDevice, + get_default_registry, + resolve_output_device, +) logger = logging.getLogger("mod3.server") @@ -748,6 +753,28 @@ def _estimate_duration_sec(text: str, speed: float) -> float: return (words / 150.0) * 60.0 / speed +def _resolve_device_for_entry(entry: dict) -> tuple[int | str | None, ResolvedOutputDevice | None]: + """Resolve the output device for a speech job, live. + + Priority (per the ADR-082 2026-04-22 amendment): + 1. If the job's session has a preferred_output_device, re-query live — + "system-default" always reads the current OS default, and named + devices are enumerated per dispatch. + 2. Otherwise fall back to the legacy ``_output_device`` module global + set by set_output_device() so existing callers keep working. + """ + session_id = entry.get("session_id") + if session_id: + try: + registry = get_default_registry() + resolved = registry.resolve_device(session_id) + entry["resolved_device"] = resolved + return resolved.index, resolved + except Exception as exc: # noqa: BLE001 — never fail synthesis on resolution + logger.warning("device resolution failed for session %s: %s", session_id, exc) + return _output_device, None + + def _run_speech_job(entry: dict) -> None: """Execute a single speech job (blocking). Called from the drain thread.""" global _last_metrics, _current_player @@ -765,7 +792,8 @@ def _run_speech_job(entry: dict) -> None: AdaptivePlayer = _adaptive_player_class() engine, resolved_voice = _resolve_voice_via_bus(voice) model = engine_module.get_model(engine) - player = AdaptivePlayer(sample_rate=model.sample_rate, device=_output_device) + device, _resolved = _resolve_device_for_entry(entry) + player = AdaptivePlayer(sample_rate=model.sample_rate, device=device) except Exception as e: _jobs[job_id]["status"] = "error" _jobs[job_id]["error"] = str(e) @@ -865,10 +893,17 @@ def _start_speech( streaming_interval: float = 1.0, speed: float = 1.0, emotion: float = 0.5, + session_id: str | None = None, ) -> tuple[str, int]: """Submit speech to the queue. Returns (job_id, queue_position). queue_position is 0 if playing immediately, >0 if queued behind others. + + When ``session_id`` is provided, the job is tagged with it so the drain + thread can live-resolve the session's preferred output device before + playback. Voice selection still uses the explicit ``voice`` argument — + callers should pass the session's assigned_voice when registering a job + against a session. """ job_id = uuid.uuid4().hex[:8] _jobs[job_id] = { @@ -884,20 +919,21 @@ def _start_speech( "player": None, "speed": speed, "estimated_duration_sec": round(_estimate_duration_sec(text, speed), 1), + "session_id": session_id, } _prune_jobs() - position = _speech_queue.enqueue( - job_id, - { - "text": text, - "voice": voice, - "stream": stream, - "streaming_interval": streaming_interval, - "speed": speed, - "emotion": emotion, - }, - ) + entry = { + "text": text, + "voice": voice, + "stream": stream, + "streaming_interval": streaming_interval, + "speed": speed, + "emotion": emotion, + } + if session_id: + entry["session_id"] = session_id + position = _speech_queue.enqueue(job_id, entry) return job_id, position @@ -947,6 +983,7 @@ def speak( stream: bool = True, speed: float = 1.25, emotion: float = 0.5, + session_id: str = "", ) -> str: """Synthesize text to speech and play it through the user's speakers. @@ -966,10 +1003,39 @@ def speak( If False, generates all audio first then plays (better prosody). speed: Speed multiplier (engines with speed support). Default 1.25. emotion: Emotion/exaggeration intensity 0.0-1.0 (Chatterbox only). Default 0.5. + session_id: Optional ADR-082 session id. When provided and the session + is registered (see register_session), the job is routed + through the per-session queue and the session's assigned + voice + preferred_output_device are used. When empty, + falls back to today's global-queue behavior for backward + compatibility. """ if not text.strip(): return json.dumps({"status": "error", "error": "Nothing to say"}) + # Route through the session registry when session_id is provided. + # If the session is registered, its assigned_voice overrides the ``voice`` + # argument unless the caller explicitly passed a non-default voice — the + # ADR treats voice as a session identity attribute, not a per-call knob. + effective_session_id: str | None = session_id or None + if effective_session_id: + registry = get_default_registry() + session = registry.get(effective_session_id) + if session is None: + return json.dumps( + { + "status": "error", + "error": f"session '{effective_session_id}' is not registered — call register_session first", + } + ) + # If caller did not pass an explicit non-default voice, use the + # session's assigned voice. "bm_lewis" is the old default so we can't + # distinguish "explicit bm_lewis" from "unspecified"; tolerate that + # and only override when the caller asks for the default. + if voice == "bm_lewis" and session.assigned_voice != "bm_lewis": + voice = session.assigned_voice + session.state = "speaking" + # Check if user is currently speaking (barge-in signal file) user_state = "idle" try: @@ -999,7 +1065,14 @@ def speak( ) try: - job_id, position = _start_speech(text, voice, stream=stream, speed=speed, emotion=emotion) + job_id, position = _start_speech( + text, + voice, + stream=stream, + speed=speed, + emotion=emotion, + session_id=effective_session_id, + ) except ValueError as e: return json.dumps({"status": "error", "error": str(e)}) except Exception as e: @@ -1480,6 +1553,95 @@ def set_output_device(device: str = "") -> str: return json.dumps({"status": "ok", "device": _output_device}) +# --------------------------------------------------------------------------- +# Session registry (ADR-082 Phase 1) +# --------------------------------------------------------------------------- + + +@mcp.tool( + annotations={ + "readOnlyHint": False, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + } +) +def register_session( + session_id: str, + participant_id: str, + participant_type: str = "agent", + preferred_voice: str = "", + preferred_output_device: str = "system-default", +) -> str: + """Register a session with the Mod3 communication bus (ADR-082). + + Each registered session gets its own output queue, an assigned voice + from the ranked pool, and a preferred output device. The global + serializer interleaves speech across sessions (round-robin by default) + so two concurrent agents do not collide on the shared speaker. + + Args: + session_id: Caller-chosen id (e.g., the Claude Code session id). + participant_id: Identity of the speaker (e.g., 'cog', 'sandy', 'slowbro'). + participant_type: 'agent' or 'user'. Free-form beyond that. + preferred_voice: Optional voice preset. If taken, voice_conflict=true + is returned but assignment still succeeds. + preferred_output_device: 'system-default' (re-queried per playback), + a device-name substring, or a numeric index. + """ + registry = get_default_registry() + result = registry.register( + session_id=session_id, + participant_id=participant_id, + participant_type=participant_type, + preferred_voice=preferred_voice or None, + preferred_output_device=preferred_output_device or "system-default", + ) + payload = result.session.to_dict(device_resolver=resolve_output_device) + payload["status"] = "ok" + payload["created"] = result.created + # Also expose a live-resolved device at the top level for convenience — + # callers can log or display it without walking nested keys. + live = registry.resolve_device(result.session.session_id) + payload["output_device"] = live.to_dict() + return json.dumps(payload) + + +@mcp.tool( + annotations={ + "readOnlyHint": False, + "destructiveHint": True, + "idempotentHint": True, + "openWorldHint": False, + } +) +def deregister_session(session_id: str) -> str: + """Release a session's voice and drop its pending jobs.""" + registry = get_default_registry() + result = registry.deregister(session_id) + return json.dumps(result) + + +@mcp.tool( + annotations={ + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + } +) +def list_sessions() -> str: + """List all registered sessions with live device resolution.""" + registry = get_default_registry() + return json.dumps( + { + "status": "ok", + "sessions": registry.list_serialized(), + "serializer": registry.serializer.snapshot(), + } + ) + + # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- diff --git a/session_registry.py b/session_registry.py new file mode 100644 index 0000000..1aad8dc --- /dev/null +++ b/session_registry.py @@ -0,0 +1,895 @@ +"""Session-aware communication bus registry (ADR-082 Phase 1). + +Evolves Mod3 from a stateless TTS engine toward a session-aware communication +bus. Each registered session owns a SessionChannel with an assigned voice, a +preferred output device (live-queried by default), and its own output queue. +A global round-robin serializer picks the next job across sessions with +pending work so multi-agent sessions can share one physical speaker without +collisions. + +Scope for Phase 1 (per the ADR's "Migration Path" section): + + - register_session / deregister_session + list_sessions + - Per-session output queues with a global serializer (round-robin default, + priority / fifo-global policies pluggable) + - Voice assignment from the ranked Kokoro pool + - preferred_output_device field on SessionChannel with live OS-default + re-query per playback (2026-04-22 amendment) + +Out of scope for Phase 1: input routing, barge-in state machine, native input +provider. See later phases of the ADR. + +Backward compatibility: an implicit "default" session is created on first +use so existing callers that do not supply session_id keep working exactly +as before. +""" + +from __future__ import annotations + +import atexit +import heapq +import logging +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable + +logger = logging.getLogger("mod3.session_registry") + + +# --------------------------------------------------------------------------- +# Voice pool — ranked allocation per ADR-082 § Voice Assignment +# --------------------------------------------------------------------------- + +# The canonical Kokoro voice pool, ranked. Greedy allocation walks this list. +# bm_lewis is first because it is the legacy default; remaining Kokoro voices +# follow in the order the ADR specifies (heart / emma / adam / bella / isabella), +# then the unlisted Kokoro voices round out the pool so we do not run dry. +VOICE_POOL: tuple[str, ...] = ( + "bm_lewis", + "af_heart", + "bf_emma", + "am_adam", + "af_bella", + "bf_isabella", + "bm_george", + "am_michael", + "af_sarah", + "af_nicole", + "af_sky", +) + +DEFAULT_SESSION_ID = "default" +DEFAULT_PARTICIPANT_ID = "legacy" +DEFAULT_PARTICIPANT_TYPE = "agent" + +SERIALIZATION_POLICIES = ("round-robin", "priority", "fifo-global") + + +# --------------------------------------------------------------------------- +# Output-device resolution (ADR-082 amendment — 2026-04-22) +# --------------------------------------------------------------------------- + + +@dataclass +class ResolvedOutputDevice: + """Resolution result for an output-device lookup. + + ``preferred`` is the string the session asked for ("system-default" or a + named device). ``index`` is the sounddevice index that was chosen, or + ``None`` to let PortAudio pick — callers should treat ``None`` as + "PortAudio fallback" rather than relying on it implicitly resolving to + the current system default (PortAudio behavior is inconsistent across + versions, per the ADR amendment). + + ``name`` mirrors the resolved device's name when known. ``fallback`` is + True when the requested name was unavailable and we backed off to the + system default. + """ + + preferred: str + index: int | None + name: str + fallback: bool = False + reason: str = "" + + def to_dict(self) -> dict[str, Any]: + return { + "preferred": self.preferred, + "index": self.index, + "name": self.name, + "fallback": self.fallback, + "reason": self.reason, + } + + +def resolve_output_device( + preferred: str | None = "system-default", + *, + query_devices: Callable[[], list[dict[str, Any]]] | None = None, + default_output_index: Callable[[], int | None] | None = None, +) -> ResolvedOutputDevice: + """Resolve ``preferred`` to a concrete sounddevice index, live. + + For ``"system-default"`` (or None/empty) the OS default is re-queried + every call — no caching. When pinned to a named device we enumerate the + current device list and pick by substring match; if nothing matches we + fall back to system-default and flag ``fallback=True`` so callers can + emit a device_fallback warning event. + + The ``query_devices`` and ``default_output_index`` callables are test + seams; they default to sounddevice when omitted. This keeps the live + requery contract explicit — every call goes through the callable, no + module-level state caches the previous result. + """ + + if query_devices is None or default_output_index is None: + + def _query() -> list[dict[str, Any]]: + import sounddevice as sd + + return list(sd.query_devices()) + + def _default() -> int | None: + import sounddevice as sd + + # sd.default.device is (input_index, output_index). We only care + # about the output side. Re-read every call so we pick up whatever + # the OS chose just now. + value = sd.default.device + if isinstance(value, (tuple, list)) and len(value) >= 2: + idx = value[1] + return int(idx) if isinstance(idx, int) and idx >= 0 else None + return None + + query_devices = query_devices or _query + default_output_index = default_output_index or _default + + pref = (preferred or "system-default").strip() + if not pref: + pref = "system-default" + + try: + devices = query_devices() + except Exception as exc: # noqa: BLE001 — device enumeration can fail noisily + logger.warning("resolve_output_device: query_devices failed: %s", exc) + return ResolvedOutputDevice( + preferred=pref, + index=None, + name="(unresolved)", + fallback=True, + reason=f"query_devices failed: {exc}", + ) + + def _default_resolution() -> ResolvedOutputDevice: + try: + idx = default_output_index() + except Exception as exc: # noqa: BLE001 + logger.warning("resolve_output_device: default_output_index failed: %s", exc) + return ResolvedOutputDevice( + preferred=pref, + index=None, + name="(system default, unresolved)", + fallback=True, + reason=f"default_output_index failed: {exc}", + ) + + if idx is None or idx < 0 or idx >= len(devices): + # Falling back to PortAudio implicit — we would rather be explicit + # but the platform did not give us a usable default. + return ResolvedOutputDevice( + preferred=pref, + index=None, + name="(system default)", + fallback=False, + reason="OS default unknown — PortAudio implicit", + ) + + name = devices[idx].get("name", "(unknown)") + return ResolvedOutputDevice( + preferred=pref, + index=idx, + name=name, + fallback=False, + reason="OS default", + ) + + if pref.lower() in ("system-default", "default"): + return _default_resolution() + + # Named device — try numeric index first, then substring on name. + if pref.isdigit(): + idx = int(pref) + if 0 <= idx < len(devices) and devices[idx].get("max_output_channels", 0) > 0: + return ResolvedOutputDevice( + preferred=pref, + index=idx, + name=devices[idx].get("name", "(unknown)"), + fallback=False, + reason="index match", + ) + + pref_lower = pref.lower() + for i, d in enumerate(devices): + if d.get("max_output_channels", 0) <= 0: + continue + if pref_lower in str(d.get("name", "")).lower(): + return ResolvedOutputDevice( + preferred=pref, + index=i, + name=d.get("name", "(unknown)"), + fallback=False, + reason="name match", + ) + + # No match — fall back to system default. + resolved = _default_resolution() + resolved.fallback = True + resolved.reason = f"named device '{pref}' unavailable — fell back to system default" + logger.warning( + "resolve_output_device: %s unavailable, falling back to system default (idx=%s name=%s)", + pref, + resolved.index, + resolved.name, + ) + return resolved + + +# --------------------------------------------------------------------------- +# Session channel +# --------------------------------------------------------------------------- + + +SessionState = str # "idle" | "speaking" | "blocked" | "waiting_for_input" + + +@dataclass +class SessionChannel: + """A registered session's bus endpoint. + + Holds identity, voice, device preference, a per-session output queue, + and lifecycle state. The instance is mutated in place under + ``SessionRegistry``'s lock — callers should not hold references across + deregistration. + """ + + session_id: str + participant_id: str + participant_type: str + assigned_voice: str + voice_conflict: bool = False + preferred_voice: str | None = None + preferred_output_device: str = "system-default" + state: SessionState = "idle" + priority: int = 0 + registered_at: float = field(default_factory=time.time) + last_active: float = field(default_factory=time.time) + # Internal: pending jobs waiting for the global serializer to pick them up. + # Elements are opaque to the registry; the serializer pops them in FIFO + # order within a session. + _queue: deque[Any] = field(default_factory=deque, repr=False) + # Per-session monotonic submit counter — used as a tie-breaker for the + # global round-robin policy and for fifo-global arrival ordering. + _submit_seq: int = field(default=0, repr=False) + + def to_dict(self, *, device_resolver: Callable[[str], ResolvedOutputDevice] | None = None) -> dict[str, Any]: + """Serialize for API responses. ``device_resolver`` lets the caller + re-query the current device (per the ADR amendment — no caching). + """ + d = { + "session_id": self.session_id, + "participant_id": self.participant_id, + "participant_type": self.participant_type, + "assigned_voice": self.assigned_voice, + "voice_conflict": self.voice_conflict, + "preferred_voice": self.preferred_voice, + "preferred_output_device": self.preferred_output_device, + "state": self.state, + "priority": self.priority, + "registered_at": self.registered_at, + "last_active": self.last_active, + "queue_depth": len(self._queue), + } + if device_resolver is not None: + try: + d["output_device"] = device_resolver(self.preferred_output_device).to_dict() + except Exception as exc: # noqa: BLE001 — never fail serialization on device enumeration + d["output_device"] = { + "preferred": self.preferred_output_device, + "index": None, + "name": "(unresolved)", + "fallback": True, + "reason": f"resolver error: {exc}", + } + return d + + +# --------------------------------------------------------------------------- +# Global serializer +# --------------------------------------------------------------------------- + + +@dataclass(order=True) +class _SerializedJob: + """Internal envelope for the global serializer. + + Priority-tuple ordering: lower sort-key first. We sort by + (negative_priority, arrival_seq) so higher-priority jobs win and FIFO + breaks ties. ``payload`` is opaque — the serializer only forwards it to + the dispatch callback. + """ + + sort_key: tuple[int, int] + session_id: str = field(compare=False) + submit_seq: int = field(compare=False) + payload: Any = field(compare=False, default=None) + + +class GlobalSerializer: + """Layer-2 serializer over per-session output queues. + + Policies (ADR-082 § Output Serialization): + + * ``round-robin`` — alternate across sessions with pending work. This + is the default and the reason Phase 1 exists at all: two concurrent + sessions should not be able to starve each other. + * ``priority`` — highest-priority session drains first; ties fall back + to round-robin ordering. + * ``fifo-global`` — strict arrival order across all sessions (matches + today's single-global-queue behavior for migration parity). + + The serializer is push-driven: registrations submit jobs, and a + dedicated dispatcher thread picks the next one per policy. Callers + receive a ``QueuedJob``-shaped handle so existing call sites (speech + queue, bus.act) can swap in without type changes. + """ + + def __init__( + self, + *, + policy: str = "round-robin", + dispatcher: Callable[[str, Any], Any] | None = None, + now: Callable[[], float] = time.time, + ): + if policy not in SERIALIZATION_POLICIES: + raise ValueError(f"unknown serialization policy: {policy}") + self._policy = policy + self._dispatcher = dispatcher + self._now = now + + self._lock = threading.RLock() + self._cond = threading.Condition(self._lock) + # round-robin cursor — list of session_ids in rotation order; we pop + # from the front and append to the back when a session has more work. + self._rr_cursor: deque[str] = deque() + self._rr_seen: set[str] = set() + # priority heap for "priority" policy: (neg_prio, submit_seq, session, job_id) + self._priority_heap: list[tuple[int, int, str, str]] = [] + # fifo-global heap: (submit_seq, session, job_id) + self._fifo_heap: list[tuple[int, str, str]] = [] + # monotonic arrival counter shared across all sessions + self._global_seq = 0 + + self._sessions: dict[str, SessionChannel] = {} + self._thread: threading.Thread | None = None + self._stopping = False + # Diagnostics: order in which jobs are dispatched. Bounded; newest + # first. Primarily for tests and the /v1/sessions dashboard. + self._dispatch_log: deque[tuple[float, str, str]] = deque(maxlen=256) + + # -- Policy plumbing ---------------------------------------------------- + + @property + def policy(self) -> str: + return self._policy + + def set_policy(self, policy: str) -> None: + if policy not in SERIALIZATION_POLICIES: + raise ValueError(f"unknown serialization policy: {policy}") + with self._lock: + self._policy = policy + + def attach_dispatcher(self, dispatcher: Callable[[str, Any], Any]) -> None: + """Set or replace the per-job dispatcher. + + The dispatcher receives ``(session_id, payload)`` on the dispatcher + thread and runs synchronously — its completion is what advances the + queue. The intent is: dispatcher blocks until the playback for this + job finishes, preserving the single-speaker contract. + """ + with self._lock: + self._dispatcher = dispatcher + + # -- Registration management ------------------------------------------- + + def attach_session(self, session: SessionChannel) -> None: + with self._lock: + self._sessions[session.session_id] = session + + def detach_session(self, session_id: str) -> int: + """Drop a session and any queued jobs. Returns count of jobs dropped.""" + with self._lock: + session = self._sessions.pop(session_id, None) + dropped = 0 + if session is not None: + dropped = len(session._queue) + session._queue.clear() + # Best-effort: prune cursor entries for this session. The heaps + # may still reference stale jobs; _next_job() will skip them. + self._rr_seen.discard(session_id) + self._rr_cursor = deque(s for s in self._rr_cursor if s != session_id) + return dropped + + # -- Submission --------------------------------------------------------- + + def submit( + self, + session_id: str, + payload: Any, + *, + priority: int | None = None, + ) -> str: + """Submit a job for ``session_id``. + + Returns an opaque job_id string. The payload is handed verbatim to + the dispatcher when its turn comes up. Raises KeyError if the + session is not registered — callers that want auto-registration + should go through SessionRegistry.submit() instead. + """ + with self._cond: + session = self._sessions.get(session_id) + if session is None: + raise KeyError(f"session '{session_id}' is not registered") + job_id = uuid.uuid4().hex[:8] + self._global_seq += 1 + session._submit_seq += 1 + prio = priority if priority is not None else session.priority + session._queue.append((self._global_seq, job_id, payload)) + session.last_active = self._now() + + # Update scheduling structures + if session_id not in self._rr_seen: + self._rr_seen.add(session_id) + self._rr_cursor.append(session_id) + heapq.heappush(self._priority_heap, (-prio, self._global_seq, session_id, job_id)) + heapq.heappush(self._fifo_heap, (self._global_seq, session_id, job_id)) + + self._cond.notify_all() + return job_id + + # -- Dispatch thread ---------------------------------------------------- + + def start(self) -> None: + """Start the dispatcher thread. Idempotent.""" + with self._lock: + if self._thread is not None and self._thread.is_alive(): + return + self._stopping = False + self._thread = threading.Thread( + target=self._run, + name="mod3-global-serializer", + daemon=True, + ) + self._thread.start() + + def stop(self) -> None: + with self._cond: + self._stopping = True + self._cond.notify_all() + t = self._thread + if t is not None: + t.join(timeout=2.0) + self._thread = None + + def _run(self) -> None: + while True: + with self._cond: + while not self._stopping and not self._has_pending_unlocked(): + self._cond.wait(timeout=0.5) + if self._stopping: + return + picked = self._pop_next_unlocked() + dispatcher = self._dispatcher + if picked is None: + continue + session_id, job_id, payload = picked + self._dispatch_log.appendleft((self._now(), session_id, job_id)) + if dispatcher is None: + logger.debug( + "GlobalSerializer: no dispatcher attached — dropping job %s for %s", + job_id, + session_id, + ) + continue + try: + dispatcher(session_id, payload) + except Exception as exc: # noqa: BLE001 — keep dispatcher robust + logger.exception("GlobalSerializer dispatcher raised: %s", exc) + + def _has_pending_unlocked(self) -> bool: + for s in self._sessions.values(): + if s._queue: + return True + return False + + def _pop_next_unlocked(self) -> tuple[str, str, Any] | None: + """Pick and remove the next job according to policy.""" + if self._policy == "round-robin": + return self._pop_round_robin_unlocked() + if self._policy == "priority": + return self._pop_priority_unlocked() + # fifo-global + return self._pop_fifo_unlocked() + + def _pop_round_robin_unlocked(self) -> tuple[str, str, Any] | None: + # Walk the cursor until we find a session with pending work. Skip + # sessions whose queues are empty — we rebuild the cursor lazily. + checked = 0 + total = len(self._rr_cursor) + while checked < total and self._rr_cursor: + sid = self._rr_cursor.popleft() + session = self._sessions.get(sid) + if session is None or not session._queue: + self._rr_seen.discard(sid) + checked += 1 + continue + seq, job_id, payload = session._queue.popleft() + # If the session still has more, put it at the back of the + # rotation so other sessions go first. + if session._queue: + self._rr_cursor.append(sid) + else: + self._rr_seen.discard(sid) + return sid, job_id, payload + return None + + def _pop_priority_unlocked(self) -> tuple[str, str, Any] | None: + while self._priority_heap: + neg_prio, seq, sid, job_id = heapq.heappop(self._priority_heap) + session = self._sessions.get(sid) + if session is None or not session._queue: + continue + # Pop the matching (seq, job_id) from the session queue. We match + # by seq because heap order and queue order can diverge when two + # sessions submit at once. + found_idx = None + for i, (qseq, qjid, _payload) in enumerate(session._queue): + if qseq == seq and qjid == job_id: + found_idx = i + break + if found_idx is None: + continue + qseq, qjid, payload = session._queue[found_idx] + del session._queue[found_idx] + return sid, qjid, payload + return None + + def _pop_fifo_unlocked(self) -> tuple[str, str, Any] | None: + while self._fifo_heap: + seq, sid, job_id = heapq.heappop(self._fifo_heap) + session = self._sessions.get(sid) + if session is None or not session._queue: + continue + found_idx = None + for i, (qseq, qjid, _payload) in enumerate(session._queue): + if qseq == seq and qjid == job_id: + found_idx = i + break + if found_idx is None: + continue + qseq, qjid, payload = session._queue[found_idx] + del session._queue[found_idx] + return sid, qjid, payload + return None + + # -- Introspection ------------------------------------------------------ + + def snapshot(self) -> dict[str, Any]: + with self._lock: + return { + "policy": self._policy, + "sessions": { + sid: { + "queue_depth": len(s._queue), + "last_active": s.last_active, + "state": s.state, + } + for sid, s in self._sessions.items() + }, + "rr_cursor": list(self._rr_cursor), + "recent_dispatches": list(self._dispatch_log), + } + + +# --------------------------------------------------------------------------- +# Session registry +# --------------------------------------------------------------------------- + + +@dataclass +class RegistrationResult: + session: SessionChannel + created: bool # False when re-registering an existing session_id + voice_conflict: bool + + +class SessionRegistry: + """Thread-safe registry of SessionChannels. + + Owns voice-pool allocation, device preference, and the global serializer. + The registry is deliberately independent of ModalityBus — tests can spin + it up without the full bus stack, and the bus can adopt it incrementally. + """ + + def __init__( + self, + *, + voice_pool: Iterable[str] | None = None, + serializer: GlobalSerializer | None = None, + device_resolver: Callable[[str], ResolvedOutputDevice] | None = None, + ): + self._lock = threading.RLock() + self._sessions: dict[str, SessionChannel] = {} + self._voice_pool: list[str] = list(voice_pool if voice_pool is not None else VOICE_POOL) + # Track which voice is currently held by which session, first-come + # first-served. A second request for the same voice is honored (voice + # is identity — collisions should be rare) but flagged. + self._voice_holder: dict[str, str] = {} + self._serializer = serializer or GlobalSerializer() + self._device_resolver = device_resolver or resolve_output_device + + # -- Lifecycle ---------------------------------------------------------- + + @property + def serializer(self) -> GlobalSerializer: + return self._serializer + + def start(self) -> None: + self._serializer.start() + + def stop(self) -> None: + self._serializer.stop() + + # -- Session management ------------------------------------------------- + + def register( + self, + *, + session_id: str, + participant_id: str, + participant_type: str = DEFAULT_PARTICIPANT_TYPE, + preferred_voice: str | None = None, + preferred_output_device: str = "system-default", + priority: int = 0, + ) -> RegistrationResult: + with self._lock: + existing = self._sessions.get(session_id) + if existing is not None: + existing.participant_id = participant_id + existing.participant_type = participant_type + existing.preferred_output_device = preferred_output_device or "system-default" + existing.last_active = time.time() + # Don't reshuffle voice on re-register. If the caller wants a + # different voice they should deregister first. + return RegistrationResult(existing, created=False, voice_conflict=existing.voice_conflict) + + voice, conflict = self._allocate_voice(session_id, preferred_voice) + session = SessionChannel( + session_id=session_id, + participant_id=participant_id, + participant_type=participant_type, + assigned_voice=voice, + voice_conflict=conflict, + preferred_voice=preferred_voice, + preferred_output_device=preferred_output_device or "system-default", + priority=priority, + ) + self._sessions[session_id] = session + self._serializer.attach_session(session) + logger.info( + "registered session: id=%s participant=%s voice=%s conflict=%s device=%s", + session_id, + participant_id, + voice, + conflict, + preferred_output_device, + ) + return RegistrationResult(session, created=True, voice_conflict=conflict) + + def deregister(self, session_id: str) -> dict[str, Any]: + with self._lock: + session = self._sessions.pop(session_id, None) + if session is None: + return {"status": "not_found", "session_id": session_id} + dropped = self._serializer.detach_session(session_id) + # Return the voice to the pool if we still hold it for this session. + if self._voice_holder.get(session.assigned_voice) == session_id: + del self._voice_holder[session.assigned_voice] + logger.info( + "deregistered session: id=%s voice=%s dropped_jobs=%d", + session_id, + session.assigned_voice, + dropped, + ) + return { + "status": "ok", + "session_id": session_id, + "released_voice": session.assigned_voice, + "dropped_jobs": dropped, + } + + def get(self, session_id: str) -> SessionChannel | None: + with self._lock: + return self._sessions.get(session_id) + + def get_or_create_default(self) -> SessionChannel: + """Backward-compat path — route legacy callers to an implicit session.""" + with self._lock: + session = self._sessions.get(DEFAULT_SESSION_ID) + if session is not None: + return session + result = self.register( + session_id=DEFAULT_SESSION_ID, + participant_id=DEFAULT_PARTICIPANT_ID, + participant_type=DEFAULT_PARTICIPANT_TYPE, + preferred_voice=None, + preferred_output_device="system-default", + ) + return result.session + + def list(self) -> list[SessionChannel]: + with self._lock: + return list(self._sessions.values()) + + def list_serialized(self) -> list[dict[str, Any]]: + with self._lock: + resolver = self._device_resolver + return [s.to_dict(device_resolver=resolver) for s in self._sessions.values()] + + def voice_pool(self) -> list[str]: + return list(self._voice_pool) + + def voice_holder_snapshot(self) -> dict[str, str]: + with self._lock: + return dict(self._voice_holder) + + # -- Submission --------------------------------------------------------- + + def submit( + self, + session_id: str | None, + payload: Any, + *, + priority: int | None = None, + auto_create_default: bool = True, + ) -> tuple[str, str]: + """Enqueue ``payload`` on ``session_id``. + + When session_id is None and auto_create_default is True, falls back + to the "default" session so legacy call sites keep working. + + Returns ``(resolved_session_id, job_id)``. + """ + if session_id is None or session_id == "": + if not auto_create_default: + raise ValueError("session_id is required when auto_create_default=False") + session = self.get_or_create_default() + session_id = session.session_id + else: + with self._lock: + if session_id not in self._sessions: + raise KeyError(f"session '{session_id}' is not registered") + + job_id = self._serializer.submit(session_id, payload, priority=priority) + return session_id, job_id + + # -- Device routing ----------------------------------------------------- + + def resolve_device(self, session_id: str) -> ResolvedOutputDevice: + """Live-resolve the output device for ``session_id``. + + Per the 2026-04-22 amendment, this is a live property: the OS default + is re-queried every call when the session's preference is + ``system-default``. Never cache the return value on the session. + """ + with self._lock: + session = self._sessions.get(session_id) + preferred = session.preferred_output_device if session else "system-default" + return self._device_resolver(preferred) + + def set_preferred_device(self, session_id: str, preferred: str) -> ResolvedOutputDevice: + with self._lock: + session = self._sessions.get(session_id) + if session is None: + raise KeyError(session_id) + session.preferred_output_device = preferred or "system-default" + return self.resolve_device(session_id) + + # -- Internals ---------------------------------------------------------- + + def _allocate_voice(self, session_id: str, preferred: str | None) -> tuple[str, bool]: + # Explicit preference — honor it, flag if someone else already holds + # it. Voice is identity, not exclusive; collisions are operator bugs + # at worst. + if preferred: + if preferred not in self._voice_pool: + # Extend the pool lazily so out-of-band voices still work. + self._voice_pool.append(preferred) + holder = self._voice_holder.get(preferred) + conflict = holder is not None and holder != session_id + self._voice_holder.setdefault(preferred, session_id) + return preferred, conflict + + # Greedy allocation — first voice in the pool whose holder is absent + # or dead. + for voice in self._voice_pool: + if voice not in self._voice_holder: + self._voice_holder[voice] = session_id + return voice, False + + # Pool exhausted — fall back to the first voice and flag a collision. + fallback = self._voice_pool[0] if self._voice_pool else "bm_lewis" + return fallback, True + + +# --------------------------------------------------------------------------- +# Process-global default registry — shared across server.py, http_api.py, +# and tests. Tests can instantiate their own SessionRegistry if they need +# isolation; the module-level singleton is just a convenience for runtime. +# --------------------------------------------------------------------------- + + +_default_registry: SessionRegistry | None = None +_default_registry_lock = threading.Lock() + + +def get_default_registry() -> SessionRegistry: + global _default_registry + with _default_registry_lock: + if _default_registry is None: + _default_registry = SessionRegistry() + _default_registry.start() + return _default_registry + + +def reset_default_registry() -> None: + """For tests — tear down the module-level registry.""" + global _default_registry + with _default_registry_lock: + if _default_registry is not None: + _default_registry.stop() + _default_registry = None + + +@atexit.register +def _shutdown_default_registry() -> None: + """Stop the dispatcher thread cleanly at interpreter shutdown. + + Without this the daemon thread gets killed mid-callback during + finalization and Python emits a Fatal Python error about a NULL + thread state. The dispatcher is idempotent for stop() so this is + safe to call even if the registry was never created. + """ + try: + reset_default_registry() + except Exception: + pass + + +__all__ = [ + "DEFAULT_SESSION_ID", + "DEFAULT_PARTICIPANT_ID", + "DEFAULT_PARTICIPANT_TYPE", + "GlobalSerializer", + "RegistrationResult", + "ResolvedOutputDevice", + "SessionChannel", + "SessionRegistry", + "SERIALIZATION_POLICIES", + "VOICE_POOL", + "get_default_registry", + "reset_default_registry", + "resolve_output_device", +] diff --git a/tests/test_browser_channel_routing.py b/tests/test_browser_channel_routing.py index fd7f89c..917049b 100644 --- a/tests/test_browser_channel_routing.py +++ b/tests/test_browser_channel_routing.py @@ -92,6 +92,17 @@ def _broadcast_with_loop(text: str, session_id: str | None = None) -> None: loop.close() +def _broadcast_complete_with_loop(metrics: dict | None = None, session_id: str | None = None) -> None: + """Sibling of _broadcast_with_loop for the response_complete frame.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + with patch("channels.asyncio.run_coroutine_threadsafe", _patched_run): + BrowserChannel.broadcast_response_complete(metrics, session_id=session_id) + finally: + loop.close() + + def test_broadcast_with_no_session_id_fans_out_to_all_active_channels(): loop = asyncio.new_event_loop() a = _FakeChannel("browser:aaa", loop) @@ -169,5 +180,78 @@ def test_broadcast_session_id_with_no_match_drops_silently(): loop.close() +# --------------------------------------------------------------------------- +# response_complete routing (kernel-path turn-done signal) +# +# Paired with broadcast_response_text in cogos_agent_bridge.run_response_bridge; +# must follow the same session-scoped routing so the complete-frame lands on +# the originating dashboard channel (otherwise multi-client setups see +# cross-talk or hang on a non-matching spinner). +# --------------------------------------------------------------------------- + + +def test_broadcast_complete_with_no_session_id_fans_out_to_all_active_channels(): + loop = asyncio.new_event_loop() + a = _FakeChannel("browser:aaa", loop) + b = _FakeChannel("browser:bbb", loop) + BrowserChannel._active_channels.update({a, b}) + + _broadcast_complete_with_loop({"provider": "cogos-agent"}) + + for ch in (a, b): + assert len(ch.ws.sent) == 1 + frame = ch.ws.sent[0] + assert frame["type"] == "response_complete" + assert frame["metrics"] == {"provider": "cogos-agent"} + + loop.close() + + +def test_broadcast_complete_routes_to_matching_session_only(): + loop = asyncio.new_event_loop() + a = _FakeChannel("browser:aaa", loop) + b = _FakeChannel("browser:bbb", loop) + BrowserChannel._active_channels.update({a, b}) + + _broadcast_complete_with_loop( + {"provider": "cogos-agent", "event_id": "r42"}, + session_id="mod3:browser:bbb", + ) + + assert a.ws.sent == [] + assert len(b.ws.sent) == 1 + assert b.ws.sent[0]["type"] == "response_complete" + assert b.ws.sent[0]["metrics"]["event_id"] == "r42" + + loop.close() + + +def test_broadcast_complete_defaults_to_empty_metrics(): + loop = asyncio.new_event_loop() + a = _FakeChannel("browser:aaa", loop) + BrowserChannel._active_channels.add(a) + + _broadcast_complete_with_loop() + + assert a.ws.sent == [{"type": "response_complete", "metrics": {}}] + + loop.close() + + +def test_broadcast_complete_skips_inactive_channels(): + loop = asyncio.new_event_loop() + a = _FakeChannel("browser:aaa", loop) + a._active = False + b = _FakeChannel("browser:bbb", loop) + BrowserChannel._active_channels.update({a, b}) + + _broadcast_complete_with_loop({"provider": "cogos-agent"}) + + assert a.ws.sent == [] + assert len(b.ws.sent) == 1 + + loop.close() + + if __name__ == "__main__": sys.exit(pytest.main([__file__, "-v"])) diff --git a/tests/test_cogos_agent_bridge.py b/tests/test_cogos_agent_bridge.py index 2d5de43..5063fd0 100644 --- a/tests/test_cogos_agent_bridge.py +++ b/tests/test_cogos_agent_bridge.py @@ -104,11 +104,20 @@ def test_run_response_bridge_fans_out_to_broadcast(): ] sub = _FakeSubscriber(envelopes) - with patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_bcast: + with ( + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_text, + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_complete") as mock_done, + ): asyncio.run(run_response_bridge(sub)) - texts = [c.args[0] for c in mock_bcast.call_args_list] + texts = [c.args[0] for c in mock_text.call_args_list] assert texts == ["reply one", "free-form string reply"] + # One completion frame per forwarded text — never zero, never doubled. + assert mock_done.call_count == 2 + # Payload shape: provider tag plus any kernel-supplied timing/ids. + for call in mock_done.call_args_list: + metrics = call.args[0] + assert metrics["provider"] == "cogos-agent" # --------------------------------------------------------------------------- @@ -144,14 +153,22 @@ def test_run_response_bridge_forwards_session_id_when_present(): envelopes = [_env({"content": inner}, "r1")] sub = _FakeSubscriber(envelopes) - with patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_bcast: + with ( + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_text, + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_complete") as mock_done, + ): asyncio.run(run_response_bridge(sub)) - assert mock_bcast.call_count == 1 - call = mock_bcast.call_args_list[0] - assert call.args[0] == "scoped reply" + assert mock_text.call_count == 1 + text_call = mock_text.call_args_list[0] + assert text_call.args[0] == "scoped reply" # session_id passed as keyword - assert call.kwargs.get("session_id") == "mod3:browser:abc" + assert text_call.kwargs.get("session_id") == "mod3:browser:abc" + # Completion frame routes to the same session so the originating + # channel's spinner clears — not a broadcast. + assert mock_done.call_count == 1 + done_call = mock_done.call_args_list[0] + assert done_call.kwargs.get("session_id") == "mod3:browser:abc" def test_run_response_bridge_falls_back_to_broadcast_when_no_session_id(): @@ -160,13 +177,38 @@ def test_run_response_bridge_falls_back_to_broadcast_when_no_session_id(): envelopes = [_env({"content": inner}, "r1")] sub = _FakeSubscriber(envelopes) - with patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_bcast: + with ( + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_text, + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_complete") as mock_done, + ): asyncio.run(run_response_bridge(sub)) - assert mock_bcast.call_count == 1 - call = mock_bcast.call_args_list[0] - assert call.args[0] == "broadcast reply" - assert call.kwargs.get("session_id") is None + assert mock_text.call_count == 1 + text_call = mock_text.call_args_list[0] + assert text_call.args[0] == "broadcast reply" + assert text_call.kwargs.get("session_id") is None + # Completion frame also broadcasts (matches the text-frame routing). + assert mock_done.call_count == 1 + assert mock_done.call_args_list[0].kwargs.get("session_id") is None + + +def test_run_response_bridge_skips_complete_when_text_missing(): + """Events with no recoverable text must not emit a completion frame. + + Holding the 1:1 pairing keeps the UI's turn counter honest: we only + mark a turn done when we actually rendered something for it. + """ + envelopes = [_env({"foo": "bar"}, "r1"), _env({}, "r2")] + sub = _FakeSubscriber(envelopes) + + with ( + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_text") as mock_text, + patch("cogos_agent_bridge.BrowserChannel.broadcast_response_complete") as mock_done, + ): + asyncio.run(run_response_bridge(sub)) + + assert mock_text.call_count == 0 + assert mock_done.call_count == 0 def test_post_user_message_uses_runtime_endpoint(monkeypatch): diff --git a/tests/test_session_registry.py b/tests/test_session_registry.py new file mode 100644 index 0000000..f6df350 --- /dev/null +++ b/tests/test_session_registry.py @@ -0,0 +1,410 @@ +"""Unit + integration tests for the ADR-082 Phase 1 session registry. + +Covers: + * Voice-pool allocation (greedy + preferred + collision flagging) + * Device resolution with a stubbed sounddevice-shaped lookup + — "system-default" live re-query, named-device match, fallback + * Global serializer round-robin across sessions + * SessionRegistry submit() auto-creating a "default" session for + legacy callers + * HTTP surface: /v1/sessions endpoints smoke-tested via FastAPI TestClient + +Run with: ``.venv/bin/python -m pytest tests/test_session_registry.py -v`` +""" + +from __future__ import annotations + +import os +import sys +import threading +from typing import Any + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from session_registry import ( # noqa: E402 + DEFAULT_SESSION_ID, + VOICE_POOL, + GlobalSerializer, + SessionRegistry, + resolve_output_device, +) + +# --------------------------------------------------------------------------- +# Voice-pool allocation +# --------------------------------------------------------------------------- + + +class TestVoiceAllocation: + def test_greedy_allocation_walks_the_pool(self): + reg = SessionRegistry() + a = reg.register(session_id="s1", participant_id="a").session + b = reg.register(session_id="s2", participant_id="b").session + c = reg.register(session_id="s3", participant_id="c").session + assert a.assigned_voice == VOICE_POOL[0] + assert b.assigned_voice == VOICE_POOL[1] + assert c.assigned_voice == VOICE_POOL[2] + assert not a.voice_conflict + assert not b.voice_conflict + assert not c.voice_conflict + + def test_preferred_voice_honored(self): + reg = SessionRegistry() + result = reg.register(session_id="s1", participant_id="a", preferred_voice="bf_emma") + assert result.session.assigned_voice == "bf_emma" + assert not result.voice_conflict + + def test_preferred_voice_collision_flagged_but_assigned(self): + reg = SessionRegistry() + first = reg.register(session_id="s1", participant_id="a", preferred_voice="bm_lewis") + second = reg.register(session_id="s2", participant_id="b", preferred_voice="bm_lewis") + # Both get the voice (collision is a flag, not a veto — per ADR) + assert first.session.assigned_voice == "bm_lewis" + assert not first.voice_conflict + assert second.session.assigned_voice == "bm_lewis" + assert second.voice_conflict + + def test_deregister_returns_voice_to_pool(self): + reg = SessionRegistry() + reg.register(session_id="s1", participant_id="a") # takes bm_lewis + result = reg.deregister("s1") + assert result["status"] == "ok" + assert result["released_voice"] == VOICE_POOL[0] + # Next registration picks up the released voice + next_session = reg.register(session_id="s2", participant_id="b").session + assert next_session.assigned_voice == VOICE_POOL[0] + + def test_out_of_pool_preferred_voice_is_added(self): + """Using a voice the pool didn't know about still works.""" + reg = SessionRegistry() + result = reg.register(session_id="s1", participant_id="a", preferred_voice="af_ember") + assert result.session.assigned_voice == "af_ember" + assert "af_ember" in reg.voice_pool() + + +# --------------------------------------------------------------------------- +# Device resolution +# --------------------------------------------------------------------------- + + +def _fake_device_list() -> list[dict[str, Any]]: + """Synthetic device list — stable across calls.""" + return [ + {"name": "MacBook Pro Speakers", "max_output_channels": 2}, + {"name": "DisplayLink Monitor", "max_output_channels": 2}, + {"name": "Realtek USB2.0 Audio", "max_output_channels": 2}, + {"name": "Microphone", "max_output_channels": 0}, + ] + + +class TestDeviceResolution: + def test_system_default_live_requery(self): + """The default must be re-read every call — not cached.""" + current = {"idx": 0} + + def query() -> list[dict[str, Any]]: + return _fake_device_list() + + def default_idx() -> int: + return current["idx"] + + r1 = resolve_output_device("system-default", query_devices=query, default_output_index=default_idx) + assert r1.index == 0 + assert r1.name == "MacBook Pro Speakers" + assert not r1.fallback + + # User plugs in headphones; the OS default changes. + current["idx"] = 2 + r2 = resolve_output_device("system-default", query_devices=query, default_output_index=default_idx) + assert r2.index == 2 + assert r2.name == "Realtek USB2.0 Audio" + + def test_named_device_match(self): + r = resolve_output_device( + "Realtek", + query_devices=_fake_device_list, + default_output_index=lambda: 0, + ) + assert r.index == 2 + assert r.name == "Realtek USB2.0 Audio" + assert not r.fallback + + def test_named_device_missing_falls_back_to_default(self): + r = resolve_output_device( + "AirPods Pro", # not plugged in + query_devices=_fake_device_list, + default_output_index=lambda: 1, + ) + assert r.fallback is True + assert r.index == 1 + assert r.name == "DisplayLink Monitor" + assert "fell back" in r.reason.lower() + + def test_numeric_index_match(self): + r = resolve_output_device( + "2", + query_devices=_fake_device_list, + default_output_index=lambda: 0, + ) + assert r.index == 2 + assert r.name == "Realtek USB2.0 Audio" + + def test_empty_preferred_treated_as_default(self): + r = resolve_output_device( + "", + query_devices=_fake_device_list, + default_output_index=lambda: 1, + ) + assert r.index == 1 + + def test_default_index_out_of_range_returns_implicit(self): + r = resolve_output_device( + "system-default", + query_devices=_fake_device_list, + default_output_index=lambda: 99, + ) + assert r.index is None + # Not a "fallback" in the device-fallback sense — OS just couldn't tell us + assert "portaudio" in r.reason.lower() or "unknown" in r.reason.lower() + + +# --------------------------------------------------------------------------- +# Global serializer +# --------------------------------------------------------------------------- + + +class TestGlobalSerializer: + def _build(self, policy: str = "round-robin") -> tuple[GlobalSerializer, list[tuple[str, Any]]]: + """Return (serializer, log) where log records (session_id, payload).""" + log: list[tuple[str, Any]] = [] + barrier = threading.Event() + + def dispatcher(session_id: str, payload: Any) -> None: + log.append((session_id, payload)) + barrier.set() + + ser = GlobalSerializer(policy=policy, dispatcher=dispatcher) + return ser, log + + def test_round_robin_interleaves_two_sessions(self): + reg = SessionRegistry() + ser = reg.serializer + + events: list[tuple[str, Any]] = [] + completed = threading.Event() + counter = {"n": 0} + + def dispatcher(session_id: str, payload: Any) -> None: + events.append((session_id, payload)) + counter["n"] += 1 + if counter["n"] >= 6: + completed.set() + + ser.attach_dispatcher(dispatcher) + reg.register(session_id="A", participant_id="a") + reg.register(session_id="B", participant_id="b") + + # Submit 3 jobs to each before starting the dispatcher so we can + # observe round-robin ordering instead of racing with a ticking queue. + for i in range(3): + reg.submit("A", {"job": f"A{i}"}) + for i in range(3): + reg.submit("B", {"job": f"B{i}"}) + + reg.start() + assert completed.wait(timeout=2.0), f"serializer stalled: {events}" + reg.stop() + + sessions = [sid for sid, _ in events] + # Round-robin starting from A (first session with pending work): + # the cursor walks A → B alternately. + expected = ["A", "B", "A", "B", "A", "B"] + assert sessions == expected, f"Expected round-robin, got {sessions}" + + def test_fifo_global_preserves_arrival_order(self): + reg = SessionRegistry() + reg.serializer.set_policy("fifo-global") + events: list[tuple[str, Any]] = [] + done = threading.Event() + + def dispatcher(session_id: str, payload: Any) -> None: + events.append((session_id, payload)) + if len(events) == 4: + done.set() + + reg.serializer.attach_dispatcher(dispatcher) + reg.register(session_id="A", participant_id="a") + reg.register(session_id="B", participant_id="b") + + reg.submit("A", "a1") + reg.submit("B", "b1") + reg.submit("A", "a2") + reg.submit("B", "b2") + + reg.start() + assert done.wait(timeout=2.0) + reg.stop() + + order = [payload for _, payload in events] + assert order == ["a1", "b1", "a2", "b2"], f"FIFO broken: {order}" + + def test_priority_policy_drains_higher_priority_first(self): + reg = SessionRegistry() + reg.serializer.set_policy("priority") + events: list[tuple[str, Any]] = [] + done = threading.Event() + + def dispatcher(session_id: str, payload: Any) -> None: + events.append((session_id, payload)) + if len(events) == 3: + done.set() + + reg.serializer.attach_dispatcher(dispatcher) + reg.register(session_id="low", participant_id="low", priority=0) + reg.register(session_id="hi", participant_id="hi", priority=10) + + reg.submit("low", "L1") + reg.submit("low", "L2") + reg.submit("hi", "H1") # this should drain first under priority + + reg.start() + assert done.wait(timeout=2.0) + reg.stop() + + assert events[0] == ("hi", "H1"), f"Priority not honored: {events}" + + def test_submit_on_unregistered_session_raises(self): + reg = SessionRegistry() + with pytest.raises(KeyError): + reg.submit("ghost", "payload", auto_create_default=False) + + def test_submit_without_session_id_auto_creates_default(self): + reg = SessionRegistry() + done = threading.Event() + events: list[tuple[str, Any]] = [] + + def dispatcher(session_id: str, payload: Any) -> None: + events.append((session_id, payload)) + done.set() + + reg.serializer.attach_dispatcher(dispatcher) + reg.submit(None, "legacy-call") + reg.start() + assert done.wait(timeout=1.0) + reg.stop() + + # Auto-created default session + assert events[0][0] == DEFAULT_SESSION_ID + assert reg.get(DEFAULT_SESSION_ID) is not None + + def test_deregister_drops_queued_jobs(self): + reg = SessionRegistry() + reg.register(session_id="A", participant_id="a") + reg.submit("A", "j1") + reg.submit("A", "j2") + result = reg.deregister("A") + assert result["dropped_jobs"] == 2 + + +# --------------------------------------------------------------------------- +# HTTP surface +# --------------------------------------------------------------------------- + + +class TestHTTPSessionEndpoints: + """Smoke-test the HTTP endpoints via FastAPI TestClient. + + We use the live registry module singleton — tests register unique session + ids with a 'pytest-' prefix and deregister in teardown so we do not leak. + """ + + @pytest.fixture + def client(self): + from fastapi.testclient import TestClient + + import http_api + + return TestClient(http_api.app) + + @pytest.fixture(autouse=True) + def _cleanup(self): + yield + # Clean up any pytest- sessions that leaked. + from session_registry import get_default_registry + + reg = get_default_registry() + for s in list(reg.list()): + if s.session_id.startswith("pytest-"): + reg.deregister(s.session_id) + + def test_register_returns_session_channel(self, client): + r = client.post( + "/v1/sessions/register", + json={ + "session_id": "pytest-s1", + "participant_id": "pytest-cog", + "participant_type": "agent", + "preferred_output_device": "system-default", + }, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["session_id"] == "pytest-s1" + assert body["participant_id"] == "pytest-cog" + assert body["assigned_voice"] in VOICE_POOL + assert body["preferred_output_device"] == "system-default" + # live-resolved device block present + assert "output_device" in body + assert "preferred" in body["output_device"] + + def test_list_includes_registered_session(self, client): + client.post( + "/v1/sessions/register", + json={"session_id": "pytest-list-1", "participant_id": "pytest-sandy"}, + ) + r = client.get("/v1/sessions") + assert r.status_code == 200 + body = r.json() + ids = [s["session_id"] for s in body["sessions"]] + assert "pytest-list-1" in ids + assert "voice_pool" in body + assert body["serializer"]["policy"] in ("round-robin", "priority", "fifo-global") + + def test_deregister_releases_voice(self, client): + client.post( + "/v1/sessions/register", + json={ + "session_id": "pytest-dr-1", + "participant_id": "pytest-cog", + "preferred_voice": "bf_emma", + }, + ) + r = client.post("/v1/sessions/pytest-dr-1/deregister") + assert r.status_code == 200 + body = r.json() + assert body["status"] == "ok" + assert body["released_voice"] == "bf_emma" + + def test_deregister_unknown_returns_404(self, client): + r = client.post("/v1/sessions/pytest-nonexistent/deregister") + assert r.status_code == 404 + + def test_get_single_session(self, client): + client.post( + "/v1/sessions/register", + json={"session_id": "pytest-get-1", "participant_id": "pytest-user"}, + ) + r = client.get("/v1/sessions/pytest-get-1") + assert r.status_code == 200 + assert r.json()["session_id"] == "pytest-get-1" + + def test_synthesize_rejects_unknown_session(self, client): + r = client.post( + "/v1/synthesize", + json={ + "text": "hello", + "session_id": "pytest-ghost", + }, + ) + assert r.status_code == 404