Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aai_cli/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ heavily-reworked commands with long bodies; small commands keep the inline
- **`core/sync_stt.py`** + **`core/signals.py`** + `commands/dictate/` — `assembly dictate`: headless dictation over the **Sync STT API** (`Environment.sync_base`, one POST `/transcribe` per utterance with the required `X-AAI-Model: u3-sync-pro` header; 80 ms–120 s of PCM/WAV). It needs no terminal: recording starts immediately and `dictate_exec._record` polls `signals.stop_on_terminate` between ~100 ms mic chunks for a SIGTERM, which finishes the utterance (clean exit 0) — so a hotkey tool like Hammerspoon can launch it as a background task and `kill -TERM`/`task:terminate()` to transcribe. SIGINT (Ctrl-C) still cancels (exit 130). Both boundaries (the stop latch, mic, HTTP) are injectable, so the suite never needs a real signal or microphone (`tests/test_dictate_exec.py` scripts the SIGTERM latch). Contrast `signals.terminate_as_interrupt` (used by `stream`/`agent`/`speak`), which routes SIGTERM into the *cancel* path instead.
- **`agent/`** — full-duplex voice agent (mic in, TTS out via `voices.py`).
- **`agent_cascade/`** + `commands/agent_cascade/` — `assembly agent-cascade`: the same live terminal conversation as `assembly agent`, but **client-orchestrated** — `engine.run_cascade` wires Streaming STT → the LLM Gateway → streaming TTS itself instead of talking to the Voice Agent endpoint, mirroring what the `agent-cascade` `assembly init` template does server-side. **Sandbox-only** (streaming TTS has no prod host; guarded via `tts.session.require_available`). Reuses the agent slice's `DuplexAudio`/`AgentRenderer` and `core.client.stream_audio`/`core.llm.complete`/`tts.session.synthesize`; the three network legs are injected through `engine.CascadeDeps` (the `tts/session.py` seam) so the cascade — greeting, per-sentence TTS, barge-in, history window — is unit-tested against fakes with no sockets/mic/speaker.
- **`tts/`** + `commands/speak.py` — `assembly speak` synthesizes text to speech over the sandbox streaming-TTS WebSocket (`streaming-tts.sandbox000.…`). **Sandbox-only:** `session.is_available()` is false in production (empty `Environment.streaming_tts_host`), so the command exits 2 with a `--sandbox` hint. `session.synthesize` drives a Begin→Generate→Flush→Audio→Terminate protocol with an injectable `connect` for hermetic tests (mirrors `agent/session.py`); `audio.py` plays the PCM (default) or writes a WAV (`--out`).
- **`tts/`** + `commands/speak.py` — `assembly speak` synthesizes text to speech over the sandbox streaming-TTS WebSocket (`streaming-tts.sandbox000.…`). **Sandbox-only:** `session.is_available()` is false in production (empty `Environment.streaming_tts_host`), so the command exits 2 with a `--sandbox` hint. `session.synthesize` drives a Begin→Generate→Flush→Audio→Terminate protocol with an injectable `connect` for hermetic tests (mirrors `agent/session.py`); `audio.py` plays the PCM (default) or writes a WAV (`--out`). The single-voice default-playback path **streams**: `synthesize`'s `on_audio(chunk, sample_rate)` callback is wired to `audio.PcmPlayer.feed`, so speech starts on the first Audio frame (it opens the device lazily, since the rate is only known at Begin) instead of after the whole text — the win for a long `--url` page. `--out` (needs the full buffer) and the multi-voice dialogue path (`synthesize_dialogue` → `_output_audio` → buffered `play_pcm`) stay buffered; `synthesize` still returns the complete PCM for the summary regardless.
- **`code_gen/`** — backs `--show-code` on `transcribe`/`stream`/`agent`: builds a ready-to-run Python SDK script from exactly the flags passed (no API key needed; generated code reads `ASSEMBLYAI_API_KEY`).
- **`auth/`** — browser-assisted `assembly login` via AMS + **Stytch B2B OAuth discovery** (`discovery.py`, `flow.py`, `loopback.py`, `ams.py`). Not Stytch Connected Apps.
- **`init/`** — scaffolds a self-contained FastAPI + HTML starter (`audio-transcription`/`live-captions`/`voice-agent` templates), optionally installs deps and opens the browser; writes the key to a git-ignored `.env`.
Expand Down
21 changes: 16 additions & 5 deletions aai_cli/commands/speak/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,22 @@ def _speak_single(
cfg = session.SpeakConfig(
text=text, voice=voice, language=opts.language, sample_rate=opts.sample_rate
)
with output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet):
result = session.synthesize(
api_key, cfg, on_warning=lambda m: output.emit_warning(m, json_mode=json_mode)
)
_output_audio(result, opts.out)

