From 576ae946ec96fae0499bc9cf198a47eb9d6f0dcb Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Wed, 15 Apr 2026 15:04:19 -0400 Subject: [PATCH 1/9] feat: bidirectional voice pipeline + MCP shim + dashboard enhancements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Migrated from cog-workspace/apps/tts-mcp — consolidating development into the canonical mod3 repo. Voice Pipeline (5 phases): - Three-tier adaptive STT (Whisper Base 31ms + Large 470ms) - Speculative generation (agent thinks while human speaks) - Opacity-as-state rendering (transparent → solidifying → solid) - Barge-in context stitching (state snapshot on interrupt) - Self-barge draft revision (agent revises its own queued output) New files: - draft_queue.py: Thread-safe DraftQueue for speculative generation - mcp_shim.py: Lightweight MCP-to-HTTP proxy (no model loading) Modified: - agent_loop.py: Context stitching, speculative inference, self-barge - channels.py: Three-tier STT scheduler - modules/voice.py: decode_streaming(), Whisper Base loader, TTS validation - dashboard/index.html: Opacity CSS, solidification, partials, queue preview - dashboard/playback.js: Progress tracking for word-level solidification - server.py: Session-aware queue foundations Co-Authored-By: Claude Opus 4.6 (1M context) --- ARCHITECTURE.md | 372 +++++++++++---------- adaptive_player.py | 25 +- agent_loop.py | 305 ++++++++++++++++- channels.py | 185 ++++++++-- dashboard/index.html | 307 ++++++++++++++++- dashboard/playback.js | 77 ++++- dashboard/transport.js | 3 + draft_queue.py | 267 +++++++++++++++ http_api.py | 18 +- mcp.channel.json | 2 +- mcp_shim.py | 742 +++++++++++++++++++++++++++++++++++++++++ modules/voice.py | 192 ++++++++++- providers.py | 29 +- server.py | 202 ++++++++++- 14 files changed, 2449 insertions(+), 277 deletions(-) create mode 100644 draft_queue.py create mode 100644 mcp_shim.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 9def800..b037f9f 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -1,216 +1,218 @@ -# Mod3 Architecture: The Modality Bus +# Mod³ Dashboard — Process Architecture -The modality bus is the sensorimotor boundary between cognitive agents and physical signals. Agents think in cognitive events ("someone spoke", "say this"); the bus translates between those events and raw bytes (audio, text, future: vision, spatial). +## Intended Flow ``` - ModalityBus - ┌──────────────────────────────────────────────┐ - │ │ - │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ - │ │ Voice │ │ Text │ │ Vision* │ ... │ - │ │ Module │ │ Module │ │ Module │ │ - │ └────┬─────┘ └────┬────┘ └────┬────┘ │ - │ │ │ │ │ - │ ┌────┴─────────────┴────────────┴────┐ │ - │ │ Event Log + Listeners │ │ - │ └────┬─────────────┬────────────┬────┘ │ - │ │ │ │ │ - │ ┌────┴────┐ ┌─────┴─────┐ ┌──┴───┐ │ - │ │ Channel │ │ Channel │ │ ... │ │ - │ │ discord │ │ http-api │ │ │ │ - │ └─────────┘ └───────────┘ └──────┘ │ - └──────────────────────────────────────────────┘ - - * Vision/Spatial are defined in ModalityType but not yet implemented. +┌─────────────────────────────────────────────────────────────┐ +│ BROWSER │ +│ │ +│ ┌──────────┐ ┌──────────┐ ┌───────────┐ │ +│ │ Silero │ │ Text │ │ Audio │ │ +│ │ VAD v5 │ │ Input │ │ Playback │ │ +│ │ (ONNX) │ │ │ │ (Web Audio│ │ +│ └────┬─────┘ └────┬─────┘ └─────▲─────┘ │ +│ │ │ │ │ +│ │ onSpeechEnd │ sendControl │ enqueueWav │ +│ │ (Int16 PCM) │ (JSON) │ (base64 WAV) │ +│ ▼ ▼ │ │ +│ ┌────────────────────────────────────┐│ │ +│ │ VoiceTransport (WebSocket) ││ │ +│ │ binary frames ──► ──► JSON ││ │ +│ │ JSON frames ──► ◄── JSON/b64 ││ │ +│ └────────────────┬───────────────────┘│ │ +│ │ │ │ +└───────────────────┼────────────────────┼─────────────────────┘ + │ WebSocket /ws/chat │ + ▼ │ +┌───────────────────┼────────────────────┼─────────────────────┐ +│ │ MOD³ SERVER │ │ +│ │ (single process) │ │ +│ ▼ │ │ +│ ┌─────────────────────────────────────────┐ │ +│ │ BrowserChannel │ │ +│ │ │ │ +│ │ _handle_audio(pcm) → buffer │ │ +│ │ _handle_json(msg) → dispatch │ │ +│ │ _deliver_async() → send to browser │ │ +│ └──────┬─────────┬──────────▲──────────────┘ │ +│ │ │ │ │ +│ PCM audio text msg encoded output │ +│ │ │ │ │ +│ ▼ │ │ │ +│ ┌──────────┐ │ │ │ +│ │ STT │ │ │ │ +│ │ (mlx_ │ │ │ │ +│ │ whisper) │ │ │ │ +│ │ temp WAV │ │ │ │ +│ └────┬─────┘ │ │ │ +│ │ │ │ │ +│ │ transcript│ │ │ +│ ▼ ▼ │ │ +│ ┌─────────────────────┐ │ │ +│ │ CognitiveEvent │ │ │ +│ │ {content: "text"} │ │ │ +│ └──────────┬──────────┘ │ │ +│ │ │ │ +│ ▼ │ │ +│ ┌──────────────────────┐ │ │ +│ │ AgentLoop │ │ │ +│ │ │ │ │ +│ │ conversation[] │ │ │ +│ │ provider.chat() │ │ │ +│ │ → tool_calls │ │ │ +│ │ │ │ │ +│ │ DISPATCH: │ │ │ +│ │ speak(text) │ │ │ +│ │ → send_response_text ──────► channel (text to chat) │ +│ │ → bus.act(VOICE) │ │ │ +│ │ ▼ │ │ │ +│ │ send_text(text) │ │ │ +│ │ → send_response_text ──────► channel (text to chat) │ +│ │ │ │ │ +│ │ think(reasoning) │ │ │ +│ │ → (internal only) │ │ │ +│ └──────────┬────────────┘ │ │ +│ │ │ │ +│ bus.act(VOICE intent) │ │ +│ │ │ │ +│ ▼ │ │ +│ ┌──────────────────────┐ │ │ +│ │ ModalityBus │ │ │ +│ │ │ │ │ +│ │ OutputQueue │ │ │ +│ │ (per-channel FIFO) │ │ │ +│ │ │ │ │ │ +│ │ ▼ │ │ │ +│ │ VoiceEncoder │ │ │ +│ │ (Kokoro TTS) │ │ │ +│ │ → WAV bytes │ │ │ +│ │ │ │ │ │ +│ │ ch.deliver(output) ─────┘ │ +│ │ (base64 JSON) │ │ +│ └──────────────────────┘ │ +│ │ +│ ┌──────────────────────┐ │ +│ │ InferenceProvider │ │ +│ │ (mlx-lm / Ollama) │ │ +│ │ │ │ +│ │ model resident in │ │ +│ │ memory (in-process) │ │ +│ └──────────────────────┘ │ +│ │ +└──────────────────────────────────────────────────────────────┘ ``` -## Core Types (modality.py) - -### Cognitive Primitives - -The agent never touches raw bytes. It sees these: - -```python -@dataclass -class CognitiveEvent: # Input percept - modality: ModalityType # VOICE, TEXT, VISION, SPATIAL - content: str # The meaning (transcribed text, caption, etc.) - source_channel: str # Which channel it arrived on - confidence: float # Decoder certainty (0.0 - 1.0) - timestamp: float - metadata: dict[str, Any] - -@dataclass -class CognitiveIntent: # Output intent (not yet encoded) - modality: ModalityType | None # None = let the bus decide - content: str # What to communicate - target_channel: str # Specific channel, or "" for bus routing - priority: int # Higher = more urgent - metadata: dict[str, Any] # voice, speed, emotion, etc. - -@dataclass -class EncodedOutput: # Raw signal ready for delivery - modality: ModalityType - data: bytes # WAV, PNG, JSON, etc. - format: str # "wav", "png", "text", etc. - duration_sec: float - metadata: dict[str, Any] -``` - -### Abstract Base Classes - -Every modality module implements three components: - -```python -class Gate(ABC): - def check(self, raw: bytes, **kwargs) -> GateResult: ... - -class Decoder(ABC): - def decode(self, raw: bytes, **kwargs) -> CognitiveEvent: ... +## Current Problems -class Encoder(ABC): - def encode(self, intent: CognitiveIntent) -> EncodedOutput: ... +### 1. Agent blocks on TTS delivery -class ModalityModule(ABC): - modality_type -> ModalityType # Which modality this handles - gate -> Gate | None # Input filter (None = pass all) - decoder -> Decoder | None # raw -> CognitiveEvent - encoder -> Encoder | None # CognitiveIntent -> EncodedOutput - state -> ModuleState # Live HUD state - health() -> dict # Diagnostics +``` +agent_loop._process(): + await send_response_text(text) # ← fast, JSON to browser + await asyncio.to_thread(bus.act) # ← BLOCKS until TTS generates + delivers + # agent can't process next event until TTS finishes ``` -`Gate` is optional. Text has no gate (all text passes). Voice uses VAD (Voice Activity Detection) to reject silence. - -## The Bus (bus.py) - -`ModalityBus` manages module registration, signal routing, and state tracking. +**Should be:** fire-and-forget the bus.act() intent, return immediately. +bus.act(blocking=False) already returns QueuedJob — just don't await the result. -### perceive() -- Input Path +### 2. Kokoro cold start blocks OutputQueue drain thread ``` -raw bytes ──→ Gate.check() ──→ Decoder.decode() ──→ CognitiveEvent - │ │ - (rejected?) (empty content?) - ↓ ↓ - None None (filtered) -``` - -```python -bus.perceive(raw: bytes, modality: str | ModalityType, channel: str = "", **kwargs) - -> CognitiveEvent | None +OutputQueue drain thread: + _do_encode() → VoiceEncoder.encode() → engine.synthesize() + → Kokoro first-time init: ~60s blocking + → All other queued jobs wait + → _deliver_sync timeout (10s) fires on older jobs ``` -1. Resolve the modality module from the registry -2. If the module has a gate, run `gate.check(raw)`. Emit a `modality.gate` bus event. Return `None` if rejected. -3. Run `decoder.decode(raw)`. If content is empty (e.g., hallucination filtered), emit `modality.filtered` and return `None`. -4. Stamp `source_channel`, emit `modality.input`, return the event. +**Should be:** pre-warm Kokoro on server startup (background thread). -### act() -- Output Path +### 3. WebSocket lifecycle fragility ``` -CognitiveIntent ──→ resolve modality ──→ Encoder.encode() ──→ EncodedOutput - │ - channel.deliver() -``` - -```python -bus.act(intent: CognitiveIntent, channel: str = "", blocking: bool = False) - -> QueuedJob | EncodedOutput +Browser page reload → new WebSocket → new BrowserChannel + Old channel's deliver callback still referenced by bus + Old OutputQueue drain thread still running + → sends to dead WebSocket → timeout → error cascade ``` -1. Resolve output modality: explicit on intent, or inferred from channel capabilities (prefers voice over text), or defaults to text. -2. Encode via the module's encoder. Emits `modality.encode_start` and `modality.output` bus events. -3. If the target channel has a `deliver` callback, call it with the encoded output. -4. If `blocking=True`, returns `EncodedOutput` directly. Otherwise queues via `OutputQueueManager` and returns a `QueuedJob`. +**Should be:** channel cleanup on disconnect should cancel all queued jobs +for that channel. -### hud() -- Agent Awareness +### 4. STT blocks the event loop context -```python -bus.hud() -> dict ``` - -Returns a live snapshot of all modules and channels: current status, active jobs, queue depths, recent events. Designed to be injected into the agent's context window so it knows what the body is doing. - -### Channels - -Channels declare which modalities they support. The bus auto-routes output based on channel capabilities. - -```python -bus.register_channel("discord-voice", [ModalityType.VOICE, ModalityType.TEXT], - deliver=send_to_discord) +_process_utterance(): + await asyncio.to_thread(_transcribe) # blocks a thread pool thread + → mlx_whisper.transcribe() # 1-2s CPU-bound + → blocks one thread pool slot ``` -### Bus Events - -Every boundary crossing is recorded as a `BusEvent` (type, modality, channel, timestamp, data). Listeners can subscribe via `bus.on_event(callback)` for ledger integration. The bus keeps the last 500 events in memory. - -## Current Modalities +This is fine for one user. But the thread pool is shared with bus.act(). -### Voice (modules/voice.py) +### 5. No separation between thinking and acting -| Component | Class | Implementation | -|-----------|-------|----------------| -| Gate | `VoiceGate` | Silero VAD via `vad.detect_speech()`. Threshold-configurable (default 0.5). Rejects audio with no detected speech. | -| Decoder | `WhisperDecoder` | `mlx_whisper` STT on Apple Silicon. Lazy-loads `mlx-community/whisper-turbo`. Applies `vad.is_hallucination()` filter to reject phantom transcripts. | -| Decoder (legacy) | `PlaceholderDecoder` | Accepts pre-transcribed text. Used by the MCP server for the `speak` tool path where text is already known. | -| Encoder | `VoiceEncoder` | Wraps `engine.synthesize()` (Kokoro, Voxtral, Chatterbox, Spark). Default voice: `bm_lewis` at 1.25x speed. Returns WAV bytes. | +The agent loop processes ONE event at a time (_processing flag). +If bus.act() blocks, no new events can be processed. +The agent should be able to think about the next input while +TTS is generating for the current one. -### Text (modules/text.py) +## Intended Architecture (what we should build toward) -| Component | Class | Implementation | -|-----------|-------|----------------| -| Gate | None | All text passes through. | -| Decoder | `TextDecoder` | Identity transform: `bytes.decode("utf-8")` -> `CognitiveEvent`. | -| Encoder | `TextEncoder` | Identity transform: `intent.content.encode("utf-8")` -> `EncodedOutput`. | - -Text exists so it is a first-class modality on the bus, not a special case. - -## Integration Points - -### MCP Server (server.py) - -The MCP server creates the bus singleton at module level: - -```python -_bus = _create_bus() # ModalityBus with VoiceModule(decoder=PlaceholderDecoder()) ``` - -MCP tools (`speak`, `diagnostics`, `vad_check`) use `_bus` for voice state tracking, health reports, and VAD. The `speak` tool resolves voices through the bus's voice module, sets encoder state, and uses the engine directly for synthesis (the adaptive player handles local playback). - -The `diagnostics` tool returns `_bus.health()` and `_bus.hud()`. - -### HTTP API (http_api.py) - -The HTTP API imports the bus singleton from the MCP server: - -```python -from server import _bus as _shared_bus # Shared instance when co-hosted -_bus = _shared_bus # Falls back to fresh ModalityBus if import fails +Browser ──WebSocket──► BrowserChannel + │ + ┌────▼────┐ + │ INPUT │ (fast, non-blocking) + │ QUEUE │ CognitiveEvents + └────┬────┘ + │ + ┌────▼────┐ + │ AGENT │ (owns conversation, calls LLM) + │ LOOP │ processes events sequentially + │ │ but NEVER blocks on output + └────┬────┘ + │ + tool calls (non-blocking) + │ + ┌────────────┼────────────┐ + │ │ │ + speak(text) send_text() think() + │ │ │ + ▼ ▼ │ + ┌──────────┐ ┌──────────┐ (log) + │ OUTPUT │ │ channel │ + │ QUEUE │ │ .deliver │ + │ (async) │ │ (JSON) │ + └────┬─────┘ └──────────┘ + │ + ┌────▼─────┐ + │ TTS │ (background thread) + │ Kokoro │ + └────┬─────┘ + │ + ch.deliver(base64 WAV) + │ + ▼ + Browser playback ``` -It ensures both Text and Voice modules are registered, then exposes the bus directly: - -| Endpoint | Bus Method | -|----------|------------| -| `GET /v1/bus/hud` | `_bus.hud()` | -| `GET /v1/bus/health` | `_bus.health()` | -| `POST /v1/bus/perceive` | `_bus.perceive(raw, modality, channel)` | -| `POST /v1/bus/act` | `_bus.act(intent, channel, blocking=True)` | -| `GET /health` | includes `_bus.health()` and `_bus.hud()` | - -When running with `--all`, both MCP and HTTP share the same bus instance and model cache. - -## Adding a New Modality - -1. **Create `modules/your_modality.py`** -- implement `Gate`, `Decoder`, `Encoder` (all optional), and a `ModalityModule` subclass that wires them together. See `modules/text.py` for the minimal case or `modules/voice.py` for the full pattern. - -2. **Add the modality type** to `ModalityType` in `modality.py` if needed. `VISION` and `SPATIAL` are already defined. - -3. **Register with the bus** where it is created (`server.py` and/or `http_api.py`): - ```python - bus.register(VisionModule()) - bus.register_channel("webcam-feed", [ModalityType.VISION]) - ``` - -4. **No routing changes needed.** The bus auto-routes `act()` based on channel capabilities. The HTTP API's `/v1/bus/perceive` and `/v1/bus/act` already accept any registered modality via the `modality` parameter. +Key principle: **the agent never waits for output delivery.** +speak() queues a TTS job and returns immediately. +The bus handles encoding and delivery asynchronously. + +## Files + +| File | Role | Lines | Status | +|------|------|-------|--------| +| `providers.py` | InferenceProvider: MLX, Ollama, CogOS | ~450 | Working | +| `channels.py` | BrowserChannel: WebSocket ↔ bus | ~260 | Working (fragile) | +| `agent_loop.py` | Event → LLM → tool dispatch | ~160 | Working (blocks on TTS) | +| `dashboard/index.html` | UI: chat, VAD, settings | ~700 | Working | +| `dashboard/transport.js` | WebSocket framing | ~100 | Working | +| `dashboard/playback.js` | Web Audio playback | ~113 | Working | +| `http_api.py` | WebSocket endpoint, static serving | +70 | Working | +| `server.py` | --dashboard startup mode | +12 | Working | +| `modules/voice.py` | VoiceGate, WhisperDecoder, VoiceEncoder | 309 | Working (not used for dashboard STT) | +| `bus.py` | ModalityBus: perceive/act, OutputQueue | 318 | Working | diff --git a/adaptive_player.py b/adaptive_player.py index 89b8062..7638437 100644 --- a/adaptive_player.py +++ b/adaptive_player.py @@ -263,13 +263,36 @@ def wait(self, timeout: float = 120.0) -> PlaybackMetrics: # Internal # ------------------------------------------------------------------ + def _resolve_device(self): + """Resolve the output device, falling back to system default if unavailable.""" + if self.device is None: + return None # sounddevice uses system default + + try: + devices = sd.query_devices() + if isinstance(self.device, int): + if self.device < len(devices): + info = devices[self.device] + if info["max_output_channels"] > 0: + return self.device + elif isinstance(self.device, str): + for i, d in enumerate(devices): + if self.device in d["name"] and d["max_output_channels"] > 0: + return i + except Exception: + pass + + # Device unavailable — fall back to system default. + return None + def _start_stream(self): self._stream_finished.clear() + resolved = self._resolve_device() self._stream = sd.OutputStream( samplerate=self.sample_rate, channels=1, dtype="float32", - device=self.device, + device=resolved, callback=self._callback, finished_callback=self._on_stream_finished, blocksize=self.buffer_size, diff --git a/agent_loop.py b/agent_loop.py index 449fcba..2f28aad 100644 --- a/agent_loop.py +++ b/agent_loop.py @@ -11,11 +11,12 @@ import logging import os import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import httpx from bus import ModalityBus +from draft_queue import DraftQueue from modality import CognitiveEvent, CognitiveIntent, ModalityType from pipeline_state import PipelineState from providers import AGENT_TOOLS, InferenceProvider @@ -82,10 +83,11 @@ def _fetch_kernel_context() -> str: interrupted = signal.get("interrupted") if interrupted: delivered = interrupted.get("delivered_text", "") + full = interrupted.get("full_text", "") pct = interrupted.get("spoken_pct", 0) parts.append( - f"[barge-in] Claude's speech was interrupted at {pct * 100:.0f}%. " - f'Delivered: "{delivered}". ' + f"[barge-in] Claude's speech was interrupted at {pct*100:.0f}%. " + f"Delivered: \"{delivered}\". " f"The user interrupted to say something — acknowledge and respond to them." ) except Exception: @@ -122,7 +124,6 @@ def _log_exchange_to_bus(user_text: str, assistant_text: str, provider_name: str except Exception as e: logger.debug("Failed to log exchange to bus: %s", e) - MAX_HISTORY = 50 @@ -143,6 +144,9 @@ def __init__( self.conversation: list[dict[str, str]] = [] self._channel_ref: BrowserChannel | None = None self._processing = False + self.draft_queue = DraftQueue() + self._speculative_context: list[dict[str, str]] = [] # Context for speculative inference + self._human_speaking = False # Whether human is currently speaking async def handle_event(self, event: CognitiveEvent) -> None: """Called when a CognitiveEvent arrives from the channel.""" @@ -169,6 +173,13 @@ async def handle_event(self, event: CognitiveEvent) -> None: async def _process(self, event: CognitiveEvent) -> None: """Core: event → provider → tool dispatch.""" + # Context stitching: inject interrupt context from dashboard path + # This closes the barge-in loop — the agent knows what was spoken, + # what was unsaid, and what the user interrupted with. + interrupt_context = self._build_interrupt_context(event.content) + if interrupt_context: + self.conversation.append({"role": "system", "content": interrupt_context}) + self.conversation.append({"role": "user", "content": event.content}) self._trim_history() @@ -203,9 +214,7 @@ async def _process(self, event: CognitiveEvent) -> None: content=text, target_channel=self.channel_id, metadata={ - "voice": self._channel_ref.config.get("voice", "bm_lewis") - if self._channel_ref - else "bm_lewis", + "voice": self._channel_ref.config.get("voice", "bm_lewis") if self._channel_ref else "bm_lewis", "speed": self._channel_ref.config.get("speed", 1.25) if self._channel_ref else 1.25, }, ) @@ -240,12 +249,10 @@ async def _process(self, event: CognitiveEvent) -> None: # Update conversation history if assistant_parts: assistant_text = " ".join(assistant_parts) - self.conversation.append( - { - "role": "assistant", - "content": assistant_text, - } - ) + self.conversation.append({ + "role": "assistant", + "content": assistant_text, + }) # Log exchange to CogOS bus (observation channel — Claude can see this) _log_exchange_to_bus(event.content, assistant_text, self.provider.name) @@ -256,6 +263,278 @@ async def _process(self, event: CognitiveEvent) -> None: metrics={"llm_ms": round(t_llm, 1), "provider": self.provider.name} ) + async def speculative_infer(self, committed_text: str) -> None: + """D2: Speculative inference trigger. + + When T3 commits a sentence while the human is still speaking, + launch background inference with context-so-far. Store result + in the DraftQueue. Does NOT play — just buffers. + """ + if not committed_text.strip(): + return + + logger.info("speculative_infer: '%s'", committed_text[:80]) + + # Build speculative conversation with committed text so far + spec_messages = list(self.conversation) + [ + {"role": "user", "content": committed_text}, + ] + + try: + t_start = time.perf_counter() + kernel_ctx = _fetch_kernel_context() + system_prompt = _BASE_SYSTEM_PROMPT + kernel_ctx + + response = await self.provider.chat( + messages=spec_messages, + tools=AGENT_TOOLS, + system=system_prompt, + ) + + t_ms = (time.perf_counter() - t_start) * 1000 + + # Extract response text + response_text = "" + for tc in response.tool_calls: + if tc.name == "speak": + response_text += tc.arguments.get("text", "") + " " + if not response_text and response.text: + response_text = response.text + + response_text = response_text.strip() + if not response_text: + return + + # Add to draft queue + import hashlib + ctx_hash = hashlib.md5(committed_text.encode()).hexdigest()[:8] + block = self.draft_queue.add_block( + text=response_text, + context_hash=ctx_hash, + generation_ms=t_ms, + ) + + logger.info( + "speculative block %s: '%s' (%.0fms)", + block.id, response_text[:60], t_ms, + ) + + # F2: Speculative TTS pre-synthesis + # Generate audio immediately but don't play + await self._presynthesise_block(block) + + # Notify dashboard of draft queue state + if self._channel_ref: + await self._channel_ref.ws.send_json({ + "type": "draft_queue", + "blocks": [b.to_dict() for b in self.draft_queue.get_pending()], + }) + + except Exception as e: + logger.debug("speculative_infer failed: %s", e) + + async def self_barge_snip(self, block_id: str) -> bool: + """E1: Remove a queued block that's no longer relevant.""" + result = self.draft_queue.snip(block_id) + if result: + logger.info("self-barge: snipped block %s", block_id) + await self._push_draft_queue_state() + return result + + async def self_barge_inject(self, position: int, text: str) -> None: + """E1: Insert a new block at position.""" + block = self.draft_queue.inject(position, text) + logger.info("self-barge: injected block %s at pos %d", block.id, position) + # Pre-synthesize the new block + await self._presynthesise_block(block) + await self._push_draft_queue_state() + + async def self_barge_revise(self, block_id: str, new_text: str) -> bool: + """E1: Replace a block's content and re-synthesize TTS.""" + result = self.draft_queue.revise(block_id, new_text) + if result: + logger.info("self-barge: revised block %s -> '%s'", block_id, new_text[:60]) + # Find the block and re-synthesize + for block in self.draft_queue.all_blocks: + if block.id == block_id: + await self._presynthesise_block(block) + break + await self._push_draft_queue_state() + return result + + async def _push_draft_queue_state(self) -> None: + """Push current draft queue state to the dashboard.""" + if self._channel_ref: + try: + await self._channel_ref.ws.send_json({ + "type": "draft_queue", + "blocks": [b.to_dict() for b in self.draft_queue.all_blocks], + }) + except Exception: + pass + + async def invalidate_stale_drafts(self, new_context: str) -> int: + """D3: Draft block invalidation. + + When a new T3 sentence arrives, check if existing draft blocks + are still valid given the updated context. Mark stale ones. + + Uses context hash comparison: if a block was generated with + different context than what we have now, it's potentially stale. + + Returns count of invalidated blocks. + """ + import hashlib + + new_hash = hashlib.md5(new_context.encode()).hexdigest()[:8] + invalidated = 0 + + for block in self.draft_queue.get_pending(): + if block.context_hash and block.context_hash != new_hash: + self.draft_queue.invalidate(block.id) + invalidated += 1 + logger.info("invalidated stale draft block %s (context changed)", block.id) + + if invalidated > 0 and self._channel_ref: + try: + await self._channel_ref.ws.send_json({ + "type": "draft_queue", + "blocks": [b.to_dict() for b in self.draft_queue.all_blocks], + }) + except Exception: + pass + + return invalidated + + async def _presynthesise_block(self, block) -> None: + """F2: Pre-synthesize TTS audio for a draft block. + + Generates audio immediately and attaches it to the block. + Ready for instant playback when the human stops speaking. + """ + from modules.voice import VoiceEncoder, _encode_wav + + try: + voice = "bm_lewis" + speed = 1.25 + if self._channel_ref: + voice = self._channel_ref.config.get("voice", "bm_lewis") + speed = self._channel_ref.config.get("speed", 1.25) + + def _synth(): + from engine import synthesize + samples, sample_rate = synthesize( + block.text, + voice=voice, + speed=speed, + ) + wav_bytes = _encode_wav(samples, sample_rate) + duration = len(samples) / sample_rate + return wav_bytes, duration + + wav_bytes, duration = await asyncio.to_thread(_synth) + block.tts_audio = wav_bytes + block.tts_duration_sec = duration + logger.info("pre-synthesized block %s: %.1fs audio", block.id, duration) + + except Exception as e: + logger.debug("pre-synthesis failed for block %s: %s", block.id, e) + + async def background_validate_drafts(self, latest_user_text: str) -> None: + """E2: Background validation loop. + + After each new human sentence, re-evaluate all queued draft blocks. + Snips/revises if context has invalidated them. This runs between + TTS synthesis and playback — the revision window. + """ + pending = self.draft_queue.get_pending() + if not pending: + return + + logger.info("background_validate: checking %d pending blocks", len(pending)) + + # First, invalidate any blocks whose context is clearly stale + await self.invalidate_stale_drafts(latest_user_text) + + # Then re-evaluate remaining valid blocks + still_pending = self.draft_queue.get_pending() + if not still_pending: + return + + # Build context with latest human input + check_messages = list(self.conversation) + [ + {"role": "user", "content": latest_user_text}, + ] + + for block in still_pending: + try: + # Quick relevance check: ask the model if this block is still appropriate + check_prompt = ( + f"Given the user just said: \"{latest_user_text}\"\n" + f"Is this planned response still appropriate? " + f"Response: \"{block.text}\"\n" + f"Answer KEEP or REVISE in one word." + ) + + response = await self.provider.chat( + messages=[{"role": "user", "content": check_prompt}], + tools=[], + system="You are evaluating whether a planned response is still valid. Answer KEEP or REVISE.", + ) + + answer = (response.text or "").strip().upper() + if "REVISE" in answer: + logger.info("background_validate: block %s needs revision", block.id) + self.draft_queue.invalidate(block.id) + else: + logger.debug("background_validate: block %s still valid", block.id) + + except Exception as e: + logger.debug("background_validate error for block %s: %s", block.id, e) + + await self._push_draft_queue_state() + + def _build_interrupt_context(self, user_text: str) -> str | None: + """Build context stitch from pipeline_state.last_interrupt. + + When the user barged in during TTS playback, captures what was + spoken vs unspoken and injects it as structured context for the + next inference call. Consumes the interrupt (clears it). + + Returns a context string, or None if no interrupt occurred. + """ + info = self.pipeline_state.last_interrupt + if info is None: + return None + + # Only use recent interrupts (within last 30 seconds) + if time.time() - info.timestamp > 30: + return None + + # Clear the interrupt so we don't re-inject it + with self.pipeline_state._lock: + self.pipeline_state._last_interrupt = None + + # Compute unspoken remainder + unspoken = "" + if info.full_text and info.delivered_text: + if info.full_text.startswith(info.delivered_text): + unspoken = info.full_text[len(info.delivered_text):].strip() + else: + # Fallback: everything after the delivered percentage + unspoken = info.full_text[len(info.delivered_text):].strip() + + parts = [] + parts.append("[Barge-in context — your previous response was interrupted]") + parts.append(f"spoken (user heard this): \"{info.delivered_text}\"") + if unspoken: + parts.append(f"unspoken (user did NOT hear this): \"{unspoken}\"") + parts.append(f"interrupted_at: {info.spoken_pct*100:.0f}%") + parts.append(f"user_said: \"{user_text}\"") + parts.append("Acknowledge what was interrupted and respond to the user's new input.") + + return "\n".join(parts) + def _trim_history(self) -> None: """Keep conversation within MAX_HISTORY messages.""" if len(self.conversation) > MAX_HISTORY: diff --git a/channels.py b/channels.py index d0ab8c8..8480953 100644 --- a/channels.py +++ b/channels.py @@ -3,6 +3,11 @@ Wraps a FastAPI WebSocket connection as a ChannelDescriptor on the bus. Knows the WebSocket protocol (binary PCM / JSON control frames), knows nothing about LLMs or agent logic. + +Includes three-tier adaptive STT scheduler: + T1 (Whisper Base, ~31ms): per-chunk during speech + T2 (Whisper Large, ~470ms): on natural pause + T3 (Whisper Large, ~470ms): on end-of-utterance (final) """ from __future__ import annotations @@ -14,10 +19,12 @@ import uuid from typing import Any, Awaitable, Callable +import numpy as np from fastapi import WebSocket, WebSocketDisconnect from bus import ModalityBus from modality import CognitiveEvent, EncodedOutput, ModalityType +from modules.voice import WhisperDecoder from pipeline_state import PipelineState logger = logging.getLogger("mod3.channels") @@ -48,6 +55,16 @@ def __init__( self._audio_buffer = bytearray() self._active = True + # Three-tier STT state + self._streaming_decoder = WhisperDecoder(load_base=True) + self._streaming_audio = bytearray() # Growing buffer for streaming STT + self._last_t1_time = 0.0 # Last T1 transcription time + self._last_speech_time = 0.0 # Last time we received speech audio + self._t1_interval = 0.3 # Run T1 every 300ms + self._t2_pause_threshold = 0.6 # Run T2 after 600ms pause + self._is_speaking = False # Whether user is currently speaking + self._t2_scheduled = False # Whether T2 is already scheduled + # Register on the bus with a delivery callback bus.register_channel( self.channel_id, @@ -65,7 +82,9 @@ def _deliver_sync(self, output: EncodedOutput) -> None: if not self._active: return try: - future = asyncio.run_coroutine_threadsafe(self._deliver_async(output), self._loop) + future = asyncio.run_coroutine_threadsafe( + self._deliver_async(output), self._loop + ) future.result(timeout=10.0) except (WebSocketDisconnect, RuntimeError, TimeoutError): logger.debug("deliver failed (client disconnected?), deactivating channel") @@ -87,15 +106,13 @@ async def _deliver_async(self, output: EncodedOutput) -> None: # Send audio as base64 JSON (avoids binary frame issues) audio_b64 = base64.b64encode(output.data).decode("ascii") logger.info("deliver: sending base64 audio JSON (%d chars)", len(audio_b64)) - await self.ws.send_json( - { - "type": "audio", - "data": audio_b64, - "format": output.format or "wav", - "duration_sec": round(output.duration_sec, 2), - "sample_rate": output.metadata.get("sample_rate", 24000), - } - ) + await self.ws.send_json({ + "type": "audio", + "data": audio_b64, + "format": output.format or "wav", + "duration_sec": round(output.duration_sec, 2), + "sample_rate": output.metadata.get("sample_rate", 24000), + }) logger.info("deliver: audio sent OK") elif output.modality == ModalityType.TEXT: text = output.data.decode("utf-8") if isinstance(output.data, bytes) else str(output.data) @@ -128,8 +145,29 @@ async def run(self) -> None: self._cleanup() def _handle_audio(self, pcm_bytes: bytes) -> None: - """Binary frame: raw Int16 PCM at 16kHz from browser Silero VAD.""" + """Binary frame: raw Int16 PCM at 16kHz from browser Silero VAD. + + A5: Receives streaming audio during speech (from onFrameProcessed) + AND the final complete buffer (from onSpeechEnd). Both accumulate + for the final T3 utterance processing. + + During speech, audio also accumulates in _streaming_audio for T1/T2 + partial transcription. + """ self._audio_buffer.extend(pcm_bytes) + self._streaming_audio.extend(pcm_bytes) + self._last_speech_time = time.monotonic() + self._is_speaking = True + + # T1: Fast Whisper Base transcription every _t1_interval + now = time.monotonic() + if now - self._last_t1_time >= self._t1_interval and len(self._streaming_audio) > 6400: + self._last_t1_time = now + asyncio.ensure_future(self._run_t1()) + + # Schedule T2 check on pause detection + if not self._t2_scheduled: + asyncio.ensure_future(self._schedule_t2_on_pause()) async def _handle_json(self, msg: dict) -> None: """JSON frame: control message dispatch.""" @@ -149,12 +187,92 @@ async def _handle_json(self, msg: dict) -> None: if key in msg: self.config[key] = msg[key] + # ------------------------------------------------------------------ + # Three-Tier STT + # ------------------------------------------------------------------ + + async def _run_t1(self) -> None: + """T1: Fast Whisper Base transcription on growing audio buffer (~31ms). + + Runs every ~300ms during speech. Emits partial_transcript with + confirmed/tentative text at 30% opacity. + """ + if not self._streaming_audio: + return + + pcm_data = bytes(self._streaming_audio) + + def _transcribe_t1(): + audio = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0 + if len(audio) < 4800: # <300ms + return None + return self._streaming_decoder.decode_streaming(audio, tier="t1") + + try: + result = await asyncio.to_thread(_transcribe_t1) + if result and result.get("changed") and not result.get("filtered"): + await self.ws.send_json({ + "type": "partial_transcript", + "confirmed": result["confirmed"], + "tentative": result["tentative"], + "tier": "t1", + "elapsed_ms": result["elapsed_ms"], + }) + except Exception as e: + logger.debug("T1 error: %s", e) + + async def _run_t2(self) -> None: + """T2: Large model transcription on natural pause (~470ms). + + Runs when speech pauses for >600ms but hasn't ended. Emits + partial_transcript with higher confidence (60% opacity). + """ + if not self._streaming_audio: + return + + pcm_data = bytes(self._streaming_audio) + + def _transcribe_t2(): + audio = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0 + if len(audio) < 8000: # <500ms + return None + return self._streaming_decoder.decode_streaming(audio, tier="t2") + + try: + result = await asyncio.to_thread(_transcribe_t2) + if result and not result.get("filtered"): + await self.ws.send_json({ + "type": "partial_transcript", + "confirmed": result["confirmed"], + "tentative": result["tentative"], + "tier": "t2", + "elapsed_ms": result["elapsed_ms"], + }) + except Exception as e: + logger.debug("T2 error: %s", e) + finally: + self._t2_scheduled = False + + async def _schedule_t2_on_pause(self) -> None: + """Check if speech has paused long enough for T2.""" + await asyncio.sleep(self._t2_pause_threshold) + if not self._is_speaking: + return + # Check if there's been a pause since last audio + silence = time.monotonic() - self._last_speech_time + if silence >= self._t2_pause_threshold and not self._t2_scheduled: + self._t2_scheduled = True + await self._run_t2() + # ------------------------------------------------------------------ # Processing # ------------------------------------------------------------------ async def _process_utterance(self) -> None: - """PCM audio buffer → WhisperDecoder STT → CognitiveEvent → agent loop. + """T3: PCM audio buffer → WhisperDecoder STT → CognitiveEvent → agent loop. + + This is the final tier — end-of-utterance. Uses the Large model for + maximum accuracy. Everything is confirmed (100% opacity). Skips the server-side VoiceGate (Silero VAD) because the browser already ran Silero VAD client-side — no need to validate again, @@ -163,6 +281,11 @@ async def _process_utterance(self) -> None: pcm_data = bytes(self._audio_buffer) self._audio_buffer.clear() + # Reset streaming state + self._streaming_audio.clear() + self._streaming_decoder.reset_streaming() + self._is_speaking = False + if len(pcm_data) < 6400: # <200ms at 16kHz Int16 return @@ -185,7 +308,7 @@ def _transcribe(): # Skip silence if len(audio) < 16000 * 0.3: return None - rms = float(np.sqrt(np.mean(audio**2))) + rms = float(np.sqrt(np.mean(audio ** 2))) if rms < 0.005: return None @@ -234,14 +357,12 @@ def _transcribe(): if event and event.content: # Send transcript to browser - await self.ws.send_json( - { - "type": "transcript", - "text": event.content, - "stt_ms": round(stt_ms, 1), - "source": "voice", - } - ) + await self.ws.send_json({ + "type": "transcript", + "text": event.content, + "stt_ms": round(stt_ms, 1), + "source": "voice", + }) # Forward to agent loop event.metadata["stt_ms"] = stt_ms if self._on_event: @@ -255,13 +376,11 @@ async def _process_text(self, text: str) -> None: source_channel=self.channel_id, confidence=1.0, ) - await self.ws.send_json( - { - "type": "transcript", - "text": text, - "source": "text", - } - ) + await self.ws.send_json({ + "type": "transcript", + "text": text, + "source": "text", + }) if self._on_event: await self._on_event(event) @@ -288,12 +407,10 @@ async def send_response_complete(self, metrics: dict | None = None) -> None: """Signal response is complete.""" if self._active: try: - await self.ws.send_json( - { - "type": "response_complete", - "metrics": metrics or {}, - } - ) + await self.ws.send_json({ + "type": "response_complete", + "metrics": metrics or {}, + }) except Exception: self._active = False diff --git a/dashboard/index.html b/dashboard/index.html index 6b1f9f8..914a2b0 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -106,6 +106,71 @@ /* Headphone hint */ .hint { font-size: 0.7rem; color: var(--muted); padding: 4px 0; flex-shrink: 0; } + /* Opacity-as-state: three visual states for text blocks */ + .opacity-inflight { opacity: 0.3; transition: opacity 0.3s ease; } + .opacity-corrected { opacity: 0.6; transition: opacity 0.3s ease; } + .opacity-committed { opacity: 1.0; transition: opacity 0.3s ease; } + + /* Word-level solidification spans */ + .voice-word { + display: inline; + transition: opacity 0.2s ease; + } + .voice-word.spoken { opacity: 1.0; } + .voice-word.speaking { opacity: 0.85; color: var(--accent); } + .voice-word.unspoken { opacity: 0.3; } + + /* Draft queue preview blocks */ + .draft-preview { + opacity: 0.3; + padding: 6px 12px; + margin-top: 4px; + border-left: 2px solid var(--accent); + border-radius: 4px; + font-size: 0.85rem; + color: var(--muted); + transition: opacity 0.3s ease; + } + .draft-preview.validated { opacity: 0.6; } + .draft-preview.stale { + opacity: 0.15; + text-decoration: line-through; + border-left-color: var(--orange); + } + .draft-preview.revised { + border-left-color: var(--green); + animation: revision-flash 0.6s ease; + } + @keyframes revision-flash { + 0% { opacity: 0.8; border-left-color: var(--green); } + 100% { opacity: 0.3; border-left-color: var(--accent); } + } + + /* Partial transcript (assembling) */ + .partial-transcript { + font-size: 0.85rem; + color: var(--muted); + padding: 4px 0; + min-height: 1.3em; + } + .partial-confirmed { opacity: 0.6; color: var(--text); } + .partial-tentative { opacity: 0.3; color: var(--muted); font-style: italic; } + + /* Interruption marker */ + .interrupt-marker { + display: inline-block; + width: 2px; + height: 1em; + background: var(--red); + margin: 0 2px; + vertical-align: middle; + animation: blink 1s ease-in-out 3; + } + @keyframes blink { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.3; } + } + /* Responsive */ @media (max-width: 700px) { .main { padding: 12px 16px; } @@ -214,9 +279,15 @@

