diff --git a/aai_cli/auth/discovery.py b/aai_cli/auth/discovery.py index f1c46916..4a52f614 100644 --- a/aai_cli/auth/discovery.py +++ b/aai_cli/auth/discovery.py @@ -5,21 +5,28 @@ from aai_cli.auth import endpoints -def build_start_url() -> str: +def build_start_url(state: str) -> str: """The Stytch B2B OAuth discovery *start* URL the browser opens. Client-side endpoint authenticated by the project's public token. After the user authenticates with the provider, Stytch redirects to our loopback `discovery_redirect_url` with `?stytch_token_type=discovery_oauth&token=...`. - No custom state is appended: the redirect is exact-match validated and the - server is loopback-only, single-shot. + + `state` is a single-use random nonce carried as a query parameter *on the + redirect URL*. Stytch validates the redirect URL by scheme/host/port/path and + preserves query parameters through the flow, so the nonce comes back on the + callback unchanged; `loopback.capture_callback` rejects any callback whose + `state` doesn't match. This stops a malicious page from completing someone + else's login at the loopback server (login CSRF / account confusion) by + replaying an attacker-minted discovery token while `aai login` is waiting. """ base = ( f"{endpoints.stytch_domain()}" f"/v1/b2b/public/oauth/{endpoints.STYTCH_OAUTH_PROVIDER}/discovery/start" ) + redirect_with_state = f"{endpoints.redirect_uri()}?{urlencode({'state': state})}" params = { "public_token": endpoints.stytch_public_token(), - "discovery_redirect_url": endpoints.redirect_uri(), + "discovery_redirect_url": redirect_with_state, } return f"{base}?{urlencode(params)}" diff --git a/aai_cli/auth/flow.py b/aai_cli/auth/flow.py index c8a086fe..7a9d32f4 100644 --- a/aai_cli/auth/flow.py +++ b/aai_cli/auth/flow.py @@ -1,5 +1,6 @@ from __future__ import annotations +import secrets import webbrowser from dataclasses import dataclass @@ -97,8 +98,8 @@ def _open_browser(url: str) -> None: ) -def _capture() -> loopback.CallbackResult: - return loopback.capture_callback() +def _capture(state: str) -> loopback.CallbackResult: + return loopback.capture_callback(state) def _reusable_cli_key(token: _Token) -> str | None: @@ -137,8 +138,12 @@ def find_or_create_cli_key(account_id: int, session_jwt: str) -> str: def run_login_flow() -> LoginResult: """Drive the full browser + AMS login and return a LoginResult.""" - _open_browser(discovery.build_start_url()) - result = _capture() + # A fresh single-use nonce binds the browser hand-off to this loopback capture: + # build_start_url carries it to Stytch, capture_callback only accepts a callback + # that echoes it back. token_urlsafe(32) is 256 bits of entropy — unguessable. + state = secrets.token_urlsafe(32) # pragma: no mutate (any large nonce works; 32 isn't magic) + _open_browser(discovery.build_start_url(state)) + result = _capture(state) if result.error == "timeout": raise APIError( diff --git a/aai_cli/auth/loopback.py b/aai_cli/auth/loopback.py index 7042031d..fa88170e 100644 --- a/aai_cli/auth/loopback.py +++ b/aai_cli/auth/loopback.py @@ -1,5 +1,6 @@ from __future__ import annotations +import secrets import threading from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, HTTPServer @@ -8,8 +9,16 @@ from aai_cli.auth import endpoints from aai_cli.errors import APIError +# The callback URL carries the single-use OAuth token (and the state nonce) in its +# query string, so it would otherwise linger in the browser's history and address +# bar. Scrub it from the current history entry with replaceState the moment the page +# loads — no extra request to race the server shutdown, unlike a redirect. The token +# is already spent server-side by the time the user reads this, but keeping it out of +# history is the OAuth-for-native-apps (RFC 8252) hygiene. The page reflects no query +# data, so there is nothing to inject; the script is a static literal. _SUCCESS_HTML = ( b"
" + b"" b"You can close this tab and return to the terminal.
" b"" ) @@ -22,10 +31,17 @@ class CallbackResult: error: str | None = None -def capture_callback(timeout: float = 120.0) -> CallbackResult: +def capture_callback( + expected_state: str, + timeout: float = 120.0, # pragma: no mutate (default window; tests pass explicit timeouts) +) -> CallbackResult: """Bind the fixed loopback port, capture one OAuth callback, return its token. - Returns a CallbackResult; `error="timeout"` if no callback arrives in time. + Only a callback whose `state` query parameter equals `expected_state` is + accepted; any other request (wrong/missing state, or a different path) gets a + 4xx and the server keeps waiting, so a forged callback can't complete someone + else's login. Returns a CallbackResult; `error="timeout"` if no matching + callback arrives in time. """ result = CallbackResult() done = threading.Event() @@ -38,6 +54,14 @@ def do_GET(self) -> None: # stdlib API name self.end_headers() return qs = parse_qs(parsed.query) + state = next(iter(qs.get("state", [])), None) + # Constant-time compare so a forged callback can't probe the nonce by + # timing. A mismatch is rejected without ending the capture: the real + # callback can still arrive (otherwise it falls through to timeout). + if state is None or not secrets.compare_digest(state, expected_state): + self.send_response(400) + self.end_headers() + return result.token = next(iter(qs.get("token", [])), None) result.token_type = next(iter(qs.get("stytch_token_type", [])), None) self.send_response(200) diff --git a/aai_cli/init/scaffold.py b/aai_cli/init/scaffold.py index 49988e90..813c8259 100644 --- a/aai_cli/init/scaffold.py +++ b/aai_cli/init/scaffold.py @@ -1,6 +1,8 @@ # aai_cli/init/scaffold.py from __future__ import annotations +import os +import stat from importlib import resources from pathlib import Path from typing import TYPE_CHECKING @@ -84,5 +86,15 @@ def scaffold( _copy_tree(root, target) lines = [f"ASSEMBLYAI_API_KEY={api_key or PLACEHOLDER_KEY}"] lines += [f"{k}={v}" for k, v in (env_vars or {}).items()] - (target / ".env").write_text("\n".join(lines) + "\n") + env_path = target / ".env" + # The .env holds the real API key, so create it readable/writable by the owner + # only (0600) instead of the umask default (commonly 0644) — otherwise the key + # would be world/group-readable on a shared host. Open with the 0600 mode so the + # secret is never briefly world-readable; the explicit chmod then also tightens an + # existing file when `aai init --force` overwrites one (O_CREAT's mode is ignored + # for a file that already exists). + fd = os.open(env_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IRUSR | stat.S_IWUSR) + with os.fdopen(fd, "w") as fh: + fh.write("\n".join(lines) + "\n") + env_path.chmod(stat.S_IRUSR | stat.S_IWUSR) return target diff --git a/tests/test_auth_discovery.py b/tests/test_auth_discovery.py index d19383b5..1fe1f49b 100644 --- a/tests/test_auth_discovery.py +++ b/tests/test_auth_discovery.py @@ -4,7 +4,7 @@ def test_build_start_url_targets_b2b_discovery_for_provider(): - url = discovery.build_start_url() + url = discovery.build_start_url("state-xyz") parsed = urlparse(url) assert parsed.scheme == "https" assert parsed.path == "/v1/b2b/public/oauth/google/discovery/start" @@ -12,7 +12,18 @@ def test_build_start_url_targets_b2b_discovery_for_provider(): def test_build_start_url_includes_public_token_and_redirect(): - url = discovery.build_start_url() + url = discovery.build_start_url("state-xyz") qs = parse_qs(urlparse(url).query) assert qs["public_token"] == [endpoints.stytch_public_token()] - assert qs["discovery_redirect_url"] == ["http://127.0.0.1:8585/callback"] + # The state nonce rides as a query param on the (still path-exact) redirect URL. + assert qs["discovery_redirect_url"] == ["http://127.0.0.1:8585/callback?state=state-xyz"] + + +def test_build_start_url_carries_state_into_redirect(): + url = discovery.build_start_url("nonce-123") + redirect = parse_qs(urlparse(url).query)["discovery_redirect_url"][0] + redirect_parsed = urlparse(redirect) + # The redirect path stays exactly /callback (what Stytch validates); only the + # query string gains the nonce, so redirect-URL matching is unaffected. + assert redirect_parsed.path == "/callback" + assert parse_qs(redirect_parsed.query)["state"] == ["nonce-123"] diff --git a/tests/test_auth_flow.py b/tests/test_auth_flow.py index f4e96c51..aeaadf8a 100644 --- a/tests/test_auth_flow.py +++ b/tests/test_auth_flow.py @@ -74,8 +74,49 @@ def test_find_or_create_raises_when_no_projects(monkeypatch): def test_capture_delegates_to_loopback(monkeypatch): sentinel = loopback.CallbackResult(token="tok", token_type="discovery_oauth") - monkeypatch.setattr(flow.loopback, "capture_callback", lambda: sentinel) - assert flow._capture() is sentinel + captured = {} + + def fake_capture(state): + captured["state"] = state + return sentinel + + monkeypatch.setattr(flow.loopback, "capture_callback", fake_capture) + assert flow._capture("nonce-1") is sentinel + assert captured["state"] == "nonce-1" # the nonce is forwarded to the loopback + + +def test_run_login_flow_binds_state_nonce(monkeypatch): + # The nonce build_start_url() carries to Stytch must be the same one the loopback + # capture is told to expect, or a genuine callback would never be accepted. + seen = {} + monkeypatch.setattr( + flow.discovery, "build_start_url", lambda state: seen.setdefault("url_state", state) or "u" + ) + monkeypatch.setattr(flow, "_open_browser", lambda url: None) + + def fake_capture(state): + seen["capture_state"] = state + return loopback.CallbackResult(token="tok", token_type="discovery_oauth") + + monkeypatch.setattr(flow, "_capture", fake_capture) + monkeypatch.setattr( + flow.ams, + "discover", + lambda token: { + "organizations": [{"organization_id": "org_1"}], + "intermediate_session_token": "ist", + }, + ) + monkeypatch.setattr( + flow.ams, + "exchange", + lambda ist, org: {"account": {"id": 9}, "session_jwt": "jwt", "session_token": "t"}, + ) + monkeypatch.setattr(flow, "find_or_create_cli_key", lambda acct, jwt: "sk_final") + + flow.run_login_flow() + assert seen["url_state"] == seen["capture_state"] + assert len(seen["capture_state"]) >= 32 # token_urlsafe(32) -> unguessable nonce def test_run_login_flow_rejects_wrong_token_type(monkeypatch): @@ -83,7 +124,7 @@ def test_run_login_flow_rejects_wrong_token_type(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda: loopback.CallbackResult(token="tok", token_type="something_else"), + lambda _state: loopback.CallbackResult(token="tok", token_type="something_else"), ) with pytest.raises(APIError) as exc: flow.run_login_flow() @@ -96,7 +137,7 @@ def test_run_login_flow_happy_path(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -120,7 +161,7 @@ def test_run_login_flow_happy_path(monkeypatch): def test_run_login_flow_timeout_raises(monkeypatch): monkeypatch.setattr(flow, "_open_browser", lambda url: None) - monkeypatch.setattr(flow, "_capture", lambda: loopback.CallbackResult(error="timeout")) + monkeypatch.setattr(flow, "_capture", lambda _state: loopback.CallbackResult(error="timeout")) with pytest.raises(APIError) as exc: flow.run_login_flow() assert exc.value.message == "Login timed out waiting for the browser." @@ -162,7 +203,7 @@ def test_run_login_flow_uses_exchange_account(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -193,7 +234,7 @@ def test_run_login_flow_multi_org_notes_selection(monkeypatch, capsys): monkeypatch.setattr( flow, "_capture", - lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -233,7 +274,7 @@ def test_run_login_flow_missing_session_token_raises_api_error(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -249,7 +290,7 @@ def test_run_login_flow_org_missing_id_raises_api_error(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -268,7 +309,7 @@ def test_run_login_flow_zero_orgs_raises(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), + lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"), ) monkeypatch.setattr( flow.ams, @@ -288,7 +329,9 @@ def test_run_login_flow_returns_session_material(monkeypatch): monkeypatch.setattr( flow, "_capture", - lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth", error=None), + lambda _state: loopback.CallbackResult( + token="tok", token_type="discovery_oauth", error=None + ), ) monkeypatch.setattr( flow.ams, diff --git a/tests/test_auth_loopback.py b/tests/test_auth_loopback.py index 6be0d088..2ad04b12 100644 --- a/tests/test_auth_loopback.py +++ b/tests/test_auth_loopback.py @@ -36,11 +36,11 @@ def test_capture_returns_token_and_type(): result_box = {} def run(): - result_box["result"] = loopback.capture_callback(timeout=5.0) + result_box["result"] = loopback.capture_callback("st8", timeout=5.0) t = threading.Thread(target=run) t.start() - status = _hit("/callback?stytch_token_type=discovery_oauth&token=tok_abc") + status = _hit("/callback?state=st8&stytch_token_type=discovery_oauth&token=tok_abc") t.join(timeout=5) assert status == 200 # the callback is acknowledged with 200 OK @@ -56,20 +56,85 @@ def test_capture_ignores_unknown_paths(): result_box = {} def run(): - result_box["result"] = loopback.capture_callback(timeout=5.0) + result_box["result"] = loopback.capture_callback("st8", timeout=5.0) t = threading.Thread(target=run) t.start() assert _hit("/favicon.ico") == 404 # unknown path -> 404, capture stays open - _hit("/callback?stytch_token_type=discovery_oauth&token=tok_late") + _hit("/callback?state=st8&stytch_token_type=discovery_oauth&token=tok_late") t.join(timeout=5) result = result_box["result"] assert result.token == "tok_late" +def _body(path: str) -> bytes: + """Fetch `path` once (no retry) and return the response body. + + Callers first confirm the server is bound via `_hit`, so no readiness loop is + needed here. + """ + conn = http.client.HTTPConnection(endpoints.LOOPBACK_HOST, endpoints.LOOPBACK_PORT, timeout=2) + try: + conn.request("GET", path) + return conn.getresponse().read() + finally: + conn.close() + + +def test_success_page_scrubs_token_from_history(): + # The callback URL holds the single-use token + state in its query string; the + # success page must drop it from browser history rather than leave it lingering. + def run(): + loopback.capture_callback("st8", timeout=5.0) + + t = threading.Thread(target=run) + t.start() + assert _hit("/favicon.ico") == 404 # wait until the server is bound (keeps capture open) + body = _body("/callback?state=st8&stytch_token_type=discovery_oauth&token=tok_abc") + t.join(timeout=5) + + assert b"replaceState" in body # the query (token + state) is scrubbed client-side + assert b"tok_abc" not in body # the page never reflects the token itself + + +def test_capture_rejects_mismatched_state(): + # A callback with the wrong state nonce (a forged/login-CSRF attempt) is refused + # with a 400 and does not end the capture; the genuine callback then completes it. + result_box = {} + + def run(): + result_box["result"] = loopback.capture_callback("good", timeout=5.0) + + t = threading.Thread(target=run) + t.start() + assert _hit("/callback?state=evil&stytch_token_type=discovery_oauth&token=tok_bad") == 400 + _hit("/callback?state=good&stytch_token_type=discovery_oauth&token=tok_ok") + t.join(timeout=5) + + result = result_box["result"] + assert result.token == "tok_ok" # the forged token was never captured + + +def test_capture_rejects_missing_state(): + # A callback with no state at all is refused (400) and never captured. + result_box = {} + + def run(): + result_box["result"] = loopback.capture_callback("good", timeout=0.8) + + t = threading.Thread(target=run) + t.start() + assert _hit("/callback?stytch_token_type=discovery_oauth&token=tok_bad") == 400 + t.join(timeout=5) + + result = result_box["result"] + assert result.error == "timeout" + assert result.token is None + + def test_capture_times_out_without_callback(): - result = loopback.capture_callback(timeout=0.3) + result = loopback.capture_callback("st8", timeout=0.3) assert result.error == "timeout" assert result.token is None @@ -84,6 +149,6 @@ def test_capture_raises_clean_error_when_port_unavailable(monkeypatch): monkeypatch.setattr(endpoints, "LOOPBACK_PORT", port) try: with pytest.raises(APIError): - loopback.capture_callback(timeout=1.0) + loopback.capture_callback("st8", timeout=1.0) finally: busy.close() diff --git a/tests/test_init_scaffold.py b/tests/test_init_scaffold.py index c3515793..8678684b 100644 --- a/tests/test_init_scaffold.py +++ b/tests/test_init_scaffold.py @@ -1,9 +1,33 @@ +import stat + import pytest from aai_cli.errors import CLIError from aai_cli.init import scaffold +def test_scaffold_env_is_owner_only_readable(tmp_path): + # The .env holds the real API key, so it must not be world/group-readable. (CI is + # POSIX; the project gates its Windows-specific paths in scripts/check.sh, not here.) + target = tmp_path / "app" + scaffold.scaffold("audio-transcription", target, api_key="sk-real-key") + mode = stat.S_IMODE((target / ".env").stat().st_mode) + assert mode == 0o600 + assert not mode & (stat.S_IRGRP | stat.S_IROTH) # no group/other read of the key + + +def test_scaffold_tightens_existing_env_on_overwrite(tmp_path): + # `aai init --force` re-scaffolds over an existing project; a stale, loosely + # permissioned .env must be tightened to 0600 rather than left as-is. + target = tmp_path / "app" + target.mkdir() + stale = target / ".env" + stale.write_text("ASSEMBLYAI_API_KEY=old\n") + stale.chmod(0o644) + scaffold.scaffold("audio-transcription", target, api_key="sk-real-key") + assert stat.S_IMODE(stale.stat().st_mode) == 0o600 + + def test_scaffold_copies_files_and_renames_dotfiles(tmp_path): target = tmp_path / "app" scaffold.scaffold("audio-transcription", target, api_key="sk-real-key")