Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aai_cli/commands/speak/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ def on_warning(message: str) -> None:
audio.PcmPlayer() as player,
output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet),
):
result = session.synthesize(api_key, cfg, on_audio=player.feed, on_warning=on_warning)
result = session.synthesize_chunked(
api_key, cfg, on_audio=player.feed, on_warning=on_warning
)
else:
with output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet):
result = session.synthesize(api_key, cfg, on_warning=on_warning)
result = session.synthesize_chunked(api_key, cfg, on_warning=on_warning)
audio.write_wav(opts.out, result.pcm, result.sample_rate)
_emit_single(result, cfg, opts.out, json_mode=json_mode)

Expand Down
63 changes: 57 additions & 6 deletions aai_cli/tts/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import json
from abc import abstractmethod
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import Protocol
from urllib.parse import urlencode

from aai_cli.core import environments
from aai_cli.core import ws as wsutil
from aai_cli.core.errors import APIError, CLIError
from aai_cli.streaming import diagnostics
from aai_cli.tts import audio
from aai_cli.tts import audio, text


class _WebSocket(Protocol):
Expand Down Expand Up @@ -164,12 +164,21 @@ def _recv_raw(ws: _WebSocket) -> str | bytes:


def _default_connect(
url: str, *, additional_headers: dict[str, str], max_size: int | None
url: str,
*,
additional_headers: dict[str, str],
max_size: int | None,
ping_timeout: float | None,
) -> _WebSocket:
"""The real websockets sync client, imported lazily so tests can inject a fake."""
from websockets.sync.client import connect

return connect(url, additional_headers=additional_headers, max_size=max_size)
return connect(
url,
additional_headers=additional_headers,
max_size=max_size,
ping_timeout=ping_timeout,
)


