Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ The post-edit hook (`.claude/settings.json`) runs `ruff check --fix --unfixable

The suite is hermetic by construction, enforced three ways (`tests/conftest.py` + `pyproject.toml` `[tool.pytest.ini_options]`): **pytest-randomly** shuffles order, an autouse `pin_timezone` fixture pins `TZ` to a fixed non-UTC zone (UTC-normalized rendering must be unaffected; use **time-machine** to freeze `now`), and **pytest-socket** (`--disable-socket`) blocks real network so an unmocked SDK/HTTP call fails loudly instead of hitting the API. A test that only binds a loopback server opts back in with the tight `@pytest.mark.allow_hosts(["127.0.0.1"])` (still blocks external hosts). The `e2e`/`install`/`install_script` marker suites legitimately reach the real network in-process (PyPI reachability probes, real-API runs), so a `pytest_collection_modifyitems` hook in `conftest.py` auto-grants them full sockets — adding a network marker is all that's needed, no per-test `enable_socket`.

### Manual QA / running the CLI in sandboxed sessions

Lessons that cost time in agent sessions — read before exercising `uv run aai` by hand:

- **Probe network reachability first.** Remote/sandboxed environments often allowlist
PyPI but block `api.assemblyai.com` / `streaming.assemblyai.com` / `llm-gateway.assemblyai.com`
(`curl -s https://api.assemblyai.com/v2/transcript -H "authorization: $ASSEMBLYAI_API_KEY"`
returning a proxy 403 like "Host not in allowlist" means **no** real-API path can work —
test error handling and `--show-code` instead of burning time on happy paths).
- **Isolate the config dir per test run.** The CLI persists profiles in
`platformdirs`-resolved `config.toml` (e.g. `~/.config/assemblyai/`). Concurrent or
destructive manual tests (corrupt-config probes, profile/env switches) stomp each other
through that shared file — set `XDG_CONFIG_HOME=$(mktemp -d)` per run instead.
- **Write scratch output to `/tmp`, never the repo root.** Redirects like `cmd > out.txt`
in the repo show up as untracked files and trip commit hooks/gates.
- **Headless boxes have no mic/speakers/browser.** `aai stream`/`aai agent` mic paths and
`aai login`'s browser flow can't complete; wrap exploratory runs in `timeout 30 …` so a
blocking path can't wedge the session. For pytest, `--timeout N` (pytest-timeout, in the
dev group) does the same per-test.

## Naming & packaging gotchas

- The **package/module** is `aai_cli`; the **distribution** name is `aai-cli`; the **console command** is `aai` (`[project.scripts] aai = "aai_cli.main:run"`).
Expand All @@ -70,15 +90,15 @@ A Typer CLI. `aai_cli/main.py` builds the `app`, registers each command sub-app,

### Command layer

Each file in `aai_cli/commands/` is a Typer sub-app (`transcribe`, `stream`, `transcripts`, `agent`, `llm`, `login`, `doctor`, `init`, `claude`). Command bodies run through `context.run_command(ctx, fn, json=...)`, which maps any `CLIError` to clean stderr output + the error's exit code. Commands never print tracebacks for expected failures.
Each file in `aai_cli/commands/` is a Typer sub-app (`transcribe`, `stream`, `transcripts`, `agent`, `llm`, `login` (login/logout/whoami), `doctor`, `init`, `dev`, `share`, `deploy`, `setup`, `onboard`, `account` (balance/usage/limits), `keys`, `sessions`, `audit`). Command bodies run through `context.run_command(ctx, fn, json=...)`, which maps any `CLIError` to clean stderr output + the error's exit code. Commands never print tracebacks for expected failures.

### Cross-cutting state (resolution order matters)

- **`context.py`** — `AppState` (profile, env) is attached to the Typer context in the root `@app.callback()`. `run_command` is the standard command wrapper.
- **`config.py`** — profiles persisted in `config.toml` (via `platformdirs`); the **API key lives only in the OS keyring** (`KEYRING_SERVICE = "assemblyai-cli"`), never in a dotfile. Key resolution order: `--api-key` flag (validation paths only) → `ASSEMBLYAI_API_KEY` env → keyring. **Run commands deliberately expose no `--api-key` flag** so keys can't leak into `ps`/shell history.
- **`environments.py`** — a frozen `Environment` (api_base, streaming_host, llm_gateway_base, ams_base, stytch_*). `DEFAULT_ENV` is **`production`**; use `--sandbox` (or `--env sandbox000` / `AAI_ENV`) to target the sandbox. The active environment is a process-global set once at startup; precedence: `--env` → `AAI_ENV` → profile's stored env → default. A credential is only valid against the environment that minted it.
- **`client.py`** — thin wrappers over the `assemblyai` SDK (`transcribe`, `list_transcripts`, `stream_audio`, etc.). It normalizes SDK exceptions: auth failures become a single clean `auth_failure()` `CLIError`; everything else becomes `APIError`. New SDK calls should follow this try/except shape.
- **`errors.py`** — the `CLIError` hierarchy (each with `error_type` + `exit_code`). `output.py` emits errors to **stderr**; stdout stays clean for pipelines. `--json` (auto-enabled when piped/agent-run) switches to machine-readable output.
- **`errors.py`** — the `CLIError` hierarchy (each with `error_type` + `exit_code`). `output.py` emits errors to **stderr**; stdout stays clean for pipelines. `--json` switches to machine-readable output; it is never auto-enabled — `output.resolve_json()` deliberately keeps human text the default even when piped or agent-run.

