Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions scripts/mutation_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from __future__ import annotations

import ast
import contextlib
import importlib.util
import re
import subprocess
import sys
Expand Down Expand Up @@ -235,15 +237,32 @@ def _run_tests(nodeids: list[str]) -> bool:
return proc.returncode != 0


def _invalidate_bytecode(path: Path) -> None:
"""Drop the module's cached ``.pyc`` so the test subprocess recompiles from the
source we just wrote.

Consecutive mutants ``ast.unparse`` to files that differ by a single token, so
they're usually byte-for-byte the same length and can be written within the same
mtime-second. CPython's default timestamp-based cache validates a ``.pyc`` by
exact (mtime, size) match, so without this it can serve the previous mutant's
(or the original's) bytecode and run *unmutated* code — a false survivor.
"""
cached = importlib.util.cache_from_source(str(path))
with contextlib.suppress(OSError):
Path(cached).unlink()


def _survives(
path: Path, tree: ast.Module, src: str, mutant: _Mutant, data: coverage.CoverageData
) -> bool:
mutant.apply()
try:
path.write_text(ast.unparse(tree), encoding="utf-8")
_invalidate_bytecode(path)
killed = _run_tests(_covering_tests(data, path, mutant.linenos))
finally:
path.write_text(src, encoding="utf-8")
_invalidate_bytecode(path)
mutant.undo()
return not killed

Expand Down
20 changes: 20 additions & 0 deletions tests/test_account_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def fake_usage(jwt, start, end, window):
# (AMS rejects naive datetimes with a 400).
for bound in (captured["start"], captured["end"]):
assert bound.endswith("+00:00") and "T" in bound, bound
# The default range spans exactly the last 30 days (pins `today - timedelta(days=30)`).
from datetime import datetime as _dt

start_day = _dt.fromisoformat(captured["start"]).date()
end_day = _dt.fromisoformat(captured["end"]).date()
assert (end_day - start_day).days == 30
data = json.loads(result.output)
assert data["usage_items"][0]["total"] == 12.5

Expand Down Expand Up @@ -117,6 +123,20 @@ def test_usage_helpers_format_windows_and_line_items():
)
== "2026-01-01 to 2026-01-03"
)
# Exactly one parseable bound falls back to the single start-day label (pins the
# `start is None or end is None` guard; an `and` would dereference the None end).
assert account._window_label({"start_timestamp": "2026-01-01T00:00:00Z"}) == "2026-01-01"
# A one-day window (end == start + 1 day) collapses to a single day, not a range
# (pins the `start.date() + timedelta(days=1)`).
assert (
account._window_label(
{
"start_timestamp": "2026-01-01T00:00:00Z",
"end_timestamp": "2026-01-02T00:00:00Z",
}
)
== "2026-01-01"
)
assert account._line_item_label({"name": "minutes", "total": "12.500"}) == "minutes: 12.5"
assert account._line_item_label({"product": "streaming"}) == "streaming"
assert account._line_item_label({"quantity": 3}) == "3"
Expand Down
42 changes: 40 additions & 2 deletions tests/test_agent_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,32 @@ def test_duplex_opens_at_device_rate_and_closes():
def factory(*, rate, blocksize, callback, device):
seen["rate"] = rate
seen["device"] = device
seen["blocksize"] = blocksize
return fake

d = DuplexAudio(device=3, device_rate=48000, stream_factory=factory)
d.player.start()
assert seen["rate"] == 48000 and seen["device"] == 3 # one stream at device rate
assert seen["blocksize"] == 4800 # ~100 ms at 48 kHz (device_rate // 10)
d.close()
assert fake.stopped and fake.closed


def test_duplex_restart_after_close_reopens_stream():
calls = {"n": 0}

def factory(**_k):
calls["n"] += 1
return FakeStream()

d = DuplexAudio(device_rate=16000, stream_factory=factory)
d.start()
assert calls["n"] == 1
d.close()
d.start() # close() cleared the started flag, so this reopens the stream
assert calls["n"] == 2