def on_warning(message: str) -> None:
output.emit_warning(message, json_mode=json_mode)

if opts.out is None:
# Play each audio frame as it arrives so speech starts on the first frame
# rather than after the whole text is synthesized (the long --url page case).
with (
audio.PcmPlayer() as player,
output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet),
):
result = session.synthesize(api_key, cfg, on_audio=player.feed, on_warning=on_warning)
else:
with output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet):
result = session.synthesize(api_key, cfg, on_warning=on_warning)
audio.write_wav(opts.out, result.pcm, result.sample_rate)
_emit_single(result, cfg, opts.out, json_mode=json_mode)


Expand Down
103 changes: 71 additions & 32 deletions aai_cli/tts/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import wave
from collections.abc import Callable
from pathlib import Path
from typing import Protocol
from typing import Literal, Protocol

from aai_cli.core.errors import CLIError
from aai_cli.core.microphone import import_sounddevice


class _OutputStream(Protocol):
"""The slice of a sounddevice output stream play_pcm drives — named as a
"""The slice of a sounddevice output stream the player drives — named as a
Protocol so the untyped library boundary is structurally typed, not opaque."""

def start(self) -> None:
Expand Down Expand Up @@ -67,41 +67,80 @@ def _playback_error(exc: Exception) -> CLIError:
)


class PcmPlayer:
"""An incremental 16-bit mono player: ``feed`` each PCM chunk as it is produced.

Used as a context manager so audio can start on the *first* chunk while later
chunks are still being synthesized (streaming TTS), instead of waiting for the
whole clip. The output sample rate isn't known until the server reports it
mid-stream, so the device is opened lazily on the first ``feed``. On a normal
exit the stream drains; on Ctrl-C (or any error) it is aborted — buffered
frames discarded for an immediate stop — and the cancel propagates. Each chunk
is written in short pieces so a Ctrl-C lands promptly between writes. A device
failure is wrapped in a clean CLIError that points at --out as the headless
escape hatch. ``stream_factory`` is injectable for tests.
"""

def __init__(self, *, stream_factory: Callable[[int], _OutputStream] | None = None) -> None:
self._factory = stream_factory or _default_output_stream
self._stream: _OutputStream | None = None

def __enter__(self) -> PcmPlayer:
return self

def feed(self, pcm: bytes, sample_rate: int) -> None:
"""Play one PCM chunk, opening the device on the first chunk."""
if self._stream is None:
self._stream = self._open(sample_rate)
self._write(self._stream, pcm)

def _open(self, sample_rate: int) -> _OutputStream:
try:
stream = self._factory(sample_rate)
stream.start()
except CLIError:
raise # audio_missing_error() is already user-facing
except Exception as exc:
raise _playback_error(exc) from exc
return stream

@staticmethod
def _write(stream: _OutputStream, pcm: bytes) -> None:
# KeyboardInterrupt (a BaseException) passes through this Exception handler
# to __exit__, which aborts the device; only real device errors are wrapped.
try:
for offset in range(0, len(pcm), _PLAYBACK_CHUNK_BYTES):
stream.write(pcm[offset : offset + _PLAYBACK_CHUNK_BYTES])
except Exception as exc:
raise _playback_error(exc) from exc

def __exit__(self, exc_type: object, *_: object) -> Literal[False]: # pragma: no mutate
stream = self._stream
if stream is not None:
try:
if exc_type is None: # normal exit -> drain; an error/Ctrl-C -> abort
stream.stop()
else:
# Cut sound immediately (discard buffered frames) instead of
# letting stop() drain the rest, then let the error propagate.
with contextlib.suppress(Exception):
stream.abort()
finally:
with contextlib.suppress(Exception):
stream.close()
return False # never suppress: Ctrl-C / device errors must reach the CLI


def play_pcm(
pcm: bytes,
sample_rate: int,
*,
stream_factory: Callable[[int], _OutputStream] | None = None,
) -> None:
"""Play 16-bit mono PCM through the default output device (blocks until done).
"""Play a complete 16-bit mono PCM buffer through the default output device.

Audio is written in short chunks so a Ctrl-C interrupts promptly: on
KeyboardInterrupt the stream is aborted (buffered frames discarded) for an
immediate stop, then the cancel propagates. ``stream_factory`` is injectable
for tests; a device failure is wrapped in a clean CLIError that points at
--out as the headless escape hatch.
A thin convenience over ``PcmPlayer`` for callers that already hold the whole
clip (the multi-voice dialogue path); it blocks until playback finishes.
"""
factory = stream_factory or _default_output_stream
try:
stream = factory(sample_rate)
except CLIError:
raise # audio_missing_error() is already user-facing
except Exception as exc:
raise _playback_error(exc) from exc