Mod³

+ + +
+ + +
@@ -379,6 +450,49 @@

Mod³

} let _currentAssistantMsg = null; +let _currentAssistantText = ''; // Full text of current response for word tracking + +/** + * C2: Word-level timing estimation. + * Given total audio duration and full text, estimates the time position + * of each word assuming uniform speech rate. + * + * Returns array of { word, startSec, endSec, index } + */ +function estimateWordTimings(text, totalDurationSec) { + const words = text.split(/\s+/).filter(w => w.length > 0); + if (words.length === 0) return []; + + // Weight by character count (longer words take more time) + const totalChars = words.reduce((sum, w) => sum + w.length, 0); + let currentTime = 0; + + return words.map((word, index) => { + const fraction = word.length / totalChars; + const duration = fraction * totalDurationSec; + const entry = { + word, + startSec: currentTime, + endSec: currentTime + duration, + index, + }; + currentTime += duration; + return entry; + }); +} + +/** + * Get the word index being spoken at a given playback time. + */ +function getWordAtTime(timings, currentTimeSec) { + for (let i = timings.length - 1; i >= 0; i--) { + if (currentTimeSec >= timings[i].startSec) return i; + } + return 0; +} + +let _wordTimings = []; // Current word timings for solidification +let _solidificationActive = false; function appendAssistantStart(source) { const div = document.createElement('div'); @@ -393,6 +507,7 @@

Mod³

chatArea.appendChild(div); chatArea.scrollTop = chatArea.scrollHeight; _currentAssistantMsg = div.querySelector('.msg-text'); + _currentAssistantText = ''; return div; } @@ -401,15 +516,72 @@

Mod³

if (!_currentAssistantMsg) { appendAssistantStart(); } - _currentAssistantMsg.innerHTML += escapeHtml(text); + _currentAssistantText += text; + + // C4: Wrap each word in a span for solidification animation + const words = text.split(/(\s+)/); + words.forEach(segment => { + if (/^\s+$/.test(segment)) { + _currentAssistantMsg.appendChild(document.createTextNode(segment)); + } else if (segment) { + const span = document.createElement('span'); + span.className = 'voice-word unspoken'; + span.textContent = segment; + _currentAssistantMsg.appendChild(span); + } + }); + chatArea.scrollTop = chatArea.scrollHeight; } +function startSolidification(totalDurationSec) { + /** + * C4: Begin solidification animation — words solidify left-to-right + * tracking the audio playback position. Uses C1 progress + C2 timing. + */ + if (!_currentAssistantMsg || !_currentAssistantText) return; + + _wordTimings = estimateWordTimings(_currentAssistantText, totalDurationSec); + _solidificationActive = true; +} + +// Listen for playback progress to drive solidification +window.addEventListener('playback-progress', (e) => { + if (!_solidificationActive || !_currentAssistantMsg) return; + + const { progress } = e.detail; + const currentTime = progress * (_wordTimings.length > 0 ? + _wordTimings[_wordTimings.length - 1].endSec : 0); + + const wordSpans = _currentAssistantMsg.querySelectorAll('.voice-word'); + const currentWordIdx = getWordAtTime(_wordTimings, currentTime); + + wordSpans.forEach((span, i) => { + if (i < currentWordIdx) { + span.className = 'voice-word spoken'; + } else if (i === currentWordIdx) { + span.className = 'voice-word speaking'; + } else { + span.className = 'voice-word unspoken'; + } + }); +}); + function finalizeAssistant(suffix) { - if (_currentAssistantMsg && suffix) { - _currentAssistantMsg.innerHTML += ` ${escapeHtml(suffix)}`; + _solidificationActive = false; + _wordTimings = []; + + if (_currentAssistantMsg) { + // Mark all words as spoken (committed) + _currentAssistantMsg.querySelectorAll('.voice-word').forEach(span => { + span.className = 'voice-word spoken'; + }); + if (suffix) { + _currentAssistantMsg.innerHTML += ` ${escapeHtml(suffix)}`; + } } _currentAssistantMsg = null; + _currentAssistantText = ''; } async function sendTextMessage(text) { @@ -471,19 +643,51 @@

