Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aai_cli/argscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ def requests_json(raw_args: list[str]) -> bool:
if token in ("-o", "--output") and raw_args[index + 1 : index + 2] == ["json"]:
return True
return False


def requests_quiet(raw_args: list[str]) -> bool:
"""Whether the token list asked for quiet output: ``--quiet`` or ``-q``."""
return any(token in ("--quiet", "-q") for token in raw_args)
11 changes: 10 additions & 1 deletion aai_cli/auth/ams.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,16 @@ def _raise_for_error(resp: httpx.Response) -> None:

def _json_or_raise(resp: httpx.Response) -> object:
_raise_for_error(resp)
data: object = resp.json()
try:
data: object = resp.json()
except ValueError as exc:
# A 2xx with an unparseable body (proxy interference, truncation) must
# surface as a clean AMS error, not escape as a raw JSONDecodeError that
# run_command can only report as an internal bug.
raise APIError(
f"AMS returned a response that is not valid JSON (HTTP {resp.status_code}).",
suggestion="Check your network and try again; if it persists, contact support.",
) from exc
return data


Expand Down
29 changes: 24 additions & 5 deletions aai_cli/auth/loopback.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class CallbackCapture:
done: threading.Event
server: HTTPServer
thread: threading.Thread
lock: threading.Lock