try:
stream.start()
for offset in range(0, len(pcm), _PLAYBACK_CHUNK_BYTES):
stream.write(pcm[offset : offset + _PLAYBACK_CHUNK_BYTES])
stream.stop()
except KeyboardInterrupt:
# Cut sound immediately (discard whatever is still buffered in the device)
# instead of letting stop() drain the rest, then propagate the cancel.
with contextlib.suppress(Exception):
stream.abort()
raise
except Exception as exc:
raise _playback_error(exc) from exc
finally:
with contextlib.suppress(Exception):
stream.close()
with PcmPlayer(stream_factory=stream_factory) as player:
player.feed(pcm, sample_rate)
34 changes: 29 additions & 5 deletions aai_cli/tts/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,21 @@ def _decode_audio_frame(msg: dict[str, object]) -> bytes:
raise APIError(f"TTS service sent an Audio frame that is not valid base64: {exc}") from exc


def _consume_audio_frame(
msg: dict[str, object],
pcm: bytearray,
sample_rate: int,
on_audio: Callable[[bytes, int], None] | None,
) -> bool:
"""Append one Audio frame's PCM, stream it to ``on_audio``, and report whether
it is the final frame (so the collection loop can stop)."""
chunk = _decode_audio_frame(msg)
pcm.extend(chunk)
if on_audio is not None:
on_audio(chunk, sample_rate)
return bool(msg.get("is_final"))


