Skip to content

Commit aeae9b7

Browse files
alexkromanalexkroman-assemblyclaude
authored
Chunk speak text and disable TTS keepalive deadline for long --url input (#231)
## Problem `assembly speak --url https://arxiv.org/pdf/...` (a long PDF/article) failed with: ``` Error: TTS session failed: sent 1011 (internal error) keepalive ping timeout; no close frame received ``` The "no close frame received" means the **client** (websockets) gave up on a still-alive socket, not that the server crashed. ## Root cause Two independent contributors, both fixed here: 1. **The whole document was sent in a single PocketTTS `Generate` frame.** PocketTTS is a streaming model meant to be fed incrementally; a whole paper stalls the server long enough that it stops answering the websocket keepalive ping, and the client closes the socket with code 1011. 2. **websockets' default 20s keepalive pong deadline is too aggressive for this workload.** A server slow to emit the first Audio frame under load gets killed before producing anything. ## Fix - **`session.synthesize_chunked`** splits the text into sentence-aligned chunks packed to a safe char budget and synthesizes **one connection per chunk** — the same one-sentence-per-connection pattern `agent-cascade` already uses. An over-long, terminator-less PDF blob is hard-sliced so no single `Generate` can blow past the server's input ceiling. Bonus: audio now starts on the first chunk instead of after the whole document. - **`ping_timeout=None`** on the TTS socket disables the redundant pong deadline. `_RECV_TIMEOUT_SECONDS` (60s) is already the per-frame liveness authority, so a slow-but-alive server is no longer killed; a genuinely dead connection still fails cleanly. ## Notes - New `aai_cli/tts/text.py` holds the pure `split_sentences`/`chunk_text` helpers (Rich-free, unit-tested). - The oversized `test_tts_session.py` was split along its natural seam (single-synthesis vs dialogue) via a shared `tests/_tts_session_helpers.py` to stay under the 500-line file gate. - Full `scripts/check.sh` gate passes (coverage, 100% patch coverage, mutation gate, build). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Alex Kroman <alex@assemblyai.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 6e43cdc commit aeae9b7

8 files changed

Lines changed: 471 additions & 145 deletions

File tree

aai_cli/commands/speak/_exec.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,12 @@ def on_warning(message: str) -> None:
156156
audio.PcmPlayer() as player,
157157
output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet),
158158
):
159-
result = session.synthesize(api_key, cfg, on_audio=player.feed, on_warning=on_warning)
159+
result = session.synthesize_chunked(
160+
api_key, cfg, on_audio=player.feed, on_warning=on_warning
161+
)
160162
else:
161163
with output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet):
162-
result = session.synthesize(api_key, cfg, on_warning=on_warning)
164+
result = session.synthesize_chunked(api_key, cfg, on_warning=on_warning)
163165
audio.write_wav(opts.out, result.pcm, result.sample_rate)
164166
_emit_single(result, cfg, opts.out, json_mode=json_mode)
165167

aai_cli/tts/session.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
import json
77
from abc import abstractmethod
88
from collections.abc import Callable, Mapping
9-
from dataclasses import dataclass, field
9+
from dataclasses import dataclass, field, replace
1010
from typing import Protocol
1111
from urllib.parse import urlencode
1212

1313
from aai_cli.core import environments
1414
from aai_cli.core import ws as wsutil
1515
from aai_cli.core.errors import APIError, CLIError
1616
from aai_cli.streaming import diagnostics
17-
from aai_cli.tts import audio
17+
from aai_cli.tts import audio, text
1818

1919

2020
class _WebSocket(Protocol):
@@ -164,12 +164,21 @@ def _recv_raw(ws: _WebSocket) -> str | bytes:
164164

165165

166166
def _default_connect(
167-
url: str, *, additional_headers: dict[str, str], max_size: int | None
167+
url: str,
168+
*,
169+
additional_headers: dict[str, str],
170+
max_size: int | None,
171+
ping_timeout: float | None,
168172
) -> _WebSocket:
169173
"""The real websockets sync client, imported lazily so tests can inject a fake."""
170174
from websockets.sync.client import connect
171175

172-
return connect(url, additional_headers=additional_headers, max_size=max_size)
176+
return connect(
177+
url,
178+
additional_headers=additional_headers,
179+
max_size=max_size,
180+
ping_timeout=ping_timeout,
181+
)
173182

174183

