Skip to content

Commit 472b69f

Browse files
authored
Split oversized modules and gate file length at 500 lines (#38)
1 parent 9bba7d5 commit 472b69f

11 files changed

Lines changed: 1203 additions & 1058 deletions

aai_cli/commands/stream.py

Lines changed: 7 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
from __future__ import annotations
22

3-
import queue
43
import tempfile
5-
import threading
6-
from collections.abc import Callable, Iterable
7-
from dataclasses import dataclass, field
84
from pathlib import Path
95

106
import typer
@@ -22,228 +18,21 @@
2218
youtube,
2319
)
2420
from aai_cli.context import AppState, run_command
25-
from aai_cli.errors import CLIError, UsageError
21+
from aai_cli.errors import UsageError
2622
from aai_cli.follow import FollowRenderer
2723
from aai_cli.help_text import examples_epilog
2824
from aai_cli.microphone import MicrophoneSource
2925
from aai_cli.streaming.macos import MacSystemAudioSource
3026
from aai_cli.streaming.render import StreamRenderer
27+
from aai_cli.streaming.session import SourceOptions, StreamSession, validate_sources
3128
from aai_cli.streaming.sources import TARGET_RATE, FileSource, StdinSource
3229

3330
app = typer.Typer()
3431

3532
DEFAULT_SPEECH_MODEL = SpeechModel.u3_rt_pro
3633

37-
# Sources that can be transcribed in parallel sessions: (label, audio chunks, sample rate).
38-
_ParallelStreams = list[tuple[str, Iterable[bytes], int]]
3934

