diff --git a/aai_cli/argscan.py b/aai_cli/argscan.py index 28349c72..ea101470 100644 --- a/aai_cli/argscan.py +++ b/aai_cli/argscan.py @@ -20,3 +20,8 @@ def requests_json(raw_args: list[str]) -> bool: if token in ("-o", "--output") and raw_args[index + 1 : index + 2] == ["json"]: return True return False + + +def requests_quiet(raw_args: list[str]) -> bool: + """Whether the token list asked for quiet output: ``--quiet`` or ``-q``.""" + return any(token in ("--quiet", "-q") for token in raw_args) diff --git a/aai_cli/auth/ams.py b/aai_cli/auth/ams.py index 19be95ea..69681fab 100644 --- a/aai_cli/auth/ams.py +++ b/aai_cli/auth/ams.py @@ -37,7 +37,16 @@ def _raise_for_error(resp: httpx.Response) -> None: def _json_or_raise(resp: httpx.Response) -> object: _raise_for_error(resp) - data: object = resp.json() + try: + data: object = resp.json() + except ValueError as exc: + # A 2xx with an unparseable body (proxy interference, truncation) must + # surface as a clean AMS error, not escape as a raw JSONDecodeError that + # run_command can only report as an internal bug. + raise APIError( + f"AMS returned a response that is not valid JSON (HTTP {resp.status_code}).", + suggestion="Check your network and try again; if it persists, contact support.", + ) from exc return data diff --git a/aai_cli/auth/loopback.py b/aai_cli/auth/loopback.py index 2cf792f8..749a87e7 100644 --- a/aai_cli/auth/loopback.py +++ b/aai_cli/auth/loopback.py @@ -43,6 +43,7 @@ class CallbackCapture: done: threading.Event server: HTTPServer thread: threading.Thread + lock: threading.Lock def wait( self, @@ -55,7 +56,15 @@ def wait( """ try: if not self.done.wait(timeout): - self.result.error = "timeout" + # Claim the capture under the lock: the handler thread may be + # processing a callback that arrived right at the deadline. If it + # already claimed (done set), its token result stands; otherwise the + # timeout claims it, and a late callback can no longer mutate the + # result this method is about to hand to the caller. + with self.lock: + if not self.done.is_set(): + self.result.error = "timeout" + self.done.set() finally: self.server.shutdown() # stop serve_forever() self.thread.join(timeout=5) # pragma: no mutate (cleanup grace period only) @@ -71,9 +80,13 @@ def start_capture() -> CallbackCapture: before opening the browser. Only a callback to the registered path that carries a `token` is accepted; any other request (a different path, or no token) gets a 4xx and the server keeps waiting, so a stray request can't end the capture early. + The first matching callback wins: a duplicate (browser reload/double-click, or + anything else hitting the loopback port afterwards) is acknowledged but can never + overwrite the captured token. """ result = CallbackResult() done = threading.Event() + lock = threading.Lock() class Handler(BaseHTTPRequestHandler): def do_GET(self) -> None: # stdlib API name @@ -91,13 +104,19 @@ def do_GET(self) -> None: # stdlib API name self.send_response(400) self.end_headers() return - result.token = token - result.token_type = next(iter(qs.get("stytch_token_type", [])), None) + # First claim wins: once the capture is done (a prior callback, or the + # timeout in wait()), the result is already in the caller's hands, so a + # late or duplicate callback must not mutate it. The lock pairs with + # wait()'s timeout claim so the two threads can't interleave mid-write. + with lock: + if not done.is_set(): + result.token = token + result.token_type = next(iter(qs.get("stytch_token_type", [])), None) + done.set() self.send_response(200) self.send_header("Content-Type", "text/html") self.end_headers() self.wfile.write(_SUCCESS_HTML) - done.set() def log_message(self, format: str, *args: object) -> None: # silence stderr logging pass @@ -113,7 +132,7 @@ def log_message(self, format: str, *args: object) -> None: # silence stderr log ) from exc thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - return CallbackCapture(result=result, done=done, server=server, thread=thread) + return CallbackCapture(result=result, done=done, server=server, thread=thread, lock=lock) def capture_callback( diff --git a/aai_cli/code_gen/agent.py b/aai_cli/code_gen/agent.py index 1264aeca..ba05af80 100644 --- a/aai_cli/code_gen/agent.py +++ b/aai_cli/code_gen/agent.py @@ -47,8 +47,12 @@ def on_audio(indata, outdata, _frames, _time, _status): def send_mic(ws): while True: chunk = mic_queue.get() - if ready.is_set(): + if not ready.is_set(): + continue + try: ws.send(json.dumps({{"type": "input.audio", "audio": base64.b64encode(chunk).decode()}})) + except Exception: + return # socket closed (session over): end the mic thread quietly stream = sd.RawStream( diff --git a/aai_cli/commands/sessions.py b/aai_cli/commands/sessions.py index 266211f2..145a2971 100644 --- a/aai_cli/commands/sessions.py +++ b/aai_cli/commands/sessions.py @@ -87,11 +87,14 @@ def render(data: list[dict[str, object]]) -> object: "model", ) for s in data: + # `is None`, not truthiness: 0 is a legitimate duration (a session + # that connected but streamed no audio) and must render as "0". + duration = s.get("audio_duration_sec") table.add_row( escape(str(s["session_id"])), theme.status_text(str(s["status"])), escape(timeparse.format_utc_datetime(s.get("created_at"))), - escape(str(s.get("audio_duration_sec") or "")), + escape("" if duration is None else str(duration)), escape(str(s.get("speech_model") or "")), ) return table diff --git a/aai_cli/context.py b/aai_cli/context.py index 798a7c6d..1b70af82 100644 --- a/aai_cli/context.py +++ b/aai_cli/context.py @@ -192,7 +192,9 @@ def _auto_login_and_exit(state: AppState, *, json_mode: bool) -> NoReturn: except CLIError as login_err: output.emit_error(login_err, json_mode=json_mode) raise typer.Exit(code=login_err.exit_code) from None - except (OSError, RuntimeError, keyring.errors.KeyringError) as exc: + except (OSError, RuntimeError, TypeError, keyring.errors.KeyringError) as exc: + # TypeError covers a value the TOML writer can't serialize: the login itself + # succeeded, so the user must see "could not save", not "unexpected error". persistence_err = _login_persistence_error(exc) output.emit_error(persistence_err, json_mode=json_mode) raise typer.Exit(code=persistence_err.exit_code) from None diff --git a/aai_cli/streaming/diagnostics.py b/aai_cli/streaming/diagnostics.py index bb5a427e..c6a99a6c 100644 --- a/aai_cli/streaming/diagnostics.py +++ b/aai_cli/streaming/diagnostics.py @@ -22,8 +22,6 @@ # CLIError already carries the message, so the logger is raised above ERROR. SDK_STREAMING_LOGGER = "assemblyai.streaming" -# Handshake statuses that mean the server refused the connection outright. -_HANDSHAKE_AUTH_STATUSES = (401, 403) _UNAUTHORIZED = 401 @@ -46,22 +44,6 @@ def handshake_suggestion(host: str) -> str: ) -def _handshake_status(error: object) -> int | None: - """The HTTP status of a rejected WebSocket handshake (401/403), else None. - - Reads the two structured shapes only — the assemblyai SDK's StreamingError - carries the status on ``.code``; websockets' InvalidStatus carries it on - ``.response.status_code`` — never the message text. - """ - code = getattr(error, "code", None) - if code in _HANDSHAKE_AUTH_STATUSES: - return int(code) - status = getattr(getattr(error, "response", None), "status_code", None) - if status in _HANDSHAKE_AUTH_STATUSES: - return int(status) - return None - - def handshake_error(error: object, message: str, *, host: str) -> CLIError | None: """An auth-flavored CLIError for a handshake 401/403, else None. @@ -70,7 +52,7 @@ def handshake_error(error: object, message: str, *, host: str) -> CLIError | Non stays an APIError — it also covers WAF/region/plan blocks, mirroring ``aai_cli.ws.is_rejected_key`` — but now carries the same suggestion. """ - status = _handshake_status(error) + status = wsutil.handshake_status(error) if status is None: return None if status == _UNAUTHORIZED: diff --git a/aai_cli/streaming/session.py b/aai_cli/streaming/session.py index ff01d7ac..1c028e70 100644 --- a/aai_cli/streaming/session.py +++ b/aai_cli/streaming/session.py @@ -10,7 +10,7 @@ import typer from aai_cli import choices, client, config_builder, llm, output -from aai_cli.errors import CLIError, UsageError +from aai_cli.errors import APIError, CLIError, UsageError from aai_cli.follow import FollowRenderer from aai_cli.streaming.render import StreamRenderer, speaker_prefix @@ -284,7 +284,15 @@ def _drive(self, streams: _ParallelStreams) -> None: def worker(source_label: str, audio: Iterable[bytes], rate: int) -> None: try: - self.stream_one(audio, rate, source_label=source_label) + try: + self.stream_one(audio, rate, source_label=source_label) + except (CLIError, BrokenPipeError): + raise + except Exception as exc: + # A non-CLIError here is a bug, but it must still fail the run: + # uncaught, it dies with this daemon thread and the command + # exits 0 for a stream that actually failed. + raise APIError(f"Streaming worker ({source_label}) failed: {exc}") from exc except (CLIError, BrokenPipeError) as exc: errors.put(exc) diff --git a/aai_cli/telemetry.py b/aai_cli/telemetry.py index c64bab7f..082e1e20 100644 --- a/aai_cli/telemetry.py +++ b/aai_cli/telemetry.py @@ -106,7 +106,7 @@ def _notice_suppressed(raw_args: list[str]) -> bool: ``--quiet`` run nor pollute the machine-readable stderr a ``--json`` (or ``-o json``) pipeline relies on. """ - return any(token in ("--quiet", "-q") for token in raw_args) or argscan.requests_json(raw_args) + return argscan.requests_quiet(raw_args) or argscan.requests_json(raw_args) def _maybe_emit_first_run_notice() -> None: diff --git a/aai_cli/tts/session.py b/aai_cli/tts/session.py index ebde6e43..263b79d4 100644 --- a/aai_cli/tts/session.py +++ b/aai_cli/tts/session.py @@ -20,7 +20,7 @@ class _WebSocket(Protocol): """The slice of a websockets sync connection this module drives — named as a Protocol so the untyped library boundary is structurally typed, not opaque.""" - def recv(self) -> str | bytes: ... + def recv(self, timeout: float | None = None) -> str | bytes: ... def send(self, data: str, /) -> None: ... # positional-only: matches ws send(message) def close(self) -> None: ... @@ -38,6 +38,11 @@ def close(self) -> None: ... # Pause inserted between speaker turns in a multi-voice dialogue, for natural pacing. _INTER_TURN_SILENCE_SECONDS = 0.25 +# Bound on the wait for each protocol frame. The server streams frames continuously +# while a synthesis is in flight, so a gap this long means it went silent mid-session; +# without a bound, `assembly speak` would hang forever instead of failing cleanly. +_RECV_TIMEOUT_SECONDS = 60.0 + @dataclass(frozen=True) class SpeakConfig: @@ -100,6 +105,18 @@ def _decode_audio_frame(msg: dict[str, object]) -> bytes: raise APIError(f"TTS service sent an Audio frame that is not valid base64: {exc}") from exc +def _recv_raw(ws: _WebSocket) -> str | bytes: + """One frame off the socket, with a bounded wait: a server that goes silent + mid-session (e.g. never sends the final Audio frame) must fail the command, + not hang it forever on an unbounded recv().""" + try: + return ws.recv(timeout=_RECV_TIMEOUT_SECONDS) + except TimeoutError as exc: + raise APIError( + f"TTS service stopped responding (no frame for {_RECV_TIMEOUT_SECONDS:g}s)." + ) from exc + + def _default_connect( url: str, *, additional_headers: dict[str, str], max_size: int | None ) -> _WebSocket: @@ -135,7 +152,7 @@ def _run_protocol( ws: _WebSocket, config: SpeakConfig, on_warning: Callable[[str], None] | None ) -> SpeakResult: """Send Generate + ForceFlushTextBuffer, collect Audio until is_final, then Terminate.""" - begin = json.loads(ws.recv()) + begin = json.loads(_recv_raw(ws)) if begin.get("type") != "Begin": raise APIError(f"TTS service did not start the session (got {begin.get('type')!r}).") sample_rate = int(begin.get("configuration", {}).get("sample_rate", _DEFAULT_SAMPLE_RATE)) @@ -145,7 +162,7 @@ def _run_protocol( pcm = bytearray() while True: - msg = json.loads(ws.recv()) + msg = json.loads(_recv_raw(ws)) mtype = msg.get("type") if mtype == "Audio": pcm.extend(_decode_audio_frame(msg)) diff --git a/aai_cli/ws.py b/aai_cli/ws.py index 6e5fa10a..8bf4ad8d 100644 --- a/aai_cli/ws.py +++ b/aai_cli/ws.py @@ -15,6 +15,9 @@ # covers WAF/region/plan blocks) — mirrors how `stream` classifies handshakes. _HTTP_FORBIDDEN = 403 +# Handshake statuses that mean the server refused the connection outright. +_HANDSHAKE_AUTH_STATUSES = (401, 403) + # The sync websockets client logs through these; both are silenced for a session # (the parent covers any future child logger, the client logger is the one that fires). WEBSOCKETS_LOGGERS = ("websockets", "websockets.client") @@ -33,6 +36,23 @@ def silence_websockets_logging() -> None: logging.getLogger(name).setLevel(logging.CRITICAL) +def handshake_status(exc: object) -> int | None: + """The HTTP status of a rejected WebSocket handshake (401/403), else None. + + Reads the two structured shapes only — the assemblyai SDK's StreamingError + carries the status on ``.code``; websockets' InvalidStatus carries it on + ``.response.status_code`` — never the message text. The single classifier for + every realtime path (stream, agent, speak), so 401-vs-403 handling can't drift. + """ + code = getattr(exc, "code", None) + if code in _HANDSHAKE_AUTH_STATUSES: + return int(code) + status = getattr(getattr(exc, "response", None), "status_code", None) + if status in _HANDSHAKE_AUTH_STATUSES: + return int(status) + return None + + def is_rejected_key(exc: Exception) -> bool: """Is this connect/session failure auth-shaped (the key itself was rejected)? @@ -43,8 +63,7 @@ def is_rejected_key(exc: Exception) -> bool: Agent's 1008 policy-violation close, or an explicitly auth-worded message (`is_auth_failure`'s text hints) count as a rejected key. """ - status = getattr(getattr(exc, "response", None), "status_code", None) - if status == _HTTP_FORBIDDEN: + if handshake_status(exc) == _HTTP_FORBIDDEN: return False return is_auth_failure(exc) diff --git a/tests/test_auth_ams.py b/tests/test_auth_ams.py index 3785b1be..61c5bea7 100644 --- a/tests/test_auth_ams.py +++ b/tests/test_auth_ams.py @@ -64,6 +64,21 @@ def handler(request: httpx.Request) -> httpx.Response: assert "detail" not in str(exc.value) +def test_ok_response_with_non_json_body_raises_clean_api_error(monkeypatch): + # A 2xx whose body isn't JSON (proxy interference, truncation) must surface as + # a clean AMS error, not escape as a raw JSONDecodeError that run_command can + # only report as an internal bug. + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, text="proxy intercepted") + + _patch_transport(monkeypatch, handler) + with pytest.raises(APIError) as exc: + ams.discover("x") + assert "not valid JSON" in str(exc.value) + assert "HTTP 200" in str(exc.value) + assert exc.value.suggestion is not None and "support" in exc.value.suggestion + + def test_error_with_non_json_body_falls_back_to_text(monkeypatch): def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(500, text="upstream is down") diff --git a/tests/test_auth_loopback.py b/tests/test_auth_loopback.py index 3530265d..8a5dee19 100644 --- a/tests/test_auth_loopback.py +++ b/tests/test_auth_loopback.py @@ -144,6 +144,29 @@ def test_capture_times_out_without_callback(): assert result.token is None +def test_first_callback_wins_duplicate_never_overwrites_the_token(): + # A second callback (browser reload/double-click — or anything else that can + # reach the loopback port) is acknowledged but must not replace the token the + # genuine first callback delivered. + capture = loopback.start_capture() + assert _hit("/callback?stytch_token_type=discovery_oauth&token=tok_first") == 200 + assert _hit("/callback?stytch_token_type=discovery_oauth&token=tok_second") == 200 + result = capture.wait(timeout=5.0) + assert result.token == "tok_first" + assert result.token_type == "discovery_oauth" + assert result.error is None + + +def test_timeout_claims_the_capture(): + # On timeout, wait() claims the capture (sets the done event) under the lock, so + # a callback losing the race takes the already-claimed branch and can no longer + # mutate the result the caller is about to receive. + capture = loopback.start_capture() + result = capture.wait(timeout=0.05) + assert result.error == "timeout" + assert capture.done.is_set() + + def test_capture_server_thread_is_daemon_and_joined_with_timeout(monkeypatch): # The serve_forever thread must be a daemon (so it can't block process exit) and the # cleanup join must be bounded (5s) so a wedged server can't hang shutdown. The diff --git a/tests/test_code_gen_stream_agent.py b/tests/test_code_gen_stream_agent.py index d93b94f9..7372b8d1 100644 --- a/tests/test_code_gen_stream_agent.py +++ b/tests/test_code_gen_stream_agent.py @@ -50,6 +50,16 @@ def test_agent_render_parses_and_injects_session_fields(): assert 'os.environ["ASSEMBLYAI_API_KEY"]' in code +def test_agent_render_mic_thread_ends_quietly_when_the_socket_closes(): + # The generated send_mic daemon thread blocks on ws.send(); when the session + # ends and the socket closes, that send raises. Without the guard, the thread + # would dump a traceback to stderr on every normal exit of the sample script. + code = code_gen.agent(voice="ivy", system_prompt="Be terse.", greeting="Hi") + ast.parse(code) + assert "except Exception:" in code + assert "return # socket closed (session over): end the mic thread quietly" in code + + def test_agent_render_escapes_quotes_in_prompt(): import json as _json diff --git a/tests/test_context.py b/tests/test_context.py index c9e3e5db..43c0b324 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -178,6 +178,34 @@ def body(state, json_mode): assert "Traceback" not in result.output +def test_run_command_auto_login_persistence_type_error_is_clean(monkeypatch): + # The TOML writer raises TypeError for a value it can't serialize. The sign-in + # itself succeeded, so the user must see the "could not save the credentials" + # message — not the generic "Unexpected error" internal-bug line. + _force_interactive(monkeypatch) + monkeypatch.setattr( + "aai_cli.context.run_login_flow", + lambda **_: LoginResult( + api_key="sk_auto", + session_jwt="jwt_auto", + session_token="tok_auto", + account_id=42, + ), + ) + monkeypatch.setattr( + "aai_cli.context.config.set_api_key", + lambda *_args: (_ for _ in ()).throw(TypeError("not TOML-serializable")), + ) + + def body(state, json_mode): + raise NotAuthenticated() + + result = runner.invoke(_make_app(body), ["go"]) + assert result.exit_code == 1 + assert "could not save the credentials" in result.output + assert "Unexpected error" not in result.output + + def test_run_command_auto_login_failure_is_clean(monkeypatch): _force_interactive(monkeypatch) diff --git a/tests/test_sessions_command.py b/tests/test_sessions_command.py index 16380d73..ab65df67 100644 --- a/tests/test_sessions_command.py +++ b/tests/test_sessions_command.py @@ -89,6 +89,35 @@ def test_sessions_list_renders_table_human(monkeypatch, mocker): assert "12.0" in result.output +def test_sessions_list_renders_zero_duration_as_zero(monkeypatch, mocker): + _auth() + _human(monkeypatch) + # 0 is a legitimate duration (a session that connected but streamed no audio): + # it must render as "0", not be coerced to a blank cell like a missing value. + # Neither row carries created_at, so the duration is the only digit in the table. + payload = { + "data": [ + { + "session_id": "s_one", + "status": "completed", + "audio_duration_sec": 0, + "speech_model": "universal", + }, + { + "session_id": "s_two", + "status": "completed", + "speech_model": "universal", + }, + ] + } + mocker.patch( + "aai_cli.commands.sessions.ams.list_streaming", autospec=True, return_value=payload + ) + result = runner.invoke(app, ["sessions", "list"]) + assert result.exit_code == 0 + assert "0" in result.output + + def test_sessions_list_empty_shows_human_empty_state(monkeypatch, mocker): _auth() _human(monkeypatch) diff --git a/tests/test_stream_session.py b/tests/test_stream_session.py index 740e98f1..72873d8c 100644 --- a/tests/test_stream_session.py +++ b/tests/test_stream_session.py @@ -322,6 +322,55 @@ def fake_stream_audio(api_key, source, *, params, **_kwargs): assert daemons and all(d is True for d in daemons) +def test_stream_system_audio_parallel_unexpected_worker_error_fails_the_run(monkeypatch): + # A non-CLIError bug inside a worker must still fail the run with a clean error: + # uncaught, it would die with the daemon thread and the command would exit 0 + # for a stream that actually failed. + config.set_api_key("default", "sk_live") + + class FakeSystemAudio: + def __init__(self, *, on_open=None): + self.sample_rate = 16000 + + def __iter__(self): + return iter([b"system"]) + + class FakeMic: + def __init__(self, *, target_rate=None, device=None, capture_rate=None, on_open=None): + self.sample_rate = target_rate + + def __iter__(self): + return iter([b"mic"]) + + class ImmediateThread: + def __init__(self, *, target, args, daemon): + self._target = target + self._args = args + + def start(self): + self._target(*self._args) + + def is_alive(self): + return False + + def join(self, timeout=None): + return None + + def fake_stream_audio(api_key, source, *, params, **_kwargs): + raise RuntimeError("event parsing blew up") + + monkeypatch.setattr("aai_cli.commands.stream.MacSystemAudioSource", FakeSystemAudio) + monkeypatch.setattr("aai_cli.commands.stream.MicrophoneSource", FakeMic) + monkeypatch.setattr("aai_cli.commands.stream.client.stream_audio", fake_stream_audio) + monkeypatch.setattr("aai_cli.streaming.session.threading.Thread", ImmediateThread) + result = runner.invoke(app, ["stream", "--system-audio", "--json"]) + assert result.exit_code == 1 + # Normalized to a clean worker error that names the source and the cause. + assert "Streaming worker" in result.output + assert "event parsing blew up" in result.output + assert "Traceback" not in result.output + + def test_stream_system_audio_parallel_keyboard_interrupt_exits_cleanly(monkeypatch): config.set_api_key("default", "sk_live") monkeypatch.setattr("aai_cli.output.resolve_json", lambda *, explicit: False) diff --git a/tests/test_tts_session.py b/tests/test_tts_session.py index 5ac6621e..123e4439 100644 --- a/tests/test_tts_session.py +++ b/tests/test_tts_session.py @@ -71,8 +71,10 @@ def __init__(self, incoming: list[str]) -> None: self._incoming = list(incoming) self.sent: list[dict[str, object]] = [] self.closed = False + self.recv_timeouts: list[float | None] = [] - def recv(self) -> str: + def recv(self, timeout: float | None = None) -> str: + self.recv_timeouts.append(timeout) return self._incoming.pop(0) def send(self, data: str) -> None: @@ -202,6 +204,38 @@ def test_synthesize_maps_error_frame_to_api_error(): session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) +def test_synthesize_bounds_every_recv_with_the_frame_timeout(): + # Each frame wait is bounded (60s): an unbounded recv() would hang the command + # forever if the server went silent mid-session. + ws = FakeWS([_begin_frame(), _audio_frame(b"\x01\x02", final=True)]) + session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) + assert ws.recv_timeouts == [60.0, 60.0] + + +def test_synthesize_maps_silent_server_to_clean_api_error(): + # websockets' sync recv raises TimeoutError when the bound expires; that must + # surface as a clean APIError naming the stall, not a hang or a raw traceback. + class SilentWS(FakeWS): + def recv(self, timeout: float | None = None) -> str: + raise TimeoutError + + ws = SilentWS([]) + with pytest.raises(APIError, match="stopped responding"): + session.synthesize("k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: ws) + assert ws.closed is True + + +def test_synthesize_silent_server_error_names_the_timeout_window(): + class SilentWS(FakeWS): + def recv(self, timeout: float | None = None) -> str: + raise TimeoutError + + with pytest.raises(APIError, match="no frame for 60s"): + session.synthesize( + "k", session.SpeakConfig(text="hi"), connect=lambda *a, **k: SilentWS([]) + ) + + def test_synthesize_invokes_on_warning_then_continues(): seen: list[str] = [] ws = FakeWS( diff --git a/tests/test_ws.py b/tests/test_ws.py index 3c923cb3..223bc273 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -13,6 +13,7 @@ from aai_cli.ws import ( WEBSOCKETS_LOGGERS, auth_or_api_error, + handshake_status, is_rejected_key, silence_websockets_logging, ) @@ -26,11 +27,38 @@ def __init__(self, status: int) -> None: self.response = types.SimpleNamespace(status_code=status) +class _SdkHandshakeRejected(Exception): + """Mimics the assemblyai SDK's StreamingError: the HTTP status on ``.code``. + + The message deliberately contains an auth hint ("forbidden") so the tests can + pin that the structured 403 veto wins over the text heuristic. + """ + + def __init__(self, status: int) -> None: + super().__init__(f"WebSocket handshake rejected: forbidden (HTTP {status})") + self.code = status + + def test_is_rejected_key_false_for_handshake_403(): # 403 also covers WAF/region/plan blocks, so it must NOT read as a rejected key. assert is_rejected_key(_HandshakeRejected(403)) is False +def test_is_rejected_key_false_for_sdk_handshake_403(): + # The SDK shape (status on ``.code``) follows the same 403-is-not-auth rule as + # the websockets shape; without the structured veto, the auth-worded message + # ("forbidden") would misclassify this as a rejected key. + assert is_rejected_key(_SdkHandshakeRejected(403)) is False + + +def test_handshake_status_reads_both_structured_shapes(): + assert handshake_status(_SdkHandshakeRejected(401)) == 401 + assert handshake_status(_HandshakeRejected(403)) == 403 + assert handshake_status(RuntimeError("network unreachable")) is None + # A WebSocket close code (e.g. 1008 policy violation) is not a handshake status. + assert handshake_status(types.SimpleNamespace(code=1008)) is None + + def test_is_rejected_key_true_for_handshake_401(): assert is_rejected_key(_HandshakeRejected(401)) is True