Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions aai_cli/auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,47 @@
import os

from aai_cli import environments
from aai_cli.errors import CLIError

# Constant across environments.
STYTCH_OAUTH_PROVIDER = "google"
CLI_TOKEN_NAME = "AssemblyAI CLI" # noqa: S105 - display name, not a credential

# Fixed loopback (Stytch does exact-match redirect validation; 8585 is registered).
LOOPBACK_HOST = "127.0.0.1"
LOOPBACK_PORT = int(os.environ.get("AAI_AUTH_PORT", "8585"))
LOOPBACK_PATH = "/callback"
_DEFAULT_LOOPBACK_PORT = 8585
_MAX_PORT = 65535 # highest valid TCP port


def _invalid_auth_port(raw: str) -> CLIError:
return CLIError(
f"AAI_AUTH_PORT must be a port number in 1-65535, got {raw!r}.",
error_type="invalid_env",
exit_code=2,
suggestion="Unset AAI_AUTH_PORT, or set it to a free port in 1-65535.",
)


def loopback_port() -> int:
"""The loopback callback port, overridable via ``AAI_AUTH_PORT`` (dev/test only).

Resolved lazily — never at import — and validated, so a malformed override
surfaces as a clean CLIError on the login path instead of the raw ``ValueError``
a module-level ``int(...)`` would raise. This module sits on the CLI's import hot
path, so that ValueError would otherwise crash *every* ``aai`` command (even
``--help``), not just ``aai login``.
"""
raw = os.environ.get("AAI_AUTH_PORT")
if raw is None:
return _DEFAULT_LOOPBACK_PORT
try:
port = int(raw)
except ValueError as exc:
raise _invalid_auth_port(raw) from exc
if not 1 <= port <= _MAX_PORT:
raise _invalid_auth_port(raw)
return port


# Environment-specific values resolve from the active environment (see
Expand All @@ -36,4 +68,4 @@ def signup_url() -> str:

def redirect_uri() -> str:
"""The exact loopback redirect URL registered in Stytch."""
return f"http://{LOOPBACK_HOST}:{LOOPBACK_PORT}{LOOPBACK_PATH}"
return f"http://{LOOPBACK_HOST}:{loopback_port()}{LOOPBACK_PATH}"
5 changes: 3 additions & 2 deletions aai_cli/auth/loopback.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ def do_GET(self) -> None: # stdlib API name
def log_message(self, format: str, *args: object) -> None: # silence stderr logging
pass

