diff --git a/aai_cli/agent/session.py b/aai_cli/agent/session.py index 44b96220..27a08df6 100644 --- a/aai_cli/agent/session.py +++ b/aai_cli/agent/session.py @@ -5,6 +5,7 @@ import json import threading from collections.abc import Callable +from dataclasses import dataclass from typing import Any from aai_cli import environments @@ -26,6 +27,30 @@ def _ws_url() -> str: # session.error codes that mean the connection is unauthorized -> exit 2. _AUTH_ERROR_CODES = {"UNAUTHORIZED", "FORBIDDEN"} +# The websocket connection, the `connect` factory, and the renderer/player/mic I/O +# objects come from libraries/modules with no usable type stubs. Alias that untyped +# boundary here so each role is named in signatures and `Any` stays in one place. +# (Server event payloads remain `dict[str, Any]`.) +_WebSocket = Any +_Connect = Any +_IO = Any + + +@dataclass(frozen=True) +class AgentRunConfig: + """The static (per-run) configuration for a Voice Agent session. + + Bundled into one value so `run_session`'s signature stays small: the I/O + objects (renderer/player/mic) vary per call, but these knobs are fixed once + the command has parsed its flags. + """ + + voice: str + system_prompt: str + greeting: str + full_duplex: bool = False + exit_after_reply: bool = False + class VoiceAgentSession: """Routes Voice Agent server events to the renderer, player, and duplex state.""" @@ -33,8 +58,8 @@ class VoiceAgentSession: def __init__( self, *, - renderer: Any, - player: Any, + renderer: _IO, + player: _IO, full_duplex: bool = False, exit_after_reply: bool = False, ready_event: threading.Event | None = None, @@ -144,7 +169,7 @@ def raise_error(self, event: dict[str, Any]) -> None: } -def _send_audio_loop(ws: Any, session: VoiceAgentSession, mic: Any) -> None: +def _send_audio_loop(ws: _WebSocket, session: VoiceAgentSession, mic: _IO) -> None: """Forward mic PCM as input.audio while the session gate allows it.""" # File-driven runs wait for session.ready before consuming the source, so a # finite clip isn't partly drained (and dropped) before the server accepts it. @@ -168,7 +193,7 @@ def _auth_or_api_error(exc: Exception, message: str) -> CLIError: return APIError(f"{message}: {exc}") -def _open_ws(connect: Any, api_key: str) -> Any: +def _open_ws(connect: _Connect, api_key: str) -> _WebSocket: """Open the Voice Agent socket, mapping a connect failure to a clean CLIError.""" try: return connect(_ws_url(), additional_headers={"Authorization": f"Bearer {api_key}"}) @@ -176,35 +201,53 @@ def _open_ws(connect: Any, api_key: str) -> Any: raise _auth_or_api_error(exc, "Could not connect to the voice agent") from exc +def _session_update_message(config: AgentRunConfig) -> str: + """The initial session.update payload as a JSON string: persona, greeting, voice.""" + return json.dumps( + { + "type": "session.update", + "session": { + "system_prompt": config.system_prompt, + "greeting": config.greeting, + "output": {"voice": config.voice}, + }, + } + ) + + +def _receive_loop(ws: _WebSocket, session: VoiceAgentSession) -> None: + """Dispatch inbound server events until the socket closes or the run finishes.""" + for raw in ws: + session.dispatch(json.loads(raw)) + if session.finished: + break + + def run_session( api_key: str, *, - renderer: Any, - player: Any, - mic: Any, - voice: str, - system_prompt: str, - greeting: str, - full_duplex: bool = False, - exit_after_reply: bool = False, - connect: Any = None, + renderer: _IO, + player: _IO, + mic: _IO, + config: AgentRunConfig, + connect: _Connect = None, ) -> None: """Open the Voice Agent WebSocket and run the bidirectional loop until close. `connect` defaults to websockets' synchronous client; injectable for tests. - When `exit_after_reply` is set (file-driven runs), the loop stops after the - agent's first reply to the spoken input and the capture thread waits for + When `config.exit_after_reply` is set (file-driven runs), the loop stops after + the agent's first reply to the spoken input and the capture thread waits for session.ready before streaming the source. """ if connect is None: from websockets.sync.client import connect - ready_event = threading.Event() if exit_after_reply else None + ready_event = threading.Event() if config.exit_after_reply else None session = VoiceAgentSession( renderer=renderer, player=player, - full_duplex=full_duplex, - exit_after_reply=exit_after_reply, + full_duplex=config.full_duplex, + exit_after_reply=config.exit_after_reply, ready_event=ready_event, ) @@ -231,22 +274,8 @@ def _capture() -> None: player.start() # opens the speaker stream; CLIError here if sounddevice can't load player_started = True threading.Thread(target=_capture, daemon=True).start() - ws.send( - json.dumps( - { - "type": "session.update", - "session": { - "system_prompt": system_prompt, - "greeting": greeting, - "output": {"voice": voice}, - }, - } - ) - ) - for raw in ws: - session.dispatch(json.loads(raw)) - if session.finished: - break + ws.send(_session_update_message(config)) + _receive_loop(ws, session) except (CLIError, KeyboardInterrupt, BrokenPipeError): raise # clean CLI errors, user Ctrl-C, and a closed pipe are handled upstream except Exception as exc: diff --git a/aai_cli/code_gen/stream.py b/aai_cli/code_gen/stream.py index 631f32f3..02465eae 100644 --- a/aai_cli/code_gen/stream.py +++ b/aai_cli/code_gen/stream.py @@ -96,38 +96,46 @@ def on_turn(client: StreamingClient, event: TurnEvent) -> None: """ -def render(merged: dict[str, object], *, llm: dict[str, object] | None = None) -> str: - """Generate a runnable microphone-streaming script with the given params. - - With `llm`, the script transforms the live transcript through the LLM Gateway, - refreshing a prompt chain on every finalized turn (the live sibling of - `transcribe --llm`). - """ +def _imports_block(merged: dict[str, object]) -> str: + """Sorted streaming-class import lines; SpeechModel only when a model kwarg is emitted.""" names = list(_BASE_IMPORTS) if "speech_model" in merged: names.append("SpeechModel") - imports = "\n".join(f" {name}," for name in sorted(names)) + return "\n".join(f" {name}," for name in sorted(names)) + +def _build_preamble(imports: str, llm: dict[str, object] | None) -> str: + """Pick and fill the plain vs. LLM-Gateway preamble for the given imports.""" if llm: prompts = "\n".join(f" {p!r}," for p in cast("list[str]", llm["prompts"])) - preamble = _LLM_PREAMBLE.format( + return _LLM_PREAMBLE.format( imports=imports, base_url=gateway.GATEWAY_BASE_URL, prompts=prompts, model=llm["model"], max_tokens=llm["max_tokens"], ) - else: - preamble = _PREAMBLE.format(imports=imports) + return _PREAMBLE.format(imports=imports) - # Mic capture rate must match StreamingParameters.sample_rate, else audio is corrupt. - rate = merged.get("sample_rate", 16000) - if merged: - # indent=8: 4 for connect(), 4 more for the StreamingParameters() args. - kwargs = "\n".join(serialize.config_kwarg_lines(merged, indent=8)) - connect = f"client.connect(\n StreamingParameters(\n{kwargs}\n )\n)" - else: - connect = "client.connect(StreamingParameters())" +def _build_connect(merged: dict[str, object]) -> str: + """The `client.connect(StreamingParameters(...))` call for the given params.""" + if not merged: + return "client.connect(StreamingParameters())" + # indent=8: 4 for connect(), 4 more for the StreamingParameters() args. + kwargs = "\n".join(serialize.config_kwarg_lines(merged, indent=8)) + return f"client.connect(\n StreamingParameters(\n{kwargs}\n )\n)" + +def render(merged: dict[str, object], *, llm: dict[str, object] | None = None) -> str: + """Generate a runnable microphone-streaming script with the given params. + + With `llm`, the script transforms the live transcript through the LLM Gateway, + refreshing a prompt chain on every finalized turn (the live sibling of + `transcribe --llm`). + """ + preamble = _build_preamble(_imports_block(merged), llm) + # Mic capture rate must match StreamingParameters.sample_rate, else audio is corrupt. + rate = merged.get("sample_rate", 16000) + connect = _build_connect(merged) return preamble + "\n" + connect + "\n" + _FOOTER.format(rate=rate) diff --git a/aai_cli/commands/agent.py b/aai_cli/commands/agent.py index cfb73950..2e041172 100644 --- a/aai_cli/commands/agent.py +++ b/aai_cli/commands/agent.py @@ -9,7 +9,12 @@ from aai_cli import client, code_gen, config, help_panels, output from aai_cli.agent.audio import SAMPLE_RATE, DuplexAudio, NullPlayer from aai_cli.agent.render import AgentRenderer -from aai_cli.agent.session import DEFAULT_GREETING, DEFAULT_PROMPT, run_session +from aai_cli.agent.session import ( + DEFAULT_GREETING, + DEFAULT_PROMPT, + AgentRunConfig, + run_session, +) from aai_cli.agent.voices import DEFAULT_VOICE, VOICES, format_voice_list from aai_cli.context import AppState, run_command from aai_cli.errors import CLIError, UsageError @@ -19,6 +24,46 @@ app = typer.Typer() +def _resolve_system_prompt(system_prompt: str, system_prompt_file: Path | None) -> str: + """The persona text: a --system-prompt-file (if given) overrides --system-prompt.""" + if system_prompt_file is None: + return system_prompt + try: + return system_prompt_file.read_text(encoding="utf-8") + except OSError as exc: + raise CLIError( + f"Could not read --system-prompt-file {system_prompt_file}: {exc}", + error_type="file_not_found", + exit_code=2, + suggestion="Check the path and that the file is readable.", + ) from exc + + +def _open_audio( + renderer: AgentRenderer, + *, + source: str | None, + sample: bool, + device: int | None, + from_file: bool, +) -> tuple[Any, Any]: + """Build the (mic, player) pair for either file-driven or live-mic input.""" + if from_file: + # Stream the clip as the user's speech and stop after the agent replies. + # No greeting and full-duplex so no part of the clip is muted/dropped, + # and a NullPlayer since there is no listener for the reply audio. + return FileSource(client.resolve_audio_source(source, sample=sample)), NullPlayer() + # One full-duplex stream for mic + speaker: macOS rejects two separate + # streams on a device, which silently kills capture. + duplex = DuplexAudio(target_rate=SAMPLE_RATE, device=device) + # notice() self-suppresses in JSON mode and routes to stderr in text mode. + renderer.notice( + "Use headphones — the mic stays open while the agent speaks, " + "so speakers would let it hear itself.\n" + ) + return duplex.mic, duplex.player + + @app.command( rich_help_panel=help_panels.TRANSCRIPTION, epilog=examples_epilog( @@ -83,18 +128,7 @@ def body(state: AppState, json_mode: bool) -> None: f"Unknown voice {voice!r}.", suggestion="Run 'aai agent --list-voices' to see the options.", ) - if system_prompt_file is not None: - try: - system_prompt_text = system_prompt_file.read_text(encoding="utf-8") - except OSError as exc: - raise CLIError( - f"Could not read --system-prompt-file {system_prompt_file}: {exc}", - error_type="file_not_found", - exit_code=2, - suggestion="Check the path and that the file is readable.", - ) from exc - else: - system_prompt_text = system_prompt + system_prompt_text = _resolve_system_prompt(system_prompt, system_prompt_file) if show_code: # Print-only: emit the equivalent agent script from the flags and exit @@ -112,37 +146,18 @@ def body(state: AppState, json_mode: bool) -> None: text_mode=text_mode, mic_input=not from_file, ) - audio: Any - player: Any - if from_file: - # Stream the clip as the user's speech and stop after the agent replies. - # No greeting and full-duplex so no part of the clip is muted/dropped, - # and a NullPlayer since there is no listener for the reply audio. - audio = FileSource(client.resolve_audio_source(source, sample=sample)) - player = NullPlayer() - else: - # One full-duplex stream for mic + speaker: macOS rejects two separate - # streams on a device, which silently kills capture. - duplex = DuplexAudio(target_rate=SAMPLE_RATE, device=device) - audio = duplex.mic - player = duplex.player - # notice() self-suppresses in JSON mode and routes to stderr in text mode. - renderer.notice( - "Use headphones — the mic stays open while the agent speaks, " - "so speakers would let it hear itself.\n" - ) + audio, player = _open_audio( + renderer, source=source, sample=sample, device=device, from_file=from_file + ) + run_config = AgentRunConfig( + voice=voice, + system_prompt=system_prompt_text, + greeting="" if from_file else greeting, + full_duplex=True, # one duplex stream -> mic always open (use headphones) + exit_after_reply=from_file, + ) try: - run_session( - api_key, - renderer=renderer, - player=player, - mic=audio, - voice=voice, - system_prompt=system_prompt_text, - greeting="" if from_file else greeting, - full_duplex=True, # one duplex stream -> mic always open (use headphones) - exit_after_reply=from_file, - ) + run_session(api_key, renderer=renderer, player=player, mic=audio, config=run_config) except KeyboardInterrupt: renderer.stopped() except BrokenPipeError as exc: diff --git a/aai_cli/commands/init.py b/aai_cli/commands/init.py index e075dc38..d1ebbffe 100644 --- a/aai_cli/commands/init.py +++ b/aai_cli/commands/init.py @@ -59,6 +59,110 @@ def _resolve_dir(directory: str | None, template: str, *, here: bool) -> Path: return Path.cwd() / template +def _resolve_template(template: str | None) -> str: + """Resolve the template name: the picker when omitted, else validate the arg.""" + chosen = template if template is not None else _pick_template() + if not templates.is_template(chosen): + raise CLIError( + f"Unknown template {chosen!r}. Choose one of: {', '.join(templates.TEMPLATE_ORDER)}.", + error_type="usage_error", + exit_code=1, + ) + return chosen + + +def _active_env_vars() -> dict[str, str]: + """Pin the scaffolded app to the active environment's hosts. + + A sandbox key (minted by `aai login` against a non-prod env) would otherwise be + rejected by the production defaults baked into the template. + """ + env = environments.active() + return { + "ASSEMBLYAI_BASE_URL": env.api_base, + "ASSEMBLYAI_LLM_GATEWAY_URL": env.llm_gateway_base, + "ASSEMBLYAI_STREAMING_HOST": env.streaming_host, + # Voice Agent host mirrors the streaming host's naming across environments. + "ASSEMBLYAI_AGENTS_HOST": env.streaming_host.replace("streaming", "agents", 1), + } + + +def _install_step( + target: Path, *, no_install: bool, api_key: str | None, use_uv: bool +) -> tuple[list[steps.Step], bool]: + """Run (or skip) dependency install, returning the report rows and whether to launch. + + Launch only happens when deps are installed and there's a key; an install failure + flips `will_launch` off so the caller exits non-zero instead of starting a server. + """ + will_launch = not no_install and api_key is not None + if no_install: + return [{"name": "install", "status": "skipped", "detail": "--no-install"}], will_launch + setup = runner.run_setup(target, use_uv=use_uv) + if setup.returncode != 0: + row: steps.Step = { + "name": "install", + "status": "failed", + "detail": (setup.stderr or setup.stdout).strip()[:300], + } + return [row], False + return [ + { + "name": "install", + "status": "installed", + "detail": "uv" if use_uv else "venv + pip", + } + ], will_launch + + +def _resolve_target(directory: str | None, chosen: str, *, here: bool, force: bool) -> Path: + """Resolve the target directory and reject --here+DIRECTORY or a non-empty conflict.""" + if here and directory: + raise CLIError( + "Pass either a DIRECTORY or --here, not both.", + error_type="usage_error", + exit_code=1, + ) + target = _resolve_dir(directory, chosen, here=here) + if scaffold.target_conflict(target) and not force: + raise CLIError( + f"{target} already exists and is not empty. " + f"Use --force to overwrite or pick another directory.", + error_type="usage_error", + exit_code=1, + ) + return target + + +def _scaffold_report(chosen: str, target: Path, api_key: str | None) -> list[steps.Step]: + """Write the template to `target` and return the opening report rows.""" + scaffold.scaffold(chosen, target, api_key=api_key, env_vars=_active_env_vars()) + report: list[steps.Step] = [{"name": "scaffold", "status": "created", "detail": str(target)}] + if api_key is None: + report.append( + { + "name": "key", + "status": "skipped", + "detail": "no API key found; wrote a placeholder to .env (run `aai login`)", + } + ) + return report + + +def _launch(target: Path, *, port: int, use_uv: bool, no_open: bool, json_mode: bool) -> None: + """Start the scaffolded app on a free port and open the browser, then block.""" + chosen_port = runner.find_free_port(port) + url = f"http://localhost:{chosen_port}" + if not json_mode: + output.console.print( + f"[aai.heading]Starting[/aai.heading] [aai.url]{escape(url)}[/aai.url]" + " [aai.muted](Ctrl-C to stop)[/aai.muted]" + ) + code = runner.launch_and_open(target, port=chosen_port, use_uv=use_uv, open_browser=not no_open) + if code: + raise typer.Exit(code=code) + + @app.command( rich_help_panel=help_panels.QUICK_START, epilog=examples_epilog( @@ -102,80 +206,17 @@ def body(state: AppState, json_mode: bool) -> None: output.console.print( f"[aai.heading]AssemblyAI CLI[/aai.heading] [aai.muted]{__version__}[/aai.muted]" ) - chosen = template - if chosen is None: - chosen = _pick_template() - if not templates.is_template(chosen): - raise CLIError( - f"Unknown template {chosen!r}. Choose one of: " - f"{', '.join(templates.TEMPLATE_ORDER)}.", - error_type="usage_error", - exit_code=1, - ) - - if here and directory: - raise CLIError( - "Pass either a DIRECTORY or --here, not both.", - error_type="usage_error", - exit_code=1, - ) - target = _resolve_dir(directory, chosen, here=here) - if scaffold.target_conflict(target) and not force: - raise CLIError( - f"{target} already exists and is not empty. " - f"Use --force to overwrite or pick another directory.", - error_type="usage_error", - exit_code=1, - ) + chosen = _resolve_template(template) + target = _resolve_target(directory, chosen, here=here, force=force) api_key = keys.resolve_optional_api_key(profile=state.profile) - # Pin the app to the active environment's hosts so a sandbox key (minted by - # `aai login` against a non-prod env) isn't rejected by the production defaults. - env = environments.active() - env_vars = { - "ASSEMBLYAI_BASE_URL": env.api_base, - "ASSEMBLYAI_LLM_GATEWAY_URL": env.llm_gateway_base, - "ASSEMBLYAI_STREAMING_HOST": env.streaming_host, - # Voice Agent host mirrors the streaming host's naming across environments. - "ASSEMBLYAI_AGENTS_HOST": env.streaming_host.replace("streaming", "agents", 1), - } - scaffold.scaffold(chosen, target, api_key=api_key, env_vars=env_vars) - - report: list[steps.Step] = [ - {"name": "scaffold", "status": "created", "detail": str(target)} - ] - if api_key is None: - report.append( - { - "name": "key", - "status": "skipped", - "detail": "no API key found; wrote a placeholder to .env (run `aai login`)", - } - ) + report = _scaffold_report(chosen, target, api_key) use_uv = runner.has_uv() - will_launch = not no_install and api_key is not None - if no_install: - report.append({"name": "install", "status": "skipped", "detail": "--no-install"}) - else: - setup = runner.run_setup(target, use_uv=use_uv) - if setup.returncode != 0: - report.append( - { - "name": "install", - "status": "failed", - "detail": (setup.stderr or setup.stdout).strip()[:300], - } - ) - will_launch = False - else: - report.append( - { - "name": "install", - "status": "installed", - "detail": "uv" if use_uv else "venv + pip", - } - ) + install_rows, will_launch = _install_step( + target, no_install=no_install, api_key=api_key, use_uv=use_uv + ) + report.extend(install_rows) # Deps are installed but there's no key, so the server can't start — say so # rather than exiting silently. @@ -193,18 +234,7 @@ def body(state: AppState, json_mode: bool) -> None: raise typer.Exit(code=1) if will_launch: - chosen_port = runner.find_free_port(port) - url = f"http://localhost:{chosen_port}" - if not json_mode: - output.console.print( - f"[aai.heading]Starting[/aai.heading] [aai.url]{escape(url)}[/aai.url]" - " [aai.muted](Ctrl-C to stop)[/aai.muted]" - ) - code = runner.launch_and_open( - target, port=chosen_port, use_uv=use_uv, open_browser=not no_open - ) - if code: - raise typer.Exit(code=code) + _launch(target, port=port, use_uv=use_uv, no_open=no_open, json_mode=json_mode) elif not json_mode: # Scaffolded but not launched (no key, or --no-install): leave the user with # the one command that starts their app, the way `vercel`/`supabase` sign off. diff --git a/aai_cli/commands/llm.py b/aai_cli/commands/llm.py index 4ecd10e7..e372755f 100644 --- a/aai_cli/commands/llm.py +++ b/aai_cli/commands/llm.py @@ -13,6 +13,33 @@ app = typer.Typer() +def _validate_follow_args( + prompt: str | None, output_field: str | None, transcript_id: str | None +) -> str: + """Reject flag combinations that don't apply to --follow's live-panel mode. + + Returns the validated (non-empty) prompt so the caller has a plain ``str``. + """ + if not prompt: + raise UsageError("Provide a prompt to run over the streamed transcript.") + if output_field is not None: + raise UsageError( + "--output applies to one-shot mode; --follow renders a live panel " + "(or NDJSON when piped)." + ) + if transcript_id: + raise UsageError( + "--follow runs over live transcript text piped on stdin; it can't be " + "combined with --transcript-id." + ) + if not stdio.stdin_is_piped(): + raise UsageError( + "--follow needs transcript text piped on stdin, e.g. " + '`aai stream -o text | aai llm -f "summarize action items as I talk"`.' + ) + return prompt + + @app.command( rich_help_panel=help_panels.TRANSCRIPTION, epilog=examples_epilog( @@ -65,24 +92,7 @@ def llm( raise typer.Exit(code=0) def follow_body(state: AppState, json_mode: bool) -> None: - if not prompt: - raise UsageError("Provide a prompt to run over the streamed transcript.") - prompt_text = prompt - if output_field is not None: - raise UsageError( - "--output applies to one-shot mode; --follow renders a live panel " - "(or NDJSON when piped)." - ) - if transcript_id: - raise UsageError( - "--follow runs over live transcript text piped on stdin; it can't be " - "combined with --transcript-id." - ) - if not stdio.stdin_is_piped(): - raise UsageError( - "--follow needs transcript text piped on stdin, e.g. " - '`aai stream -o text | aai llm -f "summarize action items as I talk"`.' - ) + prompt_text = _validate_follow_args(prompt, output_field, transcript_id) api_key = config.resolve_api_key(profile=state.profile) def ask(transcript_text: str) -> str: diff --git a/aai_cli/commands/stream.py b/aai_cli/commands/stream.py index 09118225..4477bf72 100644 --- a/aai_cli/commands/stream.py +++ b/aai_cli/commands/stream.py @@ -3,7 +3,8 @@ import queue import tempfile import threading -from collections.abc import Iterable +from collections.abc import Callable, Iterable +from dataclasses import dataclass, field from pathlib import Path import typer @@ -23,6 +24,253 @@ DEFAULT_SPEECH_MODEL = SpeechModel.u3_rt_pro.value +# Sources that can be transcribed in parallel sessions: (label, audio chunks, sample rate). +_ParallelStreams = list[tuple[str, Iterable[bytes], int]] + + +@dataclass(frozen=True) +class _SourceOptions: + """Where the audio comes from, distilled from the CLI flags. + + Centralizes the "which input?" predicates so the validation and dispatch helpers + below read off one object instead of re-deriving the same booleans. + """ + + source: str | None + sample: bool + sample_rate: int | None + device: int | None + system_audio: bool + system_audio_only: bool + + @property + def from_stdin(self) -> bool: + return self.source == "-" + + @property + def from_file(self) -> bool: + return bool(self.source) or self.sample + + @property + def from_system_audio(self) -> bool: + return self.system_audio or self.system_audio_only + + @property + def has_capture_overrides(self) -> bool: + """Whether a microphone-only flag (--sample-rate or --device) was given.""" + return self.sample_rate is not None or self.device is not None + + +def _validate_sources(opts: _SourceOptions, *, has_llm: bool, text_mode: bool) -> None: + """Reject flag combinations that can't be honored, before any audio is opened.""" + if opts.system_audio and opts.system_audio_only: + raise UsageError("Use either --system-audio or --system-audio-only, not both.") + _validate_input_source(opts) + if has_llm and text_mode: + raise UsageError( + "--llm renders a live panel (or NDJSON when piped); it can't be combined with -o text." + ) + + +def _validate_input_source(opts: _SourceOptions) -> None: + """Reject --sample-rate/--device/source combinations the chosen input can't accept.""" + if opts.from_system_audio: + if opts.from_file: + raise UsageError("--system-audio cannot be combined with an audio source or --sample.") + if opts.system_audio_only and opts.has_capture_overrides: + raise UsageError( + "--sample-rate and --device require microphone input; use --system-audio." + ) + elif opts.from_stdin: + if opts.device is not None: + raise UsageError("--device applies only to microphone input.") + elif opts.from_file and opts.has_capture_overrides: + raise UsageError("--sample-rate and --device apply only to microphone input.") + + +@dataclass +class _StreamSession: + """Owns one streaming run: the renderers, the LLM-chain state, and the audio + plumbing shared across single- and parallel-source streaming. + + Holding this as an object (rather than a nest of closures inside the command body) + keeps each step a small, independently readable method, and collapses the ~25 + per-call flags into one ``base_flags`` dict that only varies by sample rate. + """ + + api_key: str + base_flags: dict[str, object] + overrides: list[str] | None + config_file: str | None + renderer: StreamRenderer + follow: FollowRenderer | None + llm_prompts: list[str] + model: str + max_tokens: int + transcript: list[str] = field(default_factory=list[str]) + _callback_lock: threading.RLock = field(default_factory=threading.RLock) + _listening_lock: threading.Lock = field(default_factory=threading.Lock) + _listening_started: bool = False + + @property + def on_open(self) -> Callable[[], None]: + """First-audio callback: announce "Listening…" once — unless the FollowRenderer + owns the screen in --llm mode, where the notice would clutter the live panel.""" + return (lambda: None) if self.follow is not None else self._listening_once + + def _listening_once(self) -> None: + with self._listening_lock: + if self._listening_started: + return + self._listening_started = True + self.renderer.listening() + + def on_turn(self, event: object, *, source_label: str | None = None) -> None: + with self._callback_lock: + if self.follow is None: + self.renderer.turn(event, source=source_label) + else: + self._refresh_answer(event, source_label) + + def _refresh_answer(self, event: object, source_label: str | None) -> None: + """Live --llm mode: re-run the prompt chain over the growing transcript on every + finalized turn, refreshing one evolving answer (partials are ignored).""" + follow = self.follow + if follow is None or not getattr(event, "end_of_turn", False): + return + text = getattr(event, "transcript", "") or "" + if not text: + return + if source_label is not None: + display_source = {"system": "System", "you": "You"}.get(source_label, source_label) + text = f"{display_source}: {text}" + self.transcript.append(text) + answer = llm.run_chain( + self.api_key, + self.llm_prompts, + transcript_text=" ".join(self.transcript), + model=self.model, + max_tokens=self.max_tokens, + ) + follow(answer, len(self.transcript)) + + def stream_one( + self, audio: Iterable[bytes], rate: int, *, source_label: str | None = None + ) -> None: + merged = config_builder.merge_streaming_params( + flags={**self.base_flags, "sample_rate": rate}, + overrides=self.overrides, + config_file=self.config_file, + ) + params = config_builder.construct_streaming_params(merged) + client.stream_audio( + self.api_key, + audio, + params=params, + on_begin=( + None + if self.follow is not None + else lambda event: self.renderer.begin(event, source=source_label) + ), + on_turn=lambda event: self.on_turn(event, source_label=source_label), + on_termination=( + None + if self.follow is not None + else lambda event: self.renderer.termination(event, source=source_label) + ), + ) + + def _guarded(self, work: Callable[[], None]) -> None: + """Run a streaming body with the shared lifecycle handling: enter the + FollowRenderer's live panel if present, treat Ctrl-C as a clean stop, exit 0 on + a closed downstream pipe, and always close the renderer.""" + try: + if self.follow is not None: + with self.follow: + work() + else: + work() + except KeyboardInterrupt: + # Ctrl-C is a normal "user stopped" signal -> exit 0. + if self.follow is None: + self.renderer.close() + self.renderer.stopped() + except BrokenPipeError: + # Downstream consumer (e.g. `| head`) closed the pipe; stop quietly. + raise typer.Exit(code=0) from None + finally: + if self.follow is None: + self.renderer.close() + + def run(self, audio: Iterable[bytes], rate: int, *, source_label: str | None = None) -> None: + self._guarded(lambda: self.stream_one(audio, rate, source_label=source_label)) + + def run_parallel(self, streams: _ParallelStreams) -> None: + self._guarded(lambda: self._drive(streams)) + + def _drive(self, streams: _ParallelStreams) -> None: + """Stream every source concurrently, surfacing the first worker error.""" + errors: queue.Queue[Exception] = queue.Queue() + + def worker(source_label: str, audio: Iterable[bytes], rate: int) -> None: + try: + self.stream_one(audio, rate, source_label=source_label) + except (CLIError, BrokenPipeError) as exc: + errors.put(exc) + + threads = [ + threading.Thread(target=worker, args=(label, audio, rate), daemon=True) + for label, audio, rate in streams + ] + for thread in threads: + thread.start() + while any(thread.is_alive() for thread in threads): + for thread in threads: + thread.join(timeout=0.1) + if not errors.empty(): + raise errors.get() + if not errors.empty(): + raise errors.get() + + +def _dispatch(session: _StreamSession, opts: _SourceOptions) -> None: + """Open the right audio source(s) for the flags and stream them.""" + if opts.from_system_audio: + system = MacSystemAudioSource(on_open=session.on_open) + if opts.system_audio_only: + session.run(system, system.sample_rate, source_label="system") + else: + mic = MicrophoneSource( + target_rate=TARGET_RATE, + device=opts.device, + capture_rate=opts.sample_rate, + on_open=session.on_open, + ) + session.run_parallel( + [("system", system, system.sample_rate), ("you", mic, mic.sample_rate)] + ) + elif opts.from_stdin: + # Raw PCM16 mono piped on stdin (e.g. `ffmpeg … -f s16le - | aai stream -`). + stdin_src = StdinSource(sample_rate=opts.sample_rate or TARGET_RATE) + session.run(stdin_src, stdin_src.sample_rate) + elif opts.source and youtube.is_youtube_url(opts.source): + # Fetch the audio first, then stream the local file in real time. + with tempfile.TemporaryDirectory(prefix="aai-yt-") as td: + local = youtube.download_audio(opts.source, Path(td)) + session.run(FileSource(str(local)), TARGET_RATE) + elif opts.from_file: + file_audio = FileSource(client.resolve_audio_source(opts.source, sample=opts.sample)) + session.run(file_audio, file_audio.sample_rate) + else: + # Capture at the device's native rate (or --sample-rate override) and tell the + # streaming API that rate, rather than forcing one the device may reject. + # "Listening…" is announced once the device is open (see _StreamSession.on_open), + # not when the session opens — so early speech isn't lost in the gap. + mic = MicrophoneSource( + device=opts.device, capture_rate=opts.sample_rate, on_open=session.on_open + ) + session.run(mic, mic.sample_rate) + @app.command( rich_help_panel=help_panels.TRANSCRIPTION, @@ -163,47 +411,49 @@ def stream( def body(state: AppState, json_mode: bool) -> None: text_mode, json_mode = output.stream_output_modes(output_field, json_mode=json_mode) - from_stdin = source == "-" - from_file = bool(source) or sample - from_system_audio = system_audio or system_audio_only - - def make_flags(rate: int) -> dict[str, object]: - flags: dict[str, object] = { - "sample_rate": rate, - "speech_model": speech_model, - "format_turns": format_turns if format_turns is not None else True, - "encoding": encoding, - "language_detection": language_detection, - "domain": domain, - "end_of_turn_confidence_threshold": end_of_turn_confidence_threshold, - "min_turn_silence": min_turn_silence, - "max_turn_silence": max_turn_silence, - "vad_threshold": vad_threshold, - "include_partial_turns": include_partial_turns, - "keyterms_prompt": list(keyterms_prompt) if keyterms_prompt else None, - "filter_profanity": filter_profanity, - "speaker_labels": speaker_labels, - "max_speakers": max_speakers, - "voice_focus": voice_focus, - "voice_focus_threshold": voice_focus_threshold, - "redact_pii": redact_pii, - "redact_pii_policies": config_builder.split_csv(redact_pii_policy), - "redact_pii_sub": redact_pii_sub, - "inactivity_timeout": inactivity_timeout, - "webhook_url": webhook_url, - "prompt": prompt, - } - flags.update(config_builder.auth_header_flags(webhook_auth_header)) - return flags + opts = _SourceOptions( + source=source, + sample=sample, + sample_rate=sample_rate, + device=device, + system_audio=system_audio, + system_audio_only=system_audio_only, + ) + # Every streaming flag except sample_rate, which is set per source at stream time. + base_flags: dict[str, object] = { + "speech_model": speech_model, + "format_turns": format_turns if format_turns is not None else True, + "encoding": encoding, + "language_detection": language_detection, + "domain": domain, + "end_of_turn_confidence_threshold": end_of_turn_confidence_threshold, + "min_turn_silence": min_turn_silence, + "max_turn_silence": max_turn_silence, + "vad_threshold": vad_threshold, + "include_partial_turns": include_partial_turns, + "keyterms_prompt": list(keyterms_prompt) if keyterms_prompt else None, + "filter_profanity": filter_profanity, + "speaker_labels": speaker_labels, + "max_speakers": max_speakers, + "voice_focus": voice_focus, + "voice_focus_threshold": voice_focus_threshold, + "redact_pii": redact_pii, + "redact_pii_policies": config_builder.split_csv(redact_pii_policy), + "redact_pii_sub": redact_pii_sub, + "inactivity_timeout": inactivity_timeout, + "webhook_url": webhook_url, + "prompt": prompt, + } + base_flags.update(config_builder.auth_header_flags(webhook_auth_header)) if show_code: # Print-only: emit the canonical microphone-streaming script (16 kHz) from # the flags and exit without opening audio or authenticating. Raw stdout so # `--show-code > script.py` yields a runnable file. - if from_system_audio: + if opts.from_system_audio: raise UsageError("--show-code does not support macOS system audio capture yet.") merged = config_builder.merge_streaming_params( - flags=make_flags(TARGET_RATE), + flags={**base_flags, "sample_rate": TARGET_RATE}, overrides=config_kv, config_file=config_file, ) @@ -212,209 +462,20 @@ def make_flags(rate: int) -> dict[str, object]: return api_key = config.resolve_api_key(profile=state.profile) - if system_audio and system_audio_only: - raise UsageError("Use either --system-audio or --system-audio-only, not both.") - if from_system_audio: - if from_file: - raise UsageError( - "--system-audio cannot be combined with an audio source or --sample." - ) - if system_audio_only and (sample_rate is not None or device is not None): - raise UsageError( - "--sample-rate and --device require microphone input; use --system-audio." - ) - elif from_stdin: - if device is not None: - raise UsageError("--device applies only to microphone input.") - elif from_file and (sample_rate is not None or device is not None): - raise UsageError("--sample-rate and --device apply only to microphone input.") - - if llm_prompt and text_mode: - raise UsageError( - "--llm renders a live panel (or NDJSON when piped); it can't be combined " - "with -o text." - ) + _validate_sources(opts, has_llm=bool(llm_prompt), text_mode=text_mode) - renderer = StreamRenderer(json_mode=json_mode, text_mode=text_mode) - # In --llm mode the answer is rendered live by a FollowRenderer instead of the - # raw turns; transcript accumulates the finalized turns we re-run the chain over. llm_prompts = list(llm_prompt or []) - follow = FollowRenderer(json_mode=json_mode) if llm_prompts else None - transcript: list[str] = [] - callback_lock = threading.RLock() - listening_lock = threading.Lock() - listening_started = False - - def listening_once() -> None: - nonlocal listening_started - with listening_lock: - if listening_started: - return - listening_started = True - renderer.listening() - - def on_turn(event: object, *, source_label: str | None = None) -> None: - with callback_lock: - if follow is None: - renderer.turn(event, source=source_label) - return - # Live LLM mode: re-run the prompt chain over the growing transcript on every - # finalized turn, refreshing one evolving answer (partials are ignored). - if not getattr(event, "end_of_turn", False): - return - text = getattr(event, "transcript", "") or "" - if not text: - return - if source_label is not None: - display_source = {"system": "System", "you": "You"}.get( - source_label, - source_label, - ) - text = f"{display_source}: {text}" - transcript.append(text) - answer = llm.run_chain( - api_key, - llm_prompts, - transcript_text=" ".join(transcript), - model=model, - max_tokens=max_tokens, - ) - follow(answer, len(transcript)) - - def stream_one( - audio: Iterable[bytes], - rate: int, - *, - source_label: str | None = None, - ) -> None: - merged = config_builder.merge_streaming_params( - flags=make_flags(rate), overrides=config_kv, config_file=config_file - ) - params = config_builder.construct_streaming_params(merged) - client.stream_audio( - api_key, - audio, - params=params, - on_begin=( - None - if follow is not None - else lambda event: renderer.begin(event, source=source_label) - ), - on_turn=lambda event: on_turn(event, source_label=source_label), - on_termination=( - None - if follow is not None - else lambda event: renderer.termination(event, source=source_label) - ), - ) - - def run(audio: Iterable[bytes], rate: int, *, source_label: str | None = None) -> None: - try: - # The FollowRenderer is a context manager (it owns the live panel); enter it - # around the whole session so it stops cleanly and prints the final answer. - if follow is not None: - with follow: - stream_one(audio, rate, source_label=source_label) - else: - stream_one(audio, rate, source_label=source_label) - except KeyboardInterrupt: - # Ctrl-C is a normal "user stopped" signal -> exit 0. - if follow is None: - renderer.close() - renderer.stopped() - except BrokenPipeError: - # Downstream consumer (e.g. `| head`) closed the pipe; stop quietly. - raise typer.Exit(code=0) from None - finally: - if follow is None: - renderer.close() - - def run_parallel(streams: list[tuple[str, Iterable[bytes], int]]) -> None: - errors: queue.Queue[Exception] = queue.Queue() - - def worker(source_label: str, audio: Iterable[bytes], rate: int) -> None: - try: - stream_one(audio, rate, source_label=source_label) - except (CLIError, BrokenPipeError) as exc: - errors.put(exc) - - def drive() -> None: - threads = [ - threading.Thread( - target=worker, - args=(source_label, audio, rate), - daemon=True, - ) - for source_label, audio, rate in streams - ] - for thread in threads: - thread.start() - while any(thread.is_alive() for thread in threads): - for thread in threads: - thread.join(timeout=0.1) - if not errors.empty(): - raise errors.get() - if not errors.empty(): - raise errors.get() - - try: - if follow is not None: - with follow: - drive() - else: - drive() - except KeyboardInterrupt: - if follow is None: - renderer.close() - renderer.stopped() - except BrokenPipeError: - raise typer.Exit(code=0) from None - finally: - if follow is None: - renderer.close() - - if from_system_audio: - system = MacSystemAudioSource( - on_open=(lambda: None) if follow is not None else listening_once, - ) - if system_audio_only: - run(system, system.sample_rate, source_label="system") - else: - mic = MicrophoneSource( - target_rate=TARGET_RATE, - device=device, - capture_rate=sample_rate, - on_open=(lambda: None) if follow is not None else listening_once, - ) - run_parallel( - [ - ("system", system, system.sample_rate), - ("you", mic, mic.sample_rate), - ] - ) - elif from_stdin: - # Raw PCM16 mono piped on stdin (e.g. `ffmpeg … -f s16le - | aai stream -`). - stdin_src = StdinSource(sample_rate=sample_rate or TARGET_RATE) - run(stdin_src, stdin_src.sample_rate) - elif source and youtube.is_youtube_url(source): - # Fetch the audio first, then stream the local file in real time. - with tempfile.TemporaryDirectory(prefix="aai-yt-") as td: - local = youtube.download_audio(source, Path(td)) - run(FileSource(str(local)), TARGET_RATE) - elif from_file: - file_audio = FileSource(client.resolve_audio_source(source, sample=sample)) - run(file_audio, file_audio.sample_rate) - else: - # Capture at the device's native rate (or --sample-rate override) and tell - # the streaming API that rate, rather than forcing one the device may reject. - # Announce "Listening…" only once the device is open and recording, - # not when the session opens — so early speech isn't lost in the gap. - mic = MicrophoneSource( - device=device, - capture_rate=sample_rate, - # In --llm mode the FollowRenderer owns the screen, so skip the notice. - on_open=(lambda: None) if follow is not None else listening_once, - ) - run(mic, mic.sample_rate) + session = _StreamSession( + api_key=api_key, + base_flags=base_flags, + overrides=config_kv, + config_file=config_file, + renderer=StreamRenderer(json_mode=json_mode, text_mode=text_mode), + follow=FollowRenderer(json_mode=json_mode) if llm_prompts else None, + llm_prompts=llm_prompts, + model=model, + max_tokens=max_tokens, + ) + _dispatch(session, opts) run_command(ctx, body, json=json_out) diff --git a/aai_cli/timeparse.py b/aai_cli/timeparse.py index 5ce18760..d7475436 100644 --- a/aai_cli/timeparse.py +++ b/aai_cli/timeparse.py @@ -3,18 +3,27 @@ from datetime import UTC, date, datetime, time -def parse_iso_utc(value: object) -> datetime | None: - if not isinstance(value, str) or not value: - return None +def _parse_iso_datetime(value: str) -> datetime | None: + """Parse an ISO date or datetime string to a (possibly naive) datetime. + + A bare date (no ``T``) becomes midnight; ``Z`` is accepted as the UTC suffix. + Returns ``None`` when the string isn't a valid ISO date/datetime. + """ text = value[:-1] + "+00:00" if value.endswith("Z") else value try: - parsed = ( - datetime.fromisoformat(text) - if "T" in text - else datetime.combine(date.fromisoformat(text), time.min) - ) + if "T" in text: + return datetime.fromisoformat(text) + return datetime.combine(date.fromisoformat(text), time.min) except ValueError: return None + + +def parse_iso_utc(value: object) -> datetime | None: + if not isinstance(value, str) or not value: + return None + parsed = _parse_iso_datetime(value) + if parsed is None: + return None if parsed.tzinfo is None: return parsed.replace(tzinfo=UTC) return parsed.astimezone(UTC) diff --git a/pyproject.toml b/pyproject.toml index eba774d6..41f49b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,8 +165,8 @@ target-version = "py312" # boolean traps, pylint-style design issues, centralized raw output, pytest style, # small simplifications, performance footguns, and type-only import hygiene. select = ["E", "F", "I", "UP", "B", "BLE", "C4", "SIM", "RET", "PTH", "ARG", "S", "RUF", - "PGH", "ERA", "TRY", "TD", "FIX", "A", "N", "FBT", "PL", "T20", "PT", "PIE", - "PERF", "TCH"] + "PGH", "ERA", "TRY", "TD", "FIX", "A", "N", "FBT", "PL", "C90", "T20", "PT", + "PIE", "PERF", "TCH"] # E501: line length is owned by the formatter. # B008: Typer uses function calls (typer.Option/Argument) as parameter defaults. # S603/S607: we intentionally shell out to `claude`/`npx` with controlled args. @@ -179,8 +179,18 @@ select = ["E", "F", "I", "UP", "B", "BLE", "C4", "SIM", "RET", "PTH", "ARG", "S" ignore = ["E501", "B008", "S603", "S607", "TRY003", "N818", "PLC0415", "TC001", "TC002", "TC003"] +# Function-size pressure, tuned to keep functions small enough to read and edit in +# one screen (the friction a coding agent hits most). These complement xenon's +# cyclomatic-complexity gate in check.sh: mccabe (C901) and max-branches bound +# branchiness; max-statements bounds raw length; max-args bounds signatures. +[tool.ruff.lint.mccabe] +max-complexity = 10 # matches xenon's grade-B ceiling (CC <= 10) so the two agree + [tool.ruff.lint.pylint] -max-args = 10 +max-args = 8 +max-branches = 12 +max-returns = 6 +max-statements = 40 [tool.ruff.lint.per-file-ignores] # Tests assert freely, use throwaway args/temp paths, and don't need pathlib/security lints. @@ -190,8 +200,10 @@ max-args = 10 "tests/**" = ["S101", "S105", "S106", "S107", "S108", "ARG001", "ARG002", "ARG005", "PTH123", "SIM117", "TRY300", "FBT", "PLR2004", "PLC0415", "PLR0913", "PLW1510", "N806", "N818", "PLW0108", "PT018", "TCH"] -# Typer command functions naturally have many boolean options and broad signatures. -"aai_cli/commands/**" = ["FBT001", "FBT003", "PLR0912", "PLR0913", "PLR0915"] +# Typer command functions naturally have many boolean options and broad signatures +# (PLR0913/FBT). Their *bodies*, though, are held to the same length/branch limits as +# the rest of the package: PLR0912/PLR0915/C901 are deliberately NOT ignored here. +"aai_cli/commands/**" = ["FBT001", "FBT003", "PLR0913"] # The root callback is also a Typer command signature. "aai_cli/main.py" = ["FBT001", "FBT003"] # Raw stdout/stderr writes are centralized here; command modules call output helpers. diff --git a/scripts/check.sh b/scripts/check.sh index 8fe87b71..8c862d4f 100755 --- a/scripts/check.sh +++ b/scripts/check.sh @@ -48,11 +48,16 @@ uv run lint-imports echo "==> xenon (cyclomatic complexity gate, src only)" # Fail the build if any function gets too branchy. Grades map to cyclomatic # complexity: A=1-5, B=6-10, C=11-20, ... Thresholds: -# --max-absolute B : no single function may exceed CC 10 (grade B). -# --max-modules B : no file's average may exceed grade B. +# --max-absolute B : no single function may exceed CC 10 (grade B). Pairs with ruff's +# mccabe max-complexity=10 (C901); xenon/radon also counts boolean +# operators, so it's the stricter of the two on the same number. +# Raw length/arg limits live in ruff (PLR0915/C901/PLR0913) — +# xenon only measures branching. +# --max-modules A : no file's *average* may exceed grade A (CC <= 5), so no single +# module is allowed to become a complexity hotspot on average. # --max-average A : the project-wide average must stay grade A (CC <= 5). # Tests are excluded (not shipped); only the aai_cli package is gated. -uv run xenon --max-absolute B --max-modules B --max-average A aai_cli +uv run xenon --max-absolute B --max-modules A --max-average A aai_cli echo "==> swiftlint (macOS audio helper)" if command -v swiftlint >/dev/null 2>&1; then diff --git a/tests/test_agent_command.py b/tests/test_agent_command.py index d59e9518..075c461b 100644 --- a/tests/test_agent_command.py +++ b/tests/test_agent_command.py @@ -51,18 +51,7 @@ def fake_run_session(api_key, **_kwargs): def test_agent_drives_renderer_json(monkeypatch): config.set_api_key("default", "sk_live") - def fake_run_session( - api_key, - *, - renderer, - player, - mic, - voice, - system_prompt, - greeting, - full_duplex=False, - exit_after_reply=False, - ): + def fake_run_session(api_key, *, renderer, player, mic, config): renderer.connected() renderer.user_final("hello agent") renderer.agent_transcript("hello human", interrupted=False) @@ -79,21 +68,10 @@ def test_agent_passes_voice_and_prompt_file(monkeypatch, tmp_path): config.set_api_key("default", "sk_live") seen = {} - def fake_run_session( - api_key, - *, - renderer, - player, - mic, - voice, - system_prompt, - greeting, - full_duplex=False, - exit_after_reply=False, - ): - seen["voice"] = voice - seen["prompt"] = system_prompt - seen["full_duplex"] = full_duplex + def fake_run_session(api_key, *, renderer, player, mic, config): + seen["voice"] = config.voice + seen["prompt"] = config.system_prompt + seen["full_duplex"] = config.full_duplex monkeypatch.setattr("aai_cli.commands.agent.run_session", fake_run_session) prompt_file = tmp_path / "p.txt" @@ -175,9 +153,9 @@ def test_agent_file_source_streams_clip_and_exits_after_reply(monkeypatch, tmp_p assert result.exit_code == 0 # File input drives a deterministic, headless, self-terminating session. assert seen["mic"] == f"filesrc:{wav}" - assert seen["exit_after_reply"] is True - assert seen["full_duplex"] is True - assert seen["greeting"] == "" + assert seen["config"].exit_after_reply is True + assert seen["config"].full_duplex is True + assert seen["config"].greeting == "" from aai_cli.agent.audio import NullPlayer assert isinstance(seen["player"], NullPlayer) @@ -197,7 +175,7 @@ def fake_file_source(src): result = runner.invoke(app, ["agent", "--sample"]) assert result.exit_code == 0 assert captured["src"].endswith("wildfires.mp3") - assert seen["exit_after_reply"] is True + assert seen["config"].exit_after_reply is True def test_agent_file_source_with_device_exits_2(monkeypatch, tmp_path): diff --git a/tests/test_agent_session.py b/tests/test_agent_session.py index cc7804f5..01a0be86 100644 --- a/tests/test_agent_session.py +++ b/tests/test_agent_session.py @@ -3,7 +3,12 @@ import pytest -from aai_cli.agent.session import VoiceAgentSession, _send_audio_loop, run_session +from aai_cli.agent.session import ( + AgentRunConfig, + VoiceAgentSession, + _send_audio_loop, + run_session, +) from aai_cli.errors import APIError, CLIError, NotAuthenticated @@ -216,9 +221,7 @@ def bad_connect(url, **kwargs): renderer=FakeRenderer(), player=FakePlayer(), mic=[], - voice="ivy", - system_prompt="x", - greeting="hi", + config=AgentRunConfig(voice="ivy", system_prompt="x", greeting="hi"), connect=bad_connect, ) @@ -241,9 +244,7 @@ def close(self): renderer=FakeRenderer(), player=player, mic=[], - voice="ivy", - system_prompt="x", - greeting="hi", + config=AgentRunConfig(voice="ivy", system_prompt="x", greeting="hi"), connect=lambda url, **kwargs: FakeWS(), ) assert player.closed is True # speaker stream still torn down @@ -278,9 +279,7 @@ def close(self): renderer=FakeRenderer(), player=FakePlayer(), mic=_BoomMic(), - voice="ivy", - system_prompt="x", - greeting="hi", + config=AgentRunConfig(voice="ivy", system_prompt="x", greeting="hi"), connect=lambda url, **kwargs: _BlockingWS(), ) assert exc.value.exit_code == 1 # the real mic failure reaches the user, not a hang @@ -296,9 +295,7 @@ def boom(url, **kwargs): renderer=FakeRenderer(), player=FakePlayer(), mic=[], - voice="ivy", - system_prompt="x", - greeting="hi", + config=AgentRunConfig(voice="ivy", system_prompt="x", greeting="hi"), connect=boom, ) @@ -395,11 +392,13 @@ def close(self): renderer=renderer, player=FakePlayer(), mic=[], # capture thread waits for ready, then this empty source ends at once - voice="ivy", - system_prompt="x", - greeting="", - full_duplex=True, - exit_after_reply=True, + config=AgentRunConfig( + voice="ivy", + system_prompt="x", + greeting="", + full_duplex=True, + exit_after_reply=True, + ), connect=lambda url, **kwargs: _ScriptedWS(), ) finals = [c for c in renderer.calls if c[0] == "user_final"] @@ -430,9 +429,7 @@ def capture(url, **kwargs): renderer=FakeRenderer(), player=FakePlayer(), mic=[], - voice="ivy", - system_prompt="x", - greeting="hi", + config=AgentRunConfig(voice="ivy", system_prompt="x", greeting="hi"), connect=capture, ) assert seen["url"] == expected