From 2e2880091ffc0e10bb885e35ef77ebdc56d41f1d Mon Sep 17 00:00:00 2001 From: Alex Kroman Date: Mon, 8 Jun 2026 15:43:15 -0700 Subject: [PATCH] Add stream --llm-interval throttle and fix show-code LLM-chain parity Bake the --llm-interval cadence into the live `stream --llm` loop and its generated `--show-code` script: refresh the prompt chain at most once per interval (0 = every turn) with a closing flush so tail turns aren't lost. Also fix `transcribe --show-code` chained prompts to wrap follow-up steps under the same "Transcript:" label the CLI's run_chain_steps uses, so the generated script matches a real run. Co-Authored-By: Claude Opus 4.8 (1M context) --- aai_cli/code_gen/__init__.py | 23 +++- aai_cli/code_gen/stream.py | 39 +++++- aai_cli/code_gen/transcribe.py | 2 +- aai_cli/commands/stream.py | 18 ++- aai_cli/streaming/render.py | 77 ++++++++---- aai_cli/streaming/session.py | 83 +++++++++--- aai_cli/theme.py | 11 +- .../test_cli_output_snapshots.ambr | 26 ++-- tests/test_code_gen.py | 39 +++++- tests/test_stream_llm.py | 118 ++++++++++++++++++ tests/test_stream_session.py | 33 +++++ tests/test_streaming_render.py | 73 ++++++++++- tests/test_theme.py | 9 ++ 13 files changed, 477 insertions(+), 74 deletions(-) diff --git a/aai_cli/code_gen/__init__.py b/aai_cli/code_gen/__init__.py index ecc14cc9..5ed4cbaa 100644 --- a/aai_cli/code_gen/__init__.py +++ b/aai_cli/code_gen/__init__.py @@ -5,11 +5,22 @@ from aai_cli.code_gen import transcribe as _transcribe -def gateway_options(prompts: list[str], model: str, max_tokens: int) -> dict[str, object] | None: - """The LLM-gateway options dict consumed by `transcribe`/`stream`, or None if no prompts.""" +def gateway_options( + prompts: list[str], model: str, max_tokens: int, *, interval: float = 0.0 +) -> dict[str, object] | None: + """The LLM-gateway options dict consumed by `transcribe`/`stream`, or None if no prompts. + + `interval` (streaming only) is the seconds between summary refreshes baked into the + generated `stream --llm` loop; 0 refreshes on every turn. `transcribe` ignores it. + """ if not prompts: return None - return {"prompts": list(prompts), "model": model, "max_tokens": max_tokens} + return { + "prompts": list(prompts), + "model": model, + "max_tokens": max_tokens, + "interval": interval, + } def agent(voice: str, system_prompt: str, greeting: str) -> str: @@ -34,8 +45,8 @@ def stream( ) -> str: """Generate runnable Python that reproduces this streaming invocation. - With `llm` (a dict of ``prompts``/``model``/``max_tokens``), the script refreshes a - prompt-chain over the growing transcript on every finalized turn — the live sibling - of `transcribe --llm` — mirroring how `stream --llm` runs. + With `llm` (a dict of ``prompts``/``model``/``max_tokens``/``interval``), the script + refreshes a prompt-chain over the growing transcript every ``interval`` seconds (0 = + every turn) — the live sibling of `transcribe --llm` — mirroring how `stream --llm` runs. """ return _stream.render(merged, llm=llm) diff --git a/aai_cli/code_gen/stream.py b/aai_cli/code_gen/stream.py index 02465eae..cb54d3ad 100644 --- a/aai_cli/code_gen/stream.py +++ b/aai_cli/code_gen/stream.py @@ -40,6 +40,7 @@ def on_turn(client: StreamingClient, event: TurnEvent) -> None: """ _LLM_PREAMBLE = """import os +import time import assemblyai as aai from assemblyai.streaming.v3 import ( @@ -57,7 +58,12 @@ def on_turn(client: StreamingClient, event: TurnEvent) -> None: PROMPTS = [ {prompts} ] +# Turns accumulate continuously; the prompt chain re-runs at most once every +# LLM_INTERVAL seconds (0 = on every finalized turn). +LLM_INTERVAL = {interval} transcript: list[str] = [] +_summarized = 0 +_last_summary = float("-inf") def run_chain(text: str) -> str: @@ -73,12 +79,26 @@ def run_chain(text: str) -> str: return result +def summarize(*, final: bool = False) -> None: + # Refresh the answer over the growing transcript, throttled to LLM_INTERVAL. `final` + # forces a closing refresh so turns since the last tick aren't lost on stop. + global _summarized, _last_summary + turns = len(transcript) + if turns <= _summarized: + return + now = time.monotonic() + if not final and LLM_INTERVAL > 0 and now - _last_summary < LLM_INTERVAL: + return + _summarized = turns + _last_summary = now + print(run_chain(" ".join(transcript))) + + def on_turn(client: StreamingClient, event: TurnEvent) -> None: - # Refresh the answer on every finalized turn, over the growing transcript. if not event.end_of_turn or not event.transcript: return transcript.append(event.transcript) - print(run_chain(" ".join(transcript))) + summarize() client = StreamingClient( @@ -95,6 +115,17 @@ def on_turn(client: StreamingClient, event: TurnEvent) -> None: client.disconnect(terminate=True) """ +# Same as _FOOTER, but flushes a closing summary (incl. on Ctrl-C) so the turns since the +# last interval tick are reflected before disconnecting. +_LLM_FOOTER = """ +print("Listening… press Ctrl-C to stop.") +try: + client.stream(aai.extras.MicrophoneStream(sample_rate={rate})) +finally: + summarize(final=True) + client.disconnect(terminate=True) +""" + def _imports_block(merged: dict[str, object]) -> str: """Sorted streaming-class import lines; SpeechModel only when a model kwarg is emitted.""" @@ -114,6 +145,7 @@ def _build_preamble(imports: str, llm: dict[str, object] | None) -> str: prompts=prompts, model=llm["model"], max_tokens=llm["max_tokens"], + interval=llm.get("interval", 0.0), ) return _PREAMBLE.format(imports=imports) @@ -138,4 +170,5 @@ def render(merged: dict[str, object], *, llm: dict[str, object] | None = None) - # Mic capture rate must match StreamingParameters.sample_rate, else audio is corrupt. rate = merged.get("sample_rate", 16000) connect = _build_connect(merged) - return preamble + "\n" + connect + "\n" + _FOOTER.format(rate=rate) + footer = _LLM_FOOTER if llm else _FOOTER + return preamble + "\n" + connect + "\n" + footer.format(rate=rate) diff --git a/aai_cli/code_gen/transcribe.py b/aai_cli/code_gen/transcribe.py index cdb53531..3d97439b 100644 --- a/aai_cli/code_gen/transcribe.py +++ b/aai_cli/code_gen/transcribe.py @@ -86,7 +86,7 @@ def _llm_gateway_block(llm_gateway: dict[str, object]) -> list[str]: f' content = prompt + "\\n\\n{llm.TRANSCRIPT_TAG}"', ' extra = {"transcript_id": transcript.id}', " else:", - ' content = prompt + "\\n\\n" + result', + ' content = prompt + "\\n\\nTranscript:\\n" + result', " extra = None", " response = gateway.chat.completions.create(", f" model={llm_gateway['model']!r},", diff --git a/aai_cli/commands/stream.py b/aai_cli/commands/stream.py index 5836434b..76fd81ba 100644 --- a/aai_cli/commands/stream.py +++ b/aai_cli/commands/stream.py @@ -195,7 +195,11 @@ def stream( ), # speakers speaker_labels: bool | None = typer.Option( - None, "--speaker-labels", help="Label speakers.", rich_help_panel=help_panels.OPT_SPEAKERS + None, + "--speaker-labels", + help='Diarize speakers. With system audio the mic stays "You"; only the system ' + "audio is split into speakers.", + rich_help_panel=help_panels.OPT_SPEAKERS, ), max_speakers: int | None = typer.Option( None, @@ -270,6 +274,13 @@ def stream( "one's response (a chain).", rich_help_panel=help_panels.OPT_LLM, ), + llm_interval: float = typer.Option( + 30.0, + "--llm-interval", + help="Seconds between --llm summary refreshes (0 refreshes on every turn).", + min=0.0, + rich_help_panel=help_panels.OPT_LLM, + ), model: str = typer.Option( llm.DEFAULT_MODEL, "--model", @@ -367,7 +378,9 @@ def body(state: AppState, json_mode: bool) -> None: overrides=config_kv, config_file=config_file, ) - gateway = code_gen.gateway_options(list(llm_prompt or []), model, max_tokens) + gateway = code_gen.gateway_options( + list(llm_prompt or []), model, max_tokens, interval=llm_interval + ) output.print_code(code_gen.stream(merged, llm=gateway)) return @@ -385,6 +398,7 @@ def body(state: AppState, json_mode: bool) -> None: llm_prompts=llm_prompts, model=model, max_tokens=max_tokens, + llm_interval=llm_interval, ) _dispatch(session, opts) diff --git a/aai_cli/streaming/render.py b/aai_cli/streaming/render.py index 420c9f99..2070e275 100644 --- a/aai_cli/streaming/render.py +++ b/aai_cli/streaming/render.py @@ -6,8 +6,39 @@ from rich.console import Console from rich.text import Text +from aai_cli import theme from aai_cli.render import BaseRenderer +# Source label -> (display text, Rich style). System audio borrows the agent color; +# the microphone ("you") its own. Unknown sources fall back to the raw label. +_SOURCE_LABELS: dict[str, tuple[str, str]] = { + "system": ("System", "aai.agent"), + "you": ("You", "aai.you"), +} + + +def speaker_prefix(source: str | None, speaker: str | None) -> tuple[str, str] | None: + """The lead-in label and Rich style for a turn, or None when it has neither a + source nor a diarized speaker. + + - source + speaker -> "System (A)" (system audio diarized via --speaker-labels) + - source only -> "System" (parallel system/you streams) + - speaker only -> "Speaker A" (single-stream diarization, no source label) + + When a speaker is present the whole label is tinted by `theme.speaker_style` so each + speaker reads in its own color (matching batch transcribe's diarized output); a + sourced turn with no speaker keeps the source's own color. + """ + label, style = (None, "aai.label") + if source is not None: + label, style = _SOURCE_LABELS.get(source, (source, "aai.label")) + if speaker is not None: + style = theme.speaker_style(speaker) + return (f"{label} ({speaker})" if label is not None else f"Speaker {speaker}"), style + if label is not None: + return label, style + return None + class StreamRenderer(BaseRenderer): """Renders streaming events in one of three modes. @@ -46,25 +77,16 @@ def _with_source(payload: dict[str, object], source: str | None) -> dict[str, ob return payload @staticmethod - def _source_label(source: str) -> tuple[str, str]: - labels = { - "system": ("System", "aai.agent"), - "you": ("You", "aai.you"), - } - return labels.get(source, (source, "aai.label")) - - @classmethod - def _label(cls, text: str, source: str | None) -> str: - if source is None: - return text - label, _style = cls._source_label(source) - return f"{label}: {text}" + def _label(text: str, source: str | None, speaker: str | None = None) -> str: + prefix = speaker_prefix(source, speaker) + return text if prefix is None else f"{prefix[0]}: {text}" - @classmethod - def _styled_label(cls, text: str, source: str | None) -> str | Text: - if source is None: + @staticmethod + def _styled_label(text: str, source: str | None, speaker: str | None = None) -> str | Text: + prefix = speaker_prefix(source, speaker) + if prefix is None: return text - label, style = cls._source_label(source) + label, style = prefix rendered = Text() rendered.append(f"{label}: ", style=style) rendered.append(text) @@ -90,21 +112,24 @@ def listening(self) -> None: def turn(self, event: object, *, source: str | None = None) -> None: text = getattr(event, "transcript", "") or "" end = bool(getattr(event, "end_of_turn", False)) + speaker = getattr(event, "speaker_label", None) # set when --speaker-labels diarizes with self._lock: if self.json_mode: - self._emit( - self._with_source( - {"type": "turn", "transcript": text, "end_of_turn": end}, - source, - ) - ) + payload: dict[str, object] = { + "type": "turn", + "transcript": text, + "end_of_turn": end, + } + if speaker is not None: + payload["speaker"] = speaker + self._emit(self._with_source(payload, source)) elif self.text_mode: if end and text: - self._write(self._label(text, source) + "\n") # plain finalized line + self._write(self._label(text, source, speaker) + "\n") # plain finalized line elif end: - self._finalize_line(self._styled_label(text, source)) + self._finalize_line(self._styled_label(text, source, speaker)) else: - self._update_line(self._styled_label(text, source)) + self._update_line(self._styled_label(text, source, speaker)) def termination(self, event: object, *, source: str | None = None) -> None: with self._lock: diff --git a/aai_cli/streaming/session.py b/aai_cli/streaming/session.py index d4dd3d0b..0ad62c94 100644 --- a/aai_cli/streaming/session.py +++ b/aai_cli/streaming/session.py @@ -2,6 +2,7 @@ import queue import threading +import time from collections.abc import Callable, Iterable from dataclasses import dataclass, field from pathlib import Path @@ -11,7 +12,7 @@ from aai_cli import client, config_builder, llm from aai_cli.errors import CLIError, UsageError from aai_cli.follow import FollowRenderer -from aai_cli.streaming.render import StreamRenderer +from aai_cli.streaming.render import StreamRenderer, speaker_prefix # Sources that can be transcribed in parallel sessions: (label, audio chunks, sample rate). _ParallelStreams = list[tuple[str, Iterable[bytes], int]] @@ -96,10 +97,18 @@ class StreamSession: llm_prompts: list[str] model: str max_tokens: int + # Seconds between --llm summary refreshes; <=0 re-runs the chain on every turn. + llm_interval: float = 0.0 + # Monotonic clock, injectable so the interval throttle is deterministic in tests. + clock: Callable[[], float] = time.monotonic transcript: list[str] = field(default_factory=list[str]) _callback_lock: threading.RLock = field(default_factory=threading.RLock) _listening_lock: threading.Lock = field(default_factory=threading.Lock) _listening_started: bool = False + # How many turns the last refresh covered, and when it ran (monotonic seconds). + # -inf so the very first finalized turn always produces an immediate summary. + _summarized_len: int = 0 + _last_summary_at: float = float("-inf") @property def on_open(self) -> Callable[[], None]: @@ -115,39 +124,70 @@ def _listening_once(self) -> None: self.renderer.listening() def on_turn(self, event: object, *, source_label: str | None = None) -> None: - with self._callback_lock: - if self.follow is None: + if self.follow is None: + with self._callback_lock: self.renderer.turn(event, source=source_label) - else: - self._refresh_answer(event, source_label) + else: + # --llm mode locks only to record the turn; the chain re-runs (network) are + # left unlocked so the other source's turns keep flowing during a refresh. + self._record_turn(event, source_label) - def _refresh_answer(self, event: object, source_label: str | None) -> None: - """Live --llm mode: re-run the prompt chain over the growing transcript on every - finalized turn, refreshing one evolving answer (partials are ignored).""" - follow = self.follow - if follow is None or not getattr(event, "end_of_turn", False): - return + def _record_turn(self, event: object, source_label: str | None) -> None: + """Append a finalized turn to the running transcript, then refresh the --llm + answer if a refresh is due (every turn, or once per ``llm_interval`` seconds).""" + if not getattr(event, "end_of_turn", False): + return # partials don't change the transcript text = getattr(event, "transcript", "") or "" if not text: return - if source_label is not None: - display_source = {"system": "System", "you": "You"}.get(source_label, source_label) - text = f"{display_source}: {text}" - self.transcript.append(text) + prefix = speaker_prefix(source_label, getattr(event, "speaker_label", None)) + line = f"{prefix[0]}: {text}" if prefix is not None else text + with self._callback_lock: + self.transcript.append(line) + self._maybe_summarize() + + def _maybe_summarize(self, *, final: bool = False) -> None: + """Re-run the prompt chain over the transcript so far and refresh the answer. + + Claims the work under the lock — bumping ``_summarized_len``/``_last_summary_at`` + before releasing — so concurrent source threads never double-run the chain or + race the throttle. No-op when nothing new has been transcribed, or (unless + ``final``) when fewer than ``llm_interval`` seconds have elapsed since the last + refresh. ``final`` forces the closing flush so the tail turns aren't lost.""" + follow = self.follow + if follow is None: + return + with self._callback_lock: + turns = len(self.transcript) + if turns <= self._summarized_len: + return # nothing new since the last refresh + now = self.clock() + throttled = self.llm_interval > 0 and now - self._last_summary_at < self.llm_interval + if throttled and not final: + return + transcript_text = " ".join(self.transcript) + self._summarized_len = turns + self._last_summary_at = now answer = llm.run_chain( self.api_key, self.llm_prompts, - transcript_text=" ".join(self.transcript), + transcript_text=transcript_text, model=self.model, max_tokens=self.max_tokens, ) - follow(answer, len(self.transcript)) + follow(answer, turns) def stream_one( self, audio: Iterable[bytes], rate: int, *, source_label: str | None = None ) -> None: + flags = self.base_flags | {"sample_rate": rate} + if source_label == "you": + # The microphone captures you alone, so never diarize it into separate + # speakers — force speaker_labels off so the mic stays labeled "You" even + # when --speaker-labels splits the system audio into speakers. + flags = flags | {"speaker_labels": False, "max_speakers": None} merged = config_builder.merge_streaming_params( - flags=self.base_flags | {"sample_rate": rate}, + flags=flags, overrides=self.overrides, config_file=self.config_file, ) @@ -176,7 +216,12 @@ def _guarded(self, work: Callable[[], None]) -> None: try: if self.follow is not None: with self.follow: - work() + try: + work() + finally: + # Flush a closing summary (incl. on Ctrl-C) so turns since the + # last interval tick are reflected, while the panel's still live. + self._maybe_summarize(final=True) else: work() except KeyboardInterrupt: diff --git a/aai_cli/theme.py b/aai_cli/theme.py index 16206db5..8ebcf269 100644 --- a/aai_cli/theme.py +++ b/aai_cli/theme.py @@ -17,7 +17,9 @@ SYMBOL_WARN = "!" SYMBOL_HINT = "›" # noqa: RUF001 — deliberate angle-quote glyph, not a '>' typo -# Per-speaker label colors, rotated deterministically by speaker_style(). +# Per-speaker label colors, rotated deterministically by speaker_style(). Deliberately +# excludes the brand blue: that hue is reserved for "you" (aai.you) so a diarized system +# speaker can never be tinted the same color as your own mic. SPEAKER_STYLES: tuple[str, ...] = ( "aai.speaker.0", "aai.speaker.1", @@ -31,8 +33,9 @@ "aai.brand": f"bold {BRAND}", "aai.heading": f"bold {BRAND}", "aai.label": BRAND, - # Conversation labels: the human keeps the brand accent, the agent gets a - # distinct hue so "you:" and "agent:" are easy to tell apart at a glance. + # Conversation labels: the human keeps the brand accent (reserved — never reused + # for a diarized speaker, see SPEAKER_STYLES), the agent gets a distinct hue so + # "you:" and "agent:" are easy to tell apart at a glance. "aai.you": BRAND, "aai.agent": "cyan", # Links/URLs in cyan, the convention both the Vercel and Supabase CLIs use so @@ -45,7 +48,7 @@ "aai.error": "bold red", "aai.warn": "yellow", "aai.muted": "dim", - "aai.speaker.0": BRAND, + "aai.speaker.0": "dark_orange", "aai.speaker.1": "cyan", "aai.speaker.2": "magenta", "aai.speaker.3": "green", diff --git a/tests/__snapshots__/test_cli_output_snapshots.ambr b/tests/__snapshots__/test_cli_output_snapshots.ambr index e867dca5..a64bdb72 100644 --- a/tests/__snapshots__/test_cli_output_snapshots.ambr +++ b/tests/__snapshots__/test_cli_output_snapshots.ambr @@ -582,7 +582,10 @@ │ turns. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Speakers & Channels ────────────────────────────────────────────────────────╮ - │ --speaker-labels Label speakers. │ + │ --speaker-labels Diarize speakers. With system │ + │ audio the mic stays "You"; │ + │ only the system audio is split │ + │ into speakers. │ │ --max-speakers INTEGER RANGE [x>=1] Max speakers. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Features ───────────────────────────────────────────────────────────────────╮ @@ -605,13 +608,20 @@ │ --webhook-auth-header NAME:VALUE Webhook auth header as NAME:VALUE. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ LLM Transform ──────────────────────────────────────────────────────────────╮ - │ --llm TEXT Run a prompt over the live transcript through │ - │ LLM Gateway, refreshing the answer on every │ - │ finalized turn. Repeatable: each prompt runs on │ - │ the previous one's response (a chain). │ - │ --model TEXT LLM Gateway model. │ - │ [default: claude-haiku-4-5-20251001] │ - │ --max-tokens INTEGER Max tokens. [default: 1000] │ + │ --llm TEXT Run a prompt over the live │ + │ transcript through LLM Gateway, │ + │ refreshing the answer on every │ + │ finalized turn. Repeatable: each │ + │ prompt runs on the previous │ + │ one's response (a chain). │ + │ --llm-interval FLOAT RANGE [x>=0.0] Seconds between --llm summary │ + │ refreshes (0 refreshes on every │ + │ turn). │ + │ [default: 30.0] │ + │ --model TEXT LLM Gateway model. │ + │ [default: │ + │ claude-haiku-4-5-20251001] │ + │ --max-tokens INTEGER Max tokens. [default: 1000] │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Advanced ───────────────────────────────────────────────────────────────────╮ │ --config KEY=VALUE Set any StreamingParameters field as │ diff --git a/tests/test_code_gen.py b/tests/test_code_gen.py index c92d26b5..79627f26 100644 --- a/tests/test_code_gen.py +++ b/tests/test_code_gen.py @@ -335,9 +335,10 @@ def test_transcribe_show_code_chains_multiple_llm_gateway_prompts(): assert "'summarize'," in code assert "'translate the summary to Spanish'," in code assert "for i, prompt in enumerate(prompts):" in code - # First step uses the transcript; later steps chain on the previous result. + # First step uses the transcript; later steps chain on the previous result, + # wrapped under the same "Transcript:" label the CLI's run_chain_steps uses. assert '"transcript_id": transcript.id' in code - assert 'content = prompt + "\\n\\n" + result' in code + assert 'content = prompt + "\\n\\nTranscript:\\n" + result' in code def test_transcribe_show_code_without_gateway_has_no_openai_import(): @@ -365,6 +366,7 @@ def test_stream_show_code_includes_llm_follow_loop(): "prompts": ["summarize", "translate to french"], "model": "claude-haiku-4-5-20251001", "max_tokens": 500, + "interval": 30.0, }, ) ast.parse(code) @@ -372,10 +374,41 @@ def test_stream_show_code_includes_llm_follow_loop(): assert "llm-gateway.assemblyai.com" in code # Both prompts appear, in order, for the chain. assert code.index("summarize") < code.index("translate to french") - # Still streams from the mic, refreshing the answer on each finalized turn. + # Still streams from the mic, refreshing the answer on the interval. assert "MicrophoneStream" in code assert "end_of_turn" in code assert "claude-haiku-4-5-20251001" in code + # The generated loop mirrors --llm-interval: a baked-in throttle plus a closing flush. + assert "LLM_INTERVAL = 30.0" in code + assert "now - _last_summary < LLM_INTERVAL" in code + assert "summarize(final=True)" in code + + +def test_gateway_options_defaults_interval_to_per_turn(): + # Called without an explicit interval (transcribe's path), the baked-in cadence is + # per-turn (0.0); pins the default so it can't drift. + opts = code_gen.gateway_options(["summarize"], "m", 100) + assert opts is not None + assert opts["interval"] == 0.0 + + +def test_stream_show_code_defaults_interval_when_absent(): + # An llm dict with no "interval" key falls back to per-turn (LLM_INTERVAL = 0.0). + code = code_gen.stream({}, llm={"prompts": ["s"], "model": "m", "max_tokens": 1}) + ast.parse(code) + assert "LLM_INTERVAL = 0.0" in code + + +def test_stream_show_code_llm_interval_zero_is_per_turn(): + # --llm-interval 0 bakes in the legacy per-turn cadence (LLM_INTERVAL = 0.0 makes the + # throttle a no-op), while still emitting the closing flush. + code = code_gen.stream( + {}, + llm=code_gen.gateway_options(["summarize"], "m", 100, interval=0.0), + ) + ast.parse(code) + assert "LLM_INTERVAL = 0.0" in code + assert "summarize(final=True)" in code def test_stream_show_code_without_llm_is_plain_scaffold(): diff --git a/tests/test_stream_llm.py b/tests/test_stream_llm.py index 3e8058fa..a466792b 100644 --- a/tests/test_stream_llm.py +++ b/tests/test_stream_llm.py @@ -115,3 +115,121 @@ def _boom(*a, **k): assert "from openai import OpenAI" in result.output assert "summarize" in result.output assert "run_chain" in result.output # the live transcribe->LLM-per-turn loop + + +def _eot_turn(text): + # A finalized turn with no diarized speaker (the shape SDK emits without --speaker-labels). + return types.SimpleNamespace(transcript=text, end_of_turn=True, speaker_label=None) + + +def _llm_session(*, interval, clock, monkeypatch, emitted): + import io + + from aai_cli.commands.stream import StreamSession + from aai_cli.follow import FollowRenderer + from aai_cli.streaming.render import StreamRenderer + + # Capture each follow refresh (json mode emits one NDJSON object per refresh) and + # make run_chain echo the transcript it summarized so assertions read the cadence. + monkeypatch.setattr("aai_cli.follow.output.emit_ndjson", lambda obj: emitted.append(obj)) + monkeypatch.setattr( + "aai_cli.streaming.session.llm.run_chain", + lambda api_key, prompts, *, transcript_text, model, max_tokens: transcript_text, + ) + return StreamSession( + api_key="sk", + base_flags={}, + overrides=None, + config_file=None, + renderer=StreamRenderer(json_mode=True, out=io.StringIO()), + follow=FollowRenderer(json_mode=True), + llm_prompts=["sum"], + model="m", + max_tokens=10, + llm_interval=interval, + clock=clock, + ) + + +def test_stream_llm_interval_throttles_between_ticks(monkeypatch): + # --llm-interval re-runs the chain on the first turn, skips turns inside the window, + # and runs again once the interval has elapsed (turns still accumulate throughout). + now = {"t": 1000.0} + emitted: list[dict] = [] + session = _llm_session( + interval=30.0, clock=lambda: now["t"], monkeypatch=monkeypatch, emitted=emitted + ) + session.on_turn(_eot_turn("one")) # first turn -> immediate summary + now["t"] = 1010.0 + session.on_turn(_eot_turn("two")) # +10s within the window -> throttled + now["t"] = 1040.0 + session.on_turn(_eot_turn("three")) # 40s since the last refresh -> summary + session._maybe_summarize(final=True) # nothing new since -> no-op + assert [e["output"] for e in emitted] == ["one", "one two three"] + assert [e["turns"] for e in emitted] == [1, 3] + + +def test_stream_llm_final_flush_summarizes_tail(monkeypatch): + # A closing flush summarizes turns that arrived after the last interval tick, so the + # tail of the conversation is never dropped when the stream stops mid-window. + now = {"t": 1000.0} + emitted: list[dict] = [] + session = _llm_session( + interval=30.0, clock=lambda: now["t"], monkeypatch=monkeypatch, emitted=emitted + ) + session.on_turn(_eot_turn("a")) # immediate summary "a" + session.on_turn(_eot_turn("b")) # same clock -> throttled + session._maybe_summarize(final=True) # flush the tail -> "a b" + assert [e["output"] for e in emitted] == ["a", "a b"] + + +def test_stream_llm_interval_below_one_second_still_throttles(monkeypatch): + # A sub-second interval is still interval mode (llm_interval > 0): turns inside the + # window are batched into one closing flush, not emitted per turn. Pins the `> 0` + # boundary so it can't drift to `> 1` and silently treat 0 "a" + session.on_turn(_eot_turn("b")) # within the 0.5s window -> throttled + session.on_turn(_eot_turn("c")) # still within the window -> throttled + session._maybe_summarize(final=True) # flush the batch -> "a b c" + assert [e["output"] for e in emitted] == ["a", "a b c"] + + +def test_stream_llm_interval_zero_summarizes_every_turn(monkeypatch): + # --llm-interval 0 keeps the legacy per-turn cadence: every finalized turn refreshes. + now = {"t": 1000.0} + emitted: list[dict] = [] + session = _llm_session( + interval=0.0, clock=lambda: now["t"], monkeypatch=monkeypatch, emitted=emitted + ) + session.on_turn(_eot_turn("a")) + session.on_turn(_eot_turn("b")) # not throttled despite the unchanged clock + assert [e["output"] for e in emitted] == ["a", "a b"] + + +def test_maybe_summarize_is_noop_without_follow(): + # Defensive guard: with no FollowRenderer there's nothing to refresh, so the chain + # is never run (no gateway call) regardless of transcript content. + import io + + from aai_cli.commands.stream import StreamSession + from aai_cli.streaming.render import StreamRenderer + + session = StreamSession( + api_key="sk", + base_flags={}, + overrides=None, + config_file=None, + renderer=StreamRenderer(json_mode=True, out=io.StringIO()), + follow=None, + llm_prompts=["sum"], + model="m", + max_tokens=10, + ) + assert session.llm_interval == 0.0 # the default cadence is per-turn until set + session.transcript.append("x") + session._maybe_summarize(final=True) # must not raise or call the gateway diff --git a/tests/test_stream_session.py b/tests/test_stream_session.py index 70440533..36743628 100644 --- a/tests/test_stream_session.py +++ b/tests/test_stream_session.py @@ -241,6 +241,39 @@ def fake_run_chain(api_key, prompts, *, transcript_text, model, max_tokens): assert any("You: FakeMic" in value for value in transcript_inputs) +def test_stream_system_audio_speaker_labels_only_diarizes_system(monkeypatch): + # --speaker-labels diarizes the system audio but never the mic: the "you" session + # is forced to speaker_labels=False so the mic stays a single "You". + config.set_api_key("default", "sk_live") + speaker_labels_by_chunk = {} + + class FakeSystemAudio: + def __init__(self, *, on_open=None): + self.sample_rate = 16000 + + def __iter__(self): + return iter([b"system"]) + + class FakeMic: + def __init__(self, *, target_rate=None, device=None, capture_rate=None, on_open=None): + self.sample_rate = target_rate + + def __iter__(self): + return iter([b"mic"]) + + def fake_stream_audio(api_key, source, *, params, **_kwargs): + chunk = next(iter(source)) + speaker_labels_by_chunk[chunk] = params.speaker_labels + + monkeypatch.setattr("aai_cli.commands.stream.MacSystemAudioSource", FakeSystemAudio) + monkeypatch.setattr("aai_cli.commands.stream.MicrophoneSource", FakeMic) + monkeypatch.setattr("aai_cli.commands.stream.client.stream_audio", fake_stream_audio) + result = runner.invoke(app, ["stream", "--system-audio", "--speaker-labels", "--json"]) + assert result.exit_code == 0 + assert speaker_labels_by_chunk[b"system"] is True + assert speaker_labels_by_chunk[b"mic"] is False + + def test_stream_system_audio_parallel_final_worker_error_surfaces(monkeypatch): config.set_api_key("default", "sk_live") diff --git a/tests/test_streaming_render.py b/tests/test_streaming_render.py index 7bfc20b5..e4050f13 100644 --- a/tests/test_streaming_render.py +++ b/tests/test_streaming_render.py @@ -9,8 +9,10 @@ from aai_cli.streaming.render import StreamRenderer -def _turn(transcript, end_of_turn): - return types.SimpleNamespace(transcript=transcript, end_of_turn=end_of_turn) +def _turn(transcript, end_of_turn, speaker_label=None): + return types.SimpleNamespace( + transcript=transcript, end_of_turn=end_of_turn, speaker_label=speaker_label + ) def _human(width=80, color_system=None): @@ -51,6 +53,53 @@ def test_human_turn_labels_parallel_sources(): assert "\x1b[" in out +def test_human_turn_labels_system_speaker_with_source(): + # --speaker-labels diarizes the system audio: each system turn carries a + # speaker_label, rendered as "System (A):" alongside the source. + r, buf = _human(color_system="truecolor") + r.turn(_turn("first speaker", True, speaker_label="A"), source="system") + r.turn(_turn("second speaker", True, speaker_label="B"), source="system") + r.close() + out = buf.getvalue() + assert "System (A):" in out + assert "System (B):" in out + + +def test_speaker_prefix_tints_each_speaker_distinctly(): + from aai_cli import theme + from aai_cli.streaming.render import speaker_prefix + + # A speaker label is colored by theme.speaker_style, so different speakers in the + # same source render in different colors; an unlabeled sourced turn keeps the + # source's own style. + prefix_a = speaker_prefix("system", "A") + prefix_b = speaker_prefix("system", "B") + assert prefix_a is not None and prefix_b is not None + assert (prefix_a[0], prefix_b[0]) == ("System (A)", "System (B)") + assert prefix_a[1] == theme.speaker_style("A") + assert prefix_b[1] == theme.speaker_style("B") + assert prefix_a[1] != prefix_b[1] + assert speaker_prefix("system", None) == ("System", "aai.agent") + assert speaker_prefix("you", None) == ("You", "aai.you") + + +def test_human_turn_keeps_you_label_without_speaker(): + # The mic session never diarizes, so a "you" turn has no speaker_label and + # stays "You:" even while the system audio is split into speakers. + r, buf = _human(color_system="truecolor") + r.turn(_turn("hello from me", True), source="you") + r.close() + assert "You:" in buf.getvalue() + + +def test_human_turn_labels_bare_speaker_without_source(): + # Single-stream diarization (no source label) shows "Speaker A:". + r, buf = _human() + r.turn(_turn("solo", True, speaker_label="A")) + r.close() + assert "Speaker A:" in buf.getvalue() + + def test_text_mode_labels_sources_and_statuses_to_stderr(): out = io.StringIO() err = io.StringIO() @@ -126,6 +175,26 @@ def test_json_mode_emits_source_when_labeled(): } +def test_json_mode_emits_speaker_when_diarized(): + out = io.StringIO() + r = StreamRenderer(json_mode=True, out=out) + r.turn(_turn("hi", True, speaker_label="A"), source="system") + assert json.loads(out.getvalue()) == { + "type": "turn", + "transcript": "hi", + "end_of_turn": True, + "speaker": "A", + "source": "system", + } + + +def test_text_mode_labels_system_speaker(): + out = io.StringIO() + r = StreamRenderer(json_mode=False, text_mode=True, out=out, err=io.StringIO()) + r.turn(_turn("hi", True, speaker_label="A"), source="system") + assert out.getvalue() == "System (A): hi\n" + + def test_termination_json_emits_duration(): out = io.StringIO() r = StreamRenderer(json_mode=True, out=out) diff --git a/tests/test_theme.py b/tests/test_theme.py index 94666b38..d300b49b 100644 --- a/tests/test_theme.py +++ b/tests/test_theme.py @@ -46,6 +46,15 @@ def test_speaker_style_deterministic_and_in_palette(): assert theme.speaker_style("A") != theme.speaker_style("B") +def test_you_color_reserved_outside_speaker_palette(): + # "You" keeps the brand accent; no diarized speaker may resolve to that same color, + # so the mic is always visually distinct from a system speaker. + console = theme.make_console() + you_color = console.get_style("aai.you").color + speaker_colors = {console.get_style(name).color for name in theme.SPEAKER_STYLES} + assert you_color not in speaker_colors + + def test_output_console_is_themed_and_error_is_styled(monkeypatch): from aai_cli import output, theme from aai_cli.errors import CLIError