From 5644fb234d6351315b286b01c8f8ca699a2c5271 Mon Sep 17 00:00:00 2001 From: Alex Kroman Date: Mon, 8 Jun 2026 11:34:35 -0700 Subject: [PATCH 1/3] Default to production environment; drop login state nonce Flip DEFAULT_ENV from sandbox000 to production now that the prod Stytch redirect is live, and wire in the real production public token. --sandbox (or --env sandbox000 / AAI_ENV) still selects the sandbox. Remove the loopback CSRF state nonce so aai login sends a bare /callback redirect with no query parameters. Stytch validates the redirect's query string too, and the prod project only registers the bare path, so the nonce produced query_params_do_not_match. Dropping it avoids any Stytch dashboard change at the cost of the login-CSRF / account-confusion protection (deliberate, see PR notes). Also: - Lift the Stytch public tokens to module constants so the constructor kwarg takes a name, not a literal, dropping both # noqa: S106. - Make the shell-completion smoke test deterministic by pinning the detected shell (shellingham needs a TTY that CliRunner/CI lack). Co-Authored-By: Claude Opus 4.8 (1M context) --- AGENTS.md | 2 +- aai_cli/auth/discovery.py | 16 ++++------ aai_cli/auth/flow.py | 13 +++----- aai_cli/auth/loopback.py | 25 +++++++--------- aai_cli/environments.py | 20 ++++++++----- aai_cli/skills/aai-cli/SKILL.md | 2 +- tests/test_auth_discovery.py | 18 +++++------ tests/test_auth_endpoints.py | 6 ++-- tests/test_auth_flow.py | 52 ++++++++++++-------------------- tests/test_auth_loopback.py | 53 +++++++++++---------------------- tests/test_smoke.py | 10 +++++-- 11 files changed, 91 insertions(+), 126 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index be6a6f00..bb73d1f1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -59,7 +59,7 @@ Each file in `aai_cli/commands/` is a Typer sub-app (`transcribe`, `stream`, `tr - **`context.py`** — `AppState` (profile, env) is attached to the Typer context in the root `@app.callback()`. `run_command` is the standard command wrapper. - **`config.py`** — profiles persisted in `config.toml` (via `platformdirs`); the **API key lives only in the OS keyring** (`KEYRING_SERVICE = "assemblyai-cli"`), never in a dotfile. Key resolution order: `--api-key` flag (validation paths only) → `ASSEMBLYAI_API_KEY` env → keyring. **Run commands deliberately expose no `--api-key` flag** so keys can't leak into `ps`/shell history. -- **`environments.py`** — a frozen `Environment` (api_base, streaming_host, llm_gateway_base, ams_base, stytch_*). `DEFAULT_ENV` is currently **`sandbox000`** (flip to `production` once the prod Stytch value is real). The active environment is a process-global set once at startup; precedence: `--env` → `AAI_ENV` → profile's stored env → default. A credential is only valid against the environment that minted it. +- **`environments.py`** — a frozen `Environment` (api_base, streaming_host, llm_gateway_base, ams_base, stytch_*). `DEFAULT_ENV` is **`production`**; use `--sandbox` (or `--env sandbox000` / `AAI_ENV`) to target the sandbox. The active environment is a process-global set once at startup; precedence: `--env` → `AAI_ENV` → profile's stored env → default. A credential is only valid against the environment that minted it. - **`client.py`** — thin wrappers over the `assemblyai` SDK (`transcribe`, `list_transcripts`, `stream_audio`, etc.). It normalizes SDK exceptions: auth failures become a single clean `auth_failure()` `CLIError`; everything else becomes `APIError`. New SDK calls should follow this try/except shape. - **`errors.py`** — the `CLIError` hierarchy (each with `error_type` + `exit_code`). `output.py` emits errors to **stderr**; stdout stays clean for pipelines. `--json` (auto-enabled when piped/agent-run) switches to machine-readable output. diff --git a/aai_cli/auth/discovery.py b/aai_cli/auth/discovery.py index 4a52f614..3646f763 100644 --- a/aai_cli/auth/discovery.py +++ b/aai_cli/auth/discovery.py @@ -5,28 +5,22 @@ from aai_cli.auth import endpoints -def build_start_url(state: str) -> str: +def build_start_url() -> str: """The Stytch B2B OAuth discovery *start* URL the browser opens. Client-side endpoint authenticated by the project's public token. After the user authenticates with the provider, Stytch redirects to our loopback `discovery_redirect_url` with `?stytch_token_type=discovery_oauth&token=...`. - - `state` is a single-use random nonce carried as a query parameter *on the - redirect URL*. Stytch validates the redirect URL by scheme/host/port/path and - preserves query parameters through the flow, so the nonce comes back on the - callback unchanged; `loopback.capture_callback` rejects any callback whose - `state` doesn't match. This stops a malicious page from completing someone - else's login at the loopback server (login CSRF / account confusion) by - replaying an attacker-minted discovery token while `aai login` is waiting. + The redirect URL is the bare loopback path Stytch validates by + scheme/host/port/path, with no extra query parameters (so the redirect needs + no wildcard registered in the Stytch dashboard). """ base = ( f"{endpoints.stytch_domain()}" f"/v1/b2b/public/oauth/{endpoints.STYTCH_OAUTH_PROVIDER}/discovery/start" ) - redirect_with_state = f"{endpoints.redirect_uri()}?{urlencode({'state': state})}" params = { "public_token": endpoints.stytch_public_token(), - "discovery_redirect_url": redirect_with_state, + "discovery_redirect_url": endpoints.redirect_uri(), } return f"{base}?{urlencode(params)}" diff --git a/aai_cli/auth/flow.py b/aai_cli/auth/flow.py index 7a9d32f4..c8a086fe 100644 --- a/aai_cli/auth/flow.py +++ b/aai_cli/auth/flow.py @@ -1,6 +1,5 @@ from __future__ import annotations -import secrets import webbrowser from dataclasses import dataclass @@ -98,8 +97,8 @@ def _open_browser(url: str) -> None: ) -def _capture(state: str) -> loopback.CallbackResult: - return loopback.capture_callback(state) +def _capture() -> loopback.CallbackResult: + return loopback.capture_callback() def _reusable_cli_key(token: _Token) -> str | None: @@ -138,12 +137,8 @@ def find_or_create_cli_key(account_id: int, session_jwt: str) -> str: def run_login_flow() -> LoginResult: """Drive the full browser + AMS login and return a LoginResult.""" - # A fresh single-use nonce binds the browser hand-off to this loopback capture: - # build_start_url carries it to Stytch, capture_callback only accepts a callback - # that echoes it back. token_urlsafe(32) is 256 bits of entropy — unguessable. - state = secrets.token_urlsafe(32) # pragma: no mutate (any large nonce works; 32 isn't magic) - _open_browser(discovery.build_start_url(state)) - result = _capture(state) + _open_browser(discovery.build_start_url()) + result = _capture() if result.error == "timeout": raise APIError( diff --git a/aai_cli/auth/loopback.py b/aai_cli/auth/loopback.py index fa88170e..63246522 100644 --- a/aai_cli/auth/loopback.py +++ b/aai_cli/auth/loopback.py @@ -1,6 +1,5 @@ from __future__ import annotations -import secrets import threading from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, HTTPServer @@ -9,7 +8,7 @@ from aai_cli.auth import endpoints from aai_cli.errors import APIError -# The callback URL carries the single-use OAuth token (and the state nonce) in its +# The callback URL carries the single-use OAuth token in its # query string, so it would otherwise linger in the browser's history and address # bar. Scrub it from the current history entry with replaceState the moment the page # loads — no extra request to race the server shutdown, unlike a redirect. The token @@ -32,16 +31,14 @@ class CallbackResult: def capture_callback( - expected_state: str, timeout: float = 120.0, # pragma: no mutate (default window; tests pass explicit timeouts) ) -> CallbackResult: """Bind the fixed loopback port, capture one OAuth callback, return its token. - Only a callback whose `state` query parameter equals `expected_state` is - accepted; any other request (wrong/missing state, or a different path) gets a - 4xx and the server keeps waiting, so a forged callback can't complete someone - else's login. Returns a CallbackResult; `error="timeout"` if no matching - callback arrives in time. + Only a callback to the registered path that carries a `token` is accepted; any + other request (a different path, or no token) gets a 4xx and the server keeps + waiting, so a stray request can't end the capture early. Returns a + CallbackResult; `error="timeout"` if no matching callback arrives in time. """ result = CallbackResult() done = threading.Event() @@ -54,15 +51,15 @@ def do_GET(self) -> None: # stdlib API name self.end_headers() return qs = parse_qs(parsed.query) - state = next(iter(qs.get("state", [])), None) - # Constant-time compare so a forged callback can't probe the nonce by - # timing. A mismatch is rejected without ending the capture: the real - # callback can still arrive (otherwise it falls through to timeout). - if state is None or not secrets.compare_digest(state, expected_state): + token = next(iter(qs.get("token", [])), None) + # A callback with no token (a stray or preflight request) is rejected + # without ending the capture: the genuine callback can still arrive + # (otherwise it falls through to timeout). + if token is None: self.send_response(400) self.end_headers() return - result.token = next(iter(qs.get("token", [])), None) + result.token = token result.token_type = next(iter(qs.get("stytch_token_type", [])), None) self.send_response(200) self.send_header("Content-Type", "text/html") diff --git a/aai_cli/environments.py b/aai_cli/environments.py index ded6260d..fc27d396 100644 --- a/aai_cli/environments.py +++ b/aai_cli/environments.py @@ -25,6 +25,13 @@ class Environment: signup_url: str # where a first-time user creates an account +# Stytch *public* client tokens — safe to ship, not secrets despite the field name. +# Held as module constants (not inline literals) so the constructor kwarg takes a +# name, not a string, which is what would otherwise trip ruff's S106 hardcoded-secret +# heuristic on the `stytch_public_token=` argument. +_PROD_STYTCH_PUBLIC = "public-token-live-bbc59d30-d3c8-4815-a5be-fede00306680" +_SANDBOX_STYTCH_PUBLIC = "public-token-test-a161155e-7e9b-4dd1-9d43-493c899b4117" + ENVIRONMENTS: dict[str, Environment] = { "production": Environment( name="production", @@ -32,12 +39,9 @@ class Environment: streaming_host="streaming.assemblyai.com", agents_host="agents.assemblyai.com", llm_gateway_base="https://llm-gateway.assemblyai.com/v1", - # NOTE: production Stytch is not provisioned yet (see the REPLACE_ME - # token), which is why DEFAULT_ENV stays "sandbox000". Tracked under - # spec O4. ams_base="https://ams.internal.assemblyai-labs.com", stytch_domain="https://api.stytch.com", - stytch_public_token="public-token-live-REPLACE_ME", # noqa: S106 - public token, safe to ship + stytch_public_token=_PROD_STYTCH_PUBLIC, signup_url="https://www.assemblyai.com/dashboard", ), "sandbox000": Environment( @@ -48,14 +52,14 @@ class Environment: llm_gateway_base="https://llm-gateway.sandbox000.assemblyai-labs.com/v1", ams_base="https://ams.sandbox000.assemblyai-labs.com", stytch_domain="https://test.stytch.com", - stytch_public_token="public-token-test-a161155e-7e9b-4dd1-9d43-493c899b4117", # noqa: S106 - public token, safe to ship + stytch_public_token=_SANDBOX_STYTCH_PUBLIC, signup_url="https://dashboard-assemblyai.vercel.app/dashboard/login", ), } -# Shipped default when nothing selects an environment. Flip to "production" at -# release once the production Stytch value above is real. -DEFAULT_ENV = "sandbox000" +# Shipped default when nothing selects an environment. Use --sandbox (or +# --env sandbox000 / AAI_ENV) to target the sandbox instead. +DEFAULT_ENV = "production" # The environment in effect for this process, set once at CLI startup (like # aai.settings). Resolved from --env / AAI_ENV / the profile's stored env. diff --git a/aai_cli/skills/aai-cli/SKILL.md b/aai_cli/skills/aai-cli/SKILL.md index b66e59e2..4dac3b4e 100644 --- a/aai_cli/skills/aai-cli/SKILL.md +++ b/aai_cli/skills/aai-cli/SKILL.md @@ -26,7 +26,7 @@ shell history. Do not look for one. **Environment binding.** The backend environment is selected by `--env` (or `AAI_ENV`, or the profile's stored env). `--sandbox` is shorthand for -`--env sandbox000`. The default environment is currently `sandbox000`. +`--env sandbox000`. The default environment is `production`. **A credential is only valid against the environment that minted it** — a sandbox key fails against production and vice-versa. If a freshly-working key suddenly returns auth errors, check you are on the same `--env` you logged in diff --git a/tests/test_auth_discovery.py b/tests/test_auth_discovery.py index 1fe1f49b..87632653 100644 --- a/tests/test_auth_discovery.py +++ b/tests/test_auth_discovery.py @@ -4,7 +4,7 @@ def test_build_start_url_targets_b2b_discovery_for_provider(): - url = discovery.build_start_url("state-xyz") + url = discovery.build_start_url() parsed = urlparse(url) assert parsed.scheme == "https" assert parsed.path == "/v1/b2b/public/oauth/google/discovery/start" @@ -12,18 +12,18 @@ def test_build_start_url_targets_b2b_discovery_for_provider(): def test_build_start_url_includes_public_token_and_redirect(): - url = discovery.build_start_url("state-xyz") + url = discovery.build_start_url() qs = parse_qs(urlparse(url).query) assert qs["public_token"] == [endpoints.stytch_public_token()] - # The state nonce rides as a query param on the (still path-exact) redirect URL. - assert qs["discovery_redirect_url"] == ["http://127.0.0.1:8585/callback?state=state-xyz"] + # The redirect URL is the bare loopback path Stytch validates — no query params. + assert qs["discovery_redirect_url"] == ["http://127.0.0.1:8585/callback"] -def test_build_start_url_carries_state_into_redirect(): - url = discovery.build_start_url("nonce-123") +def test_build_start_url_redirect_has_no_query_params(): + url = discovery.build_start_url() redirect = parse_qs(urlparse(url).query)["discovery_redirect_url"][0] redirect_parsed = urlparse(redirect) - # The redirect path stays exactly /callback (what Stytch validates); only the - # query string gains the nonce, so redirect-URL matching is unaffected. + # Path-exact /callback with an empty query string keeps Stytch's redirect-URL + # matching simple (no query-parameter validation to configure). assert redirect_parsed.path == "/callback" - assert parse_qs(redirect_parsed.query)["state"] == ["nonce-123"] + assert redirect_parsed.query == "" diff --git a/tests/test_auth_endpoints.py b/tests/test_auth_endpoints.py index 3e162bc0..a273d4e8 100644 --- a/tests/test_auth_endpoints.py +++ b/tests/test_auth_endpoints.py @@ -6,9 +6,9 @@ def test_redirect_uri_is_fixed_loopback(): def test_env_specific_values_come_from_active_environment(): - # With no active env set, helpers fall back to the default (sandbox000). - assert endpoints.ams_base() == "https://ams.sandbox000.assemblyai-labs.com" - assert endpoints.stytch_domain() == "https://test.stytch.com" + # With no active env set, helpers fall back to the default (production). + assert endpoints.ams_base() == "https://ams.internal.assemblyai-labs.com" + assert endpoints.stytch_domain() == "https://api.stytch.com" assert endpoints.stytch_public_token().startswith("public-token-") assert endpoints.signup_url().startswith("https://") diff --git a/tests/test_auth_flow.py b/tests/test_auth_flow.py index aeaadf8a..b2d73e58 100644 --- a/tests/test_auth_flow.py +++ b/tests/test_auth_flow.py @@ -74,31 +74,20 @@ def test_find_or_create_raises_when_no_projects(monkeypatch): def test_capture_delegates_to_loopback(monkeypatch): sentinel = loopback.CallbackResult(token="tok", token_type="discovery_oauth") - captured = {} - - def fake_capture(state): - captured["state"] = state - return sentinel + monkeypatch.setattr(flow.loopback, "capture_callback", lambda: sentinel) + assert flow._capture() is sentinel - monkeypatch.setattr(flow.loopback, "capture_callback", fake_capture) - assert flow._capture("nonce-1") is sentinel - assert captured["state"] == "nonce-1" # the nonce is forwarded to the loopback - -def test_run_login_flow_binds_state_nonce(monkeypatch): - # The nonce build_start_url() carries to Stytch must be the same one the loopback - # capture is told to expect, or a genuine callback would never be accepted. +def test_run_login_flow_opens_the_discovery_start_url(monkeypatch): + # The browser is opened with exactly the URL build_start_url() produces. seen = {} + monkeypatch.setattr(flow.discovery, "build_start_url", lambda: "start-url") + monkeypatch.setattr(flow, "_open_browser", lambda url: seen.setdefault("url", url)) monkeypatch.setattr( - flow.discovery, "build_start_url", lambda state: seen.setdefault("url_state", state) or "u" + flow, + "_capture", + lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) - monkeypatch.setattr(flow, "_open_browser", lambda url: None) - - def fake_capture(state): - seen["capture_state"] = state - return loopback.CallbackResult(token="tok", token_type="discovery_oauth") - - monkeypatch.setattr(flow, "_capture", fake_capture) monkeypatch.setattr( flow.ams, "discover", @@ -115,8 +104,7 @@ def fake_capture(state): monkeypatch.setattr(flow, "find_or_create_cli_key", lambda acct, jwt: "sk_final") flow.run_login_flow() - assert seen["url_state"] == seen["capture_state"] - assert len(seen["capture_state"]) >= 32 # token_urlsafe(32) -> unguessable nonce + assert seen["url"] == "start-url" def test_run_login_flow_rejects_wrong_token_type(monkeypatch): @@ -124,7 +112,7 @@ def test_run_login_flow_rejects_wrong_token_type(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda _state: loopback.CallbackResult(token="tok", token_type="something_else"), + lambda: loopback.CallbackResult(token="tok", token_type="something_else"), ) with pytest.raises(APIError) as exc: flow.run_login_flow() @@ -137,7 +125,7 @@ def test_run_login_flow_happy_path(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -161,7 +149,7 @@ def test_run_login_flow_happy_path(monkeypatch): def test_run_login_flow_timeout_raises(monkeypatch): monkeypatch.setattr(flow, "_open_browser", lambda url: None) - monkeypatch.setattr(flow, "_capture", lambda _state: loopback.CallbackResult(error="timeout")) + monkeypatch.setattr(flow, "_capture", lambda: loopback.CallbackResult(error="timeout")) with pytest.raises(APIError) as exc: flow.run_login_flow() assert exc.value.message == "Login timed out waiting for the browser." @@ -203,7 +191,7 @@ def test_run_login_flow_uses_exchange_account(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -234,7 +222,7 @@ def test_run_login_flow_multi_org_notes_selection(monkeypatch, capsys): monkeypatch.setattr( flow, "_capture", - lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -274,7 +262,7 @@ def test_run_login_flow_missing_session_token_raises_api_error(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -290,7 +278,7 @@ def test_run_login_flow_org_missing_id_raises_api_error(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -309,7 +297,7 @@ def test_run_login_flow_zero_orgs_raises(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -329,9 +317,7 @@ def test_run_login_flow_returns_session_material(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda _state: loopback.CallbackResult( - token="tok", token_type="discovery_oauth", error=None - ), + lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth", error=None), ) monkeypatch.setattr( flow.ams, diff --git a/tests/test_auth_loopback.py b/tests/test_auth_loopback.py index 2ad04b12..8574de86 100644 --- a/tests/test_auth_loopback.py +++ b/tests/test_auth_loopback.py @@ -36,11 +36,11 @@ def test_capture_returns_token_and_type(): result_box = {} def run(): - result_box["result"] = loopback.capture_callback("st8", timeout=5.0) + result_box["result"] = loopback.capture_callback(timeout=5.0) t = threading.Thread(target=run) t.start() - status = _hit("/callback?state=st8&stytch_token_type=discovery_oauth&token=tok_abc") + status = _hit("/callback?stytch_token_type=discovery_oauth&token=tok_abc") t.join(timeout=5) assert status == 200 # the callback is acknowledged with 200 OK @@ -56,12 +56,12 @@ def test_capture_ignores_unknown_paths(): result_box = {} def run(): - result_box["result"] = loopback.capture_callback("st8", timeout=5.0) + result_box["result"] = loopback.capture_callback(timeout=5.0) t = threading.Thread(target=run) t.start() assert _hit("/favicon.ico") == 404 # unknown path -> 404, capture stays open - _hit("/callback?state=st8&stytch_token_type=discovery_oauth&token=tok_late") + _hit("/callback?stytch_token_type=discovery_oauth&token=tok_late") t.join(timeout=5) result = result_box["result"] @@ -83,58 +83,41 @@ def _body(path: str) -> bytes: def test_success_page_scrubs_token_from_history(): - # The callback URL holds the single-use token + state in its query string; the - # success page must drop it from browser history rather than leave it lingering. + # The callback URL holds the single-use token in its query string; the success + # page must drop it from browser history rather than leave it lingering. def run(): - loopback.capture_callback("st8", timeout=5.0) + loopback.capture_callback(timeout=5.0) t = threading.Thread(target=run) t.start() assert _hit("/favicon.ico") == 404 # wait until the server is bound (keeps capture open) - body = _body("/callback?state=st8&stytch_token_type=discovery_oauth&token=tok_abc") + body = _body("/callback?stytch_token_type=discovery_oauth&token=tok_abc") t.join(timeout=5) - assert b"replaceState" in body # the query (token + state) is scrubbed client-side + assert b"replaceState" in body # the query (token) is scrubbed client-side assert b"tok_abc" not in body # the page never reflects the token itself -def test_capture_rejects_mismatched_state(): - # A callback with the wrong state nonce (a forged/login-CSRF attempt) is refused - # with a 400 and does not end the capture; the genuine callback then completes it. +def test_capture_rejects_callback_without_token(): + # A callback to /callback that carries no token (a stray/preflight request) is + # refused with a 400 and does not end the capture; the real callback completes it. result_box = {} def run(): - result_box["result"] = loopback.capture_callback("good", timeout=5.0) + result_box["result"] = loopback.capture_callback(timeout=5.0) t = threading.Thread(target=run) t.start() - assert _hit("/callback?state=evil&stytch_token_type=discovery_oauth&token=tok_bad") == 400 - _hit("/callback?state=good&stytch_token_type=discovery_oauth&token=tok_ok") + assert _hit("/callback?stytch_token_type=discovery_oauth") == 400 + _hit("/callback?stytch_token_type=discovery_oauth&token=tok_ok") t.join(timeout=5) result = result_box["result"] - assert result.token == "tok_ok" # the forged token was never captured - - -def test_capture_rejects_missing_state(): - # A callback with no state at all is refused (400) and never captured. - result_box = {} - - def run(): - result_box["result"] = loopback.capture_callback("good", timeout=0.8) - - t = threading.Thread(target=run) - t.start() - assert _hit("/callback?stytch_token_type=discovery_oauth&token=tok_bad") == 400 - t.join(timeout=5) - - result = result_box["result"] - assert result.error == "timeout" - assert result.token is None + assert result.token == "tok_ok" # the tokenless request never ended the capture def test_capture_times_out_without_callback(): - result = loopback.capture_callback("st8", timeout=0.3) + result = loopback.capture_callback(timeout=0.3) assert result.error == "timeout" assert result.token is None @@ -149,6 +132,6 @@ def test_capture_raises_clean_error_when_port_unavailable(monkeypatch): monkeypatch.setattr(endpoints, "LOOPBACK_PORT", port) try: with pytest.raises(APIError): - loopback.capture_callback("st8", timeout=1.0) + loopback.capture_callback(timeout=1.0) finally: busy.close() diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 14c8a3f9..c2741ae0 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -42,11 +42,17 @@ def test_quiet_suppresses_env_override_warning(monkeypatch): assert "may be rejected" not in quiet.output -def test_shell_completion_is_available(): +def test_shell_completion_is_available(monkeypatch): # add_completion=True ships `--show-completion` (and --install-completion), the - # discoverability affordance gh/kubectl/docker users reach for. + # discoverability affordance gh/kubectl/docker users reach for. Typer detects the + # active shell via shellingham, which needs a controlling TTY — absent under + # CliRunner/CI — so pin it to a known shell to test the affordance deterministically. + import typer.completion + + monkeypatch.setattr(typer.completion, "_get_shell_name", lambda: "zsh") result = runner.invoke(app, ["--show-completion"]) assert result.exit_code == 0 + assert "_aai_completion" in result.output # the emitted zsh completion script def test_global_flags_parse(): From e4837b1fdcb9d8706e8ca8a75042a0121cc382b5 Mon Sep 17 00:00:00 2001 From: Alex Kroman Date: Mon, 8 Jun 2026 11:38:44 -0700 Subject: [PATCH 2/3] Raise a clean error on keyring write failure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit set_api_key / set_session called keyring.set_password directly, so a locked or ACL-denied OS keychain (e.g. macOS errSecInvalidOwnerEdit, -25244) escaped as a raw PasswordSetError traceback during `aai login`. Wrap both writes in _keyring_set, which converts keyring.errors.KeyringError into a CLIError with a fixable suggestion — matching the project rule that commands never print tracebacks for expected failures. Co-Authored-By: Claude Opus 4.8 (1M context) --- aai_cli/config.py | 27 +++++++++++++++++++++++---- tests/test_config.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/aai_cli/config.py b/aai_cli/config.py index 8fe91dfc..94174524 100644 --- a/aai_cli/config.py +++ b/aai_cli/config.py @@ -13,7 +13,7 @@ import tomli_w from pydantic import BaseModel, ConfigDict, Field, ValidationError -from aai_cli.errors import NotAuthenticated +from aai_cli.errors import CLIError, NotAuthenticated KEYRING_SERVICE = "assemblyai-cli" ENV_API_KEY = "ASSEMBLYAI_API_KEY" @@ -123,9 +123,29 @@ def get_active_profile() -> str: return _load().active_profile or DEFAULT_PROFILE +def _keyring_set(username: str, secret: str) -> None: + """Write a secret to the OS keyring, turning backend failures into a clean error. + + A locked keychain, or an existing entry whose ACL is bound to another app, makes + keyring raise a KeyringError (e.g. macOS errSecInvalidOwnerEdit, -25244). Surface + it as a CLIError so the command prints a fixable message instead of a traceback. + """ + try: + keyring.set_password(KEYRING_SERVICE, username, secret) + except keyring.errors.KeyringError as exc: + raise CLIError( + f"Your OS keyring rejected the write ({exc}).", + error_type="keyring_error", + suggestion=( + "Unlock your keyring, or remove the stale 'assemblyai-cli' entry and " + "retry (macOS: security delete-generic-password -s assemblyai-cli)." + ), + ) from exc + + def set_api_key(profile: str, api_key: str) -> None: _validate_profile(profile) - keyring.set_password(KEYRING_SERVICE, profile, api_key) + _keyring_set(profile, api_key) cfg = _load() cfg.profiles.setdefault(profile, Profile()) if cfg.active_profile is None: @@ -170,8 +190,7 @@ def set_session(profile: str, *, session_jwt: str, session_token: str, account_i key. The JWT is short-lived; an expired session surfaces as NotAuthenticated. """ _validate_profile(profile) - keyring.set_password( - KEYRING_SERVICE, + _keyring_set( _session_username(profile), StoredSession(jwt=session_jwt, token=session_token).model_dump_json(), ) diff --git a/tests/test_config.py b/tests/test_config.py index 4f4daab4..1880853e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -58,6 +58,34 @@ def test_clear_api_key_missing_is_silent(): assert config.get_api_key("never_set") is None +def test_set_api_key_keyring_failure_raises_clean_error(monkeypatch): + # A keyring write failure (e.g. a locked/ACL-denied macOS keychain) must surface + # as a clean CLIError, never a raw PasswordSetError traceback. + import keyring.errors + + def boom(*_a, **_k): + raise keyring.errors.PasswordSetError("denied by keychain") + + monkeypatch.setattr(config.keyring, "set_password", boom) + with pytest.raises(CLIError) as exc: + config.set_api_key("default", "sk_abc") + assert "keyring" in exc.value.message.lower() + assert exc.value.suggestion is not None + + +def test_set_session_keyring_failure_raises_clean_error(monkeypatch): + import keyring.errors + + def boom(*_a, **_k): + raise keyring.errors.PasswordSetError("denied by keychain") + + monkeypatch.setattr(config.keyring, "set_password", boom) + with pytest.raises(CLIError) as exc: + config.set_session("default", session_jwt="j", session_token="t", account_id=1) + assert "keyring" in exc.value.message.lower() + assert exc.value.suggestion is not None + + def test_invalid_profile_name_rejected(): import pytest From e263530eff1c788b41eabf33ce24c8e60adc9b15 Mon Sep 17 00:00:00 2001 From: Alex Kroman Date: Mon, 8 Jun 2026 11:43:17 -0700 Subject: [PATCH 3/3] Make browser-login credential writes atomic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit persist_browser_login wrote the API key, env, and session as three separate keyring/config writes, so a mid-sequence failure (e.g. a locked keychain after the key was already stored) left a half-written profile — an API key with no session, which looks signed-in but can't reach AMS. Add config.persist_login, which runs the three writes and, on any failure, restores the pre-login snapshot: config.toml is rewritten verbatim in one atomic dump and the two keyring entries are restored best-effort (try/finally + done flag, so no blind except). Rewire persist_browser_login onto it. Co-Authored-By: Claude Opus 4.8 (1M context) --- aai_cli/config.py | 53 +++++++++++++++++++++++++++++++ aai_cli/context.py | 6 ++-- tests/test_config.py | 74 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 3 deletions(-) diff --git a/aai_cli/config.py b/aai_cli/config.py index 94174524..9d0ba783 100644 --- a/aai_cli/config.py +++ b/aai_cli/config.py @@ -143,6 +143,19 @@ def _keyring_set(username: str, secret: str) -> None: ) from exc +def _keyring_restore(username: str, prior: str | None) -> None: + """Best-effort restore of a keyring entry to a snapshot value, for login rollback. + + Suppresses keyring errors (including a delete of an absent entry) so a failed + rollback never masks the original write error that triggered it. + """ + with contextlib.suppress(keyring.errors.KeyringError): + if prior is None: + keyring.delete_password(KEYRING_SERVICE, username) + else: + keyring.set_password(KEYRING_SERVICE, username, prior) + + def set_api_key(profile: str, api_key: str) -> None: _validate_profile(profile) _keyring_set(profile, api_key) @@ -227,6 +240,46 @@ def clear_session(profile: str) -> None: _dump(cfg) +def persist_login( + profile: str, + *, + api_key: str, + env: str, + session_jwt: str, + session_token: str, + account_id: int, +) -> None: + """Atomically persist a full browser-login result (API key + env + session). + + The three writes span the keyring and config.toml, so a mid-sequence failure + (e.g. a locked keychain after the key is already stored) would otherwise leave a + half-written profile — an API key with no session, which looks signed-in but + can't reach AMS. On any failure the pre-login snapshot is restored: config.toml + is rewritten verbatim in one atomic dump, and the two keyring entries are + restored best-effort. + """ + _validate_profile(profile) + prior_api_key = keyring.get_password(KEYRING_SERVICE, profile) + prior_session = keyring.get_password(KEYRING_SERVICE, _session_username(profile)) + prior_cfg = _load() + done = False + try: + set_api_key(profile, api_key) + set_profile_env(profile, env) + set_session( + profile, + session_jwt=session_jwt, + session_token=session_token, + account_id=account_id, + ) + done = True + finally: + if not done: + _keyring_restore(profile, prior_api_key) + _keyring_restore(_session_username(profile), prior_session) + _dump(prior_cfg) + + def resolve_api_key(*, profile: str | None = None, api_key_flag: str | None = None) -> str: if api_key_flag is not None: if not api_key_flag: diff --git a/aai_cli/context.py b/aai_cli/context.py index f0ba31ed..5a25e169 100644 --- a/aai_cli/context.py +++ b/aai_cli/context.py @@ -91,10 +91,10 @@ def env_override_warning(state: AppState) -> str | None: def persist_browser_login(profile: str, env: str) -> None: """Run the browser login flow and persist its credentials for `profile`/`env`.""" result = run_login_flow() - config.set_api_key(profile, result.api_key) - config.set_profile_env(profile, env) - config.set_session( + config.persist_login( profile, + api_key=result.api_key, + env=env, session_jwt=result.session_jwt, session_token=result.session_token, account_id=result.account_id, diff --git a/tests/test_config.py b/tests/test_config.py index 1880853e..489d89bc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -86,6 +86,80 @@ def boom(*_a, **_k): assert exc.value.suggestion is not None +def test_persist_login_writes_key_env_and_session(): + config.persist_login( + "default", + api_key="sk_new", + env="production", + session_jwt="j", + session_token="t", + account_id=7, + ) + assert config.get_api_key("default") == "sk_new" + assert config.get_profile_env("default") == "production" + assert config.get_account_id("default") == 7 + assert config.get_session("default") == {"jwt": "j", "token": "t"} + + +def _fail_on_session_write(monkeypatch): + """Make keyring writes to the session entry fail, leaving the api-key write alone. + + Reproduces the partial-write window: set_api_key succeeds, set_session does not. + """ + import keyring.errors + + real_set = config.keyring.set_password + + def selective(service, username, secret): + if username.startswith(config.SESSION_KEYRING_PREFIX + ":"): + raise keyring.errors.PasswordSetError("denied by keychain") + return real_set(service, username, secret) + + monkeypatch.setattr(config.keyring, "set_password", selective) + + +def test_persist_login_rolls_back_to_empty_when_session_write_fails(monkeypatch): + # A fresh profile: if the session write fails, nothing must persist — no orphaned + # API key or env that would make the CLI look logged in with no usable session. + _fail_on_session_write(monkeypatch) + with pytest.raises(CLIError): + config.persist_login( + "default", + api_key="sk_new", + env="production", + session_jwt="j", + session_token="t", + account_id=5, + ) + assert config.get_api_key("default") is None + assert config.get_profile_env("default") is None + assert config.get_account_id("default") is None + assert config.get_session("default") is None + + +def test_persist_login_restores_prior_credentials_when_session_write_fails(monkeypatch): + # An existing logged-in profile must be left exactly as it was if a re-login fails + # partway, rather than clobbered with the half-applied new values. + config.set_api_key("default", "sk_old") + config.set_profile_env("default", "sandbox000") + config.set_session("default", session_jwt="oldj", session_token="oldt", account_id=1) + + _fail_on_session_write(monkeypatch) + with pytest.raises(CLIError): + config.persist_login( + "default", + api_key="sk_new", + env="production", + session_jwt="j", + session_token="t", + account_id=5, + ) + assert config.get_api_key("default") == "sk_old" + assert config.get_profile_env("default") == "sandbox000" + assert config.get_account_id("default") == 1 + assert config.get_session("default") == {"jwt": "oldj", "token": "oldt"} + + def test_invalid_profile_name_rejected(): import pytest