### Feature subsystems

Expand Down
12 changes: 7 additions & 5 deletions aai_cli/agent/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def connected(self) -> None:
self._line(Text("Connected — start talking. (Ctrl-C to stop)", style="aai.muted"))

def notice(self, text: str) -> None:
"""Print a human-facing notice (suppressed in JSON; to stderr in text mode)."""
"""Print a human-facing notice: suppressed in JSON, to stderr otherwise.

Stderr in *every* non-JSON mode (not just ``-o text``): the default human
mode is also piped sometimes (``aai agent | head``), and a notice on stdout
would be consumed as transcript data there.
"""
if self.json_mode:
return
if self.text_mode:
self._status(text.rstrip("\n"))
else:
self._line(text.rstrip("\n"))
self._status(text.rstrip("\n"))

# --- user --------------------------------------------------------------
def user_partial(self, text: str) -> None:
Expand Down
41 changes: 40 additions & 1 deletion aai_cli/agent/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import contextlib
import json
import logging
import threading
from collections.abc import Callable
from dataclasses import dataclass
Expand Down Expand Up @@ -31,6 +32,9 @@ def ws_url() -> str:
# session.error codes that mean the connection is unauthorized -> exit 2.
_AUTH_ERROR_CODES = {"UNAUTHORIZED", "FORBIDDEN"}

# A pre-upgrade HTTP 403 on the WebSocket handshake (see _is_rejected_key).
_HTTP_FORBIDDEN = 403

# The websocket connection, the `connect` factory, and the renderer/player/mic I/O
# objects come from libraries/modules with no usable type stubs. Alias that untyped
# boundary here so each role is named in signatures and `Any` stays in one place.
Expand Down Expand Up @@ -189,10 +193,44 @@ def _send_audio_loop(ws: _WebSocket, session: VoiceAgentSession, mic: _IO) -> No
return


# The sync websockets client logs through these; both are silenced for the session
# (the parent covers any future child logger, the client logger is the one that fires).
_WEBSOCKETS_LOGGERS = ("websockets", "websockets.client")


def _silence_websockets_logging() -> None:
"""Keep websockets' internal logging off the user's stderr for the session.

The sync client's background reader thread logs unhandled teardown errors (e.g.
``EOFError: stream ended``) as "unexpected internal error" + traceback through the
``websockets.client`` logger, which would land on stderr right next to our clean
CLIError. Those internals are never user-actionable from the CLI, so raise the
loggers above every level they emit at. Idempotent: re-setting the level is a no-op.
"""
for name in _WEBSOCKETS_LOGGERS:
logging.getLogger(name).setLevel(logging.CRITICAL)


def _is_rejected_key(exc: Exception) -> bool:
"""Is this connect/session failure auth-shaped (the key itself was rejected)?

Mirrors how `stream` classifies handshake failures: a plain HTTP 403 on the
WebSocket upgrade stays an API error there ("Streaming error: WebSocket handshake
rejected (HTTP 403)"), so it must not become "Your API key was rejected" here —
403 also covers non-credential blocks (WAF, region, plan). Only 401, the Voice
Agent's 1008 policy-violation close, or an explicitly auth-worded message
(`is_auth_failure`'s text hints) count as a rejected key.
"""
status = getattr(getattr(exc, "response", None), "status_code", None)
if status == _HTTP_FORBIDDEN:
return False
return is_auth_failure(exc)


def _auth_or_api_error(exc: Exception, message: str) -> CLIError:
"""Map a connect/session exception to the right CLIError: a rejected key becomes
auth_failure(), anything else becomes APIError(f"{message}: {exc}")."""
if is_auth_failure(exc):
if _is_rejected_key(exc):
return auth_failure()
return APIError(f"{message}: {exc}")