175184
def _run_protocol(
@@ -252,6 +261,13 @@ def synthesize(
252261
# Bearer token upgrades fine but is rejected in-band as an Error frame.
253262
bearer=False,
254263
max_size=None, # no frame cap: a synthesis's Audio frames can exceed the 1 MiB default
264+
# Disable websockets' keepalive pong deadline. A long input (notably a --url page or
265+
# PDF, sent as one Generate frame) can leave the server silent for longer than the
266+
# default 20s ping_timeout before it streams the first Audio frame; websockets would
267+
# then close the still-alive connection itself with code 1011 ("keepalive ping
268+
# timeout"). _RECV_TIMEOUT_SECONDS already bounds server silence per frame, so it is
269+
# the single liveness authority — pings still flow, but a slow pong no longer kills us.
270+
ping_timeout=None,
255271
)
256272
try:
257273
return _run_protocol(ws, config, on_warning, on_audio)
@@ -264,6 +280,39 @@ def synthesize(
264280
ws.close()
265281

266282

283+
def synthesize_chunked(
284+
api_key: str,
285+
config: SpeakConfig,
286+
*,
287+
connect: _Connect | None = None,
288+
on_warning: Callable[[str], None] | None = None,
289+
on_audio: Callable[[bytes, int], None] | None = None,
290+
) -> SpeakResult:
291+
"""Synthesize ``config.text`` as a sequence of sentence-packed chunks, one
292+
streaming-TTS connection each, concatenating the PCM.
293+
294+
PocketTTS is a streaming model meant to be fed incrementally; a whole document in a
295+
single Generate frame stalls the server (it then misses the keepalive ping and the
296+
socket closes with a 1011). Chunking keeps every Generate small and starts playback
297+
on the first chunk. ``on_audio`` still receives each PCM chunk as it arrives and the
298+
full PCM is returned for the summary, exactly like ``synthesize`` — this is the same
299+
one-connection-per-sentence pattern the agent-cascade path uses.
300+
"""
301+
pcm = bytearray()
302+
sample_rate = _DEFAULT_SAMPLE_RATE
303+
for chunk in text.chunk_text(config.text):
304+
result = synthesize(
305+
api_key,
306+
replace(config, text=chunk),
307+
connect=connect,
308+
on_warning=on_warning,
309+
on_audio=on_audio,
310+
)
311+
pcm.extend(result.pcm)
312+
sample_rate = result.sample_rate
313+
return SpeakResult(bytes(pcm), sample_rate, _pcm_duration_seconds(pcm, sample_rate))
314+
315+
267316
def synthesize_dialogue(
268317
api_key: str,
269318
segments: list[tuple[str, str]],
@@ -281,8 +330,10 @@ def synthesize_dialogue(
281330
"""
282331
pcm = bytearray()
283332
sample_rate_out = _DEFAULT_SAMPLE_RATE
284-
for index, (voice, text) in enumerate(segments):
285-
config = SpeakConfig(text=text, voice=voice, language=language, sample_rate=sample_rate)
333+
for index, (voice, turn_text) in enumerate(segments):
334+
config = SpeakConfig(
335+
text=turn_text, voice=voice, language=language, sample_rate=sample_rate
336+
)
286337
result = synthesize(api_key, config, connect=connect, on_warning=on_warning)
287338
if index:
288339
pcm.extend(audio.silence(result.sample_rate, _INTER_TURN_SILENCE_SECONDS))

aai_cli/tts/text.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Pure text helpers for streaming TTS: sentence splitting and chunking.
2+
3+
PocketTTS (the streaming-TTS model behind ``assembly speak``) is fed incrementally —
4+
a whole document in a single ``Generate`` frame stalls the server. ``chunk_text`` breaks
5+
the input into sentence-aligned chunks small enough to synthesize one connection at a
6+
time (see ``tts.session.synthesize_chunked``). Kept Rich-free and dependency-light so it
7+
is trivially unit-testable.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
# A sentence ends at one of these terminators (mirrors the agent-cascade splitter).
13+
_TERMINATORS = ".!?"
14+
15+
# Conservative upper bound on the characters in a single Generate frame. PocketTTS is a
16+
# streaming model with a bounded context; everywhere else in the codebase it is fed one
17+
# sentence at a time. Sentences are packed up to this budget to keep the connection count
18+
# down on a long page while keeping each frame comfortably small.
19+
_MAX_CHUNK_CHARS = 500 # pragma: no mutate -- a +-1 char budget is immaterial
20+
21+
22+
def split_sentences(text: str) -> list[str]:
23+
"""Split ``text`` into sentences, each ending in ``.``/``!``/``?``.
24+
25+
A terminator ends a sentence only when it is the last character or is followed by
26+
whitespace — so a ``.`` inside a number ("$3.50") or stacked terminators ("..."/"?!")
27+
don't fragment one spoken sentence. A trailing fragment with no terminal punctuation
28+
is kept, so no text is ever dropped; empty/whitespace-only pieces are discarded.
29+
"""
30+
sentences: list[str] = []
31+
start = 0
32+
for index, char in enumerate(text):
33+
if char in _TERMINATORS and (index + 1 == len(text) or text[index + 1].isspace()):
34+
# The two `+ 1`s below are equivalent under mutation: a confirmed boundary means
35+
# index+1 is whitespace or end-of-text, so widening the slice / advancing start by
36+
# one extra char only ever spans whitespace that .strip() removes.
37+
sentences.append(text[start : index + 1].strip()) # pragma: no mutate
38+
start = index + 1 # pragma: no mutate
39+
tail = text[start:].strip()
40+
if tail:
41+
sentences.append(tail)
42+
return sentences
43+
44+
45+
def _bounded(sentence: str, max_chars: int) -> list[str]:
46+
"""Slice ``sentence`` into ``<= max_chars`` pieces. A sentence within the budget comes
47+
back as a single piece (the one slice covers it); an over-long one (e.g. a PDF blob
48+
with no sentence terminators) is split so no single Generate frame can stall the
49+
server."""
50+
return [sentence[i : i + max_chars] for i in range(0, len(sentence), max_chars)]
51+
52+
53+
def chunk_text(text: str, max_chars: int = _MAX_CHUNK_CHARS) -> list[str]:
54+
"""Split ``text`` into sentence-aligned chunks, each ``<= max_chars``.
55+
56+
Sentences are packed greedily so short ones share a chunk (and thus one connection);
57+
packing never breaks mid-sentence unless a single sentence exceeds the budget, in
58+
which case that sentence alone is sliced. Whitespace-only input yields no chunks.
59+
"""
60+
chunks: list[str] = []
61+
current = ""
62+
for sentence in split_sentences(text):
63+
for piece in _bounded(sentence, max_chars):
64+
if not current:
65+
current = piece
66+
elif len(current) + 1 + len(piece) <= max_chars:
67+
current = f"{current} {piece}"
68+
else:
69+
chunks.append(current)
70+
current = piece
71+
if current:
72+
chunks.append(current)
73+
return chunks

tests/_tts_session_helpers.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Shared test doubles for the streaming-TTS session suites.
2+
3+
Split out of ``test_tts_session.py`` so the single-synthesis tests and the
4+
``synthesize_dialogue`` tests can each stay under the 500-line file gate while
5+
driving the same fake socket and frame builders.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import base64
11+
import json
12+
13+
from aai_cli.core import environments
14+
15+
16+
def use_env(name: str) -> None:
17+
environments.set_active(environments.get(name))
18+
19+
20+
class FakeWS:
21+
"""A minimal stand-in for a websockets sync connection."""
22+
23+
def __init__(self, incoming: list[str]) -> None:
24+
self._incoming = list(incoming)
25+
self.sent: list[dict[str, object]] = []
26+
self.closed = False
27+
self.recv_timeouts: list[float | None] = []
28+
29+
def recv(self, timeout: float | None = None) -> str:
30+
self.recv_timeouts.append(timeout)
31+
return self._incoming.pop(0)
32+
33+
def send(self, data: str) -> None:
34+
self.sent.append(json.loads(data))
35+
36+
def close(self) -> None:
37+
self.closed = True
38+
39+
40+
def audio_frame(pcm: bytes, *, final: bool) -> str:
41+
# An Audio frame with an explicit is_final flag — the defensive end-of-stream path
42+
# (the live server instead omits is_final and ends with FlushDone, see audio_chunk).
43+
# The sample rate is reported once, up front, in the Begin frame's configuration.
44+
return json.dumps(
45+
{
46+
"type": "Audio",
47+
"audio": base64.b64encode(pcm).decode("ascii"),
48+
"is_final": final,
49+
}
50+
)
51+
52+
53+
def audio_chunk(pcm: bytes) -> str:
54+
# The real server's Audio frames carry the PCM payload and a flush_id, but NO
55+
# is_final flag — completion is signalled by a separate FlushDone frame.
56+
return json.dumps(
57+
{"type": "Audio", "audio": base64.b64encode(pcm).decode("ascii"), "flush_id": 0}
58+
)
59+
60+
61+
def flush_done_frame() -> str:
62+
return json.dumps({"type": "FlushDone", "flush_id": 0, "audio_duration_ms": 880})
63+
64+
65+
def begin_frame(*, sample_rate: int = 24000) -> str:
66+
return json.dumps(
67+
{
68+
"type": "Begin",
69+
"id": "s1",
70+
"expires_at": 1,
71+
"configuration": {"voice": "jane", "language": "english", "sample_rate": sample_rate},
72+
}
73+
)
74+
75+
76+
def connect_returning(ws: FakeWS, captured: dict[str, object]):
77+
def _connect(url: str, **kwargs):
78+
captured["url"] = url
79+
captured["kwargs"] = kwargs
80+
return ws
81+
82+
return _connect

tests/test_speak.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,24 @@ def test_url_reads_web_page_aloud(monkeypatch, fake_synthesize):
147147
assert fake_synthesize["cfg"].text == "The article body."
148148

149149

150+
def test_long_text_is_synthesized_one_chunk_per_connection(monkeypatch):
151+
# _speak_single feeds PocketTTS one chunk at a time (one synthesize/connection each),
152+
# never the whole document in a single Generate — the fix for the long --url/PDF case.
153+
monkeypatch.setattr("aai_cli.tts.text.chunk_text", lambda _t: ["First.", "Second.", "Third."])
154+
monkeypatch.setattr("aai_cli.commands.speak._exec.audio.PcmPlayer", lambda **_: _FakePlayer())
155+
spoken: list[str] = []
156+
157+
def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None):
158+
spoken.append(cfg.text)
159+
return session.SpeakResult(pcm=b"\x01\x02", sample_rate=24000, audio_duration_seconds=0.0)
160+
161+
monkeypatch.setattr(session, "synthesize", _fake)
162+
result = runner.invoke(app, ["--sandbox", "speak", "First. Second. Third."])
163+
assert result.exit_code == 0
164+
# One synthesize call per chunk, each carrying just that chunk's text.
165+
assert spoken == ["First.", "Second.", "Third."]
166+
167+
150168
def test_url_and_text_argument_are_mutually_exclusive(monkeypatch):
151169
result = runner.invoke(
152170
app, ["--sandbox", "speak", "Hello", "--url", "https://example.com/post"]
@@ -229,16 +247,25 @@ def test_explicit_voice_beats_the_language_default(monkeypatch, fake_synthesize)
229247
assert fake_synthesize["cfg"].voice == "jane"
230248

231249

232-
def test_json_mode_emits_metadata_object_on_stdout(monkeypatch, fake_synthesize):
250+
def test_json_mode_emits_metadata_object_on_stdout(monkeypatch):
233251
monkeypatch.setattr("aai_cli.commands.speak._exec.audio.play_pcm", lambda *a, **k: None)
252+
# 5926 bytes of 16-bit mono PCM at 24 kHz is 5926/2/24000 = 0.12345833…s — a value
253+
# with >3 decimals so the round-to-3 in the summary is actually exercised. The fake
254+
# reports the matching duration; synthesize_chunked recomputes it from the PCM anyway.
255+
pcm = b"\x00" * 5926
256+
257+
def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None):
258+
return session.SpeakResult(pcm=pcm, sample_rate=24000, audio_duration_seconds=0.12345833)
259+
260+
monkeypatch.setattr(session, "synthesize", _fake)
234261
result = runner.invoke(app, ["--sandbox", "speak", "Hi", "--voice", "jane", "--json"])
235262
assert result.exit_code == 0
236263
# The behavioral split: --json yields a parseable object, not human prose.
237264
payload = json.loads(result.stdout.strip())
238265
assert payload["voice"] == "jane"
239266
assert payload["sample_rate"] == 24000
240-
assert payload["bytes"] == 4
241-
# Duration is rounded to 3 decimals (0.123456 -> 0.123, not 0.1235).
267+
assert payload["bytes"] == 5926
268+
# Duration is rounded to 3 decimals (0.12345833… -> 0.123, not 0.1235).
242269
assert payload["audio_duration_seconds"] == 0.123
243270
# No --out -> the reported path is null, not the string "None".
244271
assert payload["out"] is None

0 commit comments

Comments
 (0)