Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion aai_cli/agent_cascade/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

from __future__ import annotations

from dataclasses import dataclass
from collections.abc import Mapping
from dataclasses import dataclass, field

from aai_cli.agent_cascade.voices import DEFAULT_VOICE
from aai_cli.core import llm

DEFAULT_MODEL = llm.DEFAULT_MODEL
DEFAULT_MAX_TOKENS = llm.DEFAULT_MAX_TOKENS
# The realtime model the cascade transcribes with (same as the agent-cascade template).
DEFAULT_SPEECH_MODEL = "u3-rt-pro"
DEFAULT_SYSTEM_PROMPT = (
"You are a friendly, concise voice assistant. Keep replies short and "
"conversational. Your reply is read aloud by a text-to-speech engine, so "
Expand All @@ -32,3 +36,13 @@ class CascadeConfig:
greeting: str = DEFAULT_GREETING
model: str = DEFAULT_MODEL
max_history: int = DEFAULT_MAX_HISTORY
# TTS language (None lets the server pick from the voice).
language: str | None = None
# LLM: cap per-reply tokens and pass through any extra gateway request fields.
max_tokens: int = DEFAULT_MAX_TOKENS
llm_extra: Mapping[str, object] = field(default_factory=dict[str, object])
# Extra streaming-TTS query params (the --tts-config escape hatch).
tts_extra: Mapping[str, str] = field(default_factory=dict[str, str])
# Whether STT formats finalized turns. The reply trigger waits for the formatted
# turn when on; with it off, an unformatted end-of-turn is the cue instead.
format_turns: bool = True
56 changes: 28 additions & 28 deletions aai_cli/agent_cascade/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from aai_cli.agent_cascade.config import CascadeConfig
from aai_cli.agent_cascade.text import split_sentences, trim_history
from aai_cli.core import client, config_builder, llm
from aai_cli.core import client, llm
from aai_cli.core.errors import CLIError
from aai_cli.tts import session as tts_session
from aai_cli.tts.session import SpeakConfig
Expand Down Expand Up @@ -96,23 +96,6 @@ def _spawn_thread(target: Callable[[], None]) -> _Worker:
return thread


# The realtime model the cascade transcribes with (same as the agent-cascade template).
STT_SPEECH_MODEL = "u3-rt-pro"


def _stt_params(sample_rate: int) -> StreamingParameters:
"""Streaming v3 params for the cascade: PCM at ``sample_rate`` with formatted turns
(so ``turn_is_formatted`` marks the cue to reply)."""
merged = config_builder.merge_streaming_params(
flags={
"sample_rate": sample_rate,
"format_turns": True,
"speech_model": STT_SPEECH_MODEL,
}
)
return config_builder.construct_streaming_params(merged)


@dataclass
class CascadeDeps:
"""The cascade's three network legs plus its thread spawner, all injectable.
Expand All @@ -133,17 +116,29 @@ def real(
config: CascadeConfig,
*,
audio: Iterable[bytes],
sample_rate: int,
stt_params: StreamingParameters,
) -> CascadeDeps:
def run_stt(on_turn: Callable[[object], None]) -> None:
client.stream_audio(api_key, audio, params=_stt_params(sample_rate), on_turn=on_turn)
client.stream_audio(api_key, audio, params=stt_params, on_turn=on_turn)

def complete_reply(messages: list[ChatCompletionMessageParam]) -> str:
response = llm.complete(api_key, model=config.model, messages=messages)
response = llm.complete(
api_key,
model=config.model,
messages=messages,
max_tokens=config.max_tokens,
extra=dict(config.llm_extra) or None,
)
return llm.content_of(response)

def synthesize(text: str) -> bytes:
spec = SpeakConfig(text=text, voice=config.voice, sample_rate=TTS_SAMPLE_RATE)
spec = SpeakConfig(
text=text,
voice=config.voice,
language=config.language,
sample_rate=TTS_SAMPLE_RATE,
extra=config.tts_extra,
)
return tts_session.synthesize(api_key, spec).pcm

return cls(run_stt=run_stt, complete_reply=complete_reply, synthesize=synthesize)
Expand Down Expand Up @@ -186,7 +181,7 @@ def on_turn(self, event: object) -> None:
text = (getattr(event, "transcript", "") or "").strip()
if not text:
return
if _is_final_turn(event):
if _is_final_turn(event, format_turns=self.config.format_turns):
self.renderer.user_final(text)
self._barge_in()
self.history.append({"role": "user", "content": text})
Expand Down Expand Up @@ -261,11 +256,16 @@ def shutdown(self) -> None:
self._join_reply()


def _is_final_turn(event: object) -> bool:
"""True for a finalized, formatted end-of-turn — the cue to generate a reply."""
return bool(getattr(event, "end_of_turn", False)) and bool(
getattr(event, "turn_is_formatted", False)
)
def _is_final_turn(event: object, *, format_turns: bool) -> bool:
"""True for an end-of-turn that's the cue to generate a reply.

With formatting on, wait for the *formatted* turn (better text for the LLM);
with it off the server never sets ``turn_is_formatted``, so a bare end-of-turn
is the cue — otherwise ``--no-format-turns`` would make the agent never reply.
"""
if not bool(getattr(event, "end_of_turn", False)):
return False
return bool(getattr(event, "turn_is_formatted", False)) or not format_turns


def run_cascade(
Expand Down
76 changes: 76 additions & 0 deletions aai_cli/commands/agent_cascade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@
from aai_cli.agent_cascade import voices
from aai_cli.agent_cascade.config import (
DEFAULT_GREETING,
DEFAULT_MAX_TOKENS,
DEFAULT_MODEL,
DEFAULT_SPEECH_MODEL,
DEFAULT_SYSTEM_PROMPT,
)
from aai_cli.agent_cascade.voices import DEFAULT_VOICE
from aai_cli.app.context import AppState, run_command, run_with_options
from aai_cli.commands.agent_cascade import _exec as agent_cascade_exec
from aai_cli.core import choices, llm
from aai_cli.streaming.turn_presets import TurnDetectionPreset
from aai_cli.ui import output
from aai_cli.ui.help_text import examples_epilog

# Option panels that group the per-leg knobs in `--help` instead of one flat wall.
_PANEL_STT = "Speech-to-text"
_PANEL_LLM = "Language model"
_PANEL_TTS = "Text-to-speech"

app = typer.Typer()

SPEC = command_registry.CommandModuleSpec(
Expand Down Expand Up @@ -65,12 +73,71 @@ def agent_cascade(
"--voice",
help="TTS voice. See --list-voices.",
autocompletion=voices.complete_voice,
rich_help_panel=_PANEL_TTS,
),
language: str | None = typer.Option(
None,
"--language",
help="TTS language (defaults to the voice's language)",
rich_help_panel=_PANEL_TTS,
),
tts_config: list[str] | None = typer.Option(
None,
"--tts-config",
help="Set any extra streaming-TTS query field as KEY=VALUE (repeatable)",
rich_help_panel=_PANEL_TTS,
),
model: str = typer.Option(
DEFAULT_MODEL,
"--model",
help="LLM Gateway model that powers the agent's replies",
autocompletion=llm.complete_model,
rich_help_panel=_PANEL_LLM,
),
max_tokens: int = typer.Option(
DEFAULT_MAX_TOKENS,
"--max-tokens",
help="Max tokens per reply",
min=1,
rich_help_panel=_PANEL_LLM,
),
llm_config: list[str] | None = typer.Option(
None,
"--llm-config",
help="Set any LLM Gateway request field as KEY=VALUE (repeatable)",
rich_help_panel=_PANEL_LLM,
),
speech_model: str = typer.Option(
DEFAULT_SPEECH_MODEL,
"--speech-model",
help="Streaming speech model",
rich_help_panel=_PANEL_STT,
),
format_turns: bool = typer.Option(
True,
"--format-turns/--no-format-turns",
help="Format (punctuate) finalized turns before replying",
rich_help_panel=_PANEL_STT,
),
turn_detection: TurnDetectionPreset | None = typer.Option(
None,
"--turn-detection",
help="Turn-detection sensitivity preset",
rich_help_panel=_PANEL_STT,
),
stt_config: list[str] | None = typer.Option(
None,
"--stt-config",
help="Set any StreamingParameters field as KEY=VALUE (repeatable)",
rich_help_panel=_PANEL_STT,
),
stt_config_file: Path | None = typer.Option(
None,
"--stt-config-file",
help="JSON file of streaming fields",
exists=True,
dir_okay=False,
rich_help_panel=_PANEL_STT,
),
system_prompt: str = typer.Option(
DEFAULT_SYSTEM_PROMPT, "--system-prompt", help="System prompt (the agent's persona)"
Expand Down Expand Up @@ -125,5 +192,14 @@ def agent_cascade(
greeting=greeting,
device=device,
output_field=output_field,
speech_model=speech_model,
format_turns=format_turns,
turn_detection=turn_detection,
stt_config=tuple(stt_config or ()),
stt_config_file=stt_config_file,
max_tokens=max_tokens,
llm_config=tuple(llm_config or ()),
language=language,
tts_config=tuple(tts_config or ()),
)
run_with_options(ctx, agent_cascade_exec.run_agent_cascade, opts, json=json_out)
79 changes: 77 additions & 2 deletions aai_cli/commands/agent_cascade/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

import typer

Expand All @@ -20,12 +21,24 @@
from aai_cli.agent_cascade.config import CascadeConfig
from aai_cli.app.agent_shared import resolve_system_prompt as _resolve_system_prompt
from aai_cli.app.context import AppState
from aai_cli.core import choices, client
from aai_cli.core import choices, client, config_builder, llm
from aai_cli.core.errors import UsageError
from aai_cli.streaming import turn_presets
from aai_cli.streaming.session import resolve_output_modes
from aai_cli.streaming.sources import FileSource
from aai_cli.tts import session as tts_session

if TYPE_CHECKING:
from assemblyai.streaming.v3 import StreamingParameters

# A --tts-config key that has its own named flag (or is owned by the cascade), with the
# message steering the user to the right place instead of silently fighting the cascade.
_RESERVED_TTS_KEYS: dict[str, str] = {
"voice": "Set the voice with --voice, not --tts-config.",
"language": "Set the language with --language, not --tts-config.",
"sample_rate": "TTS sample rate is fixed to match the live speaker and can't be overridden.",
}


@dataclass(frozen=True)
class AgentCascadeOptions:
Expand All @@ -45,6 +58,58 @@ class AgentCascadeOptions:
greeting: str
device: int | None
output_field: choices.TextOrJson | None
# Speech-to-text: common knobs named, everything else via --stt-config(-file).
speech_model: str
format_turns: bool
turn_detection: turn_presets.TurnDetectionPreset | None
stt_config: tuple[str, ...]
stt_config_file: Path | None
# Language model: token cap plus any extra gateway request field.
max_tokens: int
llm_config: tuple[str, ...]
# Text-to-speech: language named, any other query param via --tts-config.
language: str | None
tts_config: tuple[str, ...]


def _build_stt_params(opts: AgentCascadeOptions, sample_rate: int) -> StreamingParameters:
"""Construct the cascade's StreamingParameters from the STT flags + escape hatch.

A turn-detection preset expands into the three end-of-turn knobs; --stt-config /
--stt-config-file then override any field (including those knobs). sample_rate is
fixed by the audio source, so it's merged in here rather than user-set."""
eot, min_silence, max_silence = turn_presets.resolve(opts.turn_detection, None, None, None)
flags: dict[str, object] = {
"speech_model": opts.speech_model,
"format_turns": opts.format_turns,
"end_of_turn_confidence_threshold": eot,
"min_turn_silence": min_silence,
"max_turn_silence": max_silence,
}
merged = config_builder.merge_streaming_params(
flags=flags | {"sample_rate": sample_rate},
overrides=opts.stt_config or None,
config_file=opts.stt_config_file,
)
return config_builder.construct_streaming_params(merged)


def _parse_tts_config(pairs: tuple[str, ...]) -> dict[str, str]:
"""Parse --tts-config KEY=VALUE pairs into extra streaming-TTS query params,
rejecting keys that have a named flag (or are cascade-owned)."""
extra: dict[str, str] = {}
for pair in pairs:
key, sep, value = pair.partition("=")
key = key.strip()
if not sep or not key:
raise UsageError(
f"--tts-config expects KEY=VALUE, got {pair!r}.",
suggestion="e.g. --tts-config chunk_size_ms=100",
)
if key in _RESERVED_TTS_KEYS:
raise UsageError(_RESERVED_TTS_KEYS[key])
extra[key] = value
return extra


def _open_audio(
Expand Down Expand Up @@ -89,6 +154,10 @@ def run_agent_cascade(opts: AgentCascadeOptions, state: AppState, *, json_mode:
# Existence-check the clip before credentials, so a typo'd path reads as
# "file not found" instead of triggering a login.
client.resolve_audio_source(opts.source, sample=opts.sample)
# Parse the LLM/TTS escape hatches before opening the device, so a bad KEY=VALUE
# fails fast instead of after the mic is live.
llm_extra = llm.parse_gateway_overrides(opts.llm_config)
tts_extra = _parse_tts_config(opts.tts_config)
api_key = state.resolve_api_key()

config = CascadeConfig(
Expand All @@ -97,12 +166,18 @@ def run_agent_cascade(opts: AgentCascadeOptions, state: AppState, *, json_mode:
# File-driven runs speak a clip and end after the reply, so skip the greeting.
greeting="" if from_file else opts.greeting,
model=opts.model,
language=opts.language,
max_tokens=opts.max_tokens,
format_turns=opts.format_turns,
llm_extra=llm_extra,
tts_extra=tts_extra,
)
renderer = AgentRenderer(json_mode=json_mode, text_mode=text_mode, mic_input=not from_file)
audio, player, sample_rate = _open_audio(
renderer, source=opts.source, sample=opts.sample, device=opts.device, from_file=from_file
)
deps = engine.CascadeDeps.real(api_key, config, audio=audio, sample_rate=sample_rate)
stt_params = _build_stt_params(opts, sample_rate)
deps = engine.CascadeDeps.real(api_key, config, audio=audio, stt_params=stt_params)
try:
engine.run_cascade(renderer=renderer, player=player, config=config, deps=deps)
except KeyboardInterrupt:
Expand Down
14 changes: 10 additions & 4 deletions aai_cli/streaming/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,23 @@ def open_authorized_ws[T](
*,
message: str,
host: str,
bearer: bool = True,
**connect_kwargs: object,
) -> T:
"""Open a Bearer-authorized WebSocket, mapping a connect failure via ``classify_error``.
"""Open an ``Authorization``-headered WebSocket, mapping a connect failure via
``classify_error``.

The one connect path for the raw-websocket sessions (agent, speak), so a
rejected handshake (HTTP 401/403) carries the same actionable suggestion in
both and everything else keeps the shared classification.

``bearer`` selects the AssemblyAI auth scheme for the endpoint: the Voice Agent
socket expects a ``Bearer <key>`` token (the default), while the streaming
sockets (STT, TTS) authenticate with the **raw** key — pass ``bearer=False``
for those, or the server refuses the session with an in-band Error frame.
"""
token = f"Bearer {api_key}" if bearer else api_key
try:
return connect(
url, additional_headers={"Authorization": f"Bearer {api_key}"}, **connect_kwargs
)
return connect(url, additional_headers={"Authorization": token}, **connect_kwargs)
except Exception as exc:
raise classify_error(exc, message, host=host) from exc
Loading
Loading