diff --git a/.importlinter b/.importlinter index 0bb57bd3..ef910f60 100644 --- a/.importlinter +++ b/.importlinter @@ -32,6 +32,7 @@ type = forbidden ; deliberate, so this short list does not drift the way the old core list did. source_modules = aai_cli.agent + aai_cli.agent_framework aai_cli.auth aai_cli.code_gen aai_cli.init diff --git a/README.md b/README.md index 5080dd62..26837579 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ That's it. Run `assembly onboard` for a guided tour, or see [Installation](#-ins | `assembly stream` | Real-time transcription from your microphone, a file, or a URL — on macOS it can capture system audio too | | `assembly dictate` | Push-to-talk dictation: press Enter to record, Enter again for instant text (Sync STT API, up to 120 s per utterance) | | `assembly agent` | Full-duplex spoken conversation with a voice agent, right in your terminal | +| `assembly agent-framework` | Same live conversation, but wired client-side from Streaming STT + the LLM Gateway + streaming TTS, like the `agent-framework` starter (sandbox-only) | | `assembly speak` | Synthesize text to speech over the streaming-TTS WebSocket (sandbox-only) | | `assembly llm` | Prompt the LLM Gateway over a transcript, stdin, or a live stream | | `assembly clip` | Cut audio/video with ffmpeg by diarized speaker, text match, LLM pick, or time range (`--video` keeps the picture for URL sources) — clip boundaries snap into nearby silence | diff --git a/REFERENCE.md b/REFERENCE.md index 6c301b9e..443b63e2 100644 --- a/REFERENCE.md +++ b/REFERENCE.md @@ -76,6 +76,7 @@ each carrying a `"type"` field to dispatch on: | ------- | ----------- | | `assembly stream --json` | `begin`, `turn`, `termination` | | `assembly agent --json` | `session.ready`, `transcript.user.delta`, `transcript.user`, `reply.started`, `transcript.agent`, `reply.done` | +| `assembly agent-framework --json` | `session.ready`, `transcript.user.delta`, `transcript.user`, `reply.started`, `transcript.agent`, `reply.done` | | `assembly dictate --json` | `utterance` | | `assembly llm --follow --json` | `answer` | | `assembly transcribe --json` | `result` (one per source) | diff --git a/aai_cli/AGENTS.md b/aai_cli/AGENTS.md index 2f77e96b..fe6444f6 100644 --- a/aai_cli/AGENTS.md +++ b/aai_cli/AGENTS.md @@ -151,6 +151,7 @@ heavily-reworked commands with long bodies; small commands keep the inline - **`streaming/`** + `client.stream_audio` — v3 realtime API. Event callbacks run on the SDK reader thread and guard against `BrokenPipeError` (`stdio.silence_stdout()`) so a closed pipe never dumps a thread traceback. - **`core/sync_stt.py`** + **`core/hotkey.py`** + `commands/dictate/` — `assembly dictate`: push-to-talk dictation over the **Sync STT API** (`Environment.sync_base`, one POST `/transcribe` per utterance with the required `X-AAI-Model: u3-sync-pro` header; 80 ms–120 s of PCM/WAV). `hotkey.TerminalKeys` scopes stdin into cbreak (Ctrl-C still signals) and reads single keypresses; `dictate_exec._record` polls it with a zero timeout between ~100 ms mic chunks. All three boundaries (keys, mic, HTTP) are injectable, so the suite never needs a real terminal — `tests/test_hotkey.py` drives a pty pair for the termios behavior. - **`agent/`** — full-duplex voice agent (mic in, TTS out via `voices.py`). +- **`agent_framework/`** + `commands/agent_framework/` — `assembly agent-framework`: the same live terminal conversation as `assembly agent`, but **client-orchestrated** — `engine.run_cascade` wires Streaming STT → the LLM Gateway → streaming TTS itself instead of talking to the Voice Agent endpoint, mirroring what the `agent-framework` `assembly init` template does server-side. **Sandbox-only** (streaming TTS has no prod host; guarded via `tts.session.require_available`). Reuses the agent slice's `DuplexAudio`/`AgentRenderer` and `core.client.stream_audio`/`core.llm.complete`/`tts.session.synthesize`; the three network legs are injected through `engine.CascadeDeps` (the `tts/session.py` seam) so the cascade — greeting, per-sentence TTS, barge-in, history window — is unit-tested against fakes with no sockets/mic/speaker. - **`tts/`** + `commands/speak.py` — `assembly speak` synthesizes text to speech over the sandbox streaming-TTS WebSocket (`streaming-tts.sandbox000.…`). **Sandbox-only:** `session.is_available()` is false in production (empty `Environment.streaming_tts_host`), so the command exits 2 with a `--sandbox` hint. `session.synthesize` drives a Begin→Generate→Flush→Audio→Terminate protocol with an injectable `connect` for hermetic tests (mirrors `agent/session.py`); `audio.py` plays the PCM (default) or writes a WAV (`--out`). - **`code_gen/`** — backs `--show-code` on `transcribe`/`stream`/`agent`: builds a ready-to-run Python SDK script from exactly the flags passed (no API key needed; generated code reads `ASSEMBLYAI_API_KEY`). - **`auth/`** — browser-assisted `assembly login` via AMS + **Stytch B2B OAuth discovery** (`discovery.py`, `flow.py`, `loopback.py`, `ams.py`). Not Stytch Connected Apps. diff --git a/aai_cli/agent_framework/__init__.py b/aai_cli/agent_framework/__init__.py new file mode 100644 index 00000000..a8a89173 --- /dev/null +++ b/aai_cli/agent_framework/__init__.py @@ -0,0 +1,15 @@ +"""The terminal *agent framework* slice: a client-orchestrated voice cascade. + +`assembly agent-framework` holds the same kind of live voice conversation as +`assembly agent`, but where `agent` talks to AssemblyAI's single Voice Agent +endpoint, this slice wires the three primitives together itself — Streaming STT +-> the LLM Gateway -> streaming TTS — exactly like the ``agent-framework`` +``assembly init`` template does server-side. Because it uses streaming TTS it is +sandbox-only. + +`engine.run_cascade` is the orchestrator; it takes injected dependencies +(`CascadeDeps`) so tests drive the whole cascade against fakes, the same seam +`aai_cli/tts/session.py` uses. +""" + +from __future__ import annotations diff --git a/aai_cli/agent_framework/config.py b/aai_cli/agent_framework/config.py new file mode 100644 index 00000000..4d41fc3c --- /dev/null +++ b/aai_cli/agent_framework/config.py @@ -0,0 +1,34 @@ +"""Per-run configuration for the terminal voice cascade. + +Defaults mirror the ``agent-framework`` ``assembly init`` template's +``api/settings.py`` so the CLI conversation and the scaffolded app behave the +same out of the box. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from aai_cli.agent_framework.voices import DEFAULT_VOICE +from aai_cli.core import llm + +DEFAULT_MODEL = llm.DEFAULT_MODEL +DEFAULT_SYSTEM_PROMPT = ( + "You are a friendly, concise voice assistant. Keep replies short and " + "conversational. Your reply is read aloud by a text-to-speech engine, so " + "write plain spoken prose — no markdown, emoji, bullet lists, or code." +) +DEFAULT_GREETING = "Hi! I'm your AssemblyAI voice agent. What can I help you with?" +# Sliding-window size: keep the last N messages of conversation as LLM context. +DEFAULT_MAX_HISTORY = 40 + + +@dataclass(frozen=True) +class CascadeConfig: + """The static knobs for one cascade run, fixed once the flags are parsed.""" + + voice: str = DEFAULT_VOICE + system_prompt: str = DEFAULT_SYSTEM_PROMPT + greeting: str = DEFAULT_GREETING + model: str = DEFAULT_MODEL + max_history: int = DEFAULT_MAX_HISTORY diff --git a/aai_cli/agent_framework/engine.py b/aai_cli/agent_framework/engine.py new file mode 100644 index 00000000..d43431e0 --- /dev/null +++ b/aai_cli/agent_framework/engine.py @@ -0,0 +1,290 @@ +"""The terminal voice cascade: Streaming STT -> LLM Gateway -> streaming TTS. + +``run_cascade`` greets the user, then drives a live conversation by reading STT +turns and, for each finalized turn, streaming an LLM reply out through TTS +sentence-by-sentence. A new turn barges in on a reply that is still playing. + +All three network legs are injected through ``CascadeDeps`` (the same seam +``aai_cli/tts/session.py`` uses), so the orchestration is unit-tested against +fakes with no sockets, microphone, or speaker. +""" + +from __future__ import annotations + +import contextlib +import threading +from abc import abstractmethod +from collections.abc import Callable, Iterable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Protocol + +from aai_cli.agent_framework.config import CascadeConfig +from aai_cli.agent_framework.text import split_sentences, trim_history +from aai_cli.core import client, config_builder, llm +from aai_cli.core.errors import CLIError +from aai_cli.tts import session as tts_session +from aai_cli.tts.session import SpeakConfig +from aai_cli.ui import output + +if TYPE_CHECKING: + from assemblyai.streaming.v3 import StreamingParameters + from openai.types.chat import ChatCompletionMessageParam + +# Streaming TTS synthesizes at 24 kHz, the rate the live player is opened at. +TTS_SAMPLE_RATE = 24000 + + +class _Worker(Protocol): + """The slice of a thread the session drives: started already, queryable, joinable.""" + + @abstractmethod + def is_alive(self) -> bool: + """Whether the reply worker is still running.""" + + def join(self) -> None: + """Block until the reply worker finishes.""" + + +class Renderer(Protocol): + """The conversation-rendering surface the cascade drives (AgentRenderer satisfies it).""" + + def connected(self) -> None: + """Announce the session is live and listening.""" + + def user_partial(self, text: str) -> None: + """Show an interim user transcript.""" + + def user_final(self, text: str) -> None: + """Show a finalized user transcript.""" + + def reply_started(self) -> None: + """Mark the start of an agent reply.""" + + def agent_transcript(self, text: str, *, interrupted: bool) -> None: + """Show a line of the agent's reply.""" + + def reply_done(self, *, interrupted: bool) -> None: + """Mark the end of an agent reply.""" + + +class Player(Protocol): + """The speaker the cascade enqueues TTS audio into (DuplexAudio/NullPlayer satisfy it).""" + + def start(self) -> None: + """Open the output stream.""" + + def enqueue(self, pcm: bytes) -> None: + """Queue PCM audio for playback.""" + + def flush(self) -> None: + """Drop any queued-but-unplayed audio (used on barge-in).""" + + def close(self) -> None: + """Close the output stream.""" + + +def _new_history() -> list[ChatCompletionMessageParam]: + """Typed empty-history factory (ChatCompletionMessageParam is import-time-only).""" + return [] + + +def _spawn_thread(target: Callable[[], None]) -> _Worker: + """Start ``target`` on a daemon thread so a reply is generated without blocking + the STT reader (which must stay free to detect a barge-in).""" + thread = threading.Thread(target=target, daemon=True) # pragma: no mutate + thread.start() + return thread + + +# The realtime model the cascade transcribes with (same as the agent-framework template). +STT_SPEECH_MODEL = "u3-rt-pro" + + +def _stt_params(sample_rate: int) -> StreamingParameters: + """Streaming v3 params for the cascade: PCM at ``sample_rate`` with formatted turns + (so ``turn_is_formatted`` marks the cue to reply).""" + merged = config_builder.merge_streaming_params( + flags={ + "sample_rate": sample_rate, + "format_turns": True, + "speech_model": STT_SPEECH_MODEL, + } + ) + return config_builder.construct_streaming_params(merged) + + +@dataclass +class CascadeDeps: + """The cascade's three network legs plus its thread spawner, all injectable. + + ``CascadeDeps.real`` wires the live STT/LLM/TTS clients; tests pass fakes with + the same shapes (and a synchronous ``spawn``) to drive the orchestration. + """ + + run_stt: Callable[[Callable[[object], None]], None] + complete_reply: Callable[[list[ChatCompletionMessageParam]], str] + synthesize: Callable[[str], bytes] + spawn: Callable[[Callable[[], None]], _Worker] = _spawn_thread + + @classmethod + def real( + cls, + api_key: str, + config: CascadeConfig, + *, + audio: Iterable[bytes], + sample_rate: int, + ) -> CascadeDeps: + def run_stt(on_turn: Callable[[object], None]) -> None: + client.stream_audio(api_key, audio, params=_stt_params(sample_rate), on_turn=on_turn) + + def complete_reply(messages: list[ChatCompletionMessageParam]) -> str: + response = llm.complete(api_key, model=config.model, messages=messages) + return llm.content_of(response) + + def synthesize(text: str) -> bytes: + spec = SpeakConfig(text=text, voice=config.voice, sample_rate=TTS_SAMPLE_RATE) + return tts_session.synthesize(api_key, spec).pcm + + return cls(run_stt=run_stt, complete_reply=complete_reply, synthesize=synthesize) + + +@dataclass +class CascadeSession: + """Per-conversation state: the running history and the in-flight reply worker.""" + + deps: CascadeDeps + renderer: Renderer + player: Player + config: CascadeConfig + history: list[ChatCompletionMessageParam] = field(default_factory=_new_history) + # First leg failure (LLM/TTS). Recorded on the reply worker thread, where raising + # would dump a thread traceback, and re-raised from the main thread to fail cleanly. + error: CLIError | None = None + _reply: _Worker | None = field(default=None, init=False) # pragma: no mutate + _stop: threading.Event = field(default_factory=threading.Event, init=False) # pragma: no mutate + + def greet(self) -> None: + """Speak the opening greeting (if any) and seed it into the history so the + model has a record of its own first line.""" + greeting = self.config.greeting + if not greeting: + return + self.history.append({"role": "assistant", "content": greeting}) + self.renderer.agent_transcript(greeting, interrupted=False) + try: + self.player.enqueue(self.deps.synthesize(greeting)) + except CLIError as exc: + self._record_error(exc) + + def on_turn(self, event: object) -> None: + """Handle one STT turn: reply to a finalized turn, otherwise just barge in. + + Runs on the STT reader thread. An interim turn only interrupts a playing + reply; a finalized, formatted turn is shown and answered. + """ + text = (getattr(event, "transcript", "") or "").strip() + if not text: + return + if _is_final_turn(event): + self.renderer.user_final(text) + self._barge_in() + self.history.append({"role": "user", "content": text}) + trim_history(self.history, self.config.max_history) + self._start_reply() + else: + self.renderer.user_partial(text) + self._barge_in() + + def _barge_in(self) -> None: + """Stop a reply that is still playing: flush the queued audio and cancel the + worker (the player flush is what silences the browser-equivalent local buffer).""" + if self._reply is not None and self._reply.is_alive(): + self._stop.set() + self.player.flush() + self._join_reply() + + def _join_reply(self) -> None: + """Wait for the current reply worker (if any) to unwind, then drop the handle.""" + worker = self._reply + if worker is not None and worker.is_alive(): + worker.join() + self._reply = None + + def _start_reply(self) -> None: + self._stop.clear() + self._reply = self.deps.spawn(self._generate_reply) + + def _generate_reply(self) -> None: + """Stream the LLM reply, speak it sentence-by-sentence, and record what was + actually spoken (so a barge-in still leaves the history alternating).""" + messages: list[ChatCompletionMessageParam] = [ + {"role": "system", "content": self.config.system_prompt}, + *self.history, + ] + try: + reply = self.deps.complete_reply(messages) + except CLIError as exc: + self._record_error(exc) + return + self.renderer.reply_started() + spoken: list[str] = [] + for sentence in split_sentences(reply): + if self._stop.is_set(): + break + self.renderer.agent_transcript(sentence, interrupted=False) + try: + pcm = self.deps.synthesize(sentence) + except CLIError as exc: + self._record_error(exc) + break + if self._stop.is_set(): + break + self.player.enqueue(pcm) + spoken.append(sentence) + spoken_text = " ".join(spoken).strip() + if spoken_text: + self.history.append({"role": "assistant", "content": spoken_text}) + trim_history(self.history, self.config.max_history) + self.renderer.reply_done(interrupted=self._stop.is_set()) + + def _record_error(self, exc: CLIError) -> None: + """Keep the first leg failure (to re-raise on the main thread) and warn now, + since the worker thread can't surface an exit code itself.""" + if self.error is None: + self.error = exc + output.error_console.print(f"[aai.warn]agent-framework:[/aai.warn] {exc.message}") + + def shutdown(self) -> None: + """Stop and join any in-flight reply worker (run on every exit path).""" + self._stop.set() + self._join_reply() + + +def _is_final_turn(event: object) -> bool: + """True for a finalized, formatted end-of-turn — the cue to generate a reply.""" + return bool(getattr(event, "end_of_turn", False)) and bool( + getattr(event, "turn_is_formatted", False) + ) + + +def run_cascade( + *, renderer: Renderer, player: Player, config: CascadeConfig, deps: CascadeDeps +) -> None: + """Run one terminal cascade conversation until STT closes or the user stops. + + Greets, then pumps STT turns through the LLM+TTS reply path. A recorded leg + failure is re-raised here so the command exits with the right code. + """ + session = CascadeSession(deps=deps, renderer=renderer, player=player, config=config) + player.start() + try: + session.greet() + renderer.connected() + deps.run_stt(session.on_turn) + finally: + session.shutdown() + with contextlib.suppress(Exception): + player.close() + if session.error is not None: + raise session.error diff --git a/aai_cli/agent_framework/text.py b/aai_cli/agent_framework/text.py new file mode 100644 index 00000000..66b38ea7 --- /dev/null +++ b/aai_cli/agent_framework/text.py @@ -0,0 +1,41 @@ +"""Pure text helpers for the cascade: sentence splitting and history trimming. + +Kept Rich-free and dependency-light so the orchestration logic in ``engine`` can +be unit-tested without any I/O. +""" + +from __future__ import annotations + +# A reply is spoken sentence-by-sentence so the first audio plays before the whole +# answer is synthesized; a sentence ends at one of these terminators. +_TERMINATORS = ".!?" + + +def split_sentences(text: str) -> list[str]: + """Split ``text`` into sentences, each ending in ``.``/``!``/``?``. + + A trailing fragment with no terminal punctuation is kept as a final sentence, + so no text is ever dropped; empty/whitespace-only pieces are discarded. + """ + sentences: list[str] = [] + start = 0 + for index, char in enumerate(text): + if char in _TERMINATORS: + # The slice always includes the terminator at ``index``, so it is never + # blank after stripping the inter-sentence whitespace. + sentences.append(text[start : index + 1].strip()) + start = index + 1 + tail = text[start:].strip() + if tail: + sentences.append(tail) + return sentences + + +def trim_history[T](history: list[T], max_messages: int) -> None: + """Cap ``history`` to its most recent ``max_messages`` entries, in place. + + A sliding window over the conversation so an unbounded chat doesn't grow the + context (and the per-turn token cost) without limit. + """ + if len(history) > max_messages: + del history[: len(history) - max_messages] diff --git a/aai_cli/agent_framework/voices.py b/aai_cli/agent_framework/voices.py new file mode 100644 index 00000000..de55bb6f --- /dev/null +++ b/aai_cli/agent_framework/voices.py @@ -0,0 +1,43 @@ +"""The voices `assembly agent-framework` speaks with. + +The cascade's audio comes from streaming TTS, so its voices are the TTS catalog +(`aai_cli.tts.voices`) — not the Voice Agent voices `assembly agent` uses. This +module is the thin presentation layer over that catalog: the membership list +that catches a typo'd ``--voice``, the completion callback, and the grouped +``--list-voices`` rendering. +""" + +from __future__ import annotations + +from aai_cli.tts import voices as tts_voices + +DEFAULT_VOICE = "jane" + +# The selectable voice ids, sorted for a stable --list-voices / completion order. +VOICE_NAMES: list[str] = sorted(tts_voices.VOICE_LANGUAGES) + +# ISO 639-1 code -> the heading --list-voices groups that language's voices under. +_LANGUAGE_LABELS: dict[str, str] = { + "en": "English", + "fr": "French", + "it": "Italian", + "de": "German", + "es": "Spanish", + "pt": "Portuguese", +} + + +def complete_voice(incomplete: str) -> list[str]: + """Shell-completion callback for ``--voice``: catalog ids matching the prefix.""" + return [name for name in VOICE_NAMES if name.startswith(incomplete)] + + +def format_voice_list() -> str: + """Human-readable voice ids for ``--list-voices``, grouped by language.""" + blocks: list[str] = [] + for code, label in _LANGUAGE_LABELS.items(): + names = [name for name in VOICE_NAMES if tts_voices.VOICE_LANGUAGES[name] == code] + if names: + listing = "\n".join(f" {name}" for name in names) + blocks.append(f"{label}:\n{listing}") + return "\n\n".join(blocks) diff --git a/aai_cli/commands/agent_framework/__init__.py b/aai_cli/commands/agent_framework/__init__.py new file mode 100644 index 00000000..6b8c8e61 --- /dev/null +++ b/aai_cli/commands/agent_framework/__init__.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from pathlib import Path + +import typer + +from aai_cli import command_registry, help_panels, options +from aai_cli.agent_framework import voices +from aai_cli.agent_framework.config import ( + DEFAULT_GREETING, + DEFAULT_MODEL, + DEFAULT_SYSTEM_PROMPT, +) +from aai_cli.agent_framework.voices import DEFAULT_VOICE +from aai_cli.app.context import AppState, run_command, run_with_options +from aai_cli.commands.agent_framework import _exec as agent_framework_exec +from aai_cli.core import choices, llm +from aai_cli.ui import output +from aai_cli.ui.help_text import examples_epilog + +app = typer.Typer() + +SPEC = command_registry.CommandModuleSpec( + panel=help_panels.TRANSCRIPTION, + order=45, # pragma: no mutate -- sparse rank; a +-1 shift is order-equivalent + commands=("agent-framework",), +) + + +def _emit_voice_list(_state: AppState, json_mode: bool) -> None: + """--list-voices body, routed through run_command so --json yields a machine-readable + array instead of the human list; needs no auth.""" + payload = [{"name": name} for name in voices.VOICE_NAMES] + output.emit(payload, lambda _voices: voices.format_voice_list(), json_mode=json_mode) + + +@app.command( + name="agent-framework", + rich_help_panel=help_panels.TRANSCRIPTION, + epilog=examples_epilog( + [ + ("Start a live cascade conversation", "assembly --sandbox agent-framework"), + ( + "Pick a voice and opening line", + 'assembly --sandbox agent-framework --voice michael --greeting "Hi there"', + ), + ( + "Give the agent a persona", + 'assembly --sandbox agent-framework --system-prompt "You are a terse pirate."', + ), + ("See available voices", "assembly --sandbox agent-framework --list-voices"), + ] + ), +) +def agent_framework( + ctx: typer.Context, + source: str | None = typer.Argument( + None, help="Audio file path or URL to speak to the agent. Omit to use the microphone." + ), + sample: bool = typer.Option( + False, "--sample", help="Speak the hosted wildfires.mp3 sample to the agent" + ), + voice: str = typer.Option( + DEFAULT_VOICE, + "--voice", + help="TTS voice. See --list-voices.", + autocompletion=voices.complete_voice, + ), + model: str = typer.Option( + DEFAULT_MODEL, + "--model", + help="LLM Gateway model that powers the agent's replies", + autocompletion=llm.complete_model, + ), + system_prompt: str = typer.Option( + DEFAULT_SYSTEM_PROMPT, "--system-prompt", help="System prompt (the agent's persona)" + ), + system_prompt_file: Path | None = typer.Option( + None, + "--system-prompt-file", + help="Read the system prompt from a file (overrides --system-prompt)", + exists=True, + dir_okay=False, + ), + greeting: str = typer.Option(DEFAULT_GREETING, "--greeting", help="Spoken greeting"), + device: int | None = typer.Option(None, "--device", help="Microphone device index"), + list_voices: bool = typer.Option(False, "--list-voices", help="Print known voices and exit"), + json_out: bool = options.json_option("Emit newline-delimited JSON events"), + output_field: choices.TextOrJson | None = typer.Option( + None, + "-o", + "--output", + help="Output mode: text (you:/agent: lines as plain stdout, pipe-friendly) or json", + ), +) -> None: + """\\[sandbox] Hold a live voice conversation through a self-wired cascade + + Like 'assembly agent', but instead of AssemblyAI's Voice Agent endpoint this + wires the three primitives together itself — Streaming STT, the LLM Gateway, + and streaming TTS — exactly like the 'agent-framework' init template does + server-side. Because it uses streaming TTS it only runs in the sandbox: run + it as 'assembly --sandbox agent-framework' (--sandbox goes before the + subcommand). + + Use headphones: the mic stays open while the agent speaks, so on speakers it + would hear itself and loop. Pass an audio file/URL (or --sample) to speak a + recorded clip instead of the microphone; the session then ends after the + agent's reply. + + This only runs a conversation in the terminal — it writes no code. To build + an agent-framework app, run 'assembly init agent-framework' instead. + """ + + if list_voices: + run_command(ctx, _emit_voice_list, json=json_out) + return + + opts = agent_framework_exec.AgentFrameworkOptions( + source=source, + sample=sample, + voice=voice, + model=model, + system_prompt=system_prompt, + system_prompt_file=system_prompt_file, + greeting=greeting, + device=device, + output_field=output_field, + ) + run_with_options(ctx, agent_framework_exec.run_agent_framework, opts, json=json_out) diff --git a/aai_cli/commands/agent_framework/_exec.py b/aai_cli/commands/agent_framework/_exec.py new file mode 100644 index 00000000..71ee6d8d --- /dev/null +++ b/aai_cli/commands/agent_framework/_exec.py @@ -0,0 +1,129 @@ +"""Run logic for `assembly agent-framework`: the options/run split (see AGENTS.md). + +The command module parses argv into an ``AgentFrameworkOptions`` and hands it to +``run_agent_framework``, so tests drive validation and the cascade wiring by +constructing options directly rather than round-tripping through ``CliRunner``. +""" + +from __future__ import annotations + +import contextlib +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path + +import typer + +from aai_cli.agent.audio import SAMPLE_RATE, DuplexAudio, NullPlayer +from aai_cli.agent.render import AgentRenderer +from aai_cli.agent_framework import engine, voices +from aai_cli.agent_framework.config import CascadeConfig +from aai_cli.app.context import AppState +from aai_cli.core import choices, client +from aai_cli.core.errors import CLIError, UsageError +from aai_cli.streaming.session import resolve_output_modes +from aai_cli.streaming.sources import FileSource +from aai_cli.tts import session as tts_session + + +@dataclass(frozen=True) +class AgentFrameworkOptions: + """Every `assembly agent-framework` conversation flag as plain data. + + ``--list-voices`` is excluded: it dispatches to its own auth-free body in the + command module. ``--json`` is excluded: run_command resolves it into the + ``json_mode`` argument. + """ + + source: str | None + sample: bool + voice: str + model: str + system_prompt: str + system_prompt_file: Path | None + greeting: str + device: int | None + output_field: choices.TextOrJson | None + + +def _resolve_system_prompt(system_prompt: str, system_prompt_file: Path | None) -> str: + """The persona text: a --system-prompt-file (if given) overrides --system-prompt.""" + if system_prompt_file is None: + return system_prompt + try: + return system_prompt_file.read_text(encoding="utf-8") + except OSError as exc: + raise CLIError( + f"Could not read --system-prompt-file {system_prompt_file}: {exc}", + error_type="file_not_found", + exit_code=2, + suggestion="Check the path and that the file is readable.", + ) from exc + + +def _open_audio( + renderer: AgentRenderer, + *, + source: str | None, + sample: bool, + device: int | None, + from_file: bool, +) -> tuple[Iterable[bytes], engine.Player, int]: + """Build the (audio, player, sample_rate) triple for file- or mic-driven input.""" + if from_file: + # Stream the clip as the user's speech; no listener, so discard the reply audio. + file_source = FileSource(client.resolve_audio_source(source, sample=sample)) + return file_source, NullPlayer(), file_source.sample_rate + # One full-duplex stream for mic + speaker: macOS rejects two separate streams on + # one device, which silently kills capture. + duplex = DuplexAudio(target_rate=SAMPLE_RATE, device=device) + renderer.notice( + "Use headphones — the mic stays open while the agent speaks, " + "so speakers would let it hear itself.\n" + ) + return duplex.mic, duplex.player, SAMPLE_RATE + + +def run_agent_framework(opts: AgentFrameworkOptions, state: AppState, *, json_mode: bool) -> None: + """Execute one `assembly agent-framework` cascade from already-parsed flags.""" + text_mode, json_mode = resolve_output_modes(opts.output_field, json_mode=json_mode) + if opts.voice not in voices.VOICE_NAMES: + raise UsageError( + f"Unknown voice {opts.voice!r}.", + suggestion="Run 'assembly agent-framework --list-voices' to see the options.", + ) + # Streaming TTS has no production host, so the whole cascade is sandbox-only. + tts_session.require_available("agent-framework") + system_prompt_text = _resolve_system_prompt(opts.system_prompt, opts.system_prompt_file) + + from_file = bool(opts.source) or opts.sample + if from_file and opts.device is not None: + raise UsageError("--device applies only to microphone input.") + if from_file: + # Existence-check the clip before credentials, so a typo'd path reads as + # "file not found" instead of triggering a login. + client.resolve_audio_source(opts.source, sample=opts.sample) + api_key = state.resolve_api_key() + + config = CascadeConfig( + voice=opts.voice, + system_prompt=system_prompt_text, + # File-driven runs speak a clip and end after the reply, so skip the greeting. + greeting="" if from_file else opts.greeting, + model=opts.model, + ) + renderer = AgentRenderer(json_mode=json_mode, text_mode=text_mode, mic_input=not from_file) + audio, player, sample_rate = _open_audio( + renderer, source=opts.source, sample=opts.sample, device=opts.device, from_file=from_file + ) + deps = engine.CascadeDeps.real(api_key, config, audio=audio, sample_rate=sample_rate) + try: + engine.run_cascade(renderer=renderer, player=player, config=config, deps=deps) + except KeyboardInterrupt: + renderer.stopped() + except BrokenPipeError as exc: + # Downstream consumer (e.g. `| head`) closed the pipe; stop quietly. + raise typer.Exit(code=0) from exc + finally: + with contextlib.suppress(BrokenPipeError): + renderer.close() diff --git a/aai_cli/init/templates/agent_framework/api/cascade.py b/aai_cli/init/templates/agent_framework/api/cascade.py index e2b27bcf..5c41bca3 100644 --- a/aai_cli/init/templates/agent_framework/api/cascade.py +++ b/aai_cli/init/templates/agent_framework/api/cascade.py @@ -150,14 +150,14 @@ async def cancel_reply(self) -> None: if task is not None and not task.done(): task.cancel() with contextlib.suppress(asyncio.CancelledError, Exception): - await task + await asyncio.gather(task) async def drain(self) -> None: """Await the in-flight reply to natural completion (used when STT closes).""" task = self.reply_task if task is not None: with contextlib.suppress(Exception): - await task + await asyncio.gather(task) async def _connect_stt(settings: _Settings) -> ClientConnection: @@ -347,7 +347,7 @@ class _SessionClosed(Exception): async def _until_closed(pump: Awaitable[None]) -> None: """Run a pump to its natural end, then raise to close the session TaskGroup.""" - await pump + await asyncio.gather(pump) raise _SessionClosed diff --git a/aai_cli/init/templates/audio_transcription/api/index.py b/aai_cli/init/templates/audio_transcription/api/index.py index 5a29fd14..9c938a37 100644 --- a/aai_cli/init/templates/audio_transcription/api/index.py +++ b/aai_cli/init/templates/audio_transcription/api/index.py @@ -15,6 +15,7 @@ import tempfile import uuid +from abc import abstractmethod from pathlib import Path from typing import Protocol @@ -118,7 +119,9 @@ def ask(transcript_id: str = Body(...), question: str = Body(...)) -> dict[str, class _Serializable(Protocol): """The pydantic-model surface we use: a `.dict()` returning the full JSON.""" - def dict(self) -> dict[str, object]: ... + @abstractmethod + def dict(self) -> dict[str, object]: + """Return the model's full JSON as a dict.""" def _to_payload(model: _Serializable) -> dict[str, object]: diff --git a/tests/__snapshots__/test_snapshots_help_root.ambr b/tests/__snapshots__/test_snapshots_help_root.ambr index 742790c9..5095b6db 100644 --- a/tests/__snapshots__/test_snapshots_help_root.ambr +++ b/tests/__snapshots__/test_snapshots_help_root.ambr @@ -32,52 +32,58 @@ │ exit. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Quick Start ────────────────────────────────────────────────────────────────╮ - │ onboard Guided setup: sign in and run your first transcription │ + │ onboard Guided setup: sign in and run your first transcription │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Build an App ───────────────────────────────────────────────────────────────╮ - │ init Scaffold a new app from a template and launch it │ - │ dev Run the dev server for the app in the current directory │ - │ share Expose the local app on a public URL via a cloudflared tunnel │ - │ deploy Deploy the current project to Vercel, Railway, or Fly.io │ + │ init Scaffold a new app from a template and launch it │ + │ dev Run the dev server for the app in the current directory │ + │ share Expose the local app on a public URL via a cloudflared │ + │ tunnel │ + │ deploy Deploy the current project to Vercel, Railway, or Fly.io │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Run AssemblyAI ─────────────────────────────────────────────────────────────╮ - │ transcribe Transcribe a file, URL, or YouTube/podcast link — or a whole │ - │ batch │ - │ stream Transcribe live audio in real time from a mic, file, URL, or │ - │ pipe │ - │ dictate Push-to-talk dictation: record the mic, get the transcript back │ - │ agent Hold a live two-way voice conversation with a voice agent │ - │ speak [sandbox] Synthesize speech from text with AssemblyAI streaming │ - │ TTS │ - │ llm Send a prompt to AssemblyAI's LLM Gateway and print the reply │ - │ clip Cut clips from media by speaker, text match, LLM pick, or time │ - │ range │ - │ dub [sandbox] Dub a video or audio file into another language │ - │ caption Burn always-visible captions into a video │ - │ eval Transcribe a dataset and score WER against its reference texts │ - │ webhooks Receive webhook deliveries on a public dev URL │ + │ transcribe Transcribe a file, URL, or YouTube/podcast link — or a │ + │ whole batch │ + │ stream Transcribe live audio in real time from a mic, file, URL, │ + │ or pipe │ + │ dictate Push-to-talk dictation: record the mic, get the transcript │ + │ back │ + │ agent Hold a live two-way voice conversation with a voice agent │ + │ agent-framework [sandbox] Hold a live voice conversation through a │ + │ self-wired cascade │ + │ speak [sandbox] Synthesize speech from text with AssemblyAI │ + │ streaming TTS │ + │ llm Send a prompt to AssemblyAI's LLM Gateway and print the │ + │ reply │ + │ clip Cut clips from media by speaker, text match, LLM pick, or │ + │ time range │ + │ dub [sandbox] Dub a video or audio file into another language │ + │ caption Burn always-visible captions into a video │ + │ eval Transcribe a dataset and score WER against its reference │ + │ texts │ + │ webhooks Receive webhook deliveries on a public dev URL │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Setup & Tools ──────────────────────────────────────────────────────────────╮ - │ doctor Check that your environment is ready for AssemblyAI │ - │ setup Set up your coding agent for AssemblyAI (docs MCP + skills) │ - │ config Inspect and edit persisted CLI settings (profiles, env, │ - │ telemetry) │ - │ update Update the CLI to the latest release (brew/pipx/uv) │ - │ telemetry Anonymous usage telemetry: status, enable, disable │ + │ doctor Check that your environment is ready for AssemblyAI │ + │ setup Set up your coding agent for AssemblyAI (docs MCP + skills) │ + │ config Inspect and edit persisted CLI settings (profiles, env, │ + │ telemetry) │ + │ update Update the CLI to the latest release (brew/pipx/uv) │ + │ telemetry Anonymous usage telemetry: status, enable, disable │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ History ────────────────────────────────────────────────────────────────────╮ - │ transcripts Browse and fetch past transcripts │ - │ sessions Browse your past streaming (real-time) sessions │ + │ transcripts Browse and fetch past transcripts │ + │ sessions Browse your past streaming (real-time) sessions │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Account ────────────────────────────────────────────────────────────────────╮ - │ login Authenticate via your browser and store a CLI API key │ - │ logout Clear stored credentials for the active profile │ - │ whoami Show the active profile and whether its key works │ - │ balance Show your remaining account balance │ - │ usage Show usage over a date range (default: last 30 days) │ - │ limits Show your account's rate limits per service │ - │ keys List, create, and rename your AssemblyAI API keys │ - │ audit List recent audit-log entries for your account │ + │ login Authenticate via your browser and store a CLI API key │ + │ logout Clear stored credentials for the active profile │ + │ whoami Show the active profile and whether its key works │ + │ balance Show your remaining account balance │ + │ usage Show usage over a date range (default: last 30 days) │ + │ limits Show your account's rate limits per service │ + │ keys List, create, and rename your AssemblyAI API keys │ + │ audit List recent audit-log entries for your account │ ╰──────────────────────────────────────────────────────────────────────────────╯ Examples diff --git a/tests/__snapshots__/test_snapshots_help_run.ambr b/tests/__snapshots__/test_snapshots_help_run.ambr index 56e8e369..cf7c2d79 100644 --- a/tests/__snapshots__/test_snapshots_help_run.ambr +++ b/tests/__snapshots__/test_snapshots_help_run.ambr @@ -1,4 +1,78 @@ # serializer version: 1 +# name: test_command_help_matches_snapshot[agent-framework] + ''' + + Usage: assembly agent-framework [OPTIONS] [SOURCE] + + [sandbox] Hold a live voice conversation through a self-wired cascade + + Like 'assembly agent', but instead of AssemblyAI's Voice Agent endpoint this + wires the three primitives together itself — Streaming STT, the LLM Gateway, + and streaming TTS — exactly like the 'agent-framework' init template does + server-side. Because it uses streaming TTS it only runs in the sandbox: run + it as 'assembly --sandbox agent-framework' (--sandbox goes before the + subcommand). + + Use headphones: the mic stays open while the agent speaks, so on speakers it + would hear itself and loop. Pass an audio file/URL (or --sample) to speak a + recorded clip instead of the microphone; the session then ends after the + agent's reply. + + This only runs a conversation in the terminal — it writes no code. To build + an agent-framework app, run 'assembly init agent-framework' instead. + + ╭─ Arguments ──────────────────────────────────────────────────────────────────╮ + │ source [SOURCE] Audio file path or URL to speak to the agent. Omit │ + │ to use the microphone. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Options ────────────────────────────────────────────────────────────────────╮ + │ --sample Speak the hosted wildfires.mp3 │ + │ sample to the agent │ + │ --voice TEXT TTS voice. See --list-voices. │ + │ [default: jane] │ + │ --model TEXT LLM Gateway model that powers the │ + │ agent's replies │ + │ [default: │ + │ claude-haiku-4-5-20251001] │ + │ --system-prompt TEXT System prompt (the agent's │ + │ persona) │ + │ [default: You are a friendly, │ + │ concise voice assistant. Keep │ + │ replies short and conversational. │ + │ Your reply is read aloud by a │ + │ text-to-speech engine, so write │ + │ plain spoken prose — no markdown, │ + │ emoji, bullet lists, or code.] │ + │ --system-prompt-file FILE Read the system prompt from a │ + │ file (overrides --system-prompt) │ + │ --greeting TEXT Spoken greeting │ + │ [default: Hi! I'm your AssemblyAI │ + │ voice agent. What can I help you │ + │ with?] │ + │ --device INTEGER Microphone device index │ + │ --list-voices Print known voices and exit │ + │ --json -j Emit newline-delimited JSON │ + │ events │ + │ --output -o [text|json] Output mode: text (you:/agent: │ + │ lines as plain stdout, │ + │ pipe-friendly) or json │ + │ --help Show this message and exit. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + Examples + Start a live cascade conversation + $ assembly --sandbox agent-framework + Pick a voice and opening line + $ assembly --sandbox agent-framework --voice michael --greeting "Hi there" + Give the agent a persona + $ assembly --sandbox agent-framework --system-prompt "You are a terse pirate." + See available voices + $ assembly --sandbox agent-framework --list-voices + + + + ''' +# --- # name: test_command_help_matches_snapshot[agent] ''' diff --git a/tests/test_agent_framework_command.py b/tests/test_agent_framework_command.py new file mode 100644 index 00000000..cd4d4f52 --- /dev/null +++ b/tests/test_agent_framework_command.py @@ -0,0 +1,274 @@ +"""Command + wiring tests for `assembly agent-framework`. + +Covers the argv -> options seam, the validation guards, _open_audio source +selection, and CascadeDeps.real's three live legs (all driven against fakes). +""" + +from __future__ import annotations + +import dataclasses +import types + +import pytest +import typer +from typer.testing import CliRunner + +from aai_cli.agent.render import AgentRenderer +from aai_cli.agent_framework import engine +from aai_cli.agent_framework.config import CascadeConfig +from aai_cli.agent_framework.engine import CascadeDeps +from aai_cli.app.context import AppState +from aai_cli.commands.agent_framework import _exec +from aai_cli.commands.agent_framework._exec import AgentFrameworkOptions, run_agent_framework +from aai_cli.core import config +from aai_cli.core.errors import CLIError, UsageError +from aai_cli.main import app + +runner = CliRunner() + + +_DEFAULTS = AgentFrameworkOptions( + source=None, + sample=False, + voice="jane", + model="claude-haiku-4-5-20251001", + system_prompt="be nice", + system_prompt_file=None, + greeting="hello", + device=None, + output_field=None, +) + + +def _opts(**overrides) -> AgentFrameworkOptions: + return dataclasses.replace(_DEFAULTS, **overrides) + + +# --- help / list-voices ------------------------------------------------------ + + +def test_list_voices_human_lists_catalog(): + result = runner.invoke(app, ["agent-framework", "--list-voices"]) + assert result.exit_code == 0 + assert "jane" in result.output + assert "English:" in result.output + + +def test_list_voices_json_emits_array(): + result = runner.invoke(app, ["agent-framework", "--list-voices", "--json"]) + assert result.exit_code == 0 + assert result.output.lstrip().startswith("[") + assert '"jane"' in result.output + + +# --- validation guards ------------------------------------------------------- + + +def test_options_are_frozen(): + attr = "voice" # not a literal, so ruff's B010 leaves the setattr in place + with pytest.raises(dataclasses.FrozenInstanceError): + setattr(_DEFAULTS, attr, "other") + + +def test_unknown_voice_is_a_usage_error(): + with pytest.raises(UsageError, match="Unknown voice"): + run_agent_framework(_opts(voice="nope"), AppState(), json_mode=False) + + +def test_missing_system_prompt_file_is_rejected_by_typer(): + # exists=True on the option makes Typer reject a nonexistent path before the body, + # so the sandbox guard (the other exit-2 path) never runs. Asserting the guard's + # message is absent kills the exists=True mutant without depending on the Rich error + # text, which CI renders with ANSI + width ellipsis. + result = runner.invoke(app, ["agent-framework", "--system-prompt-file", "/no/such/file"]) + assert result.exit_code == 2 + assert "sandbox" not in result.output.lower() + + +def test_production_env_is_rejected_with_sandbox_hint(): + # Default env is production, which has no streaming-TTS host. + result = runner.invoke(app, ["agent-framework", "--voice", "jane"]) + assert result.exit_code == 2 + assert "only available in the sandbox" in result.output + + +def test_device_with_file_source_is_rejected(monkeypatch): + monkeypatch.setattr(_exec.tts_session, "require_available", lambda _c: None) + with pytest.raises(UsageError, match="--device applies only to microphone"): + run_agent_framework(_opts(source="clip.wav", device=2), AppState(), json_mode=False) + + +# --- system prompt resolution ------------------------------------------------ + + +def test_resolve_system_prompt_prefers_file(tmp_path): + path = tmp_path / "persona.txt" + path.write_text("you are a pirate", encoding="utf-8") + assert _exec._resolve_system_prompt("ignored", path) == "you are a pirate" + + +def test_resolve_system_prompt_without_file_passes_through(): + assert _exec._resolve_system_prompt("default persona", None) == "default persona" + + +def test_resolve_system_prompt_unreadable_file_errors(tmp_path): + missing = tmp_path / "nope.txt" + with pytest.raises(CLIError, match="Could not read --system-prompt-file") as exc: + _exec._resolve_system_prompt("x", missing) + assert exc.value.exit_code == 2 + + +# --- _open_audio ------------------------------------------------------------- + + +def _renderer() -> AgentRenderer: + return AgentRenderer(json_mode=False, text_mode=False) + + +def test_open_audio_file_uses_nullplayer_and_source_rate(monkeypatch): + fake_source = types.SimpleNamespace(sample_rate=16000) + monkeypatch.setattr(_exec, "FileSource", lambda src: fake_source) + monkeypatch.setattr(_exec.client, "resolve_audio_source", lambda source, sample: "clip.wav") + audio, player, rate = _exec._open_audio( + _renderer(), source="clip.wav", sample=False, device=None, from_file=True + ) + assert audio is fake_source + assert isinstance(player, _exec.NullPlayer) + assert rate == 16000 + + +def test_open_audio_mic_warns_and_uses_duplex_rate(monkeypatch): + fake_duplex = types.SimpleNamespace(mic=object(), player=object()) + monkeypatch.setattr(_exec, "DuplexAudio", lambda **kwargs: fake_duplex) + renderer = _renderer() + notices: list[str] = [] + monkeypatch.setattr(renderer, "notice", notices.append) + audio, player, rate = _exec._open_audio( + renderer, source=None, sample=False, device=None, from_file=False + ) + assert audio is fake_duplex.mic + assert player is fake_duplex.player + assert rate == _exec.SAMPLE_RATE + assert any("headphones" in note for note in notices) + + +# --- run_agent_framework wiring ---------------------------------------------- + + +def test_run_wires_deps_and_invokes_cascade(monkeypatch): + monkeypatch.setattr(_exec.tts_session, "require_available", lambda _c: None) + monkeypatch.setattr(config, "resolve_api_key", lambda **_: "test-key") + fake_source = types.SimpleNamespace(sample_rate=16000) + monkeypatch.setattr(_exec, "FileSource", lambda src: fake_source) + monkeypatch.setattr(_exec.client, "resolve_audio_source", lambda source, sample: "clip.wav") + captured = {} + + def fake_run_cascade(*, renderer, player, config, deps): + captured["config"] = config + captured["deps"] = deps + + monkeypatch.setattr(_exec.engine, "run_cascade", fake_run_cascade) + run_agent_framework( + _opts(source="clip.wav", voice="michael", greeting="hi there"), AppState(), json_mode=False + ) + # File-driven runs drop the greeting and carry the chosen voice into the config. + assert captured["config"].greeting == "" + assert captured["config"].voice == "michael" + assert isinstance(captured["deps"], CascadeDeps) + + +class _RecordingRenderer: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.stopped_called = False + self.closed = False + + def notice(self, text): + pass + + def stopped(self): + self.stopped_called = True + + def close(self): + self.closed = True + + +def _wire_run(monkeypatch, run_cascade): + """Stub out auth/audio/cascade so run_agent_framework reaches the run_cascade call.""" + monkeypatch.setattr(_exec.tts_session, "require_available", lambda _c: None) + monkeypatch.setattr(config, "resolve_api_key", lambda **_: "k") + monkeypatch.setattr(_exec, "FileSource", lambda src: types.SimpleNamespace(sample_rate=16000)) + monkeypatch.setattr(_exec.client, "resolve_audio_source", lambda source, sample: "clip.wav") + monkeypatch.setattr(_exec.engine, "run_cascade", run_cascade) + rendered = {} + monkeypatch.setattr( + _exec, "AgentRenderer", lambda **kw: rendered.setdefault("r", _RecordingRenderer(**kw)) + ) + return rendered + + +def test_keyboard_interrupt_stops_cleanly(monkeypatch): + def boom(**kwargs): + raise KeyboardInterrupt + + rendered = _wire_run(monkeypatch, boom) + run_agent_framework(_opts(source="clip.wav"), AppState(), json_mode=False) + assert rendered["r"].stopped_called is True + assert rendered["r"].closed is True + + +def test_broken_pipe_exits_zero(monkeypatch): + def boom(**kwargs): + raise BrokenPipeError + + rendered = _wire_run(monkeypatch, boom) + with pytest.raises(typer.Exit) as exc: + run_agent_framework(_opts(source="clip.wav"), AppState(), json_mode=False) + assert exc.value.exit_code == 0 + assert rendered["r"].closed is True + + +# --- CascadeDeps.real (the three live legs) ---------------------------------- + + +def test_deps_real_run_stt_passes_formatted_params(monkeypatch): + captured = {} + + def fake_stream_audio(api_key, source, *, params, on_turn): + captured["api_key"] = api_key + captured["source"] = source + captured["params"] = params + + monkeypatch.setattr(engine.client, "stream_audio", fake_stream_audio) + audio: list[bytes] = [] + deps = CascadeDeps.real("k", CascadeConfig(), audio=audio, sample_rate=16000) + deps.run_stt(lambda event: None) + assert captured["api_key"] == "k" + assert captured["source"] is audio + assert captured["params"].sample_rate == 16000 + assert captured["params"].format_turns is True + + +def test_deps_real_complete_reply_returns_content(monkeypatch): + monkeypatch.setattr(engine.llm, "complete", lambda api_key, **kwargs: "raw-response") + monkeypatch.setattr(engine.llm, "content_of", lambda response: response.upper()) + deps = CascadeDeps.real("k", CascadeConfig(model="m"), audio=[], sample_rate=16000) + assert deps.complete_reply([{"role": "user", "content": "hi"}]) == "RAW-RESPONSE" + + +def test_deps_real_synthesize_returns_pcm(monkeypatch): + captured = {} + + def fake_synth(api_key, spec): + captured["voice"] = spec.voice + captured["text"] = spec.text + captured["sample_rate"] = spec.sample_rate + return types.SimpleNamespace(pcm=b"AUDIO") + + monkeypatch.setattr(engine.tts_session, "synthesize", fake_synth) + deps = CascadeDeps.real("k", CascadeConfig(voice="vera"), audio=[], sample_rate=16000) + assert deps.synthesize("say this") == b"AUDIO" + assert captured["voice"] == "vera" + assert captured["text"] == "say this" + # TTS always synthesizes at the 24 kHz the live player is opened at. + assert captured["sample_rate"] == engine.TTS_SAMPLE_RATE == 24000 diff --git a/tests/test_agent_framework_config.py b/tests/test_agent_framework_config.py new file mode 100644 index 00000000..8efa8eb1 --- /dev/null +++ b/tests/test_agent_framework_config.py @@ -0,0 +1,29 @@ +"""Tests for the cascade's per-run configuration defaults.""" + +from __future__ import annotations + +import dataclasses + +import pytest + +from aai_cli.agent_framework.config import DEFAULT_GREETING, DEFAULT_MAX_HISTORY, CascadeConfig +from aai_cli.agent_framework.voices import DEFAULT_VOICE +from aai_cli.core import llm + + +def test_default_config_values(): + config = CascadeConfig() + assert config.voice == DEFAULT_VOICE + assert config.model == llm.DEFAULT_MODEL + assert config.greeting == DEFAULT_GREETING + # The sliding-window default keeps the last 40 messages of context. + assert config.max_history == 40 + assert DEFAULT_MAX_HISTORY == 40 + + +def test_config_is_frozen(): + # Frozen so a parsed run config can't be mutated mid-conversation. + config = CascadeConfig() + attr = "voice" # not a literal, so ruff's B010 leaves the setattr in place + with pytest.raises(dataclasses.FrozenInstanceError): + setattr(config, attr, "other") diff --git a/tests/test_agent_framework_engine.py b/tests/test_agent_framework_engine.py new file mode 100644 index 00000000..e1fb3413 --- /dev/null +++ b/tests/test_agent_framework_engine.py @@ -0,0 +1,440 @@ +"""Orchestration tests for the terminal voice cascade (aai_cli.agent_framework.engine). + +The cascade's three network legs and its thread spawner are injected through +CascadeDeps, so every test here runs against fakes — no sockets, mic, or speaker. +""" + +from __future__ import annotations + +import threading +import types + +import pytest + +from aai_cli.agent_framework import engine +from aai_cli.agent_framework.config import CascadeConfig +from aai_cli.agent_framework.engine import CascadeDeps, CascadeSession, run_cascade +from aai_cli.core.errors import APIError + + +class FakeRenderer: + def __init__(self): + self.calls = [] + + def connected(self): + self.calls.append(("connected",)) + + def user_partial(self, text): + self.calls.append(("user_partial", text)) + + def user_final(self, text): + self.calls.append(("user_final", text)) + + def reply_started(self): + self.calls.append(("reply_started",)) + + def agent_transcript(self, text, *, interrupted): + self.calls.append(("agent_transcript", text, interrupted)) + + def reply_done(self, *, interrupted): + self.calls.append(("reply_done", interrupted)) + + +class FakePlayer: + def __init__(self): + self.enqueued = [] + self.flushed = 0 + self.started = False + self.closed = False + + def start(self): + self.started = True + + def enqueue(self, pcm): + self.enqueued.append(pcm) + + def flush(self): + self.flushed += 1 + + def close(self): + self.closed = True + + +class FakeWorker: + def __init__(self, *, alive): + self._alive = alive + self.joined = 0 + + def is_alive(self): + return self._alive + + def join(self): + self.joined += 1 + self._alive = False + + +def _sync_spawn(target): + """Run the reply body inline and hand back a finished worker, so the cascade is + driven deterministically without real threads.""" + target() + return FakeWorker(alive=False) + + +def _turn(text, *, end_of_turn=True, turn_is_formatted=True): + return types.SimpleNamespace( + transcript=text, end_of_turn=end_of_turn, turn_is_formatted=turn_is_formatted + ) + + +def make_session( + *, + complete_reply=lambda messages: "Hello there.", + synthesize=lambda text: b"pcm:" + text.encode(), + spawn=_sync_spawn, + run_stt=lambda on_turn: None, + config=None, +): + deps = CascadeDeps( + run_stt=run_stt, complete_reply=complete_reply, synthesize=synthesize, spawn=spawn + ) + renderer = FakeRenderer() + player = FakePlayer() + session = CascadeSession( + deps=deps, renderer=renderer, player=player, config=config or CascadeConfig() + ) + return session, renderer, player + + +# --- greeting ---------------------------------------------------------------- + + +def test_greet_speaks_and_seeds_history(): + session, renderer, player = make_session() + session.greet() + assert session.history == [{"role": "assistant", "content": session.config.greeting}] + assert ("agent_transcript", session.config.greeting, False) in renderer.calls + assert player.enqueued == [b"pcm:" + session.config.greeting.encode()] + + +def test_greet_empty_greeting_is_silent(): + session, renderer, player = make_session(config=CascadeConfig(greeting="")) + session.greet() + assert session.history == [] + assert renderer.calls == [] + assert player.enqueued == [] + + +def test_greet_records_tts_failure(): + def boom(text): + raise APIError("tts down") + + session, _renderer, player = make_session(synthesize=boom) + session.greet() + assert isinstance(session.error, APIError) + assert session.error.message == "tts down" + assert player.enqueued == [] # the failed greeting enqueued nothing + + +# --- turn dispatch ----------------------------------------------------------- + + +def test_on_turn_blank_transcript_ignored(): + session, renderer, _player = make_session() + session.on_turn(_turn(" ")) + assert renderer.calls == [] + assert session.history == [] + + +def test_on_turn_final_renders_and_replies(): + session, renderer, player = make_session(complete_reply=lambda m: "Sure thing.") + session.on_turn(_turn("what time is it")) + assert ("user_final", "what time is it") in renderer.calls + assert {"role": "user", "content": "what time is it"} in session.history + assert {"role": "assistant", "content": "Sure thing."} in session.history + assert player.enqueued == [b"pcm:Sure thing."] + assert ("reply_done", False) in renderer.calls + + +def test_on_turn_interim_shows_partial_and_does_not_reply(): + replies = [] + session, renderer, _player = make_session(complete_reply=lambda m: replies.append(m) or "x") + session.on_turn(_turn("partial words", end_of_turn=False)) + assert ("user_partial", "partial words") in renderer.calls + assert replies == [] # no reply generated for an interim turn + assert session.history == [] + + +def test_on_turn_interim_barges_in_on_live_reply(): + session, _renderer, player = make_session() + session._reply = FakeWorker(alive=True) + session.on_turn(_turn("uh", end_of_turn=False)) + assert player.flushed == 1 + assert session._reply is None + + +# --- reply generation -------------------------------------------------------- + + +def test_generate_reply_speaks_each_sentence(): + spoken = [] + session, renderer, player = make_session( + complete_reply=lambda m: "One. Two! Three?", + synthesize=lambda text: spoken.append(text) or text.encode(), + ) + session._generate_reply() + assert spoken == ["One.", "Two!", "Three?"] + assert player.enqueued == [b"One.", b"Two!", b"Three?"] + assert ("reply_started",) in renderer.calls + assert ("agent_transcript", "One.", False) in renderer.calls + assert session.history[-1] == {"role": "assistant", "content": "One. Two! Three?"} + assert ("reply_done", False) in renderer.calls + + +def test_generate_reply_threads_system_prompt_and_history(): + captured = {} + + def capture(messages): + captured["messages"] = messages + return "Ok." + + session, _renderer, _player = make_session( + complete_reply=capture, config=CascadeConfig(system_prompt="be terse") + ) + session.history.append({"role": "user", "content": "prior"}) + session._generate_reply() + assert captured["messages"][0] == {"role": "system", "content": "be terse"} + assert {"role": "user", "content": "prior"} in captured["messages"] + + +def test_generate_reply_trims_history_window(): + session, _renderer, _player = make_session( + complete_reply=lambda m: "a. b.", config=CascadeConfig(max_history=1) + ) + session.history.append({"role": "user", "content": "hi"}) + session._generate_reply() + # user + assistant would be 2; the window caps it to the most recent 1. + assert session.history == [{"role": "assistant", "content": "a. b."}] + + +def test_on_turn_trims_history_window(): + # An empty reply adds no assistant turn, so only on_turn's own trim caps the list. + session, _renderer, _player = make_session( + complete_reply=lambda m: "", config=CascadeConfig(max_history=1) + ) + session.history.append({"role": "assistant", "content": "old"}) + session.on_turn(_turn("newest")) + assert session.history == [{"role": "user", "content": "newest"}] + + +def test_generate_reply_stop_after_first_sentence_records_partial(): + def synth(text): + if text == "Two.": + session._stop.set() + return text.encode() + + session, renderer, player = make_session(complete_reply=lambda m: "One. Two. Three.") + session.deps.synthesize = synth + session._generate_reply() + # Only the first sentence finished enqueuing before the barge-in stop landed. + assert player.enqueued == [b"One."] + assert session.history[-1] == {"role": "assistant", "content": "One."} + assert ("reply_done", True) in renderer.calls + + +def test_generate_reply_stop_before_first_sentence_speaks_nothing(): + session, renderer, player = make_session(complete_reply=lambda m: "One. Two.") + session._stop.set() + session._generate_reply() + assert player.enqueued == [] + # nothing spoken -> no assistant turn recorded + assert all(item.get("role") != "assistant" for item in session.history) + assert ("reply_done", True) in renderer.calls + + +def test_generate_reply_llm_failure_is_recorded_and_aborts(): + def boom(messages): + raise APIError("gateway down") + + session, renderer, _player = make_session(complete_reply=boom) + session._generate_reply() + assert isinstance(session.error, APIError) + assert ("reply_started",) not in renderer.calls # aborted before speaking + + +def test_generate_reply_tts_failure_midway_is_recorded(): + def boom(text): + raise APIError("tts down") + + session, renderer, player = make_session(complete_reply=lambda m: "Hi.", synthesize=boom) + session._generate_reply() + assert isinstance(session.error, APIError) + assert player.enqueued == [] + assert ("reply_started",) in renderer.calls + assert ("reply_done", False) in renderer.calls + + +def test_record_error_keeps_first_and_warns(monkeypatch): + printed = [] + monkeypatch.setattr(engine.output.error_console, "print", lambda msg: printed.append(msg)) + session, _renderer, _player = make_session() + session._record_error(APIError("first")) + session._record_error(APIError("second")) + assert isinstance(session.error, APIError) + assert session.error.message == "first" + assert any("first" in str(msg) for msg in printed) + + +# --- barge-in / shutdown ----------------------------------------------------- + + +def test_barge_in_cancels_and_flushes_live_worker(): + session, _renderer, player = make_session() + worker = FakeWorker(alive=True) + session._reply = worker + session._barge_in() + assert session._stop.is_set() + assert player.flushed == 1 + assert worker.joined == 1 + assert session._reply is None + + +def test_barge_in_no_worker_does_not_flush(): + session, _renderer, player = make_session() + session._barge_in() + assert player.flushed == 0 + + +def test_barge_in_finished_worker_does_not_flush(): + session, _renderer, player = make_session() + session._reply = FakeWorker(alive=False) + session._barge_in() + assert player.flushed == 0 + assert session._reply is None + + +def test_shutdown_joins_live_worker(): + session, _renderer, _player = make_session() + worker = FakeWorker(alive=True) + session._reply = worker + session.shutdown() + assert session._stop.is_set() + assert worker.joined == 1 + assert session._reply is None + + +def test_shutdown_without_worker_is_safe(): + session, _renderer, _player = make_session() + session.shutdown() # no worker spawned + assert session._reply is None + + +# --- helpers ----------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("end_of_turn", "formatted", "expected"), + [(True, True, True), (True, False, False), (False, True, False), (False, False, False)], +) +def test_is_final_turn(end_of_turn, formatted, expected): + event = _turn("hi", end_of_turn=end_of_turn, turn_is_formatted=formatted) + assert engine._is_final_turn(event) is expected + + +def test_is_final_turn_defaults_missing_attrs_to_not_final(): + # A formatted turn missing end_of_turn (and vice versa) must read as not-final, so + # each absent field defaults to False rather than being treated as present-and-true. + assert engine._is_final_turn(types.SimpleNamespace(turn_is_formatted=True)) is False + assert engine._is_final_turn(types.SimpleNamespace(end_of_turn=True)) is False + + +def test_spawn_thread_runs_target(): + ran = threading.Event() + worker = engine._spawn_thread(ran.set) + worker.join() + assert ran.is_set() + assert worker.is_alive() is False + + +# --- run_cascade ------------------------------------------------------------- + + +def test_run_cascade_greets_then_pumps_turns(): + def run_stt(on_turn): + on_turn(_turn("hello")) + + session_box = {} + + def complete_reply(messages): + session_box["messages"] = messages + return "Hi back." + + renderer = FakeRenderer() + player = FakePlayer() + config = CascadeConfig(greeting="Welcome.") + deps = CascadeDeps( + run_stt=run_stt, + complete_reply=complete_reply, + synthesize=lambda text: text.encode(), + spawn=_sync_spawn, + ) + run_cascade(renderer=renderer, player=player, config=config, deps=deps) + assert player.started is True + assert player.closed is True + assert ("connected",) in renderer.calls + assert ("agent_transcript", "Welcome.", False) in renderer.calls + assert ("user_final", "hello") in renderer.calls + # The greeting is threaded into the LLM call as prior context. + assert {"role": "assistant", "content": "Welcome."} in session_box["messages"] + + +def test_run_cascade_shuts_down_inflight_worker(): + worker = FakeWorker(alive=True) + + def lazy_spawn(target): + # Leave the reply "running" so shutdown is what joins it. + return worker + + def run_stt(on_turn): + on_turn(_turn("hello")) + + deps = CascadeDeps( + run_stt=run_stt, complete_reply=lambda m: "hi", synthesize=lambda t: b"", spawn=lazy_spawn + ) + run_cascade( + renderer=FakeRenderer(), player=FakePlayer(), config=CascadeConfig(greeting=""), deps=deps + ) + assert worker.joined == 1 + + +def test_run_cascade_reraises_recorded_leg_error(): + def run_stt(on_turn): + on_turn(_turn("hi")) + + def boom(messages): + raise APIError("gateway down") + + deps = CascadeDeps( + run_stt=run_stt, complete_reply=boom, synthesize=lambda t: b"", spawn=_sync_spawn + ) + with pytest.raises(APIError, match="gateway down"): + run_cascade( + renderer=FakeRenderer(), + player=FakePlayer(), + config=CascadeConfig(greeting=""), + deps=deps, + ) + + +def test_run_cascade_closes_player_when_stt_raises(): + def run_stt(on_turn): + raise APIError("stt failed") + + player = FakePlayer() + deps = CascadeDeps( + run_stt=run_stt, complete_reply=lambda m: "", synthesize=lambda t: b"", spawn=_sync_spawn + ) + with pytest.raises(APIError, match="stt failed"): + run_cascade( + renderer=FakeRenderer(), player=player, config=CascadeConfig(greeting=""), deps=deps + ) + assert player.closed is True diff --git a/tests/test_agent_framework_text.py b/tests/test_agent_framework_text.py new file mode 100644 index 00000000..be85f92b --- /dev/null +++ b/tests/test_agent_framework_text.py @@ -0,0 +1,45 @@ +"""Tests for the cascade's pure text helpers.""" + +from __future__ import annotations + +from aai_cli.agent_framework.text import split_sentences, trim_history + + +def test_split_sentences_breaks_on_terminators(): + assert split_sentences("One. Two! Three?") == ["One.", "Two!", "Three?"] + + +def test_split_sentences_keeps_unterminated_tail(): + assert split_sentences("Done. And more") == ["Done.", "And more"] + + +def test_split_sentences_strips_whitespace_and_drops_empties(): + assert split_sentences(" Hi. ") == ["Hi."] + + +def test_split_sentences_empty_string_is_empty_list(): + assert split_sentences("") == [] + + +def test_split_sentences_each_terminator_ends_a_sentence(): + # Every terminator closes the current chunk, so consecutive ones each yield one. + assert split_sentences("...") == [".", ".", "."] + assert split_sentences(" . ") == ["."] + + +def test_trim_history_drops_oldest_beyond_limit(): + history = [{"role": "user", "content": str(i)} for i in range(5)] + trim_history(history, 3) + assert [item["content"] for item in history] == ["2", "3", "4"] + + +def test_trim_history_leaves_short_history_untouched(): + history = [{"role": "user", "content": "a"}, {"role": "user", "content": "b"}] + trim_history(history, 3) + assert len(history) == 2 + + +def test_trim_history_at_limit_is_untouched(): + history = [{"role": "user", "content": str(i)} for i in range(3)] + trim_history(history, 3) + assert len(history) == 3 diff --git a/tests/test_agent_framework_voices.py b/tests/test_agent_framework_voices.py new file mode 100644 index 00000000..77cda3b6 --- /dev/null +++ b/tests/test_agent_framework_voices.py @@ -0,0 +1,31 @@ +"""Tests for the cascade's voice catalog presentation.""" + +from __future__ import annotations + +from aai_cli.agent_framework import voices +from aai_cli.tts import voices as tts_voices + + +def test_voice_names_are_the_tts_catalog_sorted(): + assert sorted(tts_voices.VOICE_LANGUAGES) == voices.VOICE_NAMES + + +def test_default_voice_is_in_catalog(): + assert voices.DEFAULT_VOICE in voices.VOICE_NAMES + + +def test_complete_voice_filters_by_prefix(): + completions = voices.complete_voice("ja") + assert "jane" in completions + assert all(name.startswith("ja") for name in completions) + + +def test_format_voice_list_groups_by_language(): + listing = voices.format_voice_list() + blocks = {block.split(":", 1)[0]: block for block in listing.split("\n\n")} + # Each voice is filed strictly under the language it actually speaks: the English + # block lists jane but not the Italian-only giovanni, and vice versa. + assert "jane" in blocks["English"] + assert "giovanni" not in blocks["English"] + assert "giovanni" in blocks["Italian"] + assert "jane" not in blocks["Italian"] diff --git a/tests/test_init_template_agent_framework_api.py b/tests/test_init_template_agent_framework_api.py index c2ba5442..a0da1698 100644 --- a/tests/test_init_template_agent_framework_api.py +++ b/tests/test_init_template_agent_framework_api.py @@ -181,7 +181,7 @@ async def drive(): await asyncio.sleep(0) # let it start and block on the LLM task.cancel() with pytest.raises(asyncio.CancelledError): - await task + await asyncio.gather(task) asyncio.run(drive()) # Cancellation must NOT be turned into a session.error. diff --git a/tests/test_init_template_agent_framework_reply.py b/tests/test_init_template_agent_framework_reply.py index b17d92bc..902594e6 100644 --- a/tests/test_init_template_agent_framework_reply.py +++ b/tests/test_init_template_agent_framework_reply.py @@ -135,7 +135,7 @@ async def drive(): await asyncio.sleep(0) task.cancel() with pytest.raises(asyncio.CancelledError): - await task + await asyncio.gather(task) asyncio.run(drive()) assert not any(e["type"] == "session.error" for e in browser.sent) @@ -246,7 +246,7 @@ async def drive(): await asyncio.sleep(0) # let it stream + synthesize the sentence, then block task.cancel() with pytest.raises(asyncio.CancelledError): - await task + await asyncio.gather(task) asyncio.run(drive()) assert session.history == [{"role": "assistant", "content": "First sentence."}] diff --git a/tests/test_init_template_contract.py b/tests/test_init_template_contract.py index 9d0bf965..e4010ce5 100644 --- a/tests/test_init_template_contract.py +++ b/tests/test_init_template_contract.py @@ -12,6 +12,7 @@ # Map an import name to its PyPI distribution where they differ. _PKG_MAP = {"dotenv": "python-dotenv", "multipart": "python-multipart"} _STDLIB = { + "abc", "os", "tempfile", "uuid", diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 683e2e4b..5e50ced8 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -155,6 +155,7 @@ def test_help_lists_commands_in_workflow_order(): "stream", "dictate", "agent", + "agent-framework", "speak", "llm", "clip",