Mod³

} }; + // Wire up playback progress for solidification + pb.onProgress = (samplesPlayed, totalSamples) => { + // Will be used by C4 solidification animation (Wave 2) + const pct = totalSamples > 0 ? samplesPlayed / totalSamples : 0; + // Emit custom event for word-level tracking + window.dispatchEvent(new CustomEvent('playback-progress', { + detail: { samplesPlayed, totalSamples, progress: pct } + })); + }; + const t = new VoiceTransport(wsUrl, { onAudio: (data) => { pb.enqueueWav(data); + // Start solidification when first audio arrives + if (_currentAssistantText && !_solidificationActive) { + // Estimate total duration from accumulated audio + // (will refine as more chunks arrive) + startSolidification(pb.totalDuration || 2.0); + } const el = document.getElementById('voice-status'); if (el) { el.textContent = 'Speaking...'; el.style.color = 'var(--accent)'; } }, onTranscript: (msg) => { if (msg.source !== 'text') appendMessage('user', msg.text, 'voice'); - // Don't create assistant bubble here — appendAssistantChunk lazy-creates it + // Hide partial transcript when final arrives + const ptEl = document.getElementById('partial-transcript'); + if (ptEl) { ptEl.style.display = 'none'; ptEl.innerHTML = ''; } const el = document.getElementById('voice-status'); if (el) { el.textContent = 'Thinking...'; el.style.color = 'var(--orange)'; } }, - onResponseText: (msg) => appendAssistantChunk(msg.text), + onPartialTranscript: (msg) => { + // Show partial transcript with confirmed/tentative styling + const ptEl = document.getElementById('partial-transcript'); + if (ptEl) { + ptEl.style.display = ''; + let html = ''; + if (msg.confirmed) html += `${escapeHtml(msg.confirmed)} `; + if (msg.tentative) html += `${escapeHtml(msg.tentative)}`; + ptEl.innerHTML = html || '...'; + } + }, + onResponseText: (msg) => { + pb.resetProgress(); // Reset progress for new response + appendAssistantChunk(msg.text); + }, onResponseComplete: (msg) => { finalizeAssistant(); if (msg.metrics) { @@ -500,6 +704,67 @@

Mod³

