diff --git a/aai_cli/agent_cascade/config.py b/aai_cli/agent_cascade/config.py index ac44810b..ed5481e6 100644 --- a/aai_cli/agent_cascade/config.py +++ b/aai_cli/agent_cascade/config.py @@ -7,12 +7,16 @@ from __future__ import annotations -from dataclasses import dataclass +from collections.abc import Mapping +from dataclasses import dataclass, field from aai_cli.agent_cascade.voices import DEFAULT_VOICE from aai_cli.core import llm DEFAULT_MODEL = llm.DEFAULT_MODEL +DEFAULT_MAX_TOKENS = llm.DEFAULT_MAX_TOKENS +# The realtime model the cascade transcribes with (same as the agent-cascade template). +DEFAULT_SPEECH_MODEL = "u3-rt-pro" DEFAULT_SYSTEM_PROMPT = ( "You are a friendly, concise voice assistant. Keep replies short and " "conversational. Your reply is read aloud by a text-to-speech engine, so " @@ -32,3 +36,13 @@ class CascadeConfig: greeting: str = DEFAULT_GREETING model: str = DEFAULT_MODEL max_history: int = DEFAULT_MAX_HISTORY + # TTS language (None lets the server pick from the voice). + language: str | None = None + # LLM: cap per-reply tokens and pass through any extra gateway request fields. + max_tokens: int = DEFAULT_MAX_TOKENS + llm_extra: Mapping[str, object] = field(default_factory=dict[str, object]) + # Extra streaming-TTS query params (the --tts-config escape hatch). + tts_extra: Mapping[str, str] = field(default_factory=dict[str, str]) + # Whether STT formats finalized turns. The reply trigger waits for the formatted + # turn when on; with it off, an unformatted end-of-turn is the cue instead. + format_turns: bool = True diff --git a/aai_cli/agent_cascade/engine.py b/aai_cli/agent_cascade/engine.py index 3ac070fb..9c400657 100644 --- a/aai_cli/agent_cascade/engine.py +++ b/aai_cli/agent_cascade/engine.py @@ -20,7 +20,7 @@ from aai_cli.agent_cascade.config import CascadeConfig from aai_cli.agent_cascade.text import split_sentences, trim_history -from aai_cli.core import client, config_builder, llm +from aai_cli.core import client, llm from aai_cli.core.errors import CLIError from aai_cli.tts import session as tts_session from aai_cli.tts.session import SpeakConfig @@ -96,23 +96,6 @@ def _spawn_thread(target: Callable[[], None]) -> _Worker: return thread -# The realtime model the cascade transcribes with (same as the agent-cascade template). -STT_SPEECH_MODEL = "u3-rt-pro" - - -def _stt_params(sample_rate: int) -> StreamingParameters: - """Streaming v3 params for the cascade: PCM at ``sample_rate`` with formatted turns - (so ``turn_is_formatted`` marks the cue to reply).""" - merged = config_builder.merge_streaming_params( - flags={ - "sample_rate": sample_rate, - "format_turns": True, - "speech_model": STT_SPEECH_MODEL, - } - ) - return config_builder.construct_streaming_params(merged) - - @dataclass class CascadeDeps: """The cascade's three network legs plus its thread spawner, all injectable. @@ -133,17 +116,29 @@ def real( config: CascadeConfig, *, audio: Iterable[bytes], - sample_rate: int, + stt_params: StreamingParameters, ) -> CascadeDeps: def run_stt(on_turn: Callable[[object], None]) -> None: - client.stream_audio(api_key, audio, params=_stt_params(sample_rate), on_turn=on_turn) + client.stream_audio(api_key, audio, params=stt_params, on_turn=on_turn) def complete_reply(messages: list[ChatCompletionMessageParam]) -> str: - response = llm.complete(api_key, model=config.model, messages=messages) + response = llm.complete( + api_key, + model=config.model, + messages=messages, + max_tokens=config.max_tokens, + extra=dict(config.llm_extra) or None, + ) return llm.content_of(response) def synthesize(text: str) -> bytes: - spec = SpeakConfig(text=text, voice=config.voice, sample_rate=TTS_SAMPLE_RATE) + spec = SpeakConfig( + text=text, + voice=config.voice, + language=config.language, + sample_rate=TTS_SAMPLE_RATE, + extra=config.tts_extra, + ) return tts_session.synthesize(api_key, spec).pcm return cls(run_stt=run_stt, complete_reply=complete_reply, synthesize=synthesize) @@ -186,7 +181,7 @@ def on_turn(self, event: object) -> None: text = (getattr(event, "transcript", "") or "").strip() if not text: return - if _is_final_turn(event): + if _is_final_turn(event, format_turns=self.config.format_turns): self.renderer.user_final(text) self._barge_in() self.history.append({"role": "user", "content": text}) @@ -261,11 +256,16 @@ def shutdown(self) -> None: self._join_reply() -def _is_final_turn(event: object) -> bool: - """True for a finalized, formatted end-of-turn — the cue to generate a reply.""" - return bool(getattr(event, "end_of_turn", False)) and bool( - getattr(event, "turn_is_formatted", False) - ) +def _is_final_turn(event: object, *, format_turns: bool) -> bool: + """True for an end-of-turn that's the cue to generate a reply. + + With formatting on, wait for the *formatted* turn (better text for the LLM); + with it off the server never sets ``turn_is_formatted``, so a bare end-of-turn + is the cue — otherwise ``--no-format-turns`` would make the agent never reply. + """ + if not bool(getattr(event, "end_of_turn", False)): + return False + return bool(getattr(event, "turn_is_formatted", False)) or not format_turns def run_cascade( diff --git a/aai_cli/commands/agent_cascade/__init__.py b/aai_cli/commands/agent_cascade/__init__.py index 322b7de4..448520ae 100644 --- a/aai_cli/commands/agent_cascade/__init__.py +++ b/aai_cli/commands/agent_cascade/__init__.py @@ -8,16 +8,24 @@ from aai_cli.agent_cascade import voices from aai_cli.agent_cascade.config import ( DEFAULT_GREETING, + DEFAULT_MAX_TOKENS, DEFAULT_MODEL, + DEFAULT_SPEECH_MODEL, DEFAULT_SYSTEM_PROMPT, ) from aai_cli.agent_cascade.voices import DEFAULT_VOICE from aai_cli.app.context import AppState, run_command, run_with_options from aai_cli.commands.agent_cascade import _exec as agent_cascade_exec from aai_cli.core import choices, llm +from aai_cli.streaming.turn_presets import TurnDetectionPreset from aai_cli.ui import output from aai_cli.ui.help_text import examples_epilog +# Option panels that group the per-leg knobs in `--help` instead of one flat wall. +_PANEL_STT = "Speech-to-text" +_PANEL_LLM = "Language model" +_PANEL_TTS = "Text-to-speech" + app = typer.Typer() SPEC = command_registry.CommandModuleSpec( @@ -65,12 +73,71 @@ def agent_cascade( "--voice", help="TTS voice. See --list-voices.", autocompletion=voices.complete_voice, + rich_help_panel=_PANEL_TTS, + ), + language: str | None = typer.Option( + None, + "--language", + help="TTS language (defaults to the voice's language)", + rich_help_panel=_PANEL_TTS, + ), + tts_config: list[str] | None = typer.Option( + None, + "--tts-config", + help="Set any extra streaming-TTS query field as KEY=VALUE (repeatable)", + rich_help_panel=_PANEL_TTS, ), model: str = typer.Option( DEFAULT_MODEL, "--model", help="LLM Gateway model that powers the agent's replies", autocompletion=llm.complete_model, + rich_help_panel=_PANEL_LLM, + ), + max_tokens: int = typer.Option( + DEFAULT_MAX_TOKENS, + "--max-tokens", + help="Max tokens per reply", + min=1, + rich_help_panel=_PANEL_LLM, + ), + llm_config: list[str] | None = typer.Option( + None, + "--llm-config", + help="Set any LLM Gateway request field as KEY=VALUE (repeatable)", + rich_help_panel=_PANEL_LLM, + ), + speech_model: str = typer.Option( + DEFAULT_SPEECH_MODEL, + "--speech-model", + help="Streaming speech model", + rich_help_panel=_PANEL_STT, + ), + format_turns: bool = typer.Option( + True, + "--format-turns/--no-format-turns", + help="Format (punctuate) finalized turns before replying", + rich_help_panel=_PANEL_STT, + ), + turn_detection: TurnDetectionPreset | None = typer.Option( + None, + "--turn-detection", + help="Turn-detection sensitivity preset", + rich_help_panel=_PANEL_STT, + ), + stt_config: list[str] | None = typer.Option( + None, + "--stt-config", + help="Set any StreamingParameters field as KEY=VALUE (repeatable)", + rich_help_panel=_PANEL_STT, + ), + stt_config_file: Path | None = typer.Option( + None, + "--stt-config-file", + help="JSON file of streaming fields", + exists=True, + dir_okay=False, + rich_help_panel=_PANEL_STT, ), system_prompt: str = typer.Option( DEFAULT_SYSTEM_PROMPT, "--system-prompt", help="System prompt (the agent's persona)" @@ -125,5 +192,14 @@ def agent_cascade( greeting=greeting, device=device, output_field=output_field, + speech_model=speech_model, + format_turns=format_turns, + turn_detection=turn_detection, + stt_config=tuple(stt_config or ()), + stt_config_file=stt_config_file, + max_tokens=max_tokens, + llm_config=tuple(llm_config or ()), + language=language, + tts_config=tuple(tts_config or ()), ) run_with_options(ctx, agent_cascade_exec.run_agent_cascade, opts, json=json_out) diff --git a/aai_cli/commands/agent_cascade/_exec.py b/aai_cli/commands/agent_cascade/_exec.py index 03f8406e..3f364f5e 100644 --- a/aai_cli/commands/agent_cascade/_exec.py +++ b/aai_cli/commands/agent_cascade/_exec.py @@ -11,6 +11,7 @@ from collections.abc import Iterable from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING import typer @@ -20,12 +21,24 @@ from aai_cli.agent_cascade.config import CascadeConfig from aai_cli.app.agent_shared import resolve_system_prompt as _resolve_system_prompt from aai_cli.app.context import AppState -from aai_cli.core import choices, client +from aai_cli.core import choices, client, config_builder, llm from aai_cli.core.errors import UsageError +from aai_cli.streaming import turn_presets from aai_cli.streaming.session import resolve_output_modes from aai_cli.streaming.sources import FileSource from aai_cli.tts import session as tts_session +if TYPE_CHECKING: + from assemblyai.streaming.v3 import StreamingParameters + +# A --tts-config key that has its own named flag (or is owned by the cascade), with the +# message steering the user to the right place instead of silently fighting the cascade. +_RESERVED_TTS_KEYS: dict[str, str] = { + "voice": "Set the voice with --voice, not --tts-config.", + "language": "Set the language with --language, not --tts-config.", + "sample_rate": "TTS sample rate is fixed to match the live speaker and can't be overridden.", +} + @dataclass(frozen=True) class AgentCascadeOptions: @@ -45,6 +58,58 @@ class AgentCascadeOptions: greeting: str device: int | None output_field: choices.TextOrJson | None + # Speech-to-text: common knobs named, everything else via --stt-config(-file). + speech_model: str + format_turns: bool + turn_detection: turn_presets.TurnDetectionPreset | None + stt_config: tuple[str, ...] + stt_config_file: Path | None + # Language model: token cap plus any extra gateway request field. + max_tokens: int + llm_config: tuple[str, ...] + # Text-to-speech: language named, any other query param via --tts-config. + language: str | None + tts_config: tuple[str, ...] + + +def _build_stt_params(opts: AgentCascadeOptions, sample_rate: int) -> StreamingParameters: + """Construct the cascade's StreamingParameters from the STT flags + escape hatch. + + A turn-detection preset expands into the three end-of-turn knobs; --stt-config / + --stt-config-file then override any field (including those knobs). sample_rate is + fixed by the audio source, so it's merged in here rather than user-set.""" + eot, min_silence, max_silence = turn_presets.resolve(opts.turn_detection, None, None, None) + flags: dict[str, object] = { + "speech_model": opts.speech_model, + "format_turns": opts.format_turns, + "end_of_turn_confidence_threshold": eot, + "min_turn_silence": min_silence, + "max_turn_silence": max_silence, + } + merged = config_builder.merge_streaming_params( + flags=flags | {"sample_rate": sample_rate}, + overrides=opts.stt_config or None, + config_file=opts.stt_config_file, + ) + return config_builder.construct_streaming_params(merged) + + +def _parse_tts_config(pairs: tuple[str, ...]) -> dict[str, str]: + """Parse --tts-config KEY=VALUE pairs into extra streaming-TTS query params, + rejecting keys that have a named flag (or are cascade-owned).""" + extra: dict[str, str] = {} + for pair in pairs: + key, sep, value = pair.partition("=") + key = key.strip() + if not sep or not key: + raise UsageError( + f"--tts-config expects KEY=VALUE, got {pair!r}.", + suggestion="e.g. --tts-config chunk_size_ms=100", + ) + if key in _RESERVED_TTS_KEYS: + raise UsageError(_RESERVED_TTS_KEYS[key]) + extra[key] = value + return extra def _open_audio( @@ -89,6 +154,10 @@ def run_agent_cascade(opts: AgentCascadeOptions, state: AppState, *, json_mode: # Existence-check the clip before credentials, so a typo'd path reads as # "file not found" instead of triggering a login. client.resolve_audio_source(opts.source, sample=opts.sample) + # Parse the LLM/TTS escape hatches before opening the device, so a bad KEY=VALUE + # fails fast instead of after the mic is live. + llm_extra = llm.parse_gateway_overrides(opts.llm_config) + tts_extra = _parse_tts_config(opts.tts_config) api_key = state.resolve_api_key() config = CascadeConfig( @@ -97,12 +166,18 @@ def run_agent_cascade(opts: AgentCascadeOptions, state: AppState, *, json_mode: # File-driven runs speak a clip and end after the reply, so skip the greeting. greeting="" if from_file else opts.greeting, model=opts.model, + language=opts.language, + max_tokens=opts.max_tokens, + format_turns=opts.format_turns, + llm_extra=llm_extra, + tts_extra=tts_extra, ) renderer = AgentRenderer(json_mode=json_mode, text_mode=text_mode, mic_input=not from_file) audio, player, sample_rate = _open_audio( renderer, source=opts.source, sample=opts.sample, device=opts.device, from_file=from_file ) - deps = engine.CascadeDeps.real(api_key, config, audio=audio, sample_rate=sample_rate) + stt_params = _build_stt_params(opts, sample_rate) + deps = engine.CascadeDeps.real(api_key, config, audio=audio, stt_params=stt_params) try: engine.run_cascade(renderer=renderer, player=player, config=config, deps=deps) except KeyboardInterrupt: diff --git a/aai_cli/streaming/diagnostics.py b/aai_cli/streaming/diagnostics.py index 892464ce..c652a8c5 100644 --- a/aai_cli/streaming/diagnostics.py +++ b/aai_cli/streaming/diagnostics.py @@ -92,17 +92,23 @@ def open_authorized_ws[T]( *, message: str, host: str, + bearer: bool = True, **connect_kwargs: object, ) -> T: - """Open a Bearer-authorized WebSocket, mapping a connect failure via ``classify_error``. + """Open an ``Authorization``-headered WebSocket, mapping a connect failure via + ``classify_error``. The one connect path for the raw-websocket sessions (agent, speak), so a rejected handshake (HTTP 401/403) carries the same actionable suggestion in both and everything else keeps the shared classification. + + ``bearer`` selects the AssemblyAI auth scheme for the endpoint: the Voice Agent + socket expects a ``Bearer `` token (the default), while the streaming + sockets (STT, TTS) authenticate with the **raw** key — pass ``bearer=False`` + for those, or the server refuses the session with an in-band Error frame. """ + token = f"Bearer {api_key}" if bearer else api_key try: - return connect( - url, additional_headers={"Authorization": f"Bearer {api_key}"}, **connect_kwargs - ) + return connect(url, additional_headers={"Authorization": token}, **connect_kwargs) except Exception as exc: raise classify_error(exc, message, host=host) from exc diff --git a/aai_cli/tts/session.py b/aai_cli/tts/session.py index d6774312..71fe4c93 100644 --- a/aai_cli/tts/session.py +++ b/aai_cli/tts/session.py @@ -5,8 +5,8 @@ import contextlib import json from abc import abstractmethod -from collections.abc import Callable -from dataclasses import dataclass +from collections.abc import Callable, Mapping +from dataclasses import dataclass, field from typing import Protocol from urllib.parse import urlencode @@ -54,15 +54,20 @@ def close(self) -> None: @dataclass(frozen=True) class SpeakConfig: """Per-run TTS parameters. Optional fields are only sent when set, so the - server applies its own defaults (voice/language/sample_rate) otherwise.""" + server applies its own defaults (voice/language/sample_rate) otherwise. + + ``extra`` carries any further query params verbatim (the escape hatch for + fields the named ones don't cover); the named fields always win a key clash.""" text: str voice: str | None = None language: str | None = None sample_rate: int | None = None + extra: Mapping[str, str] = field(default_factory=dict[str, str]) def query_params(self) -> dict[str, str]: - params: dict[str, str] = {} + # extra first so the named fields below override it on a key clash. + params: dict[str, str] = {str(k): str(v) for k, v in self.extra.items()} if self.voice is not None: params["voice"] = self.voice if self.language is not None: @@ -125,6 +130,12 @@ def _decode_audio_frame(msg: dict[str, object]) -> bytes: raise APIError(f"TTS service sent an Audio frame that is not valid base64: {exc}") from exc +def _error_frame_detail(msg: dict[str, object]) -> str: + """The ``(code): reason`` tail of an Error frame, with an ``unknown`` reason + fallback so a detail-less frame still yields a non-empty message.""" + return f"({msg.get('error_code', '')}): {msg.get('error') or 'unknown'}" + + def _recv_raw(ws: _WebSocket) -> str | bytes: """One frame off the socket, with a bounded wait: a server that goes silent mid-session (e.g. never sends the final Audio frame) must fail the command, @@ -149,14 +160,22 @@ def _default_connect( def _run_protocol( ws: _WebSocket, config: SpeakConfig, on_warning: Callable[[str], None] | None ) -> SpeakResult: - """Send Generate + ForceFlushTextBuffer, collect Audio until is_final, then Terminate.""" + """Send Generate + Flush, collect Audio until is_final, then Terminate.""" begin = json.loads(_recv_raw(ws)) - if begin.get("type") != "Begin": - raise APIError(f"TTS service did not start the session (got {begin.get('type')!r}).") + begin_type = begin.get("type") + if begin_type == "Error": + # The server refused the session and put the reason in the frame — surface it + # (e.g. a Bearer-vs-raw auth mismatch reads as the auth error here), rather + # than discarding it behind a generic "got 'Error'". + raise APIError(f"TTS service rejected the session {_error_frame_detail(begin)}") + if begin_type != "Begin": + raise APIError(f"TTS service did not start the session (got {begin_type!r}).") sample_rate = int(begin.get("configuration", {}).get("sample_rate", _DEFAULT_SAMPLE_RATE)) ws.send(json.dumps({"type": "Generate", "text": config.text})) - ws.send(json.dumps({"type": "ForceFlushTextBuffer"})) + # "Flush" forces synthesis of any buffered text; the server rejects the older + # "ForceFlushTextBuffer" tag with a validation error (matches the template). + ws.send(json.dumps({"type": "Flush"})) pcm = bytearray() while True: @@ -166,10 +185,13 @@ def _run_protocol( pcm.extend(_decode_audio_frame(msg)) if msg.get("is_final"): break + elif mtype == "FlushDone": + # The live server ends a synthesis with FlushDone (its Audio frames carry + # no is_final flag), so this is the real end-of-stream marker — stop here, + # or the loop blocks until the recv timeout and the audio is lost. + break elif mtype == "Error": - raise APIError( - f"TTS error ({msg.get('error_code', '')}): {msg.get('error') or 'unknown'}" - ) + raise APIError(f"TTS error {_error_frame_detail(msg)}") elif mtype == "Warning" and on_warning is not None: on_warning(str(msg.get("warning", ""))) @@ -201,6 +223,10 @@ def synthesize( ws_url(config.query_params()), message="Could not connect to the TTS service", host=environments.active().streaming_tts_host, + # Streaming TTS authenticates with the raw API key, not a Bearer token (the + # AssemblyAI streaming convention the agent-cascade template follows); a + # Bearer token upgrades fine but is rejected in-band as an Error frame. + bearer=False, max_size=None, # no frame cap: a synthesis's Audio frames can exceed the 1 MiB default ) try: diff --git a/tests/__snapshots__/test_snapshots_help_run.ambr b/tests/__snapshots__/test_snapshots_help_run.ambr index 93509711..1e5b4073 100644 --- a/tests/__snapshots__/test_snapshots_help_run.ambr +++ b/tests/__snapshots__/test_snapshots_help_run.ambr @@ -28,12 +28,6 @@ ╭─ Options ────────────────────────────────────────────────────────────────────╮ │ --sample Speak the hosted wildfires.mp3 │ │ sample to the agent │ - │ --voice TEXT TTS voice. See --list-voices. │ - │ [default: jane] │ - │ --model TEXT LLM Gateway model that powers the │ - │ agent's replies │ - │ [default: │ - │ claude-haiku-4-5-20251001] │ │ --system-prompt TEXT System prompt (the agent's │ │ persona) │ │ [default: You are a friendly, │ @@ -57,6 +51,44 @@ │ lines as plain stdout, │ │ pipe-friendly) or json │ │ --help Show this message and exit. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Text-to-speech ─────────────────────────────────────────────────────────────╮ + │ --voice TEXT TTS voice. See --list-voices. [default: jane] │ + │ --language TEXT TTS language (defaults to the voice's language) │ + │ --tts-config TEXT Set any extra streaming-TTS query field as │ + │ KEY=VALUE (repeatable) │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Language model ─────────────────────────────────────────────────────────────╮ + │ --model TEXT LLM Gateway model that powers the │ + │ agent's replies │ + │ [default: │ + │ claude-haiku-4-5-20251001] │ + │ --max-tokens INTEGER RANGE [x>=1] Max tokens per reply │ + │ [default: 1000] │ + │ --llm-config TEXT Set any LLM Gateway request field │ + │ as KEY=VALUE (repeatable) │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Speech-to-text ─────────────────────────────────────────────────────────────╮ + │ --speech-model TEXT Streaming speech │ + │ model │ + │ [default: │ + │ u3-rt-pro] │ + │ --format-turns --no-format-turns Format │ + │ (punctuate) │ + │ finalized turns │ + │ before replying │ + │ [default: │ + │ format-turns] │ + │ --turn-detection [aggressive|bala Turn-detection │ + │ nced|conservativ sensitivity │ + │ e] preset │ + │ --stt-config TEXT Set any │ + │ StreamingParame… │ + │ field as │ + │ KEY=VALUE │ + │ (repeatable) │ + │ --stt-config-file FILE JSON file of │ + │ streaming fields │ ╰──────────────────────────────────────────────────────────────────────────────╯ Examples diff --git a/tests/test_agent_cascade_command.py b/tests/test_agent_cascade_command.py index c89b759a..ad010cda 100644 --- a/tests/test_agent_cascade_command.py +++ b/tests/test_agent_cascade_command.py @@ -20,9 +20,10 @@ from aai_cli.app.context import AppState from aai_cli.commands.agent_cascade import _exec from aai_cli.commands.agent_cascade._exec import AgentCascadeOptions, run_agent_cascade -from aai_cli.core import config +from aai_cli.core import config, config_builder from aai_cli.core.errors import CLIError, UsageError from aai_cli.main import app +from aai_cli.streaming import turn_presets runner = CliRunner() @@ -37,6 +38,15 @@ greeting="hello", device=None, output_field=None, + speech_model="u3-rt-pro", + format_turns=True, + turn_detection=None, + stt_config=(), + stt_config_file=None, + max_tokens=1000, + llm_config=(), + language=None, + tts_config=(), ) @@ -98,6 +108,40 @@ def test_device_with_file_source_is_rejected(monkeypatch): run_agent_cascade(_opts(source="clip.wav", device=2), AppState(), json_mode=False) +# --- argv -> options seam ---------------------------------------------------- + + +@pytest.mark.parametrize( + ("argv", "expected"), + [([], True), (["--no-format-turns"], False), (["--format-turns"], True)], +) +def test_format_turns_flag_resolves_into_options(monkeypatch, argv, expected): + # Pin the Typer default (omitted -> True) and both explicit forms, captured at the + # argv->options seam so the run body never executes. + captured = {} + + def fake_run(opts, state, *, json_mode): + captured["opts"] = opts + + monkeypatch.setattr(_exec, "run_agent_cascade", fake_run) + result = runner.invoke(app, ["agent-cascade", *argv]) + assert result.exit_code == 0 + assert captured["opts"].format_turns is expected + + +def test_stt_config_file_must_exist(): + # --stt-config-file is existence-checked at parse time (exists=True), so a missing + # path fails as a Typer usage error before the body runs — not later on open. Wide + # terminal so the "does not exist" message isn't wrapped by the 80-col error box. + result = runner.invoke( + app, + ["agent-cascade", "--stt-config-file", "/no/such/file.json"], + env={"COLUMNS": "300"}, + ) + assert result.exit_code == 2 + assert "does not exist" in result.output + + # --- system prompt resolution ------------------------------------------------ @@ -230,10 +274,110 @@ def boom(**kwargs): assert rendered["r"].closed is True +# --- STT param + TTS config builders ----------------------------------------- + + +def test_build_stt_params_threads_named_flags(): + params = _exec._build_stt_params(_opts(speech_model="u3-rt-pro", format_turns=False), 8000) + assert params.sample_rate == 8000 # fixed by the audio source, not a flag + assert params.format_turns is False + assert params.speech_model.value == "u3-rt-pro" + + +def test_build_stt_params_expands_turn_detection_preset(): + params = _exec._build_stt_params( + _opts(turn_detection=turn_presets.TurnDetectionPreset.conservative), 16000 + ) + # The conservative preset's published end-of-turn confidence threshold. + assert params.end_of_turn_confidence_threshold == 0.7 + + +def test_build_stt_params_stt_config_overrides_any_field(): + params = _exec._build_stt_params( + _opts(stt_config=("end_of_turn_confidence_threshold=0.9",)), 16000 + ) + assert params.end_of_turn_confidence_threshold == 0.9 + + +def test_build_stt_params_reads_config_file(tmp_path): + cfg = tmp_path / "stt.json" + cfg.write_text('{"min_turn_silence": 123}', encoding="utf-8") + params = _exec._build_stt_params(_opts(stt_config_file=cfg), 16000) + assert params.min_turn_silence == 123 + + +def test_parse_tts_config_parses_pairs(): + assert _exec._parse_tts_config(("chunk_size_ms=100", "foo=bar")) == { + "chunk_size_ms": "100", + "foo": "bar", + } + + +def test_parse_tts_config_rejects_malformed_pair(): + with pytest.raises(UsageError, match="expects KEY=VALUE"): + _exec._parse_tts_config(("no-equals",)) + + +@pytest.mark.parametrize( + ("key", "hint"), + [("voice", "--voice"), ("language", "--language"), ("sample_rate", "fixed")], +) +def test_parse_tts_config_rejects_reserved_keys(key, hint): + with pytest.raises(UsageError, match=hint): + _exec._parse_tts_config((f"{key}=x",)) + + +def test_run_threads_all_leg_options_into_config_and_params(monkeypatch): + monkeypatch.setattr(_exec.tts_session, "require_available", lambda _c: None) + monkeypatch.setattr(config, "resolve_api_key", lambda **_: "k") + monkeypatch.setattr(_exec, "FileSource", lambda src: types.SimpleNamespace(sample_rate=16000)) + monkeypatch.setattr(_exec.client, "resolve_audio_source", lambda source, sample: "clip.wav") + monkeypatch.setattr(_exec.engine, "run_cascade", lambda **kw: None) + captured = {} + + def fake_real(api_key, config, *, audio, stt_params): + captured["config"] = config + captured["stt_params"] = stt_params + return CascadeDeps( + run_stt=lambda _o: None, complete_reply=lambda _m: "", synthesize=lambda _t: b"" + ) + + monkeypatch.setattr(_exec.engine.CascadeDeps, "real", fake_real) + run_agent_cascade( + _opts( + source="clip.wav", + language="en", + max_tokens=321, + format_turns=False, + llm_config=("temperature=0.3",), + tts_config=("chunk_size_ms=100",), + speech_model="u3-rt-pro", + ), + AppState(), + json_mode=False, + ) + cfg = captured["config"] + assert cfg.language == "en" + assert cfg.max_tokens == 321 + assert cfg.format_turns is False + assert cfg.llm_extra == {"temperature": 0.3} + assert cfg.tts_extra == {"chunk_size_ms": "100"} + # The STT flags are realized into the params the cascade will stream with. + assert captured["stt_params"].format_turns is False + assert captured["stt_params"].sample_rate == 16000 + + # --- CascadeDeps.real (the three live legs) ---------------------------------- -def test_deps_real_run_stt_passes_formatted_params(monkeypatch): +def _stt_params(**flags: object): + merged = config_builder.merge_streaming_params( + flags={"sample_rate": 16000, "format_turns": True, "speech_model": "u3-rt-pro", **flags} + ) + return config_builder.construct_streaming_params(merged) + + +def test_deps_real_run_stt_passes_prebuilt_params_through(monkeypatch): captured = {} def fake_stream_audio(api_key, source, *, params, on_turn): @@ -243,34 +387,66 @@ def fake_stream_audio(api_key, source, *, params, on_turn): monkeypatch.setattr(engine.client, "stream_audio", fake_stream_audio) audio: list[bytes] = [] - deps = CascadeDeps.real("k", CascadeConfig(), audio=audio, sample_rate=16000) + params = _stt_params() + deps = CascadeDeps.real("k", CascadeConfig(), audio=audio, stt_params=params) deps.run_stt(lambda event: None) assert captured["api_key"] == "k" assert captured["source"] is audio - assert captured["params"].sample_rate == 16000 - assert captured["params"].format_turns is True + # The cascade streams exactly the params it was handed — no re-derivation. + assert captured["params"] is params + +def test_deps_real_complete_reply_threads_model_tokens_and_extra(monkeypatch): + captured = {} -def test_deps_real_complete_reply_returns_content(monkeypatch): - monkeypatch.setattr(engine.llm, "complete", lambda api_key, **kwargs: "raw-response") + def fake_complete(api_key, **kwargs): + captured.update(kwargs) + return "raw-response" + + monkeypatch.setattr(engine.llm, "complete", fake_complete) monkeypatch.setattr(engine.llm, "content_of", lambda response: response.upper()) - deps = CascadeDeps.real("k", CascadeConfig(model="m"), audio=[], sample_rate=16000) + cfg = CascadeConfig(model="m", max_tokens=222, llm_extra={"temperature": 0.5}) + deps = CascadeDeps.real("k", cfg, audio=[], stt_params=_stt_params()) assert deps.complete_reply([{"role": "user", "content": "hi"}]) == "RAW-RESPONSE" + assert captured["model"] == "m" + assert captured["max_tokens"] == 222 + assert captured["extra"] == {"temperature": 0.5} + + +def test_deps_real_complete_reply_sends_no_extra_when_unset(monkeypatch): + captured = {} + + def fake_complete(api_key, **kwargs): + captured.update(kwargs) + return "x" + + monkeypatch.setattr(engine.llm, "complete", fake_complete) + monkeypatch.setattr(engine.llm, "content_of", lambda response: response) + deps = CascadeDeps.real("k", CascadeConfig(), audio=[], stt_params=_stt_params()) + deps.complete_reply([{"role": "user", "content": "hi"}]) + # Empty overrides collapse to None, not an empty dict, so the gateway sees no extra body. + assert captured["extra"] is None -def test_deps_real_synthesize_returns_pcm(monkeypatch): +def test_deps_real_synthesize_threads_voice_language_and_extra(monkeypatch): captured = {} def fake_synth(api_key, spec): captured["voice"] = spec.voice + captured["language"] = spec.language captured["text"] = spec.text captured["sample_rate"] = spec.sample_rate + captured["params"] = spec.query_params() return types.SimpleNamespace(pcm=b"AUDIO") monkeypatch.setattr(engine.tts_session, "synthesize", fake_synth) - deps = CascadeDeps.real("k", CascadeConfig(voice="vera"), audio=[], sample_rate=16000) + cfg = CascadeConfig(voice="vera", language="en", tts_extra={"chunk_size_ms": "100"}) + deps = CascadeDeps.real("k", cfg, audio=[], stt_params=_stt_params()) assert deps.synthesize("say this") == b"AUDIO" assert captured["voice"] == "vera" + assert captured["language"] == "en" assert captured["text"] == "say this" # TTS always synthesizes at the 24 kHz the live player is opened at. assert captured["sample_rate"] == engine.TTS_SAMPLE_RATE == 24000 + # The --tts-config escape hatch rides along as an extra query param. + assert captured["params"]["chunk_size_ms"] == "100" diff --git a/tests/test_agent_cascade_config.py b/tests/test_agent_cascade_config.py index 3d990ebc..e366187c 100644 --- a/tests/test_agent_cascade_config.py +++ b/tests/test_agent_cascade_config.py @@ -19,6 +19,13 @@ def test_default_config_values(): # The sliding-window default keeps the last 40 messages of context. assert config.max_history == 40 assert DEFAULT_MAX_HISTORY == 40 + # Formatting is on by default, so the reply trigger waits for the formatted turn. + assert config.format_turns is True + assert config.language is None + assert config.max_tokens == llm.DEFAULT_MAX_TOKENS + # Escape-hatch overrides start empty. + assert dict(config.llm_extra) == {} + assert dict(config.tts_extra) == {} def test_config_is_frozen(): diff --git a/tests/test_agent_cascade_engine.py b/tests/test_agent_cascade_engine.py index 917870a7..d1134339 100644 --- a/tests/test_agent_cascade_engine.py +++ b/tests/test_agent_cascade_engine.py @@ -335,16 +335,33 @@ def test_shutdown_without_worker_is_safe(): ("end_of_turn", "formatted", "expected"), [(True, True, True), (True, False, False), (False, True, False), (False, False, False)], ) -def test_is_final_turn(end_of_turn, formatted, expected): +def test_is_final_turn_with_formatting_waits_for_formatted(end_of_turn, formatted, expected): + # With formatting on, only a formatted end-of-turn is the cue (better text for the LLM). event = _turn("hi", end_of_turn=end_of_turn, turn_is_formatted=formatted) - assert engine._is_final_turn(event) is expected + assert engine._is_final_turn(event, format_turns=True) is expected + + +@pytest.mark.parametrize( + ("end_of_turn", "formatted", "expected"), + [(True, False, True), (True, True, True), (False, False, False), (False, True, False)], +) +def test_is_final_turn_without_formatting_triggers_on_end_of_turn(end_of_turn, formatted, expected): + # With --no-format-turns the server never sets turn_is_formatted, so a bare + # end-of-turn must be the cue — otherwise the agent would never reply. + event = _turn("hi", end_of_turn=end_of_turn, turn_is_formatted=formatted) + assert engine._is_final_turn(event, format_turns=False) is expected def test_is_final_turn_defaults_missing_attrs_to_not_final(): # A formatted turn missing end_of_turn (and vice versa) must read as not-final, so # each absent field defaults to False rather than being treated as present-and-true. - assert engine._is_final_turn(types.SimpleNamespace(turn_is_formatted=True)) is False - assert engine._is_final_turn(types.SimpleNamespace(end_of_turn=True)) is False + assert ( + engine._is_final_turn(types.SimpleNamespace(turn_is_formatted=True), format_turns=True) + is False + ) + assert ( + engine._is_final_turn(types.SimpleNamespace(end_of_turn=True), format_turns=True) is False + ) def test_spawn_thread_runs_target(): diff --git a/tests/test_streaming_diagnostics.py b/tests/test_streaming_diagnostics.py index 17c5cd78..150f7a3a 100644 --- a/tests/test_streaming_diagnostics.py +++ b/tests/test_streaming_diagnostics.py @@ -6,6 +6,7 @@ import logging import types +from collections.abc import Callable import pytest @@ -16,6 +17,7 @@ SDK_STREAMING_LOGGER, handshake_error, handshake_suggestion, + open_authorized_ws, silence_streaming_logging, ) @@ -139,3 +141,31 @@ def test_non_handshake_errors_return_none(): assert handshake_error(closed, "Streaming error", host="h") is None # Other HTTP statuses (e.g. a 500 on the upgrade) are not auth-shaped. assert handshake_error(_WsHandshake(500), "Streaming error", host="h") is None + + +def _header_capture() -> tuple[Callable[..., str], dict[str, str]]: + """A connect double plus the dict its Authorization header lands in.""" + captured: dict[str, str] = {} + + def _connect(url: str, *, additional_headers: dict[str, str], **kwargs: object) -> str: + captured.update(additional_headers) + return "ws" + + return _connect, captured + + +def test_open_authorized_ws_defaults_to_bearer_token(): + # The Voice Agent endpoint expects a Bearer token, so that's the default when + # `bearer` is omitted entirely (not just when passed True). + connect, captured = _header_capture() + open_authorized_ws(connect, "secret", "wss://h/ws", message="m", host="h") + assert captured["Authorization"] == "Bearer secret" + + +def test_open_authorized_ws_sends_raw_key_when_bearer_false(): + # Streaming endpoints (STT, TTS) authenticate with the raw API key — no "Bearer " + # prefix. bearer=False must send the key verbatim, not a Bearer token. + connect, captured = _header_capture() + open_authorized_ws(connect, "secret", "wss://h/ws", message="m", host="h", bearer=False) + assert captured["Authorization"] == "secret" + assert "Bearer" not in captured["Authorization"] diff --git a/tests/test_tts_session.py b/tests/test_tts_session.py index 092ee719..a181876d 100644 --- a/tests/test_tts_session.py +++ b/tests/test_tts_session.py @@ -85,8 +85,9 @@ def close(self) -> None: def _audio_frame(pcm: bytes, *, final: bool) -> str: - # The real server's Audio frames carry only the PCM payload and the final flag; - # the sample rate is reported once, up front, in the Begin frame's configuration. + # An Audio frame with an explicit is_final flag — the defensive end-of-stream path + # (the live server instead omits is_final and ends with FlushDone, see _audio_chunk). + # The sample rate is reported once, up front, in the Begin frame's configuration. return json.dumps( { "type": "Audio", @@ -96,6 +97,18 @@ def _audio_frame(pcm: bytes, *, final: bool) -> str: ) +def _audio_chunk(pcm: bytes) -> str: + # The real server's Audio frames carry the PCM payload and a flush_id, but NO + # is_final flag — completion is signalled by a separate FlushDone frame. + return json.dumps( + {"type": "Audio", "audio": base64.b64encode(pcm).decode("ascii"), "flush_id": 0} + ) + + +def _flush_done_frame() -> str: + return json.dumps({"type": "FlushDone", "flush_id": 0, "audio_duration_ms": 880}) + + def _begin_frame(*, sample_rate: int = 24000) -> str: return json.dumps( { @@ -128,8 +141,10 @@ def test_synthesize_drives_the_full_protocol(): cfg = session.SpeakConfig(text="hello", voice="jane") result = session.synthesize("k", cfg, connect=_connect_returning(ws, captured)) - # Sends Generate(text), then ForceFlushTextBuffer, then Terminate — in that order. - assert [m["type"] for m in ws.sent] == ["Generate", "ForceFlushTextBuffer", "Terminate"] + # Sends Generate(text), then Flush, then Terminate — in that order. The flush tag + # is "Flush" (the server's accepted tag, as the agent-cascade template sends), + # not "ForceFlushTextBuffer", which the server rejects with a validation error. + assert [m["type"] for m in ws.sent] == ["Generate", "Flush", "Terminate"] assert ws.sent[0]["text"] == "hello" # Accumulates decoded PCM across chunks and stops on the is_final frame. assert result.pcm == b"\x01\x02\x03\x04\x05\x06" @@ -137,10 +152,31 @@ def test_synthesize_drives_the_full_protocol(): assert result.sample_rate == 24000 # 6 bytes / 2 bytes-per-sample / 24000 Hz. assert result.audio_duration_seconds == pytest.approx(6 / 2 / 24000) - # Auth header carries the key as a Bearer token. - assert captured["kwargs"]["additional_headers"]["Authorization"] == "Bearer k" + # Streaming TTS authenticates with the raw API key — no "Bearer " prefix (the + # AssemblyAI streaming convention; the working agent-cascade template does the + # same). A Bearer token makes the server reject the session with an Error frame. + assert captured["kwargs"]["additional_headers"]["Authorization"] == "k" + assert "Bearer" not in captured["kwargs"]["additional_headers"]["Authorization"] # The frame-size cap is lifted: Audio frames can exceed websockets' 1 MiB default. assert captured["kwargs"]["max_size"] is None + + +def test_synthesize_stops_on_flush_done_when_audio_omits_is_final(): + # The live server ends a synthesis with a FlushDone frame and never sets is_final + # on its Audio frames. Without handling FlushDone the loop blocks until the recv + # timeout (the audio is silently lost), so FlushDone must end collection and return + # every PCM byte gathered so far — then Terminate the session as usual. + ws = FakeWS( + [ + _begin_frame(sample_rate=24000), + _audio_chunk(b"\x01\x02\x03\x04"), + _audio_chunk(b"\x05\x06"), + _flush_done_frame(), + ] + ) + result = session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) + assert result.pcm == b"\x01\x02\x03\x04\x05\x06" + assert [m["type"] for m in ws.sent] == ["Generate", "Flush", "Terminate"] assert ws.closed is True @@ -167,8 +203,17 @@ def test_synthesize_falls_back_to_default_rate_when_begin_omits_configuration(): def test_synthesize_raises_on_missing_begin(): ws = FakeWS([json.dumps({"type": "Audio"})]) - with pytest.raises(APIError): + with pytest.raises(APIError, match="did not start the session"): + session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) + + +def test_synthesize_session_start_error_frame_surfaces_its_reason(): + # When the very first frame is an Error (not Begin), the server is explaining why + # it refused the session — surface that reason and code, not a generic "got 'Error'". + ws = FakeWS([json.dumps({"type": "Error", "error_code": 401, "error": "bad token"})]) + with pytest.raises(APIError, match="bad token") as excinfo: session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) + assert "401" in str(excinfo.value) def test_synthesize_maps_audio_frame_without_payload_to_api_error():