Expand Down Expand Up @@ -243,6 +281,7 @@ def run_session(
the agent's first reply to the spoken input and the capture thread waits for
session.ready before streaming the source.
"""
_silence_websockets_logging()
if connect is None:
from websockets.sync.client import connect

Expand Down
20 changes: 14 additions & 6 deletions aai_cli/auth/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aai_cli import output
from aai_cli.auth import ams, discovery, endpoints, loopback
from aai_cli.errors import APIError
from aai_cli.errors import APIError, NotAuthenticated


@dataclass
Expand Down Expand Up @@ -97,8 +97,8 @@ def _open_browser(url: str) -> None:
)


def _capture() -> loopback.CallbackResult:
return loopback.capture_callback()
def _start_capture() -> loopback.CallbackCapture:
return loopback.start_capture()


def _reusable_cli_key(token: _Token) -> str | None:
Expand Down Expand Up @@ -137,13 +137,21 @@ def find_or_create_cli_key(account_id: int, session_jwt: str) -> str:

def run_login_flow() -> LoginResult:
"""Drive the full browser + AMS login and return a LoginResult."""
# Bind the loopback callback server *before* opening the browser: if the port is
# taken, fail cleanly now instead of stranding the user mid-OAuth in a flow that
# can never call back.
capture = _start_capture()
_open_browser(discovery.build_start_url())
result = _capture()
output.error_console.print(
"[aai.muted]Waiting up to 2 minutes for you to finish signing in…[/aai.muted]\n"
"[aai.muted]No browser here? Run 'aai login --api-key <KEY>' instead.[/aai.muted]"
)
result = capture.wait()

if result.error == "timeout":
raise APIError(
raise NotAuthenticated(
"Login timed out waiting for the browser.",
suggestion="Run 'aai login' again.",
suggestion="Run 'aai login' again, or use 'aai login --api-key <KEY>'.",
)
if result.token_type != "discovery_oauth" or not result.token: # noqa: S105
raise APIError(
Expand Down
64 changes: 48 additions & 16 deletions aai_cli/auth/loopback.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,47 @@ class CallbackResult:
error: str | None = None


def capture_callback(
timeout: float = 120.0, # pragma: no mutate (default window; tests pass explicit timeouts)
) -> CallbackResult:
"""Bind the fixed loopback port, capture one OAuth callback, return its token.
@dataclass
class CallbackCapture:
"""A loopback callback server that is already bound and serving.

Only a callback to the registered path that carries a `token` is accepted; any
other request (a different path, or no token) gets a 4xx and the server keeps
waiting, so a stray request can't end the capture early. Returns a
CallbackResult; `error="timeout"` if no matching callback arrives in time.
Splitting the bind (`start_capture`) from the blocking wait lets the login flow
fail on a taken port *before* it sends the user's browser into the OAuth flow.
`wait()` blocks for one matching callback and always shuts the server down.
"""

result: CallbackResult
done: threading.Event
server: HTTPServer
thread: threading.Thread

def wait(
self,
timeout: float = 120.0, # pragma: no mutate (default window; tests pass explicit timeouts)
) -> CallbackResult:
"""Block for one OAuth callback (or the timeout), then shut the server down.

Returns the CallbackResult; `error="timeout"` if no matching callback
arrived in time.
"""
try:
if not self.done.wait(timeout):
self.result.error = "timeout"
finally:
self.server.shutdown() # stop serve_forever()
self.thread.join(timeout=5) # pragma: no mutate (cleanup grace period only)
self.server.server_close() # close the listening socket (shutdown() leaves it open)
return self.result


def start_capture() -> CallbackCapture:
"""Bind the fixed loopback port and start serving; the returned capture's
``wait()`` collects one OAuth callback.

Raises a clean APIError when the bind fails (port taken) so callers can abort
before opening the browser. Only a callback to the registered path that carries
a `token` is accepted; any other request (a different path, or no token) gets a
4xx and the server keeps waiting, so a stray request can't end the capture early.
"""
result = CallbackResult()
done = threading.Event()
Expand Down Expand Up @@ -81,11 +113,11 @@ def log_message(self, format: str, *args: object) -> None: # silence stderr log
) from exc
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
try:
if not done.wait(timeout):
result.error = "timeout"
finally:
server.shutdown() # stop serve_forever()
thread.join(timeout=5)
server.server_close() # close the listening socket (shutdown() leaves it open)
return result
return CallbackCapture(result=result, done=done, server=server, thread=thread)