def test_duplex_callback_captures_input_and_zero_fills_idle_output():
cb = {}

Expand Down Expand Up @@ -78,6 +95,26 @@ def factory(*, rate, blocksize, callback, device):
d.close()


def test_duplex_callback_partial_buffer_zero_fills_exact_remainder():
cb = {}

def factory(*, rate, blocksize, callback, device):
cb["fn"] = callback
return FakeStream()

# device == target so playback bytes pass through unresampled and are easy to count.
d = DuplexAudio(target_rate=16000, device_rate=16000, stream_factory=factory)
d.player.start()
d.player.enqueue(b"\x01\x02" * 5) # 10 bytes buffered
outdata = bytearray(20) # request 20 bytes -> 10 real + 10 zero-filled
cb["fn"](b"\x00\x00" * 5, outdata, 5, None, None)
# The shortfall is filled with exactly `need - len(take)` zero bytes: the buffer
# plays out first, then silence, and the output stays exactly `need` bytes long.
assert len(outdata) == 20
assert bytes(outdata) == b"\x01\x02" * 5 + b"\x00" * 10
d.close()


def test_duplex_mic_ends_after_close():
d = DuplexAudio(target_rate=16000, device_rate=16000, stream_factory=lambda **k: FakeStream())
d.player.start()
Expand All @@ -102,8 +139,8 @@ def test_duplex_player_facade_flush_and_close():
fake = FakeStream()
d = DuplexAudio(target_rate=16000, device_rate=16000, stream_factory=lambda **k: fake)
d.player.start()
d.player.enqueue(b"\x01\x02" * 8)
assert d.player.pending() > 0
d.player.enqueue(b"\x01\x02" * 8) # 16 bytes, no resample (device == target)
assert d.player.pending() == 8 # pending() reports samples = bytes // 2
d.player.flush()
assert d.player.pending() == 0
d.player.close()
Expand Down Expand Up @@ -161,3 +198,4 @@ def boom(**kw):
with pytest.raises(CLIError) as exc:
_default_duplex_stream(rate=24000, blocksize=2400, callback=lambda *a: None, device=None)
assert exc.value.error_type == "audio_output_error"
assert exc.value.exit_code == 1
24 changes: 24 additions & 0 deletions tests/test_agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,20 @@ def test_transcripts_routed_to_renderer():
s.dispatch({"type": "transcript.user.delta", "text": "what"})
s.dispatch({"type": "transcript.user", "text": "what time"})
s.dispatch({"type": "transcript.agent", "text": "noon", "interrupted": False})
# An agent transcript with no "interrupted" key defaults to False (pins the default).
s.dispatch({"type": "transcript.agent", "text": "later"})
assert ("user_partial", "what") in s.renderer.calls
assert ("user_final", "what time") in s.renderer.calls
assert ("agent_transcript", "noon", False) in s.renderer.calls
assert ("agent_transcript", "later", False) in s.renderer.calls


def test_unauthorized_error_raises_cli_error_exit_2():
s = _session()
with pytest.raises(CLIError) as excinfo:
s.dispatch({"type": "session.error", "code": "UNAUTHORIZED", "message": "bad key"})
assert excinfo.value.exit_code == 2
assert "bad key" in str(excinfo.value) # the server message wins over code/fallback


def test_other_session_error_raises_api_error():
Expand Down Expand Up @@ -285,6 +289,26 @@ def close(self):
assert exc.value.exit_code == 1 # the real mic failure reaches the user, not a hang


def test_run_session_does_not_close_player_that_failed_to_open():
# If opening the speaker stream raises, the cleanup must NOT call close() on a
# player that never started (pins the player_started=False initializer).
class _FailingPlayer(FakePlayer):
def start(self):
raise CLIError("speaker busy", error_type="audio_output_error", exit_code=1)

player = _FailingPlayer()
with pytest.raises(CLIError):
run_session(
"sk",
renderer=FakeRenderer(),
player=player,
mic=[],
config=AgentRunConfig(voice="ivy", system_prompt="x", greeting="hi"),
connect=lambda url, **kwargs: _RecordingWS(),
)
assert player.closed is False # never opened, so never closed


