From 204ec243eb87af947f3419b3458a04232fabe3ec Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 17 Jun 2026 04:41:39 +0000 Subject: [PATCH] refactor(code_gen): type LLM-Gateway options to drop cast() escape hatches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a `GatewayOptions` TypedDict (in code_gen/serialize.py) for the prompt-chain options that `code_gen.gateway_options` builds and the transcribe/ stream renderers consume. Threading the typed shape end-to-end lets pyright/mypy know `llm["prompts"]` is `list[str]`, removing the two `cast("list[str]", …)` calls in transcribe.py and stream.py. `interval` is `NotRequired` (transcribe omits it; the streaming refresh-loop test deliberately exercises the absent case). Also drops the `# type: ignore[no-any-return]` in the e2e helper by returning a constructed dict instead of the bare `json.loads` result (no new `Any` token). Net escape hatches: cast() 13->11, type:ignore/noqa/no-cover 20->19. Co-Authored-By: Claude Opus 4.8 (1M context) Claude-Session: https://claude.ai/code/session_01FunrXSdcsWguki9hyeeMMe --- aai_cli/code_gen/__init__.py | 7 ++++--- aai_cli/code_gen/serialize.py | 16 ++++++++++++++++ aai_cli/code_gen/stream.py | 9 ++++----- aai_cli/code_gen/transcribe.py | 13 ++++++------- tests/e2e/test_cli_e2e.py | 2 +- tests/test_code_gen_stream.py | 3 ++- tests/test_code_gen_stream_agent.py | 2 +- 7 files changed, 34 insertions(+), 18 deletions(-) diff --git a/aai_cli/code_gen/__init__.py b/aai_cli/code_gen/__init__.py index 4bd15304..91af6f4f 100644 --- a/aai_cli/code_gen/__init__.py +++ b/aai_cli/code_gen/__init__.py @@ -6,6 +6,7 @@ from aai_cli.code_gen import agent_cascade as _agent_cascade from aai_cli.code_gen import stream as _stream from aai_cli.code_gen import transcribe as _transcribe +from aai_cli.code_gen.serialize import GatewayOptions if TYPE_CHECKING: from aai_cli.agent_cascade.config import CascadeConfig @@ -13,7 +14,7 @@ def gateway_options( prompts: list[str], model: str, max_tokens: int, *, interval: float = 0.0 -) -> dict[str, object] | None: +) -> GatewayOptions | None: """The LLM-gateway options dict consumed by `transcribe`/`stream`, or None if no prompts. `interval` (streaming only) is the seconds between summary refreshes baked into the @@ -49,7 +50,7 @@ def transcribe( merged: dict[str, object], source: str, *, - llm_gateway: dict[str, object] | None = None, + llm_gateway: GatewayOptions | None = None, download_sections: list[str] | None = None, ) -> str: """Generate runnable Python that reproduces this transcribe invocation.""" @@ -61,7 +62,7 @@ def transcribe( def stream( merged: dict[str, object], *, - llm: dict[str, object] | None = None, + llm: GatewayOptions | None = None, source: str | None = None, ) -> str: """Generate runnable Python that reproduces this streaming invocation. diff --git a/aai_cli/code_gen/serialize.py b/aai_cli/code_gen/serialize.py index 111c3a52..162f6e02 100644 --- a/aai_cli/code_gen/serialize.py +++ b/aai_cli/code_gen/serialize.py @@ -1,8 +1,24 @@ from __future__ import annotations +from typing import NotRequired, TypedDict + from assemblyai.streaming.v3 import SpeechModel +class GatewayOptions(TypedDict): + """The LLM-Gateway prompt-chain options threaded into generated transcribe/stream code. + + Built once by ``code_gen.gateway_options`` and consumed by the renderers, so the + fields stay typed end-to-end instead of forcing a ``cast`` at every subscript. + ``interval`` (streaming-only refresh cadence) is optional — ``transcribe`` omits it. + """ + + prompts: list[str] + model: str + max_tokens: int + interval: NotRequired[float] + + def py_literal(value: object) -> str: """Render a coerced config value as Python source. diff --git a/aai_cli/code_gen/stream.py b/aai_cli/code_gen/stream.py index 7e852cef..568e5fe2 100644 --- a/aai_cli/code_gen/stream.py +++ b/aai_cli/code_gen/stream.py @@ -1,8 +1,7 @@ from __future__ import annotations -from typing import cast - from aai_cli.code_gen import serialize +from aai_cli.code_gen.serialize import GatewayOptions from aai_cli.core import environments # Streaming-class imports always used by the generated scaffold. SpeechModel is added @@ -176,7 +175,7 @@ def _imports_block(merged: dict[str, object]) -> str: return "\n".join(f" {name}," for name in sorted(names)) -def _build_preamble(imports: str, llm: dict[str, object] | None, stdlib_imports: str) -> str: +def _build_preamble(imports: str, llm: GatewayOptions | None, stdlib_imports: str) -> str: """Pick and fill the plain vs. LLM-Gateway preamble for the given imports. Hosts come from the active environment, so a sandbox run generates a script @@ -184,7 +183,7 @@ def _build_preamble(imports: str, llm: dict[str, object] | None, stdlib_imports: """ env = environments.active() if llm: - prompts = "\n".join(f" {p!r}," for p in cast("list[str]", llm["prompts"])) + prompts = "\n".join(f" {p!r}," for p in llm["prompts"]) return _LLM_PREAMBLE.format( stdlib_imports=stdlib_imports, imports=imports, @@ -240,7 +239,7 @@ def _source_parts(source: str | None, rate: object) -> tuple[set[str], str, str, def render( merged: dict[str, object], *, - llm: dict[str, object] | None = None, + llm: GatewayOptions | None = None, source: str | None = None, ) -> str: """Generate a runnable streaming script with the given params. diff --git a/aai_cli/code_gen/transcribe.py b/aai_cli/code_gen/transcribe.py index 4a5bc553..244ca328 100644 --- a/aai_cli/code_gen/transcribe.py +++ b/aai_cli/code_gen/transcribe.py @@ -1,8 +1,7 @@ from __future__ import annotations -from typing import cast - from aai_cli.code_gen import serialize, snippets +from aai_cli.code_gen.serialize import GatewayOptions from aai_cli.core import environments, llm, youtube # ``-o/--output`` choice -> printed-result code, mirroring the run path's @@ -28,7 +27,7 @@ def render( merged: dict[str, object], source: str, *, - llm_gateway: dict[str, object] | None = None, + llm_gateway: GatewayOptions | None = None, output: str | None = None, chars_per_caption: int | None = None, download_sections: list[str] | None = None, @@ -95,7 +94,7 @@ def _download_ranges(sections: list[str]) -> tuple[str | None, bool]: def _header_block( - llm_gateway: dict[str, object] | None, + llm_gateway: GatewayOptions | None, output: str | None, *, needs_download: bool, @@ -189,7 +188,7 @@ def _transcribe_block( def _result_block( merged: dict[str, object], - llm_gateway: dict[str, object] | None, + llm_gateway: GatewayOptions | None, output: str | None, chars_per_caption: int | None, ) -> list[str]: @@ -207,14 +206,14 @@ def _result_block( return [snippets.result_handling(merged)] -def _llm_gateway_block(llm_gateway: dict[str, object]) -> list[str]: +def _llm_gateway_block(llm_gateway: GatewayOptions) -> list[str]: """Emit a chained OpenAI-compatible LLM Gateway transform over the transcript. The generated script loops over the prompts: the first runs over the transcript (injected server-side via ``transcript_id`` wherever the ``{{ transcript }}`` tag appears), and each subsequent prompt runs over the previous response. """ - prompts = cast("list[str]", llm_gateway["prompts"]) + prompts = llm_gateway["prompts"] prompt_lines = "\n".join(f" {p!r}," for p in prompts) return [ "# Transform the transcript through AssemblyAI's LLM Gateway (OpenAI-compatible).", diff --git a/tests/e2e/test_cli_e2e.py b/tests/e2e/test_cli_e2e.py index 8f7f8193..5e36619f 100644 --- a/tests/e2e/test_cli_e2e.py +++ b/tests/e2e/test_cli_e2e.py @@ -56,7 +56,7 @@ def _transcribe_sample(key: str, *flags: str, timeout: int = 180) -> dict[str, A """Transcribe the hosted sample with `flags`, asserting success, return JSON.""" proc = _run_cli(["transcribe", "--sample", *flags, "--json"], key, timeout=timeout) assert proc.returncode == 0, f"args={flags} stderr:\n{proc.stderr}" - return json.loads(proc.stdout) # type: ignore[no-any-return] + return dict(json.loads(proc.stdout)) # --- Batch transcription -------------------------------------------------- diff --git a/tests/test_code_gen_stream.py b/tests/test_code_gen_stream.py index 83da8876..2b44dae0 100644 --- a/tests/test_code_gen_stream.py +++ b/tests/test_code_gen_stream.py @@ -10,8 +10,9 @@ from hypothesis import strategies as st from aai_cli import code_gen +from aai_cli.code_gen.serialize import GatewayOptions -_LLM = {"prompts": ["summarize"], "model": "m", "max_tokens": 100, "interval": 5.0} +_LLM: GatewayOptions = {"prompts": ["summarize"], "model": "m", "max_tokens": 100, "interval": 5.0} def _compiles(code: str) -> None: diff --git a/tests/test_code_gen_stream_agent.py b/tests/test_code_gen_stream_agent.py index 431db191..b8912c99 100644 --- a/tests/test_code_gen_stream_agent.py +++ b/tests/test_code_gen_stream_agent.py @@ -112,7 +112,7 @@ def test_gateway_options_defaults_interval_to_per_turn(): # per-turn (0.0); pins the default so it can't drift. opts = code_gen.gateway_options(["summarize"], "m", 100) assert opts is not None - assert opts["interval"] == 0.0 + assert opts.get("interval") == 0.0 def test_stream_show_code_defaults_interval_when_absent():