From 6d6e1acf1750a0b456a051587a3a57045a18f1f7 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 12 Jun 2026 21:32:39 +0000 Subject: [PATCH] Add assembly dictate: hotkey push-to-talk over the Sync STT API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A new run command for hands-free terminal dictation: press Enter (or Space) to start recording the microphone, press again to stop, and the utterance is POSTed to the Sync API (sync.assemblyai.com/transcribe, X-AAI-Model: u3-sync-pro) which returns the transcript in the response body — no polling. q/Esc/Ctrl-C ends the session. - sync_stt.py: the Sync API HTTP boundary (httpx2, multipart raw-PCM + JSON config), normalizing 401/403 to auth_failure(), 429/503 to a retryable APIError, and error_code/message bodies into clean messages. Environment gains a sync_base host (prod + sandbox). - hotkey.py: TerminalKeys reads single keypresses with cbreak scoped to a with-block (Ctrl-C still signals); clean not-a-tty / no-termios errors instead of tracebacks. - dictate_exec.py: the options/run split. Capture is resampled to 16 kHz PCM16; the key poll runs with zero timeout between ~100 ms mic chunks; recordings are capped at the API's 120 s limit and ones under its 80 ms floor are skipped with a warning instead of a server 400. --language (comma list for code-switching), --prompt, --word-boost, --device, --once, --max-seconds, --json (one NDJSON object per utterance). Tests cover the HTTP boundary via MockTransport, the termios behavior via a real pty pair, and the session loop via injected key/mic/HTTP seams; the terminal requirement is validated before credentials. https://claude.ai/code/session_01FCXQLAyo8xpZiXrQ7hCMAf --- .importlinter | 6 + AGENTS.md | 5 +- README.md | 3 +- aai_cli/commands/dictate.py | 74 +++++ aai_cli/dictate_exec.py | 187 +++++++++++ aai_cli/environments.py | 3 + aai_cli/hotkey.py | 87 +++++ aai_cli/main.py | 3 + aai_cli/sync_stt.py | 141 ++++++++ .../test_snapshots_help_run.ambr | 50 +++ tests/_snapshot_surface.py | 4 +- tests/test_dictate_command.py | 84 +++++ tests/test_dictate_exec.py | 312 ++++++++++++++++++ tests/test_environments.py | 5 + tests/test_hotkey.py | 119 +++++++ tests/test_smoke.py | 1 + tests/test_sync_stt.py | 222 +++++++++++++ 17 files changed, 1302 insertions(+), 4 deletions(-) create mode 100644 aai_cli/commands/dictate.py create mode 100644 aai_cli/dictate_exec.py create mode 100644 aai_cli/hotkey.py create mode 100644 aai_cli/sync_stt.py create mode 100644 tests/test_dictate_command.py create mode 100644 tests/test_dictate_exec.py create mode 100644 tests/test_hotkey.py create mode 100644 tests/test_sync_stt.py diff --git a/.importlinter b/.importlinter index d5a80c7c..e4a7f75d 100644 --- a/.importlinter +++ b/.importlinter @@ -17,12 +17,14 @@ source_modules = aai_cli.config_builder aai_cli.context aai_cli.debuglog + aai_cli.dictate_exec aai_cli.environments aai_cli.errors aai_cli.eval_data aai_cli.follow aai_cli.help_panels aai_cli.help_text + aai_cli.hotkey aai_cli.init aai_cli.llm aai_cli.llm_exec @@ -35,6 +37,7 @@ source_modules = aai_cli.stdio aai_cli.stream_exec aai_cli.streaming + aai_cli.sync_stt aai_cli.telemetry aai_cli.theme aai_cli.transcribe_batch @@ -56,6 +59,7 @@ modules = aai_cli.commands.audit aai_cli.commands.deploy aai_cli.commands.dev + aai_cli.commands.dictate aai_cli.commands.doctor aai_cli.commands.evaluate aai_cli.commands.init @@ -83,7 +87,9 @@ source_modules = aai_cli.environments aai_cli.errors aai_cli.eval_data + aai_cli.hotkey aai_cli.llm + aai_cli.sync_stt aai_cli.telemetry aai_cli.wer forbidden_modules = diff --git a/AGENTS.md b/AGENTS.md index 67a520a0..21b9cd7d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -162,9 +162,9 @@ A Typer CLI. `aai_cli/main.py` builds the `app`, registers each command sub-app, ### Command layer -Each file in `aai_cli/commands/` is a Typer sub-app (`transcribe`, `stream`, `agent`, `speak`, `llm`, `transcripts`, `login` (login/logout/whoami), `doctor`, `init`, `dev`, `share`, `deploy`, `setup`, `onboard`, `account` (balance/usage/limits), `keys`, `sessions`, `audit`, `telemetry` (status/enable/disable), `webhooks` (listen)). Command bodies run through `context.run_command(ctx, fn, json=...)`, which maps any `CLIError` to clean stderr output + the error's exit code. Commands never print tracebacks for expected failures. +Each file in `aai_cli/commands/` is a Typer sub-app (`transcribe`, `stream`, `dictate`, `agent`, `speak`, `llm`, `transcripts`, `login` (login/logout/whoami), `doctor`, `init`, `dev`, `share`, `deploy`, `setup`, `onboard`, `account` (balance/usage/limits), `keys`, `sessions`, `audit`, `telemetry` (status/enable/disable), `webhooks` (listen)). Command bodies run through `context.run_command(ctx, fn, json=...)`, which maps any `CLIError` to clean stderr output + the error's exit code. Commands never print tracebacks for expected failures. -**Options/run split for flag-heavy commands** (gh-CLI style): the Typer function only parses argv into a frozen `Options` dataclass and hands it to a module-level `run_(opts, state, *, json_mode)` through a thin lambda adapter in `run_command(ctx, ..., json=...)`. The five run commands follow it — `aai_cli/stream_exec.py` (the reference implementation), `transcribe_exec.py`, `agent_exec.py`, `speak_exec.py`, `llm_exec.py`. Because the run path is a plain function of data, tests construct options directly (`dataclasses.replace` off a defaults instance, see `tests/test_stream_exec.py` and `tests/test_command_options_seam.py`) instead of round-tripping argv through `CliRunner` — which is also the cheap way to kill mutation-gate mutants on orchestration lines. Follow this for new or heavily-reworked commands with long bodies; small commands keep the inline `body()` closure — the dataclass is pure ceremony there. +**Options/run split for flag-heavy commands** (gh-CLI style): the Typer function only parses argv into a frozen `Options` dataclass and hands it to a module-level `run_(opts, state, *, json_mode)` through a thin lambda adapter in `run_command(ctx, ..., json=...)`. The six run commands follow it — `aai_cli/stream_exec.py` (the reference implementation), `transcribe_exec.py`, `agent_exec.py`, `speak_exec.py`, `llm_exec.py`, `dictate_exec.py`. Because the run path is a plain function of data, tests construct options directly (`dataclasses.replace` off a defaults instance, see `tests/test_stream_exec.py` and `tests/test_command_options_seam.py`) instead of round-tripping argv through `CliRunner` — which is also the cheap way to kill mutation-gate mutants on orchestration lines. Follow this for new or heavily-reworked commands with long bodies; small commands keep the inline `body()` closure — the dataclass is pure ceremony there. ### Cross-cutting state (resolution order matters) @@ -178,6 +178,7 @@ Each file in `aai_cli/commands/` is a Typer sub-app (`transcribe`, `stream`, `ag ### Feature subsystems - **`streaming/`** + `client.stream_audio` — v3 realtime API. Event callbacks run on the SDK reader thread and guard against `BrokenPipeError` (`stdio.silence_stdout()`) so a closed pipe never dumps a thread traceback. +- **`sync_stt.py`** + **`hotkey.py`** + `commands/dictate.py` — `assembly dictate`: push-to-talk dictation over the **Sync STT API** (`Environment.sync_base`, one POST `/transcribe` per utterance with the required `X-AAI-Model: u3-sync-pro` header; 80 ms–120 s of PCM/WAV). `hotkey.TerminalKeys` scopes stdin into cbreak (Ctrl-C still signals) and reads single keypresses; `dictate_exec._record` polls it with a zero timeout between ~100 ms mic chunks. All three boundaries (keys, mic, HTTP) are injectable, so the suite never needs a real terminal — `tests/test_hotkey.py` drives a pty pair for the termios behavior. - **`agent/`** — full-duplex voice agent (mic in, TTS out via `voices.py`). - **`tts/`** + `commands/speak.py` — `assembly speak` synthesizes text to speech over the sandbox streaming-TTS WebSocket (`streaming-tts.sandbox000.…`). **Sandbox-only:** `session.is_available()` is false in production (empty `Environment.streaming_tts_host`), so the command exits 2 with a `--sandbox` hint. `session.synthesize` drives a Begin→Generate→Flush→Audio→Terminate protocol with an injectable `connect` for hermetic tests (mirrors `agent/session.py`); `audio.py` plays the PCM (default) or writes a WAV (`--out`). - **`code_gen/`** — backs `--show-code` on `transcribe`/`stream`/`agent`: builds a ready-to-run Python SDK script from exactly the flags passed (no API key needed; generated code reads `ASSEMBLYAI_API_KEY`). diff --git a/README.md b/README.md index 9279a73f..828f624c 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ uv tool install "git+https://github.com/AssemblyAI/cli.git" If your default interpreter is older than Python 3.12, add `--python python3.12` (pipx) or `--python 3.12` (uv) to the install command. -Only the live-audio commands need anything extra: `stream` and `agent` use PortAudio for +Only the live-audio commands need anything extra: `stream`, `dictate`, and `agent` use PortAudio for microphone capture (Debian/Ubuntu: `sudo apt-get install libportaudio2`; Fedora: `sudo dnf install portaudio`) and [`ffmpeg`](https://ffmpeg.org) on `PATH` to stream non-WAV audio. Plain `transcribe` uploads your file directly and needs neither. @@ -164,6 +164,7 @@ assembly init # scaffold a starter app - **Transcription**: `assembly transcribe` handles files, URLs, and YouTube/podcast pages, with flags for speaker labels, PII redaction, summarization, sentiment, chapters, and more. - **Batch transcription**: point `assembly transcribe` at a directory or glob (or pipe paths with `--from-stdin`) to transcribe everything concurrently, with sidecar files that make re-runs resumable. Add `--llm "prompt"` to run an LLM prompt over each finished transcript, saved into the sidecars. - **Real-time streaming**: `assembly stream` transcribes the microphone, a file, or a URL live — on macOS it can capture system audio too. +- **Dictation**: `assembly dictate` is push-to-talk for your terminal — press Enter to record, Enter again to get the utterance back instantly from the Sync API (up to 120 s per utterance). - **Voice agent**: `assembly agent` runs a full-duplex spoken conversation in your terminal. - **LLM Gateway**: `assembly llm` prompts an LLM over a transcript, stdin, or a live stream (`assembly stream --llm "summarize as I talk"`). - **Model evaluation**: `assembly eval` transcribes a Hugging Face dataset (with built-in aliases for common benchmarks: `assembly eval tedlium`) or a local `.csv`/`.jsonl` manifest and scores WER against its references — handy for picking a speech model. diff --git a/aai_cli/commands/dictate.py b/aai_cli/commands/dictate.py new file mode 100644 index 00000000..34261a20 --- /dev/null +++ b/aai_cli/commands/dictate.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import typer + +from aai_cli import dictate_exec, help_panels, options +from aai_cli.context import run_command +from aai_cli.help_text import examples_epilog +from aai_cli.sync_stt import MAX_AUDIO_SECONDS + +app = typer.Typer() + + +@app.command( + rich_help_panel=help_panels.TRANSCRIPTION, + epilog=examples_epilog( + [ + ("Dictate: Enter starts a recording, Enter transcribes it", "assembly dictate"), + ("One utterance, then exit", "assembly dictate --once"), + ("Dictate in Spanish", "assembly dictate --language es"), + ( + "Bias recognition toward tricky terms", + "assembly dictate --word-boost AssemblyAI --word-boost LeMUR", + ), + ("One JSON object per utterance", "assembly dictate --json"), + ] + ), +) +def dictate( + ctx: typer.Context, + language: str | None = typer.Option( + None, + "--language", + help="ISO 639-1 language code, or a comma-separated list for " + "code-switching audio (default: en).", + ), + prompt: str | None = typer.Option( + None, + "--prompt", + help="Custom transcription prompt (overrides --language).", + ), + word_boost: list[str] | None = typer.Option( + None, "--word-boost", help="Bias recognition toward a term (repeatable)." + ), + device: int | None = typer.Option(None, "--device", help="Microphone device index."), + once: bool = typer.Option(False, "--once", help="Transcribe one utterance, then exit."), + max_seconds: float = typer.Option( + float(MAX_AUDIO_SECONDS), + "--max-seconds", + help="Auto-stop a recording after this many seconds.", + min=1.0, + max=float(MAX_AUDIO_SECONDS), + ), + json_out: bool = options.json_option("Emit one JSON object per utterance."), +) -> None: + """Dictate with a hotkey: record the mic, get the transcript back instantly. + + Press Enter (or Space) to start recording and press it again to stop; the + utterance is sent to the AssemblyAI Sync API and the transcript prints + immediately — no polling. Press q (or Esc/Ctrl-C) to finish. Each utterance + can be up to 120 seconds long. + """ + opts = dictate_exec.DictateOptions( + language=language, + prompt=prompt, + word_boost=word_boost, + device=device, + once=once, + max_seconds=max_seconds, + ) + run_command( + ctx, + lambda state, json_mode: dictate_exec.run_dictate(opts, state, json_mode=json_mode), + json=json_out, + ) diff --git a/aai_cli/dictate_exec.py b/aai_cli/dictate_exec.py new file mode 100644 index 00000000..896ca78e --- /dev/null +++ b/aai_cli/dictate_exec.py @@ -0,0 +1,187 @@ +"""Run logic for `assembly dictate`: the options/run split (see AGENTS.md). + +Push-to-talk dictation over the Sync STT API: wait for a hotkey, record the +microphone until the hotkey is pressed again (or the duration cap), POST the +utterance to the Sync API, print the transcript, repeat. The command module +(aai_cli/commands/dictate.py) only parses argv into a ``DictateOptions``; tests +drive the session by constructing options directly and injecting the key/mic/ +HTTP boundaries, with no CliRunner argv round-trip and no real terminal. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from aai_cli import output, sync_stt +from aai_cli.context import AppState +from aai_cli.hotkey import CTRL_C, CTRL_D, ESC, TerminalKeys +from aai_cli.microphone import MicrophoneSource + +# Capture is resampled to one rate the Sync API accepts; 16 kHz mono PCM16 keeps +# a 120 s utterance well under the 40 MB upload cap. +TARGET_RATE = 16000 +_BYTES_PER_SECOND = TARGET_RATE * 2 # PCM16 mono + +# Enter or Space toggles recording; q / Esc / Ctrl-D ends the session at the +# idle prompt (Ctrl-C works anywhere — cbreak mode keeps SIGINT delivery). +TOGGLE_KEYS = frozenset({"\r", "\n", " "}) +QUIT_KEYS = frozenset({"q", "Q", ESC, CTRL_C, CTRL_D}) + + +@dataclass(frozen=True) +class DictateOptions: + """Every `assembly dictate` flag as plain data (``--json`` excluded: run_command + resolves it into the ``json_mode`` argument).""" + + language: str | None + prompt: str | None + word_boost: list[str] | None + device: int | None + once: bool + max_seconds: float + + +def _note(message: str, *, json_mode: bool, quiet: bool) -> None: + """A muted stderr hint guiding the interactive session; silent under --json + (stderr must stay machine-readable) and --quiet.""" + if json_mode or quiet: + return + output.error_console.print(f"[aai.muted]{message}[/aai.muted]") + + +def _languages(language: str | None) -> str | list[str] | None: + """Fold --language into the config shape: one ISO code as a string, a + comma-separated list (code-switching audio) as a list, blank as unset.""" + if language is None: + return None + codes = [code.strip() for code in language.split(",") if code.strip()] + if not codes: + return None + return codes[0] if len(codes) == 1 else codes + + +def _record(keys: TerminalKeys, mic: MicrophoneSource, *, max_seconds: float) -> bytes: + """Capture PCM until a hotkey is pressed again or the duration cap is hit. + + The key poll runs between ~100 ms mic chunks with a zero timeout, so the mic + read loop is never blocked waiting on the keyboard. + """ + pcm = bytearray() + frames = iter(mic) + try: + for chunk in frames: + pcm += chunk + if len(pcm) >= int(max_seconds * _BYTES_PER_SECOND): + break + # None (no key pending) is simply not in either set. + if keys.read(0) in TOGGLE_KEYS | QUIT_KEYS: + break + finally: + # MicrophoneSource yields from a generator whose cleanup releases the + # device; close it deterministically instead of waiting on GC. Injected + # fakes (a plain list iterator) may not have close(). + close = getattr(frames, "close", None) + if callable(close): + close() + return bytes(pcm) + + +def _emit(result: sync_stt.SyncTranscript, *, json_mode: bool) -> None: + """One utterance to stdout: the bare transcript text, or one NDJSON object.""" + if json_mode: + output.emit_ndjson( + { + "text": result.text, + "confidence": result.confidence, + "audio_duration_ms": result.audio_duration_ms, + "session_id": result.session_id, + } + ) + else: + output.emit_text(result.text) + + +def _transcribe_utterance( + api_key: str, + pcm: bytes, + opts: DictateOptions, + state: AppState, + *, + json_mode: bool, +) -> None: + """Send one recorded utterance to the Sync API and print the transcript. + + A recording below the API's 80 ms floor (a double-tapped hotkey) is skipped + with a warning rather than bounced off the server as a 400. + """ + if len(pcm) < sync_stt.MIN_AUDIO_MS * _BYTES_PER_SECOND // 1000: + output.emit_warning( + f"Recording was shorter than {sync_stt.MIN_AUDIO_MS} ms; nothing to transcribe.", + json_mode=json_mode, + ) + return + with output.status("Transcribing…", json_mode=json_mode, quiet=state.quiet): + result = sync_stt.transcribe_pcm( + api_key, + pcm, + sample_rate=TARGET_RATE, + language_code=_languages(opts.language), + prompt=opts.prompt, + word_boost=opts.word_boost, + ) + _emit(result, json_mode=json_mode) + + +def _session( + keys: TerminalKeys, + api_key: str, + opts: DictateOptions, + state: AppState, + *, + json_mode: bool, +) -> None: + """The dictation loop: idle until a toggle key, record, transcribe, repeat.""" + while True: + key = keys.read(None) + if key is None or key in QUIT_KEYS: + return + if key not in TOGGLE_KEYS: + continue + mic = MicrophoneSource( + target_rate=TARGET_RATE, + device=opts.device, + on_open=lambda: _note( + "● Recording — press Enter to stop.", json_mode=json_mode, quiet=state.quiet + ), + ) + pcm = _record(keys, mic, max_seconds=opts.max_seconds) + _transcribe_utterance(api_key, pcm, opts, state, json_mode=json_mode) + if opts.once: + return + + +def run_dictate(opts: DictateOptions, state: AppState, *, json_mode: bool) -> None: + """Execute one `assembly dictate` invocation from already-parsed flags.""" + try: + # Entering TerminalKeys validates the terminal (a usage precondition) + # before credentials, so a piped stdin reads as "needs a terminal" — not + # as a login prompt. + with TerminalKeys() as keys: + api_key = state.resolve_api_key() + if opts.prompt and opts.language: + # The server ignores language_code whenever a custom prompt is set; + # never drop a requested flag silently (mirrors the speak warnings). + output.emit_warning( + "--language is ignored when --prompt is set; " + "state the language inside the prompt.", + json_mode=json_mode, + ) + _note( + "Press Enter to start recording, Enter again to transcribe. q quits.", + json_mode=json_mode, + quiet=state.quiet, + ) + _session(keys, api_key, opts, state, json_mode=json_mode) + except KeyboardInterrupt: + # Ctrl-C is the normal "done dictating" signal: end cleanly, not as an error. + return diff --git a/aai_cli/environments.py b/aai_cli/environments.py index 980e9f08..ad9dd509 100644 --- a/aai_cli/environments.py +++ b/aai_cli/environments.py @@ -16,6 +16,7 @@ class Environment: name: str api_base: str # SDK base_url for /v2/upload + /v2/transcript + sync_base: str # Sync STT API base (one-shot POST /transcribe, used by `assembly dictate`) streaming_host: str # StreamingClientOptions.api_host (SDK builds wss://host/v3/ws) streaming_tts_host: str # streaming TTS host; empty when TTS isn't available (prod) agents_host: str # Voice Agent host; the agent client builds wss://host/v1/ws @@ -37,6 +38,7 @@ class Environment: "production": Environment( name="production", api_base="https://api.assemblyai.com", + sync_base="https://sync.assemblyai.com", streaming_host="streaming.assemblyai.com", streaming_tts_host="", agents_host="agents.assemblyai.com", @@ -49,6 +51,7 @@ class Environment: "sandbox000": Environment( name="sandbox000", api_base="https://api.sandbox000.assemblyai-labs.com", + sync_base="https://sync.sandbox000.assemblyai-labs.com", streaming_host="streaming.sandbox000.assemblyai-labs.com", streaming_tts_host="streaming-tts.sandbox000.assemblyai-labs.com", agents_host="agents.sandbox000.assemblyai-labs.com", diff --git a/aai_cli/hotkey.py b/aai_cli/hotkey.py new file mode 100644 index 00000000..a525d960 --- /dev/null +++ b/aai_cli/hotkey.py @@ -0,0 +1,87 @@ +"""Single-keypress input for hotkey-driven commands (`assembly dictate`). + +``TerminalKeys`` switches stdin into cbreak mode for the lifetime of a ``with`` +block, so individual keypresses arrive without Enter — while Ctrl-C still raises +KeyboardInterrupt (cbreak keeps ISIG, unlike full raw mode). POSIX-only: there +is no termios on Windows, so entering the context raises a clean CLIError there +instead of an ImportError traceback. Stdlib-only on purpose, mirroring the other +non-rendering layers. +""" + +from __future__ import annotations + +import os +import select +import sys + +from aai_cli.errors import CLIError + +# Control characters hotkey-driven commands treat as "end the session". +CTRL_C = "\x03" +CTRL_D = "\x04" +ESC = "\x1b" + + +def _stdin_fd() -> int: + """The stdin file descriptor, or -1 when stdin has none (a captured/replaced + stream in an embedding or test harness) — os.isatty(-1) is False, so that + case falls into the clean not-a-tty error instead of an fileno traceback.""" + try: + return sys.stdin.fileno() + except (ValueError, OSError): # ValueError covers io.UnsupportedOperation + return -1 + + +class TerminalKeys: + """Reads single keypresses from a terminal fd, cbreak-scoped via ``with``. + + The fd is injectable (tests drive it through a pty pair); it defaults to + the process's stdin. + """ + + def __init__(self, fd: int | None = None) -> None: + self._fd = fd if fd is not None else _stdin_fd() + # termios.tcgetattr's attribute list (typeshed's exact shape). + self._saved: list[int | list[bytes | int]] | None = None + + def __enter__(self) -> TerminalKeys: + try: + import termios + import tty + except ImportError as exc: + raise CLIError( + "Hotkey input is not supported on this platform (no termios).", + error_type="unsupported_platform", + exit_code=2, + ) from exc + if not os.isatty(self._fd): + raise CLIError( + "This command needs an interactive terminal: it waits for hotkey presses on stdin.", + error_type="not_a_tty", + exit_code=2, + suggestion="Run it directly in a terminal, without piping or redirecting stdin.", + ) + self._saved = termios.tcgetattr(self._fd) + tty.setcbreak(self._fd) + return self + + def __exit__(self, *exc: object) -> None: + if self._saved is not None: + import termios + + termios.tcsetattr(self._fd, termios.TCSADRAIN, self._saved) + self._saved = None + + def read(self, timeout: float | None) -> str | None: + """One keypress, or None when ``timeout`` elapses or stdin hits EOF. + + ``timeout=None`` blocks until a key arrives; ``timeout=0`` polls without + waiting (the in-recording check between audio chunks). + """ + ready, _, _ = select.select([self._fd], [], [], timeout) + if not ready: + return None + data = os.read(self._fd, 1) + if not data: + return None + return data.decode("utf-8", "replace") diff --git a/aai_cli/main.py b/aai_cli/main.py index 2b287831..ec1af00a 100644 --- a/aai_cli/main.py +++ b/aai_cli/main.py @@ -27,6 +27,7 @@ audit, deploy, dev, + dictate, doctor, evaluate, init, @@ -65,6 +66,7 @@ # Run AssemblyAI — use AssemblyAI directly from the terminal "transcribe", "stream", + "dictate", "agent", "speak", "llm", @@ -400,6 +402,7 @@ def main( # panel is controlled by `_COMMAND_ORDER` via `_OrderedGroup`, not registration order. app.add_typer(transcribe.app) app.add_typer(stream.app) +app.add_typer(dictate.app) app.add_typer(transcripts.app, name="transcripts", rich_help_panel=help_panels.HISTORY) app.add_typer(sessions.app, name="sessions", rich_help_panel=help_panels.HISTORY) app.add_typer(audit.app) # audit diff --git a/aai_cli/sync_stt.py b/aai_cli/sync_stt.py new file mode 100644 index 00000000..213c6e54 --- /dev/null +++ b/aai_cli/sync_stt.py @@ -0,0 +1,141 @@ +"""HTTP boundary for the Sync STT API (sync.assemblyai.com). + +One POST per utterance — multipart audio + JSON config in, the transcript back in +the response body. No polling, no session management; the Universal-3 Pro sync +model handles audio from 80 ms up to 120 s. Backs `assembly dictate`. Like +client.py/ams.py, this module normalizes every failure into the CLIError +hierarchy and stays Rich-free (import-linter contract 3). +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass + +import httpx2 as httpx + +from aai_cli import environments, jsonshape +from aai_cli.errors import APIError, auth_failure + +# The X-AAI-Model header is required on every Sync API request. +SYNC_MODEL = "u3-sync-pro" +# Documented audio constraints; callers gate on these before uploading. +MIN_AUDIO_MS = 80 +MAX_AUDIO_SECONDS = 120 +# The request itself has a 30 s inference deadline server-side; leave headroom +# for the upload of up to 40 MB of audio. +_TIMEOUT = 60.0 +_RATE_LIMIT_STATUSES = (429, 503) +_HTTP_ERROR_MIN_STATUS = 400 + + +@dataclass(frozen=True) +class SyncTranscript: + """The fields of a Sync API response the CLI renders.""" + + text: str + confidence: float | None + audio_duration_ms: int | None + session_id: str | None + + +def _detail(resp: httpx.Response) -> str: + """A human-readable failure detail: the body's message/detail (with its + error_code when present), else the raw body, else the bare HTTP status.""" + fallback = resp.text or f"HTTP {resp.status_code}" + try: + body: object = resp.json() + except ValueError: + return fallback + mapping = jsonshape.as_mapping(body) + if mapping is None: + return fallback + message = mapping.get("message") or mapping.get("detail") + if message is None: + return fallback + code = mapping.get("error_code") + return f"{message} ({code})" if code else str(message) + + +def _raise_for_error(resp: httpx.Response) -> None: + if resp.status_code in (401, 403): + raise auth_failure() + if resp.status_code in _RATE_LIMIT_STATUSES: + # 429 (rate limit) and 503 (capacity / model cold start) are both + # transient: the documented recovery is simply to retry shortly. + raise APIError( + f"The Sync API is busy ({resp.status_code}): {_detail(resp)}", + suggestion="Wait a moment and try again.", + ) + if resp.status_code >= _HTTP_ERROR_MIN_STATUS: + raise APIError(f"Sync transcription failed ({resp.status_code}): {_detail(resp)}") + + +def _parse(resp: httpx.Response) -> SyncTranscript: + try: + body: object = resp.json() + except ValueError as exc: + raise APIError( + f"The Sync API returned a response that is not valid JSON (HTTP {resp.status_code})." + ) from exc + mapping = jsonshape.as_mapping(body) + if mapping is None or "text" not in mapping: + raise APIError("The Sync API returned an unexpected response shape (no transcript text).") + confidence = mapping.get("confidence") + duration = mapping.get("audio_duration_ms") + session_id = mapping.get("session_id") + return SyncTranscript( + text=str(mapping["text"]), + confidence=jsonshape.as_float(confidence) if confidence is not None else None, + audio_duration_ms=jsonshape.as_int(duration) if duration is not None else None, + session_id=str(session_id) if session_id is not None else None, + ) + + +def _config( + sample_rate: int, + channels: int, + language_code: str | list[str] | None, + prompt: str | None, + word_boost: list[str] | None, +) -> dict[str, object]: + """The JSON config part: PCM geometry always, model knobs only when set.""" + cfg: dict[str, object] = {"sample_rate": sample_rate, "channels": channels} + if language_code: + cfg["language_code"] = language_code + if prompt: + cfg["prompt"] = prompt + if word_boost: + cfg["word_boost"] = list(word_boost) + return cfg + + +def transcribe_pcm( + api_key: str, + pcm: bytes, + *, + sample_rate: int, + channels: int = 1, + language_code: str | list[str] | None = None, + prompt: str | None = None, + word_boost: list[str] | None = None, +) -> SyncTranscript: + """POST raw PCM (S16LE) to the Sync API and return the finished transcript.""" + cfg = _config(sample_rate, channels, language_code, prompt, word_boost) + files: dict[str, tuple[str | None, bytes | str, str | None]] = { + "audio": ("audio.pcm", pcm, "audio/pcm"), + "config": (None, json.dumps(cfg), "application/json"), + } + headers = {"authorization": api_key, "x-aai-model": SYNC_MODEL} + try: + with httpx.Client(timeout=_TIMEOUT) as client: + resp = client.post( + f"{environments.active().sync_base}/transcribe", headers=headers, files=files + ) + except httpx.HTTPError as exc: + raise APIError( + f"Could not reach the Sync API: {exc}", + suggestion="Check your network connection and try again.", + ) from exc + _raise_for_error(resp) + return _parse(resp) diff --git a/tests/__snapshots__/test_snapshots_help_run.ambr b/tests/__snapshots__/test_snapshots_help_run.ambr index 29d7c2f1..1ac7eba4 100644 --- a/tests/__snapshots__/test_snapshots_help_run.ambr +++ b/tests/__snapshots__/test_snapshots_help_run.ambr @@ -65,6 +65,56 @@ + ''' +# --- +# name: test_command_help_matches_snapshot[dictate] + ''' + + Usage: assembly dictate [OPTIONS] + + Dictate with a hotkey: record the mic, get the transcript back instantly. + + Press Enter (or Space) to start recording and press it again to stop; the + utterance is sent to the AssemblyAI Sync API and the transcript prints + immediately — no polling. Press q (or Esc/Ctrl-C) to finish. Each utterance + can be up to 120 seconds long. + + ╭─ Options ────────────────────────────────────────────────────────────────────╮ + │ --language TEXT ISO 639-1 language code, │ + │ or a comma-separated list │ + │ for code-switching audio │ + │ (default: en). │ + │ --prompt TEXT Custom transcription │ + │ prompt (overrides │ + │ --language). │ + │ --word-boost TEXT Bias recognition toward a │ + │ term (repeatable). │ + │ --device INTEGER Microphone device index. │ + │ --once Transcribe one utterance, │ + │ then exit. │ + │ --max-seconds FLOAT RANGE Auto-stop a recording │ + │ [1.0<=x<=120.0] after this many seconds. │ + │ [default: 120.0] │ + │ --json -j Emit one JSON object per │ + │ utterance. │ + │ --help Show this message and │ + │ exit. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + Examples + Dictate: Enter starts a recording, Enter transcribes it + $ assembly dictate + One utterance, then exit + $ assembly dictate --once + Dictate in Spanish + $ assembly dictate --language es + Bias recognition toward tricky terms + $ assembly dictate --word-boost AssemblyAI --word-boost LeMUR + One JSON object per utterance + $ assembly dictate --json + + + ''' # --- # name: test_command_help_matches_snapshot[eval] diff --git a/tests/_snapshot_surface.py b/tests/_snapshot_surface.py index d28cd75c..35608e4b 100644 --- a/tests/_snapshot_surface.py +++ b/tests/_snapshot_surface.py @@ -23,7 +23,9 @@ # ``tests/test_snapshots_help_.py`` module suffixes. HELP_GROUPS: dict[str, frozenset[str]] = { "build": frozenset({"onboard", "init", "dev", "share", "deploy"}), - "run": frozenset({"transcribe", "stream", "agent", "speak", "llm", "eval", "webhooks"}), + "run": frozenset( + {"transcribe", "stream", "dictate", "agent", "speak", "llm", "eval", "webhooks"} + ), "tools": frozenset({"doctor", "setup", "telemetry", "_update-check"}), "history": frozenset({"transcripts", "sessions"}), "account": frozenset( diff --git a/tests/test_dictate_command.py b/tests/test_dictate_command.py new file mode 100644 index 00000000..63f12088 --- /dev/null +++ b/tests/test_dictate_command.py @@ -0,0 +1,84 @@ +"""The `assembly dictate` Typer surface: argv -> DictateOptions mapping and the +non-terminal failure mode. Session behavior lives in test_dictate_exec.py.""" + +from typer.testing import CliRunner + +from aai_cli import dictate_exec +from aai_cli.main import app + +runner = CliRunner() + + +def _capture_run(monkeypatch): + seen = {} + + def fake_run(opts, state, *, json_mode): + seen["opts"] = opts + seen["json_mode"] = json_mode + + monkeypatch.setattr(dictate_exec, "run_dictate", fake_run) + return seen + + +def test_defaults_map_to_options(monkeypatch): + seen = _capture_run(monkeypatch) + result = runner.invoke(app, ["dictate"]) + assert result.exit_code == 0 + assert seen["opts"] == dictate_exec.DictateOptions( + language=None, + prompt=None, + word_boost=None, + device=None, + once=False, + max_seconds=120.0, + ) + assert seen["json_mode"] is False + + +def test_every_flag_maps_to_its_option_field(monkeypatch): + seen = _capture_run(monkeypatch) + result = runner.invoke( + app, + [ + "dictate", + "--language", + "es", + "--prompt", + "Verbatim.", + "--word-boost", + "AssemblyAI", + "--word-boost", + "LeMUR", + "--device", + "2", + "--once", + "--max-seconds", + "30", + "--json", + ], + ) + assert result.exit_code == 0 + assert seen["opts"] == dictate_exec.DictateOptions( + language="es", + prompt="Verbatim.", + word_boost=["AssemblyAI", "LeMUR"], + device=2, + once=True, + max_seconds=30.0, + ) + assert seen["json_mode"] is True + + +def test_max_seconds_is_capped_at_the_api_limit(): + result = runner.invoke(app, ["dictate", "--max-seconds", "200"]) + assert result.exit_code == 2 + assert "120" in result.output + + +def test_outside_a_terminal_is_a_usage_error_not_a_login(): + # CliRunner's stdin is not a terminal and no credentials are configured: the + # whole stack (command -> run_dictate -> TerminalKeys) must surface the + # terminal requirement, not start an authentication flow. + result = runner.invoke(app, ["dictate"]) + assert result.exit_code == 2 + assert "interactive terminal" in result.output diff --git a/tests/test_dictate_exec.py b/tests/test_dictate_exec.py new file mode 100644 index 00000000..2b6b185f --- /dev/null +++ b/tests/test_dictate_exec.py @@ -0,0 +1,312 @@ +"""Direct tests of the `assembly dictate` options/run seam (dictate_exec). + +The session is driven by constructing DictateOptions and injecting the three +boundaries — TerminalKeys (scripted keys), MicrophoneSource (canned PCM), and +sync_stt.transcribe_pcm (recorded calls) — so no test needs a real terminal, +microphone, or network. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import json + +import pytest + +from aai_cli import config, dictate_exec, sync_stt +from aai_cli.context import AppState +from aai_cli.errors import CLIError + +DICTATE_DEFAULTS = dictate_exec.DictateOptions( + language=None, + prompt=None, + word_boost=None, + device=None, + once=False, + max_seconds=120.0, +) + +# One ~100 ms chunk of 16 kHz PCM16 — comfortably above the 80 ms upload floor. +CHUNK = b"\x01\x00" * 1600 + +RESULT = sync_stt.SyncTranscript( + text="hello world", confidence=0.9, audio_duration_ms=1500, session_id="sess-1" +) + + +class FakeKeys: + """A scripted TerminalKeys: each read() pops the next key (None = no key yet); + an exhausted script reads as EOF, which ends the session.""" + + def __init__(self, script): + self.script = list(script) + self.timeouts = [] + self.entered = False + self.exited = False + + def __enter__(self): + self.entered = True + return self + + def __exit__(self, *exc): + self.exited = True + + def read(self, timeout): + self.timeouts.append(timeout) + return self.script.pop(0) if self.script else None + + +class RaisingKeys(FakeKeys): + def read(self, timeout): + raise KeyboardInterrupt + + +@pytest.fixture +def seams(monkeypatch): + """Wire all three boundaries; returns the mutable harness state.""" + config.set_api_key("default", "sk_live") + harness = {"keys": FakeKeys([]), "chunks": [CHUNK, CHUNK], "mic": {}, "calls": []} + + monkeypatch.setattr(dictate_exec, "TerminalKeys", lambda: harness["keys"]) + + def fake_mic(*, target_rate, device=None, on_open=None): + harness["mic"].update(target_rate=target_rate, device=device) + if on_open is not None: + on_open() + return iter(harness["chunks"]) + + monkeypatch.setattr(dictate_exec, "MicrophoneSource", fake_mic) + + def fake_transcribe(api_key, pcm, *, sample_rate, channels=1, **kwargs): + harness["calls"].append( + {"api_key": api_key, "pcm": pcm, "sample_rate": sample_rate, "channels": channels} + | kwargs + ) + return RESULT + + monkeypatch.setattr(dictate_exec.sync_stt, "transcribe_pcm", fake_transcribe) + return harness + + +def _run(opts=DICTATE_DEFAULTS, state=None, *, json_mode=False): + dictate_exec.run_dictate(opts, state or AppState(), json_mode=json_mode) + + +def test_options_are_immutable(): + field_name = dataclasses.fields(DICTATE_DEFAULTS)[0].name + with pytest.raises(dataclasses.FrozenInstanceError): + setattr(DICTATE_DEFAULTS, field_name, None) + + +def test_hotkey_records_then_prints_bare_transcript(seams, capsys): + # Enter starts; the in-recording poll sees nothing after chunk 1, Enter after + # chunk 2 stops; q at the idle prompt quits. + seams["keys"] = FakeKeys(["\r", None, "\r", "q"]) + _run() + # Both chunks were captured and uploaded as one utterance at the resampled rate. + assert seams["calls"] == [ + { + "api_key": "sk_live", + "pcm": CHUNK + CHUNK, + "sample_rate": 16000, + "channels": 1, + "language_code": None, + "prompt": None, + "word_boost": None, + } + ] + captured = capsys.readouterr() + # Human mode: the bare text on stdout (pipe-friendly), not a JSON object. + assert captured.out.strip() == "hello world" + # The interactive hints (idle prompt + recording note) go to stderr only. + assert "Press Enter to start recording" in captured.err + assert "Recording — press Enter to stop" in captured.err + assert seams["mic"] == {"target_rate": 16000, "device": None} + assert seams["keys"].entered and seams["keys"].exited # terminal restored + # Idle waits block (None); in-recording polls must not wait at all (0), or + # every audio chunk would stall behind the keyboard. + assert seams["keys"].timeouts == [None, 0, 0, None] + + +def test_json_mode_emits_one_ndjson_object_per_utterance(seams, capsys): + seams["keys"] = FakeKeys(["\r", "\r"]) + _run(json_mode=True) + captured = capsys.readouterr() + assert json.loads(captured.out) == { + "text": "hello world", + "confidence": 0.9, + "audio_duration_ms": 1500, + "session_id": "sess-1", + } + # --json keeps stderr machine-readable: no human hints. + assert captured.err == "" + + +def test_quiet_suppresses_the_interactive_hints(seams, capsys): + seams["keys"] = FakeKeys(["\r", "\r"]) + _run(state=AppState(quiet=True)) + captured = capsys.readouterr() + assert captured.out.strip() == "hello world" + assert captured.err == "" + + +def test_once_exits_after_a_single_utterance(seams): + seams["keys"] = FakeKeys(["\r", "\r", "\r", "\r", "\r", "\r"]) + _run(dataclasses.replace(DICTATE_DEFAULTS, once=True)) + assert len(seams["calls"]) == 1 + # The session ended on --once, not by draining the key script. + assert seams["keys"].script + + +@pytest.mark.parametrize("quit_key", ["q", "Q", "\x1b", "\x04"]) +def test_quit_keys_end_the_session_without_recording(seams, quit_key, capsys): + seams["keys"] = FakeKeys([quit_key, "\r", "\r"]) + _run() + assert seams["calls"] == [] + assert capsys.readouterr().out == "" + + +def test_unbound_keys_are_ignored_at_the_idle_prompt(seams): + seams["keys"] = FakeKeys(["x", "7", "q"]) + _run() + assert seams["calls"] == [] + + +def test_space_also_toggles_recording(seams): + seams["keys"] = FakeKeys([" ", " ", "q"]) + _run() + assert len(seams["calls"]) == 1 + + +def test_unbound_keys_during_recording_do_not_stop_capture(seams): + # A stray keystroke mid-utterance is ignored; only Enter/Space (or a quit + # key) ends the capture. + seams["keys"] = FakeKeys(["\r", "x", "\r", "q"]) + seams["chunks"] = [CHUNK, CHUNK, CHUNK] + _run() + assert seams["calls"][0]["pcm"] == CHUNK + CHUNK + + +def test_quit_key_during_recording_still_transcribes_the_utterance(seams): + seams["keys"] = FakeKeys(["\r", "q"]) + _run() + assert len(seams["calls"]) == 1 + assert seams["calls"][0]["pcm"] == CHUNK # stopped after the first chunk + + +def test_recording_stops_at_the_duration_cap(seams): + # 0.2 s at 16 kHz PCM16 = 6400 bytes = exactly two chunks; the poll never + # reports a key, so only the cap can stop the capture. + seams["keys"] = FakeKeys(["\r"]) + seams["chunks"] = [CHUNK] * 5 + _run(dataclasses.replace(DICTATE_DEFAULTS, max_seconds=0.2)) + assert len(seams["calls"]) == 1 + assert seams["calls"][0]["pcm"] == CHUNK + CHUNK + + +def test_recording_closes_the_mic_generator(seams): + closed = [] + + def chunk_gen(): + try: + yield CHUNK + yield CHUNK + yield CHUNK + finally: + closed.append(True) + + seams["keys"] = FakeKeys(["\r", "\r", "q"]) + seams["chunks"] = chunk_gen() + _run() + assert closed == [True] # the device-releasing cleanup ran at stop, not at GC + + +@pytest.mark.parametrize("size", [200, 2558]) # 2558: just under the exact 2560-byte floor +def test_too_short_recording_is_skipped_with_a_warning(seams, capsys, size): + seams["keys"] = FakeKeys(["\r", "\r", "q"]) + seams["chunks"] = [b"\x01" * size] # below 80 ms of 16 kHz PCM16 (2560 bytes) + _run() + assert seams["calls"] == [] + captured = capsys.readouterr() + assert captured.out == "" + assert "shorter than 80 ms" in captured.err + + +def test_recording_at_the_80ms_floor_is_transcribed(seams): + seams["keys"] = FakeKeys(["\r", "\r", "q"]) + seams["chunks"] = [b"\x01" * 2560] # exactly 80 ms: allowed, not skipped + _run() + assert len(seams["calls"]) == 1 + + +def test_language_and_boost_flags_are_forwarded(seams): + seams["keys"] = FakeKeys(["\r", "\r", "q"]) + _run(dataclasses.replace(DICTATE_DEFAULTS, language="es", word_boost=["AssemblyAI"])) + assert seams["calls"][0]["language_code"] == "es" + assert seams["calls"][0]["word_boost"] == ["AssemblyAI"] + + +def test_comma_separated_languages_become_a_list(seams): + seams["keys"] = FakeKeys(["\r", "\r", "q"]) + _run(dataclasses.replace(DICTATE_DEFAULTS, language="en, es")) + assert seams["calls"][0]["language_code"] == ["en", "es"] + + +def test_blank_language_reads_as_unset(seams): + seams["keys"] = FakeKeys(["\r", "\r", "q"]) + _run(dataclasses.replace(DICTATE_DEFAULTS, language=" , ")) + assert seams["calls"][0]["language_code"] is None + + +def test_prompt_with_language_warns_that_language_is_ignored(seams, capsys): + seams["keys"] = FakeKeys(["q"]) + _run(dataclasses.replace(DICTATE_DEFAULTS, prompt="Verbatim.", language="es")) + assert "--language is ignored when --prompt is set" in capsys.readouterr().err + + +def test_prompt_alone_is_forwarded_without_warning(seams, capsys): + seams["keys"] = FakeKeys(["\r", "\r", "q"]) + _run(dataclasses.replace(DICTATE_DEFAULTS, prompt="Verbatim.")) + assert seams["calls"][0]["prompt"] == "Verbatim." + assert "ignored" not in capsys.readouterr().err + + +def test_transcription_runs_under_the_status_spinner(seams, monkeypatch): + seen = {} + + @contextlib.contextmanager + def fake_status(message, *, json_mode, quiet=False): + seen.update(message=message, json_mode=json_mode, quiet=quiet) + yield + + monkeypatch.setattr(dictate_exec.output, "status", fake_status) + seams["keys"] = FakeKeys(["\r", "\r", "q"]) + _run(state=AppState(quiet=True)) + assert seen == {"message": "Transcribing…", "json_mode": False, "quiet": True} + + +def test_ctrl_c_ends_the_session_cleanly(seams): + keys = RaisingKeys([]) + seams["keys"] = keys + _run() # no exception + assert keys.exited # the with-block unwound, restoring the terminal + + +def test_terminal_is_validated_before_credentials(seams, monkeypatch): + # No key is configured and TerminalKeys rejects the terminal: the usage + # error must win, not NotAuthenticated (validation before credentials). + config.clear_api_key("default") + + class NoTty: + def __enter__(self): + raise CLIError("This command needs an interactive terminal.", exit_code=2) + + def __exit__(self, *exc): + return None + + monkeypatch.setattr(dictate_exec, "TerminalKeys", NoTty) + with pytest.raises(CLIError) as exc: + _run() + assert exc.value.exit_code == 2 + assert "interactive terminal" in exc.value.message diff --git a/tests/test_environments.py b/tests/test_environments.py index 051bcc68..e21b3da2 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -17,6 +17,11 @@ def test_production_uses_production_ams_endpoint(): assert env.ams_base == "https://ams.internal.assemblyai-labs.com" +def test_sync_base_per_environment(): + assert environments.get("production").sync_base == "https://sync.assemblyai.com" + assert environments.get("sandbox000").sync_base == "https://sync.sandbox000.assemblyai-labs.com" + + def test_get_unknown_raises_cli_error(): with pytest.raises(CLIError) as exc: environments.get("nope") diff --git a/tests/test_hotkey.py b/tests/test_hotkey.py new file mode 100644 index 00000000..515929cf --- /dev/null +++ b/tests/test_hotkey.py @@ -0,0 +1,119 @@ +"""TerminalKeys: cbreak scoping, single-key reads, and the clean failure modes. + +The terminal tests drive a real pty pair (os.openpty), so termios behavior is +exercised for real without touching the test runner's stdin. +""" + +import os +import sys +import termios + +import pytest + +from aai_cli.errors import CLIError +from aai_cli.hotkey import TerminalKeys, _stdin_fd + + +@pytest.fixture +def pty_pair(): + master, slave = os.openpty() + yield master, slave + os.close(master) + os.close(slave) + + +def test_reads_single_keypresses_without_enter(pty_pair): + master, slave = pty_pair + with TerminalKeys(fd=slave) as keys: + os.write(master, b"ab") + # One keypress per read, even when several are queued. + assert keys.read(5.0) == "a" + assert keys.read(5.0) == "b" + + +def test_poll_returns_none_when_no_key_is_pending(pty_pair): + _, slave = pty_pair + with TerminalKeys(fd=slave) as keys: + assert keys.read(0) is None + + +def test_cbreak_is_scoped_to_the_context(pty_pair): + _, slave = pty_pair + lflag_index = 3 + assert termios.tcgetattr(slave)[lflag_index] & termios.ICANON + with TerminalKeys(fd=slave): + inside = termios.tcgetattr(slave)[lflag_index] + assert not inside & termios.ICANON # keys arrive without Enter + assert inside & termios.ISIG # but Ctrl-C still raises KeyboardInterrupt + assert termios.tcgetattr(slave)[lflag_index] & termios.ICANON # restored + + +def test_exit_without_enter_restores_nothing(pty_pair): + # __exit__ is a no-op when the cbreak switch never happened (or already ran): + # exiting twice must not call tcsetattr with stale state. + _, slave = pty_pair + keys = TerminalKeys(fd=slave) + keys.__exit__(None, None, None) # never entered: nothing to restore + with keys: + pass + keys.__exit__(None, None, None) # second exit after restore: still a no-op + + +def test_non_tty_fd_is_a_clean_usage_error(tmp_path): + with (tmp_path / "plain-file").open("w") as f: + with pytest.raises(CLIError) as exc: + with TerminalKeys(fd=f.fileno()): + pass + assert exc.value.exit_code == 2 + assert exc.value.error_type == "not_a_tty" + assert "interactive terminal" in exc.value.message + + +def test_platform_without_termios_is_a_clean_error(pty_pair, monkeypatch): + # Windows has no termios; None in sys.modules makes the import raise. + _, slave = pty_pair + monkeypatch.setitem(sys.modules, "termios", None) + with pytest.raises(CLIError) as exc: + with TerminalKeys(fd=slave): + pass + assert exc.value.exit_code == 2 + assert exc.value.error_type == "unsupported_platform" + + +def test_read_returns_none_at_eof(): + # A pipe stands in for a hung-up terminal: select reports readable, the + # read yields no bytes. (read() itself doesn't require a tty; only the + # cbreak context does.) + read_end, write_end = os.pipe() + try: + os.write(write_end, b"z") + os.close(write_end) + keys = TerminalKeys(fd=read_end) + assert keys.read(0) == "z" # drains the last byte + assert keys.read(0) is None # then EOF + finally: + os.close(read_end) + + +def test_stdin_fd_defaults_to_real_stdin_or_minus_one(monkeypatch): + class NoFileno: + def fileno(self): + raise OSError("no underlying file") + + monkeypatch.setattr(sys, "stdin", NoFileno()) + assert _stdin_fd() == -1 + assert TerminalKeys()._fd == -1 + + class CapturedStdin: + def fileno(self): + raise ValueError("I/O operation on captured stream") + + monkeypatch.setattr(sys, "stdin", CapturedStdin()) + assert _stdin_fd() == -1 + + class RealStdin: + def fileno(self): + return 42 + + monkeypatch.setattr(sys, "stdin", RealStdin()) + assert _stdin_fd() == 42 diff --git a/tests/test_smoke.py b/tests/test_smoke.py index a1ec7119..bcdb97b9 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -153,6 +153,7 @@ def test_help_lists_commands_in_workflow_order(): # Run AssemblyAI "transcribe", "stream", + "dictate", "agent", "speak", "llm", diff --git a/tests/test_sync_stt.py b/tests/test_sync_stt.py new file mode 100644 index 00000000..2688c0ac --- /dev/null +++ b/tests/test_sync_stt.py @@ -0,0 +1,222 @@ +"""The Sync STT HTTP boundary: request shape, error normalization, parsing.""" + +import dataclasses + +import httpx2 as httpx +import pytest + +from aai_cli import environments, sync_stt +from aai_cli.errors import APIError, NotAuthenticated + + +def _patch_transport(monkeypatch, handler): + real_client = httpx.Client + seen_kwargs = {} + + def fake_client(*args, **kwargs): + seen_kwargs.update(kwargs) + kwargs["transport"] = httpx.MockTransport(handler) + return real_client(*args, **kwargs) + + monkeypatch.setattr(sync_stt.httpx, "Client", fake_client) + return seen_kwargs + + +def _ok_handler(seen): + def handler(request: httpx.Request) -> httpx.Response: + seen["url"] = str(request.url) + seen["auth"] = request.headers.get("authorization") + seen["model"] = request.headers.get("x-aai-model") + seen["body"] = request.read() + return httpx.Response( + 200, + json={ + "text": "Hi, I'm calling about my order.", + "confidence": 0.87, + "audio_duration_ms": 1500, + "session_id": "eb92c4ff", + }, + ) + + return handler + + +def test_posts_pcm_and_config_to_the_active_environment(monkeypatch): + seen = {} + client_kwargs = _patch_transport(monkeypatch, _ok_handler(seen)) + result = sync_stt.transcribe_pcm("sk_key_pcm", b"\x01\x02pcm-bytes", sample_rate=16000) + assert seen["url"] == "https://sync.assemblyai.com/transcribe" + assert seen["auth"] == "sk_key_pcm" + assert seen["model"] == "u3-sync-pro" + # Multipart body: a raw-PCM audio part plus the JSON config part. + assert b"\x01\x02pcm-bytes" in seen["body"] + assert b"audio/pcm" in seen["body"] + assert b'"sample_rate": 16000' in seen["body"] + assert b'"channels": 1' in seen["body"] + # Optional knobs are omitted entirely when unset. + assert b"language_code" not in seen["body"] + assert b"prompt" not in seen["body"] + assert b"word_boost" not in seen["body"] + # Generous timeout: the upload can carry up to 2 minutes of audio. + assert client_kwargs["timeout"] == 60.0 + assert result == sync_stt.SyncTranscript( + text="Hi, I'm calling about my order.", + confidence=0.87, + audio_duration_ms=1500, + session_id="eb92c4ff", + ) + + +def test_targets_the_sandbox_host_when_active(monkeypatch): + seen = {} + _patch_transport(monkeypatch, _ok_handler(seen)) + environments.set_active(environments.get("sandbox000")) + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert seen["url"] == "https://sync.sandbox000.assemblyai-labs.com/transcribe" + + +def test_optional_config_fields_are_sent_when_set(monkeypatch): + seen = {} + _patch_transport(monkeypatch, _ok_handler(seen)) + sync_stt.transcribe_pcm( + "sk", + b"pcm", + sample_rate=44100, + channels=2, + language_code=["en", "es"], + prompt="Transcribe verbatim.", + word_boost=["AssemblyAI", "LeMUR"], + ) + assert b'"sample_rate": 44100' in seen["body"] + assert b'"channels": 2' in seen["body"] + assert b'"language_code": ["en", "es"]' in seen["body"] + assert b'"prompt": "Transcribe verbatim."' in seen["body"] + assert b'"word_boost": ["AssemblyAI", "LeMUR"]' in seen["body"] + + +def test_single_language_code_is_sent_as_a_string(monkeypatch): + seen = {} + _patch_transport(monkeypatch, _ok_handler(seen)) + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000, language_code="es") + assert b'"language_code": "es"' in seen["body"] + + +def test_sync_transcript_is_immutable(monkeypatch): + seen = {} + _patch_transport(monkeypatch, _ok_handler(seen)) + result = sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + field_name = dataclasses.fields(result)[0].name + with pytest.raises(dataclasses.FrozenInstanceError): + setattr(result, field_name, "mutated") + + +def test_missing_optional_response_fields_parse_as_none(monkeypatch): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"text": "bare"}) + + _patch_transport(monkeypatch, handler) + result = sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert result == sync_stt.SyncTranscript( + text="bare", confidence=None, audio_duration_ms=None, session_id=None + ) + + +@pytest.mark.parametrize("status", [401, 403]) +def test_auth_rejection_raises_not_authenticated(monkeypatch, status): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(status, json={"detail": "Invalid API key"}) + + _patch_transport(monkeypatch, handler) + with pytest.raises(NotAuthenticated) as exc: + sync_stt.transcribe_pcm("bad", b"pcm", sample_rate=16000) + assert exc.value.rejected_key is True + + +@pytest.mark.parametrize("status", [429, 503]) +def test_rate_limit_and_capacity_are_retryable_api_errors(monkeypatch, status): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + status, json={"error_code": "capacity_exceeded", "message": "server at cap"} + ) + + _patch_transport(monkeypatch, handler) + with pytest.raises(APIError) as exc: + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert "busy" in exc.value.message + assert str(status) in exc.value.message + assert "server at cap (capacity_exceeded)" in exc.value.message + assert "try again" in (exc.value.suggestion or "") + + +def test_audio_error_carries_error_code_and_message(monkeypatch): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 400, json={"error_code": "audio_too_short", "message": "audio below 80 ms"} + ) + + _patch_transport(monkeypatch, handler) + with pytest.raises(APIError) as exc: + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert "Sync transcription failed (400)" in exc.value.message + assert "audio below 80 ms (audio_too_short)" in exc.value.message + + +def test_error_detail_reads_detail_field(monkeypatch): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(400, json={"detail": "missing audio part"}) + + _patch_transport(monkeypatch, handler) + with pytest.raises(APIError) as exc: + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert "missing audio part" in exc.value.message + + +@pytest.mark.parametrize( + ("body", "expected"), + [ + (b"upstream proxy says no", "upstream proxy says no"), # non-JSON body + (b"", "HTTP 500"), # empty body -> bare status + (b'["weird"]', "weird"), # JSON but not an object -> raw text + (b'{"unrelated": true}', '{"unrelated": true}'), # object without message/detail + ], +) +def test_error_detail_falls_back_to_raw_body_or_status(monkeypatch, body, expected): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(500, content=body) + + _patch_transport(monkeypatch, handler) + with pytest.raises(APIError) as exc: + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert expected in exc.value.message + + +def test_success_with_unparseable_body_is_a_clean_api_error(monkeypatch): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=b"not-json") + + _patch_transport(monkeypatch, handler) + with pytest.raises(APIError) as exc: + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert "not valid JSON" in exc.value.message + + +@pytest.mark.parametrize("payload", [{"words": []}, ["list"]]) +def test_success_without_transcript_text_is_a_clean_api_error(monkeypatch, payload): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json=payload) + + _patch_transport(monkeypatch, handler) + with pytest.raises(APIError) as exc: + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert "unexpected response shape" in exc.value.message + + +def test_network_failure_is_a_clean_api_error(monkeypatch): + def handler(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("connection refused") + + _patch_transport(monkeypatch, handler) + with pytest.raises(APIError) as exc: + sync_stt.transcribe_pcm("sk", b"pcm", sample_rate=16000) + assert "Could not reach the Sync API" in exc.value.message + assert "network" in (exc.value.suggestion or "")