def test_run_session_non_auth_failure_stays_api_error():
def boom(url, **kwargs):
raise RuntimeError("network unreachable")
Expand Down
3 changes: 3 additions & 0 deletions tests/test_auth_ams.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def handler(request: httpx.Request) -> httpx.Response:
with pytest.raises(APIError) as exc:
ams.discover("x")
assert "Something went wrong" in str(exc.value)
# The "detail" field is extracted, not the raw JSON body: the field name and its
# braces must not leak (pins `mapping is not None and "detail" in mapping`).
assert "detail" not in str(exc.value)


def test_error_with_non_json_body_falls_back_to_text(monkeypatch):
Expand Down
27 changes: 20 additions & 7 deletions tests/test_auth_loopback.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
import http.client
import socket
import threading
import time
import urllib.request

import pytest

from aai_cli.auth import endpoints, loopback
from aai_cli.errors import APIError


def _hit(path: str) -> None:
url = f"http://{endpoints.LOOPBACK_HOST}:{endpoints.LOOPBACK_PORT}{path}"
def _hit(path: str) -> int | None:
"""Request `path` against the loopback server, returning the HTTP status code.

Uses http.client (not urllib) so a 404 comes back as a normal response status
rather than a raised HTTPError, and so no urllib audit suppression is needed.
"""
# Retry briefly until the server thread is bound.
for _ in range(50):
conn = http.client.HTTPConnection(
endpoints.LOOPBACK_HOST, endpoints.LOOPBACK_PORT, timeout=2
)
try:
urllib.request.urlopen(url, timeout=2).read() # noqa: S310 - fixed localhost URL
return
conn.request("GET", path)
resp = conn.getresponse()
resp.read()
return resp.status
except OSError:
time.sleep(0.05)
finally:
conn.close()
return None


def test_capture_returns_token_and_type():
Expand All @@ -28,9 +40,10 @@ def run():

t = threading.Thread(target=run)
t.start()
_hit("/callback?stytch_token_type=discovery_oauth&token=tok_abc")
status = _hit("/callback?stytch_token_type=discovery_oauth&token=tok_abc")
t.join(timeout=5)

assert status == 200 # the callback is acknowledged with 200 OK
result = result_box["result"]
assert result.token == "tok_abc"
assert result.token_type == "discovery_oauth"
Expand All @@ -47,7 +60,7 @@ def run():

t = threading.Thread(target=run)
t.start()
_hit("/favicon.ico") # unknown path -> 404, capture stays open
assert _hit("/favicon.ico") == 404 # unknown path -> 404, capture stays open
_hit("/callback?stytch_token_type=discovery_oauth&token=tok_late")
t.join(timeout=5)

Expand Down
40 changes: 40 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def test_validate_key_true_on_success():
with patch.object(client.aai, "Transcriber") as T:
T.return_value.list_transcripts.return_value = MagicMock()
assert client.validate_key("sk_good") is True
# The probe asks for a single row — it only needs to confirm the key authenticates.
params = T.return_value.list_transcripts.call_args.args[0]
assert params.limit == 1


def test_validate_key_false_on_auth_error():
Expand Down Expand Up @@ -126,6 +129,23 @@ def test_transcribe_raises_on_error_status():
with pytest.raises(APIError) as exc:
client.transcribe("sk", "audio.mp3", config=aai.TranscriptionConfig())
assert exc.value.transcript_id == "t_err"
assert exc.value.message == "decode failed" # surfaces the SDK's error verbatim


def test_transcribe_error_status_without_message_uses_fallback():
# When the SDK reports an error status but no error text, fall back to a generic
# message (pins the `transcript.error or "Transcription failed."`).
fake_transcript = MagicMock()
fake_transcript.status = client.aai.TranscriptStatus.error
fake_transcript.error = None
fake_transcript.id = "t_err"
fake_transcriber = MagicMock()
fake_transcriber.transcribe.return_value = fake_transcript

