From 759b0907105f920f50135327a87d99c0da7cc2a6 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 6 Jun 2026 20:16:19 +0000 Subject: [PATCH 1/2] Remove code duplication across CLI modules - config_builder: extract _load_json_object (config-file/custom-spelling loaders) and _construct (transcription/streaming config builders) - auth/ams: extract _scalar_params shared by list_streaming/list_audit_logs - context/login: extract persist_browser_login, used by both the login command and the auto-login path - audit: drop redundant __-separated action-label variants already handled by _normalize_action - llm.run_chain_steps: collapse the transcript_id/transcript_text branches into one transform_transcript call --- aai_cli/auth/ams.py | 15 ++++++---- aai_cli/commands/audit.py | 11 ++----- aai_cli/commands/login.py | 13 ++------ aai_cli/config_builder.py | 62 ++++++++++++++++++++++----------------- aai_cli/context.py | 9 ++++-- aai_cli/llm.py | 26 +++++++--------- tests/test_login.py | 12 ++++---- tests/test_transcribe.py | 2 +- 8 files changed, 72 insertions(+), 78 deletions(-) diff --git a/aai_cli/auth/ams.py b/aai_cli/auth/ams.py index 44833a52..a81e93bf 100644 --- a/aai_cli/auth/ams.py +++ b/aai_cli/auth/ams.py @@ -55,6 +55,13 @@ def _json_object_list_or_raise(resp: httpx.Response) -> list[dict[str, object]]: return objects +def _scalar_params(filters: dict[str, object]) -> dict[str, str | int | float | bool]: + """Keep only scalar filter values, suitable as query-string params.""" + return { + key: value for key, value in filters.items() if isinstance(value, str | int | float | bool) + } + + def _client(session_jwt: str | None = None) -> httpx.Client: """An AMS HTTP client; pass a session JWT to send the authenticated cookie.""" cookies = {"stytch_session_jwt": session_jwt} if session_jwt else None @@ -144,9 +151,7 @@ def rename_token(account_id: int, token_id: int, token_name: str, session_jwt: s def list_streaming(session_jwt: str, **filters: object) -> dict[str, object]: """GET /v1/users/streaming -> {page_details, data: [StreamingSessionSchema]}.""" - params = { - key: value for key, value in filters.items() if isinstance(value, str | int | float | bool) - } + params = _scalar_params(filters) with _client(session_jwt) as client: resp = client.get("/v1/users/streaming", params=params) return _json_object_or_raise(resp) @@ -161,9 +166,7 @@ def get_streaming(session_id: str, session_jwt: str) -> dict[str, object]: def list_audit_logs(session_jwt: str, **filters: object) -> dict[str, object]: """GET /v2/user/audit-logs -> {page_details, data: [AuditLogResponse]}.""" - params = { - key: value for key, value in filters.items() if isinstance(value, str | int | float | bool) - } + params = _scalar_params(filters) with _client(session_jwt) as client: resp = client.get("/v2/user/audit-logs", params=params) return _json_object_or_raise(resp) diff --git a/aai_cli/commands/audit.py b/aai_cli/commands/audit.py index 7af1c651..e8999299 100644 --- a/aai_cli/commands/audit.py +++ b/aai_cli/commands/audit.py @@ -15,29 +15,24 @@ app = typer.Typer(help="View your account's audit log.") -_LOGIN_ACTIONS = {"login", "login.succeeded", "login__succeeded"} +# `__`-separated variants are handled by _normalize_action (which maps `__` -> `.`), +# so only the dotted forms need entries here. +_LOGIN_ACTIONS = {"login", "login.succeeded"} _ACTION_LABELS = { "account.create": "Account created", "account.created": "Account created", - "account__created": "Account created", "account.tos_accepted": "Terms accepted", "account.tos.accepted": "Terms accepted", - "account__tos_accepted": "Terms accepted", "account.upgrade": "Account upgraded", "account.upgraded": "Account upgraded", - "account__upgraded": "Account upgraded", "member.create": "Member created", "member.created": "Member created", - "member__created": "Member created", "token.create": "API key created", "token.created": "API key created", - "token__created": "API key created", "token.rename": "API key renamed", "token.renamed": "API key renamed", - "token__renamed": "API key renamed", "login": "Login", "login.succeeded": "Login succeeded", - "login__succeeded": "Login succeeded", } diff --git a/aai_cli/commands/login.py b/aai_cli/commands/login.py index 41262a33..d0e3aef5 100644 --- a/aai_cli/commands/login.py +++ b/aai_cli/commands/login.py @@ -5,8 +5,7 @@ from rich.table import Table from aai_cli import client, config, environments, help_panels, output -from aai_cli.auth import run_login_flow -from aai_cli.context import AppState, resolve_profile, run_command +from aai_cli.context import AppState, persist_browser_login, resolve_profile, run_command from aai_cli.errors import APIError, NotAuthenticated from aai_cli.help_text import examples_epilog @@ -47,15 +46,7 @@ def body(state: AppState, json_mode: bool) -> None: # login rather than silently reusing the old (possibly different) identity. config.clear_session(profile) else: - result = run_login_flow() - config.set_api_key(profile, result.api_key) - config.set_profile_env(profile, env) - config.set_session( - profile, - session_jwt=result.session_jwt, - session_token=result.session_token, - account_id=result.account_id, - ) + persist_browser_login(profile, env) output.emit( {"authenticated": True, "profile": profile, "env": env}, lambda _d: ( diff --git a/aai_cli/config_builder.py b/aai_cli/config_builder.py index 8491fd63..af38062e 100644 --- a/aai_cli/config_builder.py +++ b/aai_cli/config_builder.py @@ -255,17 +255,26 @@ def parse_config_overrides( return out -def load_config_file(path: str | Path, fields: dict[str, str]) -> dict[str, object]: - """Load a JSON config file and validate its keys against `fields`.""" +def _load_json_object(path: str | Path, *, label: str) -> dict[str, object]: + """Read `path` as a JSON object, surfacing read/parse/shape errors as usage errors. + + `label` (e.g. "Config file") prefixes the error messages. + """ try: data: object = json.loads(Path(path).read_text()) except FileNotFoundError as exc: - raise UsageError(f"Config file not found: {path}") from exc + raise UsageError(f"{label} not found: {path}") from exc except json.JSONDecodeError as exc: - raise UsageError(f"Config file is not valid JSON: {exc}") from exc - config = jsonshape.as_mapping(data) - if config is None: - raise UsageError("Config file must contain a JSON object.") + raise UsageError(f"{label} is not valid JSON: {exc}") from exc + mapping = jsonshape.as_mapping(data) + if mapping is None: + raise UsageError(f"{label} must contain a JSON object.") + return mapping + + +def load_config_file(path: str | Path, fields: dict[str, str]) -> dict[str, object]: + """Load a JSON config file and validate its keys against `fields`.""" + config = _load_json_object(path, label="Config file") unknown = [key for key in config if key not in fields] if unknown: valid = ", ".join(sorted(fields)) @@ -297,14 +306,27 @@ def merge_transcribe_config( return _merge(TRANSCRIBE_FIELDS, flags, overrides, config_file) -def construct_transcription_config(merged: dict[str, typing.Any]) -> aai.TranscriptionConfig: - """Build a TranscriptionConfig from a merged kwargs dict, surfacing errors as usage.""" +_ConfigT = typing.TypeVar("_ConfigT") + + +def _construct( + model_cls: Callable[..., _ConfigT], merged: dict[str, typing.Any], *, label: str +) -> _ConfigT: + """Build `model_cls(**merged)`, surfacing SDK validation as a usage error. + + `label` (e.g. "transcription") names the config in the error message. + """ try: - return aai.TranscriptionConfig(**merged) + return model_cls(**merged) except UsageError: raise except Exception as exc: # surface SDK validation as a usage error - raise UsageError(f"Invalid transcription config: {exc}") from exc + raise UsageError(f"Invalid {label} config: {exc}") from exc + + +def construct_transcription_config(merged: dict[str, typing.Any]) -> aai.TranscriptionConfig: + """Build a TranscriptionConfig from a merged kwargs dict, surfacing errors as usage.""" + return _construct(aai.TranscriptionConfig, merged, label="transcription") def merge_streaming_params( @@ -329,12 +351,7 @@ def merge_streaming_params( def construct_streaming_params(merged: dict[str, typing.Any]) -> StreamingParameters: """Build StreamingParameters from a merged kwargs dict, surfacing errors as usage.""" - try: - return StreamingParameters(**merged) - except UsageError: - raise - except Exception as exc: - raise UsageError(f"Invalid streaming config: {exc}") from exc + return _construct(StreamingParameters, merged, label="streaming") def split_csv(value: str | None) -> list[str] | None: @@ -365,16 +382,7 @@ def auth_header_flags(value: str | None) -> dict[str, object]: def load_custom_spelling(path: str) -> dict[str, object]: """Load a custom-spelling JSON map (e.g. {"AssemblyAI": ["assembly ai"]}).""" - try: - data: object = json.loads(Path(path).read_text()) - except FileNotFoundError as exc: - raise UsageError(f"Custom spelling file not found: {path}") from exc - except json.JSONDecodeError as exc: - raise UsageError(f"Custom spelling file is not valid JSON: {exc}") from exc - mapping = jsonshape.as_mapping(data) - if mapping is None: - raise UsageError("Custom spelling file must contain a JSON object.") - return mapping + return _load_json_object(path, label="Custom spelling file") def translation_request(languages: list[str]) -> dict[str, object]: diff --git a/aai_cli/context.py b/aai_cli/context.py index f5127428..35cf79a3 100644 --- a/aai_cli/context.py +++ b/aai_cli/context.py @@ -87,9 +87,8 @@ def env_override_warning(state: AppState) -> str | None: return state.env_override_warning() -def _persist_browser_login(state: AppState) -> None: - profile = state.resolve_profile() - env = environments.active().name +def persist_browser_login(profile: str, env: str) -> None: + """Run the browser login flow and persist its credentials for `profile`/`env`.""" result = run_login_flow() config.set_api_key(profile, result.api_key) config.set_profile_env(profile, env) @@ -101,6 +100,10 @@ def _persist_browser_login(state: AppState) -> None: ) +def _persist_browser_login(state: AppState) -> None: + persist_browser_login(state.resolve_profile(), environments.active().name) + + def _login_persistence_error(exc: object) -> APIError: return APIError( f"Signed in, but could not save the credentials locally: {exc}", diff --git a/aai_cli/llm.py b/aai_cli/llm.py index de31a105..5ba3cede 100644 --- a/aai_cli/llm.py +++ b/aai_cli/llm.py @@ -182,22 +182,16 @@ def run_chain_steps( if not prompts: return [] - if transcript_id is not None: - output = transform_transcript( - api_key, - prompt=prompts[0], - model=model, - max_tokens=max_tokens, - transcript_id=transcript_id, - ) - else: - output = transform_transcript( - api_key, - prompt=prompts[0], - model=model, - max_tokens=max_tokens, - transcript_text=transcript_text, - ) + # Exactly one of transcript_id / transcript_text is set by callers; pass both + # through (build_messages prefers the id) so the two cases share one call. + output = transform_transcript( + api_key, + prompt=prompts[0], + model=model, + max_tokens=max_tokens, + transcript_id=transcript_id, + transcript_text=transcript_text, + ) steps = [{"prompt": prompts[0], "output": output}] for prompt in prompts[1:]: diff --git a/tests/test_login.py b/tests/test_login.py index 04363e43..999214a1 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -77,14 +77,14 @@ def test_logout_clears_key(): def test_login_oauth_flow_stores_returned_key(monkeypatch): - monkeypatch.setattr("aai_cli.commands.login.run_login_flow", _fake_login_result) + monkeypatch.setattr("aai_cli.context.run_login_flow", _fake_login_result) result = runner.invoke(app, ["login"]) assert result.exit_code == 0 assert config.get_api_key("default") == "sk_from_oauth" def test_login_oauth_persists_session(monkeypatch): - monkeypatch.setattr("aai_cli.commands.login.run_login_flow", _fake_login_result) + monkeypatch.setattr("aai_cli.context.run_login_flow", _fake_login_result) result = runner.invoke(app, ["login"]) assert result.exit_code == 0 assert config.get_session("default") == {"jwt": "jwt_x", "token": "tok_x"} @@ -124,7 +124,7 @@ def test_login_oauth_flow_failure_exits_nonzero(monkeypatch): def boom(): raise APIError("Login timed out waiting for the browser.") - monkeypatch.setattr("aai_cli.commands.login.run_login_flow", boom) + monkeypatch.setattr("aai_cli.context.run_login_flow", boom) result = runner.invoke(app, ["login"]) assert result.exit_code != 0 assert config.get_api_key("default") is None @@ -132,7 +132,7 @@ def boom(): def test_login_api_key_flag_still_bypasses_oauth(monkeypatch): monkeypatch.setattr( - "aai_cli.commands.login.run_login_flow", + "aai_cli.context.run_login_flow", lambda: (_ for _ in ()).throw(AssertionError("OAuth must not run with --api-key")), ) with patch("aai_cli.commands.login.client.validate_key", return_value=True): @@ -142,7 +142,7 @@ def test_login_api_key_flag_still_bypasses_oauth(monkeypatch): def test_login_binds_env_to_profile(monkeypatch): - monkeypatch.setattr("aai_cli.commands.login.run_login_flow", _fake_login_result) + monkeypatch.setattr("aai_cli.context.run_login_flow", _fake_login_result) result = runner.invoke(app, ["--env", "sandbox000", "login"]) assert result.exit_code == 0 assert config.get_api_key("default") == "sk_from_oauth" @@ -150,7 +150,7 @@ def test_login_binds_env_to_profile(monkeypatch): def test_sandbox_flag_is_shortcut_for_env(monkeypatch): - monkeypatch.setattr("aai_cli.commands.login.run_login_flow", lambda: _fake_login_result("sk_x")) + monkeypatch.setattr("aai_cli.context.run_login_flow", lambda: _fake_login_result("sk_x")) result = runner.invoke(app, ["--sandbox", "login"]) assert result.exit_code == 0 assert config.get_profile_env("default") == "sandbox000" diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 24a02558..4bdc2227 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -165,7 +165,7 @@ def test_transcribe_prompt_transforms_json(monkeypatch): _auth() seen = {} - def fake_transform(api_key, *, prompt, model, transcript_id, max_tokens): + def fake_transform(api_key, *, prompt, model, transcript_id, max_tokens, transcript_text=None): seen["prompt"] = prompt seen["model"] = model seen["transcript_id"] = transcript_id From ff1bee5e1d1f9f3a971b8e07a9d0ae884e12146d Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 6 Jun 2026 20:27:20 +0000 Subject: [PATCH 2/2] Use PEP 695 type parameter syntax for _construct The Python floor moved to 3.12 (PR #26), so ruff's UP047 now requires the inline generic syntax instead of a module-level TypeVar. --- aai_cli/config_builder.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/aai_cli/config_builder.py b/aai_cli/config_builder.py index af38062e..837e4810 100644 --- a/aai_cli/config_builder.py +++ b/aai_cli/config_builder.py @@ -306,12 +306,9 @@ def merge_transcribe_config( return _merge(TRANSCRIBE_FIELDS, flags, overrides, config_file) -_ConfigT = typing.TypeVar("_ConfigT") - - -def _construct( - model_cls: Callable[..., _ConfigT], merged: dict[str, typing.Any], *, label: str -) -> _ConfigT: +def _construct[ConfigT]( + model_cls: Callable[..., ConfigT], merged: dict[str, typing.Any], *, label: str +) -> ConfigT: """Build `model_cls(**merged)`, surfacing SDK validation as a usage error. `label` (e.g. "transcription") names the config in the error message.