def _error_frame_detail(msg: dict[str, object]) -> str:
"""The ``(code): reason`` tail of an Error frame, with an ``unknown`` reason
fallback so a detail-less frame still yields a non-empty message."""
Expand Down Expand Up @@ -158,9 +173,16 @@ def _default_connect(


def _run_protocol(
ws: _WebSocket, config: SpeakConfig, on_warning: Callable[[str], None] | None
ws: _WebSocket,
config: SpeakConfig,
on_warning: Callable[[str], None] | None,
on_audio: Callable[[bytes, int], None] | None,
) -> SpeakResult:
"""Send Generate + Flush, collect Audio until is_final, then Terminate."""
"""Send Generate + Flush, collect Audio until is_final, then Terminate.

Each decoded Audio chunk is handed to ``on_audio(chunk, sample_rate)`` (when
given) as it arrives, so a caller can play it immediately instead of waiting
for the whole synthesis; the full PCM is still accumulated and returned."""
begin = json.loads(_recv_raw(ws))
begin_type = begin.get("type")
if begin_type == "Error":
Expand All @@ -182,8 +204,7 @@ def _run_protocol(
msg = json.loads(_recv_raw(ws))
mtype = msg.get("type")
if mtype == "Audio":
pcm.extend(_decode_audio_frame(msg))
if msg.get("is_final"):
if _consume_audio_frame(msg, pcm, sample_rate, on_audio):
break
elif mtype == "FlushDone":
# The live server ends a synthesis with FlushDone (its Audio frames carry
Expand All @@ -207,10 +228,13 @@ def synthesize(
*,
connect: _Connect | None = None,
on_warning: Callable[[str], None] | None = None,
on_audio: Callable[[bytes, int], None] | None = None,
) -> SpeakResult:
"""Open the streaming-TTS socket and synthesize ``config.text`` to PCM.

``connect`` defaults to websockets' synchronous client; injectable for tests.
``on_audio(chunk, sample_rate)``, when given, receives each PCM chunk as it
arrives (for incremental playback); the full PCM is still returned regardless.
Connect/session failures map to a clean CLIError (a rejected key -> exit 4).
"""
wsutil.silence_websockets_logging()
Expand All @@ -230,7 +254,7 @@ def synthesize(
max_size=None, # no frame cap: a synthesis's Audio frames can exceed the 1 MiB default
)
try:
return _run_protocol(ws, config, on_warning)
return _run_protocol(ws, config, on_warning, on_audio)
except (CLIError, KeyboardInterrupt, BrokenPipeError):
raise # clean CLI errors, Ctrl-C, and a closed pipe are handled upstream
except Exception as exc:
Expand Down
66 changes: 56 additions & 10 deletions tests/test_speak.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import re
import signal
from typing import Literal

import pytest
from typer.testing import CliRunner
Expand All @@ -23,11 +24,28 @@ def _fake_key(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(config, "resolve_api_key", lambda **_: "test-key")


class _FakePlayer:
"""A no-op stand-in for audio.PcmPlayer: records the chunks fed to it instead
of opening a real output device (none exists on a headless CI box)."""

def __init__(self) -> None:
self.fed: list[tuple[bytes, int]] = []

def __enter__(self) -> _FakePlayer:
return self

def __exit__(self, *_exc: object) -> Literal[False]:
return False

def feed(self, pcm: bytes, sample_rate: int) -> None:
self.fed.append((pcm, sample_rate))


@pytest.fixture
def fake_synthesize(monkeypatch: pytest.MonkeyPatch):
calls: dict[str, object] = {}

def _fake(api_key, cfg, *, connect=None, on_warning=None):
def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None):
calls["api_key"] = api_key
calls["cfg"] = cfg
return session.SpeakResult(
Expand All @@ -49,23 +67,51 @@ def test_production_env_is_rejected_with_sandbox_hint():
assert "before the command" in " ".join(result.output.split())


def test_plays_audio_by_default(monkeypatch, fake_synthesize):
played: dict = {}
monkeypatch.setattr(
"aai_cli.commands.speak._exec.audio.play_pcm",
lambda pcm, rate, **_: played.update(pcm=pcm, rate=rate),
)
def test_plays_audio_by_default(monkeypatch):
# Default (no --out): each audio frame is streamed to the speaker as it arrives,
# via PcmPlayer.feed, rather than buffered and played at the end. The fake
# synthesize plays one frame back through the wired on_audio callback.
player = _FakePlayer()
monkeypatch.setattr("aai_cli.commands.speak._exec.audio.PcmPlayer", lambda **_: player)
captured: list[session.SpeakConfig] = []

def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None):
captured.append(cfg)
assert on_audio is not None
on_audio(b"\x01\x02\x03\x04", 24000) # the server emitting one Audio frame
return session.SpeakResult(
pcm=b"\x01\x02\x03\x04", sample_rate=24000, audio_duration_seconds=0.1
)

monkeypatch.setattr(session, "synthesize", _fake)
result = runner.invoke(app, ["--sandbox", "speak", "Hello there"])
assert result.exit_code == 0
assert played == {"pcm": b"\x01\x02\x03\x04", "rate": 24000}
assert fake_synthesize["cfg"].text == "Hello there"
# The frame reached the speaker (chunk + the server's reported rate), proving the
# on_audio=player.feed wiring — not a buffered play_pcm at the end.
assert player.fed == [(b"\x01\x02\x03\x04", 24000)]
assert captured[0].text == "Hello there"
# No --voice given -> single-voice path falls back to the default "jane".
assert fake_synthesize["cfg"].voice == "jane"
assert captured[0].voice == "jane"
# Human summary (stderr) reports the default "played" disposition.
assert "played" in result.stderr
assert "saved to" not in result.stderr


def test_single_voice_server_warning_is_surfaced(monkeypatch):
# A Warning frame during single-voice synthesis is surfaced through the wired
# on_warning callback; in --json mode it ships as its own {"warning": …} object.
def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None):
assert on_warning is not None
on_warning("slow synthesis")
return session.SpeakResult(pcm=b"", sample_rate=24000, audio_duration_seconds=0.0)

monkeypatch.setattr(session, "synthesize", _fake)
result = runner.invoke(app, ["--sandbox", "speak", "Hi", "--json"])
assert result.exit_code == 0
warning = next(json.loads(line) for line in result.stderr.splitlines() if line.startswith("{"))
assert warning["warning"] == "slow synthesis"


def test_out_writes_wav_and_does_not_play(monkeypatch, tmp_path, fake_synthesize):
monkeypatch.setattr(
"aai_cli.commands.speak._exec.audio.play_pcm",
Expand Down
Loading
Loading