Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions aai_cli/auth/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,28 @@
from aai_cli.auth import endpoints


def build_start_url() -> str:
def build_start_url(state: str) -> str:
"""The Stytch B2B OAuth discovery *start* URL the browser opens.

Client-side endpoint authenticated by the project's public token. After the
user authenticates with the provider, Stytch redirects to our loopback
`discovery_redirect_url` with `?stytch_token_type=discovery_oauth&token=...`.
No custom state is appended: the redirect is exact-match validated and the
server is loopback-only, single-shot.

`state` is a single-use random nonce carried as a query parameter *on the
redirect URL*. Stytch validates the redirect URL by scheme/host/port/path and
preserves query parameters through the flow, so the nonce comes back on the
callback unchanged; `loopback.capture_callback` rejects any callback whose
`state` doesn't match. This stops a malicious page from completing someone
else's login at the loopback server (login CSRF / account confusion) by
replaying an attacker-minted discovery token while `aai login` is waiting.
"""
base = (
f"{endpoints.stytch_domain()}"
f"/v1/b2b/public/oauth/{endpoints.STYTCH_OAUTH_PROVIDER}/discovery/start"
)
redirect_with_state = f"{endpoints.redirect_uri()}?{urlencode({'state': state})}"
params = {
"public_token": endpoints.stytch_public_token(),
"discovery_redirect_url": endpoints.redirect_uri(),
"discovery_redirect_url": redirect_with_state,
}
return f"{base}?{urlencode(params)}"
13 changes: 9 additions & 4 deletions aai_cli/auth/flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import secrets
import webbrowser
from dataclasses import dataclass

Expand Down Expand Up @@ -97,8 +98,8 @@ def _open_browser(url: str) -> None:
)


def _capture() -> loopback.CallbackResult:
return loopback.capture_callback()
def _capture(state: str) -> loopback.CallbackResult:
return loopback.capture_callback(state)


def _reusable_cli_key(token: _Token) -> str | None:
Expand Down Expand Up @@ -137,8 +138,12 @@ def find_or_create_cli_key(account_id: int, session_jwt: str) -> str:

def run_login_flow() -> LoginResult:
"""Drive the full browser + AMS login and return a LoginResult."""
_open_browser(discovery.build_start_url())
result = _capture()
# A fresh single-use nonce binds the browser hand-off to this loopback capture:
# build_start_url carries it to Stytch, capture_callback only accepts a callback
# that echoes it back. token_urlsafe(32) is 256 bits of entropy — unguessable.
state = secrets.token_urlsafe(32) # pragma: no mutate (any large nonce works; 32 isn't magic)
_open_browser(discovery.build_start_url(state))
result = _capture(state)

