diff --git a/aai_cli/AGENTS.md b/aai_cli/AGENTS.md index 6ee011b6..35131c33 100644 --- a/aai_cli/AGENTS.md +++ b/aai_cli/AGENTS.md @@ -152,7 +152,7 @@ heavily-reworked commands with long bodies; small commands keep the inline - **`core/sync_stt.py`** + **`core/signals.py`** + `commands/dictate/` — `assembly dictate`: headless dictation over the **Sync STT API** (`Environment.sync_base`, one POST `/transcribe` per utterance with the required `X-AAI-Model: u3-sync-pro` header; 80 ms–120 s of PCM/WAV). It needs no terminal: recording starts immediately and `dictate_exec._record` polls `signals.stop_on_terminate` between ~100 ms mic chunks for a SIGTERM, which finishes the utterance (clean exit 0) — so a hotkey tool like Hammerspoon can launch it as a background task and `kill -TERM`/`task:terminate()` to transcribe. SIGINT (Ctrl-C) still cancels (exit 130). Both boundaries (the stop latch, mic, HTTP) are injectable, so the suite never needs a real signal or microphone (`tests/test_dictate_exec.py` scripts the SIGTERM latch). Contrast `signals.terminate_as_interrupt` (used by `stream`/`agent`/`speak`), which routes SIGTERM into the *cancel* path instead. - **`agent/`** — full-duplex voice agent (mic in, TTS out via `voices.py`). - **`agent_cascade/`** + `commands/agent_cascade/` — `assembly agent-cascade`: the same live terminal conversation as `assembly agent`, but **client-orchestrated** — `engine.run_cascade` wires Streaming STT → the LLM Gateway → streaming TTS itself instead of talking to the Voice Agent endpoint, mirroring what the `agent-cascade` `assembly init` template does server-side. **Sandbox-only** (streaming TTS has no prod host; guarded via `tts.session.require_available`). Reuses the agent slice's `DuplexAudio`/`AgentRenderer` and `core.client.stream_audio`/`core.llm.complete`/`tts.session.synthesize`; the three network legs are injected through `engine.CascadeDeps` (the `tts/session.py` seam) so the cascade — greeting, per-sentence TTS, barge-in, history window — is unit-tested against fakes with no sockets/mic/speaker. -- **`tts/`** + `commands/speak.py` — `assembly speak` synthesizes text to speech over the sandbox streaming-TTS WebSocket (`streaming-tts.sandbox000.…`). **Sandbox-only:** `session.is_available()` is false in production (empty `Environment.streaming_tts_host`), so the command exits 2 with a `--sandbox` hint. `session.synthesize` drives a Begin→Generate→Flush→Audio→Terminate protocol with an injectable `connect` for hermetic tests (mirrors `agent/session.py`); `audio.py` plays the PCM (default) or writes a WAV (`--out`). +- **`tts/`** + `commands/speak.py` — `assembly speak` synthesizes text to speech over the sandbox streaming-TTS WebSocket (`streaming-tts.sandbox000.…`). **Sandbox-only:** `session.is_available()` is false in production (empty `Environment.streaming_tts_host`), so the command exits 2 with a `--sandbox` hint. `session.synthesize` drives a Begin→Generate→Flush→Audio→Terminate protocol with an injectable `connect` for hermetic tests (mirrors `agent/session.py`); `audio.py` plays the PCM (default) or writes a WAV (`--out`). The single-voice default-playback path **streams**: `synthesize`'s `on_audio(chunk, sample_rate)` callback is wired to `audio.PcmPlayer.feed`, so speech starts on the first Audio frame (it opens the device lazily, since the rate is only known at Begin) instead of after the whole text — the win for a long `--url` page. `--out` (needs the full buffer) and the multi-voice dialogue path (`synthesize_dialogue` → `_output_audio` → buffered `play_pcm`) stay buffered; `synthesize` still returns the complete PCM for the summary regardless. - **`code_gen/`** — backs `--show-code` on `transcribe`/`stream`/`agent`: builds a ready-to-run Python SDK script from exactly the flags passed (no API key needed; generated code reads `ASSEMBLYAI_API_KEY`). - **`auth/`** — browser-assisted `assembly login` via AMS + **Stytch B2B OAuth discovery** (`discovery.py`, `flow.py`, `loopback.py`, `ams.py`). Not Stytch Connected Apps. - **`init/`** — scaffolds a self-contained FastAPI + HTML starter (`audio-transcription`/`live-captions`/`voice-agent` templates), optionally installs deps and opens the browser; writes the key to a git-ignored `.env`. diff --git a/aai_cli/commands/speak/_exec.py b/aai_cli/commands/speak/_exec.py index 6badac86..57ece68e 100644 --- a/aai_cli/commands/speak/_exec.py +++ b/aai_cli/commands/speak/_exec.py @@ -145,11 +145,22 @@ def _speak_single( cfg = session.SpeakConfig( text=text, voice=voice, language=opts.language, sample_rate=opts.sample_rate ) - with output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet): - result = session.synthesize( - api_key, cfg, on_warning=lambda m: output.emit_warning(m, json_mode=json_mode) - ) - _output_audio(result, opts.out) + + def on_warning(message: str) -> None: + output.emit_warning(message, json_mode=json_mode) + + if opts.out is None: + # Play each audio frame as it arrives so speech starts on the first frame + # rather than after the whole text is synthesized (the long --url page case). + with ( + audio.PcmPlayer() as player, + output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet), + ): + result = session.synthesize(api_key, cfg, on_audio=player.feed, on_warning=on_warning) + else: + with output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet): + result = session.synthesize(api_key, cfg, on_warning=on_warning) + audio.write_wav(opts.out, result.pcm, result.sample_rate) _emit_single(result, cfg, opts.out, json_mode=json_mode) diff --git a/aai_cli/tts/audio.py b/aai_cli/tts/audio.py index f1b8bfef..e4aec431 100644 --- a/aai_cli/tts/audio.py +++ b/aai_cli/tts/audio.py @@ -4,14 +4,14 @@ import wave from collections.abc import Callable from pathlib import Path -from typing import Protocol +from typing import Literal, Protocol from aai_cli.core.errors import CLIError from aai_cli.core.microphone import import_sounddevice class _OutputStream(Protocol): - """The slice of a sounddevice output stream play_pcm drives — named as a + """The slice of a sounddevice output stream the player drives — named as a Protocol so the untyped library boundary is structurally typed, not opaque.""" def start(self) -> None: @@ -67,41 +67,80 @@ def _playback_error(exc: Exception) -> CLIError: ) +class PcmPlayer: + """An incremental 16-bit mono player: ``feed`` each PCM chunk as it is produced. + + Used as a context manager so audio can start on the *first* chunk while later + chunks are still being synthesized (streaming TTS), instead of waiting for the + whole clip. The output sample rate isn't known until the server reports it + mid-stream, so the device is opened lazily on the first ``feed``. On a normal + exit the stream drains; on Ctrl-C (or any error) it is aborted — buffered + frames discarded for an immediate stop — and the cancel propagates. Each chunk + is written in short pieces so a Ctrl-C lands promptly between writes. A device + failure is wrapped in a clean CLIError that points at --out as the headless + escape hatch. ``stream_factory`` is injectable for tests. + """ + + def __init__(self, *, stream_factory: Callable[[int], _OutputStream] | None = None) -> None: + self._factory = stream_factory or _default_output_stream + self._stream: _OutputStream | None = None + + def __enter__(self) -> PcmPlayer: + return self + + def feed(self, pcm: bytes, sample_rate: int) -> None: + """Play one PCM chunk, opening the device on the first chunk.""" + if self._stream is None: + self._stream = self._open(sample_rate) + self._write(self._stream, pcm) + + def _open(self, sample_rate: int) -> _OutputStream: + try: + stream = self._factory(sample_rate) + stream.start() + except CLIError: + raise # audio_missing_error() is already user-facing + except Exception as exc: + raise _playback_error(exc) from exc + return stream + + @staticmethod + def _write(stream: _OutputStream, pcm: bytes) -> None: + # KeyboardInterrupt (a BaseException) passes through this Exception handler + # to __exit__, which aborts the device; only real device errors are wrapped. + try: + for offset in range(0, len(pcm), _PLAYBACK_CHUNK_BYTES): + stream.write(pcm[offset : offset + _PLAYBACK_CHUNK_BYTES]) + except Exception as exc: + raise _playback_error(exc) from exc + + def __exit__(self, exc_type: object, *_: object) -> Literal[False]: # pragma: no mutate + stream = self._stream + if stream is not None: + try: + if exc_type is None: # normal exit -> drain; an error/Ctrl-C -> abort + stream.stop() + else: + # Cut sound immediately (discard buffered frames) instead of + # letting stop() drain the rest, then let the error propagate. + with contextlib.suppress(Exception): + stream.abort() + finally: + with contextlib.suppress(Exception): + stream.close() + return False # never suppress: Ctrl-C / device errors must reach the CLI + + def play_pcm( pcm: bytes, sample_rate: int, *, stream_factory: Callable[[int], _OutputStream] | None = None, ) -> None: - """Play 16-bit mono PCM through the default output device (blocks until done). + """Play a complete 16-bit mono PCM buffer through the default output device. - Audio is written in short chunks so a Ctrl-C interrupts promptly: on - KeyboardInterrupt the stream is aborted (buffered frames discarded) for an - immediate stop, then the cancel propagates. ``stream_factory`` is injectable - for tests; a device failure is wrapped in a clean CLIError that points at - --out as the headless escape hatch. + A thin convenience over ``PcmPlayer`` for callers that already hold the whole + clip (the multi-voice dialogue path); it blocks until playback finishes. """ - factory = stream_factory or _default_output_stream - try: - stream = factory(sample_rate) - except CLIError: - raise # audio_missing_error() is already user-facing - except Exception as exc: - raise _playback_error(exc) from exc - - try: - stream.start() - for offset in range(0, len(pcm), _PLAYBACK_CHUNK_BYTES): - stream.write(pcm[offset : offset + _PLAYBACK_CHUNK_BYTES]) - stream.stop() - except KeyboardInterrupt: - # Cut sound immediately (discard whatever is still buffered in the device) - # instead of letting stop() drain the rest, then propagate the cancel. - with contextlib.suppress(Exception): - stream.abort() - raise - except Exception as exc: - raise _playback_error(exc) from exc - finally: - with contextlib.suppress(Exception): - stream.close() + with PcmPlayer(stream_factory=stream_factory) as player: + player.feed(pcm, sample_rate) diff --git a/aai_cli/tts/session.py b/aai_cli/tts/session.py index 71fe4c93..4d685114 100644 --- a/aai_cli/tts/session.py +++ b/aai_cli/tts/session.py @@ -130,6 +130,21 @@ def _decode_audio_frame(msg: dict[str, object]) -> bytes: raise APIError(f"TTS service sent an Audio frame that is not valid base64: {exc}") from exc +def _consume_audio_frame( + msg: dict[str, object], + pcm: bytearray, + sample_rate: int, + on_audio: Callable[[bytes, int], None] | None, +) -> bool: + """Append one Audio frame's PCM, stream it to ``on_audio``, and report whether + it is the final frame (so the collection loop can stop).""" + chunk = _decode_audio_frame(msg) + pcm.extend(chunk) + if on_audio is not None: + on_audio(chunk, sample_rate) + return bool(msg.get("is_final")) + + def _error_frame_detail(msg: dict[str, object]) -> str: """The ``(code): reason`` tail of an Error frame, with an ``unknown`` reason fallback so a detail-less frame still yields a non-empty message.""" @@ -158,9 +173,16 @@ def _default_connect( def _run_protocol( - ws: _WebSocket, config: SpeakConfig, on_warning: Callable[[str], None] | None + ws: _WebSocket, + config: SpeakConfig, + on_warning: Callable[[str], None] | None, + on_audio: Callable[[bytes, int], None] | None, ) -> SpeakResult: - """Send Generate + Flush, collect Audio until is_final, then Terminate.""" + """Send Generate + Flush, collect Audio until is_final, then Terminate. + + Each decoded Audio chunk is handed to ``on_audio(chunk, sample_rate)`` (when + given) as it arrives, so a caller can play it immediately instead of waiting + for the whole synthesis; the full PCM is still accumulated and returned.""" begin = json.loads(_recv_raw(ws)) begin_type = begin.get("type") if begin_type == "Error": @@ -182,8 +204,7 @@ def _run_protocol( msg = json.loads(_recv_raw(ws)) mtype = msg.get("type") if mtype == "Audio": - pcm.extend(_decode_audio_frame(msg)) - if msg.get("is_final"): + if _consume_audio_frame(msg, pcm, sample_rate, on_audio): break elif mtype == "FlushDone": # The live server ends a synthesis with FlushDone (its Audio frames carry @@ -207,10 +228,13 @@ def synthesize( *, connect: _Connect | None = None, on_warning: Callable[[str], None] | None = None, + on_audio: Callable[[bytes, int], None] | None = None, ) -> SpeakResult: """Open the streaming-TTS socket and synthesize ``config.text`` to PCM. ``connect`` defaults to websockets' synchronous client; injectable for tests. + ``on_audio(chunk, sample_rate)``, when given, receives each PCM chunk as it + arrives (for incremental playback); the full PCM is still returned regardless. Connect/session failures map to a clean CLIError (a rejected key -> exit 4). """ wsutil.silence_websockets_logging() @@ -230,7 +254,7 @@ def synthesize( max_size=None, # no frame cap: a synthesis's Audio frames can exceed the 1 MiB default ) try: - return _run_protocol(ws, config, on_warning) + return _run_protocol(ws, config, on_warning, on_audio) except (CLIError, KeyboardInterrupt, BrokenPipeError): raise # clean CLI errors, Ctrl-C, and a closed pipe are handled upstream except Exception as exc: diff --git a/tests/test_speak.py b/tests/test_speak.py index a118d37c..2c8f95ba 100644 --- a/tests/test_speak.py +++ b/tests/test_speak.py @@ -3,6 +3,7 @@ import json import re import signal +from typing import Literal import pytest from typer.testing import CliRunner @@ -23,11 +24,28 @@ def _fake_key(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(config, "resolve_api_key", lambda **_: "test-key") +class _FakePlayer: + """A no-op stand-in for audio.PcmPlayer: records the chunks fed to it instead + of opening a real output device (none exists on a headless CI box).""" + + def __init__(self) -> None: + self.fed: list[tuple[bytes, int]] = [] + + def __enter__(self) -> _FakePlayer: + return self + + def __exit__(self, *_exc: object) -> Literal[False]: + return False + + def feed(self, pcm: bytes, sample_rate: int) -> None: + self.fed.append((pcm, sample_rate)) + + @pytest.fixture def fake_synthesize(monkeypatch: pytest.MonkeyPatch): calls: dict[str, object] = {} - def _fake(api_key, cfg, *, connect=None, on_warning=None): + def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None): calls["api_key"] = api_key calls["cfg"] = cfg return session.SpeakResult( @@ -49,23 +67,51 @@ def test_production_env_is_rejected_with_sandbox_hint(): assert "before the command" in " ".join(result.output.split()) -def test_plays_audio_by_default(monkeypatch, fake_synthesize): - played: dict = {} - monkeypatch.setattr( - "aai_cli.commands.speak._exec.audio.play_pcm", - lambda pcm, rate, **_: played.update(pcm=pcm, rate=rate), - ) +def test_plays_audio_by_default(monkeypatch): + # Default (no --out): each audio frame is streamed to the speaker as it arrives, + # via PcmPlayer.feed, rather than buffered and played at the end. The fake + # synthesize plays one frame back through the wired on_audio callback. + player = _FakePlayer() + monkeypatch.setattr("aai_cli.commands.speak._exec.audio.PcmPlayer", lambda **_: player) + captured: list[session.SpeakConfig] = [] + + def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None): + captured.append(cfg) + assert on_audio is not None + on_audio(b"\x01\x02\x03\x04", 24000) # the server emitting one Audio frame + return session.SpeakResult( + pcm=b"\x01\x02\x03\x04", sample_rate=24000, audio_duration_seconds=0.1 + ) + + monkeypatch.setattr(session, "synthesize", _fake) result = runner.invoke(app, ["--sandbox", "speak", "Hello there"]) assert result.exit_code == 0 - assert played == {"pcm": b"\x01\x02\x03\x04", "rate": 24000} - assert fake_synthesize["cfg"].text == "Hello there" + # The frame reached the speaker (chunk + the server's reported rate), proving the + # on_audio=player.feed wiring — not a buffered play_pcm at the end. + assert player.fed == [(b"\x01\x02\x03\x04", 24000)] + assert captured[0].text == "Hello there" # No --voice given -> single-voice path falls back to the default "jane". - assert fake_synthesize["cfg"].voice == "jane" + assert captured[0].voice == "jane" # Human summary (stderr) reports the default "played" disposition. assert "played" in result.stderr assert "saved to" not in result.stderr +def test_single_voice_server_warning_is_surfaced(monkeypatch): + # A Warning frame during single-voice synthesis is surfaced through the wired + # on_warning callback; in --json mode it ships as its own {"warning": …} object. + def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None): + assert on_warning is not None + on_warning("slow synthesis") + return session.SpeakResult(pcm=b"", sample_rate=24000, audio_duration_seconds=0.0) + + monkeypatch.setattr(session, "synthesize", _fake) + result = runner.invoke(app, ["--sandbox", "speak", "Hi", "--json"]) + assert result.exit_code == 0 + warning = next(json.loads(line) for line in result.stderr.splitlines() if line.startswith("{")) + assert warning["warning"] == "slow synthesis" + + def test_out_writes_wav_and_does_not_play(monkeypatch, tmp_path, fake_synthesize): monkeypatch.setattr( "aai_cli.commands.speak._exec.audio.play_pcm", diff --git a/tests/test_tts_audio.py b/tests/test_tts_audio.py index fdde061a..93c6095d 100644 --- a/tests/test_tts_audio.py +++ b/tests/test_tts_audio.py @@ -118,6 +118,80 @@ def _missing(_rate: int): assert "Could not play audio" not in excinfo.value.message +def test_pcm_player_opens_device_once_and_drains_on_exit(): + # Streaming playback: the device is opened lazily on the FIRST feed and reused + # for every later chunk (one "start", not one per chunk), then drained (stop) + # and closed on a normal exit. The factory is called exactly once. + streams: list[FakeStream] = [] + + def _factory(rate: int) -> FakeStream: + streams.append(FakeStream()) + return streams[-1] + + with audio.PcmPlayer(stream_factory=_factory) as player: + player.feed(b"\x01\x02", 24000) + player.feed(b"\x03\x04", 24000) + assert len(streams) == 1 # opened once, not re-opened per chunk + stream = streams[0] + assert stream.events == ["start", "write", "write", "stop", "close"] + assert stream.written == b"\x01\x02\x03\x04" + + +def test_pcm_player_uses_the_first_feeds_sample_rate(): + captured: dict[str, int] = {} + + def _factory(rate: int) -> FakeStream: + captured["rate"] = rate + return FakeStream() + + with audio.PcmPlayer(stream_factory=_factory) as player: + player.feed(b"\x01\x02", 16000) + assert captured["rate"] == 16000 # the device is opened at the reported rate + + +def test_pcm_player_feed_writes_in_bounded_chunks(): + stream = FakeStream() + pcm = bytes(range(256)) * 40 # 10240 bytes > 2 * chunk + with audio.PcmPlayer(stream_factory=lambda rate: stream) as player: + player.feed(pcm, 24000) + assert [len(c) for c in stream.writes] == [4096, 4096, 2048] + assert b"".join(stream.writes) == pcm + + +def test_pcm_player_aborts_and_propagates_when_a_chunk_fails(): + # An error while playing a chunk maps to a clean CLIError; the device is aborted + # (not drained) and still closed. + stream = FakeStream(raise_on_write=RuntimeError("device fell over")) + with pytest.raises(CLIError, match="Could not play audio"): + with audio.PcmPlayer(stream_factory=lambda rate: stream) as player: + player.feed(b"\x01\x02", 24000) + assert "abort" in stream.events + assert "stop" not in stream.events + assert stream.events[-1] == "close" + + +def test_pcm_player_without_any_feed_is_a_clean_noop(): + # No audio ever arrived (e.g. an empty synthesis): the device was never opened, + # so a clean exit touches nothing and does not crash. + opened: list[int] = [] + + def _factory(rate: int) -> FakeStream: + opened.append(rate) + return FakeStream() + + with audio.PcmPlayer(stream_factory=_factory): + pass + assert opened == [] # the factory was never called + + +def test_pcm_player_propagates_a_body_error_when_never_opened(): + # An error before the first feed (device never opened) still propagates — the + # context manager never swallows it. + with pytest.raises(ValueError, match="boom"): + with audio.PcmPlayer(stream_factory=lambda rate: FakeStream()): + raise ValueError("boom") + + def test_default_output_stream_opens_raw_int16_mono_stream(monkeypatch: pytest.MonkeyPatch): captured: dict[str, object] = {} diff --git a/tests/test_tts_session.py b/tests/test_tts_session.py index a181876d..f4d1167d 100644 --- a/tests/test_tts_session.py +++ b/tests/test_tts_session.py @@ -180,6 +180,38 @@ def test_synthesize_stops_on_flush_done_when_audio_omits_is_final(): assert ws.closed is True +def test_synthesize_streams_each_chunk_to_on_audio_as_it_arrives(): + # The whole point of streaming playback: every decoded Audio chunk is handed to + # on_audio(chunk, sample_rate) the moment it arrives — one call per frame, in + # order, with the server's reported rate — while the full PCM is still returned. + ws = FakeWS( + [ + _begin_frame(sample_rate=16000), + _audio_chunk(b"\x01\x02"), + _audio_chunk(b"\x03\x04"), + _flush_done_frame(), + ] + ) + streamed: list[tuple[bytes, int]] = [] + result = session.synthesize( + "k", + session.SpeakConfig(text="hi"), + connect=lambda *a, **k: ws, + on_audio=lambda chunk, rate: streamed.append((chunk, rate)), + ) + # One call per Audio frame, in arrival order, each carrying the Begin sample rate. + assert streamed == [(b"\x01\x02", 16000), (b"\x03\x04", 16000)] + # The buffered result is unchanged — streaming is additive, not a replacement. + assert result.pcm == b"\x01\x02\x03\x04" + + +def test_synthesize_without_on_audio_still_returns_full_pcm(): + # The callback is optional: omitting it must not change the buffered result. + ws = FakeWS([_begin_frame(), _audio_frame(b"\x01\x02", final=True)]) + result = session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) + assert result.pcm == b"\x01\x02" + + def test_synthesize_reads_sample_rate_from_begin_configuration(): # A non-default rate in the Begin frame flows into the result and its duration, # proving the rate is read from Begin.configuration rather than hardcoded.