Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 63 additions & 34 deletions aai_cli/agent/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import threading
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

from aai_cli import environments
Expand All @@ -26,15 +27,39 @@ def _ws_url() -> str:
# session.error codes that mean the connection is unauthorized -> exit 2.
_AUTH_ERROR_CODES = {"UNAUTHORIZED", "FORBIDDEN"}

# The websocket connection, the `connect` factory, and the renderer/player/mic I/O
# objects come from libraries/modules with no usable type stubs. Alias that untyped
# boundary here so each role is named in signatures and `Any` stays in one place.
# (Server event payloads remain `dict[str, Any]`.)
_WebSocket = Any
_Connect = Any
_IO = Any


@dataclass(frozen=True)
class AgentRunConfig:
"""The static (per-run) configuration for a Voice Agent session.

Bundled into one value so `run_session`'s signature stays small: the I/O
objects (renderer/player/mic) vary per call, but these knobs are fixed once
the command has parsed its flags.
"""

voice: str
system_prompt: str
greeting: str
full_duplex: bool = False
exit_after_reply: bool = False


class VoiceAgentSession:
"""Routes Voice Agent server events to the renderer, player, and duplex state."""

def __init__(
self,
*,
renderer: Any,
player: Any,
renderer: _IO,
player: _IO,
full_duplex: bool = False,
exit_after_reply: bool = False,
ready_event: threading.Event | None = None,
Expand Down Expand Up @@ -144,7 +169,7 @@ def raise_error(self, event: dict[str, Any]) -> None:
}


def _send_audio_loop(ws: Any, session: VoiceAgentSession, mic: Any) -> None:
def _send_audio_loop(ws: _WebSocket, session: VoiceAgentSession, mic: _IO) -> None:
"""Forward mic PCM as input.audio while the session gate allows it."""
# File-driven runs wait for session.ready before consuming the source, so a
# finite clip isn't partly drained (and dropped) before the server accepts it.
Expand All @@ -168,43 +193,61 @@ def _auth_or_api_error(exc: Exception, message: str) -> CLIError:
return APIError(f"{message}: {exc}")


def _open_ws(connect: Any, api_key: str) -> Any:
def _open_ws(connect: _Connect, api_key: str) -> _WebSocket:
"""Open the Voice Agent socket, mapping a connect failure to a clean CLIError."""
try:
return connect(_ws_url(), additional_headers={"Authorization": f"Bearer {api_key}"})
except Exception as exc:
raise _auth_or_api_error(exc, "Could not connect to the voice agent") from exc


def _session_update_message(config: AgentRunConfig) -> str:
"""The initial session.update payload as a JSON string: persona, greeting, voice."""
return json.dumps(
{
"type": "session.update",
"session": {
"system_prompt": config.system_prompt,
"greeting": config.greeting,
"output": {"voice": config.voice},
},
}
)


def _receive_loop(ws: _WebSocket, session: VoiceAgentSession) -> None:
"""Dispatch inbound server events until the socket closes or the run finishes."""
for raw in ws:
session.dispatch(json.loads(raw))
if session.finished:
break


def run_session(
api_key: str,
*,
renderer: Any,
player: Any,
mic: Any,
voice: str,
system_prompt: str,
greeting: str,
full_duplex: bool = False,
exit_after_reply: bool = False,
connect: Any = None,
renderer: _IO,
player: _IO,
mic: _IO,
config: AgentRunConfig,
connect: _Connect = None,
) -> None:
"""Open the Voice Agent WebSocket and run the bidirectional loop until close.

`connect` defaults to websockets' synchronous client; injectable for tests.
When `exit_after_reply` is set (file-driven runs), the loop stops after the
agent's first reply to the spoken input and the capture thread waits for
When `config.exit_after_reply` is set (file-driven runs), the loop stops after
the agent's first reply to the spoken input and the capture thread waits for
session.ready before streaming the source.
"""
if connect is None:
from websockets.sync.client import connect

ready_event = threading.Event() if exit_after_reply else None
ready_event = threading.Event() if config.exit_after_reply else None
session = VoiceAgentSession(
renderer=renderer,
player=player,
full_duplex=full_duplex,
exit_after_reply=exit_after_reply,
full_duplex=config.full_duplex,
exit_after_reply=config.exit_after_reply,
ready_event=ready_event,
)

