diff --git a/aai_cli/commands/speak/_exec.py b/aai_cli/commands/speak/_exec.py index 57ece68e..02c43d11 100644 --- a/aai_cli/commands/speak/_exec.py +++ b/aai_cli/commands/speak/_exec.py @@ -156,10 +156,12 @@ def on_warning(message: str) -> None: audio.PcmPlayer() as player, output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet), ): - result = session.synthesize(api_key, cfg, on_audio=player.feed, on_warning=on_warning) + result = session.synthesize_chunked( + api_key, cfg, on_audio=player.feed, on_warning=on_warning + ) else: with output.status("Synthesizing speech…", json_mode=json_mode, quiet=quiet): - result = session.synthesize(api_key, cfg, on_warning=on_warning) + result = session.synthesize_chunked(api_key, cfg, on_warning=on_warning) audio.write_wav(opts.out, result.pcm, result.sample_rate) _emit_single(result, cfg, opts.out, json_mode=json_mode) diff --git a/aai_cli/tts/session.py b/aai_cli/tts/session.py index 4d685114..a98005e2 100644 --- a/aai_cli/tts/session.py +++ b/aai_cli/tts/session.py @@ -6,7 +6,7 @@ import json from abc import abstractmethod from collections.abc import Callable, Mapping -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import Protocol from urllib.parse import urlencode @@ -14,7 +14,7 @@ from aai_cli.core import ws as wsutil from aai_cli.core.errors import APIError, CLIError from aai_cli.streaming import diagnostics -from aai_cli.tts import audio +from aai_cli.tts import audio, text class _WebSocket(Protocol): @@ -164,12 +164,21 @@ def _recv_raw(ws: _WebSocket) -> str | bytes: def _default_connect( - url: str, *, additional_headers: dict[str, str], max_size: int | None + url: str, + *, + additional_headers: dict[str, str], + max_size: int | None, + ping_timeout: float | None, ) -> _WebSocket: """The real websockets sync client, imported lazily so tests can inject a fake.""" from websockets.sync.client import connect - return connect(url, additional_headers=additional_headers, max_size=max_size) + return connect( + url, + additional_headers=additional_headers, + max_size=max_size, + ping_timeout=ping_timeout, + ) def _run_protocol( @@ -252,6 +261,13 @@ def synthesize( # Bearer token upgrades fine but is rejected in-band as an Error frame. bearer=False, max_size=None, # no frame cap: a synthesis's Audio frames can exceed the 1 MiB default + # Disable websockets' keepalive pong deadline. A long input (notably a --url page or + # PDF, sent as one Generate frame) can leave the server silent for longer than the + # default 20s ping_timeout before it streams the first Audio frame; websockets would + # then close the still-alive connection itself with code 1011 ("keepalive ping + # timeout"). _RECV_TIMEOUT_SECONDS already bounds server silence per frame, so it is + # the single liveness authority — pings still flow, but a slow pong no longer kills us. + ping_timeout=None, ) try: return _run_protocol(ws, config, on_warning, on_audio) @@ -264,6 +280,39 @@ def synthesize( ws.close() +def synthesize_chunked( + api_key: str, + config: SpeakConfig, + *, + connect: _Connect | None = None, + on_warning: Callable[[str], None] | None = None, + on_audio: Callable[[bytes, int], None] | None = None, +) -> SpeakResult: + """Synthesize ``config.text`` as a sequence of sentence-packed chunks, one + streaming-TTS connection each, concatenating the PCM. + + PocketTTS is a streaming model meant to be fed incrementally; a whole document in a + single Generate frame stalls the server (it then misses the keepalive ping and the + socket closes with a 1011). Chunking keeps every Generate small and starts playback + on the first chunk. ``on_audio`` still receives each PCM chunk as it arrives and the + full PCM is returned for the summary, exactly like ``synthesize`` — this is the same + one-connection-per-sentence pattern the agent-cascade path uses. + """ + pcm = bytearray() + sample_rate = _DEFAULT_SAMPLE_RATE + for chunk in text.chunk_text(config.text): + result = synthesize( + api_key, + replace(config, text=chunk), + connect=connect, + on_warning=on_warning, + on_audio=on_audio, + ) + pcm.extend(result.pcm) + sample_rate = result.sample_rate + return SpeakResult(bytes(pcm), sample_rate, _pcm_duration_seconds(pcm, sample_rate)) + + def synthesize_dialogue( api_key: str, segments: list[tuple[str, str]], @@ -281,8 +330,10 @@ def synthesize_dialogue( """ pcm = bytearray() sample_rate_out = _DEFAULT_SAMPLE_RATE - for index, (voice, text) in enumerate(segments): - config = SpeakConfig(text=text, voice=voice, language=language, sample_rate=sample_rate) + for index, (voice, turn_text) in enumerate(segments): + config = SpeakConfig( + text=turn_text, voice=voice, language=language, sample_rate=sample_rate + ) result = synthesize(api_key, config, connect=connect, on_warning=on_warning) if index: pcm.extend(audio.silence(result.sample_rate, _INTER_TURN_SILENCE_SECONDS)) diff --git a/aai_cli/tts/text.py b/aai_cli/tts/text.py new file mode 100644 index 00000000..c6370273 --- /dev/null +++ b/aai_cli/tts/text.py @@ -0,0 +1,73 @@ +"""Pure text helpers for streaming TTS: sentence splitting and chunking. + +PocketTTS (the streaming-TTS model behind ``assembly speak``) is fed incrementally — +a whole document in a single ``Generate`` frame stalls the server. ``chunk_text`` breaks +the input into sentence-aligned chunks small enough to synthesize one connection at a +time (see ``tts.session.synthesize_chunked``). Kept Rich-free and dependency-light so it +is trivially unit-testable. +""" + +from __future__ import annotations + +# A sentence ends at one of these terminators (mirrors the agent-cascade splitter). +_TERMINATORS = ".!?" + +# Conservative upper bound on the characters in a single Generate frame. PocketTTS is a +# streaming model with a bounded context; everywhere else in the codebase it is fed one +# sentence at a time. Sentences are packed up to this budget to keep the connection count +# down on a long page while keeping each frame comfortably small. +_MAX_CHUNK_CHARS = 500 # pragma: no mutate -- a +-1 char budget is immaterial + + +def split_sentences(text: str) -> list[str]: + """Split ``text`` into sentences, each ending in ``.``/``!``/``?``. + + A terminator ends a sentence only when it is the last character or is followed by + whitespace — so a ``.`` inside a number ("$3.50") or stacked terminators ("..."/"?!") + don't fragment one spoken sentence. A trailing fragment with no terminal punctuation + is kept, so no text is ever dropped; empty/whitespace-only pieces are discarded. + """ + sentences: list[str] = [] + start = 0 + for index, char in enumerate(text): + if char in _TERMINATORS and (index + 1 == len(text) or text[index + 1].isspace()): + # The two `+ 1`s below are equivalent under mutation: a confirmed boundary means + # index+1 is whitespace or end-of-text, so widening the slice / advancing start by + # one extra char only ever spans whitespace that .strip() removes. + sentences.append(text[start : index + 1].strip()) # pragma: no mutate + start = index + 1 # pragma: no mutate + tail = text[start:].strip() + if tail: + sentences.append(tail) + return sentences + + +def _bounded(sentence: str, max_chars: int) -> list[str]: + """Slice ``sentence`` into ``<= max_chars`` pieces. A sentence within the budget comes + back as a single piece (the one slice covers it); an over-long one (e.g. a PDF blob + with no sentence terminators) is split so no single Generate frame can stall the + server.""" + return [sentence[i : i + max_chars] for i in range(0, len(sentence), max_chars)] + + +def chunk_text(text: str, max_chars: int = _MAX_CHUNK_CHARS) -> list[str]: + """Split ``text`` into sentence-aligned chunks, each ``<= max_chars``. + + Sentences are packed greedily so short ones share a chunk (and thus one connection); + packing never breaks mid-sentence unless a single sentence exceeds the budget, in + which case that sentence alone is sliced. Whitespace-only input yields no chunks. + """ + chunks: list[str] = [] + current = "" + for sentence in split_sentences(text): + for piece in _bounded(sentence, max_chars): + if not current: + current = piece + elif len(current) + 1 + len(piece) <= max_chars: + current = f"{current} {piece}" + else: + chunks.append(current) + current = piece + if current: + chunks.append(current) + return chunks diff --git a/tests/_tts_session_helpers.py b/tests/_tts_session_helpers.py new file mode 100644 index 00000000..cef8a1d1 --- /dev/null +++ b/tests/_tts_session_helpers.py @@ -0,0 +1,82 @@ +"""Shared test doubles for the streaming-TTS session suites. + +Split out of ``test_tts_session.py`` so the single-synthesis tests and the +``synthesize_dialogue`` tests can each stay under the 500-line file gate while +driving the same fake socket and frame builders. +""" + +from __future__ import annotations + +import base64 +import json + +from aai_cli.core import environments + + +def use_env(name: str) -> None: + environments.set_active(environments.get(name)) + + +class FakeWS: + """A minimal stand-in for a websockets sync connection.""" + + def __init__(self, incoming: list[str]) -> None: + self._incoming = list(incoming) + self.sent: list[dict[str, object]] = [] + self.closed = False + self.recv_timeouts: list[float | None] = [] + + def recv(self, timeout: float | None = None) -> str: + self.recv_timeouts.append(timeout) + return self._incoming.pop(0) + + def send(self, data: str) -> None: + self.sent.append(json.loads(data)) + + def close(self) -> None: + self.closed = True + + +def audio_frame(pcm: bytes, *, final: bool) -> str: + # An Audio frame with an explicit is_final flag — the defensive end-of-stream path + # (the live server instead omits is_final and ends with FlushDone, see audio_chunk). + # The sample rate is reported once, up front, in the Begin frame's configuration. + return json.dumps( + { + "type": "Audio", + "audio": base64.b64encode(pcm).decode("ascii"), + "is_final": final, + } + ) + + +def audio_chunk(pcm: bytes) -> str: + # The real server's Audio frames carry the PCM payload and a flush_id, but NO + # is_final flag — completion is signalled by a separate FlushDone frame. + return json.dumps( + {"type": "Audio", "audio": base64.b64encode(pcm).decode("ascii"), "flush_id": 0} + ) + + +def flush_done_frame() -> str: + return json.dumps({"type": "FlushDone", "flush_id": 0, "audio_duration_ms": 880}) + + +def begin_frame(*, sample_rate: int = 24000) -> str: + return json.dumps( + { + "type": "Begin", + "id": "s1", + "expires_at": 1, + "configuration": {"voice": "jane", "language": "english", "sample_rate": sample_rate}, + } + ) + + +def connect_returning(ws: FakeWS, captured: dict[str, object]): + def _connect(url: str, **kwargs): + captured["url"] = url + captured["kwargs"] = kwargs + return ws + + return _connect diff --git a/tests/test_speak.py b/tests/test_speak.py index 2c8f95ba..2cf1ae71 100644 --- a/tests/test_speak.py +++ b/tests/test_speak.py @@ -147,6 +147,24 @@ def test_url_reads_web_page_aloud(monkeypatch, fake_synthesize): assert fake_synthesize["cfg"].text == "The article body." +def test_long_text_is_synthesized_one_chunk_per_connection(monkeypatch): + # _speak_single feeds PocketTTS one chunk at a time (one synthesize/connection each), + # never the whole document in a single Generate — the fix for the long --url/PDF case. + monkeypatch.setattr("aai_cli.tts.text.chunk_text", lambda _t: ["First.", "Second.", "Third."]) + monkeypatch.setattr("aai_cli.commands.speak._exec.audio.PcmPlayer", lambda **_: _FakePlayer()) + spoken: list[str] = [] + + def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None): + spoken.append(cfg.text) + return session.SpeakResult(pcm=b"\x01\x02", sample_rate=24000, audio_duration_seconds=0.0) + + monkeypatch.setattr(session, "synthesize", _fake) + result = runner.invoke(app, ["--sandbox", "speak", "First. Second. Third."]) + assert result.exit_code == 0 + # One synthesize call per chunk, each carrying just that chunk's text. + assert spoken == ["First.", "Second.", "Third."] + + def test_url_and_text_argument_are_mutually_exclusive(monkeypatch): result = runner.invoke( app, ["--sandbox", "speak", "Hello", "--url", "https://example.com/post"] @@ -229,16 +247,25 @@ def test_explicit_voice_beats_the_language_default(monkeypatch, fake_synthesize) assert fake_synthesize["cfg"].voice == "jane" -def test_json_mode_emits_metadata_object_on_stdout(monkeypatch, fake_synthesize): +def test_json_mode_emits_metadata_object_on_stdout(monkeypatch): monkeypatch.setattr("aai_cli.commands.speak._exec.audio.play_pcm", lambda *a, **k: None) + # 5926 bytes of 16-bit mono PCM at 24 kHz is 5926/2/24000 = 0.12345833…s — a value + # with >3 decimals so the round-to-3 in the summary is actually exercised. The fake + # reports the matching duration; synthesize_chunked recomputes it from the PCM anyway. + pcm = b"\x00" * 5926 + + def _fake(api_key, cfg, *, connect=None, on_warning=None, on_audio=None): + return session.SpeakResult(pcm=pcm, sample_rate=24000, audio_duration_seconds=0.12345833) + + monkeypatch.setattr(session, "synthesize", _fake) result = runner.invoke(app, ["--sandbox", "speak", "Hi", "--voice", "jane", "--json"]) assert result.exit_code == 0 # The behavioral split: --json yields a parseable object, not human prose. payload = json.loads(result.stdout.strip()) assert payload["voice"] == "jane" assert payload["sample_rate"] == 24000 - assert payload["bytes"] == 4 - # Duration is rounded to 3 decimals (0.123456 -> 0.123, not 0.1235). + assert payload["bytes"] == 5926 + # Duration is rounded to 3 decimals (0.12345833… -> 0.123, not 0.1235). assert payload["audio_duration_seconds"] == 0.123 # No --out -> the reported path is null, not the string "None". assert payload["out"] is None diff --git a/tests/test_tts_session.py b/tests/test_tts_session.py index f4d1167d..cc3a11e4 100644 --- a/tests/test_tts_session.py +++ b/tests/test_tts_session.py @@ -1,31 +1,34 @@ from __future__ import annotations -import base64 import json import pytest -from aai_cli.core import environments from aai_cli.core.errors import APIError, CLIError, NotAuthenticated from aai_cli.tts import session - - -def _use_env(name: str) -> None: - environments.set_active(environments.get(name)) +from tests._tts_session_helpers import ( + FakeWS, + audio_chunk, + audio_frame, + begin_frame, + connect_returning, + flush_done_frame, + use_env, +) def test_is_available_true_in_sandbox(): - _use_env("sandbox000") + use_env("sandbox000") assert session.is_available() is True def test_is_available_false_in_production(): - _use_env("production") + use_env("production") assert session.is_available() is False def test_ws_url_includes_set_params_only(): - _use_env("sandbox000") + use_env("sandbox000") cfg = session.SpeakConfig(text="hi", voice="jane", language="English") url = session.ws_url(cfg.query_params()) assert url.startswith("wss://streaming-tts.sandbox000.assemblyai-labs.com/v1/ws/?") @@ -35,7 +38,7 @@ def test_ws_url_includes_set_params_only(): def test_ws_url_no_params_has_no_query_string(): - _use_env("sandbox000") + use_env("sandbox000") url = session.ws_url(session.SpeakConfig(text="hi").query_params()) assert url == "wss://streaming-tts.sandbox000.assemblyai-labs.com/v1/ws/" @@ -64,82 +67,17 @@ def test_speak_result_is_immutable(): _set_attr(result, "sample_rate", 1) -class FakeWS: - """A minimal stand-in for a websockets sync connection.""" - - def __init__(self, incoming: list[str]) -> None: - self._incoming = list(incoming) - self.sent: list[dict[str, object]] = [] - self.closed = False - self.recv_timeouts: list[float | None] = [] - - def recv(self, timeout: float | None = None) -> str: - self.recv_timeouts.append(timeout) - return self._incoming.pop(0) - - def send(self, data: str) -> None: - self.sent.append(json.loads(data)) - - def close(self) -> None: - self.closed = True - - -def _audio_frame(pcm: bytes, *, final: bool) -> str: - # An Audio frame with an explicit is_final flag — the defensive end-of-stream path - # (the live server instead omits is_final and ends with FlushDone, see _audio_chunk). - # The sample rate is reported once, up front, in the Begin frame's configuration. - return json.dumps( - { - "type": "Audio", - "audio": base64.b64encode(pcm).decode("ascii"), - "is_final": final, - } - ) - - -def _audio_chunk(pcm: bytes) -> str: - # The real server's Audio frames carry the PCM payload and a flush_id, but NO - # is_final flag — completion is signalled by a separate FlushDone frame. - return json.dumps( - {"type": "Audio", "audio": base64.b64encode(pcm).decode("ascii"), "flush_id": 0} - ) - - -def _flush_done_frame() -> str: - return json.dumps({"type": "FlushDone", "flush_id": 0, "audio_duration_ms": 880}) - - -def _begin_frame(*, sample_rate: int = 24000) -> str: - return json.dumps( - { - "type": "Begin", - "id": "s1", - "expires_at": 1, - "configuration": {"voice": "jane", "language": "english", "sample_rate": sample_rate}, - } - ) - - -def _connect_returning(ws: FakeWS, captured: dict[str, object]): - def _connect(url: str, **kwargs): - captured["url"] = url - captured["kwargs"] = kwargs - return ws - - return _connect - - def test_synthesize_drives_the_full_protocol(): captured: dict = {} ws = FakeWS( [ - _begin_frame(sample_rate=24000), - _audio_frame(b"\x01\x02\x03\x04", final=False), - _audio_frame(b"\x05\x06", final=True), + begin_frame(sample_rate=24000), + audio_frame(b"\x01\x02\x03\x04", final=False), + audio_frame(b"\x05\x06", final=True), ] ) cfg = session.SpeakConfig(text="hello", voice="jane") - result = session.synthesize("k", cfg, connect=_connect_returning(ws, captured)) + result = session.synthesize("k", cfg, connect=connect_returning(ws, captured)) # Sends Generate(text), then Flush, then Terminate — in that order. The flush tag # is "Flush" (the server's accepted tag, as the agent-cascade template sends), @@ -159,6 +97,11 @@ def test_synthesize_drives_the_full_protocol(): assert "Bearer" not in captured["kwargs"]["additional_headers"]["Authorization"] # The frame-size cap is lifted: Audio frames can exceed websockets' 1 MiB default. assert captured["kwargs"]["max_size"] is None + # The keepalive pong deadline is disabled: a long input (e.g. a --url PDF) can leave + # the server silent for >20s before the first Audio frame, and websockets' default + # ping_timeout would close the still-alive connection with code 1011. _RECV_TIMEOUT_SECONDS + # is the single liveness authority instead, so ping_timeout must be None. + assert captured["kwargs"]["ping_timeout"] is None def test_synthesize_stops_on_flush_done_when_audio_omits_is_final(): @@ -168,10 +111,10 @@ def test_synthesize_stops_on_flush_done_when_audio_omits_is_final(): # every PCM byte gathered so far — then Terminate the session as usual. ws = FakeWS( [ - _begin_frame(sample_rate=24000), - _audio_chunk(b"\x01\x02\x03\x04"), - _audio_chunk(b"\x05\x06"), - _flush_done_frame(), + begin_frame(sample_rate=24000), + audio_chunk(b"\x01\x02\x03\x04"), + audio_chunk(b"\x05\x06"), + flush_done_frame(), ] ) result = session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) @@ -186,10 +129,10 @@ def test_synthesize_streams_each_chunk_to_on_audio_as_it_arrives(): # order, with the server's reported rate — while the full PCM is still returned. ws = FakeWS( [ - _begin_frame(sample_rate=16000), - _audio_chunk(b"\x01\x02"), - _audio_chunk(b"\x03\x04"), - _flush_done_frame(), + begin_frame(sample_rate=16000), + audio_chunk(b"\x01\x02"), + audio_chunk(b"\x03\x04"), + flush_done_frame(), ] ) streamed: list[tuple[bytes, int]] = [] @@ -207,7 +150,7 @@ def test_synthesize_streams_each_chunk_to_on_audio_as_it_arrives(): def test_synthesize_without_on_audio_still_returns_full_pcm(): # The callback is optional: omitting it must not change the buffered result. - ws = FakeWS([_begin_frame(), _audio_frame(b"\x01\x02", final=True)]) + ws = FakeWS([begin_frame(), audio_frame(b"\x01\x02", final=True)]) result = session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) assert result.pcm == b"\x01\x02" @@ -215,7 +158,7 @@ def test_synthesize_without_on_audio_still_returns_full_pcm(): def test_synthesize_reads_sample_rate_from_begin_configuration(): # A non-default rate in the Begin frame flows into the result and its duration, # proving the rate is read from Begin.configuration rather than hardcoded. - ws = FakeWS([_begin_frame(sample_rate=16000), _audio_frame(b"\x01\x02\x03\x04", final=True)]) + ws = FakeWS([begin_frame(sample_rate=16000), audio_frame(b"\x01\x02\x03\x04", final=True)]) result = session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) assert result.sample_rate == 16000 assert result.audio_duration_seconds == pytest.approx(4 / 2 / 16000) @@ -226,7 +169,7 @@ def test_synthesize_falls_back_to_default_rate_when_begin_omits_configuration(): ws = FakeWS( [ json.dumps({"type": "Begin", "id": "s", "expires_at": 1}), - _audio_frame(b"\x01\x02", final=True), + audio_frame(b"\x01\x02", final=True), ] ) result = session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) @@ -286,7 +229,7 @@ def test_synthesize_maps_error_frame_to_api_error(): def test_synthesize_bounds_every_recv_with_the_frame_timeout(): # Each frame wait is bounded (60s): an unbounded recv() would hang the command # forever if the server went silent mid-session. - ws = FakeWS([_begin_frame(), _audio_frame(b"\x01\x02", final=True)]) + ws = FakeWS([begin_frame(), audio_frame(b"\x01\x02", final=True)]) session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) assert ws.recv_timeouts == [60.0, 60.0] @@ -321,7 +264,7 @@ def test_synthesize_invokes_on_warning_then_continues(): [ json.dumps({"type": "Begin", "id": "s", "expires_at": 1}), json.dumps({"type": "Warning", "warning_code": 1, "warning": "slow"}), - _audio_frame(b"\x01\x02", final=True), + audio_frame(b"\x01\x02", final=True), ] ) result = session.synthesize( @@ -339,7 +282,7 @@ def test_synthesize_ignores_warning_without_callback(): [ json.dumps({"type": "Begin", "id": "s", "expires_at": 1}), json.dumps({"type": "Warning", "warning_code": 1, "warning": "slow"}), - _audio_frame(b"\x01\x02", final=True), + audio_frame(b"\x01\x02", final=True), ] ) # No on_warning: the warning is silently skipped, not an error. @@ -359,7 +302,7 @@ def _connect(*_a, **_k): def test_synthesize_maps_forbidden_connect_to_api_error(): - _use_env("sandbox000") + use_env("sandbox000") class Resp: status_code = 403 @@ -427,60 +370,88 @@ def test_synthesize_without_connect_uses_real_client_and_fails_cleanly(): # never a raw socket error. disable_socket pins that blocked-at-creation behavior # on Windows too (the suite-wide conftest otherwise allows loopback there, which # would let the socket be created and then leak when the real connect is blocked). - _use_env("sandbox000") + use_env("sandbox000") with pytest.raises(CLIError): session.synthesize("k", session.SpeakConfig(text="hi")) -def test_synthesize_dialogue_concatenates_segments_with_silence(): - # One fresh fake socket per segment; record the voice each connection requested. +def test_default_connect_forwards_keepalive_and_frame_settings(monkeypatch): + # The real-client factory must forward the frame cap and the keepalive pong deadline + # verbatim to websockets — synthesize() passes ping_timeout=None to disable the deadline + # so a long --url synthesis can't die with a 1011 "keepalive ping timeout" the moment the + # server pauses >20s before the first frame. Sentinel (non-None) values are forwarded so a + # mutant that hardcodes either argument can't slip through. + captured: dict = {} + + def _fake_connect(url, **kwargs): + captured["url"] = url + captured["kwargs"] = kwargs + return object() + + monkeypatch.setattr("websockets.sync.client.connect", _fake_connect) + session._default_connect( + "wss://example/ws", + additional_headers={"Authorization": "k"}, + max_size=4096, + ping_timeout=42.0, + ) + assert captured["url"] == "wss://example/ws" + assert captured["kwargs"]["additional_headers"] == {"Authorization": "k"} + assert captured["kwargs"]["max_size"] == 4096 + assert captured["kwargs"]["ping_timeout"] == 42.0 + + +def test_synthesize_chunked_synthesizes_each_chunk_on_its_own_connection(monkeypatch): + # PocketTTS is fed one chunk per connection — never the whole document in one Generate. + monkeypatch.setattr("aai_cli.tts.text.chunk_text", lambda _t: ["First chunk.", "Second chunk."]) sockets = [ - FakeWS([_begin_frame(sample_rate=24000), _audio_frame(b"\xaa\xbb", final=True)]), - FakeWS([_begin_frame(sample_rate=24000), _audio_frame(b"\xcc\xdd", final=True)]), + FakeWS([begin_frame(), audio_frame(b"\x01\x02", final=True)]), + FakeWS([begin_frame(), audio_frame(b"\x03\x04", final=True)]), ] - urls: list[str] = [] + pool = list(sockets) + result = session.synthesize_chunked( + "k", session.SpeakConfig(text="ignored"), connect=lambda *a, **k: pool.pop(0) + ) + # One Generate per chunk, each on its own connection, in order. + assert [ws.sent[0]["text"] for ws in sockets] == ["First chunk.", "Second chunk."] + assert not pool # both connections were opened + # PCM is concatenated across the chunk connections, and the duration recomputed from it. + assert result.pcm == b"\x01\x02\x03\x04" + assert result.audio_duration_seconds == pytest.approx(4 / 2 / 24000) - def _connect(url: str, **_kwargs): - urls.append(url) - return sockets.pop(0) - result = session.synthesize_dialogue( +def test_synthesize_chunked_streams_each_chunk_to_on_audio(monkeypatch): + # on_audio fires per chunk as it arrives (incremental playback), carrying each + # connection's reported sample rate, across all chunks. + monkeypatch.setattr("aai_cli.tts.text.chunk_text", lambda _t: ["a.", "b."]) + pool = [ + FakeWS([begin_frame(sample_rate=16000), audio_frame(b"\x01\x02", final=True)]), + FakeWS([begin_frame(sample_rate=16000), audio_frame(b"\x03\x04", final=True)]), + ] + streamed: list[tuple[bytes, int]] = [] + session.synthesize_chunked( "k", - [("jane", "Hello."), ("michael", "Hi.")], - language="English", - connect=_connect, + session.SpeakConfig(text="x"), + connect=lambda *a, **k: pool.pop(0), + on_audio=lambda chunk, rate: streamed.append((chunk, rate)), ) - # Each segment connected with its own voice. - assert "voice=jane" in urls[0] - assert "voice=michael" in urls[1] - # 0.25 s of silence (24000 * 0.25 * 2 = 12000 zero bytes) sits BETWEEN the two - # segments' PCM, with none at the ends. - gap = b"\x00" * 12000 - assert result.pcm == b"\xaa\xbb" + gap + b"\xcc\xdd" - assert result.sample_rate == 24000 - # Pin the duration formula (len/2/rate) so its operators survive the mutation gate. - assert result.audio_duration_seconds == pytest.approx(len(result.pcm) / 2 / 24000) - - -def test_synthesize_dialogue_single_segment_has_no_silence(): - ws = FakeWS([_begin_frame(sample_rate=24000), _audio_frame(b"\x01\x02", final=True)]) - result = session.synthesize_dialogue("k", [("jane", "Hi.")], connect=lambda *a, **k: ws) - assert result.pcm == b"\x01\x02" # no leading/trailing pad + assert streamed == [(b"\x01\x02", 16000), (b"\x03\x04", 16000)] -def test_synthesize_dialogue_uses_server_sample_rate(): - # A non-default server rate must flow into the result (proving the per-segment - # rate is read, not left at the default) and into the duration denominator. - ws = FakeWS([_begin_frame(sample_rate=16000), _audio_frame(b"\x01\x02", final=True)]) - result = session.synthesize_dialogue("k", [("jane", "Hi.")], connect=lambda *a, **k: ws) - assert result.sample_rate == 16000 - assert result.audio_duration_seconds == pytest.approx(2 / 2 / 16000) +def test_synthesize_chunked_single_chunk_passes_text_through(): + # A short input is one chunk: the original text is synthesized unchanged. + captured: dict = {} + ws = FakeWS([begin_frame(), audio_frame(b"\x01\x02", final=True)]) + session.synthesize_chunked( + "k", session.SpeakConfig(text="Just one sentence."), connect=connect_returning(ws, captured) + ) + assert ws.sent[0]["text"] == "Just one sentence." -def test_synthesize_dialogue_empty_segments_returns_silent_default(): - # No segments -> no audio at the default rate, and no crash. connect is omitted - # entirely: the loop body never runs, so no connection is ever attempted. - result = session.synthesize_dialogue("k", []) +def test_synthesize_chunked_empty_text_returns_silent_default(): + # Whitespace-only text yields no chunks -> no audio at the default rate, and no crash. + # connect is omitted entirely: the loop never runs, so no connection is attempted. + result = session.synthesize_chunked("k", session.SpeakConfig(text=" ")) assert result.pcm == b"" assert result.sample_rate == 24000 assert result.audio_duration_seconds == 0.0 diff --git a/tests/test_tts_session_dialogue.py b/tests/test_tts_session_dialogue.py new file mode 100644 index 00000000..bc8ee8c2 --- /dev/null +++ b/tests/test_tts_session_dialogue.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import pytest + +from aai_cli.tts import session +from tests._tts_session_helpers import FakeWS, audio_frame, begin_frame + + +def test_synthesize_dialogue_concatenates_segments_with_silence(): + # One fresh fake socket per segment; record the voice each connection requested. + sockets = [ + FakeWS([begin_frame(sample_rate=24000), audio_frame(b"\xaa\xbb", final=True)]), + FakeWS([begin_frame(sample_rate=24000), audio_frame(b"\xcc\xdd", final=True)]), + ] + urls: list[str] = [] + + def _connect(url: str, **_kwargs): + urls.append(url) + return sockets.pop(0) + + result = session.synthesize_dialogue( + "k", + [("jane", "Hello."), ("michael", "Hi.")], + language="English", + connect=_connect, + ) + # Each segment connected with its own voice. + assert "voice=jane" in urls[0] + assert "voice=michael" in urls[1] + # 0.25 s of silence (24000 * 0.25 * 2 = 12000 zero bytes) sits BETWEEN the two + # segments' PCM, with none at the ends. + gap = b"\x00" * 12000 + assert result.pcm == b"\xaa\xbb" + gap + b"\xcc\xdd" + assert result.sample_rate == 24000 + # Pin the duration formula (len/2/rate) so its operators survive the mutation gate. + assert result.audio_duration_seconds == pytest.approx(len(result.pcm) / 2 / 24000) + + +def test_synthesize_dialogue_single_segment_has_no_silence(): + ws = FakeWS([begin_frame(sample_rate=24000), audio_frame(b"\x01\x02", final=True)]) + result = session.synthesize_dialogue("k", [("jane", "Hi.")], connect=lambda *a, **k: ws) + assert result.pcm == b"\x01\x02" # no leading/trailing pad + + +def test_synthesize_dialogue_uses_server_sample_rate(): + # A non-default server rate must flow into the result (proving the per-segment + # rate is read, not left at the default) and into the duration denominator. + ws = FakeWS([begin_frame(sample_rate=16000), audio_frame(b"\x01\x02", final=True)]) + result = session.synthesize_dialogue("k", [("jane", "Hi.")], connect=lambda *a, **k: ws) + assert result.sample_rate == 16000 + assert result.audio_duration_seconds == pytest.approx(2 / 2 / 16000) + + +def test_synthesize_dialogue_empty_segments_returns_silent_default(): + # No segments -> no audio at the default rate, and no crash. connect is omitted + # entirely: the loop body never runs, so no connection is ever attempted. + result = session.synthesize_dialogue("k", []) + assert result.pcm == b"" + assert result.sample_rate == 24000 + assert result.audio_duration_seconds == 0.0 diff --git a/tests/test_tts_text.py b/tests/test_tts_text.py new file mode 100644 index 00000000..d6f34146 --- /dev/null +++ b/tests/test_tts_text.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from aai_cli.tts import text + + +def test_split_sentences_keeps_terminators_and_drops_blanks(): + assert text.split_sentences("Hello there. How are you?") == ["Hello there.", "How are you?"] + + +def test_split_sentences_does_not_break_on_mid_number_period(): + # A "." inside "$3.50" is not a sentence boundary (no following whitespace). + assert text.split_sentences("It costs $3.50 today.") == ["It costs $3.50 today."] + + +def test_split_sentences_keeps_unterminated_tail(): + # PDF/article text often has no closing punctuation — the tail is kept, not dropped. + assert text.split_sentences("an extracted blob with no terminator") == [ + "an extracted blob with no terminator" + ] + + +def test_chunk_text_packs_short_sentences_into_one_chunk(): + # Several short sentences well under the budget ride in a single chunk (one + # connection) rather than one connection per sentence. + out = text.chunk_text("One. Two. Three.", max_chars=100) + assert out == ["One. Two. Three."] + + +def test_chunk_text_packs_two_sentences_exactly_at_the_budget(): + # "ab." + " " + "cd." == 7 chars: at a budget of 7 they pack into one chunk. Pins the + # space-joiner (+1) and the inclusive `<=` boundary — a budget of 7 must still pack. + assert text.chunk_text("ab. cd.", max_chars=7) == ["ab. cd."] + + +def test_chunk_text_splits_two_sentences_one_over_the_budget(): + # The same two sentences need 7 chars joined; a budget of 6 can't hold both, so the + # second rolls to its own chunk (packing never breaks mid-sentence). Pins that the + # joiner counts the separating space — without it 3+3 would wrongly fit in 6. + assert text.chunk_text("ab. cd.", max_chars=6) == ["ab.", "cd."] + + +def test_chunk_text_slices_a_single_oversized_sentence(): + # A lone "sentence" longer than the budget (no terminators — the PDF case) is sliced + # so no single Generate frame can blow past the server's input ceiling. + out = text.chunk_text("abcdefghij", max_chars=4) + assert out == ["abcd", "efgh", "ij"] + assert all(len(piece) <= 4 for piece in out) + + +def test_chunk_text_empty_input_returns_no_chunks(): + assert text.chunk_text(" ") == [] + + +def test_chunk_text_every_chunk_within_budget_for_a_long_paragraph(): + para = " ".join(f"Sentence number {n} here." for n in range(200)) + out = text.chunk_text(para, max_chars=120) + assert len(out) > 1 # a long paragraph really is chunked + assert all(len(piece) <= 120 for piece in out) + # No text is lost: rejoining the chunks recovers every word in order. + assert out and " ".join(out).split() == para.split()