pb.flush(); finalizeAssistant('[interrupted]'); }, + onDraftQueue: (msg) => { + /** + * E3 + D4: Self-barge visual feedback + Queue preview UI. + * + * When the agent revises queued output, the transparent text visibly + * changes in the dashboard. Each block gets a data-block-id attribute + * for targeted updates. Revised blocks flash briefly to draw attention. + */ + const dqEl = document.getElementById('draft-queue-preview'); + if (!dqEl) return; + if (msg.blocks && msg.blocks.length > 0) { + dqEl.style.display = ''; + + msg.blocks.forEach(b => { + const existing = dqEl.querySelector(`[data-block-id="${b.id}"]`); + if (existing) { + // Block exists — check if content changed (self-barge revision) + const oldText = existing.getAttribute('data-text') || ''; + if (oldText !== b.text) { + // E3: Content changed — flash animation to show revision + existing.innerHTML = escapeHtml(b.text); + existing.setAttribute('data-text', b.text); + existing.style.transition = 'none'; + existing.style.borderLeftColor = 'var(--green)'; + existing.style.opacity = '0.8'; + requestAnimationFrame(() => { + existing.style.transition = 'opacity 0.5s ease, border-left-color 0.5s ease'; + existing.style.opacity = b.status === 'stale' ? '0.15' : '0.3'; + existing.style.borderLeftColor = b.status === 'stale' ? 'var(--orange)' : 'var(--accent)'; + }); + } + // Update status class + existing.className = b.status === 'stale' ? 'draft-preview stale' : + b.status === 'valid' ? 'draft-preview' : 'draft-preview validated'; + } else { + // New block — create element + const div = document.createElement('div'); + div.className = b.status === 'stale' ? 'draft-preview stale' : + b.status === 'valid' ? 'draft-preview' : 'draft-preview validated'; + div.setAttribute('data-block-id', b.id); + div.setAttribute('data-text', b.text); + div.textContent = b.text; + dqEl.appendChild(div); + } + }); + + // Remove blocks that no longer exist + const currentIds = new Set(msg.blocks.map(b => b.id)); + dqEl.querySelectorAll('[data-block-id]').forEach(el => { + if (!currentIds.has(el.getAttribute('data-block-id'))) { + // Snipped block — fade out + el.style.transition = 'opacity 0.3s ease'; + el.style.opacity = '0'; + setTimeout(() => el.remove(), 300); + } + }); + } else { + dqEl.style.display = 'none'; + dqEl.innerHTML = ''; + } + }, onMetrics: (msg) => { if (msg.sample_rate && msg.sample_rate !== pb.sampleRate) pb.setSampleRate(msg.sample_rate); }, @@ -632,7 +897,27 @@

Mod³

preSpeechPadFrames: 8, onSpeechStart: () => { console.log('[Silero VAD] Speech START'); + // B3: Barge-in visual state if (_playback && _playback.isPlaying) { + // Freeze unspoken words at 30% opacity with interrupt marker + if (_currentAssistantMsg && _solidificationActive) { + const wordSpans = _currentAssistantMsg.querySelectorAll('.voice-word'); + let interrupted = false; + wordSpans.forEach(span => { + if (span.classList.contains('unspoken') || span.classList.contains('speaking')) { + span.classList.remove('speaking'); + span.classList.add('unspoken'); + if (!interrupted) { + // Insert interrupt marker before the first unspoken word + const marker = document.createElement('span'); + marker.className = 'interrupt-marker'; + span.parentNode.insertBefore(marker, span); + interrupted = true; + } + } + }); + _solidificationActive = false; + } _playback.flush(); if (_transport) _transport.interrupt(); } @@ -653,9 +938,19 @@

Mod³

onVADMisfire: () => { if (micDebug) micDebug.textContent = 'VAD: misfire (too short)'; }, - onFrameProcessed: (probs) => { + onFrameProcessed: (probs, audioFrame) => { if (levelBar) levelBar.style.width = Math.min(100, probs.isSpeech * 100) + '%'; if (micDebug) micDebug.textContent = `silero: ${probs.isSpeech.toFixed(3)} thr=${vadThreshold}`; + + // A5: Stream audio chunks during speech for server-side streaming STT + // Send frames when speech is detected (server accumulates for T1/T2) + if (probs.isSpeech > vadThreshold && audioFrame && _transport && _transport.connected) { + const int16 = new Int16Array(audioFrame.length); + for (let i = 0; i < audioFrame.length; i++) { + int16[i] = Math.max(-32768, Math.min(32767, audioFrame[i] * 32768)); + } + _transport.sendAudio(int16.buffer); + } }, }); capture.start(); diff --git a/dashboard/playback.js b/dashboard/playback.js index 82be279..9849b97 100644 --- a/dashboard/playback.js +++ b/dashboard/playback.js @@ -1,6 +1,8 @@ /** * Streaming audio playback engine. * Receives Int16 PCM chunks and plays them seamlessly via Web Audio API. + * Tracks playback progress (samplesPlayed/totalSamples) for word-level + * solidification animation. */ class AudioPlayback { constructor(sampleRate = 24000) { @@ -12,7 +14,32 @@ class AudioPlayback { this.nextStartTime = 0; this.onPlaybackStart = null; this.onPlaybackEnd = null; + this.onProgress = null; // (samplesPlayed, totalSamples) => void this.sinkId = undefined; // output device ID + + // Progress tracking + this.totalSamples = 0; // Total samples across all queued buffers + this.samplesPlayed = 0; // Samples played so far + this._chunkStartSample = 0; // Sample offset of current chunk + this._currentChunkSamples = 0; + this._playbackStartTime = 0; // audioContext.currentTime when chunk started + this._progressTimer = null; + } + + /** Current playback progress as 0.0-1.0 */ + get progress() { + if (this.totalSamples === 0) return 0; + return Math.min(1.0, this.samplesPlayed / this.totalSamples); + } + + /** Estimated current playback time in seconds */ + get currentTime() { + return this.samplesPlayed / this.sampleRate; + } + + /** Total duration in seconds of all queued audio */ + get totalDuration() { + return this.totalSamples / this.sampleRate; } _ensureContext() { @@ -48,6 +75,7 @@ class AudioPlayback { const buffer = this.audioContext.createBuffer(1, float32.length, this.sampleRate); buffer.getChannelData(0).set(float32); this.queue.push(buffer); + this.totalSamples += float32.length; if (!this.isPlaying) this._playNext(); } @@ -58,6 +86,7 @@ class AudioPlayback { try { const audioBuffer = await this.audioContext.decodeAudioData(wavArrayBuffer.slice(0)); this.queue.push(audioBuffer); + this.totalSamples += audioBuffer.length; if (!this.isPlaying) this._playNext(); } catch (err) { console.error("[AudioPlayback] Failed to decode WAV:", err); @@ -67,6 +96,9 @@ class AudioPlayback { _playNext() { if (this.queue.length === 0) { this.isPlaying = false; + this._stopProgressTimer(); + this.samplesPlayed = this.totalSamples; // Mark fully played + if (this.onProgress) this.onProgress(this.samplesPlayed, this.totalSamples); if (this.onPlaybackEnd) this.onPlaybackEnd(); return; } @@ -81,22 +113,65 @@ class AudioPlayback { const source = this.audioContext.createBufferSource(); source.buffer = buffer; source.connect(this.audioContext.destination); - source.onended = () => this._playNext(); + + // Track progress for this chunk + this._chunkStartSample = this.samplesPlayed; + this._currentChunkSamples = buffer.length; + + source.onended = () => { + // Mark chunk as fully played + this.samplesPlayed = this._chunkStartSample + this._currentChunkSamples; + if (this.onProgress) this.onProgress(this.samplesPlayed, this.totalSamples); + this._playNext(); + }; // Schedule this chunk right after the previous one for gapless playback const startTime = Math.max(this.nextStartTime, this.audioContext.currentTime); source.start(startTime); + this._playbackStartTime = startTime; this.nextStartTime = startTime + buffer.duration; this.currentSource = source; + + // Start progress timer for smooth updates during playback + this._startProgressTimer(); + } + + _startProgressTimer() { + this._stopProgressTimer(); + this._progressTimer = setInterval(() => { + if (!this.isPlaying || !this.audioContext) return; + const elapsed = this.audioContext.currentTime - this._playbackStartTime; + const chunkProgress = Math.min(elapsed * this.sampleRate, this._currentChunkSamples); + this.samplesPlayed = this._chunkStartSample + Math.floor(chunkProgress); + if (this.onProgress) this.onProgress(this.samplesPlayed, this.totalSamples); + }, 50); // 20 fps progress updates + } + + _stopProgressTimer() { + if (this._progressTimer) { + clearInterval(this._progressTimer); + this._progressTimer = null; + } } flush() { this.queue = []; + this._stopProgressTimer(); if (this.currentSource) { try { this.currentSource.stop(); } catch {} } this.isPlaying = false; this.nextStartTime = 0; + // Keep samplesPlayed/totalSamples for interrupt context + // (tells us how much was delivered before flush) + } + + /** Reset all progress counters (call when starting a new response) */ + resetProgress() { + this.totalSamples = 0; + this.samplesPlayed = 0; + this._chunkStartSample = 0; + this._currentChunkSamples = 0; } setSampleRate(rate) { diff --git a/dashboard/transport.js b/dashboard/transport.js index d0d02ad..ee8dc5c 100644 --- a/dashboard/transport.js +++ b/dashboard/transport.js @@ -63,9 +63,12 @@ class VoiceTransport { const handlerMap = { transcript: "onTranscript", + partial_transcript: "onPartialTranscript", response_text: "onResponseText", response_complete: "onResponseComplete", interrupted: "onInterrupted", + tts_progress: "onTtsProgress", + draft_queue: "onDraftQueue", metrics: "onMetrics", error: "onError", }; diff --git a/draft_queue.py b/draft_queue.py new file mode 100644 index 0000000..dccbc15 --- /dev/null +++ b/draft_queue.py @@ -0,0 +1,267 @@ +"""Draft Queue — speculative response blocks with status tracking. + +Holds draft response blocks generated speculatively while the human is +still speaking. Each block has a status lifecycle: + + valid → spoken (played aloud) + valid → stale (invalidated by new context) + valid → snipped (removed from queue by self-barge) + +Thread-safe. Used by the agent loop for speculative inference and +self-barge operations (snip, inject, revise). +""" + +from __future__ import annotations + +import threading +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class BlockStatus(Enum): + """Lifecycle states for a draft block.""" + VALID = "valid" # Generated, awaiting playback + STALE = "stale" # Invalidated by new context + SPOKEN = "spoken" # Successfully played aloud + SNIPPED = "snipped" # Removed by self-barge + SPEAKING = "speaking" # Currently being spoken + + +@dataclass +class DraftBlock: + """A single draft response block with metadata.""" + + id: str + text: str + status: BlockStatus = BlockStatus.VALID + created_at: float = field(default_factory=time.time) + context_hash: str = "" # Hash of context at generation time + generation_ms: float = 0.0 # How long inference took + tts_audio: bytes | None = None # Pre-synthesized audio (if available) + tts_duration_sec: float = 0.0 + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def is_playable(self) -> bool: + """Whether this block can be played.""" + return self.status == BlockStatus.VALID + + @property + def is_active(self) -> bool: + """Whether this block is still relevant (not stale/snipped).""" + return self.status in (BlockStatus.VALID, BlockStatus.SPEAKING) + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "text": self.text, + "status": self.status.value, + "created_at": self.created_at, + "has_audio": self.tts_audio is not None, + "tts_duration_sec": self.tts_duration_sec, + "generation_ms": self.generation_ms, + } + + +class DraftQueue: + """Thread-safe queue of speculative draft response blocks. + + The agent generates blocks speculatively while the human speaks. + Blocks are played in order when the human stops. Blocks can be + invalidated (stale), removed (snip), or replaced (revise) before + they're spoken. + + Operations: + add_block — append a new draft block + invalidate — mark a block as stale (context changed) + snip — remove a block from the queue + inject — insert a new block at a position + revise — replace a block's text (and optionally audio) + get_pending — get all valid blocks awaiting playback + mark_speaking — mark a block as currently being spoken + mark_spoken — mark a block as successfully spoken + clear — reset the queue + """ + + def __init__(self): + self._lock = threading.Lock() + self._blocks: list[DraftBlock] = [] + self._spoken_history: list[DraftBlock] = [] # Archive of spoken blocks + + # ------------------------------------------------------------------ + # Core operations + # ------------------------------------------------------------------ + + def add_block( + self, + text: str, + context_hash: str = "", + generation_ms: float = 0.0, + **metadata, + ) -> DraftBlock: + """Add a new draft block to the end of the queue.""" + block = DraftBlock( + id=uuid.uuid4().hex[:8], + text=text, + context_hash=context_hash, + generation_ms=generation_ms, + metadata=metadata, + ) + with self._lock: + self._blocks.append(block) + return block + + def invalidate(self, block_id: str) -> bool: + """Mark a block as stale. Returns True if found and invalidated.""" + with self._lock: + for block in self._blocks: + if block.id == block_id and block.is_active: + block.status = BlockStatus.STALE + return True + return False + + def invalidate_all(self) -> int: + """Mark all valid blocks as stale. Returns count invalidated.""" + count = 0 + with self._lock: + for block in self._blocks: + if block.status == BlockStatus.VALID: + block.status = BlockStatus.STALE + count += 1 + return count + + def snip(self, block_id: str) -> bool: + """Remove a block from the queue. Returns True if found.""" + with self._lock: + for i, block in enumerate(self._blocks): + if block.id == block_id: + block.status = BlockStatus.SNIPPED + self._blocks.pop(i) + return True + return False + + def inject( + self, + position: int, + text: str, + context_hash: str = "", + generation_ms: float = 0.0, + **metadata, + ) -> DraftBlock: + """Insert a new block at the given position.""" + block = DraftBlock( + id=uuid.uuid4().hex[:8], + text=text, + context_hash=context_hash, + generation_ms=generation_ms, + metadata=metadata, + ) + with self._lock: + self._blocks.insert(position, block) + return block + + def revise( + self, + block_id: str, + new_text: str, + new_audio: bytes | None = None, + new_duration: float = 0.0, + ) -> bool: + """Replace a block's content. Returns True if found and revised.""" + with self._lock: + for block in self._blocks: + if block.id == block_id and block.is_active: + block.text = new_text + if new_audio is not None: + block.tts_audio = new_audio + block.tts_duration_sec = new_duration + block.metadata["revised_at"] = time.time() + return True + return False + + # ------------------------------------------------------------------ + # Playback lifecycle + # ------------------------------------------------------------------ + + def get_pending(self) -> list[DraftBlock]: + """Get all valid blocks awaiting playback, in order.""" + with self._lock: + return [b for b in self._blocks if b.status == BlockStatus.VALID] + + def get_next(self) -> DraftBlock | None: + """Get the next valid block to play, or None.""" + with self._lock: + for block in self._blocks: + if block.status == BlockStatus.VALID: + return block + return None + + def mark_speaking(self, block_id: str) -> bool: + """Mark a block as currently being spoken.""" + with self._lock: + for block in self._blocks: + if block.id == block_id: + block.status = BlockStatus.SPEAKING + return True + return False + + def mark_spoken(self, block_id: str) -> bool: + """Mark a block as successfully spoken and archive it.""" + with self._lock: + for i, block in enumerate(self._blocks): + if block.id == block_id: + block.status = BlockStatus.SPOKEN + self._spoken_history.append(block) + self._blocks.pop(i) + return True + return False + + # ------------------------------------------------------------------ + # Query + # ------------------------------------------------------------------ + + @property + def depth(self) -> int: + """Number of blocks in the queue (all statuses).""" + with self._lock: + return len(self._blocks) + + @property + def pending_count(self) -> int: + """Number of valid (playable) blocks.""" + with self._lock: + return sum(1 for b in self._blocks if b.status == BlockStatus.VALID) + + @property + def all_blocks(self) -> list[DraftBlock]: + """Snapshot of all blocks in current queue.""" + with self._lock: + return list(self._blocks) + + @property + def spoken_text(self) -> str: + """All text that has been successfully spoken.""" + with self._lock: + return " ".join(b.text for b in self._spoken_history) + + def clear(self) -> int: + """Clear the queue. Returns number of blocks removed.""" + with self._lock: + count = len(self._blocks) + self._blocks.clear() + return count + + def status(self) -> dict[str, Any]: + """Queue status snapshot.""" + with self._lock: + return { + "total": len(self._blocks), + "valid": sum(1 for b in self._blocks if b.status == BlockStatus.VALID), + "stale": sum(1 for b in self._blocks if b.status == BlockStatus.STALE), + "speaking": sum(1 for b in self._blocks if b.status == BlockStatus.SPEAKING), + "spoken_total": len(self._spoken_history), + "blocks": [b.to_dict() for b in self._blocks], + } diff --git a/http_api.py b/http_api.py index 45a7fb8..981d7fc 100644 --- a/http_api.py +++ b/http_api.py @@ -29,7 +29,7 @@ from threading import Lock from typing import Optional -from fastapi import FastAPI, Request, Response, UploadFile, WebSocket +from fastapi import FastAPI, Request, Response, UploadFile, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field @@ -58,7 +58,6 @@ async def _warmup_kokoro(): def _do_warmup(): try: from engine import get_model - get_model("kokoro") logger.info("Kokoro TTS engine pre-warmed successfully") except Exception as e: @@ -572,7 +571,6 @@ def stop_speech(job_id: str = ""): """ try: from server import _speech_queue, pipeline_state - if job_id: cancelled = _speech_queue.cancel(job_id) return {"status": "ok", "message": f"Cancelled {job_id}" if cancelled else f"Job {job_id} not found"} @@ -679,16 +677,24 @@ async def _graceful_exit(): deadline = time.time() + timeout_sec while time.time() < deadline: with _jobs_lock: - active = sum(1 for j in _jobs.values() if j.get("status") in ("generating", "processing")) + active = sum( + 1 for j in _jobs.values() + if j.get("status") in ("generating", "processing") + ) if active == 0: break await asyncio.sleep(0.25) with _jobs_lock: - remaining = sum(1 for j in _jobs.values() if j.get("status") in ("generating", "processing")) + remaining = sum( + 1 for j in _jobs.values() + if j.get("status") in ("generating", "processing") + ) if remaining: - logger.warning("Shutdown timeout reached with %d active jobs — forcing exit", remaining) + logger.warning( + "Shutdown timeout reached with %d active jobs — forcing exit", remaining + ) else: logger.info("All jobs drained — exiting cleanly") diff --git a/mcp.channel.json b/mcp.channel.json index 53ce5bf..ccd2bf4 100644 --- a/mcp.channel.json +++ b/mcp.channel.json @@ -2,7 +2,7 @@ "mcpServers": { "mod3-voice": { "command": "python3", - "args": ["${MOD3_ROOT}/server.py", "--channel"] + "args": ["/Users/slowbro/workspaces/mod3/server.py", "--channel"] } } } diff --git a/mcp_shim.py b/mcp_shim.py new file mode 100644 index 0000000..c99ee73 --- /dev/null +++ b/mcp_shim.py @@ -0,0 +1,742 @@ +#!/usr/bin/env python3 +"""Mod³ MCP shim — thin stdio proxy to a running Mod³ HTTP service. + +Instead of spawning a full server.py (which loads TTS models, ~4GB VRAM), +this shim implements the MCP stdio protocol and forwards tool calls to +the Mod³ HTTP API at localhost:7860. + +Tools that are purely local (set_output_device, await_voice_input) are +handled in-process without touching the HTTP service. + +For `speak`, the shim posts to /v1/synthesize for audio generation, then +plays the returned WAV bytes locally via sounddevice. + +Usage: + python mcp_shim.py # normal MCP stdio mode + python mcp_shim.py --test # connectivity check, then exit +""" + +import io +import json +import logging +import os +import struct +import sys +import threading +import time +import urllib.error +import urllib.request +import wave +from collections import OrderedDict +from typing import Any + +logger = logging.getLogger("mod3.shim") + +MOD3_BASE = os.environ.get("MOD3_URL", "http://localhost:7860") + +# --------------------------------------------------------------------------- +# Lightweight audio playback (only needs sounddevice, not full TTS stack) +# --------------------------------------------------------------------------- + +_output_device: Any = None +_current_player_lock = threading.Lock() +_current_sd_stream = None +_playback_interrupt = threading.Event() + +# Job tracking (lightweight — just for speak/stop/status) +_jobs: OrderedDict = OrderedDict() +_jobs_lock = threading.Lock() +_MAX_JOBS = 50 + +# Barge-in signal file (same as server.py) +_BARGEIN_SIGNAL = os.path.expanduser("~/.mod3_bargein_signal.json") + + +def _http_request(method: str, path: str, body: dict | None = None, + timeout: float = 30.0) -> tuple[int, dict | bytes]: + """Make an HTTP request to the Mod3 service. Returns (status_code, parsed_json_or_bytes).""" + url = f"{MOD3_BASE}{path}" + headers = {"Content-Type": "application/json"} if body is not None else {} + data = json.dumps(body).encode() if body is not None else None + + req = urllib.request.Request(url, data=data, headers=headers, method=method) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + content_type = resp.headers.get("Content-Type", "") + raw = resp.read() + if "application/json" in content_type: + return resp.status, json.loads(raw) + elif "audio/" in content_type: + return resp.status, raw + else: + try: + return resp.status, json.loads(raw) + except (json.JSONDecodeError, ValueError): + return resp.status, raw + except urllib.error.HTTPError as e: + try: + body_bytes = e.read() + return e.code, json.loads(body_bytes) + except Exception: + return e.code, {"error": str(e)} + except urllib.error.URLError as e: + return 0, {"error": f"Mod3 service unreachable: {e.reason}"} + except Exception as e: + return 0, {"error": f"Request failed: {e}"} + + +def _play_wav_bytes(wav_bytes: bytes, job_id: str): + """Play WAV audio bytes through speakers via sounddevice.""" + global _current_sd_stream + try: + import numpy as np + import sounddevice as sd + except ImportError: + logger.error("sounddevice/numpy not available — cannot play audio") + with _jobs_lock: + if job_id in _jobs: + _jobs[job_id]["status"] = "error" + _jobs[job_id]["error"] = "sounddevice not installed" + return + + try: + buf = io.BytesIO(wav_bytes) + with wave.open(buf, "rb") as wf: + sr = wf.getframerate() + ch = wf.getnchannels() + sw = wf.getsampwidth() + frames = wf.readframes(wf.getnframes()) + + if sw == 2: + audio = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32767.0 + elif sw == 4: + audio = np.frombuffer(frames, dtype=np.int32).astype(np.float32) / 2147483647.0 + else: + audio = np.frombuffer(frames, dtype=np.float32) + + if ch > 1: + audio = audio.reshape(-1, ch)[:, 0] # mono mixdown + + duration = len(audio) / sr + with _jobs_lock: + if job_id in _jobs: + _jobs[job_id]["status"] = "speaking" + _jobs[job_id]["start_time"] = time.time() + _jobs[job_id]["duration_sec"] = round(duration, 2) + + _playback_interrupt.clear() + device = _output_device + with _current_player_lock: + _current_sd_stream = job_id + + sd.play(audio, samplerate=sr, device=device, blocking=True) + + with _current_player_lock: + _current_sd_stream = None + + if not _playback_interrupt.is_set(): + with _jobs_lock: + if job_id in _jobs: + _jobs[job_id]["status"] = "done" + _jobs[job_id]["metrics"] = { + "audio_duration_sec": round(duration, 2), + "sample_rate": sr, + } + except Exception as e: + logger.error("Playback error: %s", e) + with _current_player_lock: + _current_sd_stream = None + with _jobs_lock: + if job_id in _jobs: + _jobs[job_id]["status"] = "error" + _jobs[job_id]["error"] = str(e) + + +def _estimate_duration(text: str, speed: float) -> float: + words = len(text.split()) + base_wpm = 160 * speed + return (words / base_wpm) * 60 + + +# --------------------------------------------------------------------------- +# Tool implementations +# --------------------------------------------------------------------------- + +def tool_speak(text: str, voice: str = "bm_lewis", stream: bool = True, + speed: float = 1.25, emotion: float = 0.5) -> str: + """Synthesize via HTTP, play locally.""" + if not text.strip(): + return json.dumps({"status": "error", "error": "Nothing to say"}) + + # Check barge-in + try: + if os.path.exists(_BARGEIN_SIGNAL): + with open(_BARGEIN_SIGNAL) as f: + sig = json.load(f) + if sig.get("event") == "user_speaking_start": + return json.dumps({ + "status": "held", + "reason": "User is currently speaking — re-send after user finishes.", + "user_state": "recording", + "estimated_duration_sec": round(_estimate_duration(text, speed), 1), + }) + except Exception: + pass + + # Request synthesis from HTTP service + status, resp = _http_request("POST", "/v1/synthesize", { + "text": text, "voice": voice, "speed": speed, "emotion": emotion, + "format": "wav", + }, timeout=60.0) + + if status == 0: + return json.dumps({"status": "error", "error": resp.get("error", "Service unreachable")}) + if status != 200: + err = resp.get("error", f"HTTP {status}") if isinstance(resp, dict) else f"HTTP {status}" + return json.dumps({"status": "error", "error": err}) + if not isinstance(resp, bytes): + return json.dumps({"status": "error", "error": "Expected audio bytes from synthesize"}) + + # Create job and play in background + job_id = f"shim-{int(time.time()*1000)}" + with _jobs_lock: + _jobs[job_id] = { + "status": "generating", + "text": text[:100], + "voice": voice, + "created": time.time(), + } + while len(_jobs) > _MAX_JOBS: + _jobs.popitem(last=False) + + t = threading.Thread(target=_play_wav_bytes, args=(resp, job_id), daemon=True) + t.start() + + return json.dumps({"status": "speaking", "job_id": job_id}) + + +def tool_stop(job_id: str = "") -> str: + """Stop playback.""" + try: + import sounddevice as sd + except ImportError: + pass + + if job_id: + with _jobs_lock: + job = _jobs.get(job_id) + if not job: + return json.dumps({"status": "error", "error": f"Unknown job '{job_id}'"}) + if job["status"] == "speaking": + _playback_interrupt.set() + try: + import sounddevice as sd + sd.stop() + except Exception: + pass + with _jobs_lock: + _jobs[job_id]["status"] = "interrupted" + return json.dumps({"status": "ok", "message": f"Interrupted '{job_id}'"}) + return json.dumps({"status": "ok", "message": f"Job '{job_id}' status: {job['status']}"}) + + # Stop all + _playback_interrupt.set() + try: + import sounddevice as sd + sd.stop() + except Exception: + pass + with _jobs_lock: + for j in _jobs.values(): + if j["status"] in ("speaking", "generating"): + j["status"] = "interrupted" + return json.dumps({"status": "ok", "message": "Stopped all playback"}) + + +def tool_speech_status(job_id: str = "", verbose: bool = False) -> str: + """Check job status.""" + with _jobs_lock: + if not job_id: + if not _jobs: + return json.dumps({"status": "idle", "message": "No speech jobs", "queue_depth": 0}) + job_id = next(reversed(_jobs)) + job = _jobs.get(job_id) + + if not job: + return json.dumps({"status": "error", "error": f"Unknown job '{job_id}'"}) + + result = {"job_id": job_id, "status": job["status"]} + if job["status"] == "speaking" and "start_time" in job: + result["elapsed_sec"] = round(time.time() - job["start_time"], 1) + if job.get("metrics"): + result["metrics"] = job["metrics"] + if job.get("error"): + result["error"] = job["error"] + + # Queue state + with _jobs_lock: + speaking = sum(1 for j in _jobs.values() if j["status"] == "speaking") + result["queue"] = {"depth": speaking, "currently_playing": None} + + return json.dumps(result) + + +def tool_list_voices() -> str: + """List voices via HTTP.""" + status, resp = _http_request("GET", "/v1/voices") + if status != 200: + return json.dumps({"status": "error", "error": "Could not reach Mod3 service"}) + + engines = resp.get("engines", {}) + lines = [] + for engine, cfg in engines.items(): + supports = cfg.get("supports", []) + tag = f" ({', '.join(supports)})" if supports else "" + voices = cfg.get("voices", []) + lines.append(f" {engine}{tag}: {', '.join(voices)}") + return "Available voices:\n" + "\n".join(lines) + + +def tool_diagnostics() -> str: + """Diagnostics via HTTP.""" + status, resp = _http_request("GET", "/diagnostics") + if status != 200: + return json.dumps({"status": "error", "error": "Could not reach Mod3 service"}) + return json.dumps(resp, indent=2) + + +def tool_set_output_device(device: str = "") -> str: + """List or set audio output device (local only).""" + global _output_device + try: + import sounddevice as sd + except ImportError: + return json.dumps({"status": "error", "error": "sounddevice not installed"}) + + outputs = [] + for i, d in enumerate(sd.query_devices()): + if d["max_output_channels"] > 0: + is_default = i == sd.default.device[1] + is_active = ( + (_output_device is None and is_default) + or _output_device == i + or (isinstance(_output_device, str) and _output_device in d["name"]) + ) + outputs.append({"index": i, "name": d["name"], "active": is_active, "default": is_default}) + + if not device: + return json.dumps({"devices": outputs}) + + if device == "default": + _output_device = None + return json.dumps({"status": "ok", "message": "Tracking system default"}) + + # Try numeric index + try: + idx = int(device) + for d in outputs: + if d["index"] == idx: + _output_device = idx + return json.dumps({"status": "ok", "device": d["name"], "index": idx}) + return json.dumps({"status": "error", "error": f"No output device at index {idx}"}) + except ValueError: + pass + + # Try name substring + for d in outputs: + if device.lower() in d["name"].lower(): + _output_device = d["index"] + return json.dumps({"status": "ok", "device": d["name"], "index": d["index"]}) + + return json.dumps({"status": "error", "error": f"No device matching '{device}'"}) + + +def tool_await_voice_input(timeout_sec: float = 180.0) -> str: + """Block until SuperWhisper recording finishes (local only).""" + _rec_dir = os.path.expanduser("~/Documents/superwhisper/recordings") + + start = time.time() + while time.time() - start < timeout_sec: + try: + if os.path.exists(_BARGEIN_SIGNAL): + with open(_BARGEIN_SIGNAL) as f: + signal = json.load(f) + if signal.get("event") == "user_speaking_end": + break + except (OSError, json.JSONDecodeError): + pass + time.sleep(0.2) + else: + return json.dumps({"status": "timeout", "error": f"No recording completed within {timeout_sec}s"}) + + # Find latest transcript + try: + folders = sorted( + [d for d in os.listdir(_rec_dir) if d.isdigit()], + key=int, reverse=True, + ) + if folders: + meta_path = os.path.join(_rec_dir, folders[0], "meta.json") + if os.path.exists(meta_path): + with open(meta_path) as f: + meta = json.load(f) + raw = meta.get("rawResult", "").strip() + result = meta.get("result", raw).strip() + duration_ms = meta.get("duration", 0) + return json.dumps({ + "status": "ok", + "transcript": result if result else raw, + "raw_transcript": raw, + "duration_sec": round(duration_ms / 1000, 1), + "folder": folders[0], + "source": "superwhisper", + }) + except Exception as e: + logger.warning("await_voice_input error: %s", e) + + return json.dumps({"status": "error", "error": "Could not retrieve transcript"}) + + +def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: + """VAD check via HTTP.""" + if not os.path.exists(file_path): + return json.dumps({"status": "error", "error": f"File not found: {file_path}"}) + + # Read WAV and send to HTTP endpoint + try: + with open(file_path, "rb") as f: + wav_data = f.read() + except Exception as e: + return json.dumps({"status": "error", "error": str(e)}) + + # The HTTP API expects multipart file upload, use urllib + import mimetypes + boundary = "----Mod3ShimBoundary" + body = ( + f"--{boundary}\r\n" + f'Content-Disposition: form-data; name="file"; filename="{os.path.basename(file_path)}"\r\n' + f"Content-Type: audio/wav\r\n\r\n" + ).encode() + wav_data + f"\r\n--{boundary}--\r\n".encode() + + url = f"{MOD3_BASE}/v1/vad" + if threshold != 0.5: + url += f"?threshold={threshold}" + req = urllib.request.Request( + url, data=body, method="POST", + headers={"Content-Type": f"multipart/form-data; boundary={boundary}"}, + ) + try: + with urllib.request.urlopen(req, timeout=30) as resp: + return json.dumps(json.loads(resp.read())) + except Exception as e: + return json.dumps({"status": "error", "error": f"VAD request failed: {e}"}) + + +# --------------------------------------------------------------------------- +# Tool registry (matches server.py exactly) +# --------------------------------------------------------------------------- + +TOOLS = [ + { + "name": "speak", + "description": ( + "Synthesize text to speech and play it through the user's speakers.\n\n" + "Non-blocking: returns immediately with a job ID while audio plays or is\n" + "queued. If nothing is playing, starts immediately. If audio is already\n" + "playing, the new request is queued and will play automatically when the\n" + "current item finishes.\n\n" + "The response always includes the current queue state so the agent knows\n" + "exactly what's happening on the output channel without a separate status call.\n\n" + "Args:\n" + " text: The text to speak aloud. Keep it conversational.\n" + " voice: Voice preset. Use list_voices() to see options.\n" + ' Defaults to "bm_lewis" (Kokoro).\n' + " stream: If True, plays audio chunks as they generate (lower latency).\n" + " If False, generates all audio first then plays (better prosody).\n" + " speed: Speed multiplier (engines with speed support). Default 1.25.\n" + " emotion: Emotion/exaggeration intensity 0.0-1.0 (Chatterbox only). Default 0.5." + ), + "inputSchema": { + "type": "object", + "properties": { + "text": {"type": "string", "description": "The text to speak aloud. Keep it conversational."}, + "voice": {"type": "string", "default": "bm_lewis", "description": "Voice preset. Use list_voices() to see options. Defaults to \"bm_lewis\" (Kokoro)."}, + "stream": {"type": "boolean", "default": True, "description": "If True, plays audio chunks as they generate (lower latency)."}, + "speed": {"type": "number", "default": 1.25, "description": "Speed multiplier (engines with speed support). Default 1.25."}, + "emotion": {"type": "number", "default": 0.5, "description": "Emotion/exaggeration intensity 0.0-1.0 (Chatterbox only). Default 0.5."}, + }, + "required": ["text"], + }, + }, + { + "name": "speech_status", + "description": ( + "Check status of a speech job, or get the most recent result.\n\n" + "Always includes queue state so the agent has full output channel awareness.\n\n" + "Args:\n" + " job_id: The job ID returned by speak(). If empty, returns the latest job.\n" + " verbose: If True, include per-chunk metrics. Default False (summary only)." + ), + "inputSchema": { + "type": "object", + "properties": { + "job_id": {"type": "string", "default": "", "description": "The job ID returned by speak(). If empty, returns the latest job."}, + "verbose": {"type": "boolean", "default": False, "description": "If True, include per-chunk metrics. Default False (summary only)."}, + }, + }, + }, + { + "name": "stop", + "description": ( + "Stop current speech or cancel a specific queued item.\n\n" + "Args:\n" + " job_id: If provided, cancels that specific queued job (not yet playing).\n" + " If the job_id is the currently playing job, interrupts playback.\n" + " If empty, interrupts current playback AND clears the entire queue." + ), + "inputSchema": { + "type": "object", + "properties": { + "job_id": {"type": "string", "default": "", "description": "If provided, cancels that specific job. If empty, stops everything."}, + }, + }, + }, + { + "name": "list_voices", + "description": "List all available voice presets grouped by engine.", + "inputSchema": {"type": "object", "properties": {}}, + }, + { + "name": "await_voice_input", + "description": ( + "Block until the user finishes a SuperWhisper recording, then return the transcript.\n\n" + "This closes the voice input loop: instead of waiting for the user to paste\n" + "their transcribed text, you can directly receive what they said. Use this\n" + "when speak() returns \"held\" (user is recording) or when you want to listen\n" + "for the next voice input.\n\n" + "Polls the barge-in signal file for user_speaking_end, then reads the\n" + "transcript from SuperWhisper's recordings directory.\n\n" + "Args:\n" + " timeout_sec: Maximum seconds to wait for recording to finish. Default 180 (3 minutes)." + ), + "inputSchema": { + "type": "object", + "properties": { + "timeout_sec": {"type": "number", "default": 180, "description": "Maximum seconds to wait for recording to finish. Default 180 (3 minutes)."}, + }, + }, + }, + { + "name": "diagnostics", + "description": "Return engine state and last generation metrics for debugging.", + "inputSchema": {"type": "object", "properties": {}}, + }, + { + "name": "set_output_device", + "description": ( + "List audio output devices, or set the active one.\n\n" + "Args:\n" + " device: Device index (e.g. \"3\"), name substring (e.g. \"AirPods\"),\n" + " or \"default\" to track the system default automatically.\n" + " If empty, lists available devices without changing anything." + ), + "inputSchema": { + "type": "object", + "properties": { + "device": {"type": "string", "default": "", "description": "Device index, name substring, or 'default'. If empty, lists devices."}, + }, + }, + }, + { + "name": "vad_check", + "description": ( + "Check if an audio file contains speech using Silero VAD.\n\n" + "Use this before transcription to avoid Whisper hallucinations on\n" + "silence or ambient noise.\n\n" + "Args:\n" + " file_path: Path to a WAV audio file.\n" + " threshold: Speech probability threshold 0-1 (default 0.5). Higher = stricter." + ), + "inputSchema": { + "type": "object", + "properties": { + "file_path": {"type": "string", "description": "Path to a WAV audio file."}, + "threshold": {"type": "number", "default": 0.5, "description": "Speech probability threshold 0-1 (default 0.5). Higher = stricter."}, + }, + "required": ["file_path"], + }, + }, +] + +TOOL_DISPATCH = { + "speak": lambda args: tool_speak( + args["text"], + voice=args.get("voice", "bm_lewis"), + stream=args.get("stream", True), + speed=args.get("speed", 1.25), + emotion=args.get("emotion", 0.5), + ), + "speech_status": lambda args: tool_speech_status( + job_id=args.get("job_id", ""), + verbose=args.get("verbose", False), + ), + "stop": lambda args: tool_stop(job_id=args.get("job_id", "")), + "list_voices": lambda args: tool_list_voices(), + "await_voice_input": lambda args: tool_await_voice_input( + timeout_sec=args.get("timeout_sec", 180.0), + ), + "diagnostics": lambda args: tool_diagnostics(), + "set_output_device": lambda args: tool_set_output_device( + device=args.get("device", ""), + ), + "vad_check": lambda args: tool_vad_check( + file_path=args["file_path"], + threshold=args.get("threshold", 0.5), + ), +} + + +# --------------------------------------------------------------------------- +# MCP stdio protocol +# --------------------------------------------------------------------------- + +SERVER_INFO = { + "name": "mod3", + "version": "0.3.0-shim", +} + +CAPABILITIES = { + "tools": {}, +} + + +def _read_message() -> dict | None: + """Read a JSON-RPC message from stdin (newline-delimited).""" + try: + line = sys.stdin.readline() + if not line: + return None + return json.loads(line.strip()) + except (json.JSONDecodeError, ValueError): + return None + + +def _write_message(msg: dict): + """Write a JSON-RPC message to stdout.""" + sys.stdout.write(json.dumps(msg) + "\n") + sys.stdout.flush() + + +def _jsonrpc_response(id: Any, result: Any) -> dict: + return {"jsonrpc": "2.0", "id": id, "result": result} + + +def _jsonrpc_error(id: Any, code: int, message: str) -> dict: + return {"jsonrpc": "2.0", "id": id, "error": {"code": code, "message": message}} + + +def handle_initialize(msg: dict) -> dict: + return _jsonrpc_response(msg["id"], { + "protocolVersion": "2024-11-05", + "serverInfo": SERVER_INFO, + "capabilities": CAPABILITIES, + }) + + +def handle_tools_list(msg: dict) -> dict: + return _jsonrpc_response(msg["id"], {"tools": TOOLS}) + + +def handle_tools_call(msg: dict) -> dict: + params = msg.get("params", {}) + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + + handler = TOOL_DISPATCH.get(tool_name) + if not handler: + return _jsonrpc_error(msg["id"], -32602, f"Unknown tool: {tool_name}") + + try: + result_text = handler(arguments) + except Exception as e: + result_text = json.dumps({"status": "error", "error": str(e)}) + + return _jsonrpc_response(msg["id"], { + "content": [{"type": "text", "text": result_text}], + }) + + +def handle_notifications_initialized(msg: dict): + """Client sends this after initialize — no response needed.""" + pass + + +METHOD_HANDLERS = { + "initialize": handle_initialize, + "tools/list": handle_tools_list, + "tools/call": handle_tools_call, + "notifications/initialized": handle_notifications_initialized, + "ping": lambda msg: _jsonrpc_response(msg["id"], {}), +} + + +def run_stdio(): + """Main MCP stdio loop.""" + logging.basicConfig(level=logging.WARNING, stream=sys.stderr) + + while True: + msg = _read_message() + if msg is None: + break # EOF + + method = msg.get("method", "") + handler = METHOD_HANDLERS.get(method) + + if handler is None: + # Unknown method — if it has an id, return error; if notification, ignore + if "id" in msg: + _write_message(_jsonrpc_error(msg["id"], -32601, f"Method not found: {method}")) + continue + + result = handler(msg) + if result is not None: + _write_message(result) + + +# --------------------------------------------------------------------------- +# Self-test +# --------------------------------------------------------------------------- + +def self_test(): + """Quick connectivity check.""" + print(f"Mod3 shim — testing connection to {MOD3_BASE}") + + status, resp = _http_request("GET", "/health") + if status == 200: + engines = resp.get("engines", {}) + loaded = [k for k, v in engines.items() if v == "loaded"] + print(f" OK: Mod3 service healthy — {len(loaded)} engine(s) loaded: {', '.join(loaded) or 'none'}") + elif status == 0: + print(f" WARN: Mod3 service not reachable at {MOD3_BASE}") + print(" Tools will return errors until the service starts.") + else: + print(f" WARN: Unexpected status {status} from /health") + + # Check sounddevice + try: + import sounddevice as sd + default_out = sd.query_devices(sd.default.device[1]) + print(f" OK: sounddevice available — default output: {default_out['name']}") + except ImportError: + print(" WARN: sounddevice not installed — speak/stop will fail") + except Exception as e: + print(f" WARN: sounddevice error: {e}") + + print(" Shim ready.") + + +if __name__ == "__main__": + if "--test" in sys.argv: + self_test() + else: + run_stdio() diff --git a/modules/voice.py b/modules/voice.py index 6cacce3..aa03263 100644 --- a/modules/voice.py +++ b/modules/voice.py @@ -82,22 +82,28 @@ class WhisperDecoder(Decoder): Accepts PCM float32 bytes at 16kHz or a numpy float32 array directly. Lazy-loads the model on first call; subsequent calls reuse it. Applies BoH hallucination filter to transcripts. + + Supports two models: + - Large (whisper-large-v3-turbo): high-quality, used for T2/T3 tiers (~470ms) + - Base (whisper-base-mlx): fast, used for T1 tier (~31ms) """ - DEFAULT_MODEL = "mlx-community/whisper-turbo" + DEFAULT_MODEL = "mlx-community/whisper-large-v3-turbo" + BASE_MODEL = "mlx-community/whisper-base-mlx" - def __init__(self, model: str | None = None): + def __init__(self, model: str | None = None, load_base: bool = True): self._model = model or self.DEFAULT_MODEL self._loaded = False + self._base_loaded = False + self._load_base = load_base + # Streaming state: last transcript for diff-based partial detection + self._last_streaming_text: str = "" def _ensure_model(self) -> None: """Trigger model download/load on first use.""" if not self._loaded: import mlx_whisper - # A dry-run transcribe forces the model to download & cache. - # mlx_whisper handles caching internally — subsequent calls - # with the same path_or_hf_repo are fast. logger.info("WhisperDecoder: loading model %s (first call)", self._model) mlx_whisper.transcribe( np.zeros(16000, dtype=np.float32), # 1 s of silence @@ -106,6 +112,182 @@ def _ensure_model(self) -> None: self._loaded = True logger.info("WhisperDecoder: model ready") + def _ensure_base_model(self) -> None: + """Load Whisper Base model for T1 fast transcription.""" + if not self._base_loaded: + import mlx_whisper + + logger.info("WhisperDecoder: loading base model %s", self.BASE_MODEL) + mlx_whisper.transcribe( + np.zeros(16000, dtype=np.float32), + path_or_hf_repo=self.BASE_MODEL, + ) + self._base_loaded = True + logger.info("WhisperDecoder: base model ready") + + def decode_streaming( + self, + audio: np.ndarray, + tier: str = "t1", + **kwargs, + ) -> dict: + """Chunked re-transcription with LocalAgreement-2 diff. + + Re-runs mlx_whisper.transcribe() on the growing audio buffer, + diffs consecutive outputs to produce confirmed vs tentative text. + + Args: + audio: Growing float32 audio buffer at 16kHz. + tier: "t1" (Base, fast), "t2" (Large, on pause), "t3" (Large, final). + + Returns: + dict with keys: + - confirmed: str — text stable across 2+ consecutive runs + - tentative: str — new text not yet confirmed + - full_text: str — complete transcript from this run + - tier: str — which tier was used + - changed: bool — whether output differs from last run + """ + import mlx_whisper + + from vad import is_hallucination + + # Select model based on tier + if tier == "t1": + self._ensure_base_model() + model_path = self.BASE_MODEL + else: + self._ensure_model() + model_path = self._model + + t0 = time.time() + result = mlx_whisper.transcribe( + audio, + path_or_hf_repo=model_path, + language="en", + ) + elapsed_ms = (time.time() - t0) * 1000 + + transcript: str = result.get("text", "").strip() + + if is_hallucination(transcript): + return { + "confirmed": "", + "tentative": "", + "full_text": "", + "tier": tier, + "changed": False, + "elapsed_ms": round(elapsed_ms, 1), + "filtered": True, + } + + # LocalAgreement-2 diff: find longest common prefix with last run + prev = self._last_streaming_text + changed = transcript != prev + + # Confirmed = common prefix (stable across consecutive runs) + confirmed = "" + min_len = min(len(prev), len(transcript)) + for i in range(min_len): + if prev[i] == transcript[i]: + confirmed = transcript[: i + 1] + else: + break + + # Snap to word boundary + if confirmed and not confirmed.endswith(" "): + last_space = confirmed.rfind(" ") + if last_space > 0: + confirmed = confirmed[:last_space] + + # Tentative = remainder after confirmed prefix + tentative = transcript[len(confirmed):].strip() + + # T3 = end-of-utterance, everything is confirmed + if tier == "t3": + confirmed = transcript + tentative = "" + + self._last_streaming_text = transcript + + return { + "confirmed": confirmed.strip(), + "tentative": tentative, + "full_text": transcript, + "tier": tier, + "changed": changed, + "elapsed_ms": round(elapsed_ms, 1), + } + + def reset_streaming(self) -> None: + """Reset streaming state between utterances.""" + self._last_streaming_text = "" + + def validate_tts_output(self, audio_samples: np.ndarray, source_text: str, sample_rate: int = 24000) -> dict: + """Whisper validation loop: run TTS audio through Whisper Base and compare. + + After TTS generates an audio chunk, run it through Whisper Base (~31ms) + and compare transcript to source text. Flag mismatches. + + Args: + audio_samples: Float32 audio samples from TTS. + source_text: The original text that was synthesized. + sample_rate: Sample rate of the TTS audio. + + Returns: + dict with keys: + - match: bool — whether transcript matches source + - transcript: str — what Whisper heard + - source: str — original text + - similarity: float — 0.0-1.0 word overlap ratio + - elapsed_ms: float + """ + import mlx_whisper + + self._ensure_base_model() + + # Resample to 16kHz if needed (Whisper expects 16kHz) + if sample_rate != 16000: + # Simple linear resampling + ratio = 16000 / sample_rate + new_len = int(len(audio_samples) * ratio) + indices = np.linspace(0, len(audio_samples) - 1, new_len) + audio_16k = np.interp(indices, np.arange(len(audio_samples)), audio_samples).astype(np.float32) + else: + audio_16k = audio_samples + + t0 = time.time() + result = mlx_whisper.transcribe( + audio_16k, + path_or_hf_repo=self.BASE_MODEL, + language="en", + ) + elapsed_ms = (time.time() - t0) * 1000 + + transcript = result.get("text", "").strip().lower() + source_clean = source_text.strip().lower() + + # Word-level similarity + source_words = set(source_clean.split()) + transcript_words = set(transcript.split()) + + if source_words: + overlap = len(source_words & transcript_words) + similarity = overlap / len(source_words) + else: + similarity = 1.0 if not transcript_words else 0.0 + + # Match if similarity >= 0.7 (TTS output may have minor variations) + match = similarity >= 0.7 + + return { + "match": match, + "transcript": transcript, + "source": source_text, + "similarity": round(similarity, 3), + "elapsed_ms": round(elapsed_ms, 1), + } + def decode(self, raw: bytes, **kwargs) -> CognitiveEvent: import mlx_whisper diff --git a/providers.py b/providers.py index cc551c4..9068733 100644 --- a/providers.py +++ b/providers.py @@ -121,7 +121,8 @@ def _format_tools_for_prompt(tools: list[dict]) -> str: for pname, pinfo in props.items(): req_marker = " (required)" if pname in required else "" lines.append( - f" - {pname} ({pinfo.get('type', 'string')}): {pinfo.get('description', '')}{req_marker}" + f" - {pname} ({pinfo.get('type', 'string')}): " + f"{pinfo.get('description', '')}{req_marker}" ) lines.append( "\nTo call a tool, output exactly:\n" @@ -134,7 +135,9 @@ def _format_tools_for_prompt(tools: list[dict]) -> str: return "\n".join(lines) -_TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) +_TOOL_CALL_RE = re.compile( + r"\s*(\{.*?\})\s*", re.DOTALL +) def _parse_tool_calls(text: str) -> list[ToolCall]: @@ -165,7 +168,9 @@ class MlxProvider: """ def __init__(self, model_id: str | None = None): - self._model_id = model_id or os.environ.get("MLX_MODEL", "mlx-community/gemma-3-4b-it-4bit") + self._model_id = model_id or os.environ.get( + "MLX_MODEL", "mlx-community/gemma-3-4b-it-4bit" + ) self._model = None self._tokenizer = None @@ -205,7 +210,9 @@ def _generate_sync( msgs = [{"role": "system", "content": "\n\n".join(system_parts)}] + msgs # Apply chat template - prompt = self._tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + prompt = self._tokenizer.apply_chat_template( + msgs, add_generation_prompt=True, tokenize=False + ) max_tokens = int(os.environ.get("MLX_MAX_TOKENS", "512")) raw_output = generate( @@ -233,7 +240,9 @@ async def chat( tools: list[dict] | None = None, system: str = "", ) -> ProviderResponse: - return await asyncio.to_thread(self._generate_sync, messages, tools, system) + return await asyncio.to_thread( + self._generate_sync, messages, tools, system + ) # --------------------------------------------------------------------------- @@ -249,7 +258,9 @@ def __init__( endpoint: str | None = None, model: str | None = None, ): - self._endpoint = endpoint or os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434") + self._endpoint = endpoint or os.environ.get( + "OLLAMA_ENDPOINT", "http://localhost:11434" + ) self._model = model or os.environ.get("OLLAMA_MODEL", "gemma4:e4b") @property @@ -311,7 +322,9 @@ class CogOSProvider: """CogOS kernel — OpenAI-compatible chat/completions with tool support.""" def __init__(self, endpoint: str | None = None): - self._endpoint = endpoint or os.environ.get("COGOS_ENDPOINT", "http://localhost:5100") + self._endpoint = endpoint or os.environ.get( + "COGOS_ENDPOINT", "http://localhost:5100" + ) @property def name(self) -> str: @@ -430,7 +443,7 @@ def auto_detect_provider() -> InferenceProvider: return MlxProvider() try: - asyncio.get_running_loop() + loop = asyncio.get_running_loop() except RuntimeError: return asyncio.run(auto_detect_provider_async()) diff --git a/server.py b/server.py index 4d18002..8c79bcb 100644 --- a/server.py +++ b/server.py @@ -365,18 +365,60 @@ async def _filter_read_stream(): # --------------------------------------------------------------------------- _BARGEIN_SIGNAL = "/tmp/mod3-barge-in.json" +_SPEAKING_LOCK = "/tmp/mod3-speaking.json" _bargein_last_mtime: float = 0.0 +def _acquire_speaking_lock(job_id: str, text: str): + """Write cross-process speaking lock so the barge-in watcher knows ANY Mod³ is speaking.""" + try: + payload = { + "speaking": True, + "job_id": job_id, + "text": text, + "pid": os.getpid(), + "timestamp": time.time(), + } + tmp = _SPEAKING_LOCK + ".tmp" + with open(tmp, "w") as f: + json.dump(payload, f) + os.replace(tmp, _SPEAKING_LOCK) + except OSError: + pass + + +def _release_speaking_lock(): + """Clear the cross-process speaking lock.""" + try: + if os.path.exists(_SPEAKING_LOCK): + os.remove(_SPEAKING_LOCK) + except OSError: + pass + + +def _is_any_process_speaking() -> dict | None: + """Check if ANY Mod³ process is currently speaking (cross-process).""" + try: + if not os.path.exists(_SPEAKING_LOCK): + return None + with open(_SPEAKING_LOCK) as f: + lock = json.load(f) + # Stale lock check: if older than 60s, ignore it (crashed process) + if time.time() - lock.get("timestamp", 0) > 60: + os.remove(_SPEAKING_LOCK) + return None + return lock + except (OSError, json.JSONDecodeError): + return None + + def _bargein_watcher(): """Background thread that watches for barge-in signal file changes.""" global _bargein_last_mtime import json as _json - while True: try: import os - if os.path.exists(_BARGEIN_SIGNAL): mtime = os.path.getmtime(_BARGEIN_SIGNAL) if mtime > _bargein_last_mtime: @@ -384,10 +426,10 @@ def _bargein_watcher(): with open(_BARGEIN_SIGNAL) as f: signal = _json.load(f) if signal.get("event") == "user_speaking_start": + # Check local pipeline state first (same process) if pipeline_state.is_speaking: info = pipeline_state.interrupt(reason="barge_in") if info: - # Write interrupt context back to signal file signal["interrupted"] = { "spoken_pct": info.spoken_pct, "delivered_text": info.delivered_text, @@ -395,9 +437,25 @@ def _bargein_watcher(): } with open(_BARGEIN_SIGNAL, "w") as f: _json.dump(signal, f, indent=2) - logging.info( - "Barge-in: paused playback (%.0f%% delivered)", info.spoken_pct * 100 if info else 0 - ) + logging.info("Barge-in: paused local playback (%.0f%% delivered)", info.spoken_pct * 100 if info else 0) + else: + # Check cross-process lock (another Mod³ process may be speaking) + lock = _is_any_process_speaking() + if lock: + # We can't interrupt another process's pipeline_state, + # but we CAN write the interrupt context from the lock data + signal["interrupted"] = { + "spoken_pct": 0.0, # Unknown from cross-process + "delivered_text": "", + "full_text": lock.get("text", ""), + "cross_process": True, + "source_pid": lock.get("pid"), + } + with open(_BARGEIN_SIGNAL, "w") as f: + _json.dump(signal, f, indent=2) + # Clear the speaking lock to signal the other process + _release_speaking_lock() + logging.info("Barge-in: cross-process interrupt (pid=%s)", lock.get("pid")) except Exception as e: logging.debug("Barge-in watcher error: %s", e) time.sleep(0.1) # 100ms poll @@ -591,6 +649,7 @@ def _run_speech_job(entry: dict) -> None: # Register with the reflex arc so inbound VAD can interrupt us pipeline_state.start_speaking(text, player) + _acquire_speaking_lock(job_id, text) try: for chunk in engine_module.generate_audio( text, @@ -600,6 +659,11 @@ def _run_speech_job(entry: dict) -> None: speed=speed, emotion=emotion, ): + # Check if barge-in cleared our speaking lock (cross-process interrupt) + if not os.path.exists(_SPEAKING_LOCK): + logging.info("Speaking lock cleared by barge-in watcher — stopping generation") + player.stop() + break player.queue_audio(chunk.samples, chunk_meta=chunk.metadata if chunk.metadata else None) _set_bus_voice_state( status=ModuleStatus.ENCODING, @@ -617,6 +681,7 @@ def _run_speech_job(entry: dict) -> None: # Final position update and clear speaking state pipeline_state.update_position(*player.get_progress()) pipeline_state.stop_speaking() + _release_speaking_lock() result = metrics.to_dict() result["engine"] = engine @@ -769,14 +834,12 @@ def speak( # can't be cleared by stop(). if user_state == "recording": est_duration = _estimate_duration_sec(text, speed) - return json.dumps( - { - "status": "held", - "reason": "User is currently speaking — re-send this speak() call after user finishes.", - "user_state": "recording", - "estimated_duration_sec": round(est_duration, 1), - } - ) + return json.dumps({ + "status": "held", + "reason": "User is currently speaking — re-send this speak() call after user finishes.", + "user_state": "recording", + "estimated_duration_sec": round(est_duration, 1), + }) try: job_id, position = _start_speech(text, voice, stream=stream, speed=speed, emotion=emotion) @@ -1075,6 +1138,106 @@ def list_voices() -> str: return "Available voices:\n" + "\n".join(lines) +@mcp.tool( + annotations={ + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": False, + "openWorldHint": True, + } +) +def await_voice_input(timeout_sec: float = 180.0) -> str: + """Block until the user finishes a SuperWhisper recording, then return the transcript. + + This closes the voice input loop: instead of waiting for the user to paste + their transcribed text, you can directly receive what they said. Use this + when speak() returns "held" (user is recording) or when you want to listen + for the next voice input. + + Polls the barge-in signal file for user_speaking_end, then reads the + transcript from SuperWhisper's recordings directory. + + Args: + timeout_sec: Maximum seconds to wait for recording to finish. Default 180 (3 minutes). + """ + import sqlite3 as _sqlite3 + + _sw_db = os.path.expanduser( + "~/Library/Application Support/SuperWhisper/database/superwhisper.sqlite" + ) + _rec_dir = os.path.expanduser("~/Documents/superwhisper/recordings") + + start = time.time() + # If user is currently recording, wait for them to finish + while time.time() - start < timeout_sec: + try: + if os.path.exists(_BARGEIN_SIGNAL): + with open(_BARGEIN_SIGNAL) as f: + signal = json.load(f) + if signal.get("event") == "user_speaking_end": + break + except (OSError, json.JSONDecodeError): + pass + time.sleep(0.2) + else: + return json.dumps({"status": "timeout", "error": f"No recording completed within {timeout_sec}s"}) + + # Recording finished — find the latest transcript + # Method 1: Check the most recent recording folder's meta.json + try: + folders = sorted( + [d for d in os.listdir(_rec_dir) if d.isdigit()], + key=int, + reverse=True, + ) + if folders: + meta_path = os.path.join(_rec_dir, folders[0], "meta.json") + if os.path.exists(meta_path): + with open(meta_path) as f: + meta = json.load(f) + raw = meta.get("rawResult", "").strip() + result = meta.get("result", raw).strip() + duration_ms = meta.get("duration", 0) + return json.dumps({ + "status": "ok", + "transcript": result if result else raw, + "raw_transcript": raw, + "duration_sec": round(duration_ms / 1000, 1), + "folder": folders[0], + "source": "superwhisper", + }) + except Exception as e: + logger.warning("await_voice_input meta.json fallback failed: %s", e) + + # Method 2: Query SuperWhisper SQLite DB + try: + conn = _sqlite3.connect(f"file:{_sw_db}?mode=ro", uri=True, timeout=2.0) + row = conn.execute( + "SELECT folderName, duration FROM recording ORDER BY datetime DESC LIMIT 1" + ).fetchone() + conn.close() + if row: + folder_name, duration = row + meta_path = os.path.join(_rec_dir, folder_name, "meta.json") + if os.path.exists(meta_path): + with open(meta_path) as f: + meta = json.load(f) + raw = meta.get("rawResult", "").strip() + result = meta.get("result", raw).strip() + return json.dumps({ + "status": "ok", + "transcript": result if result else raw, + "raw_transcript": raw, + "duration_sec": round(duration / 1000, 1), + "folder": folder_name, + "source": "superwhisper_db", + }) + except Exception as e: + logger.warning("await_voice_input DB fallback failed: %s", e) + + return json.dumps({"status": "error", "error": "Could not retrieve transcript"}) + + @mcp.tool( annotations={ "readOnlyHint": True, @@ -1125,7 +1288,8 @@ def set_output_device(device: str = "") -> str: """List audio output devices, or set the active one. Args: - device: Device index (e.g. "3") or name substring (e.g. "AirPods"). + device: Device index (e.g. "3"), name substring (e.g. "AirPods"), + or "default" to track the system default automatically. If empty, lists available devices without changing anything. """ import sounddevice as sd @@ -1141,12 +1305,16 @@ def set_output_device(device: str = "") -> str: or _output_device == i or (isinstance(_output_device, str) and _output_device in d["name"]) ) - outputs.append({"index": i, "name": d["name"], "active": is_active}) + outputs.append({"index": i, "name": d["name"], "active": is_active, "default": is_default}) if not device: - lines = [f" [{'*' if d['active'] else ' '}] {d['index']}: {d['name']}" for d in outputs] + lines = [f" [{'*' if d['active'] else ' '}] {d['index']}: {d['name']}{' (system default)' if d['default'] else ''}" for d in outputs] return "Audio output devices (* = active):\n" + "\n".join(lines) + if device.lower() == "default": + _output_device = None + return json.dumps({"status": "ok", "device": "system_default", "note": "Now tracking system default output device"}) + if device.isdigit(): _output_device = int(device) else: From df8b3799d5f36e4049dda352981fed647c9d7fbc Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Wed, 15 Apr 2026 15:34:01 -0400 Subject: [PATCH 2/9] fix: resolve 9 ruff lint errors (unused imports, missing asyncio, formatting) - Remove unused imports: typing.Any, VoiceEncoder, WebSocketDisconnect, struct, mimetypes - Add missing asyncio import for to_thread() in speculative TTS - Prefix unused variables with _ (full, check_messages, loop) - Auto-fixed by ruff --fix + manual corrections Co-Authored-By: Claude Opus 4.6 (1M context) --- agent_loop.py | 9 +++++---- http_api.py | 2 +- mcp_shim.py | 2 -- providers.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/agent_loop.py b/agent_loop.py index 2f28aad..afdca23 100644 --- a/agent_loop.py +++ b/agent_loop.py @@ -7,11 +7,12 @@ from __future__ import annotations +import asyncio import json as _json import logging import os import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import httpx @@ -83,7 +84,7 @@ def _fetch_kernel_context() -> str: interrupted = signal.get("interrupted") if interrupted: delivered = interrupted.get("delivered_text", "") - full = interrupted.get("full_text", "") + _full = interrupted.get("full_text", "") pct = interrupted.get("spoken_pct", 0) parts.append( f"[barge-in] Claude's speech was interrupted at {pct*100:.0f}%. " @@ -412,7 +413,7 @@ async def _presynthesise_block(self, block) -> None: Generates audio immediately and attaches it to the block. Ready for instant playback when the human stops speaking. """ - from modules.voice import VoiceEncoder, _encode_wav + from modules.voice import _encode_wav try: voice = "bm_lewis" @@ -462,7 +463,7 @@ async def background_validate_drafts(self, latest_user_text: str) -> None: return # Build context with latest human input - check_messages = list(self.conversation) + [ + _check_messages = list(self.conversation) + [ {"role": "user", "content": latest_user_text}, ] diff --git a/http_api.py b/http_api.py index 981d7fc..0a6396f 100644 --- a/http_api.py +++ b/http_api.py @@ -29,7 +29,7 @@ from threading import Lock from typing import Optional -from fastapi import FastAPI, Request, Response, UploadFile, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, Request, Response, UploadFile, WebSocket from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field diff --git a/mcp_shim.py b/mcp_shim.py index c99ee73..83831fb 100644 --- a/mcp_shim.py +++ b/mcp_shim.py @@ -20,7 +20,6 @@ import json import logging import os -import struct import sys import threading import time @@ -410,7 +409,6 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: return json.dumps({"status": "error", "error": str(e)}) # The HTTP API expects multipart file upload, use urllib - import mimetypes boundary = "----Mod3ShimBoundary" body = ( f"--{boundary}\r\n" diff --git a/providers.py b/providers.py index 9068733..a8034b9 100644 --- a/providers.py +++ b/providers.py @@ -443,7 +443,7 @@ def auto_detect_provider() -> InferenceProvider: return MlxProvider() try: - loop = asyncio.get_running_loop() + _loop = asyncio.get_running_loop() except RuntimeError: return asyncio.run(auto_detect_provider_async()) From da06193a9e632718042915cc846d277efc0cc582 Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Wed, 15 Apr 2026 15:36:48 -0400 Subject: [PATCH 3/9] style: apply ruff formatting to pass CI format check Co-Authored-By: Claude Opus 4.6 (1M context) --- agent_loop.py | 71 ++++++++++++-------- channels.py | 90 ++++++++++++++----------- draft_queue.py | 13 ++-- http_api.py | 16 ++--- mcp_shim.py | 170 +++++++++++++++++++++++++++++++++-------------- modules/voice.py | 2 +- providers.py | 27 ++------ server.py | 74 ++++++++++++--------- 8 files changed, 276 insertions(+), 187 deletions(-) diff --git a/agent_loop.py b/agent_loop.py index afdca23..dd8f548 100644 --- a/agent_loop.py +++ b/agent_loop.py @@ -87,8 +87,8 @@ def _fetch_kernel_context() -> str: _full = interrupted.get("full_text", "") pct = interrupted.get("spoken_pct", 0) parts.append( - f"[barge-in] Claude's speech was interrupted at {pct*100:.0f}%. " - f"Delivered: \"{delivered}\". " + f"[barge-in] Claude's speech was interrupted at {pct * 100:.0f}%. " + f'Delivered: "{delivered}". ' f"The user interrupted to say something — acknowledge and respond to them." ) except Exception: @@ -125,6 +125,7 @@ def _log_exchange_to_bus(user_text: str, assistant_text: str, provider_name: str except Exception as e: logger.debug("Failed to log exchange to bus: %s", e) + MAX_HISTORY = 50 @@ -215,7 +216,9 @@ async def _process(self, event: CognitiveEvent) -> None: content=text, target_channel=self.channel_id, metadata={ - "voice": self._channel_ref.config.get("voice", "bm_lewis") if self._channel_ref else "bm_lewis", + "voice": self._channel_ref.config.get("voice", "bm_lewis") + if self._channel_ref + else "bm_lewis", "speed": self._channel_ref.config.get("speed", 1.25) if self._channel_ref else 1.25, }, ) @@ -250,10 +253,12 @@ async def _process(self, event: CognitiveEvent) -> None: # Update conversation history if assistant_parts: assistant_text = " ".join(assistant_parts) - self.conversation.append({ - "role": "assistant", - "content": assistant_text, - }) + self.conversation.append( + { + "role": "assistant", + "content": assistant_text, + } + ) # Log exchange to CogOS bus (observation channel — Claude can see this) _log_exchange_to_bus(event.content, assistant_text, self.provider.name) @@ -308,6 +313,7 @@ async def speculative_infer(self, committed_text: str) -> None: # Add to draft queue import hashlib + ctx_hash = hashlib.md5(committed_text.encode()).hexdigest()[:8] block = self.draft_queue.add_block( text=response_text, @@ -317,7 +323,9 @@ async def speculative_infer(self, committed_text: str) -> None: logger.info( "speculative block %s: '%s' (%.0fms)", - block.id, response_text[:60], t_ms, + block.id, + response_text[:60], + t_ms, ) # F2: Speculative TTS pre-synthesis @@ -326,10 +334,12 @@ async def speculative_infer(self, committed_text: str) -> None: # Notify dashboard of draft queue state if self._channel_ref: - await self._channel_ref.ws.send_json({ - "type": "draft_queue", - "blocks": [b.to_dict() for b in self.draft_queue.get_pending()], - }) + await self._channel_ref.ws.send_json( + { + "type": "draft_queue", + "blocks": [b.to_dict() for b in self.draft_queue.get_pending()], + } + ) except Exception as e: logger.debug("speculative_infer failed: %s", e) @@ -367,10 +377,12 @@ async def _push_draft_queue_state(self) -> None: """Push current draft queue state to the dashboard.""" if self._channel_ref: try: - await self._channel_ref.ws.send_json({ - "type": "draft_queue", - "blocks": [b.to_dict() for b in self.draft_queue.all_blocks], - }) + await self._channel_ref.ws.send_json( + { + "type": "draft_queue", + "blocks": [b.to_dict() for b in self.draft_queue.all_blocks], + } + ) except Exception: pass @@ -398,10 +410,12 @@ async def invalidate_stale_drafts(self, new_context: str) -> int: if invalidated > 0 and self._channel_ref: try: - await self._channel_ref.ws.send_json({ - "type": "draft_queue", - "blocks": [b.to_dict() for b in self.draft_queue.all_blocks], - }) + await self._channel_ref.ws.send_json( + { + "type": "draft_queue", + "blocks": [b.to_dict() for b in self.draft_queue.all_blocks], + } + ) except Exception: pass @@ -424,6 +438,7 @@ async def _presynthesise_block(self, block) -> None: def _synth(): from engine import synthesize + samples, sample_rate = synthesize( block.text, voice=voice, @@ -471,9 +486,9 @@ async def background_validate_drafts(self, latest_user_text: str) -> None: try: # Quick relevance check: ask the model if this block is still appropriate check_prompt = ( - f"Given the user just said: \"{latest_user_text}\"\n" + f'Given the user just said: "{latest_user_text}"\n' f"Is this planned response still appropriate? " - f"Response: \"{block.text}\"\n" + f'Response: "{block.text}"\n' f"Answer KEEP or REVISE in one word." ) @@ -520,18 +535,18 @@ def _build_interrupt_context(self, user_text: str) -> str | None: unspoken = "" if info.full_text and info.delivered_text: if info.full_text.startswith(info.delivered_text): - unspoken = info.full_text[len(info.delivered_text):].strip() + unspoken = info.full_text[len(info.delivered_text) :].strip() else: # Fallback: everything after the delivered percentage - unspoken = info.full_text[len(info.delivered_text):].strip() + unspoken = info.full_text[len(info.delivered_text) :].strip() parts = [] parts.append("[Barge-in context — your previous response was interrupted]") - parts.append(f"spoken (user heard this): \"{info.delivered_text}\"") + parts.append(f'spoken (user heard this): "{info.delivered_text}"') if unspoken: - parts.append(f"unspoken (user did NOT hear this): \"{unspoken}\"") - parts.append(f"interrupted_at: {info.spoken_pct*100:.0f}%") - parts.append(f"user_said: \"{user_text}\"") + parts.append(f'unspoken (user did NOT hear this): "{unspoken}"') + parts.append(f"interrupted_at: {info.spoken_pct * 100:.0f}%") + parts.append(f'user_said: "{user_text}"') parts.append("Acknowledge what was interrupted and respond to the user's new input.") return "\n".join(parts) diff --git a/channels.py b/channels.py index 8480953..c026e64 100644 --- a/channels.py +++ b/channels.py @@ -82,9 +82,7 @@ def _deliver_sync(self, output: EncodedOutput) -> None: if not self._active: return try: - future = asyncio.run_coroutine_threadsafe( - self._deliver_async(output), self._loop - ) + future = asyncio.run_coroutine_threadsafe(self._deliver_async(output), self._loop) future.result(timeout=10.0) except (WebSocketDisconnect, RuntimeError, TimeoutError): logger.debug("deliver failed (client disconnected?), deactivating channel") @@ -106,13 +104,15 @@ async def _deliver_async(self, output: EncodedOutput) -> None: # Send audio as base64 JSON (avoids binary frame issues) audio_b64 = base64.b64encode(output.data).decode("ascii") logger.info("deliver: sending base64 audio JSON (%d chars)", len(audio_b64)) - await self.ws.send_json({ - "type": "audio", - "data": audio_b64, - "format": output.format or "wav", - "duration_sec": round(output.duration_sec, 2), - "sample_rate": output.metadata.get("sample_rate", 24000), - }) + await self.ws.send_json( + { + "type": "audio", + "data": audio_b64, + "format": output.format or "wav", + "duration_sec": round(output.duration_sec, 2), + "sample_rate": output.metadata.get("sample_rate", 24000), + } + ) logger.info("deliver: audio sent OK") elif output.modality == ModalityType.TEXT: text = output.data.decode("utf-8") if isinstance(output.data, bytes) else str(output.data) @@ -211,13 +211,15 @@ def _transcribe_t1(): try: result = await asyncio.to_thread(_transcribe_t1) if result and result.get("changed") and not result.get("filtered"): - await self.ws.send_json({ - "type": "partial_transcript", - "confirmed": result["confirmed"], - "tentative": result["tentative"], - "tier": "t1", - "elapsed_ms": result["elapsed_ms"], - }) + await self.ws.send_json( + { + "type": "partial_transcript", + "confirmed": result["confirmed"], + "tentative": result["tentative"], + "tier": "t1", + "elapsed_ms": result["elapsed_ms"], + } + ) except Exception as e: logger.debug("T1 error: %s", e) @@ -241,13 +243,15 @@ def _transcribe_t2(): try: result = await asyncio.to_thread(_transcribe_t2) if result and not result.get("filtered"): - await self.ws.send_json({ - "type": "partial_transcript", - "confirmed": result["confirmed"], - "tentative": result["tentative"], - "tier": "t2", - "elapsed_ms": result["elapsed_ms"], - }) + await self.ws.send_json( + { + "type": "partial_transcript", + "confirmed": result["confirmed"], + "tentative": result["tentative"], + "tier": "t2", + "elapsed_ms": result["elapsed_ms"], + } + ) except Exception as e: logger.debug("T2 error: %s", e) finally: @@ -308,7 +312,7 @@ def _transcribe(): # Skip silence if len(audio) < 16000 * 0.3: return None - rms = float(np.sqrt(np.mean(audio ** 2))) + rms = float(np.sqrt(np.mean(audio**2))) if rms < 0.005: return None @@ -357,12 +361,14 @@ def _transcribe(): if event and event.content: # Send transcript to browser - await self.ws.send_json({ - "type": "transcript", - "text": event.content, - "stt_ms": round(stt_ms, 1), - "source": "voice", - }) + await self.ws.send_json( + { + "type": "transcript", + "text": event.content, + "stt_ms": round(stt_ms, 1), + "source": "voice", + } + ) # Forward to agent loop event.metadata["stt_ms"] = stt_ms if self._on_event: @@ -376,11 +382,13 @@ async def _process_text(self, text: str) -> None: source_channel=self.channel_id, confidence=1.0, ) - await self.ws.send_json({ - "type": "transcript", - "text": text, - "source": "text", - }) + await self.ws.send_json( + { + "type": "transcript", + "text": text, + "source": "text", + } + ) if self._on_event: await self._on_event(event) @@ -407,10 +415,12 @@ async def send_response_complete(self, metrics: dict | None = None) -> None: """Signal response is complete.""" if self._active: try: - await self.ws.send_json({ - "type": "response_complete", - "metrics": metrics or {}, - }) + await self.ws.send_json( + { + "type": "response_complete", + "metrics": metrics or {}, + } + ) except Exception: self._active = False diff --git a/draft_queue.py b/draft_queue.py index dccbc15..3a7c630 100644 --- a/draft_queue.py +++ b/draft_queue.py @@ -23,10 +23,11 @@ class BlockStatus(Enum): """Lifecycle states for a draft block.""" - VALID = "valid" # Generated, awaiting playback - STALE = "stale" # Invalidated by new context - SPOKEN = "spoken" # Successfully played aloud - SNIPPED = "snipped" # Removed by self-barge + + VALID = "valid" # Generated, awaiting playback + STALE = "stale" # Invalidated by new context + SPOKEN = "spoken" # Successfully played aloud + SNIPPED = "snipped" # Removed by self-barge SPEAKING = "speaking" # Currently being spoken @@ -38,8 +39,8 @@ class DraftBlock: text: str status: BlockStatus = BlockStatus.VALID created_at: float = field(default_factory=time.time) - context_hash: str = "" # Hash of context at generation time - generation_ms: float = 0.0 # How long inference took + context_hash: str = "" # Hash of context at generation time + generation_ms: float = 0.0 # How long inference took tts_audio: bytes | None = None # Pre-synthesized audio (if available) tts_duration_sec: float = 0.0 metadata: dict[str, Any] = field(default_factory=dict) diff --git a/http_api.py b/http_api.py index 0a6396f..45a7fb8 100644 --- a/http_api.py +++ b/http_api.py @@ -58,6 +58,7 @@ async def _warmup_kokoro(): def _do_warmup(): try: from engine import get_model + get_model("kokoro") logger.info("Kokoro TTS engine pre-warmed successfully") except Exception as e: @@ -571,6 +572,7 @@ def stop_speech(job_id: str = ""): """ try: from server import _speech_queue, pipeline_state + if job_id: cancelled = _speech_queue.cancel(job_id) return {"status": "ok", "message": f"Cancelled {job_id}" if cancelled else f"Job {job_id} not found"} @@ -677,24 +679,16 @@ async def _graceful_exit(): deadline = time.time() + timeout_sec while time.time() < deadline: with _jobs_lock: - active = sum( - 1 for j in _jobs.values() - if j.get("status") in ("generating", "processing") - ) + active = sum(1 for j in _jobs.values() if j.get("status") in ("generating", "processing")) if active == 0: break await asyncio.sleep(0.25) with _jobs_lock: - remaining = sum( - 1 for j in _jobs.values() - if j.get("status") in ("generating", "processing") - ) + remaining = sum(1 for j in _jobs.values() if j.get("status") in ("generating", "processing")) if remaining: - logger.warning( - "Shutdown timeout reached with %d active jobs — forcing exit", remaining - ) + logger.warning("Shutdown timeout reached with %d active jobs — forcing exit", remaining) else: logger.info("All jobs drained — exiting cleanly") diff --git a/mcp_shim.py b/mcp_shim.py index 83831fb..a34ca35 100644 --- a/mcp_shim.py +++ b/mcp_shim.py @@ -51,8 +51,7 @@ _BARGEIN_SIGNAL = os.path.expanduser("~/.mod3_bargein_signal.json") -def _http_request(method: str, path: str, body: dict | None = None, - timeout: float = 30.0) -> tuple[int, dict | bytes]: +def _http_request(method: str, path: str, body: dict | None = None, timeout: float = 30.0) -> tuple[int, dict | bytes]: """Make an HTTP request to the Mod3 service. Returns (status_code, parsed_json_or_bytes).""" url = f"{MOD3_BASE}{path}" headers = {"Content-Type": "application/json"} if body is not None else {} @@ -161,8 +160,10 @@ def _estimate_duration(text: str, speed: float) -> float: # Tool implementations # --------------------------------------------------------------------------- -def tool_speak(text: str, voice: str = "bm_lewis", stream: bool = True, - speed: float = 1.25, emotion: float = 0.5) -> str: + +def tool_speak( + text: str, voice: str = "bm_lewis", stream: bool = True, speed: float = 1.25, emotion: float = 0.5 +) -> str: """Synthesize via HTTP, play locally.""" if not text.strip(): return json.dumps({"status": "error", "error": "Nothing to say"}) @@ -173,20 +174,30 @@ def tool_speak(text: str, voice: str = "bm_lewis", stream: bool = True, with open(_BARGEIN_SIGNAL) as f: sig = json.load(f) if sig.get("event") == "user_speaking_start": - return json.dumps({ - "status": "held", - "reason": "User is currently speaking — re-send after user finishes.", - "user_state": "recording", - "estimated_duration_sec": round(_estimate_duration(text, speed), 1), - }) + return json.dumps( + { + "status": "held", + "reason": "User is currently speaking — re-send after user finishes.", + "user_state": "recording", + "estimated_duration_sec": round(_estimate_duration(text, speed), 1), + } + ) except Exception: pass # Request synthesis from HTTP service - status, resp = _http_request("POST", "/v1/synthesize", { - "text": text, "voice": voice, "speed": speed, "emotion": emotion, - "format": "wav", - }, timeout=60.0) + status, resp = _http_request( + "POST", + "/v1/synthesize", + { + "text": text, + "voice": voice, + "speed": speed, + "emotion": emotion, + "format": "wav", + }, + timeout=60.0, + ) if status == 0: return json.dumps({"status": "error", "error": resp.get("error", "Service unreachable")}) @@ -197,7 +208,7 @@ def tool_speak(text: str, voice: str = "bm_lewis", stream: bool = True, return json.dumps({"status": "error", "error": "Expected audio bytes from synthesize"}) # Create job and play in background - job_id = f"shim-{int(time.time()*1000)}" + job_id = f"shim-{int(time.time() * 1000)}" with _jobs_lock: _jobs[job_id] = { "status": "generating", @@ -230,6 +241,7 @@ def tool_stop(job_id: str = "") -> str: _playback_interrupt.set() try: import sounddevice as sd + sd.stop() except Exception: pass @@ -242,6 +254,7 @@ def tool_stop(job_id: str = "") -> str: _playback_interrupt.set() try: import sounddevice as sd + sd.stop() except Exception: pass @@ -372,7 +385,8 @@ def tool_await_voice_input(timeout_sec: float = 180.0) -> str: try: folders = sorted( [d for d in os.listdir(_rec_dir) if d.isdigit()], - key=int, reverse=True, + key=int, + reverse=True, ) if folders: meta_path = os.path.join(_rec_dir, folders[0], "meta.json") @@ -382,14 +396,16 @@ def tool_await_voice_input(timeout_sec: float = 180.0) -> str: raw = meta.get("rawResult", "").strip() result = meta.get("result", raw).strip() duration_ms = meta.get("duration", 0) - return json.dumps({ - "status": "ok", - "transcript": result if result else raw, - "raw_transcript": raw, - "duration_sec": round(duration_ms / 1000, 1), - "folder": folders[0], - "source": "superwhisper", - }) + return json.dumps( + { + "status": "ok", + "transcript": result if result else raw, + "raw_transcript": raw, + "duration_sec": round(duration_ms / 1000, 1), + "folder": folders[0], + "source": "superwhisper", + } + ) except Exception as e: logger.warning("await_voice_input error: %s", e) @@ -411,16 +427,22 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: # The HTTP API expects multipart file upload, use urllib boundary = "----Mod3ShimBoundary" body = ( - f"--{boundary}\r\n" - f'Content-Disposition: form-data; name="file"; filename="{os.path.basename(file_path)}"\r\n' - f"Content-Type: audio/wav\r\n\r\n" - ).encode() + wav_data + f"\r\n--{boundary}--\r\n".encode() + ( + f"--{boundary}\r\n" + f'Content-Disposition: form-data; name="file"; filename="{os.path.basename(file_path)}"\r\n' + f"Content-Type: audio/wav\r\n\r\n" + ).encode() + + wav_data + + f"\r\n--{boundary}--\r\n".encode() + ) url = f"{MOD3_BASE}/v1/vad" if threshold != 0.5: url += f"?threshold={threshold}" req = urllib.request.Request( - url, data=body, method="POST", + url, + data=body, + method="POST", headers={"Content-Type": f"multipart/form-data; boundary={boundary}"}, ) try: @@ -458,10 +480,26 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "type": "object", "properties": { "text": {"type": "string", "description": "The text to speak aloud. Keep it conversational."}, - "voice": {"type": "string", "default": "bm_lewis", "description": "Voice preset. Use list_voices() to see options. Defaults to \"bm_lewis\" (Kokoro)."}, - "stream": {"type": "boolean", "default": True, "description": "If True, plays audio chunks as they generate (lower latency)."}, - "speed": {"type": "number", "default": 1.25, "description": "Speed multiplier (engines with speed support). Default 1.25."}, - "emotion": {"type": "number", "default": 0.5, "description": "Emotion/exaggeration intensity 0.0-1.0 (Chatterbox only). Default 0.5."}, + "voice": { + "type": "string", + "default": "bm_lewis", + "description": 'Voice preset. Use list_voices() to see options. Defaults to "bm_lewis" (Kokoro).', + }, + "stream": { + "type": "boolean", + "default": True, + "description": "If True, plays audio chunks as they generate (lower latency).", + }, + "speed": { + "type": "number", + "default": 1.25, + "description": "Speed multiplier (engines with speed support). Default 1.25.", + }, + "emotion": { + "type": "number", + "default": 0.5, + "description": "Emotion/exaggeration intensity 0.0-1.0 (Chatterbox only). Default 0.5.", + }, }, "required": ["text"], }, @@ -478,8 +516,16 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "inputSchema": { "type": "object", "properties": { - "job_id": {"type": "string", "default": "", "description": "The job ID returned by speak(). If empty, returns the latest job."}, - "verbose": {"type": "boolean", "default": False, "description": "If True, include per-chunk metrics. Default False (summary only)."}, + "job_id": { + "type": "string", + "default": "", + "description": "The job ID returned by speak(). If empty, returns the latest job.", + }, + "verbose": { + "type": "boolean", + "default": False, + "description": "If True, include per-chunk metrics. Default False (summary only).", + }, }, }, }, @@ -495,7 +541,11 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "inputSchema": { "type": "object", "properties": { - "job_id": {"type": "string", "default": "", "description": "If provided, cancels that specific job. If empty, stops everything."}, + "job_id": { + "type": "string", + "default": "", + "description": "If provided, cancels that specific job. If empty, stops everything.", + }, }, }, }, @@ -510,7 +560,7 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "Block until the user finishes a SuperWhisper recording, then return the transcript.\n\n" "This closes the voice input loop: instead of waiting for the user to paste\n" "their transcribed text, you can directly receive what they said. Use this\n" - "when speak() returns \"held\" (user is recording) or when you want to listen\n" + 'when speak() returns "held" (user is recording) or when you want to listen\n' "for the next voice input.\n\n" "Polls the barge-in signal file for user_speaking_end, then reads the\n" "transcript from SuperWhisper's recordings directory.\n\n" @@ -520,7 +570,11 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "inputSchema": { "type": "object", "properties": { - "timeout_sec": {"type": "number", "default": 180, "description": "Maximum seconds to wait for recording to finish. Default 180 (3 minutes)."}, + "timeout_sec": { + "type": "number", + "default": 180, + "description": "Maximum seconds to wait for recording to finish. Default 180 (3 minutes).", + }, }, }, }, @@ -534,14 +588,18 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "description": ( "List audio output devices, or set the active one.\n\n" "Args:\n" - " device: Device index (e.g. \"3\"), name substring (e.g. \"AirPods\"),\n" - " or \"default\" to track the system default automatically.\n" + ' device: Device index (e.g. "3"), name substring (e.g. "AirPods"),\n' + ' or "default" to track the system default automatically.\n' " If empty, lists available devices without changing anything." ), "inputSchema": { "type": "object", "properties": { - "device": {"type": "string", "default": "", "description": "Device index, name substring, or 'default'. If empty, lists devices."}, + "device": { + "type": "string", + "default": "", + "description": "Device index, name substring, or 'default'. If empty, lists devices.", + }, }, }, }, @@ -559,7 +617,11 @@ def tool_vad_check(file_path: str, threshold: float = 0.5) -> str: "type": "object", "properties": { "file_path": {"type": "string", "description": "Path to a WAV audio file."}, - "threshold": {"type": "number", "default": 0.5, "description": "Speech probability threshold 0-1 (default 0.5). Higher = stricter."}, + "threshold": { + "type": "number", + "default": 0.5, + "description": "Speech probability threshold 0-1 (default 0.5). Higher = stricter.", + }, }, "required": ["file_path"], }, @@ -634,11 +696,14 @@ def _jsonrpc_error(id: Any, code: int, message: str) -> dict: def handle_initialize(msg: dict) -> dict: - return _jsonrpc_response(msg["id"], { - "protocolVersion": "2024-11-05", - "serverInfo": SERVER_INFO, - "capabilities": CAPABILITIES, - }) + return _jsonrpc_response( + msg["id"], + { + "protocolVersion": "2024-11-05", + "serverInfo": SERVER_INFO, + "capabilities": CAPABILITIES, + }, + ) def handle_tools_list(msg: dict) -> dict: @@ -659,9 +724,12 @@ def handle_tools_call(msg: dict) -> dict: except Exception as e: result_text = json.dumps({"status": "error", "error": str(e)}) - return _jsonrpc_response(msg["id"], { - "content": [{"type": "text", "text": result_text}], - }) + return _jsonrpc_response( + msg["id"], + { + "content": [{"type": "text", "text": result_text}], + }, + ) def handle_notifications_initialized(msg: dict): @@ -705,6 +773,7 @@ def run_stdio(): # Self-test # --------------------------------------------------------------------------- + def self_test(): """Quick connectivity check.""" print(f"Mod3 shim — testing connection to {MOD3_BASE}") @@ -723,6 +792,7 @@ def self_test(): # Check sounddevice try: import sounddevice as sd + default_out = sd.query_devices(sd.default.device[1]) print(f" OK: sounddevice available — default output: {default_out['name']}") except ImportError: diff --git a/modules/voice.py b/modules/voice.py index aa03263..20e02e1 100644 --- a/modules/voice.py +++ b/modules/voice.py @@ -201,7 +201,7 @@ def decode_streaming( confirmed = confirmed[:last_space] # Tentative = remainder after confirmed prefix - tentative = transcript[len(confirmed):].strip() + tentative = transcript[len(confirmed) :].strip() # T3 = end-of-utterance, everything is confirmed if tier == "t3": diff --git a/providers.py b/providers.py index a8034b9..259361b 100644 --- a/providers.py +++ b/providers.py @@ -121,8 +121,7 @@ def _format_tools_for_prompt(tools: list[dict]) -> str: for pname, pinfo in props.items(): req_marker = " (required)" if pname in required else "" lines.append( - f" - {pname} ({pinfo.get('type', 'string')}): " - f"{pinfo.get('description', '')}{req_marker}" + f" - {pname} ({pinfo.get('type', 'string')}): {pinfo.get('description', '')}{req_marker}" ) lines.append( "\nTo call a tool, output exactly:\n" @@ -135,9 +134,7 @@ def _format_tools_for_prompt(tools: list[dict]) -> str: return "\n".join(lines) -_TOOL_CALL_RE = re.compile( - r"\s*(\{.*?\})\s*", re.DOTALL -) +_TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) def _parse_tool_calls(text: str) -> list[ToolCall]: @@ -168,9 +165,7 @@ class MlxProvider: """ def __init__(self, model_id: str | None = None): - self._model_id = model_id or os.environ.get( - "MLX_MODEL", "mlx-community/gemma-3-4b-it-4bit" - ) + self._model_id = model_id or os.environ.get("MLX_MODEL", "mlx-community/gemma-3-4b-it-4bit") self._model = None self._tokenizer = None @@ -210,9 +205,7 @@ def _generate_sync( msgs = [{"role": "system", "content": "\n\n".join(system_parts)}] + msgs # Apply chat template - prompt = self._tokenizer.apply_chat_template( - msgs, add_generation_prompt=True, tokenize=False - ) + prompt = self._tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) max_tokens = int(os.environ.get("MLX_MAX_TOKENS", "512")) raw_output = generate( @@ -240,9 +233,7 @@ async def chat( tools: list[dict] | None = None, system: str = "", ) -> ProviderResponse: - return await asyncio.to_thread( - self._generate_sync, messages, tools, system - ) + return await asyncio.to_thread(self._generate_sync, messages, tools, system) # --------------------------------------------------------------------------- @@ -258,9 +249,7 @@ def __init__( endpoint: str | None = None, model: str | None = None, ): - self._endpoint = endpoint or os.environ.get( - "OLLAMA_ENDPOINT", "http://localhost:11434" - ) + self._endpoint = endpoint or os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434") self._model = model or os.environ.get("OLLAMA_MODEL", "gemma4:e4b") @property @@ -322,9 +311,7 @@ class CogOSProvider: """CogOS kernel — OpenAI-compatible chat/completions with tool support.""" def __init__(self, endpoint: str | None = None): - self._endpoint = endpoint or os.environ.get( - "COGOS_ENDPOINT", "http://localhost:5100" - ) + self._endpoint = endpoint or os.environ.get("COGOS_ENDPOINT", "http://localhost:5100") @property def name(self) -> str: diff --git a/server.py b/server.py index 8c79bcb..cc9016f 100644 --- a/server.py +++ b/server.py @@ -416,9 +416,11 @@ def _bargein_watcher(): """Background thread that watches for barge-in signal file changes.""" global _bargein_last_mtime import json as _json + while True: try: import os + if os.path.exists(_BARGEIN_SIGNAL): mtime = os.path.getmtime(_BARGEIN_SIGNAL) if mtime > _bargein_last_mtime: @@ -437,7 +439,10 @@ def _bargein_watcher(): } with open(_BARGEIN_SIGNAL, "w") as f: _json.dump(signal, f, indent=2) - logging.info("Barge-in: paused local playback (%.0f%% delivered)", info.spoken_pct * 100 if info else 0) + logging.info( + "Barge-in: paused local playback (%.0f%% delivered)", + info.spoken_pct * 100 if info else 0, + ) else: # Check cross-process lock (another Mod³ process may be speaking) lock = _is_any_process_speaking() @@ -834,12 +839,14 @@ def speak( # can't be cleared by stop(). if user_state == "recording": est_duration = _estimate_duration_sec(text, speed) - return json.dumps({ - "status": "held", - "reason": "User is currently speaking — re-send this speak() call after user finishes.", - "user_state": "recording", - "estimated_duration_sec": round(est_duration, 1), - }) + return json.dumps( + { + "status": "held", + "reason": "User is currently speaking — re-send this speak() call after user finishes.", + "user_state": "recording", + "estimated_duration_sec": round(est_duration, 1), + } + ) try: job_id, position = _start_speech(text, voice, stream=stream, speed=speed, emotion=emotion) @@ -1162,9 +1169,7 @@ def await_voice_input(timeout_sec: float = 180.0) -> str: """ import sqlite3 as _sqlite3 - _sw_db = os.path.expanduser( - "~/Library/Application Support/SuperWhisper/database/superwhisper.sqlite" - ) + _sw_db = os.path.expanduser("~/Library/Application Support/SuperWhisper/database/superwhisper.sqlite") _rec_dir = os.path.expanduser("~/Documents/superwhisper/recordings") start = time.time() @@ -1198,23 +1203,23 @@ def await_voice_input(timeout_sec: float = 180.0) -> str: raw = meta.get("rawResult", "").strip() result = meta.get("result", raw).strip() duration_ms = meta.get("duration", 0) - return json.dumps({ - "status": "ok", - "transcript": result if result else raw, - "raw_transcript": raw, - "duration_sec": round(duration_ms / 1000, 1), - "folder": folders[0], - "source": "superwhisper", - }) + return json.dumps( + { + "status": "ok", + "transcript": result if result else raw, + "raw_transcript": raw, + "duration_sec": round(duration_ms / 1000, 1), + "folder": folders[0], + "source": "superwhisper", + } + ) except Exception as e: logger.warning("await_voice_input meta.json fallback failed: %s", e) # Method 2: Query SuperWhisper SQLite DB try: conn = _sqlite3.connect(f"file:{_sw_db}?mode=ro", uri=True, timeout=2.0) - row = conn.execute( - "SELECT folderName, duration FROM recording ORDER BY datetime DESC LIMIT 1" - ).fetchone() + row = conn.execute("SELECT folderName, duration FROM recording ORDER BY datetime DESC LIMIT 1").fetchone() conn.close() if row: folder_name, duration = row @@ -1224,14 +1229,16 @@ def await_voice_input(timeout_sec: float = 180.0) -> str: meta = json.load(f) raw = meta.get("rawResult", "").strip() result = meta.get("result", raw).strip() - return json.dumps({ - "status": "ok", - "transcript": result if result else raw, - "raw_transcript": raw, - "duration_sec": round(duration / 1000, 1), - "folder": folder_name, - "source": "superwhisper_db", - }) + return json.dumps( + { + "status": "ok", + "transcript": result if result else raw, + "raw_transcript": raw, + "duration_sec": round(duration / 1000, 1), + "folder": folder_name, + "source": "superwhisper_db", + } + ) except Exception as e: logger.warning("await_voice_input DB fallback failed: %s", e) @@ -1308,12 +1315,17 @@ def set_output_device(device: str = "") -> str: outputs.append({"index": i, "name": d["name"], "active": is_active, "default": is_default}) if not device: - lines = [f" [{'*' if d['active'] else ' '}] {d['index']}: {d['name']}{' (system default)' if d['default'] else ''}" for d in outputs] + lines = [ + f" [{'*' if d['active'] else ' '}] {d['index']}: {d['name']}{' (system default)' if d['default'] else ''}" + for d in outputs + ] return "Audio output devices (* = active):\n" + "\n".join(lines) if device.lower() == "default": _output_device = None - return json.dumps({"status": "ok", "device": "system_default", "note": "Now tracking system default output device"}) + return json.dumps( + {"status": "ok", "device": "system_default", "note": "Now tracking system default output device"} + ) if device.isdigit(): _output_device = int(device) From f7f0df9ab6abd5889b0936951cc063caeb50d0c7 Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Wed, 15 Apr 2026 15:40:57 -0400 Subject: [PATCH 4/9] =?UTF-8?q?fix:=20player.stop()=20=E2=86=92=20player.f?= =?UTF-8?q?lush()=20for=20barge-in=20interrupt=20(pyright=20type=20error)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index cc9016f..0c0e4e5 100644 --- a/server.py +++ b/server.py @@ -667,7 +667,7 @@ def _run_speech_job(entry: dict) -> None: # Check if barge-in cleared our speaking lock (cross-process interrupt) if not os.path.exists(_SPEAKING_LOCK): logging.info("Speaking lock cleared by barge-in watcher — stopping generation") - player.stop() + player.flush() break player.queue_audio(chunk.samples, chunk_meta=chunk.metadata if chunk.metadata else None) _set_bus_voice_state( From 503057e1ca7ddf77e5941569a6a73e893b0dfa80 Mon Sep 17 00:00:00 2001 From: Chaz Dinkle Date: Fri, 17 Apr 2026 23:47:24 -0400 Subject: [PATCH 5/9] =?UTF-8?q?feat:=20bus-mediated=20dashboard=20chat=20?= =?UTF-8?q?=E2=80=94=20cogos=20kernel=20as=20inference=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires Mod³'s dashboard chat to route user messages through the cogos kernel's running metabolic-cycle agent instead of the local MLX Gemma provider. When MOD3_USE_COGOS_AGENT=1, user turns flow as bus events (bus_dashboard_chat → kernel inlet → harness observation → respond tool → bus_dashboard_response) and render back in the dashboard as response_text frames. Voice and text now share a single conversation through the same metabolic cycle. Also lands bidirectional barge-in context stitching on the WebSocket path: BargeinContext schema + agent_loop injection into next-turn system prompt. Fixes the gap where dashboard interruptions halted TTS but didn't surface structured context to the agent (previously only the MCP/SuperWhisper file-signal path injected it). 6 new bargein tests; 2 new bus-bridge tests; 5 new cogos-agent bridge tests. 47 pytest collect total. Dashboard: live Cycle Trace drawer consuming bus_cycle_trace via SSE subscriber. Bottom-drawer UI, 100-entry rolling window, collapsible with localStorage. ort.min.js + WASM for VAD runtime. Whisper default pinned to whisper-base-mlx to reduce concurrent MLX Metal pressure (Gemma + Kokoro + Whisper segfault). Large-v3-turbo restoration is a separate MLX-stability fix; voice-input path still crashes on mic due to underlying MLX concurrency issue (known, tracked separately). Co-Authored-By: Claude Opus 4.7 (1M context) --- agent_loop.py | 112 +- bus_bridge.py | 281 +++ bus_bridge_runner.py | 144 ++ channels.py | 70 + cogos_agent_bridge.py | 240 +++ dashboard/index.html | 53 + dashboard/trace.js | 179 ++ dashboard/transport.js | 10 + dashboard/vad/ort.min.js | 2869 ++++++++++++++++++++++++++++++ demo/e2e_audio_trace_demo.py | 140 ++ demo/e2e_dashboard_harness.py | 176 ++ http_api.py | 52 + modules/voice.py | 4 +- schemas/__init__.py | 1 + schemas/bargein.py | 70 + tests/test_bargein_context.py | 200 +++ tests/test_bus_bridge_runner.py | 74 + tests/test_cogos_agent_bridge.py | 114 ++ 18 files changed, 4754 insertions(+), 35 deletions(-) create mode 100644 bus_bridge.py create mode 100644 bus_bridge_runner.py create mode 100644 cogos_agent_bridge.py create mode 100644 dashboard/trace.js create mode 100644 dashboard/vad/ort.min.js create mode 100644 demo/e2e_audio_trace_demo.py create mode 100644 demo/e2e_dashboard_harness.py create mode 100644 schemas/__init__.py create mode 100644 schemas/bargein.py create mode 100644 tests/test_bargein_context.py create mode 100644 tests/test_bus_bridge_runner.py create mode 100644 tests/test_cogos_agent_bridge.py diff --git a/agent_loop.py b/agent_loop.py index dd8f548..63b9ea3 100644 --- a/agent_loop.py +++ b/agent_loop.py @@ -21,6 +21,7 @@ from modality import CognitiveEvent, CognitiveIntent, ModalityType from pipeline_state import PipelineState from providers import AGENT_TOOLS, InferenceProvider +from schemas.bargein import BargeinContext if TYPE_CHECKING: from channels import BrowserChannel @@ -149,6 +150,9 @@ def __init__( self.draft_queue = DraftQueue() self._speculative_context: list[dict[str, str]] = [] # Context for speculative inference self._human_speaking = False # Whether human is currently speaking + # A2: typed barge-in context prepared before the next turn, consumed by A3 + # for prompt injection. Set by _prepare_bargein_context() on the WS path. + self._pending_bargein: BargeinContext | None = None async def handle_event(self, event: CognitiveEvent) -> None: """Called when a CognitiveEvent arrives from the channel.""" @@ -175,12 +179,46 @@ async def handle_event(self, event: CognitiveEvent) -> None: async def _process(self, event: CognitiveEvent) -> None: """Core: event → provider → tool dispatch.""" - # Context stitching: inject interrupt context from dashboard path - # This closes the barge-in loop — the agent knows what was spoken, - # what was unsaid, and what the user interrupted with. - interrupt_context = self._build_interrupt_context(event.content) - if interrupt_context: - self.conversation.append({"role": "system", "content": interrupt_context}) + # A2: build typed BargeinContext from pipeline_state.last_interrupt (if any) + # and stash on self._pending_bargein. A3 will consume it for prompt injection. + self._prepare_bargein_context(user_text=event.content) + + # MOD3_USE_COGOS_AGENT fork: forward user turn to kernel bus instead of + # calling local provider. Response arrives asynchronously via the + # cogos_agent_bridge → BrowserChannel.broadcast_response_text path. + from cogos_agent_bridge import is_enabled as _cogos_agent_enabled + from cogos_agent_bridge import post_user_message as _post_user_message + + if _cogos_agent_enabled(): + session_id = f"mod3:{self.channel_id or 'unknown'}" + # Fold any pending barge-in context into the forwarded text so the + # kernel cycle sees it. A full structured payload will come in a + # later iteration; for v1 we prepend the terse prompt renderer. + forwarded_text = event.content + pending = self._pending_bargein + if pending is not None: + self._pending_bargein = None + forwarded_text = ( + "[interrupted earlier] " + + pending.format_for_prompt() + + "\n" + + forwarded_text + ) + ok = await _post_user_message(forwarded_text, session_id=session_id) + if not ok and self._channel_ref: + try: + await self._channel_ref.send_response_text( + "[cogos-agent unreachable — check kernel]" + ) + await self._channel_ref.send_response_complete( + metrics={"provider": "cogos-agent", "error": "unreachable"} + ) + except Exception: + pass + # Track the user turn in history so subsequent turns carry it. + self.conversation.append({"role": "user", "content": event.content}) + self._trim_history() + return self.conversation.append({"role": "user", "content": event.content}) self._trim_history() @@ -190,6 +228,7 @@ async def _process(self, event: CognitiveEvent) -> None: # Assemble system prompt with kernel context (afferent path) kernel_ctx = _fetch_kernel_context() system_prompt = _BASE_SYSTEM_PROMPT + kernel_ctx + system_prompt = self._inject_pending_bargein(system_prompt) response = await self.provider.chat( messages=self.conversation, @@ -510,46 +549,51 @@ async def background_validate_drafts(self, latest_user_text: str) -> None: await self._push_draft_queue_state() - def _build_interrupt_context(self, user_text: str) -> str | None: - """Build context stitch from pipeline_state.last_interrupt. - - When the user barged in during TTS playback, captures what was - spoken vs unspoken and injects it as structured context for the - next inference call. Consumes the interrupt (clears it). + def _prepare_bargein_context(self, user_text: str | None) -> None: + """Read pipeline_state.last_interrupt and stash a typed BargeinContext. - Returns a context string, or None if no interrupt occurred. + Called at the top of each WS turn. If the previous assistant reply was + interrupted (and the interrupt is still fresh, < 30s), build a + BargeinContext via the A1 schema and store it on ``self._pending_bargein`` + for A3 to pick up during prompt construction. Clears last_interrupt so + the next turn does not re-consume a stale record. """ info = self.pipeline_state.last_interrupt if info is None: - return None + self._pending_bargein = None + return # Only use recent interrupts (within last 30 seconds) if time.time() - info.timestamp > 30: - return None + # Stale — clear and skip. + with self.pipeline_state._lock: + self.pipeline_state._last_interrupt = None + self._pending_bargein = None + return - # Clear the interrupt so we don't re-inject it + # Consume the interrupt so we don't re-inject it on subsequent turns. + # pipeline_state has no public consume helper yet; clear the private + # slot under its lock (matches the pre-existing pattern on this path). with self.pipeline_state._lock: self.pipeline_state._last_interrupt = None - # Compute unspoken remainder - unspoken = "" - if info.full_text and info.delivered_text: - if info.full_text.startswith(info.delivered_text): - unspoken = info.full_text[len(info.delivered_text) :].strip() - else: - # Fallback: everything after the delivered percentage - unspoken = info.full_text[len(info.delivered_text) :].strip() + self._pending_bargein = BargeinContext.from_interrupt_info( + info, + source="browser_vad", + user_said=user_text or None, + ) - parts = [] - parts.append("[Barge-in context — your previous response was interrupted]") - parts.append(f'spoken (user heard this): "{info.delivered_text}"') - if unspoken: - parts.append(f'unspoken (user did NOT hear this): "{unspoken}"') - parts.append(f"interrupted_at: {info.spoken_pct * 100:.0f}%") - parts.append(f'user_said: "{user_text}"') - parts.append("Acknowledge what was interrupted and respond to the user's new input.") - - return "\n".join(parts) + def _inject_pending_bargein(self, system_prompt: str) -> str: + """Append the pending BargeinContext (if any) to the system prompt. + + Consumes ``self._pending_bargein`` so it does not leak into subsequent + turns. Returns the prompt unchanged if no barge-in is pending. + """ + pending = self._pending_bargein + if pending is None: + return system_prompt + self._pending_bargein = None + return system_prompt + "\n\n" + pending.format_for_prompt() def _trim_history(self) -> None: """Keep conversation within MAX_HISTORY messages.""" diff --git a/bus_bridge.py b/bus_bridge.py new file mode 100644 index 0000000..27745c1 --- /dev/null +++ b/bus_bridge.py @@ -0,0 +1,281 @@ +"""Kernel-bus SSE subscriber. + +Consumes http://localhost:6931/v1/events/stream and yields parsed bus events. +Reconnects on disconnect with exponential backoff. Tolerates unknown event kinds +per ADR-083 (cycle-trace event contract). + +C3 will consume this to broadcast CycleEvents to dashboard WebSocket clients. + +The kernel (see apps/cogos/bus_stream.go) emits SSE frames of the form: + + data: {"id":"live_*_42","type":"bus.event","timestamp":"...","data":{}}\\n\\n + +Heartbeats arrive as SSE comment lines: + + : keep-alive\\n\\n + +An initial frame of {"type":"connected","bus_id":"*","timestamp":"..."} is +sent on subscribe — we surface that as a BusEnvelope with kind="connected". +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from dataclasses import dataclass, field +from typing import Any, AsyncIterator, Optional + +import httpx + +logger = logging.getLogger("mod3.bus_bridge") + +KERNEL_BUS_STREAM_URL = "http://localhost:6931/v1/events/stream" + + +@dataclass +class BusEnvelope: + """Raw bus-envelope record as received from the kernel SSE stream. + + `raw` is the full outer JSON (the bus.event envelope). `payload` is the + inner CogBlock dict (envelope["data"]) — may be {} for non-bus.event + frames (e.g. the initial "connected" frame). `kind` is the best-effort + event-kind string: preferring payload["kind"] (ADR-083 CycleEvent), then + payload["type"], then envelope["type"]. Consumers MUST tolerate unknown + kinds. + """ + + raw: dict + kind: str + payload: dict = field(default_factory=dict) + ts: Optional[str] = None + event_id: Optional[str] = None + + +def _extract_kind(envelope: dict, payload: dict) -> str: + for src in (payload, envelope): + for key in ("kind", "type"): + val = src.get(key) if isinstance(src, dict) else None + if isinstance(val, str) and val: + return val + return "unknown" + + +class KernelBusSubscriber: + """Async SSE subscriber for the cogos kernel bus stream. + + Usage:: + + sub = KernelBusSubscriber() + async for env in sub.stream(): + handle(env) + + `stream()` yields indefinitely; on any transport error it reconnects + with exponential backoff clamped to [reconnect_min_s, reconnect_max_s]. + Call `close()` (or cancel the consuming task) to stop. + """ + + def __init__( + self, + url: str = KERNEL_BUS_STREAM_URL, + *, + bus_filter: str = "*", + consumer_id: Optional[str] = None, + reconnect_min_s: float = 1.0, + reconnect_max_s: float = 30.0, + request_timeout_s: float = 10.0, + ) -> None: + self._url = url + self._bus_filter = bus_filter + self._consumer_id = consumer_id + self._min_backoff = reconnect_min_s + self._max_backoff = reconnect_max_s + self._request_timeout = request_timeout_s + self._last_event_id: Optional[str] = None + self._closed = asyncio.Event() + self._client: Optional[httpx.AsyncClient] = None + + async def close(self) -> None: + self._closed.set() + if self._client is not None: + try: + await self._client.aclose() + except Exception: # pragma: no cover - best-effort + pass + self._client = None + + def _build_params(self) -> dict[str, str]: + params: dict[str, str] = {} + if self._bus_filter and self._bus_filter != "*": + params["bus_id"] = self._bus_filter + if self._consumer_id: + params["consumer"] = self._consumer_id + return params + + def _build_headers(self) -> dict[str, str]: + headers = {"Accept": "text/event-stream", "Cache-Control": "no-cache"} + if self._last_event_id: + # Harmless if the kernel doesn't honor it today; future protocol + # bump may use it for resume. + headers["Last-Event-ID"] = self._last_event_id + return headers + + async def stream(self) -> AsyncIterator[BusEnvelope]: + backoff = self._min_backoff + # Generous read timeout — SSE is long-lived with 30s heartbeats. + timeout = httpx.Timeout(self._request_timeout, read=None) + while not self._closed.is_set(): + self._client = httpx.AsyncClient(timeout=timeout) + try: + async with self._client.stream( + "GET", + self._url, + params=self._build_params(), + headers=self._build_headers(), + ) as resp: + if resp.status_code != 200: + logger.info( + "bus-bridge: non-200 from %s: %s — backing off %.1fs", + self._url, resp.status_code, backoff, + ) + await self._sleep_or_close(backoff) + backoff = min(self._max_backoff, max(self._min_backoff, backoff * 2)) + continue + + logger.info("bus-bridge: connected to %s", self._url) + backoff = self._min_backoff # reset on successful connect + + async for envelope in self._iter_sse(resp): + yield envelope + except (httpx.HTTPError, asyncio.TimeoutError, ConnectionError) as e: + logger.info( + "bus-bridge: transport error (%s); reconnecting in %.1fs", + e.__class__.__name__, backoff, + ) + await self._sleep_or_close(backoff) + backoff = min(self._max_backoff, max(self._min_backoff, backoff * 2)) + except asyncio.CancelledError: + await self.close() + raise + finally: + if self._client is not None: + try: + await self._client.aclose() + except Exception: # pragma: no cover + pass + self._client = None + + async def _sleep_or_close(self, seconds: float) -> None: + try: + await asyncio.wait_for(self._closed.wait(), timeout=seconds) + except asyncio.TimeoutError: + return + + async def _iter_sse(self, resp: httpx.Response) -> AsyncIterator[BusEnvelope]: + """Parse the SSE byte stream into BusEnvelope records. + + Minimal SSE parser: we accumulate field lines into the current event, + dispatch on blank-line boundaries, silently skip comment lines + (`: heartbeat`), and honor `data:`, `event:`, `id:` fields. + """ + event_name: Optional[str] = None + data_lines: list[str] = [] + event_id: Optional[str] = None + + async for raw_line in resp.aiter_lines(): + if self._closed.is_set(): + return + # httpx strips the trailing \n but preserves empty lines. + if raw_line == "": + # Dispatch boundary. + if data_lines: + env = self._parse_event(event_name, "\n".join(data_lines), event_id) + if env is not None: + yield env + event_name = None + data_lines = [] + event_id = None + continue + if raw_line.startswith(":"): + # Comment line / heartbeat. + continue + field, _, value = raw_line.partition(":") + if value.startswith(" "): + value = value[1:] + if field == "data": + data_lines.append(value) + elif field == "event": + event_name = value + elif field == "id": + event_id = value + self._last_event_id = value + # retry / unknown fields: ignore + + def _parse_event( + self, event_name: Optional[str], data: str, event_id: Optional[str] + ) -> Optional[BusEnvelope]: + try: + envelope: Any = json.loads(data) + except json.JSONDecodeError: + logger.debug("bus-bridge: non-JSON data frame dropped: %r", data[:200]) + return None + if not isinstance(envelope, dict): + logger.debug("bus-bridge: non-object data frame dropped: %r", envelope) + return None + + inner = envelope.get("data") + payload: dict = inner if isinstance(inner, dict) else {} + kind = _extract_kind(envelope, payload) + ts = envelope.get("timestamp") or payload.get("ts") or payload.get("timestamp") + eid = event_id or envelope.get("id") + if eid and not self._last_event_id: + self._last_event_id = eid + + if kind not in ("state_transition", "tool_dispatch", "assessment", "bus.event", "connected"): + # Tolerate unknowns — just log and forward. + logger.debug("bus-bridge: forwarding unknown event kind=%r", kind) + + return BusEnvelope( + raw=envelope, + kind=kind, + payload=payload, + ts=ts if isinstance(ts, str) else None, + event_id=eid if isinstance(eid, str) else None, + ) + + +# --------------------------------------------------------------------------- +# Manual validation entry point +# --------------------------------------------------------------------------- + + +async def _main() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + sub = KernelBusSubscriber() + print(f"bus-bridge: subscribing to {sub._url} (Ctrl-C to stop)") + try: + async for env in sub.stream(): + print( + json.dumps( + { + "kind": env.kind, + "ts": env.ts, + "id": env.event_id, + "payload_keys": sorted(env.payload.keys())[:12], + } + ) + ) + except (KeyboardInterrupt, asyncio.CancelledError): + pass + finally: + await sub.close() + + +if __name__ == "__main__": + try: + asyncio.run(_main()) + except KeyboardInterrupt: + pass diff --git a/bus_bridge_runner.py b/bus_bridge_runner.py new file mode 100644 index 0000000..5747ec2 --- /dev/null +++ b/bus_bridge_runner.py @@ -0,0 +1,144 @@ +"""Kernel-bus → dashboard bridge runner. + +Consumes `KernelBusSubscriber.stream()` (see `bus_bridge.py`) and fans the +ADR-083 cycle-trace events out to every connected dashboard WebSocket via +`BrowserChannel.broadcast_trace_event()` (see `channels.py`). + +Wiring: + + kernel (bus_cycle_trace) + └─► SSE /v1/events/stream?bus_id=bus_cycle_trace + └─► KernelBusSubscriber.stream() [C1] + └─► run_bridge() filter + forward + └─► BrowserChannel.broadcast_trace_event() [C2] + +The subscriber does its own reconnect with exponential backoff, so a kernel +that is temporarily unreachable does not affect server startup. Disable the +bridge entirely at process boot by setting env `MOD3_BUS_BRIDGE_DISABLED=1`. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Optional + +from bus_bridge import KERNEL_BUS_STREAM_URL, BusEnvelope, KernelBusSubscriber +from channels import BrowserChannel + +logger = logging.getLogger("mod3.bus_bridge") + +# ADR-083 kinds the dashboard trace panel cares about. Kept as a module-level +# constant so tests and the lifespan wiring share one definition. +ADR083_KINDS: frozenset[str] = frozenset({"state_transition", "tool_dispatch", "assessment"}) + +# Kernel-side bus name (see apps/cogos/trace_emit.go:const traceBusID). +TRACE_BUS_ID = "bus_cycle_trace" + +# Env flag consulted at startup. +DISABLE_ENV = "MOD3_BUS_BRIDGE_DISABLED" + + +def is_disabled() -> bool: + """True when MOD3_BUS_BRIDGE_DISABLED is set to a truthy value.""" + v = os.environ.get(DISABLE_ENV, "").strip().lower() + return v in ("1", "true", "yes", "on") + + +async def run_bridge( + subscriber: KernelBusSubscriber, + *, + filter_kinds: Optional[set[str]] = None, +) -> None: + """Consume `subscriber` and broadcast cycle-trace events to dashboard clients. + + `filter_kinds`: + - `None`: forward everything (dev mode — useful when inspecting the raw + stream through a dashboard). + - set of kind strings: only forward envelopes whose `BusEnvelope.kind` + is in the set. Unknown kinds are tolerated per ADR-083 — they simply + won't pass this filter. + + `BrowserChannel.broadcast_trace_event()` is thread-safe and non-blocking: + it dispatches each WS send via `run_coroutine_threadsafe`. We call it + directly (no await). + """ + first_event_logged = False + forwarded = 0 + async for env in subscriber.stream(): + if filter_kinds is not None and env.kind not in filter_kinds: + continue + # The "connected" bootstrap frame has an empty payload; skip silently. + if env.kind == "connected": + continue + if not first_event_logged: + logger.info( + "bridge: first event forwarded kind=%s event_id=%s", + env.kind, env.event_id, + ) + first_event_logged = True + try: + BrowserChannel.broadcast_trace_event(env.payload) + forwarded += 1 + logger.debug( + "bridge: forwarded kind=%s event_id=%s (total=%d)", + env.kind, env.event_id, forwarded, + ) + except Exception as exc: # noqa: BLE001 — broadcaster is best-effort + logger.debug("bridge: broadcast failed: %s", exc) + + +async def start_bridge( + app_state: object, + *, + url: str = KERNEL_BUS_STREAM_URL, + bus_filter: str = TRACE_BUS_ID, + filter_kinds: Optional[set[str]] = frozenset(ADR083_KINDS), +) -> None: + """Construct the subscriber + bridge task and store them on `app_state`. + + Startup is non-blocking: we don't await the task or probe the kernel. + The subscriber's own backoff loop handles reconnects. Logs a disabled + notice and returns cleanly when `MOD3_BUS_BRIDGE_DISABLED` is set. + """ + if is_disabled(): + logger.info("bridge: disabled via %s=1", DISABLE_ENV) + setattr(app_state, "bus_bridge_subscriber", None) + setattr(app_state, "bus_bridge_task", None) + return + + subscriber = KernelBusSubscriber(url=url, bus_filter=bus_filter, consumer_id="mod3-dashboard") + task = asyncio.create_task( + run_bridge(subscriber, filter_kinds=set(filter_kinds) if filter_kinds else None), + name="mod3-bus-bridge", + ) + setattr(app_state, "bus_bridge_subscriber", subscriber) + setattr(app_state, "bus_bridge_task", task) + logger.info( + "bridge: started, target=%s bus_id=%s filter=%s", + url, bus_filter, sorted(filter_kinds) if filter_kinds else "*", + ) + + +async def stop_bridge(app_state: object, *, timeout_s: float = 2.0) -> None: + """Gracefully stop the bridge: close subscriber, await task, cancel on timeout.""" + subscriber: Optional[KernelBusSubscriber] = getattr(app_state, "bus_bridge_subscriber", None) + task: Optional[asyncio.Task] = getattr(app_state, "bus_bridge_task", None) + if subscriber is None and task is None: + return + if subscriber is not None: + try: + await subscriber.close() + except Exception: # pragma: no cover - best-effort + pass + if task is not None: + try: + await asyncio.wait_for(task, timeout=timeout_s) + except (asyncio.TimeoutError, asyncio.CancelledError): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): # pragma: no cover + pass + logger.info("bridge: stopped") diff --git a/channels.py b/channels.py index c026e64..702dbc8 100644 --- a/channels.py +++ b/channels.py @@ -8,6 +8,12 @@ T1 (Whisper Base, ~31ms): per-chunk during speech T2 (Whisper Large, ~470ms): on natural pause T3 (Whisper Large, ~470ms): on end-of-utterance (final) + +Server→client WebSocket message types: + audio, response_text, response_complete, interrupted, + partial_transcript, transcript, + trace_event — kernel cycle-trace events (ADR-083), fanned out via + BrowserChannel.broadcast_trace_event(). """ from __future__ import annotations @@ -33,6 +39,12 @@ class BrowserChannel: """WebSocket-backed channel for the browser dashboard.""" + # Registry of currently-active dashboard channels. Used by + # broadcast_trace_event() to fan kernel cycle-trace events out to every + # connected dashboard client (see ADR-083). Populated in __init__, + # pruned in _cleanup. + _active_channels: "set[BrowserChannel]" = set() + def __init__( self, ws: WebSocket, @@ -71,6 +83,7 @@ def __init__( modalities=[ModalityType.VOICE, ModalityType.TEXT], deliver=self._deliver_sync, ) + BrowserChannel._active_channels.add(self) logger.info("BrowserChannel registered: %s", self.channel_id) # ------------------------------------------------------------------ @@ -424,6 +437,62 @@ async def send_response_complete(self, metrics: dict | None = None) -> None: except Exception: self._active = False + # ------------------------------------------------------------------ + # Trace event broadcast (kernel cycle-trace → dashboards) + # ------------------------------------------------------------------ + + @classmethod + def broadcast_trace_event(cls, event: dict) -> None: + """Fan a kernel cycle-trace event out to every connected dashboard. + + Per ADR-083, `event` is a pre-parsed CycleEvent dict + (id, ts, source, cycle_id, kind, payload). Wrapped in the + `{"type": "trace_event", "event": ...}` envelope and sent to each + active BrowserChannel's WebSocket. Clients whose send fails are + skipped silently (they will be pruned by their own disconnect path). + """ + frame = {"type": "trace_event", "event": event} + for ch in list(cls._active_channels): + if not ch._active: + continue + try: + asyncio.run_coroutine_threadsafe(ch.ws.send_json(frame), ch._loop) + except Exception as exc: # noqa: BLE001 — disconnected clients are expected + logger.debug("trace_event send failed for %s: %s", ch.channel_id, exc) + + @classmethod + def broadcast_response_text(cls, text: str, session_id: str | None = None) -> None: + """Push an agent-reply text frame to dashboard WebSocket clients. + + Used by the MOD3_USE_COGOS_AGENT response bridge (see + `cogos_agent_bridge.run_response_bridge`). The frame matches the + existing text-response shape emitted by `_deliver_async` and + `send_response_text`: `{"type": "response_text", "text": }`. + + If `session_id` is None (default) the frame is broadcast to every + active dashboard channel. When provided, only channels whose + `channel_id` matches the `mod3:` convention from + `cogos_agent_bridge.post_user_message` receive the frame — this is + how future multi-user routing will land, but for v1 a None + broadcast is the common case (only one dashboard attached). + + Thread-safe: dispatches each WS send via `run_coroutine_threadsafe` + on the channel's own loop, matching `broadcast_trace_event`. + """ + frame = {"type": "response_text", "text": text} + expected_channel = None + if session_id and session_id.startswith("mod3:"): + expected_channel = session_id[len("mod3:"):] + for ch in list(cls._active_channels): + if not ch._active: + continue + if expected_channel and ch.channel_id != expected_channel: + continue + try: + asyncio.run_coroutine_threadsafe(ch.ws.send_json(frame), ch._loop) + except Exception as exc: # noqa: BLE001 — disconnected clients are expected + logger.debug("response_text send failed for %s: %s", ch.channel_id, exc) + # ------------------------------------------------------------------ # Cleanup # ------------------------------------------------------------------ @@ -431,6 +500,7 @@ async def send_response_complete(self, metrics: dict | None = None) -> None: def _cleanup(self) -> None: """Deactivate channel and cancel pending TTS jobs on disconnect.""" self._active = False + BrowserChannel._active_channels.discard(self) ch = self.bus._channels.get(self.channel_id) if ch: ch.active = False diff --git a/cogos_agent_bridge.py b/cogos_agent_bridge.py new file mode 100644 index 0000000..0137b8a --- /dev/null +++ b/cogos_agent_bridge.py @@ -0,0 +1,240 @@ +"""CogOS kernel agent bridge (MOD3_USE_COGOS_AGENT=1). + +When the env flag is set, Mod³'s agent loop forwards user turns to the +cogos kernel's metabolic cycle instead of the local inference provider: + + browser → WS turn → post_user_message() ─POST /v1/bus/send─► kernel + │ + ▼ + bus_dashboard_chat + │ + ▼ + kernel cycle → `respond` tool + │ + ▼ + bus_dashboard_response + │ + SSE /v1/events/stream + │ + ▼ + KernelBusSubscriber.stream() + │ + ▼ + run_response_bridge() + │ + ▼ + BrowserChannel.broadcast_response_text() + +The subscriber does its own reconnect with exponential backoff (see +`bus_bridge.py`). Disable the whole fork by leaving `MOD3_USE_COGOS_AGENT` +unset (default). + +Note: the kernel's `POST /v1/bus/send` takes a flat `{bus_id, from, to, +message, type}` body — the inner JSON event is serialised into `message` +(matches the pattern used by other cogos producers). +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +from datetime import datetime, timezone +from typing import Optional + +import httpx + +from bus_bridge import KERNEL_BUS_STREAM_URL, KernelBusSubscriber +from channels import BrowserChannel + +logger = logging.getLogger("mod3.cogos_agent") + +# Bus names — contract with the kernel side (see ADR / c-agent subagent). +CHAT_BUS_ID = "bus_dashboard_chat" +RESPONSE_BUS_ID = "bus_dashboard_response" + +# Kernel endpoints. +_DEFAULT_KERNEL_BASE = os.environ.get("COGOS_ENDPOINT", "http://localhost:6931") +BUS_SEND_URL = f"{_DEFAULT_KERNEL_BASE}/v1/bus/send" + +# Env gate. +ENABLE_ENV = "MOD3_USE_COGOS_AGENT" + +_POST_TIMEOUT_S = 5.0 + + +def is_enabled() -> bool: + """True when MOD3_USE_COGOS_AGENT is set to a truthy value.""" + v = os.environ.get(ENABLE_ENV, "").strip().lower() + return v in ("1", "true", "yes", "on") + + +def _now_rfc3339() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +async def post_user_message(text: str, session_id: str) -> bool: + """POST a user turn to the kernel's `bus_dashboard_chat` bus. + + Returns True if the send succeeded (kernel replied 2xx), False otherwise. + Logs at warning-level on failure but never raises — callers use graceful + degradation (e.g. show an error response frame to the dashboard). + + The kernel's handleBusSend (see apps/cogos/bus_api.go) accepts + `{bus_id, from, to, message, type}` — we JSON-encode the full event dict + into `message` so the kernel's cycle receives the structured payload. + """ + event = { + "type": "user_message", + "text": text, + "session_id": session_id, + "ts": _now_rfc3339(), + } + body = { + "bus_id": CHAT_BUS_ID, + "from": "mod3-dashboard", + "type": "user_message", + "message": json.dumps(event, separators=(",", ":")), + } + try: + async with httpx.AsyncClient(timeout=_POST_TIMEOUT_S) as client: + resp = await client.post(BUS_SEND_URL, json=body) + except httpx.HTTPError as exc: + logger.warning("cogos-agent: post to %s failed: %s", BUS_SEND_URL, exc) + return False + if resp.status_code // 100 != 2: + logger.warning( + "cogos-agent: post non-2xx: %d body=%r", + resp.status_code, resp.text[:200], + ) + return False + logger.info( + "cogos-agent: forwarded user turn to kernel bus (session=%s)", + session_id, + ) + return True + + +def _extract_response_text(payload: dict) -> Optional[str]: + """Dig the assistant reply out of the bus event payload. + + Kernel's `handleBusSend` wraps the sent `message` string inside a + `{"content": ""}` map. On SSE delivery, the envelope's `data` + field is that map. We look first for structured keys (`text`, direct + agent_response shape), then fall through to parsing `content` as JSON. + """ + if not isinstance(payload, dict): + return None + # Direct shape (if an upstream producer wrote the event dict at the top level). + for key in ("text", "reply", "response"): + val = payload.get(key) + if isinstance(val, str) and val: + return val + # Standard bus envelope: payload = {"content": ""} + content = payload.get("content") + if isinstance(content, str) and content: + try: + inner = json.loads(content) + except (TypeError, ValueError): + # Free-form string — treat the whole thing as the reply. + return content + if isinstance(inner, dict): + for key in ("text", "reply", "response"): + val = inner.get(key) + if isinstance(val, str) and val: + return val + elif isinstance(inner, str) and inner: + return inner + return None + + +async def run_response_bridge(subscriber: KernelBusSubscriber) -> None: + """Consume `subscriber` and broadcast agent replies to dashboard clients. + + `BrowserChannel.broadcast_response_text()` is thread-safe via + `run_coroutine_threadsafe`, matching the existing trace-event pattern. + Malformed events (no recoverable text) are logged at debug and skipped. + """ + first_event_logged = False + forwarded = 0 + async for env in subscriber.stream(): + if env.kind == "connected": + continue + text = _extract_response_text(env.payload) + if not text: + logger.debug( + "cogos-agent: skip event with no text kind=%s id=%s", + env.kind, env.event_id, + ) + continue + if not first_event_logged: + logger.info( + "cogos-agent: first response forwarded kind=%s event_id=%s", + env.kind, env.event_id, + ) + first_event_logged = True + try: + BrowserChannel.broadcast_response_text(text) + forwarded += 1 + logger.debug( + "cogos-agent: forwarded response event_id=%s (total=%d)", + env.event_id, forwarded, + ) + except Exception as exc: # noqa: BLE001 — best-effort fan-out + logger.debug("cogos-agent: broadcast failed: %s", exc) + + +async def start_response_bridge( + app_state: object, + *, + url: str = KERNEL_BUS_STREAM_URL, +) -> None: + """Construct the response subscriber + bridge task and store on `app_state`. + + No-op (logs once) when `MOD3_USE_COGOS_AGENT` is unset. + """ + if not is_enabled(): + logger.debug("cogos-agent: response bridge disabled (%s unset)", ENABLE_ENV) + setattr(app_state, "cogos_agent_subscriber", None) + setattr(app_state, "cogos_agent_task", None) + return + + subscriber = KernelBusSubscriber( + url=url, + bus_filter=RESPONSE_BUS_ID, + consumer_id="mod3-dashboard-agent", + ) + task = asyncio.create_task( + run_response_bridge(subscriber), + name="mod3-cogos-agent-bridge", + ) + setattr(app_state, "cogos_agent_subscriber", subscriber) + setattr(app_state, "cogos_agent_task", task) + logger.info( + "cogos-agent: response bridge started, target=%s bus_id=%s", + url, RESPONSE_BUS_ID, + ) + + +async def stop_response_bridge(app_state: object, *, timeout_s: float = 2.0) -> None: + """Gracefully stop the response bridge: close subscriber, await task, cancel on timeout.""" + subscriber: Optional[KernelBusSubscriber] = getattr(app_state, "cogos_agent_subscriber", None) + task: Optional[asyncio.Task] = getattr(app_state, "cogos_agent_task", None) + if subscriber is None and task is None: + return + if subscriber is not None: + try: + await subscriber.close() + except Exception: # pragma: no cover - best-effort + pass + if task is not None: + try: + await asyncio.wait_for(task, timeout=timeout_s) + except (asyncio.TimeoutError, asyncio.CancelledError): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): # pragma: no cover + pass + logger.info("cogos-agent: response bridge stopped") diff --git a/dashboard/index.html b/dashboard/index.html index 914a2b0..ae166c5 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -171,6 +171,49 @@ 50% { opacity: 0.3; } } + /* Cycle Trace panel (bottom drawer) */ + #trace-panel { + position: fixed; left: 0; right: 0; bottom: 0; + background: var(--surface); border-top: 1px solid var(--border); + font-family: ui-monospace, SFMono-Regular, Menlo, monospace; + z-index: 20; max-height: 40vh; display: flex; flex-direction: column; + transition: max-height 0.2s ease; + } + #trace-panel.collapsed { max-height: 32px; } + #trace-panel .trace-header { + display: flex; align-items: center; gap: 8px; + padding: 6px 16px; border-bottom: 1px solid var(--border); + background: var(--bg); cursor: pointer; user-select: none; + font-size: 0.75rem; color: var(--muted); text-transform: uppercase; + letter-spacing: 0.5px; flex-shrink: 0; height: 32px; + } + #trace-panel .trace-header .trace-title { font-weight: 600; } + #trace-panel .trace-header .trace-toggle { + margin-left: auto; font-size: 0.7rem; color: var(--muted); + } + #trace-panel.collapsed #trace-entries { display: none; } + #trace-entries { + overflow-y: auto; padding: 4px 0; flex: 1; min-height: 0; + font-size: 0.75rem; line-height: 1.4; + } + .trace-entry { + display: flex; gap: 8px; align-items: baseline; + padding: 2px 16px; border-bottom: 1px solid rgba(48,54,61,0.3); + white-space: nowrap; overflow: hidden; + } + .trace-entry:hover { background: rgba(88,166,255,0.05); } + .trace-time { color: var(--muted); flex-shrink: 0; font-variant-numeric: tabular-nums; } + .trace-source { color: var(--muted); flex-shrink: 0; font-size: 0.7rem; } + .trace-kind { + flex-shrink: 0; font-size: 0.65rem; padding: 1px 6px; + border: 1px solid var(--border); border-radius: 3px; + text-transform: uppercase; letter-spacing: 0.5px; + } + .trace-summary { color: var(--text); overflow: hidden; text-overflow: ellipsis; } + + /* Leave room at the bottom so the drawer doesn't cover the input */ + body { padding-bottom: 32px; } + /* Responsive */ @media (max-width: 700px) { .main { padding: 12px 16px; } @@ -300,6 +343,15 @@

Mod³

For voice, use headphones. Speak naturally — the system detects when you start and stop. Speak during playback to interrupt.
+ + + +