Expand All @@ -231,22 +274,8 @@ def _capture() -> None:
player.start() # opens the speaker stream; CLIError here if sounddevice can't load
player_started = True
threading.Thread(target=_capture, daemon=True).start()
ws.send(
json.dumps(
{
"type": "session.update",
"session": {
"system_prompt": system_prompt,
"greeting": greeting,
"output": {"voice": voice},
},
}
)
)
for raw in ws:
session.dispatch(json.loads(raw))
if session.finished:
break
ws.send(_session_update_message(config))
_receive_loop(ws, session)
except (CLIError, KeyboardInterrupt, BrokenPipeError):
raise # clean CLI errors, user Ctrl-C, and a closed pipe are handled upstream
except Exception as exc:
Expand Down
46 changes: 27 additions & 19 deletions aai_cli/code_gen/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,38 +96,46 @@ def on_turn(client: StreamingClient, event: TurnEvent) -> None:
"""


def render(merged: dict[str, object], *, llm: dict[str, object] | None = None) -> str:
"""Generate a runnable microphone-streaming script with the given params.

With `llm`, the script transforms the live transcript through the LLM Gateway,
refreshing a prompt chain on every finalized turn (the live sibling of
`transcribe --llm`).
"""
def _imports_block(merged: dict[str, object]) -> str:
"""Sorted streaming-class import lines; SpeechModel only when a model kwarg is emitted."""
names = list(_BASE_IMPORTS)
if "speech_model" in merged:
names.append("SpeechModel")
imports = "\n".join(f" {name}," for name in sorted(names))
return "\n".join(f" {name}," for name in sorted(names))


def _build_preamble(imports: str, llm: dict[str, object] | None) -> str:
"""Pick and fill the plain vs. LLM-Gateway preamble for the given imports."""
if llm:
prompts = "\n".join(f" {p!r}," for p in cast("list[str]", llm["prompts"]))
preamble = _LLM_PREAMBLE.format(
return _LLM_PREAMBLE.format(
imports=imports,
base_url=gateway.GATEWAY_BASE_URL,
prompts=prompts,
model=llm["model"],
max_tokens=llm["max_tokens"],
)
else:
preamble = _PREAMBLE.format(imports=imports)
return _PREAMBLE.format(imports=imports)

# Mic capture rate must match StreamingParameters.sample_rate, else audio is corrupt.
rate = merged.get("sample_rate", 16000)

if merged:
# indent=8: 4 for connect(), 4 more for the StreamingParameters() args.
kwargs = "\n".join(serialize.config_kwarg_lines(merged, indent=8))
connect = f"client.connect(\n StreamingParameters(\n{kwargs}\n )\n)"
else:
connect = "client.connect(StreamingParameters())"
def _build_connect(merged: dict[str, object]) -> str:
"""The `client.connect(StreamingParameters(...))` call for the given params."""
if not merged:
return "client.connect(StreamingParameters())"
# indent=8: 4 for connect(), 4 more for the StreamingParameters() args.
kwargs = "\n".join(serialize.config_kwarg_lines(merged, indent=8))
return f"client.connect(\n StreamingParameters(\n{kwargs}\n )\n)"


def render(merged: dict[str, object], *, llm: dict[str, object] | None = None) -> str:
"""Generate a runnable microphone-streaming script with the given params.