port = endpoints.loopback_port()
try:
server = HTTPServer((endpoints.LOOPBACK_HOST, endpoints.LOOPBACK_PORT), Handler)
server = HTTPServer((endpoints.LOOPBACK_HOST, port), Handler)
except OSError as exc:
raise APIError(
f"Could not start the login callback server on "
f"{endpoints.LOOPBACK_HOST}:{endpoints.LOOPBACK_PORT} ({exc}). "
f"{endpoints.LOOPBACK_HOST}:{port} ({exc}). "
"Close whatever is using that port and run 'aai login' again."
) from exc
thread = threading.Thread(target=server.serve_forever, daemon=True)
Expand Down
23 changes: 15 additions & 8 deletions aai_cli/init/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,21 @@ def spawn(
used to capture cloudflared's output for URL discovery. Without it, stdio is inherited.
"""
if log_path is not None:
return subprocess.Popen(
command,
cwd=cwd,
env=env,
stdout=log_path.open("w"),
stderr=subprocess.STDOUT,
text=True,
)
# The child gets its own dup of the fd once Popen returns, so close the
# parent's handle straight away instead of leaking it for the (long-lived)
# process's whole lifetime.
log = log_path.open("w")
try:
return subprocess.Popen(
command,
cwd=cwd,
env=env,
stdout=log,
stderr=subprocess.STDOUT,
text=True,
)
finally:
log.close()
return subprocess.Popen(command, cwd=cwd, env=env, text=True)


Expand Down
37 changes: 36 additions & 1 deletion tests/test_auth_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest

from aai_cli.auth import endpoints
from aai_cli.errors import CLIError


def test_redirect_uri_is_fixed_loopback():
Expand All @@ -19,5 +22,37 @@ def test_constants_are_environment_independent():


def test_env_override_changes_redirect_uri(monkeypatch):
monkeypatch.setattr(endpoints, "LOOPBACK_PORT", 9999)
monkeypatch.setenv("AAI_AUTH_PORT", "9999")
assert endpoints.redirect_uri() == "http://127.0.0.1:9999/callback"


def test_loopback_port_rejects_non_integer(monkeypatch):
# A typo'd AAI_AUTH_PORT must surface as a clean CLIError on the login path,
# not a raw ValueError that would crash every command at import time.
monkeypatch.setenv("AAI_AUTH_PORT", "abc")
with pytest.raises(CLIError) as excinfo:
endpoints.loopback_port()
assert excinfo.value.exit_code == 2
assert "AAI_AUTH_PORT" in str(excinfo.value)


def test_loopback_port_rejects_above_max(monkeypatch):
# 65535 is the highest valid TCP port; one past it must be rejected.
monkeypatch.setenv("AAI_AUTH_PORT", "65536")
with pytest.raises(CLIError):
endpoints.loopback_port()


def test_loopback_port_accepts_boundary_values(monkeypatch):
# The valid range is exactly 1..65535 inclusive.
monkeypatch.setenv("AAI_AUTH_PORT", "1")
assert endpoints.loopback_port() == 1
monkeypatch.setenv("AAI_AUTH_PORT", "65535")
assert endpoints.loopback_port() == 65535


def test_loopback_port_rejects_zero(monkeypatch):
# Port 0 (OS-assign) is meaningless for a fixed, pre-registered redirect URI.
monkeypatch.setenv("AAI_AUTH_PORT", "0")
with pytest.raises(CLIError):
endpoints.loopback_port()
8 changes: 4 additions & 4 deletions tests/test_auth_loopback.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _unique_loopback_port(monkeypatch):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as probe:
probe.bind((endpoints.LOOPBACK_HOST, 0))
port = probe.getsockname()[1]
monkeypatch.setattr(endpoints, "LOOPBACK_PORT", port)
monkeypatch.setenv("AAI_AUTH_PORT", str(port))


def _hit(path: str) -> int | None:
Expand All @@ -40,7 +40,7 @@ def _hit(path: str) -> int | None:
# Retry briefly until the server thread is bound.
for _ in range(50):
conn = http.client.HTTPConnection(
endpoints.LOOPBACK_HOST, endpoints.LOOPBACK_PORT, timeout=2
endpoints.LOOPBACK_HOST, endpoints.loopback_port(), timeout=2
)
try:
conn.request("GET", path)
Expand Down Expand Up @@ -96,7 +96,7 @@ def _body(path: str) -> bytes:
Callers first confirm the server is bound via `_hit`, so no readiness loop is
needed here.
"""
conn = http.client.HTTPConnection(endpoints.LOOPBACK_HOST, endpoints.LOOPBACK_PORT, timeout=2)
conn = http.client.HTTPConnection(endpoints.LOOPBACK_HOST, endpoints.loopback_port(), timeout=2)
try:
conn.request("GET", path)
return conn.getresponse().read()
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_capture_raises_clean_error_when_port_unavailable(monkeypatch):
busy.bind((endpoints.LOOPBACK_HOST, 0))
busy.listen(1)
port = busy.getsockname()[1]
monkeypatch.setattr(endpoints, "LOOPBACK_PORT", port)
monkeypatch.setenv("AAI_AUTH_PORT", str(port))
try:
with pytest.raises(APIError):
loopback.capture_callback(timeout=1.0)
Expand Down
9 changes: 6 additions & 3 deletions tests/test_init_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,12 @@ def fake_popen(cmd, **kwargs):
runner.spawn(["cloudflared"], cwd=tmp_path, log_path=log)
assert captured["kwargs"]["stderr"] is runner.subprocess.STDOUT
assert captured["kwargs"]["text"] is True
# stdout is an open writable handle to the log file
assert captured["kwargs"]["stdout"].writable()
captured["kwargs"]["stdout"].close()
stdout = captured["kwargs"]["stdout"]
# spawn writes the child's stdout to the log file...
assert stdout.name == str(log)
# ...and closes the parent's handle once Popen returns (the child keeps its dup),
# so the file descriptor isn't leaked for the process's whole lifetime.
assert stdout.closed is True


def test_run_server_passes_command_and_env(monkeypatch):
Expand Down
Loading