with patch.object(client.aai, "Transcriber", return_value=fake_transcriber):
with pytest.raises(APIError) as exc:
client.transcribe("sk", "audio.mp3", config=aai.TranscriptionConfig())
assert exc.value.message == "Transcription failed."


def test_select_transcript_field_utterances_formats_speakers():
Expand Down Expand Up @@ -272,6 +292,26 @@ def test_stream_audio_wires_handlers_and_streams(monkeypatch):
assert last.terminate is True # graceful flush requested


def test_stream_audio_registers_begin_handler_when_provided(monkeypatch):
# A provided on_begin must actually be wired to the Begin event (pins
# `if on_begin is not None`); inverting it would leave Begin unhandled.
class BeginClient(_FakeStreamingClient):
def stream(self, source):
from assemblyai.streaming.v3 import StreamingEvents

self.handlers[StreamingEvents.Begin](self, _types.SimpleNamespace(id="sess_1"))

monkeypatch.setattr(client, "StreamingClient", BeginClient)
begins = []
client.stream_audio(
"sk",
[b"\x00"],
params=_stream_params(),
on_begin=lambda e: begins.append(e.id),
)
assert begins == ["sess_1"]


def test_stream_audio_raises_on_error_event(monkeypatch):
class ErrClient(_FakeStreamingClient):
def stream(self, source):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def test_transcribe_render_parses_and_uses_env_key():
assert "https://assembly.ai/wildfires.mp3" in code
assert "transcript.utterances" in code # result handling for speaker_labels
assert "{{API_KEY}}" not in code # never echo a real key
# config kwargs are rendered 4-space indented inside the TranscriptionConfig call
assert "aai.TranscriptionConfig(\n speaker_labels=True,\n)" in code


def test_transcribe_render_no_config_is_minimal():
Expand Down
23 changes: 23 additions & 0 deletions tests/test_config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,18 @@ def test_split_csv():
def test_parse_auth_header():
assert cb.parse_auth_header("Authorization:Bearer x") == ("Authorization", "Bearer x")
assert cb.parse_auth_header(None) is None
# Only the first ':' separates NAME from VALUE; colons in the value are preserved.
assert cb.parse_auth_header("X-Auth:Bearer a:b:c") == ("X-Auth", "Bearer a:b:c")
with pytest.raises(UsageError):
cb.parse_auth_header("no-colon")


def test_parse_config_overrides_splits_on_first_equals_only():
# A value may itself contain '='; only the first '=' separates key from value.
out = cb.parse_config_overrides(cb.TRANSCRIBE_FIELDS, ["keyterms_prompt=a=b,c"])
assert out["keyterms_prompt"] == ["a=b", "c"]


def test_load_custom_spelling(tmp_path):
p = tmp_path / "spell.json"
p.write_text('{"AssemblyAI": ["assembly ai", "assemblyai"]}')
Expand Down Expand Up @@ -398,6 +406,21 @@ def test_derive_kind_dict_origin_is_json():
assert cb._derive_kind(dict[str, int]) == "json"


def test_derive_kind_unwraps_optional_and_classifies_bare_scalars():
import typing

# A bare scalar is classified by its type, not treated as a dict/json value (pins
# the `origin is dict` check) and a list origin -> "list".
assert cb._derive_kind(int) == "int"
assert cb._derive_kind(list[str]) == "list"
# Optional[int] must unwrap to its single inner type (pins the `a is not None`
# filter). Build the Union via typing.__dict__ so ruff's UP007 ("use X | Y")
# stays quiet — the unwrap path specifically keys on the typing.Union origin,
# which `int | None` doesn't share.
optional_int = typing.__dict__["Union"][int, None]
assert cb._derive_kind(optional_int) == "int"


def test_coerce_table_unknown_field_defaults_to_str():
# A curated name the SDK model doesn't expose passes through as a string
# rather than crashing at import time.
Expand Down
Loading
Loading