|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import queue |
4 | 3 | import tempfile |
5 | | -import threading |
6 | | -from collections.abc import Callable, Iterable |
7 | | -from dataclasses import dataclass, field |
8 | 4 | from pathlib import Path |
9 | 5 |
|
10 | 6 | import typer |
|
22 | 18 | youtube, |
23 | 19 | ) |
24 | 20 | from aai_cli.context import AppState, run_command |
25 | | -from aai_cli.errors import CLIError, UsageError |
| 21 | +from aai_cli.errors import UsageError |
26 | 22 | from aai_cli.follow import FollowRenderer |
27 | 23 | from aai_cli.help_text import examples_epilog |
28 | 24 | from aai_cli.microphone import MicrophoneSource |
29 | 25 | from aai_cli.streaming.macos import MacSystemAudioSource |
30 | 26 | from aai_cli.streaming.render import StreamRenderer |
| 27 | +from aai_cli.streaming.session import SourceOptions, StreamSession, validate_sources |
31 | 28 | from aai_cli.streaming.sources import TARGET_RATE, FileSource, StdinSource |
32 | 29 |
|
33 | 30 | app = typer.Typer() |
34 | 31 |
|
35 | 32 | DEFAULT_SPEECH_MODEL = SpeechModel.u3_rt_pro |
36 | 33 |
|
37 | | -# Sources that can be transcribed in parallel sessions: (label, audio chunks, sample rate). |
38 | | -_ParallelStreams = list[tuple[str, Iterable[bytes], int]] |
39 | 34 |
|
40 | | - |
41 | | -@dataclass(frozen=True) |
42 | | -class _SourceOptions: |
43 | | - """Where the audio comes from, distilled from the CLI flags. |
44 | | -
|
45 | | - Centralizes the "which input?" predicates so the validation and dispatch helpers |
46 | | - below read off one object instead of re-deriving the same booleans. |
47 | | - """ |
48 | | - |
49 | | - source: str | None |
50 | | - sample: bool |
51 | | - sample_rate: int | None |
52 | | - device: int | None |
53 | | - system_audio: bool |
54 | | - system_audio_only: bool |
55 | | - |
56 | | - @property |
57 | | - def from_stdin(self) -> bool: |
58 | | - return self.source == "-" |
59 | | - |
60 | | - @property |
61 | | - def from_file(self) -> bool: |
62 | | - return bool(self.source) or self.sample |
63 | | - |
64 | | - @property |
65 | | - def from_system_audio(self) -> bool: |
66 | | - return self.system_audio or self.system_audio_only |
67 | | - |
68 | | - @property |
69 | | - def has_capture_overrides(self) -> bool: |
70 | | - """Whether a microphone-only flag (--sample-rate or --device) was given.""" |
71 | | - return self.sample_rate is not None or self.device is not None |
72 | | - |
73 | | - |
74 | | -def _validate_sources(opts: _SourceOptions, *, has_llm: bool, text_mode: bool) -> None: |
75 | | - """Reject flag combinations that can't be honored, before any audio is opened.""" |
76 | | - if opts.system_audio and opts.system_audio_only: |
77 | | - raise UsageError("Use either --system-audio or --system-audio-only, not both.") |
78 | | - _validate_input_source(opts) |
79 | | - if has_llm and text_mode: |
80 | | - raise UsageError( |
81 | | - "--llm renders a live panel (or NDJSON when piped); it can't be combined with -o text." |
82 | | - ) |
83 | | - |
84 | | - |
85 | | -def _validate_input_source(opts: _SourceOptions) -> None: |
86 | | - """Reject --sample-rate/--device/source combinations the chosen input can't accept.""" |
87 | | - if opts.from_system_audio: |
88 | | - if opts.from_file: |
89 | | - raise UsageError("--system-audio cannot be combined with an audio source or --sample.") |
90 | | - if opts.system_audio_only and opts.has_capture_overrides: |
91 | | - raise UsageError( |
92 | | - "--sample-rate and --device require microphone input; use --system-audio." |
93 | | - ) |
94 | | - elif opts.from_stdin: |
95 | | - if opts.device is not None: |
96 | | - raise UsageError("--device applies only to microphone input.") |
97 | | - elif opts.from_file and opts.has_capture_overrides: |
98 | | - raise UsageError("--sample-rate and --device apply only to microphone input.") |
99 | | - |
100 | | - |
101 | | -@dataclass |
102 | | -class _StreamSession: |
103 | | - """Owns one streaming run: the renderers, the LLM-chain state, and the audio |
104 | | - plumbing shared across single- and parallel-source streaming. |
105 | | -
|
106 | | - Holding this as an object (rather than a nest of closures inside the command body) |
107 | | - keeps each step a small, independently readable method, and collapses the ~25 |
108 | | - per-call flags into one ``base_flags`` dict that only varies by sample rate. |
109 | | - """ |
110 | | - |
111 | | - api_key: str |
112 | | - base_flags: dict[str, object] |
113 | | - overrides: list[str] | None |
114 | | - config_file: str | Path | None |
115 | | - renderer: StreamRenderer |
116 | | - follow: FollowRenderer | None |
117 | | - llm_prompts: list[str] |
118 | | - model: str |
119 | | - max_tokens: int |
120 | | - transcript: list[str] = field(default_factory=list[str]) |
121 | | - _callback_lock: threading.RLock = field(default_factory=threading.RLock) |
122 | | - _listening_lock: threading.Lock = field(default_factory=threading.Lock) |
123 | | - _listening_started: bool = False |
124 | | - |
125 | | - @property |
126 | | - def on_open(self) -> Callable[[], None]: |
127 | | - """First-audio callback: announce "Listening…" once — unless the FollowRenderer |
128 | | - owns the screen in --llm mode, where the notice would clutter the live panel.""" |
129 | | - return (lambda: None) if self.follow is not None else self._listening_once |
130 | | - |
131 | | - def _listening_once(self) -> None: |
132 | | - with self._listening_lock: |
133 | | - if self._listening_started: |
134 | | - return |
135 | | - self._listening_started = True |
136 | | - self.renderer.listening() |
137 | | - |
138 | | - def on_turn(self, event: object, *, source_label: str | None = None) -> None: |
139 | | - with self._callback_lock: |
140 | | - if self.follow is None: |
141 | | - self.renderer.turn(event, source=source_label) |
142 | | - else: |
143 | | - self._refresh_answer(event, source_label) |
144 | | - |
145 | | - def _refresh_answer(self, event: object, source_label: str | None) -> None: |
146 | | - """Live --llm mode: re-run the prompt chain over the growing transcript on every |
147 | | - finalized turn, refreshing one evolving answer (partials are ignored).""" |
148 | | - follow = self.follow |
149 | | - if follow is None or not getattr(event, "end_of_turn", False): |
150 | | - return |
151 | | - text = getattr(event, "transcript", "") or "" |
152 | | - if not text: |
153 | | - return |
154 | | - if source_label is not None: |
155 | | - display_source = {"system": "System", "you": "You"}.get(source_label, source_label) |
156 | | - text = f"{display_source}: {text}" |
157 | | - self.transcript.append(text) |
158 | | - answer = llm.run_chain( |
159 | | - self.api_key, |
160 | | - self.llm_prompts, |
161 | | - transcript_text=" ".join(self.transcript), |
162 | | - model=self.model, |
163 | | - max_tokens=self.max_tokens, |
164 | | - ) |
165 | | - follow(answer, len(self.transcript)) |
166 | | - |
167 | | - def stream_one( |
168 | | - self, audio: Iterable[bytes], rate: int, *, source_label: str | None = None |
169 | | - ) -> None: |
170 | | - merged = config_builder.merge_streaming_params( |
171 | | - flags=self.base_flags | {"sample_rate": rate}, |
172 | | - overrides=self.overrides, |
173 | | - config_file=self.config_file, |
174 | | - ) |
175 | | - params = config_builder.construct_streaming_params(merged) |
176 | | - client.stream_audio( |
177 | | - self.api_key, |
178 | | - audio, |
179 | | - params=params, |
180 | | - on_begin=( |
181 | | - None |
182 | | - if self.follow is not None |
183 | | - else lambda event: self.renderer.begin(event, source=source_label) |
184 | | - ), |
185 | | - on_turn=lambda event: self.on_turn(event, source_label=source_label), |
186 | | - on_termination=( |
187 | | - None |
188 | | - if self.follow is not None |
189 | | - else lambda event: self.renderer.termination(event, source=source_label) |
190 | | - ), |
191 | | - ) |
192 | | - |
193 | | - def _guarded(self, work: Callable[[], None]) -> None: |
194 | | - """Run a streaming body with the shared lifecycle handling: enter the |
195 | | - FollowRenderer's live panel if present, treat Ctrl-C as a clean stop, exit 0 on |
196 | | - a closed downstream pipe, and always close the renderer.""" |
197 | | - try: |
198 | | - if self.follow is not None: |
199 | | - with self.follow: |
200 | | - work() |
201 | | - else: |
202 | | - work() |
203 | | - except KeyboardInterrupt: |
204 | | - # Ctrl-C is a normal "user stopped" signal -> exit 0. |
205 | | - if self.follow is None: |
206 | | - self.renderer.close() |
207 | | - self.renderer.stopped() |
208 | | - except BrokenPipeError: |
209 | | - # Downstream consumer (e.g. `| head`) closed the pipe; stop quietly. |
210 | | - raise typer.Exit(code=0) from None |
211 | | - finally: |
212 | | - if self.follow is None: |
213 | | - self.renderer.close() |
214 | | - |
215 | | - def run(self, audio: Iterable[bytes], rate: int, *, source_label: str | None = None) -> None: |
216 | | - self._guarded(lambda: self.stream_one(audio, rate, source_label=source_label)) |
217 | | - |
218 | | - def run_parallel(self, streams: _ParallelStreams) -> None: |
219 | | - self._guarded(lambda: self._drive(streams)) |
220 | | - |
221 | | - def _drive(self, streams: _ParallelStreams) -> None: |
222 | | - """Stream every source concurrently, surfacing the first worker error.""" |
223 | | - errors: queue.Queue[Exception] = queue.Queue() |
224 | | - |
225 | | - def worker(source_label: str, audio: Iterable[bytes], rate: int) -> None: |
226 | | - try: |
227 | | - self.stream_one(audio, rate, source_label=source_label) |
228 | | - except (CLIError, BrokenPipeError) as exc: |
229 | | - errors.put(exc) |
230 | | - |
231 | | - threads = [ |
232 | | - threading.Thread(target=worker, args=(label, audio, rate), daemon=True) |
233 | | - for label, audio, rate in streams |
234 | | - ] |
235 | | - for thread in threads: |
236 | | - thread.start() |
237 | | - while any(thread.is_alive() for thread in threads): |
238 | | - for thread in threads: |
239 | | - thread.join(timeout=0.1) |
240 | | - if not errors.empty(): |
241 | | - raise errors.get() |
242 | | - if not errors.empty(): |
243 | | - raise errors.get() |
244 | | - |
245 | | - |
246 | | -def _dispatch(session: _StreamSession, opts: _SourceOptions) -> None: |
| 35 | +def _dispatch(session: StreamSession, opts: SourceOptions) -> None: |
247 | 36 | """Open the right audio source(s) for the flags and stream them.""" |
248 | 37 | if opts.from_system_audio: |
249 | 38 | system = MacSystemAudioSource(on_open=session.on_open) |
@@ -274,7 +63,7 @@ def _dispatch(session: _StreamSession, opts: _SourceOptions) -> None: |
274 | 63 | else: |
275 | 64 | # Capture at the device's native rate (or --sample-rate override) and tell the |
276 | 65 | # streaming API that rate, rather than forcing one the device may reject. |
277 | | - # "Listening…" is announced once the device is open (see _StreamSession.on_open), |
| 66 | + # "Listening…" is announced once the device is open (see StreamSession.on_open), |
278 | 67 | # not when the session opens — so early speech isn't lost in the gap. |
279 | 68 | mic = MicrophoneSource( |
280 | 69 | device=opts.device, capture_rate=opts.sample_rate, on_open=session.on_open |
@@ -522,7 +311,7 @@ def stream( |
522 | 311 |
|
523 | 312 | def body(state: AppState, json_mode: bool) -> None: |
524 | 313 | text_mode, json_mode = output.stream_output_modes(output_field, json_mode=json_mode) |
525 | | - opts = _SourceOptions( |
| 314 | + opts = SourceOptions( |
526 | 315 | source=source, |
527 | 316 | sample=sample, |
528 | 317 | sample_rate=sample_rate, |
@@ -573,10 +362,10 @@ def body(state: AppState, json_mode: bool) -> None: |
573 | 362 | return |
574 | 363 |
|
575 | 364 | api_key = config.resolve_api_key(profile=state.profile) |
576 | | - _validate_sources(opts, has_llm=bool(llm_prompt), text_mode=text_mode) |
| 365 | + validate_sources(opts, has_llm=bool(llm_prompt), text_mode=text_mode) |
577 | 366 |
|
578 | 367 | llm_prompts = list(llm_prompt or []) |
579 | | - session = _StreamSession( |
| 368 | + session = StreamSession( |
580 | 369 | api_key=api_key, |
581 | 370 | base_flags=base_flags, |
582 | 371 | overrides=config_kv, |
|
0 commit comments