Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions aai_cli/auth/ams.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def _json_object_list_or_raise(resp: httpx.Response) -> list[dict[str, object]]:
return objects


def _scalar_params(filters: dict[str, object]) -> dict[str, str | int | float | bool]:
"""Keep only scalar filter values, suitable as query-string params."""
return {
key: value for key, value in filters.items() if isinstance(value, str | int | float | bool)
}


def _client(session_jwt: str | None = None) -> httpx.Client:
"""An AMS HTTP client; pass a session JWT to send the authenticated cookie."""
cookies = {"stytch_session_jwt": session_jwt} if session_jwt else None
Expand Down Expand Up @@ -144,9 +151,7 @@ def rename_token(account_id: int, token_id: int, token_name: str, session_jwt: s

def list_streaming(session_jwt: str, **filters: object) -> dict[str, object]:
"""GET /v1/users/streaming -> {page_details, data: [StreamingSessionSchema]}."""
params = {
key: value for key, value in filters.items() if isinstance(value, str | int | float | bool)
}
params = _scalar_params(filters)
with _client(session_jwt) as client:
resp = client.get("/v1/users/streaming", params=params)
return _json_object_or_raise(resp)
Expand All @@ -161,9 +166,7 @@ def get_streaming(session_id: str, session_jwt: str) -> dict[str, object]:

def list_audit_logs(session_jwt: str, **filters: object) -> dict[str, object]:
"""GET /v2/user/audit-logs -> {page_details, data: [AuditLogResponse]}."""
params = {
key: value for key, value in filters.items() if isinstance(value, str | int | float | bool)
}
params = _scalar_params(filters)
with _client(session_jwt) as client:
resp = client.get("/v2/user/audit-logs", params=params)
return _json_object_or_raise(resp)
11 changes: 3 additions & 8 deletions aai_cli/commands/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,24 @@

app = typer.Typer(help="View your account's audit log.")

_LOGIN_ACTIONS = {"login", "login.succeeded", "login__succeeded"}
# `__`-separated variants are handled by _normalize_action (which maps `__` -> `.`),
# so only the dotted forms need entries here.
_LOGIN_ACTIONS = {"login", "login.succeeded"}
_ACTION_LABELS = {
"account.create": "Account created",
"account.created": "Account created",
"account__created": "Account created",
"account.tos_accepted": "Terms accepted",
"account.tos.accepted": "Terms accepted",
"account__tos_accepted": "Terms accepted",
"account.upgrade": "Account upgraded",
"account.upgraded": "Account upgraded",
"account__upgraded": "Account upgraded",
"member.create": "Member created",
"member.created": "Member created",
"member__created": "Member created",
"token.create": "API key created",
"token.created": "API key created",
"token__created": "API key created",
"token.rename": "API key renamed",
"token.renamed": "API key renamed",
"token__renamed": "API key renamed",
"login": "Login",
"login.succeeded": "Login succeeded",
"login__succeeded": "Login succeeded",
}


Expand Down
13 changes: 2 additions & 11 deletions aai_cli/commands/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from rich.table import Table

from aai_cli import client, config, environments, help_panels, output
from aai_cli.auth import run_login_flow
from aai_cli.context import AppState, resolve_profile, run_command
from aai_cli.context import AppState, persist_browser_login, resolve_profile, run_command
from aai_cli.errors import APIError, NotAuthenticated
from aai_cli.help_text import examples_epilog

Expand Down Expand Up @@ -47,15 +46,7 @@ def body(state: AppState, json_mode: bool) -> None:
# login rather than silently reusing the old (possibly different) identity.
config.clear_session(profile)
else:
result = run_login_flow()
config.set_api_key(profile, result.api_key)
config.set_profile_env(profile, env)
config.set_session(
profile,
session_jwt=result.session_jwt,
session_token=result.session_token,
account_id=result.account_id,
)
persist_browser_login(profile, env)
output.emit(
{"authenticated": True, "profile": profile, "env": env},
lambda _d: (
Expand Down
59 changes: 32 additions & 27 deletions aai_cli/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,26 @@ def parse_config_overrides(
return out


def load_config_file(path: str | Path, fields: dict[str, str]) -> dict[str, object]:
"""Load a JSON config file and validate its keys against `fields`."""
def _load_json_object(path: str | Path, *, label: str) -> dict[str, object]:
"""Read `path` as a JSON object, surfacing read/parse/shape errors as usage errors.

`label` (e.g. "Config file") prefixes the error messages.
"""
try:
data: object = json.loads(Path(path).read_text())
except FileNotFoundError as exc:
raise UsageError(f"Config file not found: {path}") from exc
raise UsageError(f"{label} not found: {path}") from exc
except json.JSONDecodeError as exc:
raise UsageError(f"Config file is not valid JSON: {exc}") from exc
config = jsonshape.as_mapping(data)
if config is None:
raise UsageError("Config file must contain a JSON object.")
raise UsageError(f"{label} is not valid JSON: {exc}") from exc
mapping = jsonshape.as_mapping(data)
if mapping is None:
raise UsageError(f"{label} must contain a JSON object.")
return mapping


def load_config_file(path: str | Path, fields: dict[str, str]) -> dict[str, object]:
"""Load a JSON config file and validate its keys against `fields`."""
config = _load_json_object(path, label="Config file")
unknown = [key for key in config if key not in fields]
if unknown:
valid = ", ".join(sorted(fields))
Expand Down Expand Up @@ -297,14 +306,24 @@ def merge_transcribe_config(
return _merge(TRANSCRIBE_FIELDS, flags, overrides, config_file)


def construct_transcription_config(merged: dict[str, typing.Any]) -> aai.TranscriptionConfig:
"""Build a TranscriptionConfig from a merged kwargs dict, surfacing errors as usage."""
def _construct[ConfigT](
model_cls: Callable[..., ConfigT], merged: dict[str, typing.Any], *, label: str
) -> ConfigT:
"""Build `model_cls(**merged)`, surfacing SDK validation as a usage error.

`label` (e.g. "transcription") names the config in the error message.
"""
try:
return aai.TranscriptionConfig(**merged)
return model_cls(**merged)
except UsageError:
raise
except Exception as exc: # surface SDK validation as a usage error
raise UsageError(f"Invalid transcription config: {exc}") from exc
raise UsageError(f"Invalid {label} config: {exc}") from exc


def construct_transcription_config(merged: dict[str, typing.Any]) -> aai.TranscriptionConfig:
"""Build a TranscriptionConfig from a merged kwargs dict, surfacing errors as usage."""
return _construct(aai.TranscriptionConfig, merged, label="transcription")


def merge_streaming_params(
Expand All @@ -329,12 +348,7 @@ def merge_streaming_params(

def construct_streaming_params(merged: dict[str, typing.Any]) -> StreamingParameters:
"""Build StreamingParameters from a merged kwargs dict, surfacing errors as usage."""
try:
return StreamingParameters(**merged)
except UsageError:
raise
except Exception as exc:
raise UsageError(f"Invalid streaming config: {exc}") from exc
return _construct(StreamingParameters, merged, label="streaming")


def split_csv(value: str | None) -> list[str] | None:
Expand Down Expand Up @@ -365,16 +379,7 @@ def auth_header_flags(value: str | None) -> dict[str, object]:

def load_custom_spelling(path: str) -> dict[str, object]:
"""Load a custom-spelling JSON map (e.g. {"AssemblyAI": ["assembly ai"]})."""
try:
data: object = json.loads(Path(path).read_text())
except FileNotFoundError as exc:
raise UsageError(f"Custom spelling file not found: {path}") from exc
except json.JSONDecodeError as exc:
raise UsageError(f"Custom spelling file is not valid JSON: {exc}") from exc
mapping = jsonshape.as_mapping(data)
if mapping is None:
raise UsageError("Custom spelling file must contain a JSON object.")
return mapping
return _load_json_object(path, label="Custom spelling file")


def translation_request(languages: list[str]) -> dict[str, object]:
Expand Down
9 changes: 6 additions & 3 deletions aai_cli/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ def env_override_warning(state: AppState) -> str | None:
return state.env_override_warning()


def _persist_browser_login(state: AppState) -> None:
profile = state.resolve_profile()
env = environments.active().name
def persist_browser_login(profile: str, env: str) -> None:
"""Run the browser login flow and persist its credentials for `profile`/`env`."""
result = run_login_flow()
config.set_api_key(profile, result.api_key)
config.set_profile_env(profile, env)
Expand All @@ -101,6 +100,10 @@ def _persist_browser_login(state: AppState) -> None:
)


def _persist_browser_login(state: AppState) -> None:
persist_browser_login(state.resolve_profile(), environments.active().name)


def _login_persistence_error(exc: object) -> APIError:
return APIError(
f"Signed in, but could not save the credentials locally: {exc}",
Expand Down
26 changes: 10 additions & 16 deletions aai_cli/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,22 +182,16 @@ def run_chain_steps(
if not prompts:
return []

if transcript_id is not None:
output = transform_transcript(
api_key,
prompt=prompts[0],
model=model,
max_tokens=max_tokens,
transcript_id=transcript_id,
)
else:
output = transform_transcript(
api_key,
prompt=prompts[0],
model=model,
max_tokens=max_tokens,
transcript_text=transcript_text,
)
# Exactly one of transcript_id / transcript_text is set by callers; pass both
# through (build_messages prefers the id) so the two cases share one call.
output = transform_transcript(
api_key,
prompt=prompts[0],
model=model,
max_tokens=max_tokens,
transcript_id=transcript_id,
transcript_text=transcript_text,
)
steps = [{"prompt": prompts[0], "output": output}]

for prompt in prompts[1:]:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ def test_logout_clears_key():


def test_login_oauth_flow_stores_returned_key(monkeypatch):
monkeypatch.setattr("aai_cli.commands.login.run_login_flow", _fake_login_result)
monkeypatch.setattr("aai_cli.context.run_login_flow", _fake_login_result)
result = runner.invoke(app, ["login"])
assert result.exit_code == 0
assert config.get_api_key("default") == "sk_from_oauth"


def test_login_oauth_persists_session(monkeypatch):
monkeypatch.setattr("aai_cli.commands.login.run_login_flow", _fake_login_result)
monkeypatch.setattr("aai_cli.context.run_login_flow", _fake_login_result)
result = runner.invoke(app, ["login"])
assert result.exit_code == 0
assert config.get_session("default") == {"jwt": "jwt_x", "token": "tok_x"}
Expand Down Expand Up @@ -124,15 +124,15 @@ def test_login_oauth_flow_failure_exits_nonzero(monkeypatch):
def boom():
raise APIError("Login timed out waiting for the browser.")

monkeypatch.setattr("aai_cli.commands.login.run_login_flow", boom)
monkeypatch.setattr("aai_cli.context.run_login_flow", boom)
result = runner.invoke(app, ["login"])
assert result.exit_code != 0
assert config.get_api_key("default") is None


def test_login_api_key_flag_still_bypasses_oauth(monkeypatch):
monkeypatch.setattr(
"aai_cli.commands.login.run_login_flow",
"aai_cli.context.run_login_flow",
lambda: (_ for _ in ()).throw(AssertionError("OAuth must not run with --api-key")),
)
with patch("aai_cli.commands.login.client.validate_key", return_value=True):
Expand All @@ -142,15 +142,15 @@ def test_login_api_key_flag_still_bypasses_oauth(monkeypatch):


def test_login_binds_env_to_profile(monkeypatch):
monkeypatch.setattr("aai_cli.commands.login.run_login_flow", _fake_login_result)
monkeypatch.setattr("aai_cli.context.run_login_flow", _fake_login_result)
result = runner.invoke(app, ["--env", "sandbox000", "login"])
assert result.exit_code == 0
assert config.get_api_key("default") == "sk_from_oauth"
assert config.get_profile_env("default") == "sandbox000"


def test_sandbox_flag_is_shortcut_for_env(monkeypatch):
monkeypatch.setattr("aai_cli.commands.login.run_login_flow", lambda: _fake_login_result("sk_x"))
monkeypatch.setattr("aai_cli.context.run_login_flow", lambda: _fake_login_result("sk_x"))
result = runner.invoke(app, ["--sandbox", "login"])
assert result.exit_code == 0
assert config.get_profile_env("default") == "sandbox000"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_transcribe_prompt_transforms_json(monkeypatch):
_auth()
seen = {}

def fake_transform(api_key, *, prompt, model, transcript_id, max_tokens):
def fake_transform(api_key, *, prompt, model, transcript_id, max_tokens, transcript_text=None):
seen["prompt"] = prompt
seen["model"] = model
seen["transcript_id"] = transcript_id
Expand Down
Loading