def capture_callback(
timeout: float = 120.0, # pragma: no mutate (default window; tests pass explicit timeouts)
) -> CallbackResult:
"""Bind the port, capture one OAuth callback, and shut down (one-shot helper)."""
return start_capture().wait(timeout)
64 changes: 53 additions & 11 deletions aai_cli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,36 @@ def resolve_audio_source(source: str | None, *, sample: bool, check_local: bool
don't have yet is legitimate.
"""
if sample:
if source:
# Never silently prefer one over the other: the user asked for both.
raise UsageError(
"An audio source and --sample cannot be combined.",
suggestion="Pass the file/URL or --sample, not both.",
)
return SAMPLE_AUDIO_URL
if not source:
raise UsageError(
"Provide an audio path or URL.",
suggestion="Or pass --sample to use the hosted demo file.",
)
if check_local and not source.startswith(("http://", "https://")) and not Path(source).exists():
raise CLIError(
f"File not found: {source}",
error_type="file_not_found",
exit_code=2,
suggestion="Check the path. For remote audio, pass an http(s):// URL.",
)
if check_local and not source.startswith(("http://", "https://")):
path = Path(source)
if not path.exists():
raise CLIError(
f"File not found: {source}",
error_type="file_not_found",
exit_code=2,
suggestion="Check the path. For remote audio, pass an http(s):// URL.",
)
if not path.is_file():
# A directory (or socket/FIFO) would otherwise fall through to credential
# resolution and fail much later as an opaque upload error.
raise CLIError(
f"Not a file: {source}",
error_type="not_a_file",
exit_code=2,
suggestion="Pass an audio file, not a directory.",
)
return source


Expand Down Expand Up @@ -90,17 +107,42 @@ def _sdk_errors(message: str) -> Generator[None]:
raise APIError(f"{message}: {exc}") from exc


def _list_transcript_params(limit: int) -> aai.ListTranscriptParameters:
"""List-transcripts params that serialize without the spurious ``model_config`` key.

assemblyai==0.64.4 under pydantic==2.13.4: the SDK's pydantic-v1-shim request model
picks up the v2-style ``model_config`` class attribute as a regular field, so the
``.dict(exclude_none=True)`` the SDK puts on the query string ships a junk
``?model_config=...`` param on every request. Null the bogus field out so
``exclude_none`` drops it from the wire.
"""
params = aai.ListTranscriptParameters(limit=limit)
object.__setattr__(params, "model_config", None)
return params


# httpx-backed SDK errors embed a multi-line repr ("…\nReason: …\nRequest: <Request(…)>").
_REQUEST_REPR_RE = re.compile(r"Request: <[^>]*>")


def _compact_reason(exc: object) -> str:
"""``str(exc)`` as a single clean line: drop the trailing ``Request: <…>`` repr and
collapse all whitespace/newlines, keeping the informative reason text."""
text = _REQUEST_REPR_RE.sub("", str(exc))
return re.sub(r"\s+", " ", text).strip()


def validate_key(api_key: str) -> bool:
"""True if the key authenticates, False on an auth failure. Raises APIError otherwise."""
_configure(api_key)
try:
aai.Transcriber().list_transcripts(aai.ListTranscriptParameters(limit=1))
aai.Transcriber().list_transcripts(_list_transcript_params(1))
except aai.types.AssemblyAIError as exc:
if is_auth_failure(exc):
return False
raise APIError(f"Could not validate key: {exc}") from exc
raise APIError(f"Could not validate key: {_compact_reason(exc)}") from exc
except Exception as exc:
raise APIError(f"Network error contacting AssemblyAI: {exc}") from exc
raise APIError(f"Network error contacting AssemblyAI: {_compact_reason(exc)}") from exc
return True


Expand All @@ -114,7 +156,7 @@ def _item_to_dict(item: Any) -> dict[str, Any]:
def list_transcripts(api_key: str, *, limit: int = 10) -> list[dict[str, object]]:
_configure(api_key)
with _sdk_errors("Could not list transcripts"):
resp = aai.Transcriber().list_transcripts(aai.ListTranscriptParameters(limit=limit))
resp = aai.Transcriber().list_transcripts(_list_transcript_params(limit))
return [_item_to_dict(item) for item in resp.transcripts]


Expand Down
8 changes: 6 additions & 2 deletions aai_cli/code_gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ def stream(
merged: dict[str, object],
*,
llm: dict[str, object] | None = None,
source: str | None = None,
) -> str:
"""Generate runnable Python that reproduces this streaming invocation.

With `llm` (a dict of ``prompts``/``model``/``max_tokens``/``interval``), the script
``source`` mirrors the CLI argument: ``None`` streams the microphone, ``"-"``
reads raw PCM16 from stdin, and anything else is a file path/URL decoded through
ffmpeg — so the generated script reads the same input the real run would. With
`llm` (a dict of ``prompts``/``model``/``max_tokens``/``interval``), the script
refreshes a prompt-chain over the growing transcript every ``interval`` seconds (0 =
every turn) — the live sibling of `transcribe --llm` — mirroring how `stream --llm` runs.
"""
return _stream.render(merged, llm=llm)
return _stream.render(merged, llm=llm, source=source)
Loading
Loading