With `llm`, the script transforms the live transcript through the LLM Gateway,
refreshing a prompt chain on every finalized turn (the live sibling of
`transcribe --llm`).
"""
preamble = _build_preamble(_imports_block(merged), llm)
# Mic capture rate must match StreamingParameters.sample_rate, else audio is corrupt.
rate = merged.get("sample_rate", 16000)
connect = _build_connect(merged)
return preamble + "\n" + connect + "\n" + _FOOTER.format(rate=rate)
101 changes: 58 additions & 43 deletions aai_cli/commands/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from aai_cli import client, code_gen, config, help_panels, output
from aai_cli.agent.audio import SAMPLE_RATE, DuplexAudio, NullPlayer
from aai_cli.agent.render import AgentRenderer
from aai_cli.agent.session import DEFAULT_GREETING, DEFAULT_PROMPT, run_session
from aai_cli.agent.session import (
DEFAULT_GREETING,
DEFAULT_PROMPT,
AgentRunConfig,
run_session,
)
from aai_cli.agent.voices import DEFAULT_VOICE, VOICES, format_voice_list
from aai_cli.context import AppState, run_command
from aai_cli.errors import CLIError, UsageError
Expand All @@ -19,6 +24,46 @@
app = typer.Typer()


def _resolve_system_prompt(system_prompt: str, system_prompt_file: Path | None) -> str:
"""The persona text: a --system-prompt-file (if given) overrides --system-prompt."""
if system_prompt_file is None:
return system_prompt
try:
return system_prompt_file.read_text(encoding="utf-8")
except OSError as exc:
raise CLIError(
f"Could not read --system-prompt-file {system_prompt_file}: {exc}",
error_type="file_not_found",
exit_code=2,
suggestion="Check the path and that the file is readable.",
) from exc


def _open_audio(
renderer: AgentRenderer,
*,
source: str | None,
sample: bool,
device: int | None,
from_file: bool,
) -> tuple[Any, Any]:
"""Build the (mic, player) pair for either file-driven or live-mic input."""
if from_file:
# Stream the clip as the user's speech and stop after the agent replies.
# No greeting and full-duplex so no part of the clip is muted/dropped,
# and a NullPlayer since there is no listener for the reply audio.
return FileSource(client.resolve_audio_source(source, sample=sample)), NullPlayer()
# One full-duplex stream for mic + speaker: macOS rejects two separate
# streams on a device, which silently kills capture.
duplex = DuplexAudio(target_rate=SAMPLE_RATE, device=device)
# notice() self-suppresses in JSON mode and routes to stderr in text mode.
renderer.notice(
"Use headphones — the mic stays open while the agent speaks, "
"so speakers would let it hear itself.\n"
)
return duplex.mic, duplex.player


@app.command(
rich_help_panel=help_panels.TRANSCRIPTION,
epilog=examples_epilog(
Expand Down Expand Up @@ -83,18 +128,7 @@ def body(state: AppState, json_mode: bool) -> None:
f"Unknown voice {voice!r}.",
suggestion="Run 'aai agent --list-voices' to see the options.",
)
if system_prompt_file is not None:
try:
system_prompt_text = system_prompt_file.read_text(encoding="utf-8")
except OSError as exc:
raise CLIError(
f"Could not read --system-prompt-file {system_prompt_file}: {exc}",
error_type="file_not_found",
exit_code=2,
suggestion="Check the path and that the file is readable.",
) from exc
else:
system_prompt_text = system_prompt
system_prompt_text = _resolve_system_prompt(system_prompt, system_prompt_file)

if show_code:
# Print-only: emit the equivalent agent script from the flags and exit
Expand All @@ -112,37 +146,18 @@ def body(state: AppState, json_mode: bool) -> None:
text_mode=text_mode,
mic_input=not from_file,
)
audio: Any
player: Any
if from_file:
# Stream the clip as the user's speech and stop after the agent replies.
# No greeting and full-duplex so no part of the clip is muted/dropped,
# and a NullPlayer since there is no listener for the reply audio.
audio = FileSource(client.resolve_audio_source(source, sample=sample))
player = NullPlayer()
else:
# One full-duplex stream for mic + speaker: macOS rejects two separate
# streams on a device, which silently kills capture.
duplex = DuplexAudio(target_rate=SAMPLE_RATE, device=device)
audio = duplex.mic
player = duplex.player
# notice() self-suppresses in JSON mode and routes to stderr in text mode.
renderer.notice(
"Use headphones — the mic stays open while the agent speaks, "
"so speakers would let it hear itself.\n"
)
audio, player = _open_audio(
renderer, source=source, sample=sample, device=device, from_file=from_file
)
run_config = AgentRunConfig(
voice=voice,
system_prompt=system_prompt_text,
greeting="" if from_file else greeting,
full_duplex=True, # one duplex stream -> mic always open (use headphones)
exit_after_reply=from_file,
)
try:
run_session(
api_key,
renderer=renderer,
player=player,
mic=audio,
voice=voice,
system_prompt=system_prompt_text,
greeting="" if from_file else greeting,
full_duplex=True, # one duplex stream -> mic always open (use headphones)
exit_after_reply=from_file,
)
run_session(api_key, renderer=renderer, player=player, mic=audio, config=run_config)
except KeyboardInterrupt:
renderer.stopped()
except BrokenPipeError as exc:
Expand Down
Loading
Loading