def _run_protocol(
Expand Down Expand Up @@ -252,6 +261,13 @@ def synthesize(
# Bearer token upgrades fine but is rejected in-band as an Error frame.
bearer=False,
max_size=None, # no frame cap: a synthesis's Audio frames can exceed the 1 MiB default
# Disable websockets' keepalive pong deadline. A long input (notably a --url page or
# PDF, sent as one Generate frame) can leave the server silent for longer than the
# default 20s ping_timeout before it streams the first Audio frame; websockets would
# then close the still-alive connection itself with code 1011 ("keepalive ping
# timeout"). _RECV_TIMEOUT_SECONDS already bounds server silence per frame, so it is
# the single liveness authority — pings still flow, but a slow pong no longer kills us.
ping_timeout=None,
)
try:
return _run_protocol(ws, config, on_warning, on_audio)
Expand All @@ -264,6 +280,39 @@ def synthesize(
ws.close()


def synthesize_chunked(
api_key: str,
config: SpeakConfig,
*,
connect: _Connect | None = None,
on_warning: Callable[[str], None] | None = None,
on_audio: Callable[[bytes, int], None] | None = None,
) -> SpeakResult:
"""Synthesize ``config.text`` as a sequence of sentence-packed chunks, one
streaming-TTS connection each, concatenating the PCM.

PocketTTS is a streaming model meant to be fed incrementally; a whole document in a
single Generate frame stalls the server (it then misses the keepalive ping and the
socket closes with a 1011). Chunking keeps every Generate small and starts playback
on the first chunk. ``on_audio`` still receives each PCM chunk as it arrives and the
full PCM is returned for the summary, exactly like ``synthesize`` — this is the same
one-connection-per-sentence pattern the agent-cascade path uses.
"""
pcm = bytearray()
sample_rate = _DEFAULT_SAMPLE_RATE
for chunk in text.chunk_text(config.text):
result = synthesize(
api_key,
replace(config, text=chunk),
connect=connect,
on_warning=on_warning,
on_audio=on_audio,
)
pcm.extend(result.pcm)
sample_rate = result.sample_rate
return SpeakResult(bytes(pcm), sample_rate, _pcm_duration_seconds(pcm, sample_rate))


def synthesize_dialogue(
api_key: str,
segments: list[tuple[str, str]],
Expand All @@ -281,8 +330,10 @@ def synthesize_dialogue(
"""
pcm = bytearray()
sample_rate_out = _DEFAULT_SAMPLE_RATE
for index, (voice, text) in enumerate(segments):
config = SpeakConfig(text=text, voice=voice, language=language, sample_rate=sample_rate)
for index, (voice, turn_text) in enumerate(segments):
config = SpeakConfig(
text=turn_text, voice=voice, language=language, sample_rate=sample_rate
)
result = synthesize(api_key, config, connect=connect, on_warning=on_warning)
if index:
pcm.extend(audio.silence(result.sample_rate, _INTER_TURN_SILENCE_SECONDS))
Expand Down
73 changes: 73 additions & 0 deletions aai_cli/tts/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Pure text helpers for streaming TTS: sentence splitting and chunking.

PocketTTS (the streaming-TTS model behind ``assembly speak``) is fed incrementally —
a whole document in a single ``Generate`` frame stalls the server. ``chunk_text`` breaks
the input into sentence-aligned chunks small enough to synthesize one connection at a
time (see ``tts.session.synthesize_chunked``). Kept Rich-free and dependency-light so it
is trivially unit-testable.
"""

from __future__ import annotations

# A sentence ends at one of these terminators (mirrors the agent-cascade splitter).
_TERMINATORS = ".!?"

# Conservative upper bound on the characters in a single Generate frame. PocketTTS is a
# streaming model with a bounded context; everywhere else in the codebase it is fed one
# sentence at a time. Sentences are packed up to this budget to keep the connection count
# down on a long page while keeping each frame comfortably small.
_MAX_CHUNK_CHARS = 500 # pragma: no mutate -- a +-1 char budget is immaterial


def split_sentences(text: str) -> list[str]:
"""Split ``text`` into sentences, each ending in ``.``/``!``/``?``.

A terminator ends a sentence only when it is the last character or is followed by
whitespace — so a ``.`` inside a number ("$3.50") or stacked terminators ("..."/"?!")
don't fragment one spoken sentence. A trailing fragment with no terminal punctuation
is kept, so no text is ever dropped; empty/whitespace-only pieces are discarded.
"""
sentences: list[str] = []
start = 0
for index, char in enumerate(text):
if char in _TERMINATORS and (index + 1 == len(text) or text[index + 1].isspace()):
# The two `+ 1`s below are equivalent under mutation: a confirmed boundary means
# index+1 is whitespace or end-of-text, so widening the slice / advancing start by
# one extra char only ever spans whitespace that .strip() removes.
sentences.append(text[start : index + 1].strip()) # pragma: no mutate
start = index + 1 # pragma: no mutate
tail = text[start:].strip()
if tail:
sentences.append(tail)
return sentences


def _bounded(sentence: str, max_chars: int) -> list[str]:
"""Slice ``sentence`` into ``<= max_chars`` pieces. A sentence within the budget comes
back as a single piece (the one slice covers it); an over-long one (e.g. a PDF blob
with no sentence terminators) is split so no single Generate frame can stall the
server."""
return [sentence[i : i + max_chars] for i in range(0, len(sentence), max_chars)]


def chunk_text(text: str, max_chars: int = _MAX_CHUNK_CHARS) -> list[str]:
"""Split ``text`` into sentence-aligned chunks, each ``<= max_chars``.

Sentences are packed greedily so short ones share a chunk (and thus one connection);
packing never breaks mid-sentence unless a single sentence exceeds the budget, in
which case that sentence alone is sliced. Whitespace-only input yields no chunks.
"""
chunks: list[str] = []
current = ""
for sentence in split_sentences(text):
for piece in _bounded(sentence, max_chars):
if not current:
current = piece
elif len(current) + 1 + len(piece) <= max_chars:
current = f"{current} {piece}"
else:
chunks.append(current)
current = piece
if current:
chunks.append(current)
return chunks
82 changes: 82 additions & 0 deletions tests/_tts_session_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Shared test doubles for the streaming-TTS session suites.

Split out of ``test_tts_session.py`` so the single-synthesis tests and the
``synthesize_dialogue`` tests can each stay under the 500-line file gate while
driving the same fake socket and frame builders.
"""

from __future__ import annotations

import base64
import json

from aai_cli.core import environments


def use_env(name: str) -> None:
environments.set_active(environments.get(name))


class FakeWS:
"""A minimal stand-in for a websockets sync connection."""

def __init__(self, incoming: list[str]) -> None:
self._incoming = list(incoming)
self.sent: list[dict[str, object]] = []
self.closed = False
self.recv_timeouts: list[float | None] = []

def recv(self, timeout: float | None = None) -> str:
self.recv_timeouts.append(timeout)
return self._incoming.pop(0)

def send(self, data: str) -> None:
self.sent.append(json.loads(data))

def close(self) -> None:
self.closed = True


def audio_frame(pcm: bytes, *, final: bool) -> str:
# An Audio frame with an explicit is_final flag — the defensive end-of-stream path
# (the live server instead omits is_final and ends with FlushDone, see audio_chunk).
# The sample rate is reported once, up front, in the Begin frame's configuration.
return json.dumps(
{
"type": "Audio",
"audio": base64.b64encode(pcm).decode("ascii"),
"is_final": final,
}
)


def audio_chunk(pcm: bytes) -> str:
# The real server's Audio frames carry the PCM payload and a flush_id, but NO
# is_final flag — completion is signalled by a separate FlushDone frame.
return json.dumps(
{"type": "Audio", "audio": base64.b64encode(pcm).decode("ascii"), "flush_id": 0}
)


def flush_done_frame() -> str:
return json.dumps({"type": "FlushDone", "flush_id": 0, "audio_duration_ms": 880})


def begin_frame(*, sample_rate: int = 24000) -> str:
return json.dumps(
{
"type": "Begin",
"id": "s1",
"expires_at": 1,
"configuration": {"voice": "jane", "language": "english", "sample_rate": sample_rate},
}
)


def connect_returning(ws: FakeWS, captured: dict[str, object]):
def _connect(url: str, **kwargs):
captured["url"] = url
captured["kwargs"] = kwargs
return ws

return _connect
33 changes: 30 additions & 3 deletions tests/test_speak.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,24 @@ def test_url_reads_web_page_aloud(monkeypatch, fake_synthesize):
assert fake_synthesize["cfg"].text == "The article body."


def test_long_text_is_synthesized_one_chunk_per_connection(monkeypatch):
# _speak_single feeds PocketTTS one chunk at a time (one synthesize/connection each),
# never the whole document in a single Generate — the fix for the long --url/PDF case.
monkeypatch.setattr("aai_cli.tts.text.chunk_text", lambda _t: ["First.", "Second.", "Third."])
monkeypatch.setattr("aai_cli.commands.speak._exec.audio.PcmPlayer", lambda **_: _FakePlayer())
spoken: list[str] = []

def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None):
spoken.append(cfg.text)
return session.SpeakResult(pcm=b"\x01\x02", sample_rate=24000, audio_duration_seconds=0.0)

monkeypatch.setattr(session, "synthesize", _fake)
result = runner.invoke(app, ["--sandbox", "speak", "First. Second. Third."])
assert result.exit_code == 0
# One synthesize call per chunk, each carrying just that chunk's text.
assert spoken == ["First.", "Second.", "Third."]


def test_url_and_text_argument_are_mutually_exclusive(monkeypatch):
result = runner.invoke(
app, ["--sandbox", "speak", "Hello", "--url", "https://example.com/post"]
Expand Down Expand Up @@ -229,16 +247,25 @@ def test_explicit_voice_beats_the_language_default(monkeypatch, fake_synthesize)
assert fake_synthesize["cfg"].voice == "jane"


def test_json_mode_emits_metadata_object_on_stdout(monkeypatch, fake_synthesize):
def test_json_mode_emits_metadata_object_on_stdout(monkeypatch):
monkeypatch.setattr("aai_cli.commands.speak._exec.audio.play_pcm", lambda *a, **k: None)
# 5926 bytes of 16-bit mono PCM at 24 kHz is 5926/2/24000 = 0.12345833…s — a value
# with >3 decimals so the round-to-3 in the summary is actually exercised. The fake
# reports the matching duration; synthesize_chunked recomputes it from the PCM anyway.
pcm = b"\x00" * 5926

def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None):
return session.SpeakResult(pcm=pcm, sample_rate=24000, audio_duration_seconds=0.12345833)

monkeypatch.setattr(session, "synthesize", _fake)
result = runner.invoke(app, ["--sandbox", "speak", "Hi", "--voice", "jane", "--json"])
assert result.exit_code == 0
# The behavioral split: --json yields a parseable object, not human prose.
payload = json.loads(result.stdout.strip())
assert payload["voice"] == "jane"
assert payload["sample_rate"] == 24000
assert payload["bytes"] == 4
# Duration is rounded to 3 decimals (0.123456 -> 0.123, not 0.1235).
assert payload["bytes"] == 5926
# Duration is rounded to 3 decimals (0.12345833… -> 0.123, not 0.1235).
assert payload["audio_duration_seconds"] == 0.123
# No --out -> the reported path is null, not the string "None".
assert payload["out"] is None
Expand Down
Loading
Loading