diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c3ea29d..7033dae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: pull_request: jobs: - lint-and-test-default: + lint-and-test-portable: runs-on: ubuntu-latest steps: - name: Checkout @@ -29,11 +29,11 @@ jobs: - name: Check formatting run: cargo fmt --all -- --check - - name: Clippy (default features) - run: cargo clippy --all-targets -- -D warnings + - name: Clippy (CPU-only OSD build) + run: cargo clippy --all-targets --no-default-features --features osd -- -D warnings - - name: Test (default features) - run: cargo test + - name: Test (CPU-only OSD build) + run: cargo test --no-default-features --features osd feature-checks: runs-on: ubuntu-latest @@ -55,13 +55,26 @@ jobs: - name: Check no default features run: cargo check --no-default-features - - name: Check osd feature only + - name: Check CPU-only OSD feature set run: cargo check --no-default-features --features osd - - name: Verify package (default publish surface) - run: cargo package --locked + - name: Verify package + run: | + if command -v nvcc >/dev/null 2>&1; then + cargo package --locked + else + cargo package --locked --no-verify + fi + + - name: Check default feature set (if toolkit available) + run: | + if command -v nvcc >/dev/null 2>&1; then + cargo check + else + echo "CUDA toolkit not available on this runner; skipping default feature check" + fi - - name: Check cuda feature only (if toolkit available) + - name: Check CUDA-only feature set (if toolkit available) run: | if command -v nvcc >/dev/null 2>&1; then cargo check --no-default-features --features cuda diff --git a/Cargo.lock b/Cargo.lock index 1d3e336..c93d0b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "alsa" version = "0.9.1" @@ -358,26 +364,6 @@ dependencies = [ "regex", ] -[[package]] -name = "bindgen" -version = "0.71.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" -dependencies = [ - "bitflags 2.11.0", - "cexpr", - "clang-sys", - "itertools 0.13.0", - "log", - "prettyplease", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", - "syn 2.0.115", -] - [[package]] name = "bindgen" version = "0.72.1" @@ -1016,6 +1002,28 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "font8x8" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "875488b8711a968268c7cf5d139578713097ca4635a76044e8fe8eedf831d07e" + +[[package]] +name = "fontdue" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e57e16b3fe8ff4364c0661fdaac543fb38b29ea9bc9c2f45612d90adf931d2b" +dependencies = [ + "hashbrown 0.15.5", + "ttf-parser", +] + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -1175,6 +1183,17 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + [[package]] name = "hashbrown" version = "0.16.1" @@ -1490,7 +1509,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", ] [[package]] @@ -1704,7 +1723,7 @@ version = "0.1.138" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84a529006bf16af70c7485ba957820dc2bc9467d75697e97970c81d2da73c76f" dependencies = [ - "bindgen 0.72.1", + "bindgen", "cc", "cmake", "find_cuda_helper", @@ -2596,6 +2615,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "seq-macro" version = "0.3.6" @@ -3272,6 +3297,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "ttf-parser" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c591d83f69777866b9126b24c6dd9a18351f177e49d625920d19f989fd31cf8" + [[package]] name = "typenum" version = "1.19.0" @@ -3569,28 +3600,29 @@ dependencies = [ [[package]] name = "whisper-rs" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71ea5d2401f30f51d08126a2d133fee4c1955136519d7ac6cf6f5ac0a91e6bc8" +checksum = "2088172d00f936c348d6a72f488dc2660ab3f507263a195df308a3c2383229f6" dependencies = [ "whisper-rs-sys", ] [[package]] name = "whisper-rs-sys" -version = "0.14.1" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e2a6e06e7ac7b8f53c53a5f50bb0bc823ba69b63ecd887339f807a5598bbd2" +checksum = "6986c0fe081241d391f09b9a071fbcbb59720c3563628c3c829057cf69f2a56f" dependencies = [ - "bindgen 0.71.1", + "bindgen", "cfg-if", "cmake", "fs_extra", + "semver", ] [[package]] name = "whispers" -version = "0.1.0" +version = "0.2.0" dependencies = [ "base64 0.22.1", "clap", @@ -3602,6 +3634,8 @@ dependencies = [ "encoding_rs", "evdev", "flacenc", + "font8x8", + "fontdue", "futures-util", "httpmock", "indicatif", diff --git a/Cargo.toml b/Cargo.toml index 8e0dced..e887bce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "whispers" -version = "0.1.0" +version = "0.2.0" edition = "2024" rust-version = "1.85" description = "Speech-to-text dictation tool for Wayland" @@ -18,7 +18,7 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal", "tim cpal = "0.17" # Whisper transcription -whisper-rs = "0.15" +whisper-rs = "0.16" llama-cpp-2 = "0.1.138" # uinput virtual keyboard for paste keystroke @@ -63,11 +63,19 @@ console = "0.16" wayland-client = { version = "0.31", optional = true } wayland-protocols = { version = "0.32", features = ["client"], optional = true } wayland-protocols-wlr = { version = "0.3", features = ["client"], optional = true } +font8x8 = { version = "0.3", optional = true } +fontdue = { version = "0.9", optional = true } [features] -default = ["osd"] +default = ["cuda", "osd"] cuda = ["whisper-rs/cuda", "llama-cpp-2/cuda"] -osd = ["dep:wayland-client", "dep:wayland-protocols", "dep:wayland-protocols-wlr"] +osd = [ + "dep:wayland-client", + "dep:wayland-protocols", + "dep:wayland-protocols-wlr", + "dep:font8x8", + "dep:fontdue", +] [[bin]] name = "whispers" diff --git a/README.md b/README.md index c372fb1..7eaec3b 100644 --- a/README.md +++ b/README.md @@ -24,14 +24,13 @@ The two invocations communicate via PID file + `SIGUSR1` — no daemon, no IPC s `whispers` now has three main dictation modes: - `raw` keeps output close to the direct transcription result and is the default -- `advanced_local` enables the smart rewrite pipeline after transcription; `[rewrite].backend` chooses whether that rewrite runs locally or in the cloud -- `agentic_rewrite` uses the same local/cloud rewrite backends as `advanced_local`, but adds app-aware policy rules, technical glossary guidance, and a stricter conservative acceptance guard +- `rewrite` runs the LLM-based rewrite pipeline after transcription; `[rewrite].backend` chooses whether that rewrite runs locally or in the cloud, and the same config also carries app-aware policy rules, glossary guidance, and correction-policy settings The older heuristic cleanup path is still available as deprecated `legacy_basic` for existing configs that already use `[cleanup]`. The local rewrite path is managed by `whispers` itself through an internal helper binary installed alongside the main executable, so there is no separate tool or daemon to install manually. When a rewrite mode is enabled with `rewrite.backend = "local"`, `whispers` keeps a hidden rewrite worker warm for a short idle window so repeated dictation is much faster without becoming a permanent background daemon. Managed rewrite models are the default path. If you point `rewrite.model_path` at your own GGUF, it should be a chat-capable model with an embedded template that `llama.cpp` can apply at runtime. -Deterministic personalization rules apply in all modes: dictionary replacements and spoken snippets. Custom rewrite instructions apply to both rewrite modes, and `agentic_rewrite` can additionally load app rules and glossary entries from separate TOML files. +Deterministic personalization rules apply in all modes: dictionary replacements and spoken snippets. Rewrite instructions, app rules, glossary entries, and correction-policy defaults all live under `[rewrite]`. Older `advanced_local` and `agentic_rewrite` mode names are still accepted as deprecated aliases when reading existing configs. Cloud ASR and cloud rewrite are both optional. Local remains the default. For file transcription, `whispers transcribe --raw ` always prints the plain ASR transcript without any post-processing. @@ -52,32 +51,32 @@ For file transcription, `whispers transcribe --raw ` always prints the pla ### From crates.io ```sh -# Default install: CPU build with Wayland OSD +# Default install: CUDA build with Wayland OSD cargo install whispers -# Enable CUDA acceleration explicitly -cargo install whispers --features cuda +# CPU-only build with Wayland OSD +cargo install whispers --no-default-features --features osd -# Build without the OSD overlay +# CPU-only build without the OSD overlay cargo install whispers --no-default-features ``` ### From git ```sh -# Default install: CPU build with Wayland OSD +# Default install: CUDA build with Wayland OSD cargo install --git https://github.com/OneNoted/whispers -# Enable CUDA acceleration explicitly -cargo install --git https://github.com/OneNoted/whispers --features cuda +# CPU-only build with Wayland OSD +cargo install --git https://github.com/OneNoted/whispers --no-default-features --features osd -# Build without the OSD overlay +# CPU-only build without the OSD overlay cargo install --git https://github.com/OneNoted/whispers --no-default-features ``` ### Setup -Run the interactive setup wizard to download a local ASR model, generate config, and optionally enable local or cloud advanced dictation. Recommended local models are shown first, and experimental backends like Parakeet are called out explicitly before you opt into them: +Run the interactive setup wizard to download a local ASR model, generate config, optionally enable local or cloud advanced dictation, and offer shell completion install for the supported shells it finds on your `PATH`. Recommended local models are shown first, and experimental backends like Parakeet are called out explicitly before you opt into them: ```sh whispers setup @@ -129,7 +128,7 @@ That still remains a single install: `whispers` manages local ASR models, the op ## Shell completions -Print completion scripts to `stdout`: +`whispers setup` can detect supported shells on your `PATH` and install completions for one shell or all detected shells. You can also print completion scripts to `stdout` manually: ```sh # auto-detect from $SHELL (falls back to parent process name) @@ -196,7 +195,7 @@ flash_attn = true # only used when use_gpu=true idle_timeout_ms = 120000 [postprocess] -mode = "raw" # or "advanced_local" / "agentic_rewrite"; deprecated: "legacy_basic" +mode = "raw" # or "rewrite"; deprecated: "legacy_basic" [session] enabled = true @@ -220,8 +219,6 @@ timeout_ms = 30000 idle_timeout_ms = 120000 max_output_chars = 1200 max_tokens = 256 - -[agentic_rewrite] policy_path = "~/.local/share/whispers/app-rewrite-policy.toml" glossary_path = "~/.local/share/whispers/technical-glossary.toml" default_correction_policy = "balanced" @@ -250,7 +247,7 @@ start_sound = "" # empty = bundled sound stop_sound = "" ``` -When `advanced_local` or `agentic_rewrite` is enabled, `whispers` also keeps a short-lived local session ledger in the runtime directory so immediate follow-up corrections like `scratch that` can safely replace the most recent dictation entry when focus has not changed. That session behavior is local either way; only the semantic rewrite stage may be cloud-backed. +When `rewrite` is enabled, `whispers` also keeps a short-lived local session ledger in the runtime directory so immediate follow-up corrections like `scratch that` can safely replace the most recent dictation entry when focus has not changed. That session behavior is local either way; only the semantic rewrite stage may be cloud-backed. ## Cloud Modes @@ -306,13 +303,13 @@ Custom rewrite models should include a chat template that `llama.cpp` can read f ## Personalization -Dictionary replacements apply deterministically in `raw`, `advanced_local`, and `agentic_rewrite`, with normalization for case and punctuation but no fuzzy matching. In the rewrite modes, dictionary replacements are applied before the rewrite model and again on the final output so exact names and product terms stay stable. +Dictionary replacements apply deterministically in `raw` and `rewrite`, with normalization for case and punctuation but no fuzzy matching. In rewrite mode, dictionary replacements are applied before the rewrite model and again on the final output so exact names and product terms stay stable. Spoken snippets also work in all modes. By default, saying `insert ` expands the configured snippet text verbatim after post-processing finishes, so the rewrite model cannot paraphrase it. Change the trigger phrase with `personalization.snippet_trigger`. Custom rewrite instructions live in a separate plain-text file referenced by `rewrite.instructions_path`. `whispers` appends that file to the built-in rewrite prompt for both rewrite modes while still enforcing the same final-text-only output contract. The file is optional, and a missing file is ignored. -`agentic_rewrite` also reads layered app rules from `agentic_rewrite.policy_path` and scoped glossary entries from `agentic_rewrite.glossary_path`. `whispers setup` creates commented starter files for both when you choose the agentic mode, and the minimal CRUD commands above are available for path/list/add/remove workflows. +`rewrite` also reads layered app rules from `rewrite.policy_path` and scoped glossary entries from `rewrite.glossary_path`. `whispers setup` creates commented starter files for both when you choose rewrite mode, and the minimal CRUD commands above are available for path/list/add/remove workflows. ## Faster Whisper diff --git a/config.example.toml b/config.example.toml index 0bcb96a..ed1985a 100644 --- a/config.example.toml +++ b/config.example.toml @@ -40,8 +40,7 @@ idle_timeout_ms = 120000 [postprocess] # "raw" keeps output close to Whisper -# "advanced_local" enables the rewrite model -# "agentic_rewrite" adds app-aware policy and glossary guidance on top of the rewrite model +# "rewrite" enables the unified rewrite model, app-aware policy, and glossary guidance # "legacy_basic" is deprecated and only kept for older cleanup-based configs mode = "raw" @@ -68,7 +67,7 @@ snippet_trigger = "insert" backend = "local" # Cloud rewrite failure policy ("local" or "none") fallback = "local" -# Managed rewrite model name for advanced_local mode +# Managed rewrite model name for rewrite mode selected_model = "qwen-3.5-4b-q4_k_m" # Manual GGUF path override (empty = use selected managed model) # Custom rewrite models should be chat-capable GGUFs with an embedded @@ -84,11 +83,9 @@ timeout_ms = 30000 # How long the hidden rewrite worker stays warm without requests (0 = never expire) idle_timeout_ms = 120000 # Maximum characters accepted from the rewrite model -max_output_chars = 1200 +max_output_chars = 8192 # Maximum tokens to generate for rewritten output -max_tokens = 256 - -[agentic_rewrite] +max_tokens = 768 # App-aware rewrite policy TOML file policy_path = "~/.local/share/whispers/app-rewrite-policy.toml" # Technical glossary TOML file @@ -132,3 +129,19 @@ enabled = true # Custom sound file paths (empty = use bundled sounds) start_sound = "" stop_sound = "" + +[voice] +# Experimental live voice mode. Leave live injection off until you trust it. +live_inject = false +# Show a rewrite-preview line in the OSD while dictating +live_rewrite = false +# How often to refresh the live ASR preview while recording +partial_interval_ms = 400 +# Minimum time between live rewrite preview updates +rewrite_interval_ms = 1400 +# Audio tail window retranscribed for each live ASR refresh +context_window_ms = 8000 +# Minimum recorded audio before live ASR starts updating +min_chunk_ms = 650 +# Freeze live injection if focus changes during the session +freeze_on_focus_change = true diff --git a/src/agentic_rewrite.rs b/src/agentic_rewrite.rs index 3e19b76..a6b732c 100644 --- a/src/agentic_rewrite.rs +++ b/src/agentic_rewrite.rs @@ -2,7 +2,7 @@ use std::path::Path; use serde::{Deserialize, Serialize}; -use crate::config::{Config, PostprocessMode}; +use crate::config::Config; use crate::error::{Result, WhsprError}; use crate::rewrite_protocol::{ RewriteCandidate, RewriteCandidateKind, RewriteCorrectionPolicy, RewritePolicyContext, @@ -13,7 +13,7 @@ const DEFAULT_POLICY_PATH: &str = "~/.local/share/whispers/app-rewrite-policy.to const DEFAULT_GLOSSARY_PATH: &str = "~/.local/share/whispers/technical-glossary.toml"; const MAX_GLOSSARY_CANDIDATES: usize = 4; -const POLICY_STARTER: &str = r#"# App-aware rewrite policy for whispers agentic_rewrite mode. +const POLICY_STARTER: &str = r#"# App-aware rewrite policy for whispers rewrite mode. # Rules are layered, not first-match. Matching rules apply in this order: # global defaults, surface_kind, app_id, window_title_contains, browser_domain_contains. # Later, more specific rules override earlier fields. @@ -38,7 +38,7 @@ const POLICY_STARTER: &str = r#"# App-aware rewrite policy for whispers agentic_ # instructions = "Preserve identifiers, filenames, snake_case, camelCase, and Rust terminology." "#; -const GLOSSARY_STARTER: &str = r#"# Technical glossary for whispers agentic_rewrite mode. +const GLOSSARY_STARTER: &str = r#"# Technical glossary for whispers rewrite mode. # Each entry defines a canonical term plus likely spoken or mis-transcribed aliases. # # Uncomment and edit the examples below. @@ -240,7 +240,7 @@ pub fn apply_runtime_policy(config: &Config, transcript: &mut RewriteTranscript) let glossary_entries = load_glossary_file_for_runtime(&config.resolved_agentic_glossary_path()); let policy_context = resolve_policy_context( - config.agentic_rewrite.default_correction_policy, + config.rewrite.default_correction_policy, transcript.typing_context.as_ref(), &transcript.rewrite_candidates, &policy_rules, @@ -259,6 +259,50 @@ pub fn apply_runtime_policy(config: &Config, transcript: &mut RewriteTranscript) } transcript.policy_context = policy_context; + promote_policy_preferred_candidate(transcript); +} + +fn promote_policy_preferred_candidate(transcript: &mut RewriteTranscript) { + if matches!( + transcript.policy_context.correction_policy, + RewriteCorrectionPolicy::Conservative + ) { + return; + } + + let preferred = transcript + .policy_context + .glossary_candidates + .iter() + .find(|candidate| { + let text = candidate.text.trim(); + !text.is_empty() && text != transcript.correction_aware_text.trim() + }) + .cloned(); + + let Some(preferred) = preferred else { + return; + }; + + if transcript.recommended_candidate.is_some() { + return; + } + + transcript.recommended_candidate = Some(preferred.clone()); + tracing::debug!( + preferred_candidate = preferred.text, + "promoted glossary-backed rewrite candidate to recommended candidate" + ); + + if let Some(index) = transcript + .rewrite_candidates + .iter() + .position(|candidate| candidate.text == preferred.text) + && index > 0 + { + let candidate = transcript.rewrite_candidates.remove(index); + transcript.rewrite_candidates.insert(0, candidate); + } } pub fn conservative_output_allowed(transcript: &RewriteTranscript, text: &str) -> bool { @@ -351,7 +395,7 @@ fn levenshtein_distance(left: &str, right: &str) -> usize { } pub fn ensure_starter_files(config: &Config) -> Result> { - if config.postprocess.mode != PostprocessMode::AgenticRewrite { + if !config.postprocess.mode.uses_rewrite() { return Ok(Vec::new()); } @@ -534,12 +578,11 @@ fn resolve_policy_context( } } - let mut active_glossary_entries = glossary_entries - .iter() + let mut active_glossary_entries = built_in_glossary_entries() + .into_iter() + .chain(glossary_entries.iter().cloned()) .enumerate() - .filter_map(|(index, entry)| { - PreparedGlossaryEntry::new(entry.clone()).map(|entry| (index, entry)) - }) + .filter_map(|(index, entry)| PreparedGlossaryEntry::new(entry).map(|entry| (index, entry))) .filter(|(_, entry)| entry.matcher.matches(context)) .collect::>(); active_glossary_entries @@ -566,7 +609,7 @@ fn built_in_rules(default_policy: RewriteCorrectionPolicy) -> Vec { AppRule::built_in( "baseline/global-default", ContextMatcher::default(), - "Use the active typing context, recent dictation context, glossary terms, and bounded candidates to resolve technical dictation cleanly while keeping the final-text-only contract. When the utterance clearly points to software, tools, APIs, Linux components, product names, or other technical concepts, prefer the most plausible intended technical term over a phonetically similar common word. Use category cues like window manager, editor, language, library, shell, or package manager to disambiguate nearby technical names. If it remains genuinely ambiguous, stay close to the transcript.", + "Use the active typing context, recent dictation context, glossary terms, and bounded candidates to resolve technical dictation cleanly while keeping the final-text-only contract. When the utterance clearly points to software, tools, APIs, Linux components, product names, or other technical concepts, prefer the most plausible intended technical term over a phonetically similar common word. If a token is an obvious phonetic near-miss for a technical term and nearby category words make the intended term clear, proactively normalize it to the canonical spelling. Use category cues like window manager, editor, language, library, shell, or package manager to disambiguate nearby technical names. If multiple plausible technical interpretations remain similarly credible, stay close to the transcript.", Some(default_policy), ), AppRule::built_in( @@ -575,7 +618,7 @@ fn built_in_rules(default_policy: RewriteCorrectionPolicy) -> Vec { surface_kind: Some(RewriteSurfaceKind::Browser), ..ContextMatcher::default() }, - "Favor clean prose and natural punctuation for browser text fields, but stay grounded in the listed candidates, glossary evidence, and the utterance's technical topic when it clearly refers to software or documentation.", + "Favor clean prose and natural punctuation for browser text fields, but stay grounded in the listed candidates, glossary evidence, and the utterance's technical topic when it clearly refers to software or documentation. Correct obvious phonetic misses of technical terms or product names when the surrounding sentence makes the intended term clear.", Some(RewriteCorrectionPolicy::Balanced), ), AppRule::built_in( @@ -584,7 +627,7 @@ fn built_in_rules(default_policy: RewriteCorrectionPolicy) -> Vec { surface_kind: Some(RewriteSurfaceKind::GenericText), ..ContextMatcher::default() }, - "Favor clean prose and natural punctuation for general text entry while staying grounded in the listed candidates and glossary evidence. If the utterance clearly discusses technical tools or software, prefer the most plausible technical term over a phonetically similar common word.", + "Favor clean prose and natural punctuation for general text entry while staying grounded in the listed candidates and glossary evidence. If the utterance clearly discusses technical tools or software, prefer the most plausible technical term over a phonetically similar common word and proactively fix obvious phonetic near-misses to the canonical spelling.", Some(RewriteCorrectionPolicy::Balanced), ), AppRule::built_in( @@ -593,7 +636,7 @@ fn built_in_rules(default_policy: RewriteCorrectionPolicy) -> Vec { surface_kind: Some(RewriteSurfaceKind::Editor), ..ContextMatcher::default() }, - "Preserve identifiers, filenames, API names, symbols, and technical casing. Avoid rewriting technical wording into generic prose. Infer likely technical terms and proper names from the utterance when the topic is clearly code, tooling, or software.", + "Preserve identifiers, filenames, API names, symbols, and technical casing. Avoid rewriting technical wording into generic prose. Infer likely technical terms and proper names from the utterance when the topic is clearly code, tooling, or software, and proactively normalize obvious phonetic misses to the canonical technical spelling.", Some(RewriteCorrectionPolicy::Balanced), ), AppRule::built_in( @@ -608,6 +651,39 @@ fn built_in_rules(default_policy: RewriteCorrectionPolicy) -> Vec { ] } +fn built_in_glossary_entries() -> Vec { + vec![ + GlossaryEntry { + term: "TypeScript".into(), + aliases: vec!["type script".into(), "types script".into()], + matcher: ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::Editor), + ..ContextMatcher::default() + }, + }, + GlossaryEntry { + term: "Neovim".into(), + aliases: vec!["neo vim".into(), "neo-vim".into()], + matcher: ContextMatcher::default(), + }, + GlossaryEntry { + term: "Hyprland".into(), + aliases: vec!["hyperland".into(), "hyper land".into(), "highprland".into()], + matcher: ContextMatcher::default(), + }, + GlossaryEntry { + term: "niri".into(), + aliases: vec!["neary".into(), "niry".into(), "nearie".into()], + matcher: ContextMatcher::default(), + }, + GlossaryEntry { + term: "Sway".into(), + aliases: vec!["sui".into(), "swayy".into()], + matcher: ContextMatcher::default(), + }, + ] +} + fn matching_rules(rules: &[AppRule], context: Option<&RewriteTypingContext>) -> Vec { let mut matches = rules .iter() @@ -785,7 +861,7 @@ fn normalize_word(word: &str) -> String { } fn is_word_char(ch: char) -> bool { - ch.is_alphanumeric() || matches!(ch, '\'' | '-' | '_' | '.') + ch.is_alphanumeric() || matches!(ch, '\'' | '-' | '_') } fn contains_ignore_ascii_case(haystack: Option<&str>, needle: &str) -> bool { @@ -993,6 +1069,7 @@ mod tests { text: "type script and sir dee json".into(), }], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), } } @@ -1033,7 +1110,14 @@ mod tests { policy .effective_rule_instructions .iter() - .any(|instruction| instruction.contains("phonetically similar common word")) + .any(|instruction| instruction + .contains("proactively fix obvious phonetic near-misses")) + ); + assert!( + policy + .effective_rule_instructions + .iter() + .any(|instruction| instruction.contains("obvious phonetic near-miss")) ); } @@ -1110,7 +1194,18 @@ mod tests { &[], &glossary, ); - assert_eq!(policy.active_glossary_terms.len(), 2); + assert!( + policy + .active_glossary_terms + .iter() + .any(|term| term.term == "TypeScript") + ); + assert!( + policy + .active_glossary_terms + .iter() + .any(|term| term.term == "serde_json") + ); assert_eq!(policy.glossary_candidates.len(), 1); assert_eq!( policy.glossary_candidates[0].text, @@ -1160,10 +1255,16 @@ mod tests { text: "I'm currently using the window manager hyperland.".into(), }], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), }; - hyperland_transcript.policy_context.correction_policy = - RewriteCorrectionPolicy::Conservative; + hyperland_transcript.policy_context = resolve_policy_context( + RewriteCorrectionPolicy::Conservative, + hyperland_transcript.typing_context.as_ref(), + &hyperland_transcript.rewrite_candidates, + &[], + &[], + ); assert!(conservative_output_allowed( &hyperland_transcript, @@ -1188,9 +1289,16 @@ mod tests { text: "I'm switching from Sui to Hyperland.".into(), }], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), }; - switch_transcript.policy_context.correction_policy = RewriteCorrectionPolicy::Conservative; + switch_transcript.policy_context = resolve_policy_context( + RewriteCorrectionPolicy::Conservative, + switch_transcript.typing_context.as_ref(), + &switch_transcript.rewrite_candidates, + &[], + &[], + ); assert!(conservative_output_allowed( &switch_transcript, @@ -1218,6 +1326,7 @@ mod tests { text: "cargo clipy".into(), }], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), }; transcript.policy_context.correction_policy = RewriteCorrectionPolicy::Conservative; @@ -1261,6 +1370,42 @@ mod tests { .iter() .any(|candidate| candidate.text == "TypeScript and sir dee json") ); + assert_eq!( + transcript + .recommended_candidate + .as_ref() + .map(|candidate| candidate.text.as_str()), + Some("TypeScript and sir dee json") + ); + assert_eq!( + transcript + .rewrite_candidates + .first() + .map(|candidate| candidate.text.as_str()), + Some("TypeScript and sir dee json") + ); + } + + #[test] + fn built_in_glossary_candidates_cover_window_manager_terms() { + let context = typing_context(RewriteSurfaceKind::GenericText); + let policy = resolve_policy_context( + RewriteCorrectionPolicy::Balanced, + Some(&context), + &[RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "I'm switching from sui to hyperland and neary.".into(), + }], + &[], + &[], + ); + + assert!( + policy + .glossary_candidates + .iter() + .any(|candidate| candidate.text == "I'm switching from Sway to Hyprland and niri.") + ); } #[test] diff --git a/src/app.rs b/src/app.rs index 7cace96..0e83881 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,9 +1,5 @@ -use std::process::Child; use std::time::Instant; -#[cfg(feature = "osd")] -use std::process::Command; - use crate::asr; use crate::audio::AudioRecorder; use crate::config::Config; @@ -11,6 +7,7 @@ use crate::context; use crate::error::Result; use crate::feedback::FeedbackPlayer; use crate::inject::TextInjector; +use crate::osd::{OsdHandle, OsdMode}; use crate::postprocess; use crate::session; @@ -39,7 +36,7 @@ pub async fn run(config: Config) -> Result<()> { }; let mut recorder = AudioRecorder::new(&config.audio); recorder.start()?; - let mut osd = spawn_osd(); + let mut osd = OsdHandle::spawn(OsdMode::Meter); tracing::info!("recording... (run whispers again to stop)"); let transcriber = asr::prepare_transcriber(&config)?; @@ -55,13 +52,13 @@ pub async fn run(config: Config) -> Result<()> { } _ = tokio::signal::ctrl_c() => { tracing::info!("interrupted, cancelling"); - kill_osd(&mut osd); + osd.kill(); recorder.stop()?; return Ok(()); } _ = sigterm.recv() => { tracing::info!("terminated, cancelling"); - kill_osd(&mut osd); + osd.kill(); recorder.stop()?; return Ok(()); } @@ -69,7 +66,7 @@ pub async fn run(config: Config) -> Result<()> { // Stop recording before playing feedback so the stop sound doesn't // leak into the mic. - kill_osd(&mut osd); + osd.kill(); let audio = recorder.stop()?; feedback.play_stop(); let sample_rate = config.audio.sample_rate; @@ -144,7 +141,7 @@ pub async fn run(config: Config) -> Result<()> { let injector = TextInjector::new(); match finalized.operation { postprocess::FinalizedOperation::Append => { - injector.inject(&finalized.text).await?; + injector.inject(&finalized.text, &injection_context).await?; if session_enabled { session::record_append( &config.session, @@ -159,7 +156,7 @@ pub async fn run(config: Config) -> Result<()> { delete_graphemes, } => { injector - .replace_recent_text(delete_graphemes, &finalized.text) + .replace_recent_text(delete_graphemes, &finalized.text, &injection_context) .await?; if session_enabled { session::record_replace( @@ -180,55 +177,3 @@ pub async fn run(config: Config) -> Result<()> { ); Ok(()) } - -#[cfg(feature = "osd")] -fn spawn_osd() -> Option { - // Look for whispers-osd next to our own binary first, then fall back to PATH - let osd_path = std::env::current_exe() - .ok() - .and_then(|p| p.parent().map(|dir| dir.join("whispers-osd"))) - .filter(|p| p.exists()) - .unwrap_or_else(|| "whispers-osd".into()); - - match Command::new(&osd_path).spawn() { - Ok(child) => { - tracing::debug!("spawned whispers-osd (pid {})", child.id()); - Some(child) - } - Err(e) => { - tracing::warn!( - "failed to spawn whispers-osd from {}: {e}", - osd_path.display() - ); - None - } - } -} - -#[cfg(not(feature = "osd"))] -fn spawn_osd() -> Option { - None -} - -fn kill_osd(child: &mut Option) { - if let Some(mut c) = child.take() { - let pid = c.id() as libc::pid_t; - unsafe { - libc::kill(pid, libc::SIGTERM); - } - let _ = c.wait(); - tracing::debug!("whispers-osd (pid {pid}) terminated"); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn kill_osd_none_is_noop() { - let mut child: Option = None; - kill_osd(&mut child); - assert!(child.is_none()); - } -} diff --git a/src/asr.rs b/src/asr.rs index 0014bf7..a1da40f 100644 --- a/src/asr.rs +++ b/src/asr.rs @@ -9,6 +9,7 @@ use crate::transcribe::{ }; use std::collections::HashSet; use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; pub enum PreparedTranscriber { Whisper(tokio::task::JoinHandle>), @@ -17,6 +18,13 @@ pub enum PreparedTranscriber { Cloud(CloudService), } +pub enum LiveTranscriber { + Whisper(Arc>), + Faster(FasterWhisperService), + Nemo(NemoAsrService), + Cloud(CloudService), +} + pub fn prepare_transcriber(config: &Config) -> Result { cleanup_stale_transcribers(config)?; @@ -49,6 +57,41 @@ pub fn prepare_transcriber(config: &Config) -> Result { } } +pub async fn prepare_live_transcriber(config: &Config) -> Result { + cleanup_stale_transcribers(config)?; + + match config.transcription.backend { + TranscriptionBackend::WhisperCpp => { + let whisper_config = config.transcription.clone(); + let model_path = config.resolved_model_path(); + let backend = tokio::task::spawn_blocking(move || { + WhisperLocal::new(&whisper_config, &model_path) + }) + .await + .map_err(|e| WhsprError::Transcription(format!("model loading task failed: {e}")))??; + Ok(LiveTranscriber::Whisper(Arc::new(Mutex::new(backend)))) + } + TranscriptionBackend::FasterWhisper => { + faster_whisper::prepare_service(&config.transcription) + .map(LiveTranscriber::Faster) + .ok_or_else(|| { + WhsprError::Transcription( + "faster-whisper backend selected but no model path could be resolved" + .into(), + ) + }) + } + TranscriptionBackend::Nemo => nemo_asr::prepare_service(&config.transcription) + .map(LiveTranscriber::Nemo) + .ok_or_else(|| { + WhsprError::Transcription( + "nemo backend selected but no model reference could be resolved".into(), + ) + }), + TranscriptionBackend::Cloud => Ok(LiveTranscriber::Cloud(CloudService::new(config)?)), + } +} + pub fn cleanup_stale_transcribers(config: &Config) -> Result<()> { let retained = retained_socket_paths(config); let stale_workers = collect_stale_asr_workers(&retained)?; @@ -89,6 +132,20 @@ pub fn prewarm_transcriber(prepared: &PreparedTranscriber, phase: &str) { } } +pub fn prewarm_live_transcriber(prepared: &LiveTranscriber, phase: &str) { + match prepared { + LiveTranscriber::Faster(service) => match service.prewarm() { + Ok(()) => tracing::info!("prewarming faster-whisper worker via {}", phase), + Err(err) => tracing::warn!("failed to prewarm faster-whisper worker: {err}"), + }, + LiveTranscriber::Nemo(service) => match service.prewarm() { + Ok(()) => tracing::info!("prewarming NeMo ASR worker via {}", phase), + Err(err) => tracing::warn!("failed to prewarm NeMo ASR worker: {err}"), + }, + LiveTranscriber::Whisper(_) | LiveTranscriber::Cloud(_) => {} + } +} + pub async fn transcribe_audio( config: &Config, prepared: PreparedTranscriber, @@ -131,6 +188,52 @@ pub async fn transcribe_audio( } } +pub async fn transcribe_live_audio( + config: &Config, + prepared: &LiveTranscriber, + audio: Vec, + sample_rate: u32, +) -> Result { + match prepared { + LiveTranscriber::Whisper(backend) => { + let backend = Arc::clone(backend); + tokio::task::spawn_blocking(move || { + let backend = backend.lock().map_err(|_| { + WhsprError::Transcription("live whisper backend lock poisoned".into()) + })?; + backend.transcribe(&audio, sample_rate) + }) + .await + .map_err(|e| WhsprError::Transcription(format!("transcription task failed: {e}")))? + } + LiveTranscriber::Faster(service) => { + match service.transcribe_live(&audio, sample_rate).await { + Ok(transcript) => Ok(transcript), + Err(err) => { + tracing::warn!("faster-whisper transcription failed: {err}"); + fallback_whisper_cpp_transcribe(config, audio, sample_rate).await + } + } + } + LiveTranscriber::Nemo(service) => match service.transcribe(&audio, sample_rate).await { + Ok(transcript) => Ok(transcript), + Err(err) => { + tracing::warn!("NeMo ASR transcription failed: {err}"); + fallback_whisper_cpp_transcribe(config, audio, sample_rate).await + } + }, + LiveTranscriber::Cloud(service) => { + match service.transcribe_audio(config, &audio, sample_rate).await { + Ok(transcript) => Ok(transcript), + Err(err) => { + tracing::warn!("cloud transcription failed: {err}"); + fallback_local_transcribe(config, audio, sample_rate).await + } + } + } + } +} + async fn fallback_local_transcribe( config: &Config, audio: Vec, diff --git a/src/asr_protocol.rs b/src/asr_protocol.rs index d0a5f0c..e521321 100644 --- a/src/asr_protocol.rs +++ b/src/asr_protocol.rs @@ -8,6 +8,8 @@ pub enum AsrRequest { Transcribe { audio_f32_b64: String, sample_rate: u32, + #[serde(default)] + live: bool, }, } diff --git a/src/audio.rs b/src/audio.rs index fd0f940..e7867f1 100644 --- a/src/audio.rs +++ b/src/audio.rs @@ -185,6 +185,13 @@ impl AudioRecorder { preprocess_audio(&mut buffer, self.config.sample_rate); Ok(buffer) } + + pub fn snapshot(&self) -> Result> { + self.buffer + .lock() + .map(|buffer| buffer.clone()) + .map_err(|_| WhsprError::Audio("audio buffer lock poisoned".into())) + } } pub fn preprocess_audio(samples: &mut Vec, sample_rate: u32) { @@ -213,6 +220,31 @@ pub fn preprocess_audio(samples: &mut Vec, sample_rate: u32) { ); } +pub fn preprocess_live_audio(samples: &mut [f32], sample_rate: u32) { + if samples.is_empty() || sample_rate == 0 { + return; + } + + let before_len = samples.len(); + let before = audio_stats(samples); + + remove_dc_offset(samples); + apply_highpass(samples, sample_rate, HIGHPASS_CUTOFF_HZ); + let gain = normalize_peak(samples); + + let after = audio_stats(samples); + tracing::debug!( + "live audio preprocessing: len {} -> {}, rms {:.4} -> {:.4}, peak {:.4} -> {:.4}, gain {:.2}x", + before_len, + samples.len(), + before.rms, + after.rms, + before.peak, + after.peak, + gain + ); +} + #[derive(Clone, Copy)] struct AudioStats { rms: f32, diff --git a/src/bin/whispers-osd.rs b/src/bin/whispers-osd.rs index ed47233..a5f7e61 100644 --- a/src/bin/whispers-osd.rs +++ b/src/bin/whispers-osd.rs @@ -1,63 +1,83 @@ #[path = "../branding.rs"] mod branding; +#[path = "../osd_protocol.rs"] +mod osd_protocol; use std::ffi::CString; +use std::io::{BufRead, BufReader}; use std::os::fd::AsRawFd; use std::os::unix::io::{AsFd, FromRawFd}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::sync::OnceLock; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::time::Instant; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use font8x8::{BASIC_FONTS, UnicodeFonts}; +use fontdue::Font; +use osd_protocol::{OsdEvent, VoiceOsdStatus, VoiceOsdUpdate}; use wayland_client::protocol::{ wl_buffer, wl_compositor, wl_registry, wl_shm, wl_shm_pool, wl_surface, }; use wayland_client::{Connection, Dispatch, QueueHandle, delegate_noop}; use wayland_protocols_wlr::layer_shell::v1::client::{zwlr_layer_shell_v1, zwlr_layer_surface_v1}; -// --- Layout --- -const NUM_BARS: usize = 28; +const NUM_BARS: usize = 12; const BAR_WIDTH: u32 = 3; const BAR_GAP: u32 = 2; -const PAD_X: u32 = 10; -const PAD_Y: u32 = 8; +const PAD_X: u32 = 12; +const PAD_Y: u32 = 6; const BAR_MIN_HEIGHT: f32 = 2.0; -const BAR_MAX_HEIGHT: f32 = 30.0; -const OSD_WIDTH: u32 = PAD_X * 2 + NUM_BARS as u32 * BAR_WIDTH + (NUM_BARS as u32 - 1) * BAR_GAP; -const OSD_HEIGHT: u32 = BAR_MAX_HEIGHT as u32 + PAD_Y * 2; +const BAR_MAX_HEIGHT: f32 = 20.0; +const METER_WIDTH: u32 = PAD_X * 2 + NUM_BARS as u32 * BAR_WIDTH + (NUM_BARS as u32 - 1) * BAR_GAP; +const METER_HEIGHT: u32 = BAR_MAX_HEIGHT as u32 + PAD_Y * 2; +const VOICE_WIDTH: u32 = 760; +const VOICE_HEIGHT: u32 = 248; const MARGIN_BOTTOM: i32 = 40; -const CORNER_RADIUS: u32 = 12; +const CORNER_RADIUS: u32 = 16; const BORDER_WIDTH: u32 = 1; -const RISE_RATE: f32 = 0.55; -const DECAY_RATE: f32 = 0.88; - -// --- Animation --- +const RISE_RATE: f32 = 0.40; +const DECAY_RATE: f32 = 0.92; const FPS: i32 = 30; const FRAME_MS: i32 = 1000 / FPS; -// --- Colors --- -const BG_R: u8 = 18; -const BG_G: u8 = 18; -const BG_B: u8 = 30; -const BG_A: u8 = 185; - -const BORDER_R: u8 = 140; -const BORDER_G: u8 = 180; +const BG_R: u8 = 15; +const BG_G: u8 = 16; +const BG_B: u8 = 24; +const BG_A: u8 = 160; +const METER_BG_A: u8 = 56; +const BORDER_R: u8 = 200; +const BORDER_G: u8 = 215; const BORDER_B: u8 = 255; -const BORDER_A: u8 = 40; - -// Bar gradient: teal → violet -const BAR_LEFT_R: f32 = 0.0; -const BAR_LEFT_G: f32 = 0.82; -const BAR_LEFT_B: f32 = 0.75; -const BAR_RIGHT_R: f32 = 0.65; -const BAR_RIGHT_G: f32 = 0.35; -const BAR_RIGHT_B: f32 = 1.0; +const BORDER_A: u8 = 22; + +const BAR_REST_R: f32 = 0.863; +const BAR_REST_G: f32 = 0.882; +const BAR_REST_B: f32 = 0.922; +const BAR_REST_A: f32 = 0.706; +const BAR_PEAK_R: f32 = 0.392; +const BAR_PEAK_G: f32 = 0.608; +const BAR_PEAK_B: f32 = 1.0; +const BAR_PEAK_A: f32 = 0.941; + +const TEXT_PRIMARY: (u8, u8, u8, u8) = (242, 246, 255, 230); +const TEXT_MUTED: (u8, u8, u8, u8) = (185, 196, 220, 200); +const TEXT_UNSTABLE: (u8, u8, u8, u8) = (115, 235, 226, 240); +const TEXT_REWRITE: (u8, u8, u8, u8) = (255, 208, 126, 220); +const TEXT_REWRITE_PRIMARY: (u8, u8, u8, u8) = (255, 236, 205, 240); +const TEXT_WARNING: (u8, u8, u8, u8) = (255, 153, 134, 235); +const DIVIDER: (u8, u8, u8, u8) = (120, 150, 205, 42); static SHOULD_EXIT: AtomicBool = AtomicBool::new(false); +static OSD_FONT: OnceLock> = OnceLock::new(); + +#[derive(Clone, Copy, PartialEq, Eq)] +enum OverlayMode { + Meter, + Voice, +} -// --- Audio state (shared with capture thread) --- struct AudioLevel { rms_bits: AtomicU32, } @@ -68,41 +88,43 @@ impl AudioLevel { rms_bits: AtomicU32::new(0), } } + fn set(&self, val: f32) { self.rms_bits.store(val.to_bits(), Ordering::Relaxed); } + fn get(&self) -> f32 { f32::from_bits(self.rms_bits.load(Ordering::Relaxed)) } } -// --- Bar animation state --- struct BarState { heights: [f32; NUM_BARS], + smooth_rms: f32, } impl BarState { fn new() -> Self { Self { heights: [BAR_MIN_HEIGHT; NUM_BARS], + smooth_rms: 0.0, } } fn update(&mut self, rms: f32, time: f32) { - // Amplify RMS for visual impact let level = (rms * 5.0).min(1.0); + let rms_target = (rms * 4.0).min(1.0); + self.smooth_rms = self.smooth_rms * 0.85 + rms_target * 0.15; for i in 0..NUM_BARS { let t = i as f32 / NUM_BARS as f32; - // Create wave pattern across bars, driven by audio level - let wave1 = (t * std::f32::consts::PI * 2.5 + time * 3.0).sin() * 0.5 + 0.5; - let wave2 = (t * std::f32::consts::PI * 1.3 - time * 1.8).sin() * 0.3 + 0.5; - let wave3 = (t * std::f32::consts::PI * 4.0 + time * 5.5).sin() * 0.2 + 0.5; + let wave1 = (t * std::f32::consts::PI * 2.0 + time * 2.0).sin() * 0.5 + 0.5; + let wave2 = (t * std::f32::consts::PI * 1.5 - time * 1.2).sin() * 0.3 + 0.5; + let wave3 = (t * std::f32::consts::PI * 3.5 + time * 4.0).sin() * 0.15 + 0.5; - let combined = (wave1 * 0.5 + wave2 * 0.3 + wave3 * 0.2) * level; + let combined = (wave1 * 0.55 + wave2 * 0.30 + wave3 * 0.15) * level; let target = BAR_MIN_HEIGHT + combined * (BAR_MAX_HEIGHT - BAR_MIN_HEIGHT); - // Smooth: fast rise, slow decay if target > self.heights[i] { self.heights[i] += (target - self.heights[i]) * RISE_RATE; } else { @@ -113,7 +135,42 @@ impl BarState { } } -// --- Wayland state --- +#[derive(Debug, Clone)] +struct VoiceOverlayState { + status: VoiceOsdStatus, + stable_text: String, + unstable_text: String, + rewrite_preview: Option, + live_inject: bool, + frozen: bool, +} + +impl Default for VoiceOverlayState { + fn default() -> Self { + Self { + status: VoiceOsdStatus::Listening, + stable_text: String::new(), + unstable_text: String::new(), + rewrite_preview: None, + live_inject: false, + frozen: false, + } + } +} + +impl From for VoiceOverlayState { + fn from(update: VoiceOsdUpdate) -> Self { + Self { + status: update.status, + stable_text: update.stable_text, + unstable_text: update.unstable_text, + rewrite_preview: update.rewrite_preview, + live_inject: update.live_inject, + frozen: update.frozen, + } + } +} + struct OsdState { running: bool, width: u32, @@ -144,23 +201,36 @@ fn main() -> Result<(), Box> { ); } + let mode = if std::env::args().any(|arg| arg == "--voice") { + OverlayMode::Voice + } else { + OverlayMode::Meter + }; + let (width, height) = match mode { + OverlayMode::Meter => (METER_WIDTH, METER_HEIGHT), + OverlayMode::Voice => (VOICE_WIDTH, VOICE_HEIGHT), + }; + let _ = std::fs::write(pid_file_path(), std::process::id().to_string()); - // Start audio capture for visualization let audio_level = Arc::new(AudioLevel::new()); let _audio_stream = start_audio_capture(Arc::clone(&audio_level)); - // Wayland setup + let voice_state = matches!(mode, OverlayMode::Voice) + .then(|| Arc::new(std::sync::Mutex::new(VoiceOverlayState::default()))); + let _voice_reader = voice_state + .as_ref() + .map(|state| start_voice_event_reader(Arc::clone(state))); + let conn = Connection::connect_to_env()?; let mut event_queue = conn.new_event_queue(); let qh = event_queue.handle(); - conn.display().get_registry(&qh, ()); let mut state = OsdState { running: true, - width: OSD_WIDTH, - height: OSD_HEIGHT, + width, + height, compositor: None, shm: None, layer_shell: None, @@ -172,7 +242,6 @@ fn main() -> Result<(), Box> { event_queue.roundtrip(&mut state)?; - // Create layer surface let compositor = state .compositor .as_ref() @@ -192,7 +261,7 @@ fn main() -> Result<(), Box> { (), ); - layer_surface.set_size(OSD_WIDTH, OSD_HEIGHT); + layer_surface.set_size(width, height); layer_surface.set_anchor(zwlr_layer_surface_v1::Anchor::Bottom); layer_surface.set_margin(0, 0, MARGIN_BOTTOM, 0); layer_surface.set_exclusive_zone(-1); @@ -201,19 +270,14 @@ fn main() -> Result<(), Box> { state.surface = Some(surface); state.layer_surface = Some(layer_surface); - event_queue.roundtrip(&mut state)?; - // Animation state let mut bars = BarState::new(); let start_time = Instant::now(); + let mut pixels = vec![0u8; (width * height * 4) as usize]; - // Reusable pixel buffer (avoids alloc/dealloc per frame) - let mut pixels = vec![0u8; (OSD_WIDTH * OSD_HEIGHT * 4) as usize]; - - // Persistent shm pool: create memfd + pool once, reuse each frame - let stride = OSD_WIDTH * 4; - let shm_size = (stride * OSD_HEIGHT) as i32; + let stride = width * 4; + let shm_size = (stride * height) as i32; let shm_name = CString::new(branding::OSD_BINARY).expect("valid OSD memfd name"); let shm_fd = unsafe { libc::memfd_create(shm_name.as_ptr(), libc::MFD_CLOEXEC) }; if shm_fd < 0 { @@ -227,7 +291,6 @@ fn main() -> Result<(), Box> { .ok_or("wl_shm not advertised by wayland server")?; let pool = shm.create_pool(shm_file.as_fd(), shm_size, &qh, ()); - // Main animation loop while state.running && !SHOULD_EXIT.load(Ordering::Relaxed) { conn.flush()?; @@ -253,24 +316,30 @@ fn main() -> Result<(), Box> { continue; } - // Update animation let time = start_time.elapsed().as_secs_f32(); let rms = audio_level.get(); bars.update(rms, time); - // Render frame into reusable buffer - let w = state.width; - let h = state.height; + let voice_snapshot = voice_state + .as_ref() + .and_then(|shared| shared.lock().ok().map(|state| state.clone())); + pixels.fill(0); - render_frame(&mut pixels, w, h, &bars, time); + render_frame( + &mut pixels, + state.width, + state.height, + &bars, + time, + mode, + voice_snapshot.as_ref(), + ); - // Present frame using persistent shm pool - if let Err(e) = present_frame(&mut state, &qh, &pool, &shm_file, &pixels, w, h) { - eprintln!("frame dropped: {e}"); + if let Err(err) = present_frame(&mut state, &qh, &pool, &shm_file, &pixels, width, height) { + eprintln!("frame dropped: {err}"); } } - // Cleanup pool.destroy(); if let Some(ls) = state.layer_surface.take() { ls.destroy(); @@ -289,13 +358,33 @@ unsafe extern "C" fn handle_signal(_sig: libc::c_int) { SHOULD_EXIT.store(true, Ordering::Relaxed); } -// --- Audio capture --- +fn start_voice_event_reader( + state: Arc>, +) -> std::thread::JoinHandle<()> { + std::thread::spawn(move || { + let stdin = std::io::stdin(); + let reader = BufReader::new(stdin.lock()); + for line in reader.lines() { + let Ok(line) = line else { + break; + }; + if line.trim().is_empty() { + continue; + } + let Ok(event) = serde_json::from_str::(&line) else { + continue; + }; + let OsdEvent::VoiceUpdate(update) = event; + if let Ok(mut guard) = state.lock() { + *guard = update.into(); + } + } + }) +} fn start_audio_capture(level: Arc) -> Option { let host = cpal::default_host(); let device = host.default_input_device()?; - - // Try to find a supported config at 16kHz, preferring mono then fewer channels let config = device .supported_input_configs() .ok() @@ -324,7 +413,6 @@ fn start_audio_capture(level: Arc) -> Option { if data.is_empty() { return; } - // Downmix to mono if needed, then compute RMS let sample_count = data.len() / channels.max(1); if sample_count == 0 { return; @@ -351,24 +439,50 @@ fn start_audio_capture(level: Arc) -> Option { Some(stream) } -// --- Rendering --- +fn render_frame( + pixels: &mut [u8], + w: u32, + h: u32, + bars: &BarState, + _time: f32, + mode: OverlayMode, + voice_state: Option<&VoiceOverlayState>, +) { + let shell_alpha = match mode { + OverlayMode::Meter => METER_BG_A, + OverlayMode::Voice => BG_A, + }; + render_overlay_shell(pixels, w, h, shell_alpha); + + match mode { + OverlayMode::Meter => render_meter_overlay(pixels, w, h, bars), + OverlayMode::Voice => render_voice_overlay( + pixels, + w, + h, + bars, + voice_state.unwrap_or(&VoiceOverlayState::default()), + ), + } +} -fn render_frame(pixels: &mut [u8], w: u32, h: u32, bars: &BarState, _time: f32) { - // Glassmorphic background - draw_rounded_rect( - pixels, - w, - h, - 0, - 0, - w, - h, - CORNER_RADIUS, - BG_R, - BG_G, - BG_B, - BG_A, - ); +fn render_overlay_shell(pixels: &mut [u8], w: u32, h: u32, fill_alpha: u8) { + if fill_alpha > 0 { + draw_rounded_rect( + pixels, + w, + h, + 0, + 0, + w, + h, + CORNER_RADIUS, + BG_R, + BG_G, + BG_B, + fill_alpha, + ); + } draw_rounded_border( pixels, w, @@ -380,36 +494,267 @@ fn render_frame(pixels: &mut [u8], w: u32, h: u32, bars: &BarState, _time: f32) BORDER_B, BORDER_A, ); - - // Top highlight (glass reflection) for x in (CORNER_RADIUS + 2)..(w.saturating_sub(CORNER_RADIUS + 2)) { - set_pixel_blend(pixels, w, h, x, 1, 255, 255, 255, 18); + set_pixel_blend(pixels, w, h, x, 1, 255, 255, 255, 12); } +} + +fn render_meter_overlay(pixels: &mut [u8], w: u32, h: u32, bars: &BarState) { + render_meter_bars(pixels, w, h, bars, h / 2); +} - // Visualizer bars - let center_y = h / 2; +fn render_voice_overlay( + pixels: &mut [u8], + w: u32, + h: u32, + bars: &BarState, + voice: &VoiceOverlayState, +) { + let pad = 20i32; + let header_y = 16i32; + let transcript_y = 50i32; + let status_font = 16.0; + let badge_font = 13.0; + let transcript_font = 20.0; + let footer_font = 14.0; + let transcript_line_height = line_height(transcript_font) + 4; + let footer_line_height = line_height(footer_font) + 2; + let transcript_width = w.saturating_sub((pad as u32) * 2); + let raw_live_text = combined_voice_text(&voice.stable_text, &voice.unstable_text); + let rewrite_available = voice + .rewrite_preview + .as_deref() + .map(str::trim) + .filter(|text| !text.is_empty()); + + let status_label = status_label(voice.status); + draw_text( + pixels, + w, + h, + pad, + header_y, + status_font, + status_label, + TEXT_PRIMARY, + ); + + let badge_text = if voice.live_inject { + "LIVE INJECT" + } else { + "PREVIEW ONLY" + }; + let badge_width = text_width(badge_text, badge_font); + draw_text( + pixels, + w, + h, + w.saturating_sub(pad as u32 + badge_width) as i32, + header_y, + badge_font, + badge_text, + if voice.live_inject { + TEXT_UNSTABLE + } else { + TEXT_MUTED + }, + ); + + if let Some(rewrite_text) = rewrite_available { + let preview_label = "Live rewrite preview"; + draw_text( + pixels, + w, + h, + pad, + transcript_y - footer_line_height - 2, + footer_font, + preview_label, + TEXT_REWRITE, + ); + + let rewrite_lines = wrap_text(rewrite_text, transcript_width, 4, transcript_font); + for (idx, line) in rewrite_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + pad, + transcript_y + idx as i32 * transcript_line_height, + transcript_font, + line, + TEXT_REWRITE_PRIMARY, + ); + } + } else { + let stable_lines = wrap_text(&voice.stable_text, transcript_width, 3, transcript_font); + let unstable_lines = wrap_text(&voice.unstable_text, transcript_width, 2, transcript_font); + + if stable_lines.is_empty() && unstable_lines.is_empty() { + draw_text( + pixels, + w, + h, + pad, + transcript_y, + transcript_font, + "Listening for speech...", + TEXT_MUTED, + ); + } else { + for (idx, line) in stable_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + pad, + transcript_y + idx as i32 * transcript_line_height, + transcript_font, + line, + TEXT_PRIMARY, + ); + } + + let unstable_y = transcript_y + stable_lines.len() as i32 * transcript_line_height; + for (idx, line) in unstable_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + pad, + unstable_y + idx as i32 * transcript_line_height, + transcript_font, + line, + TEXT_UNSTABLE, + ); + } + } + } + + let divider_y = h.saturating_sub(74); + for x in pad as u32..w.saturating_sub(pad as u32) { + set_pixel_blend( + pixels, w, h, x, divider_y, DIVIDER.0, DIVIDER.1, DIVIDER.2, DIVIDER.3, + ); + } + + if voice.frozen { + let warning_lines = wrap_text( + "Focus changed. Live injection is frozen for this take.", + transcript_width, + 2, + footer_font, + ); + for (idx, line) in warning_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + pad, + divider_y as i32 + 10 + idx as i32 * footer_line_height, + footer_font, + line, + TEXT_WARNING, + ); + } + } else if rewrite_available.is_some() { + draw_text( + pixels, + w, + h, + pad, + divider_y as i32 + 10, + footer_font, + "Raw live hypothesis", + TEXT_MUTED, + ); + let raw_lines = wrap_text(&raw_live_text, transcript_width, 2, footer_font); + for (idx, line) in raw_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + pad, + divider_y as i32 + 10 + footer_line_height + idx as i32 * footer_line_height, + footer_font, + line, + TEXT_MUTED, + ); + } + } + + render_voice_bars(pixels, w, h, bars, h.saturating_sub(24)); +} + +fn render_meter_bars(pixels: &mut [u8], w: u32, h: u32, bars: &BarState, center_y: u32) { + let color_t = bars.smooth_rms.clamp(0.0, 1.0); + let cr = (lerp(BAR_REST_R, BAR_PEAK_R, color_t) * 255.0) as u8; + let cg = (lerp(BAR_REST_G, BAR_PEAK_G, color_t) * 255.0) as u8; + let cb = (lerp(BAR_REST_B, BAR_PEAK_B, color_t) * 255.0) as u8; + let base_alpha = lerp(BAR_REST_A, BAR_PEAK_A, color_t); + + let glow_expand = 1 + (color_t * 2.0) as u32; + let glow_alpha = (15.0 + color_t * 25.0) as u8; + + let total_width = NUM_BARS as u32 * BAR_WIDTH + (NUM_BARS as u32 - 1) * BAR_GAP; + let start_x = w.saturating_sub(total_width) / 2; for i in 0..NUM_BARS { - let bx = PAD_X + i as u32 * (BAR_WIDTH + BAR_GAP); + let bx = start_x + i as u32 * (BAR_WIDTH + BAR_GAP); + let bar_h = bars.heights[i] as u32; + let half_h = bar_h / 2; + let top_y = center_y.saturating_sub(half_h); + + for gy in top_y.saturating_sub(glow_expand) + ..=(top_y + bar_h + glow_expand).min(h.saturating_sub(1)) + { + for gx in bx.saturating_sub(glow_expand) + ..=(bx + BAR_WIDTH + glow_expand).min(w.saturating_sub(1)) + { + set_pixel_blend(pixels, w, h, gx, gy, cr, cg, cb, glow_alpha); + } + } + + for y in top_y..(top_y + bar_h).min(h) { + let vy = (y as f32 - top_y as f32) / bar_h.max(1) as f32; + let brightness = 1.0 - (vy - 0.5).abs() * 0.5; + let a = (brightness * base_alpha * 255.0) as u8; + for x in bx..(bx + BAR_WIDTH).min(w) { + set_pixel_blend(pixels, w, h, x, y, cr, cg, cb, a); + } + } + } +} + +fn render_voice_bars(pixels: &mut [u8], w: u32, h: u32, bars: &BarState, center_y: u32) { + const VOICE_BAR_LEFT_R: f32 = 0.0; + const VOICE_BAR_LEFT_G: f32 = 0.82; + const VOICE_BAR_LEFT_B: f32 = 0.75; + const VOICE_BAR_RIGHT_R: f32 = 0.65; + const VOICE_BAR_RIGHT_G: f32 = 0.35; + const VOICE_BAR_RIGHT_B: f32 = 1.0; + + let total_width = NUM_BARS as u32 * BAR_WIDTH + (NUM_BARS as u32 - 1) * BAR_GAP; + let start_x = w.saturating_sub(total_width) / 2; + for i in 0..NUM_BARS { + let bx = start_x + i as u32 * (BAR_WIDTH + BAR_GAP); let bar_h = bars.heights[i] as u32; let half_h = bar_h / 2; let top_y = center_y.saturating_sub(half_h); let t = i as f32 / (NUM_BARS - 1) as f32; - let r = lerp(BAR_LEFT_R, BAR_RIGHT_R, t); - let g = lerp(BAR_LEFT_G, BAR_RIGHT_G, t); - let b = lerp(BAR_LEFT_B, BAR_RIGHT_B, t); + let r = lerp(VOICE_BAR_LEFT_R, VOICE_BAR_RIGHT_R, t); + let g = lerp(VOICE_BAR_LEFT_G, VOICE_BAR_RIGHT_G, t); + let b = lerp(VOICE_BAR_LEFT_B, VOICE_BAR_RIGHT_B, t); let cr = (r * 255.0) as u8; let cg = (g * 255.0) as u8; let cb = (b * 255.0) as u8; - // Glow - for gy in top_y.saturating_sub(2)..=(top_y + bar_h + 2).min(h - 1) { - for gx in bx.saturating_sub(1)..=(bx + BAR_WIDTH).min(w - 1) { + for gy in top_y.saturating_sub(2)..=(top_y + bar_h + 2).min(h.saturating_sub(1)) { + for gx in bx.saturating_sub(1)..=(bx + BAR_WIDTH).min(w.saturating_sub(1)) { set_pixel_blend(pixels, w, h, gx, gy, cr, cg, cb, 25); } } - // Bar body with vertical brightness gradient for y in top_y..(top_y + bar_h).min(h) { let vy = (y as f32 - top_y as f32) / bar_h.max(1) as f32; let brightness = 1.0 - (vy - 0.5).abs() * 0.6; @@ -421,6 +766,301 @@ fn render_frame(pixels: &mut [u8], w: u32, h: u32, bars: &BarState, _time: f32) } } +fn status_label(status: VoiceOsdStatus) -> &'static str { + match status { + VoiceOsdStatus::Listening => "Listening", + VoiceOsdStatus::Transcribing => "Transcribing", + VoiceOsdStatus::Rewriting => "Rewriting", + VoiceOsdStatus::Finalizing => "Finalizing", + VoiceOsdStatus::Frozen => "Frozen", + } +} + +fn load_osd_font() -> Option { + const FONT_CANDIDATES: &[&str] = &[ + "/usr/share/fonts/noto/NotoSans-Regular.ttf", + "/usr/share/fonts/noto/NotoSans-Medium.ttf", + "/usr/share/fonts/TTF/NotoSansMNerdFont-Regular.ttf", + "/usr/share/fonts/liberation/LiberationSans-Regular.ttf", + "/usr/share/fonts/TTF/ArimoNerdFont-Regular.ttf", + "/usr/share/fonts/TTF/DejaVuSansMNerdFont-Regular.ttf", + ]; + + FONT_CANDIDATES + .iter() + .find_map(|path| load_font_from_path(path)) +} + +fn load_font_from_path(path: &str) -> Option { + let bytes = std::fs::read(Path::new(path)).ok()?; + Font::from_bytes(bytes, fontdue::FontSettings::default()).ok() +} + +fn osd_font() -> Option<&'static Font> { + OSD_FONT.get_or_init(load_osd_font).as_ref() +} + +fn line_height(px_size: f32) -> i32 { + if let Some(font) = osd_font() { + if let Some(metrics) = font.horizontal_line_metrics(px_size) { + return metrics.new_line_size.ceil().max(px_size.ceil()) as i32; + } + } + let scale = ((px_size / 8.0).round() as i32).max(1); + 8 * scale + scale + 2 +} + +fn wrap_text(text: &str, max_width_px: u32, max_lines: usize, px_size: f32) -> Vec { + if max_width_px == 0 || max_lines == 0 { + return Vec::new(); + } + + let text = sanitize_text(text); + if text.is_empty() { + return Vec::new(); + } + + let mut lines = Vec::new(); + let mut current = String::new(); + + for word in text.split_whitespace() { + if word.is_empty() { + continue; + } + let candidate = if current.is_empty() { + word.to_string() + } else { + format!("{current} {word}") + }; + + if text_width(&candidate, px_size) <= max_width_px { + current = candidate; + continue; + } + + if !current.is_empty() { + lines.push(current); + } + + current = truncate_word_to_width(word, max_width_px, px_size); + } + + if !current.is_empty() { + lines.push(current); + } + + if lines.len() > max_lines { + let mut tail = lines.split_off(lines.len() - max_lines); + if let Some(first) = tail.first_mut() { + *first = fit_text_to_width(&format!("…{first}"), max_width_px, px_size, true); + } + return tail; + } + + lines +} + +fn combined_voice_text(stable_text: &str, unstable_text: &str) -> String { + match (stable_text.trim(), unstable_text.trim()) { + ("", "") => String::new(), + ("", unstable) => unstable.to_string(), + (stable, "") => stable.to_string(), + (stable, unstable) => format!("{stable} {unstable}"), + } +} + +fn sanitize_text(text: &str) -> String { + let mut normalized = String::new(); + let mut pending_space = false; + + for ch in text.chars() { + if ch.is_control() && !ch.is_whitespace() { + continue; + } + if ch.is_whitespace() { + if !normalized.is_empty() { + pending_space = true; + } + continue; + } + if pending_space { + normalized.push(' '); + pending_space = false; + } + normalized.push(ch); + } + + normalized +} + +fn truncate_word_to_width(word: &str, max_width_px: u32, px_size: f32) -> String { + fit_text_to_width(word, max_width_px, px_size, false) +} + +fn fit_text_to_width(text: &str, max_width_px: u32, px_size: f32, keep_tail: bool) -> String { + let sanitized = sanitize_text(text); + if sanitized.is_empty() { + return String::new(); + } + if text_width(&sanitized, px_size) <= max_width_px { + return sanitized; + } + + let ellipsis = "…"; + let chars: Vec = sanitized.chars().collect(); + if keep_tail { + for start in 0..chars.len() { + let candidate: String = std::iter::once('…') + .chain(chars[start..].iter().copied()) + .collect(); + if text_width(&candidate, px_size) <= max_width_px { + return candidate; + } + } + } else { + let mut candidate = String::new(); + for ch in chars { + let next = if candidate.is_empty() { + format!("{ch}{ellipsis}") + } else { + format!("{candidate}{ch}{ellipsis}") + }; + if text_width(&next, px_size) > max_width_px { + break; + } + candidate.push(ch); + } + if !candidate.is_empty() { + candidate.push('…'); + return candidate; + } + } + + ellipsis.to_string() +} + +fn text_width(text: &str, px_size: f32) -> u32 { + let sanitized = sanitize_text(text); + if sanitized.is_empty() { + return 0; + } + if let Some(font) = osd_font() { + let width: f32 = sanitized + .chars() + .map(|ch| font.metrics(ch, px_size).advance_width) + .sum(); + return width.ceil().max(0.0) as u32; + } + + let scale = ((px_size / 8.0).round() as i32).max(1); + let glyph_w = (8 * scale + scale).max(1) as u32; + sanitized.chars().count() as u32 * glyph_w +} + +#[allow(clippy::too_many_arguments)] +fn draw_text( + pixels: &mut [u8], + w: u32, + h: u32, + x: i32, + y: i32, + px_size: f32, + text: &str, + color: (u8, u8, u8, u8), +) { + let sanitized = sanitize_text(text); + if sanitized.is_empty() { + return; + } + + if let Some(font) = osd_font() { + let baseline = y as f32 + + font + .horizontal_line_metrics(px_size) + .map(|metrics| metrics.ascent) + .unwrap_or(px_size * 0.92); + let mut pen_x = x as f32; + for ch in sanitized.chars() { + let (metrics, bitmap) = font.rasterize(ch, px_size); + let glyph_x = pen_x.round() as i32 + metrics.xmin; + // fontdue reports ymin from the baseline to the glyph's bottom edge. + // In our positive-Y-down compositor space, the top pixel row sits at + // baseline - height - ymin. + let glyph_y = (baseline - metrics.height as f32 - metrics.ymin as f32).round() as i32; + for row in 0..metrics.height { + for col in 0..metrics.width { + let alpha = bitmap[row * metrics.width + col]; + if alpha == 0 { + continue; + } + let px = glyph_x + col as i32; + let py = glyph_y + row as i32; + if px >= 0 && py >= 0 { + let blended_alpha = ((alpha as u16 * color.3 as u16) / 255) as u8; + set_pixel_blend( + pixels, + w, + h, + px as u32, + py as u32, + color.0, + color.1, + color.2, + blended_alpha, + ); + } + } + } + pen_x += metrics.advance_width; + } + return; + } + + let scale = ((px_size / 8.0).round() as i32).max(1); + let mut cursor_x = x; + let glyph_advance = (8 * scale + scale).max(1); + for ch in sanitized.chars() { + draw_bitmap_char(pixels, w, h, cursor_x, y, scale, ch, color); + cursor_x += glyph_advance; + } +} + +#[allow(clippy::too_many_arguments)] +fn draw_bitmap_char( + pixels: &mut [u8], + w: u32, + h: u32, + x: i32, + y: i32, + scale: i32, + ch: char, + color: (u8, u8, u8, u8), +) { + let glyph = BASIC_FONTS.get(ch).or_else(|| BASIC_FONTS.get('?')); + let Some(glyph) = glyph else { + return; + }; + + for (row_idx, row_bits) in glyph.iter().enumerate() { + for col_idx in 0..8 { + if (row_bits >> col_idx) & 1 == 0 { + continue; + } + for sy in 0..scale.max(1) { + for sx in 0..scale.max(1) { + let px = x + (col_idx * scale as usize + sx as usize) as i32; + let py = y + (row_idx * scale as usize + sy as usize) as i32; + if px >= 0 && py >= 0 { + set_pixel_blend( + pixels, w, h, px as u32, py as u32, color.0, color.1, color.2, color.3, + ); + } + } + } + } + } +} + fn present_frame( state: &mut OsdState, qh: &QueueHandle, @@ -437,7 +1077,6 @@ fn present_frame( writer.seek(std::io::SeekFrom::Start(0))?; writer.write_all(pixels)?; - // Destroy previous buffer if let Some(old) = state.buffer.take() { old.destroy(); } @@ -464,8 +1103,6 @@ fn present_frame( Ok(()) } -// --- Drawing primitives --- - #[allow(clippy::too_many_arguments)] #[inline] fn set_pixel_blend(pixels: &mut [u8], w: u32, h: u32, x: u32, y: u32, r: u8, g: u8, b: u8, a: u8) { @@ -474,7 +1111,6 @@ fn set_pixel_blend(pixels: &mut [u8], w: u32, h: u32, x: u32, y: u32, r: u8, g: } let idx = ((y * w + x) * 4) as usize; if a == 255 { - // Premultiplied: BGRA pixels[idx] = b; pixels[idx + 1] = g; pixels[idx + 2] = r; @@ -483,7 +1119,6 @@ fn set_pixel_blend(pixels: &mut [u8], w: u32, h: u32, x: u32, y: u32, r: u8, g: } let sa = a as u32; let inv = 255 - sa; - // Premultiply source, blend with existing premultiplied dest pixels[idx] = ((sa * b as u32 + inv * pixels[idx] as u32) / 255) as u8; pixels[idx + 1] = ((sa * g as u32 + inv * pixels[idx + 1] as u32) / 255) as u8; pixels[idx + 2] = ((sa * r as u32 + inv * pixels[idx + 2] as u32) / 255) as u8; @@ -553,7 +1188,6 @@ fn is_inside_rounded_rect(x: u32, y: u32, w: u32, h: u32, r: u32) -> bool { if r == 0 || w == 0 || h == 0 { return x < w && y < h; } - // Check only corner regions let in_left = x < r; let in_right = x >= w - r; let in_top = y < r; @@ -587,8 +1221,6 @@ fn lerp(a: f32, b: f32, t: f32) -> f32 { a + (b - a) * t } -// --- Dispatch implementations --- - impl Dispatch for OsdState { fn event( state: &mut Self, diff --git a/src/bin/whispers-rewrite-worker.rs b/src/bin/whispers-rewrite-worker.rs index 8a32989..db91069 100644 --- a/src/bin/whispers-rewrite-worker.rs +++ b/src/bin/whispers-rewrite-worker.rs @@ -2,6 +2,8 @@ mod branding; #[path = "../rewrite.rs"] mod rewrite; +#[path = "../rewrite_local.rs"] +mod rewrite_local; #[path = "../rewrite_profile.rs"] mod rewrite_profile; #[path = "../rewrite_protocol.rs"] @@ -14,7 +16,7 @@ use clap::Parser; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::net::{UnixListener, UnixStream}; -use crate::rewrite::LocalRewriter; +use crate::rewrite_local::LocalRewriter; use crate::rewrite_profile::ResolvedRewriteProfile; use crate::rewrite_protocol::{WorkerRequest, WorkerResponse}; @@ -30,10 +32,10 @@ struct Cli { #[arg(long, value_enum, default_value_t = ResolvedRewriteProfile::Generic)] profile: ResolvedRewriteProfile, - #[arg(long, default_value_t = 256)] + #[arg(long, default_value_t = 768)] max_tokens: usize, - #[arg(long, default_value_t = 1200)] + #[arg(long, default_value_t = 8192)] max_output_chars: usize, #[arg(long, default_value_t = 120000)] diff --git a/src/cli.rs b/src/cli.rs index d431781..fa0ac37 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -31,11 +31,41 @@ pub enum CompletionShell { Nushell, } +impl CompletionShell { + pub const fn all() -> [Self; 4] { + [Self::Bash, Self::Zsh, Self::Fish, Self::Nushell] + } + + pub fn as_str(self) -> &'static str { + match self { + Self::Bash => "bash", + Self::Zsh => "zsh", + Self::Fish => "fish", + Self::Nushell => "nushell", + } + } + + pub fn binary_names(self) -> &'static [&'static str] { + match self { + Self::Bash => &["bash"], + Self::Zsh => &["zsh"], + Self::Fish => &["fish"], + Self::Nushell => &["nu", "nushell"], + } + } +} + #[derive(Subcommand, Debug)] pub enum Command { /// Guided setup for local, cloud, and experimental dictation paths Setup, + /// Show the active configuration, selected models, and runtime options + Status, + + /// Experimental live voice mode with partial transcription preview + Voice, + /// Transcribe an audio file (wav, mp3, flac, ogg, mp4/m4a) Transcribe { /// Path to the audio file @@ -334,6 +364,18 @@ mod tests { )); } + #[test] + fn parses_voice_command() { + let cli = Cli::try_parse_from(["whispers", "voice"]).unwrap(); + assert!(matches!(cli.command, Some(Command::Voice))); + } + + #[test] + fn parses_status_command() { + let cli = Cli::try_parse_from(["whispers", "status"]).unwrap(); + assert!(matches!(cli.command, Some(Command::Status))); + } + #[test] fn parses_rewrite_model_subcommand() { let cli = Cli::try_parse_from(["whispers", "rewrite-model", "list"]).unwrap(); diff --git a/src/cloud.rs b/src/cloud.rs index fa04ea8..f8d24de 100644 --- a/src/cloud.rs +++ b/src/cloud.rs @@ -679,7 +679,7 @@ mod tests { crate::test_support::set_env("OPENAI_API_KEY", "test-key"); let mut config = Config::default(); - config.postprocess.mode = crate::config::PostprocessMode::AdvancedLocal; + config.postprocess.mode = crate::config::PostprocessMode::Rewrite; config.rewrite.backend = RewriteBackend::Cloud; config.cloud.base_url = format!("{}/v1", server.base_url()); let service = CloudService::new(&config).expect("service"); @@ -701,6 +701,7 @@ mod tests { edit_hypotheses: Vec::new(), rewrite_candidates: Vec::new(), recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: crate::rewrite_protocol::RewritePolicyContext::default(), }, None, diff --git a/src/completions.rs b/src/completions.rs index b632c20..22614e3 100644 --- a/src/completions.rs +++ b/src/completions.rs @@ -1,5 +1,6 @@ use std::io::Write; -use std::path::Path; +use std::path::{Path, PathBuf}; +use std::process::Command; use clap::CommandFactory; use clap_complete::{generate, shells}; @@ -22,10 +23,52 @@ pub fn run_completions(shell_arg: Option) -> Result<()> { Ok(()) } -fn detect_shell() -> Option { +pub fn detect_shell() -> Option { detect_shell_from_env().or_else(detect_shell_from_parent_process) } +pub fn detect_installed_shells() -> Vec { + detect_installed_shells_in_path(std::env::var_os("PATH").as_deref()) +} + +pub fn install_completions(shell: CompletionShell) -> Result { + let path = install_path(shell); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|e| { + WhsprError::Config(format!( + "failed to create completion directory {}: {e}", + parent.display() + )) + })?; + } + std::fs::write(&path, render_completions(shell)).map_err(|e| { + WhsprError::Config(format!( + "failed to write {} completions to {}: {e}", + shell.as_str(), + path.display() + )) + })?; + finalize_completion_install(shell, &path)?; + Ok(path) +} + +pub fn install_note(shell: CompletionShell, path: &Path) -> Option { + match shell { + CompletionShell::Bash => { + Some("Restart bash or run `source ~/.bashrc` to refresh completions.".to_string()) + } + CompletionShell::Zsh if path.parent() == Some(home_dir().join(".zfunc").as_path()) => { + Some("Ensure your zsh fpath includes ~/.zfunc before running compinit.".to_string()) + } + CompletionShell::Zsh => Some( + "Restart zsh or run `autoload -Uz compinit && compinit` to refresh completions." + .to_string(), + ), + CompletionShell::Nushell => Some("Restart Nushell to refresh completions.".to_string()), + _ => None, + } +} + fn detect_shell_from_env() -> Option { std::env::var("SHELL") .ok() @@ -41,6 +84,26 @@ fn detect_shell_from_parent_process() -> Option { detect_shell_from_pid(ppid) } +fn detect_installed_shells_in_path(path: Option<&std::ffi::OsStr>) -> Vec { + CompletionShell::all() + .into_iter() + .filter(|shell| { + shell + .binary_names() + .iter() + .any(|binary| binary_exists_in_path(binary, path)) + }) + .collect() +} + +fn binary_exists_in_path(binary: &str, path: Option<&std::ffi::OsStr>) -> bool { + let Some(path) = path else { + return false; + }; + + std::env::split_paths(path).any(|dir| dir.join(binary).is_file()) +} + fn detect_shell_from_pid(pid: libc::pid_t) -> Option { let cmdline_path = format!("/proc/{pid}/cmdline"); if let Ok(cmdline) = std::fs::read(cmdline_path) { @@ -83,14 +146,264 @@ fn shell_from_token(value: &str) -> Option { } fn write_completions(shell: CompletionShell, out: &mut dyn Write) { + out.write_all(render_completions(shell).as_bytes()) + .expect("writing completions to stdout should succeed"); +} + +pub fn render_completions(shell: CompletionShell) -> String { let mut cmd = Cli::command(); let bin_name = cmd.get_name().to_string(); + let mut output = Vec::new(); + + match shell { + CompletionShell::Bash => generate(shells::Bash, &mut cmd, &bin_name, &mut output), + CompletionShell::Zsh => generate(shells::Zsh, &mut cmd, &bin_name, &mut output), + CompletionShell::Fish => generate(shells::Fish, &mut cmd, &bin_name, &mut output), + CompletionShell::Nushell => generate(Nushell, &mut cmd, &bin_name, &mut output), + } + + String::from_utf8(output).expect("completion scripts should be valid UTF-8") +} + +fn install_path(shell: CompletionShell) -> PathBuf { + match shell { + CompletionShell::Bash => xdg_data_home() + .join("bash-completion") + .join("completions") + .join("whispers"), + CompletionShell::Zsh => zsh_install_dir().join("_whispers"), + CompletionShell::Fish => xdg_config_home() + .join("fish") + .join("completions") + .join("whispers.fish"), + CompletionShell::Nushell => nushell_autoload_dir().join("whispers.nu"), + } +} +fn finalize_completion_install(shell: CompletionShell, path: &Path) -> Result<()> { match shell { - CompletionShell::Bash => generate(shells::Bash, &mut cmd, &bin_name, out), - CompletionShell::Zsh => generate(shells::Zsh, &mut cmd, &bin_name, out), - CompletionShell::Fish => generate(shells::Fish, &mut cmd, &bin_name, out), - CompletionShell::Nushell => generate(Nushell, &mut cmd, &bin_name, out), + CompletionShell::Bash => ensure_bashrc_sources_completion(path), + CompletionShell::Nushell => ensure_nushell_config_uses_completion(path), + _ => Ok(()), + } +} + +fn ensure_bashrc_sources_completion(path: &Path) -> Result<()> { + let bashrc_path = home_dir().join(".bashrc"); + let existing = std::fs::read_to_string(&bashrc_path).unwrap_or_default(); + let start_marker = "# >>> whispers bash completions >>>"; + let end_marker = "# <<< whispers bash completions <<<"; + let source_path = shell_double_quote(path); + let block = format!( + "{start_marker}\nif [[ $- == *i* ]] && [[ -r \"{source_path}\" ]]; then\n source \"{source_path}\"\nfi\n{end_marker}\n" + ); + + let updated = if let Some(start) = existing.find(start_marker) { + if let Some(end_rel) = existing[start..].find(end_marker) { + let end = start + end_rel + end_marker.len(); + let mut next = String::new(); + next.push_str(&existing[..start]); + if !next.ends_with('\n') && !next.is_empty() { + next.push('\n'); + } + next.push_str(&block); + let suffix = existing[end..] + .strip_prefix('\n') + .unwrap_or(&existing[end..]); + next.push_str(suffix); + next + } else { + existing.clone() + } + } else if existing.contains(&format!("source \"{source_path}\"")) { + existing + } else { + let mut next = existing; + if !next.ends_with('\n') && !next.is_empty() { + next.push('\n'); + } + if !next.is_empty() { + next.push('\n'); + } + next.push_str(&block); + next + }; + + if updated != std::fs::read_to_string(&bashrc_path).unwrap_or_default() { + std::fs::write(&bashrc_path, updated).map_err(|e| { + WhsprError::Config(format!( + "failed to update bash init at {}: {e}", + bashrc_path.display() + )) + })?; + } + + Ok(()) +} + +fn ensure_nushell_config_uses_completion(path: &Path) -> Result<()> { + let config_path = xdg_config_home().join("nushell").join("config.nu"); + let existing = std::fs::read_to_string(&config_path).unwrap_or_default(); + let start_marker = "# >>> whispers nushell completions >>>"; + let end_marker = "# <<< whispers nushell completions <<<"; + let source_path = nushell_string_literal(path); + let block = format!("{start_marker}\nuse {source_path} *\n{end_marker}\n"); + + let updated = replace_or_prepend_block(&existing, start_marker, end_marker, &block); + if updated != existing { + if let Some(parent) = config_path.parent() { + std::fs::create_dir_all(parent).map_err(|e| { + WhsprError::Config(format!( + "failed to create Nushell config directory {}: {e}", + parent.display() + )) + })?; + } + std::fs::write(&config_path, updated).map_err(|e| { + WhsprError::Config(format!( + "failed to update Nushell config at {}: {e}", + config_path.display() + )) + })?; + } + + Ok(()) +} + +fn replace_or_prepend_block( + existing: &str, + start_marker: &str, + end_marker: &str, + block: &str, +) -> String { + if let Some(start) = existing.find(start_marker) { + if let Some(end_rel) = existing[start..].find(end_marker) { + let end = start + end_rel + end_marker.len(); + let mut next = String::new(); + next.push_str(&existing[..start]); + if !next.ends_with('\n') && !next.is_empty() { + next.push('\n'); + } + next.push_str(block); + let suffix = existing[end..] + .strip_prefix('\n') + .unwrap_or(&existing[end..]); + next.push_str(suffix); + return next; + } + + return existing.to_string(); + } + + if existing.contains(block.trim()) { + return existing.to_string(); + } + + let mut next = String::new(); + next.push_str(block); + if !existing.is_empty() && !block.ends_with('\n') { + next.push('\n'); + } + next.push_str(existing); + next +} + +fn zsh_install_dir() -> PathBuf { + detect_zsh_completion_dir().unwrap_or_else(|| home_dir().join(".zfunc")) +} + +fn detect_zsh_completion_dir() -> Option { + let output = Command::new("zsh") + .arg("-ic") + .arg("print -l -- $fpath") + .output() + .ok()?; + if !output.status.success() { + return None; + } + + preferred_zsh_completion_dir_from_fpath(&String::from_utf8_lossy(&output.stdout)) +} + +fn preferred_zsh_completion_dir_from_fpath(output: &str) -> Option { + let home = home_dir(); + let candidates: Vec = output + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .map(PathBuf::from) + .filter(|path| path.starts_with(&home)) + .collect(); + + candidates + .iter() + .find(|path| !path.starts_with(home.join(".cache"))) + .cloned() + .or_else(|| candidates.into_iter().next()) +} + +fn nushell_autoload_dir() -> PathBuf { + detect_nushell_autoload_dir() + .unwrap_or_else(|| xdg_config_home().join("nushell").join("autoload")) +} + +fn detect_nushell_autoload_dir() -> Option { + let output = Command::new("nu") + .arg("-c") + .arg("print (($nu.user-autoload-dirs | first) | to text)") + .output() + .ok()?; + if !output.status.success() { + return None; + } + + let path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if path.is_empty() { + None + } else { + Some(PathBuf::from(path)) + } +} + +fn shell_double_quote(path: &Path) -> String { + path.display() + .to_string() + .replace('\\', "\\\\") + .replace('"', "\\\"") +} + +fn nushell_string_literal(path: &Path) -> String { + format!( + "\"{}\"", + path.display() + .to_string() + .replace('\\', "\\\\") + .replace('"', "\\\"") + ) +} + +fn xdg_config_home() -> PathBuf { + if let Ok(dir) = std::env::var("XDG_CONFIG_HOME") { + PathBuf::from(dir) + } else { + home_dir().join(".config") + } +} + +fn xdg_data_home() -> PathBuf { + if let Ok(dir) = std::env::var("XDG_DATA_HOME") { + PathBuf::from(dir) + } else { + home_dir().join(".local").join("share") + } +} + +fn home_dir() -> PathBuf { + if let Ok(home) = std::env::var("HOME") { + PathBuf::from(home) + } else { + tracing::warn!("HOME is not set, falling back to /tmp for shell completion install path"); + PathBuf::from("/tmp") } } @@ -133,9 +446,7 @@ mod tests { } fn generate_to_string(shell: CompletionShell) -> String { - let mut output = Vec::new(); - write_completions(shell, &mut output); - String::from_utf8(output).unwrap() + render_completions(shell) } #[test] @@ -165,4 +476,129 @@ mod tests { assert!(script.contains("whispers")); assert!(script.contains("export extern")); } + + #[test] + fn install_path_uses_expected_xdg_locations() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&[ + "HOME", + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + ]); + let root = crate::test_support::unique_temp_dir("completions-paths"); + crate::test_support::set_env("HOME", &root.to_string_lossy()); + crate::test_support::remove_env("XDG_CONFIG_HOME"); + crate::test_support::remove_env("XDG_DATA_HOME"); + + assert_eq!( + install_path(CompletionShell::Fish), + root.join(".config/fish/completions/whispers.fish") + ); + assert_eq!( + install_path(CompletionShell::Bash), + root.join(".local/share/bash-completion/completions/whispers") + ); + assert_eq!( + install_path(CompletionShell::Zsh), + root.join(".zfunc/_whispers") + ); + assert_eq!( + install_path(CompletionShell::Nushell), + root.join(".config/nushell/autoload/whispers.nu") + ); + } + + #[test] + fn install_completions_writes_script_to_target_path() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&[ + "HOME", + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + ]); + let root = crate::test_support::unique_temp_dir("completions-install"); + crate::test_support::set_env("HOME", &root.to_string_lossy()); + crate::test_support::remove_env("XDG_CONFIG_HOME"); + crate::test_support::remove_env("XDG_DATA_HOME"); + + let path = install_completions(CompletionShell::Fish).expect("install fish completions"); + let script = std::fs::read_to_string(&path).expect("read installed fish completions"); + assert!(script.contains("complete -c whispers")); + assert_eq!(path, root.join(".config/fish/completions/whispers.fish")); + } + + #[test] + fn detect_installed_shells_finds_supported_binaries_on_path() { + let root = crate::test_support::unique_temp_dir("completions-shell-detect"); + for binary in ["fish", "zsh", "nu"] { + let path = root.join(binary); + std::fs::write(path, "#!/bin/sh\n").expect("write fake shell"); + } + + let detected = detect_installed_shells_in_path(Some(root.as_os_str())); + assert_eq!( + detected, + vec![ + CompletionShell::Zsh, + CompletionShell::Fish, + CompletionShell::Nushell, + ] + ); + } + + #[test] + fn preferred_zsh_completion_dir_uses_active_user_fpath_over_fallback() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&["HOME"]); + let root = crate::test_support::unique_temp_dir("completions-zsh-fpath"); + crate::test_support::set_env("HOME", &root.to_string_lossy()); + + let fpath = format!( + "{}/.cache/zinit/completions\n{}/.local/share/zinit/completions\n/usr/share/zsh/site-functions\n", + root.display(), + root.display() + ); + assert_eq!( + preferred_zsh_completion_dir_from_fpath(&fpath), + Some(root.join(".local/share/zinit/completions")) + ); + } + + #[test] + fn ensure_bashrc_sources_completion_adds_guarded_block() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&["HOME"]); + let root = crate::test_support::unique_temp_dir("completions-bashrc"); + crate::test_support::set_env("HOME", &root.to_string_lossy()); + + let completion_path = root.join(".local/share/bash-completion/completions/whispers"); + std::fs::create_dir_all(completion_path.parent().expect("parent")).expect("mkdir"); + std::fs::write(&completion_path, "# test\n").expect("write completion"); + + ensure_bashrc_sources_completion(&completion_path).expect("update bashrc"); + let bashrc = std::fs::read_to_string(root.join(".bashrc")).expect("read bashrc"); + assert!(bashrc.contains("# >>> whispers bash completions >>>")); + assert!(bashrc.contains("source \"/")); + assert!(bashrc.contains("bash-completion/completions/whispers")); + } + + #[test] + fn ensure_nushell_config_uses_completion_adds_guarded_block() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&["HOME", "XDG_CONFIG_HOME"]); + let root = crate::test_support::unique_temp_dir("completions-nushell-config"); + crate::test_support::set_env("HOME", &root.to_string_lossy()); + crate::test_support::set_env("XDG_CONFIG_HOME", &root.join(".config").to_string_lossy()); + + let completion_path = root.join(".config/nushell/autoload/whispers.nu"); + std::fs::create_dir_all(completion_path.parent().expect("parent")).expect("mkdir"); + std::fs::write(&completion_path, "# test\n").expect("write completion"); + + ensure_nushell_config_uses_completion(&completion_path).expect("update nushell config"); + let config = + std::fs::read_to_string(root.join(".config/nushell/config.nu")).expect("read config"); + assert!(config.contains("# >>> whispers nushell completions >>>")); + assert!(config.contains("use \"")); + assert!(config.contains("autoload/whispers.nu")); + } } diff --git a/src/config.rs b/src/config.rs index e0c75b6..c45ee9c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -16,11 +16,13 @@ pub struct Config { pub session: SessionConfig, pub personalization: PersonalizationConfig, pub rewrite: RewriteConfig, - pub agentic_rewrite: AgenticRewriteConfig, + #[serde(default, rename = "agentic_rewrite")] + legacy_agentic_rewrite: LegacyAgenticRewriteConfig, pub cloud: CloudConfig, pub cleanup: CleanupConfig, pub inject: InjectConfig, pub feedback: FeedbackConfig, + pub voice: VoiceConfig, } #[derive(Debug, Clone, Deserialize)] @@ -82,8 +84,8 @@ pub struct PostprocessConfig { pub enum PostprocessMode { #[default] Raw, - AdvancedLocal, - AgenticRewrite, + #[serde(alias = "advanced_local", alias = "agentic_rewrite")] + Rewrite, LegacyBasic, } @@ -116,11 +118,14 @@ pub struct RewriteConfig { pub idle_timeout_ms: u64, pub max_output_chars: usize, pub max_tokens: usize, + pub policy_path: String, + pub glossary_path: String, + pub default_correction_policy: RewriteCorrectionPolicy, } #[derive(Debug, Clone, Deserialize, PartialEq, Eq)] #[serde(default)] -pub struct AgenticRewriteConfig { +struct LegacyAgenticRewriteConfig { pub policy_path: String, pub glossary_path: String, pub default_correction_policy: RewriteCorrectionPolicy, @@ -236,6 +241,18 @@ pub struct FeedbackConfig { pub stop_sound: String, } +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct VoiceConfig { + pub live_inject: bool, + pub live_rewrite: bool, + pub partial_interval_ms: u64, + pub rewrite_interval_ms: u64, + pub context_window_ms: u64, + pub min_chunk_ms: u64, + pub freeze_on_focus_change: bool, +} + impl Default for AudioConfig { fn default() -> Self { Self { @@ -285,14 +302,13 @@ impl PostprocessMode { pub fn as_str(self) -> &'static str { match self { Self::Raw => "raw", - Self::AdvancedLocal => "advanced_local", - Self::AgenticRewrite => "agentic_rewrite", + Self::Rewrite => "rewrite", Self::LegacyBasic => "legacy_basic", } } pub fn uses_rewrite(self) -> bool { - matches!(self, Self::AdvancedLocal | Self::AgenticRewrite) + matches!(self, Self::Rewrite) } } @@ -363,13 +379,16 @@ impl Default for RewriteConfig { profile: RewriteProfile::Auto, timeout_ms: 30000, idle_timeout_ms: 120000, - max_output_chars: 1200, - max_tokens: 256, + max_output_chars: 8192, + max_tokens: 768, + policy_path: crate::agentic_rewrite::default_policy_path().into(), + glossary_path: crate::agentic_rewrite::default_glossary_path().into(), + default_correction_policy: RewriteCorrectionPolicy::Balanced, } } } -impl Default for AgenticRewriteConfig { +impl Default for LegacyAgenticRewriteConfig { fn default() -> Self { Self { policy_path: crate::agentic_rewrite::default_policy_path().into(), @@ -457,6 +476,20 @@ impl Default for FeedbackConfig { } } +impl Default for VoiceConfig { + fn default() -> Self { + Self { + live_inject: false, + live_rewrite: false, + partial_interval_ms: 400, + rewrite_interval_ms: 1400, + context_window_ms: 8000, + min_chunk_ms: 650, + freeze_on_focus_change: true, + } + } +} + impl TranscriptionConfig { pub fn resolved_local_backend(&self) -> TranscriptionBackend { match self.local_backend { @@ -490,6 +523,7 @@ impl Config { config.apply_legacy_transcription_migration(&contents, &config_path); config.apply_legacy_cleanup_migration(&contents, &config_path); + config.apply_legacy_rewrite_migration(&contents, &config_path); config.apply_cloud_sanitization(); Ok(config) } @@ -516,12 +550,20 @@ impl Config { PathBuf::from(expand_tilde(&self.personalization.snippets_path)) } + pub fn resolved_rewrite_policy_path(&self) -> PathBuf { + PathBuf::from(expand_tilde(&self.rewrite.policy_path)) + } + + pub fn resolved_rewrite_glossary_path(&self) -> PathBuf { + PathBuf::from(expand_tilde(&self.rewrite.glossary_path)) + } + pub fn resolved_agentic_policy_path(&self) -> PathBuf { - PathBuf::from(expand_tilde(&self.agentic_rewrite.policy_path)) + self.resolved_rewrite_policy_path() } pub fn resolved_agentic_glossary_path(&self) -> PathBuf { - PathBuf::from(expand_tilde(&self.agentic_rewrite.glossary_path)) + self.resolved_rewrite_glossary_path() } fn apply_legacy_transcription_migration(&mut self, contents: &str, config_path: &Path) { @@ -570,6 +612,44 @@ impl Config { } } + fn apply_legacy_rewrite_migration(&mut self, contents: &str, config_path: &Path) { + let Ok(doc) = contents.parse::() else { + return; + }; + + if !section_present(contents, "agentic_rewrite") { + return; + } + + let rewrite_has_policy = table_key_present(&doc, "rewrite", "policy_path"); + let rewrite_has_glossary = table_key_present(&doc, "rewrite", "glossary_path"); + let rewrite_has_default_policy = + table_key_present(&doc, "rewrite", "default_correction_policy"); + + if !rewrite_has_policy { + self.rewrite.policy_path = self.legacy_agentic_rewrite.policy_path.clone(); + } + if !rewrite_has_glossary { + self.rewrite.glossary_path = self.legacy_agentic_rewrite.glossary_path.clone(); + } + if !rewrite_has_default_policy { + self.rewrite.default_correction_policy = + self.legacy_agentic_rewrite.default_correction_policy; + } + + if rewrite_has_policy || rewrite_has_glossary || rewrite_has_default_policy { + tracing::warn!( + "config {} contains deprecated [agentic_rewrite]; [rewrite] takes precedence", + config_path.display() + ); + } else { + tracing::warn!( + "config {} uses deprecated [agentic_rewrite]; mapping to [rewrite]", + config_path.display() + ); + } + } + fn apply_cloud_sanitization(&mut self) { if self.transcription.local_backend == TranscriptionBackend::Cloud { tracing::warn!( @@ -688,7 +768,7 @@ flash_attn = true idle_timeout_ms = 120000 [postprocess] -# "raw" (default), "advanced_local", "agentic_rewrite", or "legacy_basic" for deprecated cleanup configs +# "raw" (default), "rewrite", or "legacy_basic" for deprecated cleanup configs mode = "raw" [session] @@ -714,7 +794,7 @@ snippet_trigger = "insert" backend = "local" # Cloud fallback behavior ("local" or "none") fallback = "local" -# Managed rewrite model name for advanced_local mode +# Managed rewrite model name for rewrite mode selected_model = "qwen-3.5-4b-q4_k_m" # Manual GGUF path override (empty = use selected managed model) # Custom rewrite models should be chat-capable GGUFs with an embedded @@ -729,14 +809,12 @@ timeout_ms = 30000 # How long the hidden rewrite worker stays warm without requests idle_timeout_ms = 120000 # Maximum characters accepted from the rewrite model -max_output_chars = 1200 +max_output_chars = 8192 # Maximum tokens to generate for rewritten output -max_tokens = 256 - -[agentic_rewrite] -# App-aware rewrite policy rules used by postprocess.mode = "agentic_rewrite" +max_tokens = 768 +# App-aware rewrite policy rules used by postprocess.mode = "rewrite" policy_path = "~/.local/share/whispers/app-rewrite-policy.toml" -# Technical glossary used by postprocess.mode = "agentic_rewrite" +# Technical glossary used by postprocess.mode = "rewrite" glossary_path = "~/.local/share/whispers/technical-glossary.toml" # Default correction policy ("conservative", "balanced", or "aggressive") default_correction_policy = "balanced" @@ -777,6 +855,22 @@ enabled = true # Custom sound file paths (empty = use bundled sounds) start_sound = "" stop_sound = "" + +[voice] +# Experimental live voice mode: mutate the target app while recording +live_inject = false +# Experimental live preview rewrite line in the OSD +live_rewrite = false +# How often to refresh the live ASR preview while recording +partial_interval_ms = 400 +# Minimum time between live rewrite preview updates +rewrite_interval_ms = 1400 +# Audio tail window retranscribed for each live ASR refresh +context_window_ms = 8000 +# Minimum recorded audio before live ASR starts updating +min_chunk_ms = 650 +# Freeze live injection if focus changes during the session +freeze_on_focus_change = true "# ); @@ -837,6 +931,7 @@ pub fn update_config_postprocess_mode(config_path: &Path, mode: PostprocessMode) .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; ensure_standard_postprocess_tables(&mut doc); + doc.as_table_mut().remove("agentic_rewrite"); doc["postprocess"]["mode"] = toml_edit::value(mode.as_str()); std::fs::write(config_path, doc.to_string()) @@ -854,11 +949,8 @@ pub fn update_config_rewrite_selection(config_path: &Path, selected_model: &str) .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; ensure_standard_postprocess_tables(&mut doc); - let mode = match doc["postprocess"]["mode"].as_str() { - Some("agentic_rewrite") => PostprocessMode::AgenticRewrite, - _ => PostprocessMode::AdvancedLocal, - }; - doc["postprocess"]["mode"] = toml_edit::value(mode.as_str()); + doc.as_table_mut().remove("agentic_rewrite"); + doc["postprocess"]["mode"] = toml_edit::value(PostprocessMode::Rewrite.as_str()); let rewrite_backend = doc["rewrite"] .as_table_like() .and_then(|table| table.get("backend")) @@ -874,13 +966,12 @@ pub fn update_config_rewrite_selection(config_path: &Path, selected_model: &str) doc["rewrite"]["profile"] = toml_edit::value(RewriteProfile::Auto.as_str()); doc["rewrite"]["timeout_ms"] = toml_edit::value(30000); doc["rewrite"]["idle_timeout_ms"] = toml_edit::value(120000); - doc["rewrite"]["max_output_chars"] = toml_edit::value(1200); - doc["rewrite"]["max_tokens"] = toml_edit::value(256); - doc["agentic_rewrite"]["policy_path"] = - toml_edit::value(crate::agentic_rewrite::default_policy_path()); - doc["agentic_rewrite"]["glossary_path"] = + doc["rewrite"]["max_output_chars"] = toml_edit::value(8192); + doc["rewrite"]["max_tokens"] = toml_edit::value(768); + doc["rewrite"]["policy_path"] = toml_edit::value(crate::agentic_rewrite::default_policy_path()); + doc["rewrite"]["glossary_path"] = toml_edit::value(crate::agentic_rewrite::default_glossary_path()); - doc["agentic_rewrite"]["default_correction_policy"] = + doc["rewrite"]["default_correction_policy"] = toml_edit::value(RewriteCorrectionPolicy::Balanced.as_str()); std::fs::write(config_path, doc.to_string()) @@ -964,25 +1055,46 @@ pub fn update_config_cloud_settings( Ok(()) } +pub fn update_config_voice_settings(config_path: &Path, voice: &VoiceConfig) -> Result<()> { + let contents = std::fs::read_to_string(config_path) + .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; + let mut doc = contents + .parse::() + .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; + + ensure_standard_postprocess_tables(&mut doc); + doc["voice"]["live_inject"] = toml_edit::value(voice.live_inject); + doc["voice"]["live_rewrite"] = toml_edit::value(voice.live_rewrite); + doc["voice"]["partial_interval_ms"] = toml_edit::value(voice.partial_interval_ms as i64); + doc["voice"]["rewrite_interval_ms"] = toml_edit::value(voice.rewrite_interval_ms as i64); + doc["voice"]["context_window_ms"] = toml_edit::value(voice.context_window_ms as i64); + doc["voice"]["min_chunk_ms"] = toml_edit::value(voice.min_chunk_ms as i64); + doc["voice"]["freeze_on_focus_change"] = toml_edit::value(voice.freeze_on_focus_change); + + std::fs::write(config_path, doc.to_string()) + .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; + Ok(()) +} + fn ensure_standard_postprocess_tables(doc: &mut toml_edit::DocumentMut) { ensure_root_table(doc, "transcription"); ensure_root_table(doc, "postprocess"); ensure_root_table(doc, "session"); ensure_root_table(doc, "rewrite"); - ensure_root_table(doc, "agentic_rewrite"); ensure_root_table(doc, "cloud"); ensure_nested_table(doc, "cloud", "transcription"); ensure_nested_table(doc, "cloud", "rewrite"); ensure_root_table(doc, "personalization"); + ensure_root_table(doc, "voice"); } fn normalize_postprocess_mode(doc: &mut toml_edit::DocumentMut) { let current = doc["postprocess"]["mode"].as_str().unwrap_or_default(); if !matches!( current, - "raw" | "advanced_local" | "agentic_rewrite" | "legacy_basic" + "raw" | "rewrite" | "advanced_local" | "agentic_rewrite" | "legacy_basic" ) { - doc["postprocess"]["mode"] = toml_edit::value(PostprocessMode::AdvancedLocal.as_str()); + doc["postprocess"]["mode"] = toml_edit::value(PostprocessMode::Rewrite.as_str()); } } @@ -1017,6 +1129,13 @@ fn cleanup_section_present(contents: &str) -> bool { section_present(contents, "cleanup") } +fn table_key_present(doc: &toml_edit::DocumentMut, table: &str, key: &str) -> bool { + doc[table] + .as_table_like() + .and_then(|table| table.get(key)) + .is_some() +} + fn section_present(contents: &str, name: &str) -> bool { toml::from_str::(contents) .ok() @@ -1047,6 +1166,7 @@ mod tests { assert_eq!(config.postprocess.mode, PostprocessMode::Raw); assert_eq!(config.personalization.snippet_trigger, "insert"); assert_eq!(config.rewrite.selected_model, "qwen-3.5-4b-q4_k_m"); + assert_eq!(config.voice, VoiceConfig::default()); } #[test] @@ -1122,6 +1242,7 @@ mod tests { assert!(raw.contains("[postprocess]")); assert!(raw.contains("[session]")); assert!(raw.contains("[rewrite]")); + assert!(raw.contains("[voice]")); assert!(!raw.contains("[whisper]")); } @@ -1191,7 +1312,7 @@ remove_fillers = false } #[test] - fn update_rewrite_selection_enables_advanced_mode() { + fn update_rewrite_selection_enables_rewrite_mode() { let dir = crate::test_support::unique_temp_dir("config-rewrite-select"); let config_path = dir.join("config.toml"); write_default_config(&config_path, "~/model.bin").expect("write config"); @@ -1200,7 +1321,7 @@ remove_fillers = false .expect("select rewrite model"); let loaded = Config::load(Some(&config_path)).expect("load config"); - assert_eq!(loaded.postprocess.mode, PostprocessMode::AdvancedLocal); + assert_eq!(loaded.postprocess.mode, PostprocessMode::Rewrite); assert_eq!(loaded.rewrite.selected_model, "qwen-3.5-2b-q4_k_m"); assert!(loaded.rewrite.model_path.is_empty()); assert_eq!( @@ -1244,13 +1365,34 @@ language = "auto" TranscriptionBackend::WhisperCpp ); assert_eq!(loaded.transcription.selected_model, "large-v3-turbo"); - assert_eq!(loaded.postprocess.mode, PostprocessMode::AdvancedLocal); + assert_eq!(loaded.postprocess.mode, PostprocessMode::Rewrite); assert_eq!(loaded.rewrite.selected_model, "qwen-3.5-4b-q4_k_m"); let raw = std::fs::read_to_string(&config_path).expect("read upgraded config"); assert!(!raw.contains("[whisper]")); } + #[test] + fn update_voice_settings_roundtrip() { + let config_path = crate::test_support::unique_temp_path("config-voice-settings", "toml"); + write_default_config(&config_path, "~/model.bin").expect("write config"); + + let voice = VoiceConfig { + live_inject: true, + live_rewrite: true, + partial_interval_ms: 900, + rewrite_interval_ms: 2500, + context_window_ms: 15000, + min_chunk_ms: 1800, + freeze_on_focus_change: false, + }; + + update_config_voice_settings(&config_path, &voice).expect("update voice settings"); + + let loaded = Config::load(Some(&config_path)).expect("load config"); + assert_eq!(loaded.voice, voice); + } + #[test] fn load_cloud_literal_key_from_legacy_api_key_env() { let path = crate::test_support::unique_temp_path("config-cloud-literal-key", "toml"); diff --git a/src/faster_whisper.rs b/src/faster_whisper.rs index 9c01336..1da5c35 100644 --- a/src/faster_whisper.rs +++ b/src/faster_whisper.rs @@ -265,6 +265,19 @@ impl FasterWhisperService { } pub async fn transcribe(&self, audio: &[f32], sample_rate: u32) -> Result { + self.transcribe_mode(audio, sample_rate, false).await + } + + pub async fn transcribe_live(&self, audio: &[f32], sample_rate: u32) -> Result { + self.transcribe_mode(audio, sample_rate, true).await + } + + async fn transcribe_mode( + &self, + audio: &[f32], + sample_rate: u32, + live: bool, + ) -> Result { let timeout = Duration::from_millis(60_000); self.ensure_running(timeout).await?; @@ -290,6 +303,7 @@ impl FasterWhisperService { let mut payload = serde_json::to_vec(&AsrRequest::Transcribe { audio_f32_b64: base64::engine::general_purpose::STANDARD.encode(audio_bytes), sample_rate, + live, }) .map_err(|e| WhsprError::Transcription(format!("failed to encode ASR request: {e}")))?; payload.push(b'\n'); diff --git a/src/faster_whisper_worker.py b/src/faster_whisper_worker.py index d3c8731..fb8491c 100644 --- a/src/faster_whisper_worker.py +++ b/src/faster_whisper_worker.py @@ -14,6 +14,9 @@ BEAM_SIZE = 3 BEST_OF = 3 CONDITION_ON_PREVIOUS_TEXT_MIN_SECONDS = 6.0 +LIVE_BEAM_SIZE = 1 +LIVE_BEST_OF = 1 +LIVE_MIN_DURATION_SECONDS = 0.25 def cmd_download(repo_id: str, model_dir: str) -> int: @@ -30,21 +33,45 @@ def make_model(model_dir: str, device: str, compute_type: str) -> WhisperModel: return WhisperModel(model_dir, device=device, compute_type=compute_type) -def transcribe(model: WhisperModel, audio_f32_b64: str, sample_rate: int, language: str) -> dict: +def empty_transcript(language: str) -> dict: + detected_language = None if language == "auto" else language + return { + "type": "transcript", + "transcript": { + "raw_text": "", + "detected_language": detected_language, + "segments": [], + }, + } + + +def transcribe( + model: WhisperModel, + audio_f32_b64: str, + sample_rate: int, + language: str, + live: bool, +) -> dict: if sample_rate != 16000: raise ValueError(f"unsupported sample rate {sample_rate}; expected 16000") audio_bytes = base64.b64decode(audio_f32_b64.encode("ascii")) audio = np.frombuffer(audio_bytes, dtype=np.float32) duration_seconds = float(audio.size) / float(sample_rate) + if live: + if duration_seconds < LIVE_MIN_DURATION_SECONDS: + return empty_transcript(language) + segments_iter, info = model.transcribe( audio, language=None if language == "auto" else language, task="transcribe", - beam_size=BEAM_SIZE, - best_of=BEST_OF, + beam_size=LIVE_BEAM_SIZE if live else BEAM_SIZE, + best_of=LIVE_BEST_OF if live else BEST_OF, condition_on_previous_text=( - duration_seconds >= CONDITION_ON_PREVIOUS_TEXT_MIN_SECONDS + False + if live + else duration_seconds >= CONDITION_ON_PREVIOUS_TEXT_MIN_SECONDS ), vad_filter=False, word_timestamps=False, @@ -93,6 +120,7 @@ def handle_connection(conn: socket.socket, model: WhisperModel, language: str) - request["audio_f32_b64"], int(request["sample_rate"]), language, + bool(request.get("live", False)), ) except Exception as exc: # noqa: BLE001 response = {"type": "error", "message": str(exc)} diff --git a/src/inject.rs b/src/inject.rs index 4fd91d2..51abf16 100644 --- a/src/inject.rs +++ b/src/inject.rs @@ -1,14 +1,36 @@ use std::process::{Command, Stdio}; +use std::sync::{Mutex, OnceLock}; use std::time::Duration; use evdev::uinput::VirtualDevice; use evdev::{AttributeSet, EventType, InputEvent, KeyCode}; +use crate::context::{SurfaceKind, TypingContext}; use crate::error::{Result, WhsprError}; -const DEVICE_READY_DELAY: Duration = Duration::from_millis(120); -const CLIPBOARD_READY_DELAY: Duration = Duration::from_millis(180); -const POST_DELETE_SETTLE_DELAY: Duration = Duration::from_millis(30); +const DEVICE_READY_DELAY: Duration = Duration::from_millis(45); +const PASTE_KEY_DELAY: Duration = Duration::from_millis(4); + +static INJECT_DEVICE: OnceLock>> = OnceLock::new(); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PasteShortcut { + CtrlV, + CtrlShiftV, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct InjectionPolicy { + paste_shortcut: PasteShortcut, + surface_label: &'static str, + backspace_key_delay: Duration, + backspace_burst_len: usize, + backspace_burst_pause: Duration, + clipboard_ready_delay: Duration, + post_delete_settle_delay: Duration, + live_destructive_delete_limit: usize, + destructive_correction_confirmations: usize, +} pub struct TextInjector { wl_copy_bin: String, @@ -31,7 +53,7 @@ impl TextInjector { } } - pub async fn inject(&self, text: &str) -> Result<()> { + pub async fn inject(&self, text: &str, context: &TypingContext) -> Result<()> { if text.is_empty() { tracing::warn!("empty text, nothing to inject"); return Ok(()); @@ -41,48 +63,154 @@ impl TextInjector { let text_len = text.len(); let wl_copy_bin = self.wl_copy_bin.clone(); let wl_copy_args = self.wl_copy_args.clone(); - tokio::task::spawn_blocking(move || inject_sync(&wl_copy_bin, &wl_copy_args, &text)) - .await - .map_err(|e| WhsprError::Injection(format!("injection task panicked: {e}")))??; + let policy = InjectionPolicy::for_context(context); + tokio::task::spawn_blocking(move || { + inject_sync(&wl_copy_bin, &wl_copy_args, &text, policy) + }) + .await + .map_err(|e| WhsprError::Injection(format!("injection task panicked: {e}")))??; - tracing::info!("injected {} chars via wl-copy + Ctrl+Shift+V", text_len); + tracing::info!( + paste_shortcut = policy.paste_shortcut().label(), + surface_policy = policy.label(), + "injected {} chars via wl-copy + paste shortcut", + text_len + ); Ok(()) } - pub async fn replace_recent_text(&self, delete_graphemes: usize, text: &str) -> Result<()> { + pub async fn replace_recent_text( + &self, + delete_graphemes: usize, + text: &str, + context: &TypingContext, + ) -> Result<()> { if delete_graphemes == 0 { - return self.inject(text).await; + return self.inject(text, context).await; } let text = text.to_string(); let wl_copy_bin = self.wl_copy_bin.clone(); let wl_copy_args = self.wl_copy_args.clone(); + let policy = InjectionPolicy::for_context(context); tokio::task::spawn_blocking(move || { - replace_recent_text_sync(&wl_copy_bin, &wl_copy_args, delete_graphemes, &text) + replace_recent_text_sync(&wl_copy_bin, &wl_copy_args, delete_graphemes, &text, policy) }) .await .map_err(|e| WhsprError::Injection(format!("replace task panicked: {e}")))??; tracing::info!( - "replaced {} graphemes via backspace + wl-copy paste", - delete_graphemes + delete_graphemes, + paste_shortcut = policy.paste_shortcut().label(), + surface_policy = policy.label(), + "replaced graphemes via backspace + wl-copy paste" ); Ok(()) } } -fn inject_sync(wl_copy_bin: &str, wl_copy_args: &[String], text: &str) -> Result<()> { - let mut device = build_virtual_device()?; +impl InjectionPolicy { + pub(crate) fn for_context(context: &TypingContext) -> Self { + match context.surface_kind { + SurfaceKind::Terminal => Self { + paste_shortcut: PasteShortcut::CtrlShiftV, + surface_label: "terminal", + backspace_key_delay: Duration::from_millis(2), + backspace_burst_len: 48, + backspace_burst_pause: Duration::from_millis(4), + clipboard_ready_delay: Duration::from_millis(50), + post_delete_settle_delay: Duration::from_millis(6), + live_destructive_delete_limit: usize::MAX, + destructive_correction_confirmations: 2, + }, + SurfaceKind::Editor => Self { + paste_shortcut: PasteShortcut::CtrlV, + surface_label: "editor", + backspace_key_delay: Duration::from_millis(3), + backspace_burst_len: 32, + backspace_burst_pause: Duration::from_millis(6), + clipboard_ready_delay: Duration::from_millis(55), + post_delete_settle_delay: Duration::from_millis(8), + live_destructive_delete_limit: 24, + destructive_correction_confirmations: 2, + }, + SurfaceKind::Browser => Self { + paste_shortcut: PasteShortcut::CtrlV, + surface_label: "browser", + backspace_key_delay: Duration::from_millis(5), + backspace_burst_len: 16, + backspace_burst_pause: Duration::from_millis(12), + clipboard_ready_delay: Duration::from_millis(70), + post_delete_settle_delay: Duration::from_millis(12), + live_destructive_delete_limit: 12, + destructive_correction_confirmations: 3, + }, + SurfaceKind::GenericText => Self { + paste_shortcut: PasteShortcut::CtrlV, + surface_label: "generic_text", + backspace_key_delay: Duration::from_millis(5), + backspace_burst_len: 12, + backspace_burst_pause: Duration::from_millis(14), + clipboard_ready_delay: Duration::from_millis(75), + post_delete_settle_delay: Duration::from_millis(14), + live_destructive_delete_limit: 8, + destructive_correction_confirmations: 2, + }, + SurfaceKind::Unknown => Self { + paste_shortcut: PasteShortcut::CtrlV, + surface_label: "unknown", + backspace_key_delay: Duration::from_millis(6), + backspace_burst_len: 10, + backspace_burst_pause: Duration::from_millis(16), + clipboard_ready_delay: Duration::from_millis(80), + post_delete_settle_delay: Duration::from_millis(16), + live_destructive_delete_limit: 0, + destructive_correction_confirmations: usize::MAX, + }, + } + } - run_wl_copy(wl_copy_bin, wl_copy_args, text)?; + pub(crate) fn destructive_correction_confirmations(self) -> usize { + self.destructive_correction_confirmations + } - // Wait for compositor to process the clipboard offer. - // The uinput device was created above, so it has already been - // registering during the wl-copy write. - std::thread::sleep(CLIPBOARD_READY_DELAY); - emit_paste_combo(&mut device)?; + pub(crate) fn allows_live_destructive_correction(self, delete_graphemes: usize) -> bool { + self.live_destructive_delete_limit > 0 + && delete_graphemes <= self.live_destructive_delete_limit + } - Ok(()) + pub(crate) fn label(self) -> &'static str { + self.surface_label + } + + fn paste_shortcut(self) -> PasteShortcut { + self.paste_shortcut + } +} + +impl PasteShortcut { + fn label(self) -> &'static str { + match self { + Self::CtrlV => "Ctrl+V", + Self::CtrlShiftV => "Ctrl+Shift+V", + } + } +} + +fn inject_sync( + wl_copy_bin: &str, + wl_copy_args: &[String], + text: &str, + policy: InjectionPolicy, +) -> Result<()> { + with_virtual_device(|device, created| { + if created { + std::thread::sleep(DEVICE_READY_DELAY); + } + run_wl_copy(wl_copy_bin, wl_copy_args, text)?; + std::thread::sleep(policy.clipboard_ready_delay); + emit_paste_combo(device, policy.paste_shortcut()) + }) } fn replace_recent_text_sync( @@ -90,30 +218,50 @@ fn replace_recent_text_sync( wl_copy_args: &[String], delete_graphemes: usize, text: &str, + policy: InjectionPolicy, ) -> Result<()> { - let mut device = build_virtual_device()?; - // Unlike plain injection, replacement can try to backspace immediately - // after creating the uinput device. Give the compositor a moment to - // register it first so the initial backspaces are not dropped. - std::thread::sleep(DEVICE_READY_DELAY); - emit_backspaces(&mut device, delete_graphemes)?; - - if !text.is_empty() { - std::thread::sleep(POST_DELETE_SETTLE_DELAY); - run_wl_copy(wl_copy_bin, wl_copy_args, text)?; - std::thread::sleep(CLIPBOARD_READY_DELAY); - emit_paste_combo(&mut device)?; - } + with_virtual_device(|device, created| { + if created { + // The first use needs a short registration window so the compositor + // doesn't miss the initial backspaces. + std::thread::sleep(DEVICE_READY_DELAY); + } + emit_backspaces(device, delete_graphemes, policy)?; - Ok(()) + if !text.is_empty() { + std::thread::sleep(policy.post_delete_settle_delay); + run_wl_copy(wl_copy_bin, wl_copy_args, text)?; + std::thread::sleep(policy.clipboard_ready_delay); + emit_paste_combo(device, policy.paste_shortcut())?; + } + + Ok(()) + }) +} + +fn with_virtual_device(f: impl FnOnce(&mut VirtualDevice, bool) -> Result) -> Result { + let device_store = INJECT_DEVICE.get_or_init(|| Mutex::new(None)); + let mut guard = device_store + .lock() + .map_err(|_| WhsprError::Injection("uinput device lock poisoned".into()))?; + let created = if guard.is_none() { + *guard = Some(build_virtual_device()?); + true + } else { + false + }; + let device = guard + .as_mut() + .ok_or_else(|| WhsprError::Injection("uinput device unavailable".into()))?; + f(device, created) } fn build_virtual_device() -> Result { let mut keys = AttributeSet::::new(); keys.insert(KeyCode::KEY_LEFTCTRL); keys.insert(KeyCode::KEY_LEFTSHIFT); - keys.insert(KeyCode::KEY_V); keys.insert(KeyCode::KEY_BACKSPACE); + keys.insert(KeyCode::KEY_V); VirtualDevice::builder() .map_err(|e| WhsprError::Injection(format!("uinput: {e}")))? @@ -179,14 +327,23 @@ fn run_wl_copy_with_timeout( Ok(()) } -fn emit_paste_combo(device: &mut VirtualDevice) -> Result<()> { +fn emit_paste_combo(device: &mut VirtualDevice, shortcut: PasteShortcut) -> Result<()> { + let mut modifier_events = vec![InputEvent::new( + EventType::KEY.0, + KeyCode::KEY_LEFTCTRL.0, + 1, + )]; + if matches!(shortcut, PasteShortcut::CtrlShiftV) { + modifier_events.push(InputEvent::new( + EventType::KEY.0, + KeyCode::KEY_LEFTSHIFT.0, + 1, + )); + } device - .emit(&[ - InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTCTRL.0, 1), - InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTSHIFT.0, 1), - ]) + .emit(&modifier_events) .map_err(|e| WhsprError::Injection(format!("paste modifier press: {e}")))?; - std::thread::sleep(Duration::from_millis(12)); + std::thread::sleep(PASTE_KEY_DELAY); device .emit(&[ @@ -194,27 +351,48 @@ fn emit_paste_combo(device: &mut VirtualDevice) -> Result<()> { InputEvent::new(EventType::KEY.0, KeyCode::KEY_V.0, 0), ]) .map_err(|e| WhsprError::Injection(format!("paste key press: {e}")))?; - std::thread::sleep(Duration::from_millis(12)); - + std::thread::sleep(PASTE_KEY_DELAY); + + let mut release_events = Vec::new(); + if matches!(shortcut, PasteShortcut::CtrlShiftV) { + release_events.push(InputEvent::new( + EventType::KEY.0, + KeyCode::KEY_LEFTSHIFT.0, + 0, + )); + } + release_events.push(InputEvent::new( + EventType::KEY.0, + KeyCode::KEY_LEFTCTRL.0, + 0, + )); device - .emit(&[ - InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTSHIFT.0, 0), - InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTCTRL.0, 0), - ]) + .emit(&release_events) .map_err(|e| WhsprError::Injection(format!("paste modifier release: {e}")))?; Ok(()) } -fn emit_backspaces(device: &mut VirtualDevice, count: usize) -> Result<()> { - for _ in 0..count { +fn emit_backspaces( + device: &mut VirtualDevice, + count: usize, + policy: InjectionPolicy, +) -> Result<()> { + for index in 0..count { device .emit(&[ InputEvent::new(EventType::KEY.0, KeyCode::KEY_BACKSPACE.0, 1), InputEvent::new(EventType::KEY.0, KeyCode::KEY_BACKSPACE.0, 0), ]) .map_err(|e| WhsprError::Injection(format!("backspace key press: {e}")))?; - std::thread::sleep(Duration::from_millis(6)); + std::thread::sleep(policy.backspace_key_delay); + let next = index + 1; + if next < count + && next % policy.backspace_burst_len == 0 + && !policy.backspace_burst_pause.is_zero() + { + std::thread::sleep(policy.backspace_burst_pause); + } } Ok(()) @@ -225,6 +403,17 @@ mod tests { use super::*; use crate::error::WhsprError; + fn context(surface_kind: SurfaceKind) -> TypingContext { + TypingContext { + focus_fingerprint: "focus".into(), + app_id: Some("app".into()), + window_title: Some("window".into()), + surface_kind, + browser_domain: None, + captured_at_ms: 0, + } + } + #[test] fn run_wl_copy_reports_spawn_failure() { let err = run_wl_copy("/definitely/missing/wl-copy", &[], "hello") @@ -273,6 +462,32 @@ mod tests { #[tokio::test] async fn inject_empty_text_is_noop() { let injector = TextInjector::with_wl_copy_command("/bin/true", &[]); - injector.inject("").await.expect("empty text should no-op"); + injector + .inject("", &TypingContext::unknown()) + .await + .expect("empty text should no-op"); + } + + #[test] + fn terminal_policy_uses_terminal_paste_shortcut() { + let policy = InjectionPolicy::for_context(&context(SurfaceKind::Terminal)); + assert_eq!(policy.paste_shortcut(), PasteShortcut::CtrlShiftV); + assert!(policy.allows_live_destructive_correction(64)); + assert_eq!(policy.destructive_correction_confirmations(), 2); + } + + #[test] + fn unknown_policy_disables_live_destructive_corrections() { + let policy = InjectionPolicy::for_context(&context(SurfaceKind::Unknown)); + assert_eq!(policy.paste_shortcut(), PasteShortcut::CtrlV); + assert!(!policy.allows_live_destructive_correction(1)); + } + + #[test] + fn browser_policy_requires_more_confirmation_and_smaller_live_rewrites() { + let policy = InjectionPolicy::for_context(&context(SurfaceKind::Browser)); + assert_eq!(policy.destructive_correction_confirmations(), 3); + assert!(policy.allows_live_destructive_correction(12)); + assert!(!policy.allows_live_destructive_correction(13)); } } diff --git a/src/main.rs b/src/main.rs index 046bbd0..13257cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ mod asr; mod asr_model; mod asr_protocol; mod audio; +mod branding; mod cleanup; mod cli; mod cloud; @@ -17,6 +18,8 @@ mod file_audio; mod inject; mod model; mod nemo_asr; +mod osd; +mod osd_protocol; mod personalization; mod postprocess; mod rewrite; @@ -26,10 +29,12 @@ mod rewrite_protocol; mod rewrite_worker; mod session; mod setup; +mod status; #[cfg(test)] mod test_support; mod transcribe; mod ui; +mod voice; use std::path::{Path, PathBuf}; @@ -262,6 +267,20 @@ async fn run_default(cli: &Cli) -> crate::error::Result<()> { app::run(config).await } +async fn run_voice(cli: &Cli) -> crate::error::Result<()> { + let Some(_pid_lock) = acquire_or_signal_lock()? else { + return Ok(()); + }; + + tracing::info!("whispers v{} (voice mode)", env!("CARGO_PKG_VERSION")); + + let config = Config::load(cli.config.as_deref())?; + asr::validate_transcription_config(&config)?; + tracing::debug!("config loaded: {config:?}"); + + voice::run(config).await +} + #[tokio::main] async fn main() -> crate::error::Result<()> { let cli = Cli::parse(); @@ -270,6 +289,8 @@ async fn main() -> crate::error::Result<()> { match &cli.command { None => run_default(&cli).await, + Some(Command::Status) => status::print_status(cli.config.as_deref()), + Some(Command::Voice) => run_voice(&cli).await, Some(Command::Completions { shell }) => completions::run_completions(*shell), Some(Command::Setup) => setup::run_setup(cli.config.as_deref()).await, Some(Command::Transcribe { file, output, raw }) => { diff --git a/src/nemo_asr.rs b/src/nemo_asr.rs index aa164ab..1a9ad5d 100644 --- a/src/nemo_asr.rs +++ b/src/nemo_asr.rs @@ -372,6 +372,7 @@ impl NemoAsrService { let mut payload = serde_json::to_vec(&AsrRequest::Transcribe { audio_f32_b64: base64::engine::general_purpose::STANDARD.encode(audio_bytes), sample_rate, + live: true, }) .map_err(|e| WhsprError::Transcription(format!("failed to encode ASR request: {e}")))?; payload.push(b'\n'); diff --git a/src/osd.rs b/src/osd.rs new file mode 100644 index 0000000..f939384 --- /dev/null +++ b/src/osd.rs @@ -0,0 +1,117 @@ +#[cfg(feature = "osd")] +use std::io::Write; +use std::process::Child; + +#[cfg(feature = "osd")] +use std::process::{ChildStdin, Command, Stdio}; + +#[cfg(feature = "osd")] +use crate::branding; +#[cfg(feature = "osd")] +use crate::osd_protocol::OsdEvent; + +pub enum OsdMode { + Meter, + Voice, +} + +pub struct OsdHandle { + child: Option, + #[cfg(feature = "osd")] + stdin: Option, +} + +impl OsdHandle { + pub fn spawn(mode: OsdMode) -> Self { + #[cfg(feature = "osd")] + { + let osd_path = branding::resolve_sidecar_executable(&[branding::OSD_BINARY]); + let mut command = Command::new(&osd_path); + if matches!(mode, OsdMode::Voice) { + command.arg("--voice").stdin(Stdio::piped()); + } + + match command.spawn() { + Ok(mut child) => { + tracing::debug!("spawned whispers-osd (pid {})", child.id()); + return Self { + stdin: child.stdin.take(), + child: Some(child), + }; + } + Err(e) => { + tracing::warn!( + "failed to spawn whispers-osd from {}: {e}", + osd_path.display() + ); + } + } + } + + let _ = mode; + Self { + child: None, + #[cfg(feature = "osd")] + stdin: None, + } + } + + pub fn send_voice_update(&mut self, update: &crate::osd_protocol::VoiceOsdUpdate) { + #[cfg(feature = "osd")] + { + let Some(stdin) = self.stdin.as_mut() else { + return; + }; + let Ok(payload) = serde_json::to_string(&OsdEvent::VoiceUpdate(update.clone())) else { + tracing::warn!("failed to encode voice OSD update"); + return; + }; + if let Err(err) = stdin.write_all(payload.as_bytes()) { + tracing::warn!("failed to write voice OSD update: {err}"); + return; + } + if let Err(err) = stdin.write_all(b"\n") { + tracing::warn!("failed to terminate voice OSD update: {err}"); + return; + } + if let Err(err) = stdin.flush() { + tracing::warn!("failed to flush voice OSD update: {err}"); + } + } + + let _ = update; + } + + pub fn kill(&mut self) { + if let Some(mut child) = self.child.take() { + let pid = child.id() as libc::pid_t; + unsafe { + libc::kill(pid, libc::SIGTERM); + } + let _ = child.wait(); + tracing::debug!("whispers-osd (pid {pid}) terminated"); + } + } +} + +impl Drop for OsdHandle { + fn drop(&mut self) { + self.kill(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn kill_without_child_is_noop() { + let mut handle = OsdHandle { + child: None, + #[cfg(feature = "osd")] + stdin: None, + }; + handle.kill(); + assert!(handle.child.is_none()); + } +} diff --git a/src/osd_protocol.rs b/src/osd_protocol.rs new file mode 100644 index 0000000..140732d --- /dev/null +++ b/src/osd_protocol.rs @@ -0,0 +1,49 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum VoiceOsdStatus { + #[default] + Listening, + Transcribing, + Rewriting, + Finalizing, + Frozen, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct VoiceOsdUpdate { + pub status: VoiceOsdStatus, + pub stable_text: String, + pub unstable_text: String, + pub rewrite_preview: Option, + pub live_inject: bool, + pub frozen: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum OsdEvent { + VoiceUpdate(VoiceOsdUpdate), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn voice_update_roundtrips_as_json() { + let event = OsdEvent::VoiceUpdate(VoiceOsdUpdate { + status: VoiceOsdStatus::Rewriting, + stable_text: "hello".into(), + unstable_text: "world".into(), + rewrite_preview: Some("Hello world.".into()), + live_inject: true, + frozen: false, + }); + + let json = serde_json::to_string(&event).expect("serialize event"); + let parsed: OsdEvent = serde_json::from_str(&json).expect("deserialize event"); + assert_eq!(parsed, event); + } +} diff --git a/src/personalization.rs b/src/personalization.rs index 15ffb10..3532561 100644 --- a/src/personalization.rs +++ b/src/personalization.rs @@ -6,11 +6,11 @@ use crate::cleanup; use crate::config::Config; use crate::error::{Result, WhsprError}; use crate::rewrite_protocol::{ - RewriteCandidate, RewriteCandidateKind, RewriteEditAction, RewriteEditHypothesis, - RewriteEditHypothesisMatchSource, RewriteEditIntent, RewriteEditSignal, RewriteEditSignalKind, - RewriteEditSignalScope, RewriteEditSignalStrength, RewriteIntentConfidence, - RewritePolicyContext, RewriteReplacementScope, RewriteTailShape, RewriteTranscript, - RewriteTranscriptSegment, + RewriteCandidate, RewriteCandidateKind, RewriteEditAction, RewriteEditContext, + RewriteEditHypothesis, RewriteEditHypothesisMatchSource, RewriteEditIntent, RewriteEditSignal, + RewriteEditSignalKind, RewriteEditSignalScope, RewriteEditSignalStrength, + RewriteIntentConfidence, RewritePolicyContext, RewriteReplacementScope, RewriteTailShape, + RewriteTranscript, RewriteTranscriptSegment, }; use crate::transcribe::Transcript; @@ -153,8 +153,12 @@ pub fn build_rewrite_transcript( &analysis.edit_hypotheses, rules, ); - let recommended_candidate = - recommended_candidate(&rewrite_candidates, &analysis.edit_hypotheses); + let edit_context = derive_edit_context(&transcript.raw_text, &analysis.edit_hypotheses); + let recommended_candidate = recommended_candidate( + &rewrite_candidates, + &analysis.edit_hypotheses, + &edit_context, + ); RewriteTranscript { raw_text, @@ -221,6 +225,7 @@ pub fn build_rewrite_transcript( edit_hypotheses, rewrite_candidates, recommended_candidate, + edit_context, policy_context: RewritePolicyContext::default(), } } @@ -386,7 +391,10 @@ fn build_rewrite_candidates( } } - if has_strong_explicit_hypothesis(edit_hypotheses) { + let edit_context = derive_edit_context(raw_text, edit_hypotheses); + if has_strong_explicit_hypothesis(edit_hypotheses) + && !(edit_context.cue_is_utterance_initial && edit_context.courtesy_prefix_detected) + { candidates.sort_by_key(|candidate| candidate_priority(candidate.kind)); } @@ -670,12 +678,70 @@ fn normalize_candidate_spacing(text: &str) -> Option { fn recommended_candidate( rewrite_candidates: &[RewriteCandidate], edit_hypotheses: &[cleanup::EditHypothesis], + edit_context: &RewriteEditContext, ) -> Option { + if edit_context.cue_is_utterance_initial && edit_context.courtesy_prefix_detected { + return None; + } + has_strong_explicit_hypothesis(edit_hypotheses) .then(|| rewrite_candidates.first().cloned()) .flatten() } +fn derive_edit_context( + raw_text: &str, + edit_hypotheses: &[cleanup::EditHypothesis], +) -> RewriteEditContext { + let spans = collect_word_spans(raw_text); + let Some(hypothesis) = earliest_strong_explicit_hypothesis(edit_hypotheses) else { + return RewriteEditContext::default(); + }; + + let prefix_words = spans + .get(..hypothesis.word_start) + .unwrap_or(&[]) + .iter() + .map(|span| span.normalized.as_str()) + .collect::>(); + let courtesy_prefix_word_count = courtesy_prefix_word_count(&prefix_words); + let preceding_content_word_count = prefix_words + .len() + .saturating_sub(courtesy_prefix_word_count); + + RewriteEditContext { + cue_is_utterance_initial: prefix_words.is_empty() || preceding_content_word_count == 0, + preceding_content_word_count, + courtesy_prefix_detected: courtesy_prefix_word_count > 0, + has_recent_same_focus_entry: false, + recommended_session_action_is_replace: false, + } +} + +fn earliest_strong_explicit_hypothesis( + edit_hypotheses: &[cleanup::EditHypothesis], +) -> Option<&cleanup::EditHypothesis> { + edit_hypotheses + .iter() + .filter(|hypothesis| { + hypothesis.strength == cleanup::EditSignalStrength::Strong + && matches!( + hypothesis.match_source, + cleanup::EditHypothesisMatchSource::Exact + | cleanup::EditHypothesisMatchSource::Alias + ) + }) + .min_by_key(|hypothesis| hypothesis.word_start) +} + +fn courtesy_prefix_word_count(words: &[&str]) -> usize { + match words { + ["my", "apologies"] => 2, + ["apologies"] | ["sorry"] => 1, + _ => 0, + } +} + fn has_strong_explicit_hypothesis(edit_hypotheses: &[cleanup::EditHypothesis]) -> bool { edit_hypotheses.iter().any(|hypothesis| { hypothesis.strength == cleanup::EditSignalStrength::Strong @@ -1518,6 +1584,28 @@ mod tests { ); } + #[test] + fn courtesy_prefixed_opening_does_not_force_non_literal_recommendation() { + let transcript = Transcript { + raw_text: "my apologies i meant jonatan".into(), + detected_language: Some("en".into()), + segments: Vec::new(), + }; + + let rewrite = build_rewrite_transcript(&transcript, &rules()); + assert!(rewrite.edit_context.cue_is_utterance_initial); + assert!(rewrite.edit_context.courtesy_prefix_detected); + assert_eq!(rewrite.edit_context.preceding_content_word_count, 0); + assert_eq!( + rewrite + .rewrite_candidates + .first() + .map(|candidate| candidate.kind), + Some(RewriteCandidateKind::Literal) + ); + assert!(rewrite.recommended_candidate.is_none()); + } + #[test] fn load_custom_instructions_tolerates_missing_file() { let mut config = Config::default(); diff --git a/src/postprocess.rs b/src/postprocess.rs index 797ce4d..afe17b1 100644 --- a/src/postprocess.rs +++ b/src/postprocess.rs @@ -40,9 +40,7 @@ pub fn raw_text(transcript: &Transcript) -> String { fn base_text(config: &Config, transcript: &Transcript) -> String { match config.postprocess.mode { PostprocessMode::LegacyBasic => cleanup::clean_transcript(transcript, &config.cleanup), - PostprocessMode::AdvancedLocal | PostprocessMode::AgenticRewrite => { - cleanup::correction_aware_text(transcript) - } + PostprocessMode::Rewrite => cleanup::correction_aware_text(transcript), PostprocessMode::Raw => raw_text(transcript), } } @@ -83,7 +81,7 @@ pub async fn finalize_transcript( }, &rules, ), - PostprocessMode::AdvancedLocal | PostprocessMode::AgenticRewrite => { + PostprocessMode::Rewrite => { rewrite_transcript_or_fallback( config, &transcript, @@ -162,69 +160,46 @@ async fn rewrite_transcript_or_fallback( } let mut rewrite_transcript = personalization::build_rewrite_transcript(transcript, rules); rewrite_transcript.typing_context = typing_context.and_then(session::to_rewrite_typing_context); - if config.postprocess.mode == PostprocessMode::AgenticRewrite { - agentic_rewrite::apply_runtime_policy(config, &mut rewrite_transcript); - } + agentic_rewrite::apply_runtime_policy(config, &mut rewrite_transcript); let session_plan = session::build_backtrack_plan(&rewrite_transcript, recent_session); + rewrite_transcript.edit_context.has_recent_same_focus_entry = recent_session.is_some(); + rewrite_transcript + .edit_context + .recommended_session_action_is_replace = + session_plan.recommended.as_ref().is_some_and(|candidate| { + matches!( + candidate.kind, + RewriteSessionBacktrackCandidateKind::ReplaceLastEntry + ) + }); rewrite_transcript.recent_session_entries = session_plan.recent_entries.clone(); rewrite_transcript.session_backtrack_candidates = session_plan.candidates.clone(); rewrite_transcript.recommended_session_candidate = session_plan.recommended.clone(); tracing::debug!( mode = config.postprocess.mode.as_str(), + rewrite_route = crate::rewrite::route_label(&rewrite_transcript), + correction_policy = rewrite_transcript.policy_context.correction_policy.as_str(), edit_hypotheses = rewrite_transcript.edit_hypotheses.len(), rewrite_candidates = rewrite_transcript.rewrite_candidates.len(), + glossary_candidates = rewrite_transcript.policy_context.glossary_candidates.len(), session_backtrack_candidates = rewrite_transcript.session_backtrack_candidates.len(), recommended_candidate = rewrite_transcript .recommended_candidate .as_ref() .map(|candidate| candidate.text.as_str()) .unwrap_or(""), + recommended_session_candidate = rewrite_transcript + .recommended_session_candidate + .as_ref() + .map(|candidate| candidate.text.as_str()) + .unwrap_or(""), "prepared rewrite request" ); + tracing::trace!( + "rewrite diagnostics:\n{}", + crate::rewrite::debug_summary(&rewrite_transcript) + ); let custom_instructions = personalization::custom_instructions(rules); - let deterministic_session_replacement = session_plan.deterministic_replacement_text.clone(); - - if let Some(text) = deterministic_session_replacement { - tracing::debug!( - output_len = text.len(), - mode = config.postprocess.mode.as_str(), - "using deterministic session replacement" - ); - let operation = rewrite_transcript - .recommended_session_candidate - .as_ref() - .and_then(|candidate| { - matches!( - candidate.kind, - RewriteSessionBacktrackCandidateKind::ReplaceLastEntry - ) - .then_some(FinalizedOperation::ReplaceLastEntry { - entry_id: candidate.entry_id?, - delete_graphemes: candidate.delete_graphemes, - }) - }) - .unwrap_or(FinalizedOperation::Append); - return finalize_plain_text( - text, - SessionRewriteSummary { - had_edit_cues: !rewrite_transcript.edit_signals.is_empty() - || !rewrite_transcript.edit_hypotheses.is_empty(), - rewrite_used: false, - recommended_candidate: rewrite_transcript - .recommended_session_candidate - .as_ref() - .map(|candidate| candidate.text.clone()) - .or_else(|| { - rewrite_transcript - .recommended_candidate - .as_ref() - .map(|candidate| candidate.text.clone()) - }), - }, - rules, - ) - .with_operation(operation); - } let rewrite_result = match config.rewrite.backend { RewriteBackend::Local => { @@ -282,18 +257,18 @@ async fn rewrite_transcript_or_fallback( .map(|candidate| candidate.text.clone()) }); - let (base, rewrite_used) = match rewrite_result { + let (base, rewrite_used, decision_source) = match rewrite_result { Ok(text) if rewrite_output_accepted(config, &rewrite_transcript, &text) => { tracing::debug!( output_len = text.len(), mode = config.postprocess.mode.as_str(), "rewrite applied successfully" ); - (text, true) + (text, true, "rewrite_accepted") } Ok(text) if text.trim().is_empty() => { tracing::warn!("rewrite model returned empty text; using fallback"); - (fallback, false) + (fallback, false, "empty_rewrite_output") } Ok(text) => { tracing::warn!( @@ -301,13 +276,33 @@ async fn rewrite_transcript_or_fallback( output_len = text.len(), "rewrite output failed acceptance guard; using fallback" ); - (fallback, false) + (fallback, false, "acceptance_guard_rejected") } Err(err) => { tracing::warn!("rewrite failed: {err}; using fallback"); - (fallback, false) + (fallback, false, "rewrite_error") } }; + let matched_recommended_candidate = recommended_candidate + .as_deref() + .map(|candidate| candidate.trim() == base.trim()) + .unwrap_or(false); + let matched_glossary_candidate = rewrite_transcript + .policy_context + .glossary_candidates + .iter() + .any(|candidate| candidate.text.trim() == base.trim()); + tracing::debug!( + mode = config.postprocess.mode.as_str(), + rewrite_route = crate::rewrite::route_label(&rewrite_transcript), + rewrite_used, + decision_source, + matched_recommended_candidate, + matched_glossary_candidate, + final_chars = base.len(), + "rewrite decision finalized" + ); + tracing::trace!("rewrite final text: {}", base); let operation = rewrite_transcript .recommended_session_candidate .as_ref() @@ -391,7 +386,7 @@ impl FinalizedTranscript { } fn rewrite_output_accepted( - config: &Config, + _config: &Config, rewrite_transcript: &RewriteTranscript, text: &str, ) -> bool { @@ -399,10 +394,6 @@ fn rewrite_output_accepted( return false; } - if config.postprocess.mode != PostprocessMode::AgenticRewrite { - return true; - } - match rewrite_transcript.policy_context.correction_policy { RewriteCorrectionPolicy::Conservative => { agentic_rewrite::conservative_output_allowed(rewrite_transcript, text) diff --git a/src/rewrite.rs b/src/rewrite.rs index ffe0c33..8684465 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -1,278 +1,25 @@ -use std::num::NonZeroU32; -use std::path::Path; -use std::sync::OnceLock; - -use encoding_rs::UTF_8; -use llama_cpp_2::context::params::LlamaContextParams; -use llama_cpp_2::llama_backend::LlamaBackend; -use llama_cpp_2::llama_batch::LlamaBatch; -use llama_cpp_2::model::params::LlamaModelParams; -use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaChatTemplate, LlamaModel}; -use llama_cpp_2::openai::OpenAIChatTemplateParams; -use llama_cpp_2::sampling::LlamaSampler; use serde_json::json; use crate::rewrite_profile::ResolvedRewriteProfile; use crate::rewrite_profile::RewriteProfile; use crate::rewrite_protocol::RewriteTranscript; +const LEGACY_REWRITE_MAX_TOKENS: usize = 256; +const RECOMMENDED_REWRITE_MAX_TOKENS: usize = 768; #[allow(dead_code)] -pub struct LocalRewriter { - model: LlamaModel, - chat_template: LlamaChatTemplate, - profile: ResolvedRewriteProfile, - max_tokens: usize, - max_output_chars: usize, -} - +const LEGACY_REWRITE_MAX_OUTPUT_CHARS: usize = 1200; #[allow(dead_code)] -static LLAMA_BACKEND: OnceLock<&'static LlamaBackend> = OnceLock::new(); -#[allow(dead_code)] -static EXTERNAL_LLAMA_BACKEND: LlamaBackend = LlamaBackend {}; +const RECOMMENDED_REWRITE_MAX_OUTPUT_CHARS: usize = 8192; #[derive(Debug, Clone, PartialEq, Eq)] pub struct RewritePrompt { pub system: String, pub user: String, } - -impl LocalRewriter { - #[allow(dead_code)] - pub fn new( - model_path: &Path, - profile: ResolvedRewriteProfile, - max_tokens: usize, - max_output_chars: usize, - ) -> std::result::Result { - if !model_path.exists() { - return Err(format!( - "rewrite model file not found: {}", - model_path.display() - )); - } - - let backend = llama_backend()?; - - let mut model_params = LlamaModelParams::default(); - if cfg!(feature = "cuda") { - model_params = model_params.with_n_gpu_layers(1000); - } - - let model = LlamaModel::load_from_file(backend, model_path, &model_params) - .map_err(|e| format!("failed to load rewrite model: {e}"))?; - let chat_template = model - .chat_template(None) - .map_err(|e| format!("rewrite model does not expose a usable chat template: {e}"))?; - - Ok(Self { - model, - chat_template, - profile, - max_tokens, - max_output_chars, - }) - } - - #[allow(dead_code)] - pub fn rewrite_with_instructions( - &self, - transcript: &RewriteTranscript, - custom_instructions: Option<&str>, - ) -> std::result::Result { - if transcript.raw_text.trim().is_empty() { - return Ok(String::new()); - } - - let prompt = build_rewrite_prompt( - &self.model, - &self.chat_template, - transcript, - self.profile, - custom_instructions, - )?; - let effective_max_tokens = effective_max_tokens(self.max_tokens, transcript); - let prompt_tokens = self - .model - .str_to_token(&prompt, AddBos::Never) - .map_err(|e| format!("failed to tokenize rewrite prompt: {e}"))?; - let behavior = rewrite_behavior(self.profile); - - let n_ctx_tokens = prompt_tokens - .len() - .saturating_add(effective_max_tokens) - .saturating_add(64) - .max(2048) - .min(u32::MAX as usize) as u32; - let n_batch = prompt_tokens - .len() - .max(512) - .min(n_ctx_tokens as usize) - .min(u32::MAX as usize) as u32; - let threads = std::thread::available_parallelism() - .map(|threads| threads.get()) - .unwrap_or(4) - .clamp(1, i32::MAX as usize) as i32; - - let ctx_params = LlamaContextParams::default() - .with_n_ctx(NonZeroU32::new(n_ctx_tokens)) - .with_n_batch(n_batch) - .with_n_ubatch(n_batch) - .with_n_threads(threads) - .with_n_threads_batch(threads); - let backend = llama_backend()?; - let mut ctx = self - .model - .new_context(backend, ctx_params) - .map_err(|e| format!("failed to create rewrite context: {e}"))?; - - let mut prompt_batch = LlamaBatch::new(prompt_tokens.len(), 1); - prompt_batch - .add_sequence(&prompt_tokens, 0, false) - .map_err(|e| format!("failed to enqueue rewrite prompt: {e}"))?; - ctx.decode(&mut prompt_batch) - .map_err(|e| format!("failed to decode rewrite prompt: {e}"))?; - - let mut sampler = LlamaSampler::chain_simple([ - LlamaSampler::top_k(behavior.top_k), - LlamaSampler::top_p(behavior.top_p, 1), - LlamaSampler::temp(behavior.temperature), - LlamaSampler::greedy(), - ]); - sampler.accept_many(prompt_tokens.iter()); - - let mut decoder = UTF_8.new_decoder_without_bom_handling(); - let mut output = String::new(); - let start_pos = i32::try_from(prompt_tokens.len()).unwrap_or(i32::MAX); - - for i in 0..effective_max_tokens { - let mut candidates = ctx.token_data_array(); - candidates.apply_sampler(&sampler); - let token = candidates - .selected_token() - .ok_or_else(|| "rewrite sampler did not select a token".to_string())?; - - if token == self.model.token_eos() { - break; - } - - sampler.accept(token); - - let piece = self - .model - .token_to_piece(token, &mut decoder, true, None) - .map_err(|e| format!("failed to decode rewrite token: {e}"))?; - output.push_str(&piece); - - if output.contains("") || output.len() >= self.max_output_chars { - break; - } - - let mut batch = LlamaBatch::new(1, 1); - batch - .add( - token, - start_pos.saturating_add(i32::try_from(i).unwrap_or(i32::MAX)), - &[0], - true, - ) - .map_err(|e| format!("failed to enqueue rewrite token: {e}"))?; - ctx.decode(&mut batch) - .map_err(|e| format!("failed to decode rewrite token: {e}"))?; - } - - let rewritten = sanitize_rewrite_output(&output); - if rewritten.is_empty() { - return Err("rewrite model returned empty output".into()); - } - - Ok(rewritten) - } -} - -#[allow(dead_code)] -fn llama_backend() -> std::result::Result<&'static LlamaBackend, String> { - if let Some(backend) = LLAMA_BACKEND.get().copied() { - return Ok(backend); - } - - match LlamaBackend::init() { - Ok(backend) => { - let backend = Box::leak(Box::new(backend)); - let _ = LLAMA_BACKEND.set(backend); - Ok(LLAMA_BACKEND - .get() - .copied() - .expect("llama backend initialized")) - } - // Use a static non-dropping token when another part of the process already - // owns llama.cpp global initialization. The worker never drops this token. - Err(llama_cpp_2::LlamaCppError::BackendAlreadyInitialized) => { - let _ = LLAMA_BACKEND.set(&EXTERNAL_LLAMA_BACKEND); - Ok(LLAMA_BACKEND - .get() - .copied() - .expect("external llama backend cached")) - } - Err(err) => Err(format!("failed to initialize llama backend: {err}")), - } -} - #[allow(dead_code)] -fn build_rewrite_prompt( - model: &LlamaModel, - chat_template: &LlamaChatTemplate, - transcript: &RewriteTranscript, - profile: ResolvedRewriteProfile, - custom_instructions: Option<&str>, -) -> std::result::Result { - let prompt = build_prompt(transcript, profile, custom_instructions)?; - if matches!(profile, ResolvedRewriteProfile::Qwen) { - return build_qwen_rewrite_prompt(model, chat_template, &prompt); - } - - let messages = vec![ - LlamaChatMessage::new("system".into(), prompt.system) - .map_err(|e| format!("failed to build rewrite system message: {e}"))?, - LlamaChatMessage::new("user".into(), prompt.user) - .map_err(|e| format!("failed to build rewrite user message: {e}"))?, - ]; - - model - .apply_chat_template(chat_template, &messages, true) - .map_err(|e| format!("failed to apply rewrite chat template: {e}")) -} - -fn build_qwen_rewrite_prompt( - model: &LlamaModel, - chat_template: &LlamaChatTemplate, +pub(crate) fn build_oaicompat_messages_json( prompt: &RewritePrompt, ) -> std::result::Result { - let messages_json = build_oaicompat_messages_json(prompt)?; - let result = model - .apply_chat_template_oaicompat( - chat_template, - &OpenAIChatTemplateParams { - messages_json: &messages_json, - tools_json: None, - tool_choice: None, - json_schema: None, - grammar: None, - reasoning_format: None, - chat_template_kwargs: None, - add_generation_prompt: true, - use_jinja: true, - parallel_tool_calls: false, - enable_thinking: false, - add_bos: false, - add_eos: false, - parse_tool_calls: false, - }, - ) - .map_err(|e| format!("failed to apply Qwen rewrite chat template: {e}"))?; - Ok(result.prompt) -} - -fn build_oaicompat_messages_json(prompt: &RewritePrompt) -> std::result::Result { serde_json::to_string(&vec![ json!({ "role": "system", @@ -346,7 +93,7 @@ fn correction_policy_contract( "Conservative: stay close to explicit rewrite candidates and glossary evidence. If uncertain, prefer candidate-preserving output over freer rewriting." } crate::rewrite_protocol::RewriteCorrectionPolicy::Balanced => { - "Balanced: allow stronger technical correction when the glossary, app context, or utterance semantics support it. Prefer candidate-backed output when it is competitive, but do not keep an obviously wrong technical spelling just because it appears in the candidate list." + "Balanced: proactively fix obvious technical or proper-name misrecognitions when glossary terms, app context, nearby category words, or utterance semantics make the intended term clearly more plausible than the literal transcript. Prefer candidate-backed output when it is competitive, but do not keep an obviously wrong technical spelling just because it appears in the candidate list." } crate::rewrite_protocol::RewriteCorrectionPolicy::Aggressive => { "Aggressive: allow freer technical correction and contextual cleanup when the utterance strongly points to a technical term or proper name. Candidates are useful evidence, not hard limits, as long as you still return only final text within the provided bounds." @@ -362,7 +109,7 @@ fn agentic_latitude_contract( "In conservative mode, treat the candidate list and glossary as the main evidence. Only make a freer technical normalization when the utterance itself makes the intended term unusually clear." } crate::rewrite_protocol::RewriteCorrectionPolicy::Balanced => { - "In balanced mode, you may normalize likely technical terms, product names, commands, libraries, languages, editors, or Linux components even when the literal transcript spelling is noisy or the exact canonical form is not already present in the candidate list, as long as the utterance strongly supports that normalization." + "In balanced mode, you should normalize likely technical terms, product names, commands, libraries, languages, editors, or Linux components when the literal transcript is an obvious phonetic near-miss and the utterance strongly supports the canonical form, even if that exact spelling is not already present in the candidate list." } crate::rewrite_protocol::RewriteCorrectionPolicy::Aggressive => { "In aggressive mode, you may confidently rewrite phonetically similar words into the most plausible technical term or proper name when the utterance semantics, app context, or nearby category cues make that interpretation clearly better than the literal transcript." @@ -370,39 +117,12 @@ fn agentic_latitude_contract( } } -#[allow(dead_code)] -struct RewriteBehavior { - top_k: i32, - top_p: f32, - temperature: f32, -} - -#[allow(dead_code)] -fn rewrite_behavior(profile: ResolvedRewriteProfile) -> RewriteBehavior { - match profile { - ResolvedRewriteProfile::Qwen => RewriteBehavior { - top_k: 24, - top_p: 0.9, - temperature: 0.1, - }, - ResolvedRewriteProfile::Generic => RewriteBehavior { - top_k: 32, - top_p: 0.92, - temperature: 0.12, - }, - ResolvedRewriteProfile::LlamaCompat => RewriteBehavior { - top_k: 40, - top_p: 0.95, - temperature: 0.15, - }, - } -} - fn rewrite_instructions(profile: ResolvedRewriteProfile) -> &'static str { let base = "You clean up dictated speech into the final text the user meant to type. \ Return only the finished text. Do not explain anything. Remove obvious disfluencies when natural. \ -Use the correction-aware transcript as the primary source of truth unless structured edit signals say the \ -utterance may still be ambiguous. The raw transcript may still contain spoken editing phrases or canceled wording. \ +Use the correction-aware transcript as strong heuristic evidence. When structured edit signals are present, treat it \ +as advisory rather than absolute and resolve ambiguity using the raw transcript, session context, and candidates. The \ +raw transcript may still contain spoken editing phrases or canceled wording. \ Never reintroduce text that was removed by an explicit spoken correction cue. Respect any structured edit intents \ provided alongside the transcript. If structured edit signals or edit hypotheses are present, use the candidate \ interpretations as bounded options, choose the best interpretation, and lightly refine it only when needed for natural \ @@ -411,10 +131,13 @@ explicit correction says otherwise. Do not normalize names into more common spel When the utterance clearly refers to software, tools, APIs, libraries, Linux components, product names, or other \ technical concepts, prefer the most plausible intended technical term or proper name over a phonetically similar common \ word. Use nearby category words like window manager, editor, language, library, package manager, shell, or terminal \ -tool to disambiguate technical names. If the utterance remains genuinely ambiguous, stay close to the transcript rather \ +tool to disambiguate technical names. When a dictated word is an obvious phonetic near-miss for a likely technical term \ +and the surrounding context clearly identifies the category, correct it to the canonical technical spelling instead of \ +echoing the miss. If multiple plausible interpretations remain similarly credible, stay close to the transcript rather \ than inventing a niche term. \ -If an edit intent says to replace or cancel previous wording, preserve that edit and do not keep the spoken correction \ -phrase itself unless the transcript clearly still intends it. Examples:\n\ +If an edit intent says to replace or cancel previous wording, preserve that edit when the utterance or same-session \ +context clearly supports it. Preserve utterance-initial courtesy or apology wording when the raw transcript still \ +clearly intends it. Examples:\n\ - raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ - raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ - raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ @@ -423,6 +146,7 @@ phrase itself unless the transcript clearly still intends it. Examples:\n\ - raw: Let's meet tomorrow, or rather Friday.\n correction-aware: Let's meet Friday.\n final: Let's meet Friday.\n\ - raw: I'm currently using the window manager Hyperland.\n correction-aware: I'm currently using the window manager Hyperland.\n final: I'm currently using the window manager Hyprland.\n\ - raw: I'm switching from Sui to Hyperland.\n correction-aware: I'm switching from Sui to Hyperland.\n final: I'm switching from Sway to Hyprland.\n\ +- raw: I moved back to the window manager neary.\n correction-aware: I moved back to the window manager neary.\n final: I moved back to the window manager niri.\n\ - raw: I use type script for backend tooling.\n correction-aware: I use type script for backend tooling.\n final: I use TypeScript for backend tooling.\n\ - raw: I edit the config in neo vim.\n correction-aware: I edit the config in neo vim.\n final: I edit the config in Neovim."; @@ -430,8 +154,9 @@ phrase itself unless the transcript clearly still intends it. Examples:\n\ ResolvedRewriteProfile::Qwen => { "You clean up dictated speech into the final text the user meant to type. \ Return only the finished text. Do not explain anything. Do not emit reasoning, think tags, or XML wrappers. \ -Remove obvious disfluencies when natural. Use the correction-aware transcript as the primary source of truth unless \ -structured edit signals say the utterance may still be ambiguous. The raw transcript may still contain spoken editing \ +Remove obvious disfluencies when natural. Use the correction-aware transcript as strong heuristic evidence. When \ +structured edit signals are present, treat it as advisory rather than absolute and resolve ambiguity using the raw \ +transcript, session context, and candidates. The raw transcript may still contain spoken editing \ phrases or canceled wording. Never reintroduce text that was removed by an explicit spoken correction cue. Respect \ any structured edit intents provided alongside the transcript. If structured edit signals or edit hypotheses are \ present, use the candidate interpretations as bounded options, choose the best interpretation, and lightly refine it \ @@ -440,9 +165,10 @@ unless a user dictionary or explicit correction says otherwise. Do not normalize because they look familiar. When the utterance clearly refers to software, tools, APIs, libraries, Linux components, \ product names, or other technical concepts, prefer the most plausible intended technical term or proper name over a \ phonetically similar common word. Use nearby category words like window manager, editor, language, library, package \ -manager, shell, or terminal tool to disambiguate technical names. If the utterance remains genuinely ambiguous, stay \ -close to the transcript rather than inventing a niche term. If an edit intent says to replace or cancel previous wording, preserve that edit and do \ -not keep the spoken correction phrase itself unless the transcript clearly still intends it. Examples:\n\ +manager, shell, or terminal tool to disambiguate technical names. When a dictated word is an obvious phonetic near-miss \ +for a likely technical term and the surrounding context clearly identifies the category, correct it to the canonical \ +technical spelling instead of echoing the miss. If multiple plausible interpretations remain similarly credible, stay \ +close to the transcript rather than inventing a niche term. If an edit intent says to replace or cancel previous wording, preserve that edit when the utterance or same-session context clearly supports it. Preserve utterance-initial courtesy or apology wording when the raw transcript still clearly intends it. Examples:\n\ - raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ - raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ - raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ @@ -451,6 +177,7 @@ not keep the spoken correction phrase itself unless the transcript clearly still - raw: Let's meet tomorrow, or rather Friday.\n correction-aware: Let's meet Friday.\n final: Let's meet Friday.\n\ - raw: I'm currently using the window manager Hyperland.\n correction-aware: I'm currently using the window manager Hyperland.\n final: I'm currently using the window manager Hyprland.\n\ - raw: I'm switching from Sui to Hyperland.\n correction-aware: I'm switching from Sui to Hyperland.\n final: I'm switching from Sway to Hyprland.\n\ +- raw: I moved back to the window manager neary.\n correction-aware: I moved back to the window manager neary.\n final: I moved back to the window manager niri.\n\ - raw: I use type script for backend tooling.\n correction-aware: I use type script for backend tooling.\n final: I use TypeScript for backend tooling.\n\ - raw: I edit the config in neo vim.\n correction-aware: I edit the config in neo vim.\n final: I edit the config in Neovim." } @@ -460,10 +187,11 @@ not keep the spoken correction phrase itself unless the transcript clearly still #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum RewriteRoute { - Fast, - ResolvedCorrection, SessionCandidateAdjudication, CandidateAdjudication, + AgenticCandidateAdjudication, + ResolvedCorrection, + Fast, } fn build_user_message(transcript: &RewriteTranscript) -> String { @@ -472,6 +200,7 @@ fn build_user_message(transcript: &RewriteTranscript) -> String { let raw = transcript.raw_text.trim(); let edit_intents = render_edit_intents(transcript); let edit_signals = render_edit_signals(transcript); + let edit_context = render_edit_context(transcript); let agentic_context = render_agentic_context(transcript); let route = rewrite_route(transcript); tracing::debug!( @@ -510,12 +239,15 @@ Active typing context:\n\ Recent dictation session:\n\ {recent_session_entries}\ {agentic_policy_context}\ +Structured cue context:\n\ +{edit_context}\ Session backtrack candidates:\n\ {session_candidates}\ {recommended_session_candidate}\ The user may be correcting the most recent prior dictation entry rather than appending new text.\n\ If the recommended session candidate says replace_last_entry, treat your final text as the replacement text for that previous dictation entry, not as newly appended text.\n\ Prefer the recommended session candidate unless another listed session candidate is clearly better.\n\ +If the utterance begins with a courtesy-prefixed correction cue and the session evidence is weak, preserve the courtesy wording instead of assuming replacement.\n\ {surface_guidance}\ Current utterance correction candidate:\n\ {correction_aware}\n\ @@ -533,9 +265,12 @@ Final text:" let recent_segments = render_recent_segments(transcript, 4); let aggressive_candidate = render_aggressive_candidate(transcript); let exact_cue_guidance = if has_strong_explicit_edit_cue(transcript) { - "A strong explicit spoken edit cue was detected. The literal raw transcript probably contains canceled wording. \ -Prefer a candidate interpretation that removes the cue and canceled wording unless doing so would clearly lose intended meaning. \ -If the cue is an exact strong match for phrases like scratch that, never mind, or wait no, do not keep the literal cue text in the final output.\n" + if opening_cue_requires_literal_bias(transcript) { + "A strong edit cue appears at the beginning of the utterance without strong same-session replacement evidence. \ +Do not assume it cancels earlier text. If the opening includes courtesy language such as sorry or my apologies, preserve that courtesy wording unless another candidate is clearly better.\n" + } else { + "A strong explicit spoken edit cue was detected. Prefer a candidate interpretation that preserves the intended edit when the cue clearly corrects earlier same-utterance wording or the most recent same-session dictation.\n" + } } else { "" }; @@ -550,12 +285,13 @@ Structured edit signals:\n\ {edit_signals}\ Structured edit intents:\n\ {edit_intents}\ +Structured cue context:\n\ +{edit_context}\ This utterance likely contains spoken self-corrections or restatements.\n\ Choose the best candidate interpretation and lightly refine it only when needed.\n\ {exact_cue_guidance}\ -When an exact strong edit cue is present, treat the non-literal candidates as more trustworthy than the literal transcript.\n\ -The candidate list is ordered from most likely to least likely by heuristics.\n\ -For exact strong edit cues, the first candidate is the heuristic best guess and should usually win unless another candidate is clearly better.\n\ +When an exact strong edit cue is present, treat non-literal candidates as evidence, not an automatic winner.\n\ +The candidate list is ordered heuristically and may be wrong for utterance-initial or courtesy-prefixed cues.\n\ Prefer the smallest replacement scope that yields a coherent result.\n\ Use span-level replacements when only a key phrase was corrected, clause-level replacements when the correction replaces the surrounding thought, and sentence-level replacements only when the whole sentence was canceled.\n\ Preserve literal wording when the cue is not actually an edit.\n\ @@ -571,6 +307,35 @@ Raw transcript:\n\ {raw}\n\ Recent segments:\n\ {recent_segments}\n\ +Final text:" + ) + } + RewriteRoute::AgenticCandidateAdjudication => { + let rewrite_candidates = render_rewrite_candidates(transcript); + let recommended_candidate = render_recommended_candidate(transcript); + let recent_segments = render_recent_segments(transcript, 4); + let glossary_candidates = render_glossary_candidates(transcript); + let aggressive_candidate = render_aggressive_candidate(transcript); + format!( + "Language: {language}\n\ +{agentic_context}\ +This utterance likely refers to a technical term, product name, command, library, or proper name that may need contextual normalization.\n\ +Choose the most plausible final text from the available candidates and the active glossary/app context.\n\ +Prefer the recommended candidate unless another listed candidate is clearly better.\n\ +If a glossary-backed candidate cleanly resolves an obvious phonetic near-miss, it should usually win.\n\ +Preserve uncommon names when the evidence points to them, but do not invent niche terms when the evidence is weak.\n\ +{recommended_candidate}\ +Candidate interpretations:\n\ +{rewrite_candidates}\ +Glossary-backed candidates:\n\ +{glossary_candidates}\ +Correction candidate:\n\ +{correction_aware}\n\ +{aggressive_candidate}\ +Raw transcript:\n\ +{raw}\n\ +Recent segments:\n\ +{recent_segments}\n\ Final text:" ) } @@ -581,6 +346,8 @@ Structured edit signals:\n\ {edit_signals}\ Structured edit intents:\n\ {edit_intents}\ +Structured cue context:\n\ +{edit_context}\ Self-corrections were already resolved before rewriting.\n\ Use this correction-aware transcript as the main source text. In agentic mode, you may still normalize likely \ technical terms or proper names when the utterance strongly supports them, even if the exact canonical spelling is not \ @@ -600,6 +367,8 @@ Structured edit signals:\n\ {edit_signals}\ Structured edit intents:\n\ {edit_intents}\ +Structured cue context:\n\ +{edit_context}\ Correction-aware transcript:\n\ {correction_aware}\n\ Treat the correction-aware transcript as authoritative for explicit spoken edits and overall meaning, but in agentic \ @@ -651,33 +420,33 @@ Active glossary terms:\n\ } fn render_agentic_runtime_context(transcript: &RewriteTranscript) -> String { - has_policy_context(transcript) - .then(|| { - format!( - "Active typing context:\n\ + if has_policy_context(transcript) { + format!( + "Active typing context:\n\ {}\ Recent dictation session:\n\ {}", - render_typing_context(transcript), - render_recent_session_entries(transcript), - ) - }) - .unwrap_or_default() + render_typing_context(transcript), + render_recent_session_entries(transcript), + ) + } else { + String::new() + } } fn render_agentic_candidates(transcript: &RewriteTranscript) -> String { - has_policy_context(transcript) - .then(|| { - format!( - "Available rewrite candidates (advisory, not exhaustive in agentic mode):\n\ + if has_policy_context(transcript) { + format!( + "Available rewrite candidates (advisory, not exhaustive in agentic mode):\n\ {}\ Glossary-backed candidates:\n\ {}", - render_rewrite_candidates(transcript), - render_glossary_candidates(transcript) - ) - }) - .unwrap_or_default() + render_rewrite_candidates(transcript), + render_glossary_candidates(transcript) + ) + } else { + String::new() + } } fn rewrite_route(transcript: &RewriteTranscript) -> RewriteRoute { @@ -685,6 +454,8 @@ fn rewrite_route(transcript: &RewriteTranscript) -> RewriteRoute { RewriteRoute::SessionCandidateAdjudication } else if requires_candidate_adjudication(transcript) { RewriteRoute::CandidateAdjudication + } else if requires_agentic_candidate_adjudication(transcript) { + RewriteRoute::AgenticCandidateAdjudication } else if transcript.correction_aware_text.trim() != transcript.raw_text.trim() { RewriteRoute::ResolvedCorrection } else { @@ -692,10 +463,90 @@ fn rewrite_route(transcript: &RewriteTranscript) -> RewriteRoute { } } +#[allow(dead_code)] +pub(crate) fn route_label(transcript: &RewriteTranscript) -> &'static str { + match rewrite_route(transcript) { + RewriteRoute::SessionCandidateAdjudication => "session_candidate_adjudication", + RewriteRoute::CandidateAdjudication => "candidate_adjudication", + RewriteRoute::AgenticCandidateAdjudication => "agentic_candidate_adjudication", + RewriteRoute::ResolvedCorrection => "resolved_correction", + RewriteRoute::Fast => "fast", + } +} + +#[allow(dead_code)] +pub(crate) fn debug_summary(transcript: &RewriteTranscript) -> String { + let recommended_candidate = transcript + .recommended_candidate + .as_ref() + .map(|candidate| candidate.text.as_str()) + .unwrap_or("(none)"); + let recommended_session_candidate = transcript + .recommended_session_candidate + .as_ref() + .map(|candidate| candidate.text.as_str()) + .unwrap_or("(none)"); + format!( + "route: {}\n\ +correction_policy: {}\n\ +raw_text: {}\n\ +correction_aware_text: {}\n\ +recommended_candidate: {}\n\ +recommended_session_candidate: {}\n\ +matched_rules:\n\ +{}\ +effective_rule_instructions:\n\ +{}\ +active_glossary_terms:\n\ +{}\ +rewrite_candidates:\n\ +{}\ +glossary_candidates:\n\ +{}", + route_label(transcript), + transcript.policy_context.correction_policy.as_str(), + transcript.raw_text.trim(), + transcript.correction_aware_text.trim(), + recommended_candidate, + recommended_session_candidate, + render_matched_rule_names(transcript), + render_effective_rule_instructions(transcript), + render_active_glossary_terms(transcript), + render_rewrite_candidates(transcript), + render_glossary_candidates(transcript), + ) +} + fn requires_candidate_adjudication(transcript: &RewriteTranscript) -> bool { !transcript.edit_signals.is_empty() || !transcript.edit_hypotheses.is_empty() } +fn requires_agentic_candidate_adjudication(transcript: &RewriteTranscript) -> bool { + if !has_policy_context(transcript) + || matches!( + transcript.policy_context.correction_policy, + crate::rewrite_protocol::RewriteCorrectionPolicy::Conservative + ) + { + return false; + } + + transcript + .recommended_candidate + .as_ref() + .map(|candidate| { + let text = candidate.text.trim(); + !text.is_empty() && text != transcript.correction_aware_text.trim() + }) + .unwrap_or_else(|| { + transcript + .policy_context + .glossary_candidates + .iter() + .any(|candidate| candidate.text.trim() != transcript.correction_aware_text.trim()) + }) +} + fn has_strong_explicit_edit_cue(transcript: &RewriteTranscript) -> bool { transcript.edit_hypotheses.iter().any(|hypothesis| { hypothesis.strength == crate::rewrite_protocol::RewriteEditSignalStrength::Strong @@ -707,6 +558,15 @@ fn has_strong_explicit_edit_cue(transcript: &RewriteTranscript) -> bool { }) } +fn opening_cue_requires_literal_bias(transcript: &RewriteTranscript) -> bool { + transcript.edit_context.cue_is_utterance_initial + && transcript.edit_context.courtesy_prefix_detected + && !(transcript.edit_context.has_recent_same_focus_entry + && transcript + .edit_context + .recommended_session_action_is_replace) +} + fn has_session_backtrack_candidate(transcript: &RewriteTranscript) -> bool { transcript.recommended_session_candidate.is_some() || !transcript.session_backtrack_candidates.is_empty() @@ -820,6 +680,25 @@ fn render_edit_hypotheses(transcript: &RewriteTranscript) -> String { rendered } +fn render_edit_context(transcript: &RewriteTranscript) -> String { + format!( + "- cue_is_utterance_initial: {}\n- preceding_content_word_count: {}\n- courtesy_prefix_detected: {}\n- has_recent_same_focus_entry: {}\n- recommended_session_action_is_replace: {}\n", + yes_no(transcript.edit_context.cue_is_utterance_initial), + transcript.edit_context.preceding_content_word_count, + yes_no(transcript.edit_context.courtesy_prefix_detected), + yes_no(transcript.edit_context.has_recent_same_focus_entry), + yes_no( + transcript + .edit_context + .recommended_session_action_is_replace + ), + ) +} + +fn yes_no(value: bool) -> &'static str { + if value { "yes" } else { "no" } +} + fn render_rewrite_candidates(transcript: &RewriteTranscript) -> String { if transcript.rewrite_candidates.is_empty() { return "- no candidates available\n".to_string(); @@ -1074,7 +953,8 @@ fn has_policy_context(transcript: &RewriteTranscript) -> bool { } #[allow(dead_code)] -fn effective_max_tokens(max_tokens: usize, transcript: &RewriteTranscript) -> usize { +pub(crate) fn effective_max_tokens(max_tokens: usize, transcript: &RewriteTranscript) -> usize { + let max_tokens = normalized_rewrite_max_tokens(max_tokens); let word_count = transcript .correction_aware_text .split_whitespace() @@ -1097,6 +977,46 @@ fn effective_max_tokens(max_tokens: usize, transcript: &RewriteTranscript) -> us derived.clamp(minimum, max_tokens) } +#[allow(dead_code)] +pub(crate) fn effective_max_output_chars( + max_output_chars: usize, + transcript: &RewriteTranscript, +) -> usize { + let max_output_chars = normalized_rewrite_max_output_chars(max_output_chars); + let transcript_chars = transcript + .correction_aware_text + .chars() + .count() + .max(transcript.raw_text.chars().count()); + let minimum = 1200; + let extra_margin = if requires_candidate_adjudication(transcript) { + 768 + } else { + 384 + }; + let derived = transcript_chars + .saturating_mul(2) + .saturating_add(extra_margin); + derived.clamp(minimum, max_output_chars) +} + +fn normalized_rewrite_max_tokens(max_tokens: usize) -> usize { + if max_tokens == LEGACY_REWRITE_MAX_TOKENS { + RECOMMENDED_REWRITE_MAX_TOKENS + } else { + max_tokens.max(64) + } +} + +#[allow(dead_code)] +fn normalized_rewrite_max_output_chars(max_output_chars: usize) -> usize { + if max_output_chars == LEGACY_REWRITE_MAX_OUTPUT_CHARS { + RECOMMENDED_REWRITE_MAX_OUTPUT_CHARS + } else { + max_output_chars.max(1200) + } +} + pub(crate) fn sanitize_rewrite_output(raw: &str) -> String { let mut text = raw.replace("\r\n", "\n"); @@ -1223,6 +1143,7 @@ mod tests { kind: RewriteCandidateKind::Literal, text: "Hi there, this is a test. Wait, no. Hi there.".into(), }), + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), } } @@ -1256,6 +1177,7 @@ mod tests { }, ], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), } } @@ -1292,6 +1214,7 @@ mod tests { }, ], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext { correction_policy: RewriteCorrectionPolicy::Balanced, matched_rule_names: vec!["baseline/global-default".into()], @@ -1305,6 +1228,31 @@ mod tests { } } + fn glossary_agentic_transcript() -> RewriteTranscript { + let mut transcript = fast_agentic_transcript(); + transcript.rewrite_candidates.insert( + 0, + RewriteCandidate { + kind: RewriteCandidateKind::GlossaryCorrection, + text: "I'm currently using the window manager Hyprland.".into(), + }, + ); + transcript.recommended_candidate = Some(RewriteCandidate { + kind: RewriteCandidateKind::GlossaryCorrection, + text: "I'm currently using the window manager Hyprland.".into(), + }); + transcript.policy_context.active_glossary_terms = + vec![crate::rewrite_protocol::RewritePolicyGlossaryTerm { + term: "Hyprland".into(), + aliases: vec!["hyperland".into()], + }]; + transcript.policy_context.glossary_candidates = vec![RewriteCandidate { + kind: RewriteCandidateKind::GlossaryCorrection, + text: "I'm currently using the window manager Hyprland.".into(), + }]; + transcript + } + #[test] fn instructions_cover_self_correction_examples() { let instructions = rewrite_instructions(ResolvedRewriteProfile::LlamaCompat); @@ -1313,6 +1261,7 @@ mod tests { assert!(instructions.contains("scratch that, brownies")); assert!(instructions.contains("window manager Hyperland")); assert!(instructions.contains("switching from Sui to Hyperland")); + assert!(instructions.contains("window manager neary")); } #[test] @@ -1327,6 +1276,18 @@ mod tests { let instructions = rewrite_instructions(ResolvedRewriteProfile::LlamaCompat); assert!(instructions.contains("technical concepts")); assert!(instructions.contains("phonetically similar common word")); + assert!(instructions.contains("obvious phonetic near-miss")); + } + + #[test] + fn balanced_policy_contract_pushes_obvious_corrections() { + let contract = correction_policy_contract(RewriteCorrectionPolicy::Balanced); + assert!( + contract.contains("proactively fix obvious technical or proper-name misrecognitions") + ); + let latitude = agentic_latitude_contract(RewriteCorrectionPolicy::Balanced); + assert!(latitude.contains("should normalize likely technical terms")); + assert!(latitude.contains("obvious phonetic near-miss")); } #[test] @@ -1371,9 +1332,11 @@ mod tests { assert!(instructions.contains( "do not keep an obviously wrong technical spelling just because it appears in the candidate list" )); - assert!(instructions.contains( - "even when the literal transcript spelling is noisy or the exact canonical form is not already present in the candidate list" - )); + assert!( + instructions.contains( + "even if that exact spelling is not already present in the candidate list" + ) + ); } #[test] @@ -1391,6 +1354,37 @@ mod tests { ); } + #[test] + fn agentic_glossary_prompt_uses_candidate_adjudication_route() { + let transcript = glossary_agentic_transcript(); + assert!(matches!( + rewrite_route(&transcript), + RewriteRoute::AgenticCandidateAdjudication + )); + let prompt = build_user_message(&transcript); + assert!(prompt.contains( + "This utterance likely refers to a technical term, product name, command, library, or proper name" + )); + assert!(prompt.contains( + "Prefer the recommended candidate unless another listed candidate is clearly better." + )); + assert!(prompt.contains("Glossary-backed candidates:")); + assert!(prompt.contains( + "Recommended interpretation:\nI'm currently using the window manager Hyprland." + )); + assert!( + prompt.contains("Raw transcript:\nI'm currently using the window manager hyperland.") + ); + assert!(prompt.contains("Candidate interpretations:\n")); + } + + #[test] + fn conservative_policy_does_not_use_agentic_candidate_adjudication_route() { + let mut transcript = glossary_agentic_transcript(); + transcript.policy_context.correction_policy = RewriteCorrectionPolicy::Conservative; + assert!(matches!(rewrite_route(&transcript), RewriteRoute::Fast)); + } + #[test] fn cue_prompt_includes_raw_candidate_and_signals() { let prompt = build_user_message(&correction_transcript()); @@ -1405,9 +1399,11 @@ mod tests { assert!(prompt.contains("Candidate interpretations")); assert!(prompt.contains("A strong explicit spoken edit cue was detected")); assert!(prompt.contains( - "The candidate list is ordered from most likely to least likely by heuristics." + "The candidate list is ordered heuristically and may be wrong for utterance-initial or courtesy-prefixed cues." )); - assert!(prompt.contains("the first candidate is the heuristic best guess")); + assert!( + prompt.contains("treat non-literal candidates as evidence, not an automatic winner") + ); assert!(prompt.contains("Recommended interpretation:")); assert!(prompt.contains( "Use this as the default final text unless another candidate is clearly better." @@ -1415,6 +1411,7 @@ mod tests { assert!( prompt.contains("Prefer the smallest replacement scope that yields a coherent result.") ); + assert!(prompt.contains("Structured cue context")); assert!(prompt.contains("- preferred_candidate")); assert!(prompt.contains( "- preferred_candidate literal (keep only if the cue was not actually an edit): Hi there, this is a test. Wait, no. Hi there." @@ -1430,6 +1427,65 @@ mod tests { assert!(prompt.contains("Recent segments")); } + #[test] + fn cue_prompt_preserves_courtesy_prefixed_opening_guidance() { + let transcript = RewriteTranscript { + raw_text: "My apologies, I meant jonatan.".into(), + correction_aware_text: "Jonatan.".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: vec![RewriteEditIntent { + action: RewriteEditAction::ReplacePreviousPhrase, + trigger: "i meant".into(), + confidence: RewriteIntentConfidence::High, + }], + edit_signals: vec![RewriteEditSignal { + trigger: "i meant".into(), + kind: RewriteEditSignalKind::Replace, + scope_hint: RewriteEditSignalScope::Phrase, + strength: RewriteEditSignalStrength::Strong, + }], + edit_hypotheses: vec![RewriteEditHypothesis { + cue_family: "i_meant".into(), + matched_text: "i meant".into(), + match_source: RewriteEditHypothesisMatchSource::Exact, + kind: RewriteEditSignalKind::Replace, + scope_hint: RewriteEditSignalScope::Phrase, + replacement_scope: RewriteReplacementScope::Span, + tail_shape: RewriteTailShape::Phrase, + strength: RewriteEditSignalStrength::Strong, + }], + rewrite_candidates: vec![ + RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "My apologies, I meant jonatan.".into(), + }, + RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "Jonatan.".into(), + }, + ], + recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext { + cue_is_utterance_initial: true, + preceding_content_word_count: 0, + courtesy_prefix_detected: true, + has_recent_same_focus_entry: false, + recommended_session_action_is_replace: false, + }, + policy_context: RewritePolicyContext::default(), + }; + + let prompt = build_user_message(&transcript); + assert!(prompt.contains("preserve that courtesy wording")); + assert!(prompt.contains("courtesy_prefix_detected: yes")); + } + #[test] fn cue_prompt_includes_aggressive_candidate_when_available() { let mut transcript = correction_transcript(); @@ -1479,6 +1535,7 @@ mod tests { text: "Hi there.".into(), }], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), }; @@ -1511,6 +1568,7 @@ mod tests { text: "hi there".into(), }], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), }; assert_eq!(effective_max_tokens(256, &short), 48); @@ -1533,6 +1591,7 @@ mod tests { text: "word ".repeat(80), }], recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), }; assert_eq!(effective_max_tokens(256, &long), 184); @@ -1600,6 +1659,7 @@ mod tests { kind: RewriteCandidateKind::SentenceReplacement, text: "Hi".into(), }), + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: RewritePolicyContext::default(), }; diff --git a/src/rewrite_local.rs b/src/rewrite_local.rs new file mode 100644 index 0000000..1f7b701 --- /dev/null +++ b/src/rewrite_local.rs @@ -0,0 +1,293 @@ +use std::num::NonZeroU32; +use std::path::Path; +use std::sync::OnceLock; + +use encoding_rs::UTF_8; +use llama_cpp_2::context::params::LlamaContextParams; +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::params::LlamaModelParams; +use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaChatTemplate, LlamaModel}; +use llama_cpp_2::openai::OpenAIChatTemplateParams; +use llama_cpp_2::sampling::LlamaSampler; + +use crate::rewrite::{ + RewritePrompt, build_oaicompat_messages_json, build_prompt, effective_max_output_chars, + effective_max_tokens, sanitize_rewrite_output, +}; +use crate::rewrite_profile::ResolvedRewriteProfile; +use crate::rewrite_protocol::RewriteTranscript; + +#[allow(dead_code)] +pub struct LocalRewriter { + model: LlamaModel, + chat_template: LlamaChatTemplate, + profile: ResolvedRewriteProfile, + max_tokens: usize, + max_output_chars: usize, +} + +#[allow(dead_code)] +static LLAMA_BACKEND: OnceLock<&'static LlamaBackend> = OnceLock::new(); +#[allow(dead_code)] +static EXTERNAL_LLAMA_BACKEND: LlamaBackend = LlamaBackend {}; + +impl LocalRewriter { + #[allow(dead_code)] + pub fn new( + model_path: &Path, + profile: ResolvedRewriteProfile, + max_tokens: usize, + max_output_chars: usize, + ) -> std::result::Result { + if !model_path.exists() { + return Err(format!( + "rewrite model file not found: {}", + model_path.display() + )); + } + + let backend = llama_backend()?; + + let mut model_params = LlamaModelParams::default(); + if cfg!(feature = "cuda") { + model_params = model_params.with_n_gpu_layers(1000); + } + + let model = LlamaModel::load_from_file(backend, model_path, &model_params) + .map_err(|e| format!("failed to load rewrite model: {e}"))?; + let chat_template = model + .chat_template(None) + .map_err(|e| format!("rewrite model does not expose a usable chat template: {e}"))?; + + Ok(Self { + model, + chat_template, + profile, + max_tokens, + max_output_chars, + }) + } + + #[allow(dead_code)] + pub fn rewrite_with_instructions( + &self, + transcript: &RewriteTranscript, + custom_instructions: Option<&str>, + ) -> std::result::Result { + if transcript.raw_text.trim().is_empty() { + return Ok(String::new()); + } + + let prompt = build_rewrite_prompt( + &self.model, + &self.chat_template, + transcript, + self.profile, + custom_instructions, + )?; + let effective_max_tokens = effective_max_tokens(self.max_tokens, transcript); + let effective_max_output_chars = + effective_max_output_chars(self.max_output_chars, transcript); + let prompt_tokens = self + .model + .str_to_token(&prompt, AddBos::Never) + .map_err(|e| format!("failed to tokenize rewrite prompt: {e}"))?; + let behavior = rewrite_behavior(self.profile); + + let n_ctx_tokens = prompt_tokens + .len() + .saturating_add(effective_max_tokens) + .saturating_add(64) + .max(2048) + .min(u32::MAX as usize) as u32; + let n_batch = prompt_tokens + .len() + .max(512) + .min(n_ctx_tokens as usize) + .min(u32::MAX as usize) as u32; + let threads = std::thread::available_parallelism() + .map(|threads| threads.get()) + .unwrap_or(4) + .clamp(1, i32::MAX as usize) as i32; + + let ctx_params = LlamaContextParams::default() + .with_n_ctx(NonZeroU32::new(n_ctx_tokens)) + .with_n_batch(n_batch) + .with_n_ubatch(n_batch) + .with_n_threads(threads) + .with_n_threads_batch(threads); + let backend = llama_backend()?; + let mut ctx = self + .model + .new_context(backend, ctx_params) + .map_err(|e| format!("failed to create rewrite context: {e}"))?; + + let mut prompt_batch = LlamaBatch::new(prompt_tokens.len(), 1); + prompt_batch + .add_sequence(&prompt_tokens, 0, false) + .map_err(|e| format!("failed to enqueue rewrite prompt: {e}"))?; + ctx.decode(&mut prompt_batch) + .map_err(|e| format!("failed to decode rewrite prompt: {e}"))?; + + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::top_k(behavior.top_k), + LlamaSampler::top_p(behavior.top_p, 1), + LlamaSampler::temp(behavior.temperature), + LlamaSampler::greedy(), + ]); + sampler.accept_many(prompt_tokens.iter()); + + let mut decoder = UTF_8.new_decoder_without_bom_handling(); + let mut output = String::new(); + let start_pos = i32::try_from(prompt_tokens.len()).unwrap_or(i32::MAX); + + for i in 0..effective_max_tokens { + let mut candidates = ctx.token_data_array(); + candidates.apply_sampler(&sampler); + let token = candidates + .selected_token() + .ok_or_else(|| "rewrite sampler did not select a token".to_string())?; + + if token == self.model.token_eos() { + break; + } + + sampler.accept(token); + + let piece = self + .model + .token_to_piece(token, &mut decoder, true, None) + .map_err(|e| format!("failed to decode rewrite token: {e}"))?; + output.push_str(&piece); + + if output.contains("") || output.len() >= effective_max_output_chars { + break; + } + + let mut batch = LlamaBatch::new(1, 1); + batch + .add( + token, + start_pos.saturating_add(i32::try_from(i).unwrap_or(i32::MAX)), + &[0], + true, + ) + .map_err(|e| format!("failed to enqueue rewrite token: {e}"))?; + ctx.decode(&mut batch) + .map_err(|e| format!("failed to decode rewrite token: {e}"))?; + } + + let rewritten = sanitize_rewrite_output(&output); + if rewritten.is_empty() { + return Err("rewrite model returned empty output".into()); + } + + Ok(rewritten) + } +} + +fn llama_backend() -> std::result::Result<&'static LlamaBackend, String> { + if let Some(backend) = LLAMA_BACKEND.get().copied() { + return Ok(backend); + } + + match LlamaBackend::init() { + Ok(backend) => { + let backend = Box::leak(Box::new(backend)); + let _ = LLAMA_BACKEND.set(backend); + Ok(LLAMA_BACKEND + .get() + .copied() + .expect("llama backend initialized")) + } + Err(llama_cpp_2::LlamaCppError::BackendAlreadyInitialized) => { + let _ = LLAMA_BACKEND.set(&EXTERNAL_LLAMA_BACKEND); + Ok(LLAMA_BACKEND + .get() + .copied() + .expect("external llama backend cached")) + } + Err(err) => Err(format!("failed to initialize llama backend: {err}")), + } +} + +fn build_rewrite_prompt( + model: &LlamaModel, + chat_template: &LlamaChatTemplate, + transcript: &RewriteTranscript, + profile: ResolvedRewriteProfile, + custom_instructions: Option<&str>, +) -> std::result::Result { + let prompt = build_prompt(transcript, profile, custom_instructions)?; + if matches!(profile, ResolvedRewriteProfile::Qwen) { + return build_qwen_rewrite_prompt(model, chat_template, &prompt); + } + + let messages = vec![ + LlamaChatMessage::new("system".into(), prompt.system) + .map_err(|e| format!("failed to build rewrite system message: {e}"))?, + LlamaChatMessage::new("user".into(), prompt.user) + .map_err(|e| format!("failed to build rewrite user message: {e}"))?, + ]; + + model + .apply_chat_template(chat_template, &messages, true) + .map_err(|e| format!("failed to apply rewrite chat template: {e}")) +} + +fn build_qwen_rewrite_prompt( + model: &LlamaModel, + chat_template: &LlamaChatTemplate, + prompt: &RewritePrompt, +) -> std::result::Result { + let messages_json = build_oaicompat_messages_json(prompt)?; + let result = model + .apply_chat_template_oaicompat( + chat_template, + &OpenAIChatTemplateParams { + messages_json: &messages_json, + tools_json: None, + tool_choice: None, + json_schema: None, + grammar: None, + reasoning_format: None, + chat_template_kwargs: None, + add_generation_prompt: true, + use_jinja: true, + parallel_tool_calls: false, + enable_thinking: false, + add_bos: false, + add_eos: false, + parse_tool_calls: false, + }, + ) + .map_err(|e| format!("failed to apply Qwen rewrite chat template: {e}"))?; + Ok(result.prompt) +} + +struct RewriteBehavior { + top_k: i32, + top_p: f32, + temperature: f32, +} + +fn rewrite_behavior(profile: ResolvedRewriteProfile) -> RewriteBehavior { + match profile { + ResolvedRewriteProfile::Qwen => RewriteBehavior { + top_k: 24, + top_p: 0.9, + temperature: 0.1, + }, + ResolvedRewriteProfile::Generic => RewriteBehavior { + top_k: 32, + top_p: 0.92, + temperature: 0.12, + }, + ResolvedRewriteProfile::LlamaCompat => RewriteBehavior { + top_k: 40, + top_p: 0.95, + temperature: 0.15, + }, + } +} diff --git a/src/rewrite_model.rs b/src/rewrite_model.rs index c7cbb1d..dc70485 100644 --- a/src/rewrite_model.rs +++ b/src/rewrite_model.rs @@ -31,7 +31,7 @@ pub const REWRITE_MODELS: &[RewriteModelInfo] = &[ name: "qwen-3.5-4b-q4_k_m", filename: "Qwen_Qwen3.5-4B-Q4_K_M.gguf", size: "~2.9 GB", - description: "Recommended balance for advanced_local mode", + description: "Recommended balance for rewrite mode", url: "https://huggingface.co/bartowski/Qwen_Qwen3.5-4B-GGUF/resolve/main/Qwen_Qwen3.5-4B-Q4_K_M.gguf", profile: RewriteProfile::Qwen, }, @@ -293,7 +293,7 @@ mod tests { select_model("qwen-3.5-2b-q4_k_m", Some(&config_path)).expect("select model"); let loaded = Config::load(Some(&config_path)).expect("load config"); - assert_eq!(loaded.postprocess.mode, PostprocessMode::AdvancedLocal); + assert_eq!(loaded.postprocess.mode, PostprocessMode::Rewrite); assert_eq!(loaded.rewrite.selected_model, "qwen-3.5-2b-q4_k_m"); } diff --git a/src/rewrite_protocol.rs b/src/rewrite_protocol.rs index edacdc4..33b13ff 100644 --- a/src/rewrite_protocol.rs +++ b/src/rewrite_protocol.rs @@ -18,6 +18,8 @@ pub struct RewriteTranscript { pub rewrite_candidates: Vec, pub recommended_candidate: Option, #[serde(default)] + pub edit_context: RewriteEditContext, + #[serde(default)] pub policy_context: RewritePolicyContext, } @@ -90,6 +92,15 @@ pub struct RewriteCandidate { pub text: String, } +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteEditContext { + pub cue_is_utterance_initial: bool, + pub preceding_content_word_count: usize, + pub courtesy_prefix_detected: bool, + pub has_recent_same_focus_entry: bool, + pub recommended_session_action_is_replace: bool, +} + #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] pub struct RewritePolicyContext { pub correction_policy: RewriteCorrectionPolicy, diff --git a/src/session.rs b/src/session.rs index 4088081..7e5b2c5 100644 --- a/src/session.rs +++ b/src/session.rs @@ -551,6 +551,7 @@ mod tests { kind: crate::rewrite_protocol::RewriteCandidateKind::SentenceReplacement, text: "Hi".into(), }), + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: crate::rewrite_protocol::RewritePolicyContext::default(), }; @@ -606,6 +607,7 @@ mod tests { edit_hypotheses: Vec::new(), rewrite_candidates: Vec::new(), recommended_candidate: None, + edit_context: crate::rewrite_protocol::RewriteEditContext::default(), policy_context: crate::rewrite_protocol::RewritePolicyContext::default(), }; diff --git a/src/setup.rs b/src/setup.rs index 5a44de8..8de13a6 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -1,9 +1,10 @@ use std::path::Path; use crate::asr_model::{self, ASR_MODELS, AsrModelInfo}; +use crate::cli::CompletionShell; use crate::config::{ self, CloudLanguageMode, CloudProvider, CloudSettingsUpdate, PostprocessMode, RewriteBackend, - RewriteFallback, TranscriptionBackend, TranscriptionFallback, resolve_config_path, + RewriteFallback, TranscriptionBackend, TranscriptionFallback, VoiceConfig, resolve_config_path, }; use crate::error::Result; use crate::rewrite_model::{self, REWRITE_MODELS}; @@ -14,6 +15,7 @@ struct SetupSelections { rewrite_model: Option<&'static str>, postprocess_mode: PostprocessMode, cloud: CloudSetup, + voice: VoiceSetup, } struct CloudSetup { @@ -27,6 +29,13 @@ struct CloudSetup { rewrite_fallback: RewriteFallback, } +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +struct VoiceSetup { + enabled: bool, + live_inject: bool, + live_rewrite: bool, +} + impl Default for CloudSetup { fn default() -> Self { Self { @@ -80,12 +89,15 @@ pub async fn run_setup(config_path_override: Option<&Path>) -> Result<()> { if cloud.rewrite_enabled && postprocess_mode == PostprocessMode::Raw { postprocess_mode = choose_rewrite_mode(&ui)?; } + let voice = configure_voice(&ui, postprocess_mode)?; + let completion_shells = choose_completion_shells(&ui)?; let selections = SetupSelections { asr_model, rewrite_model, postprocess_mode, cloud, + voice, }; let config_path = resolve_config_path(config_path_override); @@ -103,8 +115,10 @@ pub async fn run_setup(config_path_override: Option<&Path>) -> Result<()> { &selections.cloud, )?; apply_cloud_settings(&ui, &config_path, &selections.cloud)?; + apply_voice_selection(&ui, &config_path, &selections.voice)?; maybe_create_agentic_starter_files(&ui, &config_path, &selections)?; cleanup_stale_asr_workers(&ui, &config_path)?; + maybe_install_shell_completions(&ui, &completion_shells)?; if let Some(rewrite_model) = selections.rewrite_model { ui.print_ok(format!( @@ -119,7 +133,7 @@ pub async fn run_setup(config_path_override: Option<&Path>) -> Result<()> { ui.blank(); print_setup_summary(&ui, &selections); ui.blank(); - print_setup_complete(&ui); + print_setup_complete(&ui, &selections); Ok(()) } @@ -204,6 +218,21 @@ fn apply_cloud_settings(ui: &SetupUi, config_path: &Path, cloud: &CloudSetup) -> Ok(()) } +fn apply_voice_selection(ui: &SetupUi, config_path: &Path, voice: &VoiceSetup) -> Result<()> { + let voice_config = VoiceConfig { + live_inject: voice.enabled && voice.live_inject, + live_rewrite: voice.enabled && voice.live_rewrite, + ..Default::default() + }; + config::update_config_voice_settings(config_path, &voice_config)?; + + if !voice.enabled { + ui.print_info("Live voice mode: disabled."); + } + + Ok(()) +} + fn maybe_prewarm_experimental_nemo( ui: &SetupUi, config_path: &Path, @@ -303,18 +332,142 @@ async fn choose_rewrite_model( } fn choose_rewrite_mode(ui: &SetupUi) -> Result { + let items = ["rewrite: LLM-based rewrite with policy and technical glossary support"]; + let _selection = ui.select("Choose the rewrite mode", &items, 0)?; + Ok(PostprocessMode::Rewrite) +} + +fn configure_voice(ui: &SetupUi, postprocess_mode: PostprocessMode) -> Result { + if !ui.confirm("Enable experimental live voice mode?", false)? { + return Ok(VoiceSetup::default()); + } + let items = [ - "advanced_local: smart rewrite cleanup with current bounded-candidate behavior", - "agentic_rewrite: app-aware rewrite with policy and technical glossary support", + "Preview-only: live transcript OSD while recording, final insert on stop", + "Live inject: update the target app while recording (freezes on focus change)", ]; - let selection = ui.select("Choose the rewrite mode", &items, 1)?; - Ok(if selection == 0 { - PostprocessMode::AdvancedLocal + let selection = ui.select("Choose the live voice behavior", &items, 0)?; + let live_inject = selection == 1; + let live_rewrite = if postprocess_mode.uses_rewrite() { + ui.confirm( + "Show live rewrite preview in the OSD while recording?", + false, + )? } else { - PostprocessMode::AgenticRewrite + ui.print_info("Live rewrite preview needs a rewrite-enabled postprocess mode."); + false + }; + + Ok(VoiceSetup { + enabled: true, + live_inject, + live_rewrite, + }) +} + +fn choose_completion_shells(ui: &SetupUi) -> Result> { + let detected_shells = crate::completions::detect_installed_shells(); + let current_shell = crate::completions::detect_shell(); + + if detected_shells.is_empty() { + ui.print_info("Could not find any supported shells on PATH automatically."); + if !ui.confirm("Install shell completions anyway?", false)? { + return Ok(Vec::new()); + } + return Ok(vec![choose_shell_manually(ui)?]); + } + + let mut detected_names = detected_shells + .iter() + .map(|shell| shell.as_str()) + .collect::>() + .join(", "); + if let Some(shell) = current_shell { + detected_names.push_str(&format!(" (current shell hint: {})", shell.as_str())); + } + ui.print_info(format!("Detected supported shells: {detected_names}.")); + + if !ui.confirm("Install shell completions?", true)? { + return Ok(Vec::new()); + } + + if detected_shells.len() == 1 { + return Ok(detected_shells); + } + + let mut items = vec!["all detected shells".to_string()]; + items.extend( + detected_shells + .iter() + .map(|shell| shell_choice_label(*shell, current_shell)), + ); + let default = current_shell + .and_then(|shell| { + detected_shells + .iter() + .position(|candidate| *candidate == shell) + }) + .map_or(0, |index| index + 1); + let selection = ui.select("Choose shells for completion install", &items, default)?; + if selection == 0 { + return Ok(detected_shells); + } + + Ok(vec![detected_shells[selection - 1]]) +} + +fn choose_shell_manually(ui: &SetupUi) -> Result { + let items = ["bash", "zsh", "fish", "nushell"]; + let selection = ui.select("Choose a shell for completions", &items, 0)?; + Ok(match selection { + 0 => CompletionShell::Bash, + 1 => CompletionShell::Zsh, + 2 => CompletionShell::Fish, + _ => CompletionShell::Nushell, }) } +fn shell_choice_label(shell: CompletionShell, current_shell: Option) -> String { + if current_shell == Some(shell) { + format!("{} (current shell)", shell.as_str()) + } else { + shell.as_str().to_string() + } +} + +fn maybe_install_shell_completions(ui: &SetupUi, shells: &[CompletionShell]) -> Result<()> { + if shells.is_empty() { + return Ok(()); + } + + for shell in shells { + match crate::completions::install_completions(*shell) { + Ok(path) => { + ui.print_ok(format!( + "Installed {} completions at {}.", + crate::ui::value(shell.as_str()), + crate::ui::value(path.display().to_string()) + )); + if let Some(note) = crate::completions::install_note(*shell, &path) { + ui.print_info(note); + } + } + Err(err) => { + ui.print_warn(format!( + "Failed to install {} completions automatically: {err}", + shell.as_str() + )); + ui.print_info(format!( + "You can still run {} later.", + crate::ui::value(format!("whispers completions {}", shell.as_str())) + )); + } + } + } + + Ok(()) +} + async fn configure_cloud( ui: &SetupUi, rewrite_model: &mut Option<&'static str>, @@ -458,14 +611,46 @@ fn print_setup_summary(ui: &SetupUi, selections: &SetupSelections) { crate::ui::summary_key("NeMo note") ); } + + match ( + selections.voice.enabled, + selections.voice.live_inject, + selections.voice.live_rewrite, + ) { + (false, _, _) => println!(" {}: disabled", crate::ui::summary_key("Voice")), + (true, false, false) => println!( + " {}: preview-only via {}", + crate::ui::summary_key("Voice"), + crate::ui::value("whispers voice") + ), + (true, true, false) => println!( + " {}: live inject via {}", + crate::ui::summary_key("Voice"), + crate::ui::value("whispers voice") + ), + (true, false, true) => println!( + " {}: preview-only via {} with live rewrite preview", + crate::ui::summary_key("Voice"), + crate::ui::value("whispers voice") + ), + (true, true, true) => println!( + " {}: live inject via {} with live rewrite preview", + crate::ui::summary_key("Voice"), + crate::ui::value("whispers voice") + ), + } } -fn print_setup_complete(ui: &SetupUi) { +fn print_setup_complete(ui: &SetupUi, selections: &SetupSelections) { ui.print_header("Setup complete"); println!("You can now use whispers."); ui.print_section("Example keybind"); ui.print_subtle("Bind it to a key in your compositor, e.g. for Hyprland:"); println!(" bind = SUPER ALT, D, exec, whispers"); + if selections.voice.enabled { + println!(" bind = SUPER ALT, V, exec, whispers voice"); + ui.print_subtle("Voice mode is separate so you can keep the existing dictation flow."); + } } fn apply_runtime_backend_selection( @@ -508,7 +693,7 @@ fn maybe_create_agentic_starter_files( config_path: &Path, selections: &SetupSelections, ) -> Result<()> { - if selections.postprocess_mode != PostprocessMode::AgenticRewrite { + if !selections.postprocess_mode.uses_rewrite() { return Ok(()); } @@ -550,4 +735,64 @@ mod tests { TranscriptionFallback::ConfiguredLocal ); } + + #[test] + fn apply_voice_selection_persists_voice_defaults_and_toggles() { + let config_path = crate::test_support::unique_temp_path("setup-voice-selection", "toml"); + config::write_default_config(&config_path, "~/model.bin").expect("write config"); + + let ui = SetupUi::new(); + let voice = VoiceSetup { + enabled: true, + live_inject: true, + live_rewrite: true, + }; + apply_voice_selection(&ui, &config_path, &voice).expect("apply voice selection"); + + let config = Config::load(Some(&config_path)).expect("load config"); + assert!(config.voice.live_inject); + assert!(config.voice.live_rewrite); + assert_eq!( + config.voice.partial_interval_ms, + VoiceConfig::default().partial_interval_ms + ); + assert_eq!( + config.voice.freeze_on_focus_change, + VoiceConfig::default().freeze_on_focus_change + ); + } + + #[test] + fn maybe_install_shell_completions_writes_fish_script() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&[ + "HOME", + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + ]); + let root = crate::test_support::unique_temp_dir("setup-shell-completions"); + crate::test_support::set_env("HOME", &root.to_string_lossy()); + crate::test_support::remove_env("XDG_CONFIG_HOME"); + crate::test_support::remove_env("XDG_DATA_HOME"); + + let ui = SetupUi::new(); + maybe_install_shell_completions(&ui, &[CompletionShell::Fish]) + .expect("install fish completions"); + + let path = root.join(".config/fish/completions/whispers.fish"); + let contents = std::fs::read_to_string(path).expect("read fish completions"); + assert!(contents.contains("complete -c whispers")); + } + + #[test] + fn shell_choice_label_marks_current_shell() { + assert_eq!( + shell_choice_label(CompletionShell::Fish, Some(CompletionShell::Fish)), + "fish (current shell)" + ); + assert_eq!( + shell_choice_label(CompletionShell::Bash, Some(CompletionShell::Fish)), + "bash" + ); + } } diff --git a/src/status.rs b/src/status.rs new file mode 100644 index 0000000..c8e4195 --- /dev/null +++ b/src/status.rs @@ -0,0 +1,673 @@ +use std::fmt::Write; +use std::path::{Path, PathBuf}; + +use console::style; + +use crate::config::{self, CloudProvider, Config, RewriteBackend, TranscriptionBackend}; +use crate::error::Result; + +pub fn print_status(config_path_override: Option<&Path>) -> Result<()> { + let config_path = config::resolve_config_path(config_path_override); + let config_exists = config_path.exists(); + let config = Config::load(Some(&config_path))?; + print!("{}", render_status(&config_path, config_exists, &config)); + Ok(()) +} + +fn render_status(config_path: &Path, config_exists: bool, config: &Config) -> String { + let mut out = String::new(); + + let _ = writeln!(out, "{}", crate::ui::header("Whispers status")); + push_section(&mut out, "Config"); + push_path_field(&mut out, "path", config_path); + push_field( + &mut out, + "source", + if config_exists { + "config file" + } else { + "defaults (config file missing)" + }, + ValueStyle::Status, + ); + push_path_field(&mut out, "data_dir", &config::data_dir()); + + push_section(&mut out, "Transcription"); + push_field( + &mut out, + "backend", + config.transcription.backend.as_str(), + ValueStyle::Backend, + ); + push_field( + &mut out, + "local_backend", + config.transcription.resolved_local_backend().as_str(), + ValueStyle::Backend, + ); + push_field( + &mut out, + "fallback", + config.transcription.fallback.as_str(), + ValueStyle::Value, + ); + push_field( + &mut out, + "selected_model", + &config.transcription.selected_model, + ValueStyle::Value, + ); + push_field( + &mut out, + "model_status", + transcription_model_status(config), + ValueStyle::Status, + ); + push_path_field(&mut out, "model_path", &config.resolved_model_path()); + push_field( + &mut out, + "language", + &config.transcription.language, + ValueStyle::Value, + ); + push_field( + &mut out, + "use_gpu", + yes_no(config.transcription.use_gpu), + ValueStyle::Boolean, + ); + push_field( + &mut out, + "flash_attn", + yes_no(config.transcription.flash_attn), + ValueStyle::Boolean, + ); + push_field( + &mut out, + "idle_timeout_ms", + config.transcription.idle_timeout_ms.to_string(), + ValueStyle::Value, + ); + + push_section(&mut out, "Postprocess"); + push_field( + &mut out, + "mode", + config.postprocess.mode.as_str(), + ValueStyle::Value, + ); + push_field( + &mut out, + "rewrite_enabled", + yes_no(config.postprocess.mode.uses_rewrite()), + ValueStyle::Boolean, + ); + + push_section(&mut out, "Rewrite"); + push_field( + &mut out, + "backend", + config.rewrite.backend.as_str(), + if config.rewrite.backend == RewriteBackend::Cloud { + ValueStyle::Provider + } else { + ValueStyle::Backend + }, + ); + push_field( + &mut out, + "fallback", + config.rewrite.fallback.as_str(), + ValueStyle::Value, + ); + push_field( + &mut out, + "selected_model", + &config.rewrite.selected_model, + ValueStyle::Value, + ); + push_field( + &mut out, + "local_model_status", + rewrite_model_status(config), + ValueStyle::Status, + ); + push_field( + &mut out, + "local_model_path", + optional_path_display(resolve_rewrite_model_path(config).as_deref()), + ValueStyle::Path, + ); + push_field( + &mut out, + "profile", + config.rewrite.profile.as_str(), + ValueStyle::Value, + ); + push_field( + &mut out, + "instructions_status", + path_presence_status(config.resolved_rewrite_instructions_path().as_deref()), + ValueStyle::Status, + ); + push_field( + &mut out, + "instructions_path", + optional_path_display(config.resolved_rewrite_instructions_path().as_deref()), + ValueStyle::Path, + ); + push_field( + &mut out, + "timeout_ms", + config.rewrite.timeout_ms.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "idle_timeout_ms", + config.rewrite.idle_timeout_ms.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "max_output_chars", + config.rewrite.max_output_chars.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "max_tokens", + config.rewrite.max_tokens.to_string(), + ValueStyle::Value, + ); + + push_section(&mut out, "Rewrite Policy"); + push_field( + &mut out, + "enabled", + yes_no(config.postprocess.mode.uses_rewrite()), + ValueStyle::Boolean, + ); + push_field( + &mut out, + "default_correction_policy", + config.rewrite.default_correction_policy.as_str(), + ValueStyle::Value, + ); + push_field( + &mut out, + "policy_status", + path_presence_status(Some(config.resolved_rewrite_policy_path().as_path())), + ValueStyle::Status, + ); + push_path_field( + &mut out, + "policy_path", + &config.resolved_rewrite_policy_path(), + ); + push_field( + &mut out, + "glossary_status", + path_presence_status(Some(config.resolved_rewrite_glossary_path().as_path())), + ValueStyle::Status, + ); + push_path_field( + &mut out, + "glossary_path", + &config.resolved_rewrite_glossary_path(), + ); + + push_section(&mut out, "Cloud"); + push_field( + &mut out, + "transcription_active", + yes_no(config.transcription.backend == TranscriptionBackend::Cloud), + ValueStyle::Boolean, + ); + push_field( + &mut out, + "rewrite_active", + yes_no( + config.postprocess.mode.uses_rewrite() + && config.rewrite.backend == RewriteBackend::Cloud, + ), + ValueStyle::Boolean, + ); + push_field( + &mut out, + "provider", + config.cloud.provider.as_str(), + ValueStyle::Provider, + ); + push_field( + &mut out, + "base_url", + cloud_base_url_display(config), + ValueStyle::Path, + ); + push_field( + &mut out, + "api_key", + cloud_api_key_status(config), + ValueStyle::Status, + ); + push_field( + &mut out, + "transcription_model", + &config.cloud.transcription.model, + ValueStyle::Value, + ); + push_field( + &mut out, + "language_mode", + config.cloud.transcription.language_mode.as_str(), + ValueStyle::Value, + ); + push_field( + &mut out, + "forced_language", + if config.cloud.transcription.language.trim().is_empty() { + "(none)" + } else { + config.cloud.transcription.language.as_str() + }, + ValueStyle::Value, + ); + push_field( + &mut out, + "rewrite_model", + &config.cloud.rewrite.model, + ValueStyle::Value, + ); + push_field( + &mut out, + "rewrite_temperature", + format!("{:.2}", config.cloud.rewrite.temperature), + ValueStyle::Value, + ); + push_field( + &mut out, + "rewrite_max_output_tokens", + config.cloud.rewrite.max_output_tokens.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "connect_timeout_ms", + config.cloud.connect_timeout_ms.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "request_timeout_ms", + config.cloud.request_timeout_ms.to_string(), + ValueStyle::Value, + ); + + push_section(&mut out, "Personalization"); + push_field( + &mut out, + "dictionary_status", + path_presence_status(Some(config.resolved_dictionary_path().as_path())), + ValueStyle::Status, + ); + push_path_field( + &mut out, + "dictionary_path", + &config.resolved_dictionary_path(), + ); + push_field( + &mut out, + "snippets_status", + path_presence_status(Some(config.resolved_snippets_path().as_path())), + ValueStyle::Status, + ); + push_path_field(&mut out, "snippets_path", &config.resolved_snippets_path()); + push_field( + &mut out, + "snippet_trigger", + &config.personalization.snippet_trigger, + ValueStyle::Value, + ); + + push_section(&mut out, "Voice"); + push_field( + &mut out, + "live_inject", + yes_no(config.voice.live_inject), + ValueStyle::Boolean, + ); + push_field( + &mut out, + "live_rewrite", + yes_no(config.voice.live_rewrite), + ValueStyle::Boolean, + ); + push_field( + &mut out, + "partial_interval_ms", + config.voice.partial_interval_ms.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "rewrite_interval_ms", + config.voice.rewrite_interval_ms.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "context_window_ms", + config.voice.context_window_ms.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "min_chunk_ms", + config.voice.min_chunk_ms.to_string(), + ValueStyle::Value, + ); + push_field( + &mut out, + "freeze_on_focus_change", + yes_no(config.voice.freeze_on_focus_change), + ValueStyle::Boolean, + ); + + out +} + +fn push_section(out: &mut String, name: &str) { + let _ = writeln!(out, "\n{}", crate::ui::section(name)); +} + +fn push_field(out: &mut String, label: &str, value: impl AsRef, style: ValueStyle) { + let _ = writeln!( + out, + " {}: {}", + crate::ui::summary_key(label), + style_value(style, value.as_ref()) + ); +} + +fn yes_no(value: bool) -> &'static str { + if value { "yes" } else { "no" } +} + +fn push_path_field(out: &mut String, label: &str, path: &Path) { + push_field(out, label, path.display().to_string(), ValueStyle::Path); +} + +fn path_presence_status(path: Option<&Path>) -> &'static str { + match path { + Some(path) if path.exists() => "present", + Some(_) => "missing", + None => "not configured", + } +} + +fn optional_path_display(path: Option<&Path>) -> String { + path.map(|path| path.display().to_string()) + .unwrap_or_else(|| "(not configured)".into()) +} + +#[derive(Clone, Copy)] +enum ValueStyle { + Value, + Boolean, + Status, + Backend, + Provider, + Path, +} + +fn style_value(style_kind: ValueStyle, value: &str) -> String { + match style_kind { + ValueStyle::Value => crate::ui::value(value), + ValueStyle::Boolean => bool_token(value), + ValueStyle::Status => status_value_token(value), + ValueStyle::Backend => crate::ui::backend_token(value), + ValueStyle::Provider => crate::ui::provider_token(value), + ValueStyle::Path => crate::ui::subtle(value), + } +} + +fn bool_token(value: &str) -> String { + match value.trim() { + "yes" => style(value).bold().green().to_string(), + "no" => style(value).bold().red().to_string(), + _ => crate::ui::value(value), + } +} + +fn status_value_token(value: &str) -> String { + match status_category(value) { + StatusCategory::Good => style(value).bold().green().to_string(), + StatusCategory::Warn => style(value).bold().yellow().to_string(), + StatusCategory::Bad => style(value).bold().red().to_string(), + StatusCategory::Info => crate::ui::provider_token(value), + StatusCategory::Neutral => crate::ui::value(value), + StatusCategory::Subtle => crate::ui::subtle(value), + } +} + +fn status_category(value: &str) -> StatusCategory { + match value.trim() { + "ready" | "present" | "set" | "config file" | "config value set" => StatusCategory::Good, + "cloud" => StatusCategory::Info, + "missing" | "env:OPENAI_API_KEY (missing)" => StatusCategory::Bad, + value if value.contains("(missing)") => StatusCategory::Warn, + "not configured" | "defaults (config file missing)" => StatusCategory::Warn, + _ if value.starts_with("env:") && value.ends_with("(set)") => StatusCategory::Good, + _ if value.starts_with("https://") || value.starts_with("http://") => { + StatusCategory::Subtle + } + _ => StatusCategory::Neutral, + } +} + +#[derive(Clone, Copy)] +enum StatusCategory { + Good, + Warn, + Bad, + Info, + Neutral, + Subtle, +} + +fn transcription_model_status(config: &Config) -> &'static str { + let model_path = config.resolved_model_path(); + match config.transcription.resolved_local_backend() { + TranscriptionBackend::WhisperCpp => { + if model_path.exists() { + "ready" + } else { + "missing" + } + } + TranscriptionBackend::FasterWhisper => { + if crate::faster_whisper::model_dir_is_ready(&model_path) { + "ready" + } else { + "missing" + } + } + TranscriptionBackend::Nemo => { + if crate::nemo_asr::model_dir_is_ready(&model_path) { + "ready" + } else { + "missing" + } + } + TranscriptionBackend::Cloud => "cloud", + } +} + +fn resolve_rewrite_model_path(config: &Config) -> Option { + if let Some(path) = config.resolved_rewrite_model_path() { + return Some(path); + } + crate::rewrite_model::selected_model_path(&config.rewrite.selected_model) +} + +fn rewrite_model_status(config: &Config) -> &'static str { + let Some(model_path) = resolve_rewrite_model_path(config) else { + return "not configured"; + }; + if model_path.exists() { + "ready" + } else { + "missing" + } +} + +fn cloud_base_url_display(config: &Config) -> String { + let configured = config.cloud.base_url.trim(); + if !configured.is_empty() { + return configured.to_string(); + } + + match config.cloud.provider { + CloudProvider::OpenAi => "https://api.openai.com/v1 (default)".into(), + CloudProvider::OpenAiCompatible => "(missing for openai_compatible)".into(), + } +} + +fn cloud_api_key_status(config: &Config) -> String { + if !config.cloud.api_key.trim().is_empty() { + return "config value set".into(); + } + + let env_name = config.cloud.api_key_env.trim(); + if env_name.is_empty() { + return "missing".into(); + } + + let env_status = if std::env::var_os(env_name).is_some() { + "set" + } else { + "missing" + }; + format!("env:{env_name} ({env_status})") +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::config::{PostprocessMode, RewriteBackend, TranscriptionBackend}; + + #[test] + fn render_status_reports_selected_runtime_settings() { + let _env_lock = crate::test_support::env_lock(); + let temp = crate::test_support::unique_temp_dir("status-render"); + let config_path = temp.join("config.toml"); + + let asr_model = temp.join("ggml-large-v3-turbo.bin"); + std::fs::write(&asr_model, "bin").unwrap(); + + let rewrite_model = temp.join("rewrite.gguf"); + std::fs::write(&rewrite_model, "gguf").unwrap(); + + let instructions_path = temp.join("rewrite-instructions.txt"); + std::fs::write(&instructions_path, "Keep terms exact.").unwrap(); + + let policy_path = temp.join("app-rewrite-policy.toml"); + std::fs::write(&policy_path, "").unwrap(); + + let glossary_path = temp.join("technical-glossary.toml"); + std::fs::write(&glossary_path, "").unwrap(); + + let dictionary_path = temp.join("dictionary.toml"); + std::fs::write(&dictionary_path, "").unwrap(); + + let snippets_path = temp.join("snippets.toml"); + std::fs::write(&snippets_path, "").unwrap(); + + crate::test_support::set_env("WHISPERS_STATUS_TEST_KEY", "secret"); + + let mut config = Config::default(); + config.transcription.backend = TranscriptionBackend::WhisperCpp; + config.transcription.local_backend = TranscriptionBackend::WhisperCpp; + config.transcription.model_path = asr_model.display().to_string(); + config.transcription.selected_model = "large-v3-turbo".into(); + config.postprocess.mode = PostprocessMode::Rewrite; + config.rewrite.backend = RewriteBackend::Local; + config.rewrite.model_path = rewrite_model.display().to_string(); + config.rewrite.instructions_path = instructions_path.display().to_string(); + config.rewrite.policy_path = policy_path.display().to_string(); + config.rewrite.glossary_path = glossary_path.display().to_string(); + config.personalization.dictionary_path = dictionary_path.display().to_string(); + config.personalization.snippets_path = snippets_path.display().to_string(); + config.cloud.api_key = String::new(); + config.cloud.api_key_env = "WHISPERS_STATUS_TEST_KEY".into(); + config.voice.live_inject = true; + config.voice.live_rewrite = true; + + let rendered = render_status(&config_path, true, &config); + assert!(rendered.contains("Whispers status")); + assert!(rendered.contains(&crate::ui::section("Config"))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("source"), + style_value(ValueStyle::Status, "config file") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("backend"), + style_value(ValueStyle::Backend, "whisper_cpp") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("model_status"), + style_value(ValueStyle::Status, "ready") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("mode"), + style_value(ValueStyle::Value, "rewrite") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("default_correction_policy"), + style_value(ValueStyle::Value, "balanced") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("api_key"), + style_value(ValueStyle::Status, "env:WHISPERS_STATUS_TEST_KEY (set)") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("live_inject"), + style_value(ValueStyle::Boolean, "yes") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("live_rewrite"), + style_value(ValueStyle::Boolean, "yes") + ))); + } + + #[test] + fn render_status_marks_missing_optional_files() { + let config = Config::default(); + let rendered = render_status(Path::new("/tmp/whispers-status.toml"), false, &config); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("source"), + style_value(ValueStyle::Status, "defaults (config file missing)") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("policy_status"), + style_value(ValueStyle::Status, "missing") + ))); + assert!(rendered.contains(&format!( + "{}: {}", + crate::ui::summary_key("glossary_status"), + style_value(ValueStyle::Status, "missing") + ))); + } +} diff --git a/src/voice.rs b/src/voice.rs new file mode 100644 index 0000000..8fe393f --- /dev/null +++ b/src/voice.rs @@ -0,0 +1,1126 @@ +use std::time::{Duration, Instant}; + +use tokio::time::MissedTickBehavior; +use unicode_segmentation::UnicodeSegmentation; + +use crate::asr::{self, LiveTranscriber}; +use crate::audio::{self, AudioRecorder}; +use crate::config::Config; +use crate::context::{self, TypingContext}; +use crate::error::Result; +use crate::feedback::FeedbackPlayer; +use crate::inject::{InjectionPolicy, TextInjector}; +use crate::osd::{OsdHandle, OsdMode}; +use crate::osd_protocol::{VoiceOsdStatus, VoiceOsdUpdate}; +use crate::postprocess::{self, FinalizedOperation, FinalizedTranscript}; +use crate::rewrite_worker::RewriteService; +use crate::session::{self, EligibleSessionEntry}; +use crate::transcribe::Transcript; + +const UNSTABLE_TAIL_MS: u32 = 3500; +const LIVE_MIN_TRANSCRIBE_DELTA_MS: u64 = 220; +const LIVE_SILENCE_WINDOW_MS: u64 = 450; +const LIVE_SILENCE_RMS_THRESHOLD: f32 = 0.0022; +const LIVE_SILENCE_SETTLE_MS: u64 = 220; + +pub async fn run(config: Config) -> Result<()> { + let activation_started = Instant::now(); + let mut sigusr1 = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined1())?; + let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?; + + let feedback = FeedbackPlayer::new( + config.feedback.enabled, + &config.feedback.start_sound, + &config.feedback.stop_sound, + ); + + feedback.play_start(); + let recording_context = context::capture_typing_context(); + let session_enabled = config.postprocess.mode.uses_rewrite(); + let recent_session = if session_enabled { + session::load_recent_entry(&config.session, &recording_context)? + } else { + None + }; + let replaceable_prefix_graphemes = recent_session + .as_ref() + .map(|entry| entry.delete_graphemes) + .unwrap_or(0); + + let mut recorder = AudioRecorder::new(&config.audio); + recorder.start()?; + + let mut osd = OsdHandle::spawn(OsdMode::Voice); + let mut accumulator = VoiceTranscriptAccumulator::default(); + let mut live_preview_pacing = LivePreviewPacing::default(); + let mut rewrite_preview = None::; + let mut live_injection = LiveInjectionState::new( + config.voice.live_inject, + config.voice.freeze_on_focus_change, + &recording_context, + replaceable_prefix_graphemes, + ); + osd.send_voice_update(&build_osd_update( + VoiceOsdStatus::Listening, + &accumulator, + rewrite_preview.as_deref(), + &live_injection, + config.voice.live_inject, + )); + + tracing::info!("voice recording... (run whispers voice again to stop)"); + + let transcriber = asr::prepare_live_transcriber(&config).await?; + let rewrite_service = postprocess::prepare_rewrite_service(&config); + asr::prewarm_live_transcriber(&transcriber, "voice recording"); + if let Some(service) = rewrite_service.as_ref() { + postprocess::prewarm_rewrite_service(service, "voice recording"); + } + + let partial_interval_ms = config.voice.partial_interval_ms.max(50); + let mut partial_tick = tokio::time::interval(Duration::from_millis(partial_interval_ms)); + partial_tick.set_missed_tick_behavior(MissedTickBehavior::Skip); + partial_tick.tick().await; + let rewrite_interval = Duration::from_millis(config.voice.rewrite_interval_ms.max(1)); + let mut last_rewrite_at = Instant::now() - rewrite_interval; + + loop { + tokio::select! { + _ = sigusr1.recv() => { + tracing::info!("toggle signal received, stopping voice recording"); + break; + } + _ = tokio::signal::ctrl_c() => { + tracing::info!("interrupted, cancelling voice recording"); + osd.kill(); + recorder.stop()?; + return Ok(()); + } + _ = sigterm.recv() => { + tracing::info!("terminated, cancelling voice recording"); + osd.kill(); + recorder.stop()?; + return Ok(()); + } + _ = partial_tick.tick() => { + if let Err(err) = process_partial_tick( + &config, + &recorder, + &transcriber, + rewrite_service.as_ref(), + recent_session.as_ref(), + &mut accumulator, + &mut live_preview_pacing, + &mut rewrite_preview, + &mut last_rewrite_at, + &mut live_injection, + &mut osd, + ).await { + tracing::warn!("live partial update failed: {err}"); + osd.send_voice_update(&build_osd_update( + if live_injection.is_frozen() { + VoiceOsdStatus::Frozen + } else { + VoiceOsdStatus::Listening + }, + &accumulator, + rewrite_preview.as_deref(), + &live_injection, + config.voice.live_inject, + )); + } + } + } + } + + let audio = recorder.stop()?; + feedback.play_stop(); + let sample_rate = config.audio.sample_rate; + let audio_duration_ms = audio_duration_ms(audio.len(), sample_rate); + osd.send_voice_update(&build_osd_update( + VoiceOsdStatus::Finalizing, + &accumulator, + rewrite_preview.as_deref(), + &live_injection, + config.voice.live_inject, + )); + + tracing::info!( + samples = audio.len(), + sample_rate, + audio_duration_ms, + "transcribing final voice-mode audio" + ); + + let transcript = asr::transcribe_live_audio(&config, &transcriber, audio, sample_rate).await?; + if transcript.is_empty() { + tracing::warn!("final voice-mode transcription returned empty text"); + postprocess::wait_for_feedback_drain().await; + osd.kill(); + return Ok(()); + } + + let injection_context = context::capture_typing_context(); + let recent_session = recent_session.filter(|entry| { + let same_focus = entry.entry.focus_fingerprint == injection_context.focus_fingerprint; + if !same_focus { + tracing::debug!( + previous_focus = entry.entry.focus_fingerprint, + current_focus = injection_context.focus_fingerprint, + "session backtrack blocked because focus changed before final voice injection" + ); + } + same_focus + }); + + let finalized = postprocess::finalize_transcript( + &config, + transcript, + rewrite_service.as_ref(), + Some(&injection_context), + recent_session.as_ref(), + ) + .await; + if finalized.text.is_empty() { + tracing::warn!("final voice-mode post-processing produced empty text"); + postprocess::wait_for_feedback_drain().await; + osd.kill(); + return Ok(()); + } + + let injector = TextInjector::new(); + let injection_applied = if config.voice.live_inject { + apply_final_live_output( + &injector, + &mut live_injection, + &injection_context, + &finalized, + ) + .await? + } else { + inject_final_output(&injector, &injection_context, &finalized).await?; + true + }; + + if injection_applied && session_enabled { + record_final_session(&config, &injection_context, &finalized)?; + } else if !injection_applied { + tracing::warn!("skipping session recording because final live injection was not applied"); + } + + tracing::info!( + total_elapsed_ms = activation_started.elapsed().as_millis(), + final_chars = finalized.text.len(), + "voice dictation pipeline finished" + ); + osd.kill(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +async fn process_partial_tick( + config: &Config, + recorder: &AudioRecorder, + transcriber: &LiveTranscriber, + rewrite_service: Option<&RewriteService>, + recent_session: Option<&EligibleSessionEntry>, + accumulator: &mut VoiceTranscriptAccumulator, + live_preview_pacing: &mut LivePreviewPacing, + rewrite_preview: &mut Option, + last_rewrite_at: &mut Instant, + live_injection: &mut LiveInjectionState, + osd: &mut OsdHandle, +) -> Result<()> { + let snapshot = recorder.snapshot()?; + let total_audio_ms = audio_duration_ms(snapshot.len(), config.audio.sample_rate); + if total_audio_ms < config.voice.min_chunk_ms { + return Ok(()); + } + let new_audio_ms = audio_duration_ms( + snapshot + .len() + .saturating_sub(live_preview_pacing.last_processed_samples), + config.audio.sample_rate, + ); + let recent_rms = recent_audio_rms(&snapshot, config.audio.sample_rate, LIVE_SILENCE_WINDOW_MS); + if recent_rms >= LIVE_SILENCE_RMS_THRESHOLD { + live_preview_pacing.last_voice_activity_at = Some(Instant::now()); + } else { + let settled_silence = live_preview_pacing + .last_voice_activity_at + .map(|at| at.elapsed() >= Duration::from_millis(LIVE_SILENCE_SETTLE_MS)) + .unwrap_or(true); + if settled_silence { + tracing::trace!( + total_audio_ms, + new_audio_ms, + recent_rms, + "skipping live partial update during settled silence" + ); + live_preview_pacing.last_processed_samples = snapshot.len(); + osd.send_voice_update(&build_osd_update( + if live_injection.is_frozen() { + VoiceOsdStatus::Frozen + } else { + VoiceOsdStatus::Listening + }, + accumulator, + rewrite_preview.as_deref(), + live_injection, + config.voice.live_inject, + )); + return Ok(()); + } + } + if new_audio_ms < LIVE_MIN_TRANSCRIBE_DELTA_MS { + return Ok(()); + } + + osd.send_voice_update(&build_osd_update( + VoiceOsdStatus::Transcribing, + accumulator, + rewrite_preview.as_deref(), + live_injection, + config.voice.live_inject, + )); + + let (mut chunk, chunk_start_ms) = clip_audio_tail( + &snapshot, + config.audio.sample_rate, + config.voice.context_window_ms, + ); + audio::preprocess_live_audio(&mut chunk, config.audio.sample_rate); + let mut transcript = + asr::transcribe_live_audio(config, transcriber, chunk, config.audio.sample_rate).await?; + if transcript.is_empty() { + tracing::debug!( + total_audio_ms, + new_audio_ms, + recent_rms, + "live partial transcription returned empty text; preserving previous preview" + ); + osd.send_voice_update(&build_osd_update( + if live_injection.is_frozen() { + VoiceOsdStatus::Frozen + } else { + VoiceOsdStatus::Listening + }, + accumulator, + rewrite_preview.as_deref(), + live_injection, + config.voice.live_inject, + )); + return Ok(()); + } + offset_transcript_segments(&mut transcript, chunk_start_ms); + live_preview_pacing.last_processed_samples = snapshot.len(); + + accumulator.update(&transcript, total_audio_ms as u32); + let live_preview_text = accumulator.full_text(); + let current_context = (config.voice.live_rewrite || config.voice.live_inject) + .then(context::capture_typing_context); + + if config.voice.live_rewrite + && !live_preview_text.is_empty() + && last_rewrite_at.elapsed() + >= Duration::from_millis(config.voice.rewrite_interval_ms.max(1)) + { + osd.send_voice_update(&build_osd_update( + VoiceOsdStatus::Rewriting, + accumulator, + rewrite_preview.as_deref(), + live_injection, + config.voice.live_inject, + )); + let preview_transcript = + build_live_rewrite_transcript(accumulator, &live_preview_text, &transcript); + let live_recent_session = recent_session.filter(|entry| { + current_context + .as_ref() + .map(|context| entry.entry.focus_fingerprint == context.focus_fingerprint) + .unwrap_or(false) + }); + let finalized = postprocess::finalize_transcript( + config, + preview_transcript, + rewrite_service, + current_context.as_ref(), + live_recent_session, + ) + .await; + *rewrite_preview = match finalized.text.trim() { + "" => None, + text if text == live_preview_text => None, + text => Some(text.to_string()), + }; + tracing::debug!( + live_preview_chars = live_preview_text.len(), + rewrite_preview_chars = rewrite_preview.as_ref().map(|text| text.len()).unwrap_or(0), + "updated live rewrite preview" + ); + *last_rewrite_at = Instant::now(); + } + + if config.voice.live_inject && !live_preview_text.is_empty() { + let current_context = current_context + .as_ref() + .cloned() + .unwrap_or_else(context::capture_typing_context); + if let Some(command) = live_injection.plan_update(&live_preview_text, ¤t_context) { + tracing::trace!( + delete_graphemes = command.delete_graphemes, + insert_chars = command.text.len(), + desired_chars = live_preview_text.len(), + "applying live injection command" + ); + if let Err(err) = + apply_injection_command(&TextInjector::new(), &command, ¤t_context).await + { + live_injection.freeze(); + return Err(err); + } + } + } + + osd.send_voice_update(&build_osd_update( + if live_injection.is_frozen() { + VoiceOsdStatus::Frozen + } else { + VoiceOsdStatus::Listening + }, + accumulator, + rewrite_preview.as_deref(), + live_injection, + config.voice.live_inject, + )); + Ok(()) +} + +async fn inject_final_output( + injector: &TextInjector, + injection_context: &TypingContext, + finalized: &FinalizedTranscript, +) -> Result<()> { + match finalized.operation { + FinalizedOperation::Append => injector.inject(&finalized.text, injection_context).await, + FinalizedOperation::ReplaceLastEntry { + delete_graphemes, .. + } => { + injector + .replace_recent_text(delete_graphemes, &finalized.text, injection_context) + .await + } + } +} + +async fn apply_final_live_output( + injector: &TextInjector, + live_injection: &mut LiveInjectionState, + injection_context: &TypingContext, + finalized: &FinalizedTranscript, +) -> Result { + match live_injection.plan_finalize(&finalized.operation, &finalized.text, injection_context) { + FinalizeInjectionDecision::Apply(command) => { + apply_injection_command(injector, &command, injection_context).await?; + Ok(true) + } + FinalizeInjectionDecision::Noop => Ok(true), + FinalizeInjectionDecision::Blocked => Ok(false), + } +} + +fn record_final_session( + config: &Config, + injection_context: &TypingContext, + finalized: &FinalizedTranscript, +) -> Result<()> { + match finalized.operation { + FinalizedOperation::Append => session::record_append( + &config.session, + injection_context, + &finalized.text, + finalized.rewrite_summary.clone(), + ), + FinalizedOperation::ReplaceLastEntry { entry_id, .. } => session::record_replace( + &config.session, + injection_context, + entry_id, + &finalized.text, + finalized.rewrite_summary.clone(), + ), + } +} + +async fn apply_injection_command( + injector: &TextInjector, + command: &InjectionCommand, + current_context: &TypingContext, +) -> Result<()> { + injector + .replace_recent_text(command.delete_graphemes, &command.text, current_context) + .await +} + +fn build_osd_update( + status: VoiceOsdStatus, + accumulator: &VoiceTranscriptAccumulator, + rewrite_preview: Option<&str>, + live_injection: &LiveInjectionState, + live_inject_enabled: bool, +) -> VoiceOsdUpdate { + VoiceOsdUpdate { + status, + stable_text: accumulator.stable_text.clone(), + unstable_text: accumulator.unstable_text.clone(), + rewrite_preview: rewrite_preview.map(str::to_string), + live_inject: live_inject_enabled, + frozen: live_injection.is_frozen(), + } +} + +fn build_live_rewrite_transcript( + accumulator: &VoiceTranscriptAccumulator, + live_preview_text: &str, + transcript: &Transcript, +) -> Transcript { + let mut preview_transcript = transcript.clone(); + preview_transcript.raw_text = live_preview_text.to_string(); + + if preview_transcript.segments.is_empty() && !live_preview_text.trim().is_empty() { + preview_transcript + .segments + .push(crate::transcribe::TranscriptSegment { + text: accumulator.full_text(), + start_ms: 0, + end_ms: accumulator.committed_until_ms.max(1), + }); + } + + preview_transcript +} + +fn clip_audio_tail(samples: &[f32], sample_rate: u32, context_window_ms: u64) -> (Vec, u32) { + if sample_rate == 0 || samples.is_empty() { + return (Vec::new(), 0); + } + + let context_samples = ((context_window_ms as u128 * sample_rate as u128) / 1000) as usize; + let start = if context_samples == 0 { + 0 + } else { + samples.len().saturating_sub(context_samples) + }; + ( + samples[start..].to_vec(), + audio_duration_ms(start, sample_rate) as u32, + ) +} + +fn offset_transcript_segments(transcript: &mut Transcript, offset_ms: u32) { + for segment in &mut transcript.segments { + segment.start_ms = segment.start_ms.saturating_add(offset_ms); + segment.end_ms = segment.end_ms.saturating_add(offset_ms); + } +} + +fn audio_duration_ms(samples: usize, sample_rate: u32) -> u64 { + if sample_rate == 0 { + return 0; + } + ((samples as f64 / sample_rate as f64) * 1000.0).round() as u64 +} + +fn audio_rms(samples: &[f32]) -> f32 { + if samples.is_empty() { + return 0.0; + } + let energy: f32 = samples.iter().map(|sample| sample * sample).sum(); + (energy / samples.len() as f32).sqrt() +} + +fn recent_audio_rms(samples: &[f32], sample_rate: u32, window_ms: u64) -> f32 { + let (tail, _) = clip_audio_tail(samples, sample_rate, window_ms); + audio_rms(&tail) +} + +fn grapheme_count(text: &str) -> usize { + UnicodeSegmentation::graphemes(text, true).count() +} + +fn shared_grapheme_prefix(current: &str, desired: &str) -> (usize, usize) { + let mut current_end = 0; + let mut desired_end = 0; + + for ((current_idx, current_grapheme), (desired_idx, desired_grapheme)) in current + .grapheme_indices(true) + .zip(desired.grapheme_indices(true)) + { + if current_grapheme != desired_grapheme { + break; + } + current_end = current_idx + current_grapheme.len(); + desired_end = desired_idx + desired_grapheme.len(); + } + + (current_end, desired_end) +} + +fn build_suffix_rewrite_command( + current_text: &str, + desired_text: &str, +) -> Option { + if current_text == desired_text { + return None; + } + + let (current_prefix_end, desired_prefix_end) = + shared_grapheme_prefix(current_text, desired_text); + let delete_graphemes = grapheme_count(¤t_text[current_prefix_end..]); + let text = desired_text[desired_prefix_end..].to_string(); + + if delete_graphemes == 0 && text.is_empty() { + return None; + } + + Some(InjectionCommand { + delete_graphemes, + text, + }) +} + +fn append_segment_text(output: &mut String, text: &str) { + let trimmed = text.trim(); + if trimmed.is_empty() { + return; + } + if !output.is_empty() { + output.push(' '); + } + output.push_str(trimmed); +} + +fn join_segment_text(segments: I) -> String +where + I: IntoIterator, + I::Item: AsRef, +{ + let mut joined = String::new(); + for segment in segments { + append_segment_text(&mut joined, segment.as_ref()); + } + joined +} + +#[derive(Debug, Clone, Default)] +struct VoiceTranscriptAccumulator { + stable_text: String, + unstable_text: String, + committed_until_ms: u32, +} + +#[derive(Debug, Clone, Default)] +struct LivePreviewPacing { + last_processed_samples: usize, + last_voice_activity_at: Option, +} + +impl VoiceTranscriptAccumulator { + fn update(&mut self, transcript: &Transcript, total_audio_ms: u32) { + let stable_boundary_ms = total_audio_ms.saturating_sub(UNSTABLE_TAIL_MS); + let committed_until_ms = self.committed_until_ms; + for segment in transcript.segments.iter().filter(|segment| { + segment.end_ms <= stable_boundary_ms && segment.end_ms > committed_until_ms + }) { + append_segment_text(&mut self.stable_text, &segment.text); + self.committed_until_ms = self.committed_until_ms.max(segment.end_ms); + } + + if transcript.segments.is_empty() { + self.unstable_text = transcript.raw_text.trim().to_string(); + return; + } + + self.unstable_text = join_segment_text( + transcript + .segments + .iter() + .filter(|segment| segment.end_ms > self.committed_until_ms) + .map(|segment| segment.text.as_str()), + ); + } + + fn full_text(&self) -> String { + join_segment_text([self.stable_text.as_str(), self.unstable_text.as_str()]) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct InjectionCommand { + delete_graphemes: usize, + text: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum FinalizeInjectionDecision { + Apply(InjectionCommand), + Noop, + Blocked, +} + +#[derive(Debug, Clone)] +struct LiveInjectionState { + enabled: bool, + freeze_on_focus_change: bool, + original_focus: String, + replaceable_prefix_graphemes: usize, + current_text: String, + pending_correction: Option, + frozen: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct PendingCorrection { + desired_text: String, + confirmations: usize, +} + +impl LiveInjectionState { + fn new( + enabled: bool, + freeze_on_focus_change: bool, + recording_context: &TypingContext, + replaceable_prefix_graphemes: usize, + ) -> Self { + Self { + enabled, + freeze_on_focus_change, + original_focus: recording_context.focus_fingerprint.clone(), + replaceable_prefix_graphemes, + current_text: String::new(), + pending_correction: None, + frozen: false, + } + } + + fn is_frozen(&self) -> bool { + self.frozen + } + + fn freeze(&mut self) { + self.frozen = true; + } + + fn plan_update( + &mut self, + desired_text: &str, + current_context: &TypingContext, + ) -> Option { + if !self.enabled || self.frozen { + return None; + } + if self.should_freeze(current_context) { + self.frozen = true; + return None; + } + let policy = InjectionPolicy::for_context(current_context); + let Some(command) = build_suffix_rewrite_command(&self.current_text, desired_text) else { + self.pending_correction = None; + return None; + }; + if command.delete_graphemes > 0 { + if !policy.allows_live_destructive_correction(command.delete_graphemes) { + tracing::trace!( + surface_policy = policy.label(), + delete_graphemes = command.delete_graphemes, + "deferring live destructive correction until final reconciliation" + ); + self.pending_correction = None; + return None; + } + let required_confirmations = policy.destructive_correction_confirmations(); + let should_apply = match self.pending_correction.as_mut() { + Some(pending) if pending.desired_text == desired_text => { + pending.confirmations = pending.confirmations.saturating_add(1); + pending.confirmations >= required_confirmations + } + _ => { + self.pending_correction = Some(PendingCorrection { + desired_text: desired_text.to_string(), + confirmations: 1, + }); + required_confirmations <= 1 + } + }; + if !should_apply { + return None; + } + } + self.pending_correction = None; + self.current_text = desired_text.to_string(); + Some(command) + } + + fn plan_finalize( + &mut self, + operation: &FinalizedOperation, + final_text: &str, + current_context: &TypingContext, + ) -> FinalizeInjectionDecision { + if !self.enabled { + return FinalizeInjectionDecision::Blocked; + } + if self.frozen && self.should_freeze(current_context) { + return FinalizeInjectionDecision::Blocked; + } + + let extra_delete = match operation { + FinalizedOperation::Append => 0, + FinalizedOperation::ReplaceLastEntry { + delete_graphemes, .. + } => { + if *delete_graphemes > self.replaceable_prefix_graphemes { + return FinalizeInjectionDecision::Blocked; + } + *delete_graphemes + } + }; + let delete_graphemes = extra_delete + grapheme_count(&self.current_text); + if extra_delete == 0 { + if let Some(command) = build_suffix_rewrite_command(&self.current_text, final_text) { + self.current_text = final_text.to_string(); + self.frozen = false; + self.pending_correction = None; + return FinalizeInjectionDecision::Apply(command); + } + return FinalizeInjectionDecision::Noop; + } + + let command = InjectionCommand { + delete_graphemes, + text: final_text.to_string(), + }; + self.current_text = final_text.to_string(); + self.pending_correction = None; + self.frozen = false; + FinalizeInjectionDecision::Apply(command) + } + + fn should_freeze(&self, current_context: &TypingContext) -> bool { + self.freeze_on_focus_change + && !self.original_focus.is_empty() + && !current_context.focus_fingerprint.is_empty() + && current_context.focus_fingerprint != self.original_focus + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::context::SurfaceKind; + use crate::transcribe::TranscriptSegment; + + fn transcript(segments: &[(&str, u32, u32)]) -> Transcript { + Transcript { + raw_text: join_segment_text(segments.iter().map(|(text, _, _)| *text)), + detected_language: Some("en".into()), + segments: segments + .iter() + .map(|(text, start_ms, end_ms)| TranscriptSegment { + text: (*text).to_string(), + start_ms: *start_ms, + end_ms: *end_ms, + }) + .collect(), + } + } + + fn context_with_surface(focus: &str, surface_kind: SurfaceKind) -> TypingContext { + TypingContext { + focus_fingerprint: focus.into(), + app_id: Some("app".into()), + window_title: Some("window".into()), + surface_kind, + browser_domain: None, + captured_at_ms: 0, + } + } + + fn context(focus: &str) -> TypingContext { + context_with_surface(focus, SurfaceKind::GenericText) + } + + #[test] + fn accumulator_commits_stable_prefix_and_keeps_tail_mutable() { + let mut accumulator = VoiceTranscriptAccumulator::default(); + accumulator.update( + &transcript(&[ + ("hello", 0, 900), + ("world", 900, 1900), + ("again", 1900, 3300), + ]), + UNSTABLE_TAIL_MS + 2000, + ); + + assert_eq!(accumulator.stable_text, "hello world"); + assert_eq!(accumulator.unstable_text, "again"); + assert_eq!(accumulator.full_text(), "hello world again"); + } + + #[test] + fn accumulator_handles_short_tail_regressions() { + let mut accumulator = VoiceTranscriptAccumulator::default(); + accumulator.update( + &transcript(&[("hello", 0, 900), ("world", 900, 2100)]), + 3000, + ); + assert_eq!(accumulator.full_text(), "hello world"); + + accumulator.update(&transcript(&[("hello", 0, 900)]), 3000); + assert_eq!(accumulator.stable_text, ""); + assert_eq!(accumulator.unstable_text, "hello"); + } + + #[test] + fn accumulator_allows_earlier_correction_inside_mutable_suffix() { + let mut accumulator = VoiceTranscriptAccumulator::default(); + accumulator.update( + &transcript(&[("ship", 1000, 1800), ("it", 1800, 2400)]), + 3000, + ); + assert_eq!(accumulator.full_text(), "ship it"); + + accumulator.update( + &transcript(&[("shift", 1000, 1800), ("it", 1800, 2400)]), + 3000, + ); + assert_eq!(accumulator.full_text(), "shift it"); + } + + #[test] + fn live_rewrite_transcript_uses_accumulated_preview_text() { + let mut accumulator = VoiceTranscriptAccumulator::default(); + accumulator.update( + &transcript(&[ + ("i'm", 0, 400), + ("using", 400, 800), + ("hyperland", 800, 1200), + ]), + UNSTABLE_TAIL_MS + 800, + ); + assert_eq!(accumulator.stable_text, "i'm using"); + assert_eq!(accumulator.unstable_text, "hyperland"); + + let live_preview_text = accumulator.full_text(); + let preview = build_live_rewrite_transcript( + &accumulator, + &live_preview_text, + &transcript(&[("hyperland", 800, 1200)]), + ); + + assert_eq!(preview.raw_text, "i'm using hyperland"); + assert_eq!(preview.segments.len(), 1); + assert_eq!(preview.segments[0].text, "hyperland"); + } + + #[test] + fn final_reconciliation_overrides_last_partial_text() { + let mut state = LiveInjectionState::new(true, true, &context("focus-a"), 0); + let _ = state.plan_update("draft text", &context("focus-a")); + + let decision = state.plan_finalize( + &FinalizedOperation::Append, + "final text", + &context("focus-a"), + ); + assert_eq!( + decision, + FinalizeInjectionDecision::Apply(InjectionCommand { + delete_graphemes: 10, + text: "final text".into(), + }) + ); + } + + #[test] + fn preview_only_mode_never_plans_live_mutation() { + let mut state = LiveInjectionState::new(false, true, &context("focus-a"), 0); + assert_eq!(state.plan_update("hello", &context("focus-a")), None); + } + + #[test] + fn live_inject_only_rewrites_owned_text() { + let mut state = LiveInjectionState::new(true, true, &context("focus-a"), 0); + let first = state + .plan_update("hello", &context("focus-a")) + .expect("first update"); + assert_eq!( + first, + InjectionCommand { + delete_graphemes: 0, + text: "hello".into() + } + ); + + let second = state + .plan_update("hello world", &context("focus-a")) + .expect("second update"); + assert_eq!( + second, + InjectionCommand { + delete_graphemes: 0, + text: " world".into() + } + ); + } + + #[test] + fn live_inject_only_rewrites_changed_suffix_when_correcting() { + let mut state = LiveInjectionState::new(true, true, &context("focus-a"), 0); + let _ = state.plan_update("ship it", &context("focus-a")); + + assert_eq!(state.plan_update("shift it", &context("focus-a")), None); + + let correction = state + .plan_update("shift it", &context("focus-a")) + .expect("confirmed correction update"); + assert_eq!( + correction, + InjectionCommand { + delete_graphemes: 4, + text: "ft it".into() + } + ); + } + + #[test] + fn destructive_correction_confirmation_resets_when_target_changes() { + let mut state = LiveInjectionState::new(true, true, &context("focus-a"), 0); + let _ = state.plan_update("hyperland", &context("focus-a")); + + assert_eq!(state.plan_update("hyprland", &context("focus-a")), None); + assert_eq!(state.plan_update("highprland", &context("focus-a")), None); + assert_eq!(state.plan_update("hyprland", &context("focus-a")), None); + + let correction = state + .plan_update("hyprland", &context("focus-a")) + .expect("correction update"); + assert_eq!( + correction, + InjectionCommand { + delete_graphemes: 6, + text: "rland".into() + } + ); + } + + #[test] + fn destructive_correction_confirmation_resets_when_correction_disappears() { + let mut state = LiveInjectionState::new(true, true, &context("focus-a"), 0); + let _ = state.plan_update("ship it", &context("focus-a")); + + assert_eq!(state.plan_update("shift it", &context("focus-a")), None); + assert_eq!(state.plan_update("ship it", &context("focus-a")), None); + assert_eq!(state.plan_update("shift it", &context("focus-a")), None); + assert!(state.plan_update("shift it", &context("focus-a")).is_some()); + } + + #[test] + fn final_append_only_appends_delta_when_live_text_is_already_correct() { + let mut state = LiveInjectionState::new(true, true, &context("focus-a"), 0); + let _ = state.plan_update("hello", &context("focus-a")); + + let decision = state.plan_finalize( + &FinalizedOperation::Append, + "hello world", + &context("focus-a"), + ); + assert_eq!( + decision, + FinalizeInjectionDecision::Apply(InjectionCommand { + delete_graphemes: 0, + text: " world".into(), + }) + ); + } + + #[test] + fn focus_change_freezes_live_injection_immediately() { + let mut state = LiveInjectionState::new(true, true, &context("focus-a"), 0); + assert_eq!(state.plan_update("hello", &context("focus-b")), None); + assert!(state.is_frozen()); + } + + #[test] + fn browser_surface_requires_extra_confirmation_for_destructive_live_updates() { + let browser = context_with_surface("focus-a", SurfaceKind::Browser); + let mut state = LiveInjectionState::new(true, true, &browser, 0); + let _ = state.plan_update("hyperland", &browser); + + assert_eq!(state.plan_update("hyprland", &browser), None); + assert_eq!(state.plan_update("hyprland", &browser), None); + let correction = state + .plan_update("hyprland", &browser) + .expect("third confirmation should apply"); + assert_eq!( + correction, + InjectionCommand { + delete_graphemes: 6, + text: "rland".into() + } + ); + } + + #[test] + fn unknown_surface_keeps_live_mode_append_only_but_allows_final_reconciliation() { + let unknown = context_with_surface("focus-a", SurfaceKind::Unknown); + let mut state = LiveInjectionState::new(true, true, &unknown, 0); + let _ = state + .plan_update("hello", &unknown) + .expect("initial append"); + + assert_eq!(state.plan_update("help", &unknown), None); + assert_eq!(state.plan_update("help", &unknown), None); + + let decision = state.plan_finalize(&FinalizedOperation::Append, "help", &unknown); + assert_eq!( + decision, + FinalizeInjectionDecision::Apply(InjectionCommand { + delete_graphemes: 2, + text: "p".into(), + }) + ); + } + + #[test] + fn final_reconciliation_respects_owned_delete_bounds() { + let mut state = LiveInjectionState::new(true, true, &context("focus-a"), 3); + let _ = state.plan_update("hello", &context("focus-a")); + + let blocked = state.plan_finalize( + &FinalizedOperation::ReplaceLastEntry { + entry_id: 7, + delete_graphemes: 4, + }, + "replacement", + &context("focus-a"), + ); + assert_eq!(blocked, FinalizeInjectionDecision::Blocked); + + let allowed = state.plan_finalize( + &FinalizedOperation::ReplaceLastEntry { + entry_id: 7, + delete_graphemes: 3, + }, + "replacement", + &context("focus-a"), + ); + assert_eq!( + allowed, + FinalizeInjectionDecision::Apply(InjectionCommand { + delete_graphemes: 8, + text: "replacement".into(), + }) + ); + } +}