Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions aai_cli/agent/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,6 @@ def _send_audio_loop(ws: _WebSocket, session: VoiceAgentSession, mic: _IO) -> No
return


def _open_ws(connect: _Connect, api_key: str) -> _WebSocket:
"""Open the Voice Agent socket, mapping a connect failure to a clean CLIError.

A rejected handshake (HTTP 401/403) gets the shared actionable suggestion
(whoami / environment / network); anything else keeps the wsutil mapping.
"""
try:
return connect(ws_url(), additional_headers={"Authorization": f"Bearer {api_key}"})
except Exception as exc:
raise diagnostics.classify_error(
exc, "Could not connect to the voice agent", host=environments.active().agents_host
) from exc


def _session_update_message(config: AgentRunConfig) -> str:
"""The initial session.update payload as a JSON string: persona, greeting, voice."""
return json.dumps(
Expand Down Expand Up @@ -270,7 +256,13 @@ def run_session(
ready_event=ready_event,
)

ws = _open_ws(connect, api_key)
ws = diagnostics.open_authorized_ws(
connect,
api_key,
ws_url(),
message="Could not connect to the voice agent",
host=environments.active().agents_host,
)

# The mic opens lazily on first iteration, inside the capture thread; a failure
# there (no device, sounddevice missing) must reach the user instead of vanishing
Expand Down
24 changes: 24 additions & 0 deletions aai_cli/streaming/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from __future__ import annotations

import logging
from collections.abc import Callable

from aai_cli import ws as wsutil
from aai_cli.errors import APIError, CLIError, NotAuthenticated
Expand Down Expand Up @@ -77,3 +78,26 @@ def classify_error(error: object, message: str, *, host: str) -> CLIError:
if rejected is not None:
return rejected
return wsutil.auth_or_api_error(error, message)


def open_authorized_ws[T](
connect: Callable[..., T],
api_key: str,
url: str,
*,
message: str,
host: str,
**connect_kwargs: object,
) -> T:
"""Open a Bearer-authorized WebSocket, mapping a connect failure via ``classify_error``.

The one connect path for the raw-websocket sessions (agent, speak), so a
rejected handshake (HTTP 401/403) carries the same actionable suggestion in
both and everything else keeps the shared classification.
"""
try:
return connect(
url, additional_headers={"Authorization": f"Bearer {api_key}"}, **connect_kwargs
)
except Exception as exc:
raise classify_error(exc, message, host=host) from exc
29 changes: 8 additions & 21 deletions aai_cli/tts/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,6 @@ def _default_connect(
return connect(url, additional_headers=additional_headers, max_size=max_size)


def _open_ws(connect: _Connect, api_key: str, url: str) -> _WebSocket:
"""Open the TTS socket, mapping a connect failure to a clean CLIError.

A rejected handshake (HTTP 401/403) gets the shared actionable suggestion
(whoami / environment / network); anything else keeps the wsutil mapping.
"""
try:
return connect(
url,
additional_headers={"Authorization": f"Bearer {api_key}"},
max_size=None,
)
except Exception as exc:
raise diagnostics.classify_error(
exc,
"Could not connect to the TTS service",
host=environments.active().streaming_tts_host,
) from exc


def _run_protocol(
ws: _WebSocket, config: SpeakConfig, on_warning: Callable[[str], None] | None
) -> SpeakResult:
Expand Down Expand Up @@ -202,7 +182,14 @@ def synthesize(
if connect is None:
connect = _default_connect

ws = _open_ws(connect, api_key, ws_url(config.query_params()))
ws = diagnostics.open_authorized_ws(
connect,
api_key,
ws_url(config.query_params()),
message="Could not connect to the TTS service",
host=environments.active().streaming_tts_host,
max_size=None, # no frame cap: a synthesis's Audio frames can exceed the 1 MiB default
)
try:
return _run_protocol(ws, config, on_warning)
except (CLIError, KeyboardInterrupt, BrokenPipeError):
Expand Down
10 changes: 8 additions & 2 deletions tests/test_tts_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def test_synthesize_drives_the_full_protocol():
assert result.audio_duration_seconds == pytest.approx(6 / 2 / 24000)
# Auth header carries the key as a Bearer token.
assert captured["kwargs"]["additional_headers"]["Authorization"] == "Bearer k"
# The frame-size cap is lifted: Audio frames can exceed websockets' 1 MiB default.
assert captured["kwargs"]["max_size"] is None
assert ws.closed is True


Expand Down Expand Up @@ -280,6 +282,8 @@ def _connect(*_a, **_k):


def test_synthesize_maps_forbidden_connect_to_api_error():
_use_env("sandbox000")

class Resp:
status_code = 403

Expand All @@ -292,10 +296,11 @@ def _connect(*_a, **_k):
with pytest.raises(APIError) as exc:
session.synthesize("k", session.SpeakConfig(text="hi"), connect=_connect)
assert "Could not connect to the TTS service" in exc.value.message
# The rejected handshake carries the actionable next steps.
# The rejected handshake carries the actionable next steps, env host included.
assert exc.value.suggestion is not None
assert "assembly whoami" in exc.value.suggestion
assert "--sandbox" in exc.value.suggestion
assert "streaming-tts.sandbox000" in exc.value.suggestion


def test_synthesize_handshake_401_is_not_authenticated_with_suggestion():
Expand Down Expand Up @@ -340,7 +345,8 @@ def test_synthesize_maps_unexpected_protocol_error_to_api_error():
def test_synthesize_without_connect_uses_real_client_and_fails_cleanly():
# No `connect` provided: synthesize imports websockets' real sync client and
# attempts a connection. pytest-socket blocks socket creation, so this must
# surface as a clean CLIError (mapped in _open_ws), never a raw socket error.
# surface as a clean CLIError (mapped in diagnostics.open_authorized_ws),
# never a raw socket error.
_use_env("sandbox000")
with pytest.raises(CLIError):
session.synthesize("k", session.SpeakConfig(text="hi"))
Expand Down
Loading