def wait(
self,
Expand All @@ -55,7 +56,15 @@ def wait(
"""
try:
if not self.done.wait(timeout):
self.result.error = "timeout"
# Claim the capture under the lock: the handler thread may be
# processing a callback that arrived right at the deadline. If it
# already claimed (done set), its token result stands; otherwise the
# timeout claims it, and a late callback can no longer mutate the
# result this method is about to hand to the caller.
with self.lock:
if not self.done.is_set():
self.result.error = "timeout"
self.done.set()
finally:
self.server.shutdown() # stop serve_forever()
self.thread.join(timeout=5) # pragma: no mutate (cleanup grace period only)
Expand All @@ -71,9 +80,13 @@ def start_capture() -> CallbackCapture:
before opening the browser. Only a callback to the registered path that carries
a `token` is accepted; any other request (a different path, or no token) gets a
4xx and the server keeps waiting, so a stray request can't end the capture early.
The first matching callback wins: a duplicate (browser reload/double-click, or
anything else hitting the loopback port afterwards) is acknowledged but can never
overwrite the captured token.
"""
result = CallbackResult()
done = threading.Event()
lock = threading.Lock()

class Handler(BaseHTTPRequestHandler):
def do_GET(self) -> None: # stdlib API name
Expand All @@ -91,13 +104,19 @@ def do_GET(self) -> None: # stdlib API name
self.send_response(400)
self.end_headers()
return
result.token = token
result.token_type = next(iter(qs.get("stytch_token_type", [])), None)
# First claim wins: once the capture is done (a prior callback, or the
# timeout in wait()), the result is already in the caller's hands, so a
# late or duplicate callback must not mutate it. The lock pairs with
# wait()'s timeout claim so the two threads can't interleave mid-write.
with lock:
if not done.is_set():
result.token = token
result.token_type = next(iter(qs.get("stytch_token_type", [])), None)
done.set()
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(_SUCCESS_HTML)
done.set()

def log_message(self, format: str, *args: object) -> None: # silence stderr logging
pass
Expand All @@ -113,7 +132,7 @@ def log_message(self, format: str, *args: object) -> None: # silence stderr log
) from exc
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
return CallbackCapture(result=result, done=done, server=server, thread=thread)
return CallbackCapture(result=result, done=done, server=server, thread=thread, lock=lock)


def capture_callback(
Expand Down
6 changes: 5 additions & 1 deletion aai_cli/code_gen/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ def on_audio(indata, outdata, _frames, _time, _status):
def send_mic(ws):
while True:
chunk = mic_queue.get()
if ready.is_set():
if not ready.is_set():
continue
try:
ws.send(json.dumps({{"type": "input.audio", "audio": base64.b64encode(chunk).decode()}}))
except Exception:
return # socket closed (session over): end the mic thread quietly


stream = sd.RawStream(
Expand Down
5 changes: 4 additions & 1 deletion aai_cli/commands/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,14 @@ def render(data: list[dict[str, object]]) -> object:
"model",
)
for s in data:
# `is None`, not truthiness: 0 is a legitimate duration (a session
# that connected but streamed no audio) and must render as "0".
duration = s.get("audio_duration_sec")
table.add_row(
escape(str(s["session_id"])),
theme.status_text(str(s["status"])),
escape(timeparse.format_utc_datetime(s.get("created_at"))),
escape(str(s.get("audio_duration_sec") or "")),
escape("" if duration is None else str(duration)),
escape(str(s.get("speech_model") or "")),
)
return table
Expand Down
4 changes: 3 additions & 1 deletion aai_cli/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def _auto_login_and_exit(state: AppState, *, json_mode: bool) -> NoReturn:
except CLIError as login_err:
output.emit_error(login_err, json_mode=json_mode)
raise typer.Exit(code=login_err.exit_code) from None
except (OSError, RuntimeError, keyring.errors.KeyringError) as exc:
except (OSError, RuntimeError, TypeError, keyring.errors.KeyringError) as exc:
# TypeError covers a value the TOML writer can't serialize: the login itself
# succeeded, so the user must see "could not save", not "unexpected error".
persistence_err = _login_persistence_error(exc)
output.emit_error(persistence_err, json_mode=json_mode)
raise typer.Exit(code=persistence_err.exit_code) from None
Expand Down
20 changes: 1 addition & 19 deletions aai_cli/streaming/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
# CLIError already carries the message, so the logger is raised above ERROR.
SDK_STREAMING_LOGGER = "assemblyai.streaming"

# Handshake statuses that mean the server refused the connection outright.
_HANDSHAKE_AUTH_STATUSES = (401, 403)
_UNAUTHORIZED = 401


Expand All @@ -46,22 +44,6 @@ def handshake_suggestion(host: str) -> str:
)


def _handshake_status(error: object) -> int | None:
"""The HTTP status of a rejected WebSocket handshake (401/403), else None.

Reads the two structured shapes only — the assemblyai SDK's StreamingError
carries the status on ``.code``; websockets' InvalidStatus carries it on
``.response.status_code`` — never the message text.
"""
code = getattr(error, "code", None)
if code in _HANDSHAKE_AUTH_STATUSES:
return int(code)
status = getattr(getattr(error, "response", None), "status_code", None)
if status in _HANDSHAKE_AUTH_STATUSES:
return int(status)
return None


def handshake_error(error: object, message: str, *, host: str) -> CLIError | None:
"""An auth-flavored CLIError for a handshake 401/403, else None.

Expand All @@ -70,7 +52,7 @@ def handshake_error(error: object, message: str, *, host: str) -> CLIError | Non
stays an APIError — it also covers WAF/region/plan blocks, mirroring
``aai_cli.ws.is_rejected_key`` — but now carries the same suggestion.
"""
status = _handshake_status(error)
status = wsutil.handshake_status(error)
if status is None:
return None
if status == _UNAUTHORIZED:
Expand Down
12 changes: 10 additions & 2 deletions aai_cli/streaming/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import typer

from aai_cli import choices, client, config_builder, llm, output
from aai_cli.errors import CLIError, UsageError
from aai_cli.errors import APIError, CLIError, UsageError
from aai_cli.follow import FollowRenderer
from aai_cli.streaming.render import StreamRenderer, speaker_prefix

Expand Down Expand Up @@ -284,7 +284,15 @@ def _drive(self, streams: _ParallelStreams) -> None:

def worker(source_label: str, audio: Iterable[bytes], rate: int) -> None:
try:
self.stream_one(audio, rate, source_label=source_label)
try:
self.stream_one(audio, rate, source_label=source_label)
except (CLIError, BrokenPipeError):
raise
except Exception as exc:
# A non-CLIError here is a bug, but it must still fail the run:
# uncaught, it dies with this daemon thread and the command
# exits 0 for a stream that actually failed.
raise APIError(f"Streaming worker ({source_label}) failed: {exc}") from exc
except (CLIError, BrokenPipeError) as exc:
errors.put(exc)

Expand Down
2 changes: 1 addition & 1 deletion aai_cli/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _notice_suppressed(raw_args: list[str]) -> bool:
``--quiet`` run nor pollute the machine-readable stderr a ``--json`` (or
``-o json``) pipeline relies on.
"""
return any(token in ("--quiet", "-q") for token in raw_args) or argscan.requests_json(raw_args)
return argscan.requests_quiet(raw_args) or argscan.requests_json(raw_args)


def _maybe_emit_first_run_notice() -> None:
Expand Down
23 changes: 20 additions & 3 deletions aai_cli/tts/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class _WebSocket(Protocol):
"""The slice of a websockets sync connection this module drives — named as a
Protocol so the untyped library boundary is structurally typed, not opaque."""

def recv(self) -> str | bytes: ...
def recv(self, timeout: float | None = None) -> str | bytes: ...
def send(self, data: str, /) -> None: ... # positional-only: matches ws send(message)
def close(self) -> None: ...

Expand All @@ -38,6 +38,11 @@ def close(self) -> None: ...
# Pause inserted between speaker turns in a multi-voice dialogue, for natural pacing.
_INTER_TURN_SILENCE_SECONDS = 0.25

# Bound on the wait for each protocol frame. The server streams frames continuously
# while a synthesis is in flight, so a gap this long means it went silent mid-session;
# without a bound, `assembly speak` would hang forever instead of failing cleanly.
_RECV_TIMEOUT_SECONDS = 60.0


@dataclass(frozen=True)
class SpeakConfig:
Expand Down Expand Up @@ -100,6 +105,18 @@ def _decode_audio_frame(msg: dict[str, object]) -> bytes:
raise APIError(f"TTS service sent an Audio frame that is not valid base64: {exc}") from exc


def _recv_raw(ws: _WebSocket) -> str | bytes:
"""One frame off the socket, with a bounded wait: a server that goes silent
mid-session (e.g. never sends the final Audio frame) must fail the command,
not hang it forever on an unbounded recv()."""
try:
return ws.recv(timeout=_RECV_TIMEOUT_SECONDS)
except TimeoutError as exc:
raise APIError(
f"TTS service stopped responding (no frame for {_RECV_TIMEOUT_SECONDS:g}s)."
) from exc


def _default_connect(
url: str, *, additional_headers: dict[str, str], max_size: int | None
) -> _WebSocket:
Expand Down Expand Up @@ -135,7 +152,7 @@ def _run_protocol(
ws: _WebSocket, config: SpeakConfig, on_warning: Callable[[str], None] | None
) -> SpeakResult:
"""Send Generate + ForceFlushTextBuffer, collect Audio until is_final, then Terminate."""
begin = json.loads(ws.recv())
begin = json.loads(_recv_raw(ws))
if begin.get("type") != "Begin":
raise APIError(f"TTS service did not start the session (got {begin.get('type')!r}).")
sample_rate = int(begin.get("configuration", {}).get("sample_rate", _DEFAULT_SAMPLE_RATE))
Expand All @@ -145,7 +162,7 @@ def _run_protocol(

pcm = bytearray()
while True:
msg = json.loads(ws.recv())
msg = json.loads(_recv_raw(ws))
mtype = msg.get("type")
if mtype == "Audio":
pcm.extend(_decode_audio_frame(msg))
Expand Down
23 changes: 21 additions & 2 deletions aai_cli/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# covers WAF/region/plan blocks) — mirrors how `stream` classifies handshakes.
_HTTP_FORBIDDEN = 403

# Handshake statuses that mean the server refused the connection outright.
_HANDSHAKE_AUTH_STATUSES = (401, 403)

# The sync websockets client logs through these; both are silenced for a session
# (the parent covers any future child logger, the client logger is the one that fires).
WEBSOCKETS_LOGGERS = ("websockets", "websockets.client")
Expand All @@ -33,6 +36,23 @@ def silence_websockets_logging() -> None:
logging.getLogger(name).setLevel(logging.CRITICAL)


def handshake_status(exc: object) -> int | None:
"""The HTTP status of a rejected WebSocket handshake (401/403), else None.

Reads the two structured shapes only — the assemblyai SDK's StreamingError
carries the status on ``.code``; websockets' InvalidStatus carries it on
``.response.status_code`` — never the message text. The single classifier for
every realtime path (stream, agent, speak), so 401-vs-403 handling can't drift.
"""
code = getattr(exc, "code", None)
if code in _HANDSHAKE_AUTH_STATUSES:
return int(code)
status = getattr(getattr(exc, "response", None), "status_code", None)
if status in _HANDSHAKE_AUTH_STATUSES:
return int(status)
return None


def is_rejected_key(exc: Exception) -> bool:
"""Is this connect/session failure auth-shaped (the key itself was rejected)?

Expand All @@ -43,8 +63,7 @@ def is_rejected_key(exc: Exception) -> bool:
Agent's 1008 policy-violation close, or an explicitly auth-worded message
(`is_auth_failure`'s text hints) count as a rejected key.
"""
status = getattr(getattr(exc, "response", None), "status_code", None)
if status == _HTTP_FORBIDDEN:
if handshake_status(exc) == _HTTP_FORBIDDEN:
return False
return is_auth_failure(exc)

Expand Down
15 changes: 15 additions & 0 deletions tests/test_auth_ams.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ def handler(request: httpx.Request) -> httpx.Response:
assert "detail" not in str(exc.value)


def test_ok_response_with_non_json_body_raises_clean_api_error(monkeypatch):
# A 2xx whose body isn't JSON (proxy interference, truncation) must surface as
# a clean AMS error, not escape as a raw JSONDecodeError that run_command can
# only report as an internal bug.
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, text="<html>proxy intercepted</html>")

_patch_transport(monkeypatch, handler)
with pytest.raises(APIError) as exc:
ams.discover("x")
assert "not valid JSON" in str(exc.value)
assert "HTTP 200" in str(exc.value)
assert exc.value.suggestion is not None and "support" in exc.value.suggestion


def test_error_with_non_json_body_falls_back_to_text(monkeypatch):
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(500, text="upstream is down")
Expand Down
23 changes: 23 additions & 0 deletions tests/test_auth_loopback.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,29 @@ def test_capture_times_out_without_callback():
assert result.token is None


def test_first_callback_wins_duplicate_never_overwrites_the_token():
# A second callback (browser reload/double-click — or anything else that can
# reach the loopback port) is acknowledged but must not replace the token the
# genuine first callback delivered.
capture = loopback.start_capture()
assert _hit("/callback?stytch_token_type=discovery_oauth&token=tok_first") == 200
assert _hit("/callback?stytch_token_type=discovery_oauth&token=tok_second") == 200
result = capture.wait(timeout=5.0)
assert result.token == "tok_first"
assert result.token_type == "discovery_oauth"
assert result.error is None


def test_timeout_claims_the_capture():
# On timeout, wait() claims the capture (sets the done event) under the lock, so
# a callback losing the race takes the already-claimed branch and can no longer
# mutate the result the caller is about to receive.
capture = loopback.start_capture()
result = capture.wait(timeout=0.05)
assert result.error == "timeout"
assert capture.done.is_set()


def test_capture_server_thread_is_daemon_and_joined_with_timeout(monkeypatch):
# The serve_forever thread must be a daemon (so it can't block process exit) and the
# cleanup join must be bounded (5s) so a wedged server can't hang shutdown. The
Expand Down
10 changes: 10 additions & 0 deletions tests/test_code_gen_stream_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def test_agent_render_parses_and_injects_session_fields():
assert 'os.environ["ASSEMBLYAI_API_KEY"]' in code


def test_agent_render_mic_thread_ends_quietly_when_the_socket_closes():
# The generated send_mic daemon thread blocks on ws.send(); when the session
# ends and the socket closes, that send raises. Without the guard, the thread
# would dump a traceback to stderr on every normal exit of the sample script.
code = code_gen.agent(voice="ivy", system_prompt="Be terse.", greeting="Hi")
ast.parse(code)
assert "except Exception:" in code
assert "return # socket closed (session over): end the mic thread quietly" in code


def test_agent_render_escapes_quotes_in_prompt():
import json as _json

Expand Down
Loading
Loading