Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions aai_cli/code_gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from aai_cli.code_gen import agent_cascade as _agent_cascade
from aai_cli.code_gen import stream as _stream
from aai_cli.code_gen import transcribe as _transcribe
from aai_cli.code_gen.serialize import GatewayOptions

if TYPE_CHECKING:
from aai_cli.agent_cascade.config import CascadeConfig


def gateway_options(
prompts: list[str], model: str, max_tokens: int, *, interval: float = 0.0
) -> dict[str, object] | None:
) -> GatewayOptions | None:
"""The LLM-gateway options dict consumed by `transcribe`/`stream`, or None if no prompts.

`interval` (streaming only) is the seconds between summary refreshes baked into the
Expand Down Expand Up @@ -49,7 +50,7 @@ def transcribe(
merged: dict[str, object],
source: str,
*,
llm_gateway: dict[str, object] | None = None,
llm_gateway: GatewayOptions | None = None,
download_sections: list[str] | None = None,
) -> str:
"""Generate runnable Python that reproduces this transcribe invocation."""
Expand All @@ -61,7 +62,7 @@ def transcribe(
def stream(
merged: dict[str, object],
*,
llm: dict[str, object] | None = None,
llm: GatewayOptions | None = None,
source: str | None = None,
) -> str:
"""Generate runnable Python that reproduces this streaming invocation.
Expand Down
16 changes: 16 additions & 0 deletions aai_cli/code_gen/serialize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
from __future__ import annotations

from typing import NotRequired, TypedDict

from assemblyai.streaming.v3 import SpeechModel


class GatewayOptions(TypedDict):
"""The LLM-Gateway prompt-chain options threaded into generated transcribe/stream code.

Built once by ``code_gen.gateway_options`` and consumed by the renderers, so the
fields stay typed end-to-end instead of forcing a ``cast`` at every subscript.
``interval`` (streaming-only refresh cadence) is optional — ``transcribe`` omits it.
"""

prompts: list[str]
model: str
max_tokens: int
interval: NotRequired[float]


def py_literal(value: object) -> str:
"""Render a coerced config value as Python source.

Expand Down
9 changes: 4 additions & 5 deletions aai_cli/code_gen/stream.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from typing import cast

from aai_cli.code_gen import serialize
from aai_cli.code_gen.serialize import GatewayOptions
from aai_cli.core import environments

# Streaming-class imports always used by the generated scaffold. SpeechModel is added
Expand Down Expand Up @@ -176,15 +175,15 @@ def _imports_block(merged: dict[str, object]) -> str:
return "\n".join(f" {name}," for name in sorted(names))


def _build_preamble(imports: str, llm: dict[str, object] | None, stdlib_imports: str) -> str:
def _build_preamble(imports: str, llm: GatewayOptions | None, stdlib_imports: str) -> str:
"""Pick and fill the plain vs. LLM-Gateway preamble for the given imports.

Hosts come from the active environment, so a sandbox run generates a script
that targets the sandbox its key was minted for, not production.
"""
env = environments.active()
if llm:
prompts = "\n".join(f" {p!r}," for p in cast("list[str]", llm["prompts"]))
prompts = "\n".join(f" {p!r}," for p in llm["prompts"])
return _LLM_PREAMBLE.format(
stdlib_imports=stdlib_imports,
imports=imports,
Expand Down Expand Up @@ -240,7 +239,7 @@ def _source_parts(source: str | None, rate: object) -> tuple[set[str], str, str,
def render(
merged: dict[str, object],
*,
llm: dict[str, object] | None = None,
llm: GatewayOptions | None = None,
source: str | None = None,
) -> str:
"""Generate a runnable streaming script with the given params.
Expand Down
13 changes: 6 additions & 7 deletions aai_cli/code_gen/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from typing import cast

from aai_cli.code_gen import serialize, snippets
from aai_cli.code_gen.serialize import GatewayOptions
from aai_cli.core import environments, llm, youtube

# ``-o/--output`` choice -> printed-result code, mirroring the run path's
Expand All @@ -28,7 +27,7 @@ def render(
merged: dict[str, object],
source: str,
*,
llm_gateway: dict[str, object] | None = None,
llm_gateway: GatewayOptions | None = None,
output: str | None = None,
chars_per_caption: int | None = None,
download_sections: list[str] | None = None,
Expand Down Expand Up @@ -95,7 +94,7 @@ def _download_ranges(sections: list[str]) -> tuple[str | None, bool]:


def _header_block(
llm_gateway: dict[str, object] | None,
llm_gateway: GatewayOptions | None,
output: str | None,
*,
needs_download: bool,
Expand Down Expand Up @@ -189,7 +188,7 @@ def _transcribe_block(

def _result_block(
merged: dict[str, object],
llm_gateway: dict[str, object] | None,
llm_gateway: GatewayOptions | None,
output: str | None,
chars_per_caption: int | None,
) -> list[str]:
Expand All @@ -207,14 +206,14 @@ def _result_block(
return [snippets.result_handling(merged)]


def _llm_gateway_block(llm_gateway: dict[str, object]) -> list[str]:
def _llm_gateway_block(llm_gateway: GatewayOptions) -> list[str]:
"""Emit a chained OpenAI-compatible LLM Gateway transform over the transcript.

The generated script loops over the prompts: the first runs over the transcript
(injected server-side via ``transcript_id`` wherever the ``{{ transcript }}`` tag
appears), and each subsequent prompt runs over the previous response.
"""
prompts = cast("list[str]", llm_gateway["prompts"])
prompts = llm_gateway["prompts"]
prompt_lines = "\n".join(f" {p!r}," for p in prompts)
return [
"# Transform the transcript through AssemblyAI's LLM Gateway (OpenAI-compatible).",
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_cli_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _transcribe_sample(key: str, *flags: str, timeout: int = 180) -> dict[str, A
"""Transcribe the hosted sample with `flags`, asserting success, return JSON."""
proc = _run_cli(["transcribe", "--sample", *flags, "--json"], key, timeout=timeout)
assert proc.returncode == 0, f"args={flags} stderr:\n{proc.stderr}"
return json.loads(proc.stdout) # type: ignore[no-any-return]
return dict(json.loads(proc.stdout))


# --- Batch transcription --------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion tests/test_code_gen_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from hypothesis import strategies as st

from aai_cli import code_gen
from aai_cli.code_gen.serialize import GatewayOptions

_LLM = {"prompts": ["summarize"], "model": "m", "max_tokens": 100, "interval": 5.0}
_LLM: GatewayOptions = {"prompts": ["summarize"], "model": "m", "max_tokens": 100, "interval": 5.0}


def _compiles(code: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_code_gen_stream_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_gateway_options_defaults_interval_to_per_turn():
# per-turn (0.0); pins the default so it can't drift.
opts = code_gen.gateway_options(["summarize"], "m", 100)
assert opts is not None
assert opts["interval"] == 0.0
assert opts.get("interval") == 0.0


def test_stream_show_code_defaults_interval_when_absent():
Expand Down
Loading