if result.error == "timeout":
raise APIError(
Expand Down
28 changes: 26 additions & 2 deletions aai_cli/auth/loopback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import secrets
import threading
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, HTTPServer
Expand All @@ -8,8 +9,16 @@
from aai_cli.auth import endpoints
from aai_cli.errors import APIError

# The callback URL carries the single-use OAuth token (and the state nonce) in its
# query string, so it would otherwise linger in the browser's history and address
# bar. Scrub it from the current history entry with replaceState the moment the page
# loads — no extra request to race the server shutdown, unlike a redirect. The token
# is already spent server-side by the time the user reads this, but keeping it out of
# history is the OAuth-for-native-apps (RFC 8252) hygiene. The page reflects no query
# data, so there is nothing to inject; the script is a static literal.
_SUCCESS_HTML = (
b"<html><body style='font-family:sans-serif'>"
b"<script>history.replaceState(null,'',location.pathname)</script>"
b"<h2>Signed in.</h2><p>You can close this tab and return to the terminal.</p>"
b"</body></html>"
)
Expand All @@ -22,10 +31,17 @@ class CallbackResult:
error: str | None = None


def capture_callback(timeout: float = 120.0) -> CallbackResult:
def capture_callback(
expected_state: str,
timeout: float = 120.0, # pragma: no mutate (default window; tests pass explicit timeouts)
) -> CallbackResult:
"""Bind the fixed loopback port, capture one OAuth callback, return its token.

Returns a CallbackResult; `error="timeout"` if no callback arrives in time.
Only a callback whose `state` query parameter equals `expected_state` is
accepted; any other request (wrong/missing state, or a different path) gets a
4xx and the server keeps waiting, so a forged callback can't complete someone
else's login. Returns a CallbackResult; `error="timeout"` if no matching
callback arrives in time.
"""
result = CallbackResult()
done = threading.Event()
Expand All @@ -38,6 +54,14 @@ def do_GET(self) -> None: # stdlib API name
self.end_headers()
return
qs = parse_qs(parsed.query)
state = next(iter(qs.get("state", [])), None)
# Constant-time compare so a forged callback can't probe the nonce by
# timing. A mismatch is rejected without ending the capture: the real
# callback can still arrive (otherwise it falls through to timeout).
if state is None or not secrets.compare_digest(state, expected_state):
self.send_response(400)
self.end_headers()
return
result.token = next(iter(qs.get("token", [])), None)
result.token_type = next(iter(qs.get("stytch_token_type", [])), None)
self.send_response(200)
Expand Down
14 changes: 13 additions & 1 deletion aai_cli/init/scaffold.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# aai_cli/init/scaffold.py
from __future__ import annotations

import os
import stat
from importlib import resources
from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -84,5 +86,15 @@ def scaffold(
_copy_tree(root, target)
lines = [f"ASSEMBLYAI_API_KEY={api_key or PLACEHOLDER_KEY}"]
lines += [f"{k}={v}" for k, v in (env_vars or {}).items()]
(target / ".env").write_text("\n".join(lines) + "\n")
env_path = target / ".env"
# The .env holds the real API key, so create it readable/writable by the owner
# only (0600) instead of the umask default (commonly 0644) — otherwise the key
# would be world/group-readable on a shared host. Open with the 0600 mode so the
# secret is never briefly world-readable; the explicit chmod then also tightens an
# existing file when `aai init --force` overwrites one (O_CREAT's mode is ignored
# for a file that already exists).
fd = os.open(env_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IRUSR | stat.S_IWUSR)
with os.fdopen(fd, "w") as fh:
fh.write("\n".join(lines) + "\n")
env_path.chmod(stat.S_IRUSR | stat.S_IWUSR)
return target
17 changes: 14 additions & 3 deletions tests/test_auth_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,26 @@


def test_build_start_url_targets_b2b_discovery_for_provider():
url = discovery.build_start_url()
url = discovery.build_start_url("state-xyz")
parsed = urlparse(url)
assert parsed.scheme == "https"
assert parsed.path == "/v1/b2b/public/oauth/google/discovery/start"
assert url.startswith(endpoints.stytch_domain())


def test_build_start_url_includes_public_token_and_redirect():
url = discovery.build_start_url()
url = discovery.build_start_url("state-xyz")
qs = parse_qs(urlparse(url).query)
assert qs["public_token"] == [endpoints.stytch_public_token()]
assert qs["discovery_redirect_url"] == ["http://127.0.0.1:8585/callback"]
# The state nonce rides as a query param on the (still path-exact) redirect URL.
assert qs["discovery_redirect_url"] == ["http://127.0.0.1:8585/callback?state=state-xyz"]


def test_build_start_url_carries_state_into_redirect():
url = discovery.build_start_url("nonce-123")
redirect = parse_qs(urlparse(url).query)["discovery_redirect_url"][0]
redirect_parsed = urlparse(redirect)
# The redirect path stays exactly /callback (what Stytch validates); only the
# query string gains the nonce, so redirect-URL matching is unaffected.
assert redirect_parsed.path == "/callback"
assert parse_qs(redirect_parsed.query)["state"] == ["nonce-123"]
65 changes: 54 additions & 11 deletions tests/test_auth_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,57 @@ def test_find_or_create_raises_when_no_projects(monkeypatch):

def test_capture_delegates_to_loopback(monkeypatch):
sentinel = loopback.CallbackResult(token="tok", token_type="discovery_oauth")
monkeypatch.setattr(flow.loopback, "capture_callback", lambda: sentinel)
assert flow._capture() is sentinel
captured = {}

def fake_capture(state):
captured["state"] = state
return sentinel

monkeypatch.setattr(flow.loopback, "capture_callback", fake_capture)
assert flow._capture("nonce-1") is sentinel
assert captured["state"] == "nonce-1" # the nonce is forwarded to the loopback


def test_run_login_flow_binds_state_nonce(monkeypatch):
# The nonce build_start_url() carries to Stytch must be the same one the loopback
# capture is told to expect, or a genuine callback would never be accepted.
seen = {}
monkeypatch.setattr(
flow.discovery, "build_start_url", lambda state: seen.setdefault("url_state", state) or "u"
)
monkeypatch.setattr(flow, "_open_browser", lambda url: None)

def fake_capture(state):
seen["capture_state"] = state
return loopback.CallbackResult(token="tok", token_type="discovery_oauth")

monkeypatch.setattr(flow, "_capture", fake_capture)
monkeypatch.setattr(
flow.ams,
"discover",
lambda token: {
"organizations": [{"organization_id": "org_1"}],
"intermediate_session_token": "ist",
},
)
monkeypatch.setattr(
flow.ams,
"exchange",
lambda ist, org: {"account": {"id": 9}, "session_jwt": "jwt", "session_token": "t"},
)
monkeypatch.setattr(flow, "find_or_create_cli_key", lambda acct, jwt: "sk_final")

flow.run_login_flow()
assert seen["url_state"] == seen["capture_state"]
assert len(seen["capture_state"]) >= 32 # token_urlsafe(32) -> unguessable nonce


def test_run_login_flow_rejects_wrong_token_type(monkeypatch):
monkeypatch.setattr(flow, "_open_browser", lambda url: None)
monkeypatch.setattr(
flow,
"_capture",
lambda: loopback.CallbackResult(token="tok", token_type="something_else"),
lambda _state: loopback.CallbackResult(token="tok", token_type="something_else"),
)
with pytest.raises(APIError) as exc:
flow.run_login_flow()
Expand All @@ -96,7 +137,7 @@ def test_run_login_flow_happy_path(monkeypatch):
monkeypatch.setattr(
flow,
"_capture",
lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
)
monkeypatch.setattr(
flow.ams,
Expand All @@ -120,7 +161,7 @@ def test_run_login_flow_happy_path(monkeypatch):

def test_run_login_flow_timeout_raises(monkeypatch):
monkeypatch.setattr(flow, "_open_browser", lambda url: None)
monkeypatch.setattr(flow, "_capture", lambda: loopback.CallbackResult(error="timeout"))
monkeypatch.setattr(flow, "_capture", lambda _state: loopback.CallbackResult(error="timeout"))
with pytest.raises(APIError) as exc:
flow.run_login_flow()
assert exc.value.message == "Login timed out waiting for the browser."
Expand Down Expand Up @@ -162,7 +203,7 @@ def test_run_login_flow_uses_exchange_account(monkeypatch):
monkeypatch.setattr(
flow,
"_capture",
lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
)
monkeypatch.setattr(
flow.ams,
Expand Down Expand Up @@ -193,7 +234,7 @@ def test_run_login_flow_multi_org_notes_selection(monkeypatch, capsys):
monkeypatch.setattr(
flow,
"_capture",
lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
)
monkeypatch.setattr(
flow.ams,
Expand Down Expand Up @@ -233,7 +274,7 @@ def test_run_login_flow_missing_session_token_raises_api_error(monkeypatch):
monkeypatch.setattr(
flow,
"_capture",
lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
)
monkeypatch.setattr(
flow.ams,
Expand All @@ -249,7 +290,7 @@ def test_run_login_flow_org_missing_id_raises_api_error(monkeypatch):
monkeypatch.setattr(
flow,
"_capture",
lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
)
monkeypatch.setattr(
flow.ams,
Expand All @@ -268,7 +309,7 @@ def test_run_login_flow_zero_orgs_raises(monkeypatch):
monkeypatch.setattr(
flow,
"_capture",
lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
lambda _state: loopback.CallbackResult(token="tok", token_type="discovery_oauth"),
)
monkeypatch.setattr(
flow.ams,
Expand All @@ -288,7 +329,9 @@ def test_run_login_flow_returns_session_material(monkeypatch):
monkeypatch.setattr(
flow,
"_capture",
lambda: loopback.CallbackResult(token="tok", token_type="discovery_oauth", error=None),
lambda _state: loopback.CallbackResult(
token="tok", token_type="discovery_oauth", error=None
),
)
monkeypatch.setattr(
flow.ams,
Expand Down
Loading
Loading