40-
41-
@dataclass(frozen=True)
42-
class _SourceOptions:
43-
"""Where the audio comes from, distilled from the CLI flags.
44-
45-
Centralizes the "which input?" predicates so the validation and dispatch helpers
46-
below read off one object instead of re-deriving the same booleans.
47-
"""
48-
49-
source: str | None
50-
sample: bool
51-
sample_rate: int | None
52-
device: int | None
53-
system_audio: bool
54-
system_audio_only: bool
55-
56-
@property
57-
def from_stdin(self) -> bool:
58-
return self.source == "-"
59-
60-
@property
61-
def from_file(self) -> bool:
62-
return bool(self.source) or self.sample
63-
64-
@property
65-
def from_system_audio(self) -> bool:
66-
return self.system_audio or self.system_audio_only
67-
68-
@property
69-
def has_capture_overrides(self) -> bool:
70-
"""Whether a microphone-only flag (--sample-rate or --device) was given."""
71-
return self.sample_rate is not None or self.device is not None
72-
73-
74-
def _validate_sources(opts: _SourceOptions, *, has_llm: bool, text_mode: bool) -> None:
75-
"""Reject flag combinations that can't be honored, before any audio is opened."""
76-
if opts.system_audio and opts.system_audio_only:
77-
raise UsageError("Use either --system-audio or --system-audio-only, not both.")
78-
_validate_input_source(opts)
79-
if has_llm and text_mode:
80-
raise UsageError(
81-
"--llm renders a live panel (or NDJSON when piped); it can't be combined with -o text."
82-
)
83-
84-
85-
def _validate_input_source(opts: _SourceOptions) -> None:
86-
"""Reject --sample-rate/--device/source combinations the chosen input can't accept."""
87-
if opts.from_system_audio:
88-
if opts.from_file:
89-
raise UsageError("--system-audio cannot be combined with an audio source or --sample.")
90-
if opts.system_audio_only and opts.has_capture_overrides:
91-
raise UsageError(
92-
"--sample-rate and --device require microphone input; use --system-audio."
93-
)
94-
elif opts.from_stdin:
95-
if opts.device is not None:
96-
raise UsageError("--device applies only to microphone input.")
97-
elif opts.from_file and opts.has_capture_overrides:
98-
raise UsageError("--sample-rate and --device apply only to microphone input.")
99-
100-
101-
@dataclass
102-
class _StreamSession:
103-
"""Owns one streaming run: the renderers, the LLM-chain state, and the audio
104-
plumbing shared across single- and parallel-source streaming.
105-
106-
Holding this as an object (rather than a nest of closures inside the command body)
107-
keeps each step a small, independently readable method, and collapses the ~25
108-
per-call flags into one ``base_flags`` dict that only varies by sample rate.
109-
"""
110-
111-
api_key: str
112-
base_flags: dict[str, object]
113-
overrides: list[str] | None
114-
config_file: str | Path | None
115-
renderer: StreamRenderer
116-
follow: FollowRenderer | None
117-
llm_prompts: list[str]
118-
model: str
119-
max_tokens: int
120-
transcript: list[str] = field(default_factory=list[str])
121-
_callback_lock: threading.RLock = field(default_factory=threading.RLock)
122-
_listening_lock: threading.Lock = field(default_factory=threading.Lock)
123-
_listening_started: bool = False
124-
125-
@property
126-
def on_open(self) -> Callable[[], None]:
127-
"""First-audio callback: announce "Listening…" once — unless the FollowRenderer
128-
owns the screen in --llm mode, where the notice would clutter the live panel."""
129-
return (lambda: None) if self.follow is not None else self._listening_once
130-
131-
def _listening_once(self) -> None:
132-
with self._listening_lock:
133-
if self._listening_started:
134-
return
135-
self._listening_started = True
136-
self.renderer.listening()
137-
138-
def on_turn(self, event: object, *, source_label: str | None = None) -> None:
139-
with self._callback_lock:
140-
if self.follow is None:
141-
self.renderer.turn(event, source=source_label)
142-
else:
143-
self._refresh_answer(event, source_label)
144-
145-
def _refresh_answer(self, event: object, source_label: str | None) -> None:
146-
"""Live --llm mode: re-run the prompt chain over the growing transcript on every
147-
finalized turn, refreshing one evolving answer (partials are ignored)."""
148-
follow = self.follow
149-
if follow is None or not getattr(event, "end_of_turn", False):
150-
return
151-
text = getattr(event, "transcript", "") or ""
152-
if not text:
153-
return
154-
if source_label is not None:
155-
display_source = {"system": "System", "you": "You"}.get(source_label, source_label)
156-
text = f"{display_source}: {text}"
157-
self.transcript.append(text)
158-
answer = llm.run_chain(
159-
self.api_key,
160-
self.llm_prompts,
161-
transcript_text=" ".join(self.transcript),
162-
model=self.model,
163-
max_tokens=self.max_tokens,
164-
)
165-
follow(answer, len(self.transcript))
166-
167-
def stream_one(
168-
self, audio: Iterable[bytes], rate: int, *, source_label: str | None = None
169-
) -> None:
170-
merged = config_builder.merge_streaming_params(
171-
flags=self.base_flags | {"sample_rate": rate},
172-
overrides=self.overrides,
173-
config_file=self.config_file,
174-
)
175-
params = config_builder.construct_streaming_params(merged)
176-
client.stream_audio(
177-
self.api_key,
178-
audio,
179-
params=params,
180-
on_begin=(
181-
None
182-
if self.follow is not None
183-
else lambda event: self.renderer.begin(event, source=source_label)
184-
),
185-
on_turn=lambda event: self.on_turn(event, source_label=source_label),
186-
on_termination=(
187-
None
188-
if self.follow is not None
189-
else lambda event: self.renderer.termination(event, source=source_label)
190-
),
191-
)
192-
193-
def _guarded(self, work: Callable[[], None]) -> None:
194-
"""Run a streaming body with the shared lifecycle handling: enter the
195-
FollowRenderer's live panel if present, treat Ctrl-C as a clean stop, exit 0 on
196-
a closed downstream pipe, and always close the renderer."""
197-
try:
198-
if self.follow is not None:
199-
with self.follow:
200-
work()
201-
else:
202-
work()
203-
except KeyboardInterrupt:
204-
# Ctrl-C is a normal "user stopped" signal -> exit 0.
205-
if self.follow is None:
206-
self.renderer.close()
207-
self.renderer.stopped()
208-
except BrokenPipeError:
209-
# Downstream consumer (e.g. `| head`) closed the pipe; stop quietly.
210-
raise typer.Exit(code=0) from None
211-
finally:
212-
if self.follow is None:
213-
self.renderer.close()
214-
215-
def run(self, audio: Iterable[bytes], rate: int, *, source_label: str | None = None) -> None:
216-
self._guarded(lambda: self.stream_one(audio, rate, source_label=source_label))
217-
218-
def run_parallel(self, streams: _ParallelStreams) -> None:
219-
self._guarded(lambda: self._drive(streams))
220-
221-
def _drive(self, streams: _ParallelStreams) -> None:
222-
"""Stream every source concurrently, surfacing the first worker error."""
223-
errors: queue.Queue[Exception] = queue.Queue()
224-
225-
def worker(source_label: str, audio: Iterable[bytes], rate: int) -> None:
226-
try:
227-
self.stream_one(audio, rate, source_label=source_label)
228-
except (CLIError, BrokenPipeError) as exc:
229-
errors.put(exc)
230-
231-
threads = [
232-
threading.Thread(target=worker, args=(label, audio, rate), daemon=True)
233-
for label, audio, rate in streams
234-
]
235-
for thread in threads:
236-
thread.start()
237-
while any(thread.is_alive() for thread in threads):
238-
for thread in threads:
239-
thread.join(timeout=0.1)
240-
if not errors.empty():
241-
raise errors.get()
242-
if not errors.empty():
243-
raise errors.get()
244-
245-
246-
def _dispatch(session: _StreamSession, opts: _SourceOptions) -> None:
35+
def _dispatch(session: StreamSession, opts: SourceOptions) -> None:
24736
"""Open the right audio source(s) for the flags and stream them."""
24837
if opts.from_system_audio:
24938
system = MacSystemAudioSource(on_open=session.on_open)
@@ -274,7 +63,7 @@ def _dispatch(session: _StreamSession, opts: _SourceOptions) -> None:
27463
else:
27564
# Capture at the device's native rate (or --sample-rate override) and tell the
27665
# streaming API that rate, rather than forcing one the device may reject.
277-
# "Listening…" is announced once the device is open (see _StreamSession.on_open),
66+
# "Listening…" is announced once the device is open (see StreamSession.on_open),
27867
# not when the session opens — so early speech isn't lost in the gap.
27968
mic = MicrophoneSource(
28069
device=opts.device, capture_rate=opts.sample_rate, on_open=session.on_open
@@ -522,7 +311,7 @@ def stream(
522311

523312
def body(state: AppState, json_mode: bool) -> None:
524313
text_mode, json_mode = output.stream_output_modes(output_field, json_mode=json_mode)
525-
opts = _SourceOptions(
314+
opts = SourceOptions(
526315
source=source,
527316
sample=sample,
528317
sample_rate=sample_rate,
@@ -573,10 +362,10 @@ def body(state: AppState, json_mode: bool) -> None:
573362
return
574363

575364
api_key = config.resolve_api_key(profile=state.profile)
576-
_validate_sources(opts, has_llm=bool(llm_prompt), text_mode=text_mode)
365+
validate_sources(opts, has_llm=bool(llm_prompt), text_mode=text_mode)
577366

578367
llm_prompts = list(llm_prompt or [])
579-
session = _StreamSession(
368+
session = StreamSession(
580369
api_key=api_key,
581370
base_flags=base_flags,
582371
overrides=config_kv,

0 commit comments

Comments
 (0)