diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..a83a16b --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,5 @@ +[env] +WHISPER_DONT_GENERATE_BINDINGS = "1" + +[patch.crates-io] +whisper-rs-sys = { path = "vendor/whisper-rs-sys" } diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c3ea29d..66032c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,13 +58,17 @@ jobs: - name: Check osd feature only run: cargo check --no-default-features --features osd - - name: Verify package (default publish surface) + - name: Check local rewrite feature only + run: cargo check --no-default-features --features local-rewrite + + - name: Package tarball run: cargo package --locked - name: Check cuda feature only (if toolkit available) run: | if command -v nvcc >/dev/null 2>&1; then cargo check --no-default-features --features cuda + cargo check --no-default-features --features cuda,local-rewrite else echo "CUDA toolkit not available on this runner; skipping cuda feature check" fi diff --git a/Cargo.lock b/Cargo.lock index 1d3e336..9867318 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "alsa" version = "0.9.1" @@ -1016,6 +1022,28 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "font8x8" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "875488b8711a968268c7cf5d139578713097ca4635a76044e8fe8eedf831d07e" + +[[package]] +name = "fontdue" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e57e16b3fe8ff4364c0661fdaac543fb38b29ea9bc9c2f45612d90adf931d2b" +dependencies = [ + "hashbrown 0.15.5", + "ttf-parser", +] + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -1175,6 +1203,17 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + [[package]] name = "hashbrown" version = "0.16.1" @@ -1490,7 +1529,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", ] [[package]] @@ -3272,6 +3311,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "ttf-parser" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c591d83f69777866b9126b24c6dd9a18351f177e49d625920d19f989fd31cf8" + [[package]] name = "typenum" version = "1.19.0" @@ -3579,8 +3624,6 @@ dependencies = [ [[package]] name = "whisper-rs-sys" version = "0.14.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e2a6e06e7ac7b8f53c53a5f50bb0bc823ba69b63ecd887339f807a5598bbd2" dependencies = [ "bindgen 0.71.1", "cfg-if", @@ -3590,7 +3633,7 @@ dependencies = [ [[package]] name = "whispers" -version = "0.1.0" +version = "0.1.1" dependencies = [ "base64 0.22.1", "clap", @@ -3602,6 +3645,8 @@ dependencies = [ "encoding_rs", "evdev", "flacenc", + "font8x8", + "fontdue", "futures-util", "httpmock", "indicatif", diff --git a/Cargo.toml b/Cargo.toml index 8e0dced..ae9b685 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "whispers" -version = "0.1.0" +version = "0.1.1" edition = "2024" rust-version = "1.85" description = "Speech-to-text dictation tool for Wayland" @@ -8,7 +8,18 @@ license = "MPL-2.0" repository = "https://github.com/OneNoted/whispers" keywords = ["wayland", "speech-to-text", "whisper", "dictation", "voice"] categories = ["accessibility", "multimedia"] -exclude = [".jj/", ".git/", "target/"] +include = [ + "Cargo.toml", + "Cargo.lock", + "README.md", + "LICENSE", + "NOTICE", + "config.example.toml", + ".cargo/config.toml", + "sounds/*.wav", + "src/**", + "vendor/whisper-rs-sys/**/*", +] [dependencies] # Async runtime @@ -19,7 +30,7 @@ cpal = "0.17" # Whisper transcription whisper-rs = "0.15" -llama-cpp-2 = "0.1.138" +llama-cpp-2 = { version = "0.1.138", optional = true } # uinput virtual keyboard for paste keystroke evdev = { version = "0.13" } @@ -63,11 +74,20 @@ console = "0.16" wayland-client = { version = "0.31", optional = true } wayland-protocols = { version = "0.32", features = ["client"], optional = true } wayland-protocols-wlr = { version = "0.3", features = ["client"], optional = true } +font8x8 = { version = "0.3", optional = true } +fontdue = { version = "0.9", optional = true } [features] default = ["osd"] -cuda = ["whisper-rs/cuda", "llama-cpp-2/cuda"] -osd = ["dep:wayland-client", "dep:wayland-protocols", "dep:wayland-protocols-wlr"] +cuda = ["whisper-rs/cuda", "llama-cpp-2?/cuda"] +local-rewrite = ["dep:llama-cpp-2"] +osd = [ + "dep:wayland-client", + "dep:wayland-protocols", + "dep:wayland-protocols-wlr", + "dep:font8x8", + "dep:fontdue", +] [[bin]] name = "whispers" @@ -80,7 +100,8 @@ required-features = ["osd"] [[bin]] name = "whispers-rewrite-worker" -path = "src/bin/whispers-rewrite-worker.rs" +path = "src/bin/whispers-rewrite-worker/main.rs" +required-features = ["local-rewrite"] [dev-dependencies] httpmock = "0.7" diff --git a/README.md b/README.md index e360c57..57893e8 100644 --- a/README.md +++ b/README.md @@ -1,348 +1,123 @@ # whispers -Fast speech-to-text dictation for Wayland with local-first ASR and optional cloud ASR/rewrite backends. -Press a key to start recording, press it again to transcribe and paste. +Fast speech-to-text dictation for Wayland. -Local mode keeps all inference on your machine. Optional cloud modes can offload ASR, rewrite, or both when configured. - -Inspired by [hyprwhspr](https://github.com/goodroot/hyprwhspr) by goodroot. - -image - - - -## How it works - -1. Bind `whispers` to a key in your compositor -2. First press starts recording (OSD overlay shows audio visualization) -3. Second press stops recording, transcribes, and pastes via `Ctrl+Shift+V` - -The two invocations communicate via PID file + `SIGUSR1` — no daemon, no IPC server. - -## Post-processing modes - -`whispers` now has two main dictation modes: - -- `raw` keeps output close to the direct transcription result and is the default -- `advanced_local` enables the smart rewrite pipeline after transcription; `[rewrite].backend` chooses whether that rewrite runs locally or in the cloud - -The older heuristic cleanup path is still available as deprecated `legacy_basic` for existing configs that already use `[cleanup]`. -The local rewrite path is managed by `whispers` itself through an internal helper binary installed alongside the main executable, so there is no separate tool or daemon to install manually. -When `advanced_local` is enabled with `rewrite.backend = "local"`, `whispers` keeps a hidden rewrite worker warm for a short idle window so repeated dictation is much faster without becoming a permanent background daemon. -Managed rewrite models are the default path. If you point `rewrite.model_path` at your own GGUF, it should be a chat-capable model with an embedded template that `llama.cpp` can apply at runtime. -Deterministic personalization rules apply in all modes: dictionary replacements, spoken snippets, and optional append-only custom rewrite instructions for `advanced_local`. -Cloud ASR and cloud rewrite are both optional. Local remains the default. - -For file transcription, `whispers transcribe --raw ` always prints the plain ASR transcript without any post-processing. - -## Requirements - -- Rust 1.85+ (edition 2024) -- Linux with Wayland compositor -- `wl-copy` (from `wl-clipboard`) -- `uinput` access (for virtual keyboard paste) -- NVIDIA GPU + CUDA toolkit (optional, for GPU acceleration) -- `python3` on `PATH` if you want to use the optional `faster-whisper` backend -- `python3.10`, `python3.11`, or `python3.12` on `PATH` if you want to use the experimental NeMo backends -- If no compatible GPU is available, set `transcription.use_gpu = false` in config +`whispers` is local-first by default, with optional cloud ASR and rewrite backends when you want them. The normal flow is simple: press a key to start recording, press it again to transcribe and paste. ## Install -### From crates.io - ```sh -# Default install: CPU build with Wayland OSD +# default install cargo install whispers -# Enable CUDA acceleration explicitly +# CUDA cargo install whispers --features cuda -# Build without the OSD overlay -cargo install whispers --no-default-features -``` - -### From git +# local rewrite support +cargo install whispers --features local-rewrite -```sh -# Default install: CPU build with Wayland OSD -cargo install --git https://github.com/OneNoted/whispers +# CUDA + local rewrite +cargo install whispers --features cuda,local-rewrite -# Enable CUDA acceleration explicitly -cargo install --git https://github.com/OneNoted/whispers --features cuda - -# Build without the OSD overlay -cargo install --git https://github.com/OneNoted/whispers --no-default-features +# no OSD +cargo install whispers --no-default-features ``` -### Setup - -Run the interactive setup wizard to download a local ASR model, generate config, and optionally enable local or cloud advanced dictation. Recommended local models are shown first, and experimental backends like Parakeet are called out explicitly before you opt into them: +If you want the latest GitHub version instead of crates.io: ```sh -whispers setup +cargo install --git https://github.com/OneNoted/whispers --features cuda,local-rewrite ``` -Normal runs keep output concise. Add `-v` when you want detailed diagnostic logs during setup, downloads, or dictation. +## Requirements + +- Linux with Wayland +- `wl-copy` +- access to `/dev/uinput` +- Rust 1.85+ +- CUDA toolkit if you enable the `cuda` feature -Use a custom config file for any command (including `setup` and `asr-model`): +If `/dev/uinput` is blocked, add your user to the `input` group and log back in: ```sh -whispers --config /path/to/config.toml setup -whispers --config /path/to/config.toml asr-model select tiny +sudo usermod -aG input $USER ``` -Or manage ASR models manually: +## Quick Start ```sh -whispers asr-model list -whispers asr-model download large-v3-turbo -whispers asr-model select large-v3-turbo -whispers asr-model download distil-large-v3.5 -whispers asr-model select distil-large-v3.5 -# Experimental NeMo path: -whispers asr-model download parakeet-tdt_ctc-1.1b -whispers asr-model select parakeet-tdt_ctc-1.1b - -# Legacy whisper_cpp-only aliases still work for one release: -whispers model list -whispers model download large-v3-turbo -whispers model select large-v3-turbo +# generate config and download a model +whispers setup -whispers rewrite-model list -whispers rewrite-model download qwen-3.5-4b-q4_k_m -whispers rewrite-model select qwen-3.5-4b-q4_k_m -whispers cloud check +# one-shot dictation +whispers -whispers dictionary add "wisper flow" "Wispr Flow" -whispers dictionary list -whispers snippets add signature "Best regards,\nNotes" -whispers snippets list -whispers rewrite-instructions-path +# live mode +whispers voice ``` -That still remains a single install: `whispers` manages local ASR models, the optional local rewrite worker/model, and the optional cloud configuration from the same package. `faster-whisper` is bootstrapped into a hidden managed runtime when you download or prewarm that backend. - -## Shell completions +Default config path: -Print completion scripts to `stdout`: - -```sh -# auto-detect from $SHELL (falls back to parent process name) -whispers completions - -# or specify manually -whispers completions zsh +```text +~/.config/whispers/config.toml ``` -Supported shells: `bash`, `zsh`, `fish`, `nushell`. - -Example install paths: +Canonical example config: -```sh -# bash -mkdir -p ~/.local/share/bash-completion/completions -whispers completions bash > ~/.local/share/bash-completion/completions/whispers - -# zsh -mkdir -p ~/.zfunc -whispers completions zsh > ~/.zfunc/_whispers - -# fish -mkdir -p ~/.config/fish/completions -whispers completions fish > ~/.config/fish/completions/whispers.fish +- [config.example.toml](config.example.toml) -# nushell -mkdir -p ~/.config/nushell/completions -whispers completions nushell > ~/.config/nushell/completions/whispers.nu -``` - -## Compositor keybinding +### Keybinding -### Hyprland +Hyprland: ```conf bind = SUPER ALT, D, exec, whispers ``` -### Sway +Sway: ```conf bindsym $mod+Alt+d exec whispers ``` -## Configuration - -Config lives at `~/.config/whispers/config.toml` by default. Generated automatically by `whispers setup`, or copy from `config.example.toml`: - -```toml -[audio] -device = "" # empty = system default -sample_rate = 16000 - -[transcription] -backend = "whisper_cpp" # or "faster_whisper" / "nemo" / "cloud" -fallback = "configured_local" # or "none" -local_backend = "whisper_cpp" -selected_model = "large-v3-turbo" -model_path = "~/.local/share/whispers/ggml-large-v3-turbo.bin" -language = "auto" # or "en", "fr", "de", etc. -use_gpu = true # set false to force CPU -flash_attn = true # only used when use_gpu=true -idle_timeout_ms = 120000 - -[postprocess] -mode = "raw" # or "advanced_local"; deprecated: "legacy_basic" - -[session] -enabled = true -max_entries = 3 -max_age_ms = 8000 -max_replace_graphemes = 400 - -[personalization] -dictionary_path = "~/.local/share/whispers/dictionary.toml" -snippets_path = "~/.local/share/whispers/snippets.toml" -snippet_trigger = "insert" - -[rewrite] -backend = "local" # or "cloud" -fallback = "local" # or "none" -selected_model = "qwen-3.5-4b-q4_k_m" -model_path = "" # optional manual GGUF path override -instructions_path = "~/.local/share/whispers/rewrite-instructions.txt" -profile = "auto" # or "qwen", "generic", "llama_compat" -timeout_ms = 30000 -idle_timeout_ms = 120000 -max_output_chars = 1200 -max_tokens = 256 - -[cloud] -provider = "openai" # or "openai_compatible" -base_url = "" # required for openai_compatible -api_key = "" # optional direct API key; leave empty to use api_key_env -api_key_env = "OPENAI_API_KEY" -connect_timeout_ms = 3000 -request_timeout_ms = 15000 - -[cloud.transcription] -model = "gpt-4o-mini-transcribe" -language_mode = "inherit_local" # or "force" -language = "" - -[cloud.rewrite] -model = "gpt-4.1-mini" -temperature = 0.1 -max_output_tokens = 256 - -[feedback] -enabled = true -start_sound = "" # empty = bundled sound -stop_sound = "" -``` - -When `advanced_local` is enabled, `whispers` also keeps a short-lived local session ledger in the runtime directory so immediate follow-up corrections like `scratch that` can safely replace the most recent dictation entry when focus has not changed. That session behavior is local either way; only the semantic rewrite stage may be cloud-backed. - -## Cloud Modes - -- `transcription.backend = "cloud"` uploads recorded audio to the configured provider for ASR. -- `rewrite.backend = "cloud"` uploads transcript/context JSON to the configured provider for semantic cleanup. -- `transcription.fallback = "configured_local"` keeps a local ASR fallback path. -- `rewrite.fallback = "local"` keeps a local rewrite fallback path. -- Use either `cloud.api_key_env` or `cloud.api_key`. `setup` accepts either an env var name or a pasted key. - -Use `whispers cloud check` to validate cloud config, API key resolution, and basic provider connectivity. - -## Managed ASR models - -`whispers` currently ships managed local ASR entries across two backend families: - -| Model | Backend | Scope | Notes | -|-------|---------|-------|-------| -| large-v3-turbo | whisper_cpp | Multilingual | Default path | -| large-v3 | whisper_cpp | Multilingual | Slower, higher accuracy | -| medium / small / base / tiny | whisper_cpp | Multilingual | Smaller/faster tradeoffs | -| *.en variants | whisper_cpp | English only | Smaller English Whisper options | -| distil-large-v3.5 | faster_whisper | English only | Fast English option | -| parakeet-tdt_ctc-1.1b | nemo | English only | Experimental NeMo ASR benchmark path | -| canary-qwen-2.5b | nemo | English only | Experimental NeMo ASR/LLM hybrid (currently blocked) | - -`large-v3-turbo` remains the default multilingual local model. `distil-large-v3.5` is the speed-focused English option on the optional `faster-whisper` backend. `parakeet-tdt_ctc-1.1b` is kept as an experimental English-only NeMo backend for benchmarking against Whisper-family models, not as the default recommendation. Its first warm-up can be much slower than steady-state dictation, so judge it on warm use rather than the first cold start. `canary-qwen-2.5b` remains listed for evaluation, but the managed path is currently blocked by an upstream NeMo/PEFT initialization incompatibility. Cloud ASR models are configured under `[cloud.transcription]` instead of being downloaded locally. - -## Whisper Models - -| Model | Size | Speed | Notes | -|-------|------|-------|-------| -| large-v3-turbo | 1.6 GB | Fast | Best balance (recommended) | -| large-v3-turbo-q5_0 | 574 MB | Fast | Quantized, slightly less accurate | -| large-v3 | 3.1 GB | Slow | Most accurate | -| small / small.en | 488 MB | Very fast | Good for English-only | -| tiny / tiny.en | 78 MB | Instant | Least accurate | - -Whisper.cpp models are downloaded from [Hugging Face](https://huggingface.co/ggerganov/whisper.cpp) and stored in `~/.local/share/whispers/`. The managed `faster-whisper` backend stores models and its Python runtime under the same XDG data directory. - -## Managed rewrite models - -When `rewrite.backend = "local"`, `advanced_local` uses a second local model for post-processing. The managed local catalog currently includes: - -| Model | Size | Notes | -|-------|------|-------| -| qwen-3.5-2b-q4_k_m | ~1.3 GB | Fallback for weaker hardware | -| qwen-3.5-4b-q4_k_m | ~2.9 GB | Recommended default | -| qwen-3.5-9b-q4_k_m | ~5.9 GB | Higher quality, heavier | - -If you want to tinker, set `rewrite.model_path` to a custom GGUF file. When `rewrite.model_path` is set, it overrides the managed selection. -`rewrite.profile = "auto"` keeps the prompt/runtime model-aware without requiring manual tuning for managed models, and still falls back safely for custom GGUFs. -Custom rewrite models should include a chat template that `llama.cpp` can read from the GGUF metadata; otherwise rewrite prompting will fail fast instead of silently producing bad output. - -## Personalization +## Commands -Dictionary replacements apply deterministically in both `raw` and `advanced_local`, with normalization for case and punctuation but no fuzzy matching. In `advanced_local`, dictionary replacements are applied before the rewrite model and again on the final output so exact names and product terms stay stable. - -Spoken snippets also work in all modes. By default, saying `insert ` expands the configured snippet text verbatim after post-processing finishes, so the rewrite model cannot paraphrase it. Change the trigger phrase with `personalization.snippet_trigger`. - -Custom rewrite instructions live in a separate plain-text file referenced by `rewrite.instructions_path`. `whispers` appends that file to the built-in rewrite prompt for `advanced_local`, while still enforcing the same final-text-only output contract. The file is optional, and a missing file is ignored. - -## Faster Whisper - -`faster-whisper` is optional and intended for users who want the fastest English dictation path. The current managed model for it is `distil-large-v3.5`. - -Notes: -- English dictation is the intended use case -- if it fails at runtime and a local `large-v3-turbo` Whisper model is available, `whispers` falls back to `whisper_cpp` -- `transcription.idle_timeout_ms = 0` keeps the hidden ASR worker warm indefinitely - -## Experimental NeMo backends - -`parakeet-tdt_ctc-1.1b` is available as an experimental English-only ASR option on a managed NeMo backend. `canary-qwen-2.5b` remains under evaluation, but the managed path is currently blocked by an upstream initialization issue. +```sh +# setup +whispers setup -Notes: -- they are intended for benchmarking and experimentation, not as the default recommendation -- first warm-up can be much slower than steady-state dictation because the hidden worker and model need to come up -- the first use bootstraps a hidden managed Python runtime under the XDG data directory -- the runtime currently requires Python 3.10, 3.11, or 3.12 on `PATH` -- model downloads are stored as prepared NeMo model directories instead of ggml files -- if a NeMo backend fails at runtime and a local `large-v3-turbo` Whisper model is available, `whispers` falls back to `whisper_cpp` +# dictation +whispers +whispers voice +whispers transcribe audio.wav -## Privacy +# ASR models +whispers asr-model list +whispers asr-model download large-v3-turbo +whispers asr-model select large-v3-turbo -- Local-only: no inference-time network traffic -- Cloud ASR: audio leaves the machine for transcription -- Cloud rewrite: transcript/context leaves the machine for rewrite -- Cloud ASR + rewrite: both leave the machine +# rewrite models +whispers rewrite-model list +whispers rewrite-model download qwen-3.5-4b-q4_k_m +whispers rewrite-model select qwen-3.5-4b-q4_k_m -## uinput permissions +# personalization +whispers dictionary add "wisper flow" "Wispr Flow" +whispers snippets add signature "Best regards,\nNotes" -whispers needs access to `/dev/uinput` for the virtual keyboard paste. Add your user to the `input` group: +# cloud +whispers cloud check -```sh -sudo usermod -aG input $USER +# shell completions +whispers completions zsh ``` -Then log out and back in. - -## Acknowledgements +## Notes -This project is inspired by [hyprwhspr](https://github.com/goodroot/hyprwhspr) by [goodroot](https://github.com/goodroot), which provides native speech-to-text for Linux with support for multiple backends. whispers is a from-scratch Rust reimplementation focused on local-first dictation with minimal dependencies. +- Local ASR is the default. +- Local rewrite is installed automatically with `--features local-rewrite`. +- `whispers` installs the helper rewrite worker for you when that feature is enabled. +- Shell completions are printed to `stdout`. ## License diff --git a/config.example.toml b/config.example.toml index 5a70488..ec37c8c 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,5 +1,4 @@ # whispers configuration -# Copy to ~/.config/whispers/config.toml and customize # # Keybinding is handled by your compositor. Example for Hyprland: # bind = SUPER ALT, D, exec, whispers @@ -9,42 +8,36 @@ [audio] # Input device name (empty = system default) device = "" -# Sample rate in Hz (ASR expects 16000) +# Sample rate in Hz (ASR requires 16000) sample_rate = 16000 [transcription] # Active transcription backend ("whisper_cpp", "faster_whisper", "nemo", or "cloud") -# "nemo" is experimental and intended for benchmarking rather than the default recommendation. backend = "whisper_cpp" -# Cloud failure policy ("configured_local" or "none") +# Cloud fallback behavior ("configured_local" or "none") fallback = "configured_local" # Local backend used directly in local mode and as the cloud fallback backend local_backend = "whisper_cpp" -# Managed ASR model name for the local backend. -# Recommended defaults: "large-v3-turbo" for multilingual local use, -# "distil-large-v3.5" for faster English-only local use. +# Managed ASR model name for the selected backend selected_model = "large-v3-turbo" -# Local backend model path. Leave empty to use the managed selection. -# Experimental NeMo models are managed as prepared model directories, not ggml files. -# Their first warm-up can be much slower than steady-state dictation. +# Path to the local backend-specific model or empty to use the selected managed model +# Manage models with: whispers asr-model list / download / select model_path = "~/.local/share/whispers/ggml-large-v3-turbo.bin" -# Language code ("en", "fr", "de", etc.) +# Language code ("en", "fr", "de", etc.) or "auto" for auto-detect language = "auto" # Enable GPU acceleration (set false to force CPU) use_gpu = true # Enable flash attention when GPU is enabled flash_attn = true # How long the hidden ASR worker stays warm without requests (0 = never expire) -# Experimental NeMo models usually feel best with a warm worker. idle_timeout_ms = 120000 [postprocess] -# "raw" keeps output close to Whisper, "advanced_local" enables the rewrite model -# "legacy_basic" is deprecated and only kept for older cleanup-based configs +# "raw" (default), "advanced_local", "agentic_rewrite", or "legacy_basic" for deprecated cleanup configs mode = "raw" [session] -# Enable short-lived session backtracking in advanced_local mode +# Enable short-lived session backtracking in rewrite modes enabled = true # How many recent dictation entries to keep in the runtime session ledger max_entries = 3 @@ -64,7 +57,7 @@ snippet_trigger = "insert" [rewrite] # Rewrite backend ("local" or "cloud") backend = "local" -# Cloud rewrite failure policy ("local" or "none") +# Cloud fallback behavior ("local" or "none") fallback = "local" # Managed rewrite model name for advanced_local mode selected_model = "qwen-3.5-4b-q4_k_m" @@ -72,24 +65,31 @@ selected_model = "qwen-3.5-4b-q4_k_m" # Custom rewrite models should be chat-capable GGUFs with an embedded # chat template that llama.cpp can apply at runtime. model_path = "" -# Optional plain-text file with extra rewrite instructions appended to the -# built-in system prompt. Missing files are ignored. +# Append-only custom rewrite instructions file (empty = disabled) instructions_path = "~/.local/share/whispers/rewrite-instructions.txt" # Rewrite profile selection ("auto", "qwen", "generic", or "llama_compat") profile = "auto" # Timeout for local rewrite inference in milliseconds timeout_ms = 30000 -# How long the hidden rewrite worker stays warm without requests (0 = never expire) +# How long the hidden rewrite worker stays warm without requests idle_timeout_ms = 120000 # Maximum characters accepted from the rewrite model max_output_chars = 1200 # Maximum tokens to generate for rewritten output max_tokens = 256 +[agentic_rewrite] +# App-aware rewrite policy rules used by postprocess.mode = "agentic_rewrite" +policy_path = "~/.local/share/whispers/app-rewrite-policy.toml" +# Technical glossary used by postprocess.mode = "agentic_rewrite" +glossary_path = "~/.local/share/whispers/technical-glossary.toml" +# Default correction policy ("conservative", "balanced", or "aggressive") +default_correction_policy = "balanced" + [cloud] # Cloud provider ("openai" or "openai_compatible") provider = "openai" -# Base URL for openai_compatible providers. Leave empty for OpenAI. +# Custom base URL for openai_compatible providers (empty uses the OpenAI default) base_url = "" # Optional API key stored directly in the config (empty = use api_key_env instead) api_key = "" @@ -103,9 +103,9 @@ request_timeout_ms = 15000 [cloud.transcription] # Cloud transcription model model = "gpt-4o-mini-transcribe" -# "inherit_local" uses [transcription].language when it is not "auto" -# "force" uses the value below instead +# "inherit_local" uses [transcription].language when it is not "auto"; "force" uses the value below language_mode = "inherit_local" +# Language code used when language_mode = "force" language = "" [cloud.rewrite] @@ -113,7 +113,7 @@ language = "" model = "gpt-4.1-mini" # Sampling temperature for cloud rewrite temperature = 0.1 -# Maximum tokens requested from the cloud rewrite backend +# Maximum tokens requested from the cloud rewrite model max_output_tokens = 256 [feedback] diff --git a/docs/refactor-plan.md b/docs/refactor-plan.md new file mode 100644 index 0000000..073b464 --- /dev/null +++ b/docs/refactor-plan.md @@ -0,0 +1,314 @@ +# Whispers Refactor Plan + +Status: complete +Workspace: `refactor-plan` at `/home/notes/Projects/whispers-refactor-plan` +Planning goal: reduce module sprawl and dependency tangles without mixing in feature work or behavior changes. + +## Working Rules + +- Keep refactor work in this workspace, not the shared feature workspace. +- Prefer behavior-preserving extractions first. Delay semantic changes until the new boundaries are in place. +- Keep each checkpoint to one logical concern and one Conventional Commit description. +- Run targeted tests after each checkpoint, then broaden to `cargo test` when the phase is stable. +- Do not start with OSD polish or naming cleanup. Fix structure first. + +## Current Diagnosis + +The main mess is not the top-level flow. The main mess is that a few large modules own too many responsibilities at once: + +- `src/main.rs` is the de facto crate root for almost everything. +- `src/bin/whispers-rewrite-worker.rs` and `src/bin/whispers-osd.rs` pull shared code in via `#[path = ...]` instead of a shared library crate. +- `src/postprocess.rs` mixes planning, backend routing, fallback, and finalization. +- `src/agentic_rewrite.rs` mixes runtime policy logic with file-backed CLI admin. +- `src/asr.rs` duplicates backend lifecycle logic across batch and live paths. +- `src/app.rs` mixes orchestration, runtime state, injection policy, and session persistence. +- `src/personalization.rs`, `src/session.rs`, `src/config.rs`, and `src/setup.rs` each bundle multiple separate concerns. + +## Recommended Order + +1. Establish crate boundaries. +2. Fix dependency direction in the runtime path. +3. Split the largest domain modules by responsibility. +4. Split config/setup/model/completion orchestration. +5. Finish with platform adapters and retire stale reporting cleanup if no real surface remains. + +## Phase 1: Crate Boundaries + +Goal: stop sharing code between binaries through `#[path = ...]` includes and give the project a real library surface. + +### Checkpoint 1.1 + +- Commit: `refactor: add library crate and thin binary entrypoints` +- Deliverables: + - Add `src/lib.rs`. + - Move module declarations out of `src/main.rs`. + - Make `src/main.rs` a thin CLI entrypoint. + - Make `src/bin/whispers-rewrite-worker.rs` and `src/bin/whispers-osd.rs` use library modules instead of `#[path = ...]`. +- Validation: + - `cargo test` + - `cargo test --bin whispers` + - `cargo test --bin whispers-rewrite-worker` + +### Checkpoint 1.2 + +- Commit: `refactor: isolate binary-only startup code` +- Deliverables: + - Move PID lock and process signaling helpers into a small runtime support module. + - Keep binary-specific CLI/bootstrap logic out of domain modules. +- Validation: + - `cargo test main::tests` + - `cargo test` + +## Phase 2: Runtime Path + +Goal: make the dictation path read as orchestration over smaller components instead of one large cross-module knot. + +### Checkpoint 2.1 + +- Commit: `refactor: extract agentic rewrite runtime policy engine` +- Deliverables: + - Split `src/agentic_rewrite.rs` into runtime policy code and file-backed admin/store code. + - Runtime modules should not print to stdout or mutate files. + - CLI-facing add/list/remove/path helpers should depend on the store layer, not the runtime layer. +- Validation: + - `cargo test agentic_rewrite` + - `cargo test postprocess` + +### Checkpoint 2.2 + +- Commit: `refactor: split postprocess planning and execution` +- Deliverables: + - Extract a planning layer from `src/postprocess.rs` for transcript preparation and session intent. + - Extract an execution layer for local/cloud rewrite calls. + - Keep final acceptance and fallback rules in a smaller decision layer. +- Validation: + - `cargo test postprocess` + - `cargo test session` + - `cargo test personalization` + +### Checkpoint 2.3 + +- Commit: `refactor: unify asr backend lifecycle` +- Deliverables: + - Remove duplicated backend switching across `prepare_transcriber`, `prepare_live_transcriber`, `transcribe_audio`, and `transcribe_live_audio`. + - Centralize fallback policy in one place. +- Validation: + - `cargo test asr` + - `cargo test faster_whisper` + - `cargo test nemo_asr` + +### Checkpoint 2.4 + +- Commit: `refactor: split rewrite routing from prompt rendering` +- Status: + - completed sub-checkpoints: routing split, prompt rendering split, local rewrite engine extraction, output cleanup plus thin facade + - phase status: complete +- Deliverables: + - Separate route selection from prompt/template rendering in `src/rewrite.rs`. + - Keep giant prompt contracts out of routing logic. +- Validation: + - `cargo test rewrite` + - `cargo test rewrite_profile` + +### Checkpoint 2.5 + +- Commit: `refactor: split app controller from dictation runtime state` +- Status: + - completed sub-checkpoints: extracted runtime state transitions, isolated OSD helpers, kept `run()` as controller orchestration + - phase status: complete +- Deliverables: + - Keep `src/app.rs` as orchestration. + - Extract dictation runtime state, preview pacing, session updates, and injection decisions into smaller modules. + - Minimize direct side effects inside the main dictation loop. +- Validation: + - `cargo test app` + - `cargo test session` + - targeted manual smoke test for `whispers voice` + +## Phase 3: Domain Modules + +Goal: split large pure-ish logic files by domain instead of by size. + +### Checkpoint 3.1 + +- Commit: `refactor: split personalization store and rewrite candidates` +- Status: + - completed sub-checkpoints: extracted file-backed dictionary/snippet store, moved rewrite transcript and candidate generation out of the facade, kept `crate::personalization::*` call sites stable via re-exports + - phase status: complete +- Deliverables: + - Split `src/personalization.rs` into: + - store and CLI mutation helpers + - text transformation rules + - rewrite candidate building and ranking +- Validation: + - `cargo test personalization` + +### Checkpoint 3.2 + +- Commit: `refactor: split session persistence from backtrack planning` +- Status: + - completed sub-checkpoints: extracted runtime session persistence, isolated backtrack heuristics and typing-context mapping, kept `crate::session::*` paths stable via re-exports + - phase status: complete +- Deliverables: + - Move JSON load/save/prune logic away from backtrack heuristics. + - Make backtrack planning operate on in-memory data structures. +- Validation: + - `cargo test session` + - `cargo test postprocess` + +### Checkpoint 3.3 + +- Commit: `refactor: split cleanup lexicon analysis and rendering` +- Status: + - completed sub-checkpoints: extracted cue-family lexicon and hypothesis matching, isolated piece rendering, kept `crate::cleanup::*` public APIs stable at the root + - phase status: complete +- Deliverables: + - Split `src/cleanup.rs` into lexical rules, analysis, and rendering pieces. + - Keep the public cleanup API stable until follow-up cleanup is done. +- Validation: + - `cargo test cleanup` + +## Phase 4: Config and Command Surface + +Goal: remove duplicated sources of truth and reduce direct file mutation from high-level commands. + +### Checkpoint 4.1 + +- Commit: `refactor: split config schema defaults and editing` +- Status: + - completed sub-checkpoints: extracted schema/default types, split load and legacy migration logic from path helpers, isolated TOML mutation helpers behind the root `crate::config::*` facade + - phase status: complete +- Deliverables: + - Split `src/config.rs` into schema, defaults/template, load/migrate, and edit/update modules. + - Put TOML mutation behind a small config editor API. +- Validation: + - `cargo test config` + - `cargo test cli` + +### Checkpoint 4.2 + +- Commit: `refactor: extract setup flow phases` +- Status: + - completed sub-checkpoints: separated interactive selection from config application, isolated post-apply side effects, moved summary and completion rendering out of the flow orchestrator + - phase status: complete +- Deliverables: + - Break `src/setup.rs` into prompt/selection, config apply, side effects, and summary/reporting phases. + - Keep interactive behavior unchanged. +- Validation: + - `cargo test setup` + +### Checkpoint 4.3 + +- Commit: `refactor: unify model management workflows` +- Status: + - completed sub-checkpoints: extracted shared model config/bootstrap helpers, centralized common download/status logic, trimmed backend-specific model modules down to catalog and backend behavior + - phase status: complete +- Deliverables: + - Reduce duplication across `src/model.rs`, `src/asr_model.rs`, and `src/rewrite_model.rs`. + - Share download/select/status plumbing where behavior is actually the same. +- Validation: + - `cargo test model` + - `cargo test asr_model` + - `cargo test rewrite_model` + +### Checkpoint 4.4 + +- Commit: `refactor: isolate shell completion installers` +- Status: + - completed sub-checkpoints: split shell detection from completion rendering, kept `run_completions` as the thin entrypoint, noted that the current tree does not yet include install-path or shell-rc mutation logic + - phase status: complete +- Deliverables: + - Separate shell detection, script generation, install-path policy, and shell rc mutation in `src/completions.rs`. +- Validation: + - `cargo test completions` + +### Checkpoint 4.5 + +- Commit: `docs: derive config docs from canonical source` +- Status: + - completed sub-checkpoints: made the config writer template the canonical source, aligned `config.example.toml` with that template, removed the duplicated README config block in favor of referencing the canonical example + - phase status: complete +- Deliverables: + - Stop maintaining defaults separately in code, `config.example.toml`, and the README snippet. + - Pick one canonical source and generate or reuse it everywhere else. +- Validation: + - `cargo test config` + - manual check of `README.md` and `config.example.toml` + +## Phase 5: Platform Adapters and Reporting + +Goal: separate policy from OS effects in smaller but high-value modules. + +### Checkpoint 5.1 + +- Commit: `refactor: extract injection adapter layer` +- Status: + - completed sub-checkpoints: split clipboard process handling from virtual keyboard emission, kept `TextInjector` as the stable policy/orchestration facade for runtime callers + - phase status: complete +- Deliverables: + - Separate injection policy from evdev and clipboard execution in `src/inject.rs`. +- Validation: + - `cargo test inject` + +### Checkpoint 5.2 + +- Commit: `refactor: split audio recorder and dsp helpers` +- Status: + - completed sub-checkpoints: split recorder lifecycle and device/config negotiation from reusable DSP helpers, kept `AudioRecorder` and `preprocess_audio` stable for callers + - phase status: complete +- Deliverables: + - Separate recorder lifecycle and device interaction from reusable audio transforms in `src/audio.rs`. +- Validation: + - `cargo test audio` + +### Checkpoint 5.3 + +- Commit: `docs: retire stale status reporting checkpoint` +- Status: + - completed sub-checkpoints: verified that `src/status.rs` is absent, confirmed earlier checkpoints already split the remaining real reporting surfaces (`setup/report.rs` and model status rendering), retired the stale roadmap item instead of inventing a fake module + - phase status: complete +- Deliverables: + - Confirm whether a standalone status/reporting module still exists in the current tree. + - Retire the stale checkpoint if the earlier refactors already covered the real reporting surfaces. +- Validation: + - manual codebase search for reporting surfaces + +## Not Now + +- Rewriting the user-facing CLI. +- Replacing `tokio` structure or async strategy. +- Changing OSD visuals. +- Large naming-only passes. +- Folding unrelated feature work into refactor commits. + +## Per-Checkpoint Template + +Use this each time work starts on a new item: + +1. Confirm the checkpoint and write the exact Conventional Commit description with `jj desc -m`. +2. Restate the non-goals for that checkpoint. +3. Move code without changing behavior. +4. Run targeted tests for touched modules. +5. If the checkpoint is complete, create the next working-copy change with `jj new`. +6. Update this file with status notes before moving on. + +## Progress Log + +- [x] Phase 1.1 complete +- [x] Phase 1.2 complete +- [x] Phase 2.1 complete +- [x] Phase 2.2 complete +- [x] Phase 2.3 complete +- [x] Phase 2.4 complete +- [x] Phase 2.5 complete +- [x] Phase 3.1 complete +- [x] Phase 3.2 complete +- [x] Phase 3.3 complete +- [x] Phase 4.1 complete +- [x] Phase 4.2 complete +- [x] Phase 4.3 complete +- [x] Phase 4.4 complete +- [x] Phase 4.5 complete +- [x] Phase 5.1 complete +- [x] Phase 5.2 complete +- [x] Phase 5.3 complete diff --git a/src/agentic_rewrite/admin.rs b/src/agentic_rewrite/admin.rs new file mode 100644 index 0000000..5037c93 --- /dev/null +++ b/src/agentic_rewrite/admin.rs @@ -0,0 +1,170 @@ +use std::path::Path; + +use crate::config::Config; +use crate::error::Result; +use crate::rewrite_protocol::RewriteCorrectionPolicy; + +use super::{AppRule, ContextMatcher, GlossaryEntry, store}; + +pub(super) fn print_app_rule_path(config_override: Option<&Path>) -> Result<()> { + let config = Config::load(config_override)?; + println!("{}", config.resolved_agentic_policy_path().display()); + Ok(()) +} + +pub(super) fn print_glossary_path(config_override: Option<&Path>) -> Result<()> { + let config = Config::load(config_override)?; + println!("{}", config.resolved_agentic_glossary_path().display()); + Ok(()) +} + +pub(super) fn list_app_rules(config_override: Option<&Path>) -> Result<()> { + let config = Config::load(config_override)?; + let rules = store::read_policy_file(&config.resolved_agentic_policy_path())?; + if rules.is_empty() { + println!("No app rules configured."); + return Ok(()); + } + + for rule in rules { + println!( + "{} | match: {} | correction_policy: {} | instructions: {}", + rule.name, + render_matcher(&rule.matcher), + rule.correction_policy + .map(|policy| policy.as_str()) + .unwrap_or("inherit"), + single_line(&rule.instructions) + ); + } + + Ok(()) +} + +pub(super) fn add_app_rule( + config_override: Option<&Path>, + name: &str, + instructions: &str, + matcher: ContextMatcher, + correction_policy: Option, +) -> Result<()> { + let config = Config::load(config_override)?; + let path = config.resolved_agentic_policy_path(); + let mut rules = store::read_policy_file(&path)?; + store::upsert_app_rule( + &mut rules, + AppRule { + name: name.to_string(), + matcher, + instructions: instructions.to_string(), + correction_policy, + }, + ); + store::write_policy_file(&path, &rules)?; + println!("Added app rule: {name}"); + println!("App rules updated: {}", path.display()); + Ok(()) +} + +pub(super) fn remove_app_rule(config_override: Option<&Path>, name: &str) -> Result<()> { + let config = Config::load(config_override)?; + let path = config.resolved_agentic_policy_path(); + let mut rules = store::read_policy_file(&path)?; + let removed = store::remove_app_rule_entry(&mut rules, name); + store::write_policy_file(&path, &rules)?; + if removed { + println!("Removed app rule: {name}"); + } else { + println!("No app rule matched: {name}"); + } + println!("App rules updated: {}", path.display()); + Ok(()) +} + +pub(super) fn list_glossary(config_override: Option<&Path>) -> Result<()> { + let config = Config::load(config_override)?; + let entries = store::read_glossary_file(&config.resolved_agentic_glossary_path())?; + if entries.is_empty() { + println!("No glossary entries configured."); + return Ok(()); + } + + for entry in entries { + let aliases = if entry.aliases.is_empty() { + "-".to_string() + } else { + entry.aliases.join(", ") + }; + println!( + "{} | aliases: {} | match: {}", + entry.term, + aliases, + render_matcher(&entry.matcher) + ); + } + + Ok(()) +} + +pub(super) fn add_glossary_entry( + config_override: Option<&Path>, + term: &str, + aliases: &[String], + matcher: ContextMatcher, +) -> Result<()> { + let config = Config::load(config_override)?; + let path = config.resolved_agentic_glossary_path(); + let mut entries = store::read_glossary_file(&path)?; + store::upsert_glossary_entry( + &mut entries, + GlossaryEntry { + term: term.to_string(), + aliases: aliases.to_vec(), + matcher, + }, + ); + store::write_glossary_file(&path, &entries)?; + println!("Added glossary entry: {term}"); + println!("Glossary updated: {}", path.display()); + Ok(()) +} + +pub(super) fn remove_glossary_entry(config_override: Option<&Path>, term: &str) -> Result<()> { + let config = Config::load(config_override)?; + let path = config.resolved_agentic_glossary_path(); + let mut entries = store::read_glossary_file(&path)?; + let removed = store::remove_glossary_entry_by_term(&mut entries, term); + store::write_glossary_file(&path, &entries)?; + if removed { + println!("Removed glossary entry: {term}"); + } else { + println!("No glossary entry matched: {term}"); + } + println!("Glossary updated: {}", path.display()); + Ok(()) +} + +fn single_line(text: &str) -> String { + text.trim().replace('\n', "\\n") +} + +fn render_matcher(matcher: &ContextMatcher) -> String { + let mut parts = Vec::new(); + if let Some(surface_kind) = matcher.surface_kind { + parts.push(format!("surface_kind={}", surface_kind.as_str())); + } + if let Some(app_id) = matcher.app_id.as_deref() { + parts.push(format!("app_id={app_id}")); + } + if let Some(window_title) = matcher.window_title_contains.as_deref() { + parts.push(format!("window_title_contains={window_title}")); + } + if let Some(browser_domain) = matcher.browser_domain_contains.as_deref() { + parts.push(format!("browser_domain_contains={browser_domain}")); + } + if parts.is_empty() { + "global".to_string() + } else { + parts.join(", ") + } +} diff --git a/src/agentic_rewrite/mod.rs b/src/agentic_rewrite/mod.rs new file mode 100644 index 0000000..b387c7b --- /dev/null +++ b/src/agentic_rewrite/mod.rs @@ -0,0 +1,268 @@ +mod admin; +mod runtime; +mod store; + +use std::path::Path; + +use serde::{Deserialize, Serialize}; + +use crate::config::Config; +use crate::error::Result; +use crate::rewrite_protocol::{RewriteCorrectionPolicy, RewriteSurfaceKind, RewriteTranscript}; + +#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] +#[serde(default)] +pub struct ContextMatcher { + pub surface_kind: Option, + pub app_id: Option, + pub window_title_contains: Option, + pub browser_domain_contains: Option, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] +#[serde(default)] +struct AppRule { + name: String, + #[serde(flatten)] + matcher: ContextMatcher, + instructions: String, + correction_policy: Option, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] +#[serde(default)] +struct GlossaryEntry { + term: String, + aliases: Vec, + #[serde(flatten)] + matcher: ContextMatcher, +} + +#[derive(Debug, Clone)] +struct PreparedGlossaryEntry { + term: String, + aliases: Vec, + matcher: ContextMatcher, + normalized_aliases: Vec>, +} + +pub use runtime::conservative_output_allowed; + +pub fn default_policy_path() -> &'static str { + store::default_policy_path() +} + +pub fn default_glossary_path() -> &'static str { + store::default_glossary_path() +} + +pub fn apply_runtime_policy(config: &Config, transcript: &mut RewriteTranscript) { + let policy_rules = store::load_policy_file_for_runtime(&config.resolved_agentic_policy_path()); + let glossary_entries = + store::load_glossary_file_for_runtime(&config.resolved_agentic_glossary_path()); + + let policy_context = runtime::resolve_policy_context( + config.agentic_rewrite.default_correction_policy, + transcript.typing_context.as_ref(), + &transcript.rewrite_candidates, + &policy_rules, + &glossary_entries, + ); + + for candidate in &policy_context.glossary_candidates { + if transcript + .rewrite_candidates + .iter() + .any(|existing| existing.text == candidate.text) + { + continue; + } + transcript.rewrite_candidates.push(candidate.clone()); + } + + transcript.policy_context = policy_context; +} + +pub fn ensure_starter_files(config: &Config) -> Result> { + store::ensure_starter_files(config) +} + +pub fn print_app_rule_path(config_override: Option<&Path>) -> Result<()> { + admin::print_app_rule_path(config_override) +} + +pub fn print_glossary_path(config_override: Option<&Path>) -> Result<()> { + admin::print_glossary_path(config_override) +} + +pub fn list_app_rules(config_override: Option<&Path>) -> Result<()> { + admin::list_app_rules(config_override) +} + +pub fn add_app_rule( + config_override: Option<&Path>, + name: &str, + instructions: &str, + matcher: ContextMatcher, + correction_policy: Option, +) -> Result<()> { + admin::add_app_rule( + config_override, + name, + instructions, + matcher, + correction_policy, + ) +} + +pub fn remove_app_rule(config_override: Option<&Path>, name: &str) -> Result<()> { + admin::remove_app_rule(config_override, name) +} + +pub fn list_glossary(config_override: Option<&Path>) -> Result<()> { + admin::list_glossary(config_override) +} + +pub fn add_glossary_entry( + config_override: Option<&Path>, + term: &str, + aliases: &[String], + matcher: ContextMatcher, +) -> Result<()> { + admin::add_glossary_entry(config_override, term, aliases, matcher) +} + +pub fn remove_glossary_entry(config_override: Option<&Path>, term: &str) -> Result<()> { + admin::remove_glossary_entry(config_override, term) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use crate::rewrite_protocol::{ + RewriteCandidate, RewriteCandidateKind, RewritePolicyContext, RewriteTranscript, + RewriteTypingContext, + }; + + fn typing_context(surface_kind: RewriteSurfaceKind) -> RewriteTypingContext { + RewriteTypingContext { + focus_fingerprint: "focus".into(), + app_id: Some("dev.zed.Zed".into()), + window_title: Some("docs.rs - serde_json".into()), + surface_kind, + browser_domain: Some("docs.rs".into()), + captured_at_ms: 42, + } + } + + fn transcript_with_candidates(surface_kind: RewriteSurfaceKind) -> RewriteTranscript { + RewriteTranscript { + raw_text: "type script and sir dee json".into(), + correction_aware_text: "type script and sir dee json".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: Some(typing_context(surface_kind)), + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "type script and sir dee json".into(), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + } + } + + #[test] + fn apply_runtime_policy_adds_glossary_candidates() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&[ + "HOME", + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + ]); + let home = crate::test_support::unique_temp_dir("agentic-runtime-home"); + crate::test_support::set_env("HOME", &home.to_string_lossy()); + crate::test_support::remove_env("XDG_CONFIG_HOME"); + crate::test_support::remove_env("XDG_DATA_HOME"); + + let config = Config::default(); + let glossary_path = config.resolved_agentic_glossary_path(); + store::write_glossary_file( + &glossary_path, + &[GlossaryEntry { + term: "TypeScript".into(), + aliases: vec!["type script".into()], + matcher: ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::Editor), + ..ContextMatcher::default() + }, + }], + ) + .expect("write glossary"); + + let mut transcript = transcript_with_candidates(RewriteSurfaceKind::Editor); + apply_runtime_policy(&config, &mut transcript); + assert!( + transcript + .rewrite_candidates + .iter() + .any(|candidate| candidate.text == "TypeScript and sir dee json") + ); + } + + #[test] + fn add_and_remove_roundtrip_for_policy_and_glossary() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&[ + "HOME", + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + ]); + let home = crate::test_support::unique_temp_dir("agentic-cli-home"); + crate::test_support::set_env("HOME", &home.to_string_lossy()); + crate::test_support::remove_env("XDG_CONFIG_HOME"); + crate::test_support::remove_env("XDG_DATA_HOME"); + + add_app_rule( + None, + "zed", + "Preserve Rust identifiers.", + ContextMatcher { + app_id: Some("dev.zed.Zed".into()), + ..ContextMatcher::default() + }, + Some(RewriteCorrectionPolicy::Balanced), + ) + .expect("add app rule"); + let config = Config::load(None).expect("config"); + let rules = store::read_policy_file(&config.resolved_agentic_policy_path()).expect("rules"); + assert_eq!(rules.len(), 1); + + add_glossary_entry( + None, + "serde_json", + &[String::from("sir dee json")], + ContextMatcher::default(), + ) + .expect("add glossary entry"); + let entries = + store::read_glossary_file(&config.resolved_agentic_glossary_path()).expect("entries"); + assert_eq!(entries.len(), 1); + + remove_app_rule(None, "zed").expect("remove app rule"); + remove_glossary_entry(None, "serde_json").expect("remove glossary entry"); + + let rules = store::read_policy_file(&config.resolved_agentic_policy_path()).expect("rules"); + let entries = + store::read_glossary_file(&config.resolved_agentic_glossary_path()).expect("entries"); + assert!(rules.is_empty()); + assert!(entries.is_empty()); + } +} diff --git a/src/agentic_rewrite/runtime.rs b/src/agentic_rewrite/runtime.rs new file mode 100644 index 0000000..ea00a2f --- /dev/null +++ b/src/agentic_rewrite/runtime.rs @@ -0,0 +1,935 @@ +use super::{AppRule, ContextMatcher, GlossaryEntry, PreparedGlossaryEntry}; +use crate::rewrite_protocol::{ + RewriteCandidate, RewriteCandidateKind, RewriteCorrectionPolicy, RewritePolicyContext, + RewritePolicyGlossaryTerm, RewriteSurfaceKind, RewriteTranscript, RewriteTypingContext, +}; + +const MAX_GLOSSARY_CANDIDATES: usize = 4; + +pub(super) fn resolve_policy_context( + default_policy: RewriteCorrectionPolicy, + context: Option<&RewriteTypingContext>, + rewrite_candidates: &[RewriteCandidate], + policy_rules: &[AppRule], + glossary_entries: &[GlossaryEntry], +) -> RewritePolicyContext { + let mut matched_rule_names = Vec::new(); + let mut effective_rule_instructions = Vec::new(); + let mut correction_policy = default_policy; + + for rule in built_in_rules(default_policy) + .into_iter() + .filter(|rule| rule.matcher.matches(context)) + .chain(matching_rules(policy_rules, context)) + { + matched_rule_names.push(rule.name.clone()); + if let Some(policy) = rule.correction_policy { + correction_policy = policy; + } + + let instructions = rule.instructions.trim(); + if !instructions.is_empty() { + effective_rule_instructions.push(instructions.to_string()); + } + } + + let mut active_glossary_entries = glossary_entries + .iter() + .enumerate() + .filter_map(|(index, entry)| { + PreparedGlossaryEntry::new(entry.clone()).map(|entry| (index, entry)) + }) + .filter(|(_, entry)| entry.matcher.matches(context)) + .collect::>(); + active_glossary_entries + .sort_by_key(|(index, entry)| (entry.matcher.specificity_rank(), *index)); + let active_glossary_entries = active_glossary_entries + .into_iter() + .map(|(_, entry)| entry) + .collect::>(); + + RewritePolicyContext { + correction_policy, + matched_rule_names, + effective_rule_instructions, + active_glossary_terms: collapse_glossary_terms(&active_glossary_entries), + glossary_candidates: build_glossary_candidates( + rewrite_candidates, + &active_glossary_entries, + ), + } +} + +pub fn conservative_output_allowed(transcript: &RewriteTranscript, text: &str) -> bool { + let text = text.trim(); + if text.is_empty() { + return false; + } + + transcript + .rewrite_candidates + .iter() + .any(|candidate| candidate_supports_output(&candidate.text, text)) + || transcript + .policy_context + .glossary_candidates + .iter() + .any(|candidate| candidate_supports_output(&candidate.text, text)) +} + +impl ContextMatcher { + fn matches(&self, context: Option<&RewriteTypingContext>) -> bool { + if self.is_empty() { + return true; + } + + let Some(context) = context else { + return false; + }; + + if let Some(surface_kind) = self.surface_kind + && context.surface_kind != surface_kind + { + return false; + } + + if let Some(app_id) = self.app_id.as_deref() + && context.app_id.as_deref() != Some(app_id) + { + return false; + } + + if let Some(needle) = self.window_title_contains.as_deref() + && !contains_ignore_ascii_case(context.window_title.as_deref(), needle) + { + return false; + } + + if let Some(needle) = self.browser_domain_contains.as_deref() + && !contains_ignore_ascii_case(context.browser_domain.as_deref(), needle) + { + return false; + } + + true + } + + fn specificity_rank(&self) -> (u8, u8) { + let strongest_layer = if self.browser_domain_contains.is_some() { + 4 + } else if self.window_title_contains.is_some() { + 3 + } else if self.app_id.is_some() { + 2 + } else if self.surface_kind.is_some() { + 1 + } else { + 0 + }; + let matcher_count = [ + self.surface_kind.is_some(), + self.app_id.is_some(), + self.window_title_contains.is_some(), + self.browser_domain_contains.is_some(), + ] + .into_iter() + .filter(|present| *present) + .count() as u8; + (strongest_layer, matcher_count) + } + + fn is_empty(&self) -> bool { + self.surface_kind.is_none() + && self.app_id.is_none() + && self.window_title_contains.is_none() + && self.browser_domain_contains.is_none() + } +} + +impl AppRule { + fn built_in( + name: &str, + matcher: ContextMatcher, + instructions: &str, + correction_policy: Option, + ) -> Self { + Self { + name: name.to_string(), + matcher, + instructions: instructions.to_string(), + correction_policy, + } + } +} + +impl PreparedGlossaryEntry { + fn new(entry: GlossaryEntry) -> Option { + let term = entry.term.trim().to_string(); + if term.is_empty() { + return None; + } + + let aliases = entry + .aliases + .into_iter() + .map(|alias| alias.trim().to_string()) + .filter(|alias| !alias.is_empty()) + .collect::>(); + let normalized_aliases = aliases + .iter() + .map(|alias| normalized_words(alias)) + .filter(|words| !words.is_empty()) + .collect::>(); + + Some(Self { + term, + aliases, + matcher: entry.matcher, + normalized_aliases, + }) + } +} + +fn candidate_supports_output(candidate: &str, output: &str) -> bool { + if candidate.trim() == output.trim() { + return true; + } + + let candidate_words = normalized_words(candidate); + let output_words = normalized_words(output); + if candidate_words.is_empty() || output_words.is_empty() { + return false; + } + + if candidate_words == output_words { + return true; + } + + if candidate_words.len() != output_words.len() || candidate_words.len() < 4 { + return false; + } + + let differing_pairs = candidate_words + .iter() + .zip(&output_words) + .filter(|(candidate_word, output_word)| candidate_word != output_word) + .collect::>(); + if differing_pairs.is_empty() || differing_pairs.len() > 2 { + return false; + } + + differing_pairs + .into_iter() + .all(|(candidate_word, output_word)| { + is_minor_term_normalization(candidate_word, output_word) + }) +} + +fn is_minor_term_normalization(candidate_word: &str, output_word: &str) -> bool { + let candidate_len = candidate_word.chars().count(); + let output_len = output_word.chars().count(); + let max_len = candidate_len.max(output_len); + if max_len < 3 { + return false; + } + + let distance = levenshtein_distance(candidate_word, output_word); + if distance == 0 || distance > 3 { + return false; + } + + if phonetic_skeleton(candidate_word) == phonetic_skeleton(output_word) { + return true; + } + + distance * 2 <= max_len + 1 +} + +fn levenshtein_distance(left: &str, right: &str) -> usize { + if left == right { + return 0; + } + + let right_chars = right.chars().collect::>(); + let mut previous = (0..=right_chars.len()).collect::>(); + let mut current = vec![0; right_chars.len() + 1]; + + for (left_index, left_char) in left.chars().enumerate() { + current[0] = left_index + 1; + for (right_index, right_char) in right_chars.iter().enumerate() { + let substitution_cost = usize::from(left_char != *right_char); + current[right_index + 1] = (previous[right_index + 1] + 1) + .min(current[right_index] + 1) + .min(previous[right_index] + substitution_cost); + } + std::mem::swap(&mut previous, &mut current); + } + + previous[right_chars.len()] +} + +fn phonetic_skeleton(word: &str) -> String { + let mut chars = word + .chars() + .filter(|ch| is_word_char(*ch)) + .flat_map(|ch| ch.to_lowercase()); + let Some(first) = chars.next() else { + return String::new(); + }; + + let mut skeleton = String::from(first); + let mut previous = first; + for ch in chars { + if matches!(ch, 'a' | 'e' | 'i' | 'o' | 'u' | 'w' | 'y') { + continue; + } + if ch != previous { + skeleton.push(ch); + previous = ch; + } + } + skeleton +} + +fn built_in_rules(default_policy: RewriteCorrectionPolicy) -> Vec { + vec![ + AppRule::built_in( + "baseline/global-default", + ContextMatcher::default(), + "Use the active typing context, recent dictation context, glossary terms, and bounded candidates to resolve technical dictation cleanly while keeping the final-text-only contract. When the utterance clearly points to software, tools, APIs, Linux components, product names, or other technical concepts, prefer the most plausible intended technical term over a phonetically similar common word. Use category cues like window manager, editor, language, library, shell, or package manager to disambiguate nearby technical names. If it remains genuinely ambiguous, stay close to the transcript.", + Some(default_policy), + ), + AppRule::built_in( + "baseline/browser", + ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::Browser), + ..ContextMatcher::default() + }, + "Favor clean prose and natural punctuation for browser text fields, but stay grounded in the listed candidates, glossary evidence, and the utterance's technical topic when it clearly refers to software or documentation.", + Some(RewriteCorrectionPolicy::Balanced), + ), + AppRule::built_in( + "baseline/generic-text", + ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::GenericText), + ..ContextMatcher::default() + }, + "Favor clean prose and natural punctuation for general text entry while staying grounded in the listed candidates and glossary evidence. If the utterance clearly discusses technical tools or software, prefer the most plausible technical term over a phonetically similar common word.", + Some(RewriteCorrectionPolicy::Balanced), + ), + AppRule::built_in( + "baseline/editor", + ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::Editor), + ..ContextMatcher::default() + }, + "Preserve identifiers, filenames, API names, symbols, and technical casing. Avoid rewriting technical wording into generic prose. Infer likely technical terms and proper names from the utterance when the topic is clearly code, tooling, or software.", + Some(RewriteCorrectionPolicy::Balanced), + ), + AppRule::built_in( + "baseline/terminal", + ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::Terminal), + ..ContextMatcher::default() + }, + "Preserve commands, flags, paths, package names, environment variables, and punctuation that changes command meaning. Infer technical commands or package names only when the utterance strongly supports them. If uncertain, prefer the closest listed candidate.", + Some(RewriteCorrectionPolicy::Conservative), + ), + ] +} + +fn matching_rules(rules: &[AppRule], context: Option<&RewriteTypingContext>) -> Vec { + let mut matches = rules + .iter() + .enumerate() + .filter(|(_, rule)| rule.matcher.matches(context)) + .collect::>(); + matches.sort_by_key(|(index, rule)| (rule.matcher.specificity_rank(), *index)); + matches.into_iter().map(|(_, rule)| rule.clone()).collect() +} + +fn collapse_glossary_terms(entries: &[PreparedGlossaryEntry]) -> Vec { + let mut collapsed = Vec::::new(); + for entry in entries { + if let Some(existing) = collapsed + .iter_mut() + .find(|candidate| candidate.term == entry.term) + { + for alias in &entry.aliases { + if !existing + .aliases + .iter() + .any(|existing_alias| existing_alias == alias) + { + existing.aliases.push(alias.clone()); + } + } + continue; + } + + collapsed.push(RewritePolicyGlossaryTerm { + term: entry.term.clone(), + aliases: entry.aliases.clone(), + }); + } + collapsed +} + +fn build_glossary_candidates( + rewrite_candidates: &[RewriteCandidate], + glossary_entries: &[PreparedGlossaryEntry], +) -> Vec { + let mut generated = Vec::new(); + for candidate in rewrite_candidates { + if generated.len() >= MAX_GLOSSARY_CANDIDATES { + break; + } + + if let Some(text) = apply_glossary_entries(&candidate.text, glossary_entries) + && text != candidate.text + && !generated + .iter() + .any(|existing: &RewriteCandidate| existing.text == text) + && !rewrite_candidates + .iter() + .any(|existing| existing.text == text) + { + generated.push(RewriteCandidate { + kind: RewriteCandidateKind::GlossaryCorrection, + text, + }); + } + } + generated +} + +fn apply_glossary_entries(text: &str, entries: &[PreparedGlossaryEntry]) -> Option { + if text.trim().is_empty() || entries.is_empty() { + return None; + } + + let spans = collect_word_spans(text); + if spans.is_empty() { + return None; + } + + let mut replacements = collect_glossary_replacements(text, &spans, entries); + if replacements.is_empty() { + return None; + } + + replacements.sort_by_key(|replacement| replacement.start); + + let mut output = String::new(); + let mut cursor = 0usize; + for replacement in replacements { + output.push_str(&text[cursor..replacement.start]); + output.push_str(&replacement.term); + cursor = replacement.end; + } + output.push_str(&text[cursor..]); + Some(output.trim().to_string()) +} + +fn collect_glossary_replacements( + text: &str, + spans: &[WordSpan], + entries: &[PreparedGlossaryEntry], +) -> Vec { + let mut candidates = Vec::new(); + for (priority, entry) in entries.iter().enumerate() { + if entry.normalized_aliases.is_empty() { + continue; + } + + let mut index = 0usize; + while index < spans.len() { + let Some(alias_len) = best_alias_match(spans, index, &entry.normalized_aliases) else { + index += 1; + continue; + }; + + candidates.push(GlossaryReplacement { + start: spans[index].start, + end: spans[index + alias_len - 1].end, + start_span: index, + end_span: index + alias_len, + term: entry.term.clone(), + priority, + }); + index += alias_len; + } + } + + candidates.sort_by(|left, right| { + right + .priority + .cmp(&left.priority) + .then_with(|| { + (right.end_span - right.start_span).cmp(&(left.end_span - left.start_span)) + }) + .then_with(|| left.start_span.cmp(&right.start_span)) + }); + + let mut selected = Vec::new(); + for candidate in candidates { + if selected.iter().any(|existing: &GlossaryReplacement| { + candidate.start_span < existing.end_span && candidate.end_span > existing.start_span + }) { + continue; + } + selected.push(candidate); + } + + if selected.is_empty() && !text.is_empty() { + return Vec::new(); + } + + selected +} + +fn best_alias_match(spans: &[WordSpan], index: usize, aliases: &[Vec]) -> Option { + aliases + .iter() + .filter(|alias| matches_words(spans, index, alias)) + .map(Vec::len) + .max() +} + +fn matches_words(spans: &[WordSpan], index: usize, words: &[String]) -> bool { + if words.is_empty() || index + words.len() > spans.len() { + return false; + } + + spans[index..index + words.len()] + .iter() + .zip(words) + .all(|(span, word)| span.normalized == *word) +} + +fn collect_word_spans(text: &str) -> Vec { + let mut spans = Vec::new(); + let mut current_start = None; + + for (index, ch) in text.char_indices() { + if is_word_char(ch) { + current_start.get_or_insert(index); + continue; + } + + if let Some(start) = current_start.take() { + spans.push(WordSpan { + start, + end: index, + normalized: normalize_word(&text[start..index]), + }); + } + } + + if let Some(start) = current_start { + spans.push(WordSpan { + start, + end: text.len(), + normalized: normalize_word(&text[start..]), + }); + } + + spans +} + +fn normalized_words(text: &str) -> Vec { + collect_word_spans(text) + .into_iter() + .map(|span| span.normalized) + .collect() +} + +fn normalize_word(word: &str) -> String { + word.chars() + .filter(|ch| is_word_char(*ch)) + .flat_map(|ch| ch.to_lowercase()) + .collect() +} + +fn is_word_char(ch: char) -> bool { + ch.is_alphanumeric() || matches!(ch, '\'' | '-' | '_' | '.') +} + +fn contains_ignore_ascii_case(haystack: Option<&str>, needle: &str) -> bool { + let Some(haystack) = haystack else { + return false; + }; + haystack + .to_ascii_lowercase() + .contains(&needle.to_ascii_lowercase()) +} + +#[derive(Debug, Clone)] +struct WordSpan { + start: usize, + end: usize, + normalized: String, +} + +#[derive(Debug, Clone)] +struct GlossaryReplacement { + start: usize, + end: usize, + start_span: usize, + end_span: usize, + term: String, + priority: usize, +} + +#[cfg(test)] +mod tests { + use super::super::{AppRule, ContextMatcher, GlossaryEntry}; + use super::*; + use crate::rewrite_protocol::{ + RewriteCandidate, RewriteCandidateKind, RewritePolicyContext, RewriteSurfaceKind, + RewriteTranscript, RewriteTypingContext, + }; + + fn typing_context(surface_kind: RewriteSurfaceKind) -> RewriteTypingContext { + RewriteTypingContext { + focus_fingerprint: "focus".into(), + app_id: Some("dev.zed.Zed".into()), + window_title: Some("docs.rs - serde_json".into()), + surface_kind, + browser_domain: Some("docs.rs".into()), + captured_at_ms: 42, + } + } + + fn transcript_with_candidates(surface_kind: RewriteSurfaceKind) -> RewriteTranscript { + RewriteTranscript { + raw_text: "type script and sir dee json".into(), + correction_aware_text: "type script and sir dee json".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: Some(typing_context(surface_kind)), + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "type script and sir dee json".into(), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + } + } + + #[test] + fn built_in_terminal_policy_is_conservative() { + let context = typing_context(RewriteSurfaceKind::Terminal); + let policy = resolve_policy_context( + RewriteCorrectionPolicy::Balanced, + Some(&context), + &[], + &[], + &[], + ); + assert_eq!( + policy.correction_policy, + RewriteCorrectionPolicy::Conservative + ); + assert!( + policy + .matched_rule_names + .iter() + .any(|name| name == "baseline/terminal") + ); + } + + #[test] + fn built_in_policy_guides_technical_term_inference() { + let context = typing_context(RewriteSurfaceKind::GenericText); + let policy = resolve_policy_context( + RewriteCorrectionPolicy::Balanced, + Some(&context), + &[], + &[], + &[], + ); + assert!( + policy + .effective_rule_instructions + .iter() + .any(|instruction| instruction.contains("phonetically similar common word")) + ); + } + + #[test] + fn more_specific_rules_override_less_specific_rules() { + let rules = vec![ + AppRule { + name: "surface".into(), + matcher: ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::Editor), + ..ContextMatcher::default() + }, + instructions: "surface".into(), + correction_policy: Some(RewriteCorrectionPolicy::Balanced), + }, + AppRule { + name: "app".into(), + matcher: ContextMatcher { + app_id: Some("dev.zed.Zed".into()), + ..ContextMatcher::default() + }, + instructions: "app".into(), + correction_policy: Some(RewriteCorrectionPolicy::Aggressive), + }, + ]; + let context = typing_context(RewriteSurfaceKind::Editor); + let policy = resolve_policy_context( + RewriteCorrectionPolicy::Balanced, + Some(&context), + &[], + &rules, + &[], + ); + assert_eq!( + policy.correction_policy, + RewriteCorrectionPolicy::Aggressive + ); + assert_eq!( + policy + .effective_rule_instructions + .last() + .map(String::as_str), + Some("app") + ); + } + + #[test] + fn higher_precedence_matcher_layers_override_lower_layer_combinations() { + let rules = vec![ + AppRule { + name: "surface-and-app".into(), + matcher: ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::Editor), + app_id: Some("dev.zed.Zed".into()), + ..ContextMatcher::default() + }, + instructions: "surface-and-app".into(), + correction_policy: Some(RewriteCorrectionPolicy::Aggressive), + }, + AppRule { + name: "window-title".into(), + matcher: ContextMatcher { + window_title_contains: Some("serde_json".into()), + ..ContextMatcher::default() + }, + instructions: "window-title".into(), + correction_policy: Some(RewriteCorrectionPolicy::Conservative), + }, + ]; + let context = typing_context(RewriteSurfaceKind::Editor); + let policy = resolve_policy_context( + RewriteCorrectionPolicy::Balanced, + Some(&context), + &[], + &rules, + &[], + ); + assert_eq!( + policy.correction_policy, + RewriteCorrectionPolicy::Conservative + ); + assert_eq!( + policy + .effective_rule_instructions + .last() + .map(String::as_str), + Some("window-title") + ); + } + + #[test] + fn glossary_candidates_follow_matching_scope() { + let glossary = vec![ + GlossaryEntry { + term: "TypeScript".into(), + aliases: vec!["type script".into()], + matcher: ContextMatcher { + surface_kind: Some(RewriteSurfaceKind::Editor), + ..ContextMatcher::default() + }, + }, + GlossaryEntry { + term: "serde_json".into(), + aliases: vec!["sir dee json".into()], + matcher: ContextMatcher { + browser_domain_contains: Some("docs.rs".into()), + ..ContextMatcher::default() + }, + }, + ]; + let policy = resolve_policy_context( + RewriteCorrectionPolicy::Balanced, + Some(&typing_context(RewriteSurfaceKind::Editor)), + &[RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "type script and sir dee json".into(), + }], + &[], + &glossary, + ); + assert_eq!(policy.active_glossary_terms.len(), 2); + assert_eq!(policy.glossary_candidates.len(), 1); + assert_eq!( + policy.glossary_candidates[0].text, + "TypeScript and serde_json" + ); + } + + #[test] + fn glossary_candidates_preserve_scoped_alias_overrides() { + let glossary = vec![ + GlossaryEntry { + term: "serde".into(), + aliases: vec!["sir dee".into()], + matcher: ContextMatcher::default(), + }, + GlossaryEntry { + term: "serde_json".into(), + aliases: vec!["sir dee".into()], + matcher: ContextMatcher { + browser_domain_contains: Some("docs.rs".into()), + ..ContextMatcher::default() + }, + }, + ]; + let policy = resolve_policy_context( + RewriteCorrectionPolicy::Balanced, + Some(&typing_context(RewriteSurfaceKind::Editor)), + &[RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "sir dee".into(), + }], + &[], + &glossary, + ); + assert_eq!(policy.glossary_candidates.len(), 1); + assert_eq!(policy.glossary_candidates[0].text, "serde_json"); + } + + #[test] + fn conservative_acceptance_requires_explicit_candidate() { + let mut transcript = transcript_with_candidates(RewriteSurfaceKind::Terminal); + transcript.policy_context.correction_policy = RewriteCorrectionPolicy::Conservative; + transcript.policy_context.glossary_candidates = vec![RewriteCandidate { + kind: RewriteCandidateKind::GlossaryCorrection, + text: "TypeScript and serde_json".into(), + }]; + assert!(conservative_output_allowed( + &transcript, + "type script and sir dee json" + )); + assert!(conservative_output_allowed( + &transcript, + "TypeScript and serde_json" + )); + assert!(!conservative_output_allowed( + &transcript, + "A different rewrite" + )); + } + + #[test] + fn conservative_acceptance_allows_sentence_like_minor_term_normalization() { + let mut hyperland_transcript = RewriteTranscript { + raw_text: "I'm currently using the window manager hyperland.".into(), + correction_aware_text: "I'm currently using the window manager hyperland.".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: Some(typing_context(RewriteSurfaceKind::Terminal)), + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "I'm currently using the window manager hyperland.".into(), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + }; + hyperland_transcript.policy_context.correction_policy = + RewriteCorrectionPolicy::Conservative; + + assert!(conservative_output_allowed( + &hyperland_transcript, + "I'm currently using the window manager Hyprland." + )); + + let mut switch_transcript = RewriteTranscript { + raw_text: "I'm switching from Sui to Hyperland.".into(), + correction_aware_text: "I'm switching from Sui to Hyperland.".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: Some(typing_context(RewriteSurfaceKind::Terminal)), + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "I'm switching from Sui to Hyperland.".into(), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + }; + switch_transcript.policy_context.correction_policy = RewriteCorrectionPolicy::Conservative; + + assert!(conservative_output_allowed( + &switch_transcript, + "I'm switching from Sway to Hyprland." + )); + } + + #[test] + fn conservative_acceptance_keeps_short_command_fragments_strict() { + let mut transcript = RewriteTranscript { + raw_text: "cargo clipy".into(), + correction_aware_text: "cargo clipy".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: Some(typing_context(RewriteSurfaceKind::Terminal)), + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "cargo clipy".into(), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + }; + transcript.policy_context.correction_policy = RewriteCorrectionPolicy::Conservative; + + assert!(!conservative_output_allowed(&transcript, "cargo clippy")); + } + + #[test] + fn minor_term_normalization_uses_phonetic_skeleton_without_allowing_unrelated_words() { + assert!(is_minor_term_normalization("sui", "sway")); + assert!(!is_minor_term_normalization("cat", "dog")); + } +} diff --git a/src/agentic_rewrite/store.rs b/src/agentic_rewrite/store.rs new file mode 100644 index 0000000..dc099c9 --- /dev/null +++ b/src/agentic_rewrite/store.rs @@ -0,0 +1,232 @@ +use std::path::Path; + +use serde::{Deserialize, Serialize}; + +use crate::config::{Config, PostprocessMode}; +use crate::error::{Result, WhsprError}; + +use super::{AppRule, GlossaryEntry}; + +const DEFAULT_POLICY_PATH: &str = "~/.local/share/whispers/app-rewrite-policy.toml"; +const DEFAULT_GLOSSARY_PATH: &str = "~/.local/share/whispers/technical-glossary.toml"; + +const POLICY_STARTER: &str = r#"# App-aware rewrite policy for whispers agentic_rewrite mode. +# Rules are layered, not first-match. Matching rules apply in this order: +# global defaults, surface_kind, app_id, window_title_contains, browser_domain_contains. +# Later, more specific rules override earlier fields. +# +# Uncomment and edit the examples below. +# +# [[rules]] +# name = "terminal-shell" +# surface_kind = "terminal" +# correction_policy = "conservative" +# instructions = "Preserve commands, flags, paths, package names, and environment variables." +# +# [[rules]] +# name = "docs-rs-browser" +# surface_kind = "browser" +# browser_domain_contains = "docs.rs" +# instructions = "Preserve Rust crate names, module paths, and type identifiers." +# +# [[rules]] +# name = "zed-rust" +# app_id = "dev.zed.Zed" +# instructions = "Preserve identifiers, filenames, snake_case, camelCase, and Rust terminology." +"#; + +const GLOSSARY_STARTER: &str = r#"# Technical glossary for whispers agentic_rewrite mode. +# Each entry defines a canonical term plus likely spoken or mis-transcribed aliases. +# +# Uncomment and edit the examples below. +# +# [[entries]] +# term = "TypeScript" +# aliases = ["type script", "types script"] +# surface_kind = "editor" +# +# [[entries]] +# term = "pyproject.toml" +# aliases = ["pie project dot toml", "pie project toml"] +# surface_kind = "terminal" +# +# [[entries]] +# term = "serde_json" +# aliases = ["sir dee json", "serdy json"] +# browser_domain_contains = "docs.rs" +"#; + +#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] +#[serde(default)] +struct PolicyFile { + rules: Vec, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] +#[serde(default)] +struct GlossaryFile { + entries: Vec, +} + +pub(super) fn default_policy_path() -> &'static str { + DEFAULT_POLICY_PATH +} + +pub(super) fn default_glossary_path() -> &'static str { + DEFAULT_GLOSSARY_PATH +} + +pub(super) fn ensure_starter_files(config: &Config) -> Result> { + if config.postprocess.mode != PostprocessMode::AgenticRewrite { + return Ok(Vec::new()); + } + + let mut created = Vec::new(); + let policy_path = config.resolved_agentic_policy_path(); + if ensure_text_file(&policy_path, POLICY_STARTER)? { + created.push(policy_path.display().to_string()); + } + + let glossary_path = config.resolved_agentic_glossary_path(); + if ensure_text_file(&glossary_path, GLOSSARY_STARTER)? { + created.push(glossary_path.display().to_string()); + } + + Ok(created) +} + +fn ensure_text_file(path: &Path, contents: &str) -> Result { + if path.exists() { + return Ok(false); + } + + write_parent(path)?; + std::fs::write(path, contents).map_err(|e| { + WhsprError::Config(format!( + "failed to write starter file {}: {e}", + path.display() + )) + })?; + Ok(true) +} + +pub(super) fn read_policy_file(path: &Path) -> Result> { + if !path.exists() { + return Ok(Vec::new()); + } + + let contents = std::fs::read_to_string(path).map_err(|e| { + WhsprError::Config(format!("failed to read app rules {}: {e}", path.display())) + })?; + if contents.trim().is_empty() { + return Ok(Vec::new()); + } + let file: PolicyFile = toml::from_str(&contents).map_err(|e| { + WhsprError::Config(format!("failed to parse app rules {}: {e}", path.display())) + })?; + Ok(file.rules) +} + +pub(super) fn write_policy_file(path: &Path, rules: &[AppRule]) -> Result<()> { + write_parent(path)?; + let contents = toml::to_string_pretty(&PolicyFile { + rules: rules.to_vec(), + }) + .map_err(|e| WhsprError::Config(format!("failed to encode app rules: {e}")))?; + std::fs::write(path, contents).map_err(|e| { + WhsprError::Config(format!("failed to write app rules {}: {e}", path.display())) + })?; + Ok(()) +} + +pub(super) fn read_glossary_file(path: &Path) -> Result> { + if !path.exists() { + return Ok(Vec::new()); + } + + let contents = std::fs::read_to_string(path).map_err(|e| { + WhsprError::Config(format!("failed to read glossary {}: {e}", path.display())) + })?; + if contents.trim().is_empty() { + return Ok(Vec::new()); + } + let file: GlossaryFile = toml::from_str(&contents).map_err(|e| { + WhsprError::Config(format!("failed to parse glossary {}: {e}", path.display())) + })?; + Ok(file.entries) +} + +pub(super) fn write_glossary_file(path: &Path, entries: &[GlossaryEntry]) -> Result<()> { + write_parent(path)?; + let contents = toml::to_string_pretty(&GlossaryFile { + entries: entries.to_vec(), + }) + .map_err(|e| WhsprError::Config(format!("failed to encode glossary: {e}")))?; + std::fs::write(path, contents).map_err(|e| { + WhsprError::Config(format!("failed to write glossary {}: {e}", path.display())) + })?; + Ok(()) +} + +pub(super) fn load_policy_file_for_runtime(path: &Path) -> Vec { + match read_policy_file(path) { + Ok(rules) => rules, + Err(err) => { + tracing::warn!("{err}; using built-in app rewrite defaults"); + Vec::new() + } + } +} + +pub(super) fn load_glossary_file_for_runtime(path: &Path) -> Vec { + match read_glossary_file(path) { + Ok(entries) => entries, + Err(err) => { + tracing::warn!("{err}; ignoring runtime glossary"); + Vec::new() + } + } +} + +fn write_parent(path: &Path) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|e| { + WhsprError::Config(format!( + "failed to create directory {}: {e}", + parent.display() + )) + })?; + } + Ok(()) +} + +pub(super) fn upsert_app_rule(rules: &mut Vec, rule: AppRule) { + if let Some(existing) = rules.iter_mut().find(|existing| existing.name == rule.name) { + *existing = rule; + return; + } + rules.push(rule); +} + +pub(super) fn remove_app_rule_entry(rules: &mut Vec, name: &str) -> bool { + let before = rules.len(); + rules.retain(|rule| rule.name != name); + before != rules.len() +} + +pub(super) fn upsert_glossary_entry(entries: &mut Vec, entry: GlossaryEntry) { + if let Some(existing) = entries + .iter_mut() + .find(|existing| existing.term == entry.term) + { + *existing = entry; + return; + } + entries.push(entry); +} + +pub(super) fn remove_glossary_entry_by_term(entries: &mut Vec, term: &str) -> bool { + let before = entries.len(); + entries.retain(|entry| entry.term != term); + before != entries.len() +} diff --git a/src/app.rs b/src/app.rs deleted file mode 100644 index 3a99191..0000000 --- a/src/app.rs +++ /dev/null @@ -1,234 +0,0 @@ -use std::process::Child; -use std::time::Instant; - -#[cfg(feature = "osd")] -use std::process::Command; - -use crate::asr; -use crate::audio::AudioRecorder; -use crate::config::{Config, PostprocessMode}; -use crate::context; -use crate::error::Result; -use crate::feedback::FeedbackPlayer; -use crate::inject::TextInjector; -use crate::postprocess; -use crate::session; - -pub async fn run(config: Config) -> Result<()> { - let activation_started = Instant::now(); - // Register signals before startup work to minimize early-signal races. - let mut sigusr1 = - tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined1())?; - let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?; - - let feedback = FeedbackPlayer::new( - config.feedback.enabled, - &config.feedback.start_sound, - &config.feedback.stop_sound, - ); - - // Play start sound first (blocking), then start recording so the sound - // doesn't leak into the mic. - feedback.play_start(); - let recording_context = context::capture_typing_context(); - let session_enabled = config.postprocess.mode == PostprocessMode::AdvancedLocal; - let recent_session = if session_enabled { - session::load_recent_entry(&config.session, &recording_context)? - } else { - None - }; - let mut recorder = AudioRecorder::new(&config.audio); - recorder.start()?; - let mut osd = spawn_osd(); - tracing::info!("recording... (run whispers again to stop)"); - - let transcriber = asr::prepare_transcriber(&config)?; - let rewrite_service = postprocess::prepare_rewrite_service(&config); - asr::prewarm_transcriber(&transcriber, "recording"); - if let Some(service) = rewrite_service.as_ref() { - postprocess::prewarm_rewrite_service(service, "recording"); - } - - tokio::select! { - _ = sigusr1.recv() => { - tracing::info!("toggle signal received, stopping recording"); - } - _ = tokio::signal::ctrl_c() => { - tracing::info!("interrupted, cancelling"); - kill_osd(&mut osd); - recorder.stop()?; - return Ok(()); - } - _ = sigterm.recv() => { - tracing::info!("terminated, cancelling"); - kill_osd(&mut osd); - recorder.stop()?; - return Ok(()); - } - } - - // Stop recording before playing feedback so the stop sound doesn't - // leak into the mic. - kill_osd(&mut osd); - let audio = recorder.stop()?; - feedback.play_stop(); - let sample_rate = config.audio.sample_rate; - let audio_duration_ms = ((audio.len() as f64 / sample_rate as f64) * 1000.0).round() as u64; - - tracing::info!( - samples = audio.len(), - sample_rate, - audio_duration_ms, - "transcribing captured audio" - ); - - let transcribe_started = Instant::now(); - let transcript = asr::transcribe_audio(&config, transcriber, audio, sample_rate).await?; - tracing::info!( - elapsed_ms = transcribe_started.elapsed().as_millis(), - transcript_chars = transcript.raw_text.len(), - "transcription stage finished" - ); - - if transcript.is_empty() { - tracing::warn!("transcription returned empty text"); - postprocess::wait_for_feedback_drain().await; - return Ok(()); - } - - let injection_context = context::capture_typing_context(); - let recent_session = recent_session.filter(|entry| { - let same_focus = entry.entry.focus_fingerprint == injection_context.focus_fingerprint; - if !same_focus { - tracing::debug!( - previous_focus = entry.entry.focus_fingerprint, - current_focus = injection_context.focus_fingerprint, - "session backtrack blocked because focus changed before injection" - ); - } - same_focus - }); - let finalize_started = Instant::now(); - let finalized = postprocess::finalize_transcript( - &config, - transcript, - rewrite_service.as_ref(), - Some(&injection_context), - recent_session.as_ref(), - ) - .await; - tracing::info!( - elapsed_ms = finalize_started.elapsed().as_millis(), - output_chars = finalized.text.len(), - operation = match finalized.operation { - postprocess::FinalizedOperation::Append => "append", - postprocess::FinalizedOperation::ReplaceLastEntry { .. } => "replace_last_entry", - }, - rewrite_used = finalized.rewrite_summary.rewrite_used, - "post-processing stage finished" - ); - - if finalized.text.is_empty() { - tracing::warn!("post-processing produced empty text"); - // When the RMS/duration gates skip transcription, the process would - // exit almost immediately after play_stop(). PipeWire may still be - // draining the stop sound's last buffer; exiting while it's "warm" - // causes an audible click as the OS closes our audio file descriptors. - // With speech, transcription takes seconds — providing natural drain time. - postprocess::wait_for_feedback_drain().await; - return Ok(()); - } - - // Inject text - tracing::info!("injecting text: {:?}", finalized.text); - let injector = TextInjector::new(); - match finalized.operation { - postprocess::FinalizedOperation::Append => { - injector.inject(&finalized.text).await?; - if session_enabled { - session::record_append( - &config.session, - &injection_context, - &finalized.text, - finalized.rewrite_summary, - )?; - } - } - postprocess::FinalizedOperation::ReplaceLastEntry { - entry_id, - delete_graphemes, - } => { - injector - .replace_recent_text(delete_graphemes, &finalized.text) - .await?; - if session_enabled { - session::record_replace( - &config.session, - &injection_context, - entry_id, - &finalized.text, - finalized.rewrite_summary, - )?; - } - } - } - - tracing::info!("done"); - tracing::info!( - total_elapsed_ms = activation_started.elapsed().as_millis(), - "dictation pipeline finished" - ); - Ok(()) -} - -#[cfg(feature = "osd")] -fn spawn_osd() -> Option { - // Look for whispers-osd next to our own binary first, then fall back to PATH - let osd_path = std::env::current_exe() - .ok() - .and_then(|p| p.parent().map(|dir| dir.join("whispers-osd"))) - .filter(|p| p.exists()) - .unwrap_or_else(|| "whispers-osd".into()); - - match Command::new(&osd_path).spawn() { - Ok(child) => { - tracing::debug!("spawned whispers-osd (pid {})", child.id()); - Some(child) - } - Err(e) => { - tracing::warn!( - "failed to spawn whispers-osd from {}: {e}", - osd_path.display() - ); - None - } - } -} - -#[cfg(not(feature = "osd"))] -fn spawn_osd() -> Option { - None -} - -fn kill_osd(child: &mut Option) { - if let Some(mut c) = child.take() { - let pid = c.id() as libc::pid_t; - unsafe { - libc::kill(pid, libc::SIGTERM); - } - let _ = c.wait(); - tracing::debug!("whispers-osd (pid {pid}) terminated"); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn kill_osd_none_is_noop() { - let mut child: Option = None; - kill_osd(&mut child); - assert!(child.is_none()); - } -} diff --git a/src/app/mod.rs b/src/app/mod.rs new file mode 100644 index 0000000..8a58ae6 --- /dev/null +++ b/src/app/mod.rs @@ -0,0 +1,67 @@ +use std::time::Instant; + +use crate::config::Config; +use crate::error::Result; +use crate::postprocess::finalize; + +mod osd; +mod runtime; + +use runtime::DictationRuntime; + +pub async fn run(config: Config) -> Result<()> { + let activation_started = Instant::now(); + // Register signals before startup work to minimize early-signal races. + let mut sigusr1 = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined1())?; + let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?; + let mut runtime = DictationRuntime::new(config); + let recording = runtime.start_recording()?; + runtime.prepare_services()?; + + tokio::select! { + _ = sigusr1.recv() => { + tracing::info!("toggle signal received, stopping recording"); + } + _ = tokio::signal::ctrl_c() => { + tracing::info!("interrupted, cancelling"); + runtime.cancel_recording(recording)?; + return Ok(()); + } + _ = sigterm.recv() => { + tracing::info!("terminated, cancelling"); + runtime.cancel_recording(recording)?; + return Ok(()); + } + } + + let captured = runtime.finish_recording(recording)?; + let transcribed = runtime.transcribe_recording(captured).await?; + + if transcribed.is_empty() { + tracing::warn!("transcription returned empty text"); + finalize::wait_for_feedback_drain().await; + return Ok(()); + } + + let finalized = runtime.finalize_recording(transcribed).await; + if finalized.is_empty() { + tracing::warn!("post-processing produced empty text"); + // When the RMS/duration gates skip transcription, the process would + // exit almost immediately after play_stop(). PipeWire may still be + // draining the stop sound's last buffer; exiting while it's "warm" + // causes an audible click as the OS closes our audio file descriptors. + // With speech, transcription takes seconds — providing natural drain time. + finalize::wait_for_feedback_drain().await; + return Ok(()); + } + + runtime.inject_finalized(finalized).await?; + + tracing::info!("done"); + tracing::info!( + total_elapsed_ms = activation_started.elapsed().as_millis(), + "dictation pipeline finished" + ); + Ok(()) +} diff --git a/src/app/osd.rs b/src/app/osd.rs new file mode 100644 index 0000000..d6a4af2 --- /dev/null +++ b/src/app/osd.rs @@ -0,0 +1,57 @@ +use std::process::Child; + +#[cfg(feature = "osd")] +use std::process::Command; + +#[cfg(feature = "osd")] +pub(super) fn spawn_osd() -> Option { + // Look for whispers-osd next to our own binary first, then fall back to PATH + let osd_path = std::env::current_exe() + .ok() + .and_then(|p| p.parent().map(|dir| dir.join("whispers-osd"))) + .filter(|p| p.exists()) + .unwrap_or_else(|| "whispers-osd".into()); + + match Command::new(&osd_path).spawn() { + Ok(child) => { + tracing::debug!("spawned whispers-osd (pid {})", child.id()); + Some(child) + } + Err(e) => { + tracing::warn!( + "failed to spawn whispers-osd from {}: {e}", + osd_path.display() + ); + None + } + } +} + +#[cfg(not(feature = "osd"))] +pub(super) fn spawn_osd() -> Option { + None +} + +pub(super) fn kill_osd(child: &mut Option) { + if let Some(mut c) = child.take() { + let pid = c.id() as libc::pid_t; + unsafe { + libc::kill(pid, libc::SIGTERM); + } + let _ = c.wait(); + tracing::debug!("whispers-osd (pid {pid}) terminated"); + } +} + +#[cfg(test)] +mod tests { + use super::kill_osd; + use std::process::Child; + + #[test] + fn kill_osd_none_is_noop() { + let mut child: Option = None; + kill_osd(&mut child); + assert!(child.is_none()); + } +} diff --git a/src/app/runtime.rs b/src/app/runtime.rs new file mode 100644 index 0000000..ba138b5 --- /dev/null +++ b/src/app/runtime.rs @@ -0,0 +1,263 @@ +use std::process::Child; +use std::time::Instant; + +use crate::asr; +use crate::audio::AudioRecorder; +use crate::config::Config; +use crate::context::{self, TypingContext}; +use crate::error::Result; +use crate::feedback::FeedbackPlayer; +use crate::inject::TextInjector; +use crate::postprocess::{execution, finalize}; +use crate::rewrite_worker::RewriteService; +use crate::session::{self, EligibleSessionEntry}; +use crate::transcribe::Transcript; + +pub(super) struct DictationRuntime { + config: Config, + feedback: FeedbackPlayer, + session_enabled: bool, + transcriber: Option, + rewrite_service: Option, +} + +pub(super) struct ActiveRecording { + recorder: AudioRecorder, + osd: Option, + recent_session: Option, +} + +pub(super) struct CapturedRecording { + audio: Vec, + sample_rate: u32, + recent_session: Option, +} + +pub(super) struct TranscribedRecording { + transcript: Transcript, + recent_session: Option, +} + +pub(super) struct ReadyInjection { + finalized: finalize::FinalizedTranscript, + injection_context: TypingContext, +} + +impl DictationRuntime { + pub(super) fn new(config: Config) -> Self { + let feedback = FeedbackPlayer::new( + config.feedback.enabled, + &config.feedback.start_sound, + &config.feedback.stop_sound, + ); + let session_enabled = config.postprocess.mode.uses_rewrite(); + + Self { + config, + feedback, + session_enabled, + transcriber: None, + rewrite_service: None, + } + } + + pub(super) fn start_recording(&self) -> Result { + // Play start sound first (blocking), then start recording so the sound + // doesn't leak into the mic. + self.feedback.play_start(); + let recording_context = context::capture_typing_context(); + let recent_session = if self.session_enabled { + session::load_recent_entry(&self.config.session, &recording_context)? + } else { + None + }; + + let mut recorder = AudioRecorder::new(&self.config.audio); + recorder.start()?; + let osd = super::osd::spawn_osd(); + tracing::info!("recording... (run whispers again to stop)"); + + Ok(ActiveRecording { + recorder, + osd, + recent_session, + }) + } + + pub(super) fn prepare_services(&mut self) -> Result<()> { + let transcriber = asr::prepare::prepare_transcriber(&self.config)?; + let rewrite_service = execution::prepare_rewrite_service(&self.config); + asr::prepare::prewarm_transcriber(&transcriber, "recording"); + if let Some(service) = rewrite_service.as_ref() { + execution::prewarm_rewrite_service(service, "recording"); + } + + self.transcriber = Some(transcriber); + self.rewrite_service = rewrite_service; + Ok(()) + } + + pub(super) fn cancel_recording(&self, mut recording: ActiveRecording) -> Result<()> { + super::osd::kill_osd(&mut recording.osd); + recording.recorder.stop()?; + Ok(()) + } + + pub(super) fn finish_recording( + &self, + mut recording: ActiveRecording, + ) -> Result { + // Stop recording before playing feedback so the stop sound doesn't + // leak into the mic. + super::osd::kill_osd(&mut recording.osd); + let audio = recording.recorder.stop()?; + self.feedback.play_stop(); + let sample_rate = self.config.audio.sample_rate; + let audio_duration_ms = ((audio.len() as f64 / sample_rate as f64) * 1000.0).round() as u64; + + tracing::info!( + samples = audio.len(), + sample_rate, + audio_duration_ms, + "transcribing captured audio" + ); + + Ok(CapturedRecording { + audio, + sample_rate, + recent_session: recording.recent_session, + }) + } + + pub(super) async fn transcribe_recording( + &mut self, + recording: CapturedRecording, + ) -> Result { + let transcriber = self + .transcriber + .take() + .expect("transcriber prepared before transcription"); + let transcribe_started = Instant::now(); + let transcript = asr::execute::transcribe_audio( + &self.config, + transcriber, + recording.audio, + recording.sample_rate, + ) + .await?; + + tracing::info!( + elapsed_ms = transcribe_started.elapsed().as_millis(), + transcript_chars = transcript.raw_text.len(), + "transcription stage finished" + ); + + Ok(TranscribedRecording { + transcript, + recent_session: recording.recent_session, + }) + } + + pub(super) async fn finalize_recording( + &self, + recording: TranscribedRecording, + ) -> ReadyInjection { + let injection_context = context::capture_typing_context(); + let recent_session = recording.recent_session.filter(|entry| { + let same_focus = entry.entry.focus_fingerprint == injection_context.focus_fingerprint; + if !same_focus { + tracing::debug!( + previous_focus = entry.entry.focus_fingerprint, + current_focus = injection_context.focus_fingerprint, + "session backtrack blocked because focus changed before injection" + ); + } + same_focus + }); + + let finalize_started = Instant::now(); + let finalized = finalize::finalize_transcript( + &self.config, + recording.transcript, + self.rewrite_service.as_ref(), + Some(&injection_context), + recent_session.as_ref(), + ) + .await; + + tracing::info!( + elapsed_ms = finalize_started.elapsed().as_millis(), + output_chars = finalized.text.len(), + operation = match finalized.operation { + finalize::FinalizedOperation::Append => "append", + finalize::FinalizedOperation::ReplaceLastEntry { .. } => "replace_last_entry", + }, + rewrite_used = finalized.rewrite_summary.rewrite_used, + "post-processing stage finished" + ); + + ReadyInjection { + finalized, + injection_context, + } + } + + pub(super) async fn inject_finalized(&self, ready: ReadyInjection) -> Result<()> { + let ReadyInjection { + finalized, + injection_context, + } = ready; + let finalize::FinalizedTranscript { + text, + operation, + rewrite_summary, + } = finalized; + + tracing::info!("injecting text: {:?}", text); + let injector = TextInjector::new(); + match operation { + finalize::FinalizedOperation::Append => { + injector.inject(&text).await?; + if self.session_enabled { + session::record_append( + &self.config.session, + &injection_context, + &text, + rewrite_summary, + )?; + } + } + finalize::FinalizedOperation::ReplaceLastEntry { + entry_id, + delete_graphemes, + } => { + injector + .replace_recent_text(delete_graphemes, &text) + .await?; + if self.session_enabled { + session::record_replace( + &self.config.session, + &injection_context, + entry_id, + &text, + rewrite_summary, + )?; + } + } + } + + Ok(()) + } +} + +impl TranscribedRecording { + pub(super) fn is_empty(&self) -> bool { + self.transcript.is_empty() + } +} + +impl ReadyInjection { + pub(super) fn is_empty(&self) -> bool { + self.finalized.text.is_empty() + } +} diff --git a/src/asr.rs b/src/asr.rs deleted file mode 100644 index 0014bf7..0000000 --- a/src/asr.rs +++ /dev/null @@ -1,417 +0,0 @@ -use crate::cloud::CloudService; -use crate::config::{Config, TranscriptionBackend, TranscriptionConfig, TranscriptionFallback}; -use crate::error::{Result, WhsprError}; -use crate::faster_whisper::{self, FasterWhisperService}; -use crate::model; -use crate::nemo_asr::{self, NemoAsrService}; -use crate::transcribe::{ - Transcript, TranscriptionBackend as SyncTranscriptionBackend, WhisperLocal, -}; -use std::collections::HashSet; -use std::path::{Path, PathBuf}; - -pub enum PreparedTranscriber { - Whisper(tokio::task::JoinHandle>), - Faster(FasterWhisperService), - Nemo(NemoAsrService), - Cloud(CloudService), -} - -pub fn prepare_transcriber(config: &Config) -> Result { - cleanup_stale_transcribers(config)?; - - match config.transcription.backend { - TranscriptionBackend::WhisperCpp => { - let whisper_config = config.transcription.clone(); - let model_path = config.resolved_model_path(); - Ok(PreparedTranscriber::Whisper(tokio::task::spawn_blocking( - move || WhisperLocal::new(&whisper_config, &model_path), - ))) - } - TranscriptionBackend::FasterWhisper => { - faster_whisper::prepare_service(&config.transcription) - .map(PreparedTranscriber::Faster) - .ok_or_else(|| { - WhsprError::Transcription( - "faster-whisper backend selected but no model path could be resolved" - .into(), - ) - }) - } - TranscriptionBackend::Nemo => nemo_asr::prepare_service(&config.transcription) - .map(PreparedTranscriber::Nemo) - .ok_or_else(|| { - WhsprError::Transcription( - "nemo backend selected but no model reference could be resolved".into(), - ) - }), - TranscriptionBackend::Cloud => Ok(PreparedTranscriber::Cloud(CloudService::new(config)?)), - } -} - -pub fn cleanup_stale_transcribers(config: &Config) -> Result<()> { - let retained = retained_socket_paths(config); - let stale_workers = collect_stale_asr_workers(&retained)?; - for worker in stale_workers { - tracing::info!( - pid = worker.pid, - kind = worker.kind, - socket = %worker.socket_path.display(), - "terminating stale ASR worker" - ); - let result = unsafe { libc::kill(worker.pid, libc::SIGTERM) }; - if result == 0 { - continue; - } - let err = std::io::Error::last_os_error(); - if err.raw_os_error() == Some(libc::ESRCH) { - continue; - } - return Err(WhsprError::Transcription(format!( - "failed to terminate stale {} worker (pid {}): {err}", - worker.kind, worker.pid - ))); - } - Ok(()) -} - -pub fn prewarm_transcriber(prepared: &PreparedTranscriber, phase: &str) { - match prepared { - PreparedTranscriber::Faster(service) => match service.prewarm() { - Ok(()) => tracing::info!("prewarming faster-whisper worker via {}", phase), - Err(err) => tracing::warn!("failed to prewarm faster-whisper worker: {err}"), - }, - PreparedTranscriber::Nemo(service) => match service.prewarm() { - Ok(()) => tracing::info!("prewarming NeMo ASR worker via {}", phase), - Err(err) => tracing::warn!("failed to prewarm NeMo ASR worker: {err}"), - }, - _ => {} - } -} - -pub async fn transcribe_audio( - config: &Config, - prepared: PreparedTranscriber, - audio: Vec, - sample_rate: u32, -) -> Result { - match prepared { - PreparedTranscriber::Whisper(handle) => { - let backend = handle.await.map_err(|e| { - WhsprError::Transcription(format!("model loading task failed: {e}")) - })??; - tokio::task::spawn_blocking(move || backend.transcribe(&audio, sample_rate)) - .await - .map_err(|e| WhsprError::Transcription(format!("transcription task failed: {e}")))? - } - PreparedTranscriber::Faster(service) => match service.transcribe(&audio, sample_rate).await - { - Ok(transcript) => Ok(transcript), - Err(err) => { - tracing::warn!("faster-whisper transcription failed: {err}"); - fallback_whisper_cpp_transcribe(config, audio, sample_rate).await - } - }, - PreparedTranscriber::Nemo(service) => match service.transcribe(&audio, sample_rate).await { - Ok(transcript) => Ok(transcript), - Err(err) => { - tracing::warn!("NeMo ASR transcription failed: {err}"); - fallback_whisper_cpp_transcribe(config, audio, sample_rate).await - } - }, - PreparedTranscriber::Cloud(service) => { - match service.transcribe_audio(config, &audio, sample_rate).await { - Ok(transcript) => Ok(transcript), - Err(err) => { - tracing::warn!("cloud transcription failed: {err}"); - fallback_local_transcribe(config, audio, sample_rate).await - } - } - } - } -} - -async fn fallback_local_transcribe( - config: &Config, - audio: Vec, - sample_rate: u32, -) -> Result { - if config.transcription.backend == TranscriptionBackend::Cloud - && config.transcription.fallback == TranscriptionFallback::None - { - return Err(WhsprError::Transcription( - "cloud transcription failed and [transcription].fallback = \"none\"".into(), - )); - } - - let mut local_config = config.transcription.clone(); - local_config.backend = config.transcription.resolved_local_backend(); - let model_path = config.resolved_model_path(); - tracing::warn!( - "falling back to local ASR backend '{}' using {}", - local_config.backend.as_str(), - model_path.display() - ); - let prepared = match local_config.backend { - TranscriptionBackend::WhisperCpp => { - let whisper_config = local_config.clone(); - Ok(PreparedTranscriber::Whisper(tokio::task::spawn_blocking( - move || WhisperLocal::new(&whisper_config, &model_path), - ))) - } - TranscriptionBackend::FasterWhisper => faster_whisper::prepare_service(&local_config) - .map(PreparedTranscriber::Faster) - .ok_or_else(|| { - WhsprError::Transcription( - "faster-whisper fallback selected but no model path could be resolved".into(), - ) - }), - TranscriptionBackend::Nemo => nemo_asr::prepare_service(&local_config) - .map(PreparedTranscriber::Nemo) - .ok_or_else(|| { - WhsprError::Transcription( - "nemo fallback selected but no model reference could be resolved".into(), - ) - }), - TranscriptionBackend::Cloud => Err(WhsprError::Transcription( - "cloud backend cannot be prepared as a local transcriber".into(), - )), - }?; - match prepared { - PreparedTranscriber::Whisper(handle) => { - let backend = handle.await.map_err(|e| { - WhsprError::Transcription(format!("fallback model loading task failed: {e}")) - })??; - tokio::task::spawn_blocking(move || backend.transcribe(&audio, sample_rate)) - .await - .map_err(|e| { - WhsprError::Transcription(format!("fallback transcription task failed: {e}")) - })? - } - PreparedTranscriber::Faster(service) => service.transcribe(&audio, sample_rate).await, - PreparedTranscriber::Nemo(service) => service.transcribe(&audio, sample_rate).await, - PreparedTranscriber::Cloud(_) => Err(WhsprError::Transcription( - "cloud fallback resolved to cloud backend".into(), - )), - } -} - -async fn fallback_whisper_cpp_transcribe( - config: &Config, - audio: Vec, - sample_rate: u32, -) -> Result { - let Some(model_path) = fallback_whisper_model_path() else { - return Err(WhsprError::Transcription( - "faster-whisper failed and no local large-v3-turbo fallback model is available".into(), - )); - }; - tracing::warn!("falling back to whisper_cpp using {}", model_path.display()); - let whisper_config = whisper_fallback_config(&config.transcription); - let backend = - tokio::task::spawn_blocking(move || WhisperLocal::new(&whisper_config, &model_path)) - .await - .map_err(|e| { - WhsprError::Transcription(format!("fallback model loading task failed: {e}")) - })??; - tokio::task::spawn_blocking(move || backend.transcribe(&audio, sample_rate)) - .await - .map_err(|e| { - WhsprError::Transcription(format!("fallback transcription task failed: {e}")) - })? -} - -fn whisper_fallback_config(config: &TranscriptionConfig) -> TranscriptionConfig { - let mut fallback = config.clone(); - fallback.backend = TranscriptionBackend::WhisperCpp; - fallback.local_backend = TranscriptionBackend::WhisperCpp; - fallback.selected_model = "large-v3-turbo".into(); - fallback.model_path = model::model_path_for_config("ggml-large-v3-turbo.bin"); - fallback -} - -fn fallback_whisper_model_path() -> Option { - let path = model::selected_model_local_path("large-v3-turbo")?; - path.exists().then_some(path) -} - -fn retained_socket_paths(config: &Config) -> HashSet { - let mut retained = HashSet::new(); - match config.transcription.backend { - TranscriptionBackend::FasterWhisper => { - if let Some(service) = faster_whisper::prepare_service(&config.transcription) { - retained.insert(service.socket_path().to_path_buf()); - } - } - TranscriptionBackend::Nemo => { - if let Some(service) = nemo_asr::prepare_service(&config.transcription) { - retained.insert(service.socket_path().to_path_buf()); - } - } - TranscriptionBackend::WhisperCpp | TranscriptionBackend::Cloud => {} - } - retained -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct AsrWorkerProcess { - pid: libc::pid_t, - kind: &'static str, - socket_path: PathBuf, -} - -fn collect_stale_asr_workers(retained: &HashSet) -> Result> { - let proc_dir = std::fs::read_dir("/proc") - .map_err(|e| WhsprError::Transcription(format!("failed to inspect /proc: {e}")))?; - let mut stale = Vec::new(); - for entry in proc_dir { - let entry = match entry { - Ok(entry) => entry, - Err(_) => continue, - }; - let file_name = entry.file_name(); - let Some(pid) = file_name.to_string_lossy().parse::().ok() else { - continue; - }; - let cmdline = match std::fs::read(entry.path().join("cmdline")) { - Ok(cmdline) => cmdline, - Err(_) => continue, - }; - let Some((kind, socket_path)) = parse_asr_worker_cmdline(&cmdline) else { - continue; - }; - if retained.contains(&socket_path) { - continue; - } - stale.push(AsrWorkerProcess { - pid, - kind, - socket_path, - }); - } - Ok(stale) -} - -fn parse_asr_worker_cmdline(cmdline: &[u8]) -> Option<(&'static str, PathBuf)> { - let args: Vec = cmdline - .split(|byte| *byte == 0) - .filter(|arg| !arg.is_empty()) - .map(|arg| String::from_utf8_lossy(arg).into_owned()) - .collect(); - if args.is_empty() || !args.iter().any(|arg| arg == "serve") { - return None; - } - - let kind = if args.iter().any(|arg| { - Path::new(arg) - .file_name() - .is_some_and(|name| name == "faster_whisper_worker.py") - }) { - "faster_whisper" - } else if args.iter().any(|arg| { - Path::new(arg) - .file_name() - .is_some_and(|name| name == "nemo_asr_worker.py") - }) { - "nemo" - } else { - return None; - }; - - let socket_index = args.iter().position(|arg| arg == "--socket-path")?; - let socket_path = PathBuf::from(args.get(socket_index + 1)?); - let runtime_scope = asr_runtime_scope_dir(); - if !socket_path.starts_with(&runtime_scope) { - return None; - } - let file_name = socket_path.file_name()?.to_string_lossy(); - if !file_name.starts_with("asr-") || !file_name.ends_with(".sock") { - return None; - } - - Some((kind, socket_path)) -} - -fn asr_runtime_scope_dir() -> PathBuf { - let base = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".into()); - PathBuf::from(base).join("whispers") -} - -pub fn validate_transcription_config(config: &Config) -> Result<()> { - if config.transcription.backend == TranscriptionBackend::Cloud { - crate::cloud::validate_config(config)?; - } - - if config.transcription.resolved_local_backend() == TranscriptionBackend::FasterWhisper - && !config.transcription.language.eq_ignore_ascii_case("en") - && !config.transcription.language.eq_ignore_ascii_case("auto") - { - return Err(WhsprError::Config( - "faster-whisper managed models are currently English-focused; set [transcription].language = \"en\" or \"auto\"".into(), - )); - } - - if config.transcription.resolved_local_backend() == TranscriptionBackend::FasterWhisper - && config.transcription.language.eq_ignore_ascii_case("auto") - { - tracing::warn!( - "faster-whisper backend is configured with language = \"auto\"; English dictation is recommended" - ); - } - - if config.transcription.resolved_local_backend() == TranscriptionBackend::Nemo - && !config.transcription.language.eq_ignore_ascii_case("en") - && !config.transcription.language.eq_ignore_ascii_case("auto") - { - return Err(WhsprError::Config( - "NeMo experimental ASR models are currently English-only; set [transcription].language = \"en\" or \"auto\"".into(), - )); - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_faster_worker_cmdline_extracts_socket_path() { - let socket = asr_runtime_scope_dir().join("asr-faster-123.sock"); - let cmdline = format!( - "/home/user/.local/share/whispers/faster-whisper/venv/bin/python\0/home/user/.local/share/whispers/faster-whisper/faster_whisper_worker.py\0serve\0--socket-path\0{}\0--model-dir\0/tmp/model\0", - socket.display() - ); - let parsed = parse_asr_worker_cmdline(cmdline.as_bytes()).expect("parse worker"); - assert_eq!(parsed.0, "faster_whisper"); - assert_eq!(parsed.1, socket); - } - - #[test] - fn parse_nemo_worker_cmdline_extracts_socket_path() { - let socket = asr_runtime_scope_dir().join("asr-nemo-456.sock"); - let cmdline = format!( - "/home/user/.local/share/whispers/nemo/venv-asr/bin/python\0/home/user/.local/share/whispers/nemo/nemo_asr_worker.py\0serve\0--socket-path\0{}\0--model-ref\0/tmp/model.nemo\0", - socket.display() - ); - let parsed = parse_asr_worker_cmdline(cmdline.as_bytes()).expect("parse worker"); - assert_eq!(parsed.0, "nemo"); - assert_eq!(parsed.1, socket); - } - - #[test] - fn parse_asr_worker_cmdline_ignores_unrelated_processes() { - let socket = asr_runtime_scope_dir().join("asr-other.sock"); - let cmdline = format!( - "/usr/bin/python\0/home/user/script.py\0serve\0--socket-path\0{}\0", - socket.display() - ); - assert!(parse_asr_worker_cmdline(cmdline.as_bytes()).is_none()); - } - - #[test] - fn parse_asr_worker_cmdline_ignores_socket_outside_runtime_scope() { - let cmdline = b"/home/user/.local/share/whispers/nemo/venv-asr/bin/python\0/home/user/.local/share/whispers/nemo/nemo_asr_worker.py\0serve\0--socket-path\0/var/run/asr-nemo.sock\0"; - assert!(parse_asr_worker_cmdline(cmdline).is_none()); - } -} diff --git a/src/asr/cleanup.rs b/src/asr/cleanup.rs new file mode 100644 index 0000000..92f6948 --- /dev/null +++ b/src/asr/cleanup.rs @@ -0,0 +1,177 @@ +use crate::config::{Config, TranscriptionBackend}; +use crate::error::{Result, WhsprError}; +use std::collections::HashSet; +use std::path::{Path, PathBuf}; + +pub fn cleanup_stale_transcribers(config: &Config) -> Result<()> { + let retained = retained_socket_paths(config); + let stale_workers = collect_stale_asr_workers(&retained)?; + for worker in stale_workers { + tracing::info!( + pid = worker.pid, + kind = worker.kind, + socket = %worker.socket_path.display(), + "terminating stale ASR worker" + ); + let result = unsafe { libc::kill(worker.pid, libc::SIGTERM) }; + if result == 0 { + continue; + } + let err = std::io::Error::last_os_error(); + if err.raw_os_error() == Some(libc::ESRCH) { + continue; + } + return Err(WhsprError::Transcription(format!( + "failed to terminate stale {} worker (pid {}): {err}", + worker.kind, worker.pid + ))); + } + Ok(()) +} + +fn retained_socket_paths(config: &Config) -> HashSet { + let mut retained = HashSet::new(); + match config.transcription.backend { + TranscriptionBackend::FasterWhisper => { + if let Some(service) = crate::faster_whisper::prepare_service(&config.transcription) { + retained.insert(service.socket_path().to_path_buf()); + } + } + TranscriptionBackend::Nemo => { + if let Some(service) = crate::nemo_asr::prepare_service(&config.transcription) { + retained.insert(service.socket_path().to_path_buf()); + } + } + TranscriptionBackend::WhisperCpp | TranscriptionBackend::Cloud => {} + } + retained +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct AsrWorkerProcess { + pid: libc::pid_t, + kind: &'static str, + socket_path: PathBuf, +} + +fn collect_stale_asr_workers(retained: &HashSet) -> Result> { + let proc_dir = std::fs::read_dir("/proc") + .map_err(|e| WhsprError::Transcription(format!("failed to inspect /proc: {e}")))?; + let mut stale = Vec::new(); + for entry in proc_dir { + let entry = match entry { + Ok(entry) => entry, + Err(_) => continue, + }; + let file_name = entry.file_name(); + let Some(pid) = file_name.to_string_lossy().parse::().ok() else { + continue; + }; + let cmdline = match std::fs::read(entry.path().join("cmdline")) { + Ok(cmdline) => cmdline, + Err(_) => continue, + }; + let Some((kind, socket_path)) = parse_asr_worker_cmdline(&cmdline) else { + continue; + }; + if retained.contains(&socket_path) { + continue; + } + stale.push(AsrWorkerProcess { + pid, + kind, + socket_path, + }); + } + Ok(stale) +} + +fn parse_asr_worker_cmdline(cmdline: &[u8]) -> Option<(&'static str, PathBuf)> { + let args: Vec = cmdline + .split(|byte| *byte == 0) + .filter(|arg| !arg.is_empty()) + .map(|arg| String::from_utf8_lossy(arg).into_owned()) + .collect(); + if args.is_empty() || !args.iter().any(|arg| arg == "serve") { + return None; + } + + let kind = if args.iter().any(|arg| { + Path::new(arg) + .file_name() + .is_some_and(|name| name == "faster_whisper_worker.py") + }) { + "faster_whisper" + } else if args.iter().any(|arg| { + Path::new(arg) + .file_name() + .is_some_and(|name| name == "nemo_asr_worker.py") + }) { + "nemo" + } else { + return None; + }; + + let socket_index = args.iter().position(|arg| arg == "--socket-path")?; + let socket_path = PathBuf::from(args.get(socket_index + 1)?); + let runtime_scope = asr_runtime_scope_dir(); + if !socket_path.starts_with(&runtime_scope) { + return None; + } + let file_name = socket_path.file_name()?.to_string_lossy(); + if !file_name.starts_with("asr-") || !file_name.ends_with(".sock") { + return None; + } + + Some((kind, socket_path)) +} + +fn asr_runtime_scope_dir() -> PathBuf { + let base = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".into()); + PathBuf::from(base).join("whispers") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_faster_worker_cmdline_extracts_socket_path() { + let socket = asr_runtime_scope_dir().join("asr-faster-123.sock"); + let cmdline = format!( + "/home/user/.local/share/whispers/faster-whisper/venv/bin/python\0/home/user/.local/share/whispers/faster-whisper/faster_whisper_worker.py\0serve\0--socket-path\0{}\0--model-dir\0/tmp/model\0", + socket.display() + ); + let parsed = parse_asr_worker_cmdline(cmdline.as_bytes()).expect("parse worker"); + assert_eq!(parsed.0, "faster_whisper"); + assert_eq!(parsed.1, socket); + } + + #[test] + fn parse_nemo_worker_cmdline_extracts_socket_path() { + let socket = asr_runtime_scope_dir().join("asr-nemo-456.sock"); + let cmdline = format!( + "/home/user/.local/share/whispers/nemo/venv-asr/bin/python\0/home/user/.local/share/whispers/nemo/nemo_asr_worker.py\0serve\0--socket-path\0{}\0--model-ref\0/tmp/model.nemo\0", + socket.display() + ); + let parsed = parse_asr_worker_cmdline(cmdline.as_bytes()).expect("parse worker"); + assert_eq!(parsed.0, "nemo"); + assert_eq!(parsed.1, socket); + } + + #[test] + fn parse_asr_worker_cmdline_ignores_unrelated_processes() { + let socket = asr_runtime_scope_dir().join("asr-other.sock"); + let cmdline = format!( + "/usr/bin/python\0/home/user/script.py\0serve\0--socket-path\0{}\0", + socket.display() + ); + assert!(parse_asr_worker_cmdline(cmdline.as_bytes()).is_none()); + } + + #[test] + fn parse_asr_worker_cmdline_ignores_socket_outside_runtime_scope() { + let cmdline = b"/home/user/.local/share/whispers/nemo/venv-asr/bin/python\0/home/user/.local/share/whispers/nemo/nemo_asr_worker.py\0serve\0--socket-path\0/var/run/asr-nemo.sock\0"; + assert!(parse_asr_worker_cmdline(cmdline).is_none()); + } +} diff --git a/src/asr/execute.rs b/src/asr/execute.rs new file mode 100644 index 0000000..2069deb --- /dev/null +++ b/src/asr/execute.rs @@ -0,0 +1,188 @@ +use super::prepare::{self, PreparedTranscriber}; +use crate::config::{Config, TranscriptionBackend, TranscriptionConfig, TranscriptionFallback}; +use crate::error::{Result, WhsprError}; +use crate::model; +use crate::transcribe::{Transcript, TranscriptionBackend as _}; + +pub async fn transcribe_audio( + config: &Config, + prepared: PreparedTranscriber, + audio: Vec, + sample_rate: u32, +) -> Result { + match prepared { + prepared @ PreparedTranscriber::Whisper(_) => { + transcribe_with_prepared(prepared, &audio, sample_rate, "").await + } + prepared @ PreparedTranscriber::Faster(_) => { + match transcribe_with_prepared(prepared, &audio, sample_rate, "").await { + Ok(transcript) => Ok(transcript), + Err(err) => { + tracing::warn!("faster-whisper transcription failed: {err}"); + fallback_whisper_cpp_transcribe(config, audio, sample_rate).await + } + } + } + prepared @ PreparedTranscriber::Nemo(_) => { + match transcribe_with_prepared(prepared, &audio, sample_rate, "").await { + Ok(transcript) => Ok(transcript), + Err(err) => { + tracing::warn!("NeMo ASR transcription failed: {err}"); + fallback_whisper_cpp_transcribe(config, audio, sample_rate).await + } + } + } + PreparedTranscriber::Cloud(service) => { + match service.transcribe_audio(config, &audio, sample_rate).await { + Ok(transcript) => Ok(transcript), + Err(err) => { + tracing::warn!("cloud transcription failed: {err}"); + fallback_local_transcribe(config, audio, sample_rate).await + } + } + } + } +} + +async fn transcribe_with_prepared( + prepared: PreparedTranscriber, + audio: &[f32], + sample_rate: u32, + task_label: &str, +) -> Result { + match prepared { + PreparedTranscriber::Whisper(handle) => { + let audio = audio.to_vec(); + let backend = handle + .await + .map_err(|e| transcription_task_error(task_label, "model loading", &e))??; + tokio::task::spawn_blocking(move || backend.transcribe(&audio, sample_rate)) + .await + .map_err(|e| transcription_task_error(task_label, "transcription", &e))? + } + PreparedTranscriber::Faster(service) => service.transcribe(audio, sample_rate).await, + PreparedTranscriber::Nemo(service) => service.transcribe(audio, sample_rate).await, + PreparedTranscriber::Cloud(_) => Err(WhsprError::Transcription( + "cloud transcriber cannot be executed without the caller-owned config".into(), + )), + } +} + +async fn fallback_local_transcribe( + config: &Config, + audio: Vec, + sample_rate: u32, +) -> Result { + let (local_config, model_path) = local_fallback_config(config)?; + tracing::warn!( + "falling back to local ASR backend '{}' using {}", + local_config.backend.as_str(), + model_path.display() + ); + let prepared = prepare::prepare_local_transcriber(&local_config, &model_path)?; + transcribe_with_prepared(prepared, &audio, sample_rate, "fallback").await +} + +async fn fallback_whisper_cpp_transcribe( + config: &Config, + audio: Vec, + sample_rate: u32, +) -> Result { + let Some(model_path) = fallback_whisper_model_path() else { + return Err(WhsprError::Transcription( + "faster-whisper failed and no local large-v3-turbo fallback model is available".into(), + )); + }; + tracing::warn!("falling back to whisper_cpp using {}", model_path.display()); + let whisper_config = whisper_fallback_config(&config.transcription); + let prepared = prepare::prepare_local_transcriber(&whisper_config, &model_path)?; + transcribe_with_prepared(prepared, &audio, sample_rate, "fallback").await +} + +fn local_fallback_config(config: &Config) -> Result<(TranscriptionConfig, std::path::PathBuf)> { + if config.transcription.backend == TranscriptionBackend::Cloud + && config.transcription.fallback == TranscriptionFallback::None + { + return Err(WhsprError::Transcription( + "cloud transcription failed and [transcription].fallback = \"none\"".into(), + )); + } + + let mut local_config = config.transcription.clone(); + local_config.backend = config.transcription.resolved_local_backend(); + Ok((local_config, config.resolved_model_path())) +} + +fn whisper_fallback_config(config: &TranscriptionConfig) -> TranscriptionConfig { + let mut fallback = config.clone(); + fallback.backend = TranscriptionBackend::WhisperCpp; + fallback.local_backend = TranscriptionBackend::WhisperCpp; + fallback.selected_model = "large-v3-turbo".into(); + fallback.model_path = model::model_path_for_config("ggml-large-v3-turbo.bin"); + fallback +} + +fn fallback_whisper_model_path() -> Option { + let path = model::selected_model_local_path("large-v3-turbo")?; + path.exists().then_some(path) +} + +fn transcription_task_error( + task_label: &str, + phase: &str, + error: &tokio::task::JoinError, +) -> WhsprError { + let prefix = if task_label.is_empty() { + String::new() + } else { + format!("{task_label} ") + }; + WhsprError::Transcription(format!("{prefix}{phase} task failed: {error}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn cloud_fallback_none_returns_existing_error() { + let mut config = Config::default(); + config.transcription.backend = TranscriptionBackend::Cloud; + config.transcription.fallback = TranscriptionFallback::None; + + let err = fallback_local_transcribe(&config, vec![0.0; 16], 16_000) + .await + .expect_err("fallback should fail"); + match err { + WhsprError::Transcription(message) => { + assert_eq!( + message, + "cloud transcription failed and [transcription].fallback = \"none\"" + ); + } + other => panic!("unexpected error: {other:?}"), + } + } + + #[test] + fn local_fallback_config_resolves_configured_local_backend() { + let mut config = Config::default(); + config.transcription.backend = TranscriptionBackend::Cloud; + config.transcription.local_backend = TranscriptionBackend::Nemo; + + let (local_config, _) = local_fallback_config(&config).expect("fallback config"); + assert_eq!(local_config.backend, TranscriptionBackend::Nemo); + } + + #[test] + fn whisper_fallback_config_pins_whisper_cpp_large_v3_turbo() { + let fallback = whisper_fallback_config(&TranscriptionConfig::default()); + assert_eq!(fallback.backend, TranscriptionBackend::WhisperCpp); + assert_eq!(fallback.local_backend, TranscriptionBackend::WhisperCpp); + assert_eq!(fallback.selected_model, "large-v3-turbo"); + assert_eq!( + fallback.model_path, + model::model_path_for_config("ggml-large-v3-turbo.bin") + ); + } +} diff --git a/src/asr/mod.rs b/src/asr/mod.rs new file mode 100644 index 0000000..fbd90e2 --- /dev/null +++ b/src/asr/mod.rs @@ -0,0 +1,4 @@ +pub mod cleanup; +pub mod execute; +pub mod prepare; +pub mod validation; diff --git a/src/asr/prepare.rs b/src/asr/prepare.rs new file mode 100644 index 0000000..1e24580 --- /dev/null +++ b/src/asr/prepare.rs @@ -0,0 +1,70 @@ +use crate::cloud::CloudService; +use crate::config::{Config, TranscriptionBackend, TranscriptionConfig}; +use crate::error::{Result, WhsprError}; +use crate::faster_whisper::{self, FasterWhisperService}; +use crate::nemo_asr::{self, NemoAsrService}; +use crate::transcribe::WhisperLocal; +use std::path::Path; + +pub enum PreparedTranscriber { + Whisper(tokio::task::JoinHandle>), + Faster(FasterWhisperService), + Nemo(NemoAsrService), + Cloud(CloudService), +} + +pub fn prepare_transcriber(config: &Config) -> Result { + super::cleanup::cleanup_stale_transcribers(config)?; + + if config.transcription.backend == TranscriptionBackend::Cloud { + return Ok(PreparedTranscriber::Cloud(CloudService::new(config)?)); + } + + prepare_local_transcriber(&config.transcription, &config.resolved_model_path()) +} + +pub(crate) fn prepare_local_transcriber( + transcription: &TranscriptionConfig, + model_path: &Path, +) -> Result { + match transcription.backend { + TranscriptionBackend::WhisperCpp => { + let whisper_config = transcription.clone(); + let model_path = model_path.to_path_buf(); + Ok(PreparedTranscriber::Whisper(tokio::task::spawn_blocking( + move || WhisperLocal::new(&whisper_config, &model_path), + ))) + } + TranscriptionBackend::FasterWhisper => faster_whisper::prepare_service(transcription) + .map(PreparedTranscriber::Faster) + .ok_or_else(|| { + WhsprError::Transcription( + "faster-whisper backend selected but no model path could be resolved".into(), + ) + }), + TranscriptionBackend::Nemo => nemo_asr::prepare_service(transcription) + .map(PreparedTranscriber::Nemo) + .ok_or_else(|| { + WhsprError::Transcription( + "nemo backend selected but no model reference could be resolved".into(), + ) + }), + TranscriptionBackend::Cloud => Err(WhsprError::Transcription( + "cloud backend cannot be prepared as a local transcriber".into(), + )), + } +} + +pub fn prewarm_transcriber(prepared: &PreparedTranscriber, phase: &str) { + match prepared { + PreparedTranscriber::Faster(service) => match service.prewarm() { + Ok(()) => tracing::info!("prewarming faster-whisper worker via {}", phase), + Err(err) => tracing::warn!("failed to prewarm faster-whisper worker: {err}"), + }, + PreparedTranscriber::Nemo(service) => match service.prewarm() { + Ok(()) => tracing::info!("prewarming NeMo ASR worker via {}", phase), + Err(err) => tracing::warn!("failed to prewarm NeMo ASR worker: {err}"), + }, + _ => {} + } +} diff --git a/src/asr/validation.rs b/src/asr/validation.rs new file mode 100644 index 0000000..21678de --- /dev/null +++ b/src/asr/validation.rs @@ -0,0 +1,77 @@ +use crate::config::{Config, TranscriptionBackend}; +use crate::error::{Result, WhsprError}; + +pub fn validate_transcription_config(config: &Config) -> Result<()> { + if config.transcription.backend == TranscriptionBackend::Cloud { + crate::cloud::validate_config(config)?; + } + + if config.transcription.resolved_local_backend() == TranscriptionBackend::FasterWhisper + && !config.transcription.language.eq_ignore_ascii_case("en") + && !config.transcription.language.eq_ignore_ascii_case("auto") + { + return Err(WhsprError::Config( + "faster-whisper managed models are currently English-focused; set [transcription].language = \"en\" or \"auto\"".into(), + )); + } + + if config.transcription.resolved_local_backend() == TranscriptionBackend::FasterWhisper + && config.transcription.language.eq_ignore_ascii_case("auto") + { + tracing::warn!( + "faster-whisper backend is configured with language = \"auto\"; English dictation is recommended" + ); + } + + if config.transcription.resolved_local_backend() == TranscriptionBackend::Nemo + && !config.transcription.language.eq_ignore_ascii_case("en") + && !config.transcription.language.eq_ignore_ascii_case("auto") + { + return Err(WhsprError::Config( + "NeMo experimental ASR models are currently English-only; set [transcription].language = \"en\" or \"auto\"".into(), + )); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn faster_whisper_rejects_non_english_explicit_language() { + let mut config = Config::default(); + config.transcription.backend = TranscriptionBackend::FasterWhisper; + config.transcription.local_backend = TranscriptionBackend::FasterWhisper; + config.transcription.language = "sv".into(); + + let err = validate_transcription_config(&config).expect_err("config should fail"); + match err { + WhsprError::Config(message) => { + assert!( + message.contains("faster-whisper managed models are currently English-focused") + ); + } + other => panic!("unexpected error: {other:?}"), + } + } + + #[test] + fn nemo_rejects_non_english_explicit_language() { + let mut config = Config::default(); + config.transcription.backend = TranscriptionBackend::Nemo; + config.transcription.local_backend = TranscriptionBackend::Nemo; + config.transcription.language = "sv".into(); + + let err = validate_transcription_config(&config).expect_err("config should fail"); + match err { + WhsprError::Config(message) => { + assert!( + message.contains("NeMo experimental ASR models are currently English-only") + ); + } + other => panic!("unexpected error: {other:?}"), + } + } +} diff --git a/src/asr_model.rs b/src/asr_model.rs index 3e0b9b4..d2a0ad5 100644 --- a/src/asr_model.rs +++ b/src/asr_model.rs @@ -1,9 +1,8 @@ use std::path::{Path, PathBuf}; -use crate::config::{ - self, TranscriptionBackend, resolve_config_path, update_config_transcription_selection, -}; +use crate::config::{TranscriptionBackend, update_config_transcription_selection}; use crate::error::{Result, WhsprError}; +use crate::model_support; use crate::{faster_whisper, model, nemo_asr}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -258,11 +257,7 @@ pub fn selected_model_path(name: &str) -> Option { fn active_model_name( config_path_override: Option<&Path>, ) -> Option<(TranscriptionBackend, String)> { - let config_path = resolve_config_path(config_path_override); - if !config_path.exists() { - return None; - } - let config = config::Config::load(Some(&config_path)).ok()?; + let config = model_support::load_config_if_exists(config_path_override)?; Some(( config.transcription.resolved_local_backend(), config.transcription.selected_model, @@ -285,12 +280,7 @@ fn model_status(info: &AsrModelInfo, active: Option<(TranscriptionBackend, &str) }) .unwrap_or(false); - match (is_active, is_local) { - (true, true) => "active", - (true, false) => "active (missing)", - (_, true) => "local", - _ => "remote", - } + model_support::managed_download_status(is_active, is_local) } pub fn list_models(config_path_override: Option<&Path>) { @@ -387,13 +377,10 @@ pub fn select_model(name: &str, config_path_override: Option<&Path>) -> Result<( ))); } - let config_path = resolve_config_path(config_path_override); - if !config_path.exists() { - config::write_default_config( - &config_path, - &model::model_path_for_config("ggml-large-v3-turbo.bin"), - )?; - } + let (config_path, _) = model_support::ensure_default_config( + config_path_override, + &model::model_path_for_config("ggml-large-v3-turbo.bin"), + )?; let config_model_path = match info.backend { TranscriptionBackend::WhisperCpp => model::model_path_for_config( @@ -410,7 +397,7 @@ pub fn select_model(name: &str, config_path_override: Option<&Path>) -> Result<( info.backend, info.name, &config_model_path, - config::Config::load(Some(&config_path)) + model_support::load_config_at_if_exists(&config_path) .map(|config| config.transcription.backend != TranscriptionBackend::Cloud) .unwrap_or(true), )?; diff --git a/src/audio/dsp.rs b/src/audio/dsp.rs new file mode 100644 index 0000000..0a720ea --- /dev/null +++ b/src/audio/dsp.rs @@ -0,0 +1,166 @@ +const HIGHPASS_CUTOFF_HZ: f32 = 80.0; +const TRIM_FRAME_MS: usize = 10; +const TRIM_PADDING_MS: usize = 40; +const TRIM_MIN_RMS: f32 = 0.002; +const TRIM_RELATIVE_RMS: f32 = 0.08; +pub(super) const NORMALIZE_TARGET_PEAK: f32 = 0.85; +const NORMALIZE_MAX_GAIN: f32 = 2.5; +const NORMALIZE_MIN_PEAK: f32 = 0.005; + +pub fn preprocess_audio(samples: &mut Vec, sample_rate: u32) { + if samples.is_empty() || sample_rate == 0 { + return; + } + + let before_len = samples.len(); + let before = audio_stats(samples); + + remove_dc_offset(samples); + apply_highpass(samples, sample_rate, HIGHPASS_CUTOFF_HZ); + trim_silence(samples, sample_rate); + let gain = normalize_peak(samples); + + let after = audio_stats(samples); + tracing::debug!( + "audio preprocessing: len {} -> {}, rms {:.4} -> {:.4}, peak {:.4} -> {:.4}, gain {:.2}x", + before_len, + samples.len(), + before.rms, + after.rms, + before.peak, + after.peak, + gain + ); +} + +#[derive(Clone, Copy)] +struct AudioStats { + rms: f32, + peak: f32, +} + +fn audio_stats(samples: &[f32]) -> AudioStats { + if samples.is_empty() { + return AudioStats { + rms: 0.0, + peak: 0.0, + }; + } + + let mut peak = 0.0f32; + let mut energy = 0.0f32; + for sample in samples { + peak = peak.max(sample.abs()); + energy += sample * sample; + } + + AudioStats { + rms: (energy / samples.len() as f32).sqrt(), + peak, + } +} + +pub(super) fn remove_dc_offset(samples: &mut [f32]) { + if samples.is_empty() { + return; + } + + let mean = samples.iter().copied().sum::() / samples.len() as f32; + if mean.abs() < 1e-6 { + return; + } + + for sample in samples { + *sample -= mean; + } +} + +fn apply_highpass(samples: &mut [f32], sample_rate: u32, cutoff_hz: f32) { + if samples.len() < 2 || sample_rate == 0 || cutoff_hz <= 0.0 { + return; + } + + let dt = 1.0 / sample_rate as f32; + let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff_hz); + let alpha = rc / (rc + dt); + + let mut previous_input = samples[0]; + let mut previous_output = 0.0f32; + samples[0] = 0.0; + + for sample in samples.iter_mut().skip(1) { + let input = *sample; + let output = alpha * (previous_output + input - previous_input); + *sample = output; + previous_input = input; + previous_output = output; + } +} + +pub(super) fn trim_silence(samples: &mut Vec, sample_rate: u32) { + if samples.is_empty() || sample_rate == 0 { + return; + } + + let frame_len = ((sample_rate as usize * TRIM_FRAME_MS) / 1000).max(1); + if samples.len() <= frame_len * 2 { + return; + } + + let frame_rms: Vec = samples.chunks(frame_len).map(frame_rms).collect(); + let peak_rms = frame_rms.iter().copied().fold(0.0f32, f32::max); + if peak_rms <= 0.0 { + return; + } + + let threshold = (peak_rms * TRIM_RELATIVE_RMS).max(TRIM_MIN_RMS); + let Some(first_active) = frame_rms.iter().position(|rms| *rms >= threshold) else { + return; + }; + let Some(last_active) = frame_rms.iter().rposition(|rms| *rms >= threshold) else { + return; + }; + + let padding_samples = (sample_rate as usize * TRIM_PADDING_MS) / 1000; + let padding_frames = padding_samples.div_ceil(frame_len); + let start_frame = first_active.saturating_sub(padding_frames); + let end_frame = (last_active + 1 + padding_frames).min(frame_rms.len()); + + let start = start_frame.saturating_mul(frame_len); + let end = (end_frame.saturating_mul(frame_len)).min(samples.len()); + if start == 0 && end == samples.len() { + return; + } + if start >= end { + return; + } + + *samples = samples[start..end].to_vec(); +} + +fn frame_rms(frame: &[f32]) -> f32 { + if frame.is_empty() { + return 0.0; + } + + let energy = frame.iter().map(|sample| sample * sample).sum::(); + (energy / frame.len() as f32).sqrt() +} + +pub(super) fn normalize_peak(samples: &mut [f32]) -> f32 { + let peak = samples.iter().copied().map(f32::abs).fold(0.0f32, f32::max); + if !(NORMALIZE_MIN_PEAK..NORMALIZE_TARGET_PEAK).contains(&peak) { + return 1.0; + } + + let gain = (NORMALIZE_TARGET_PEAK / peak).min(NORMALIZE_MAX_GAIN); + if gain <= 1.0 { + return 1.0; + } + + for sample in samples { + *sample = (*sample * gain).clamp(-1.0, 1.0); + } + + gain +} diff --git a/src/audio/mod.rs b/src/audio/mod.rs new file mode 100644 index 0000000..bf4207e --- /dev/null +++ b/src/audio/mod.rs @@ -0,0 +1,8 @@ +mod dsp; +mod recorder; + +#[cfg(test)] +mod tests; + +pub use dsp::preprocess_audio; +pub use recorder::AudioRecorder; diff --git a/src/audio.rs b/src/audio/recorder.rs similarity index 55% rename from src/audio.rs rename to src/audio/recorder.rs index fd0f940..8e045bc 100644 --- a/src/audio.rs +++ b/src/audio/recorder.rs @@ -7,14 +7,6 @@ use crate::config::AudioConfig; use crate::error::{Result, WhsprError}; const PREALLOC_SECONDS: usize = 120; -const HIGHPASS_CUTOFF_HZ: f32 = 80.0; -const TRIM_FRAME_MS: usize = 10; -const TRIM_PADDING_MS: usize = 40; -const TRIM_MIN_RMS: f32 = 0.002; -const TRIM_RELATIVE_RMS: f32 = 0.08; -const NORMALIZE_TARGET_PEAK: f32 = 0.85; -const NORMALIZE_MAX_GAIN: f32 = 2.5; -const NORMALIZE_MIN_PEAK: f32 = 0.005; pub struct AudioRecorder { config: AudioConfig, @@ -154,7 +146,7 @@ impl AudioRecorder { pub fn stop(&mut self) -> Result> { // Take and leak the stream — cpal's ALSA backend calls snd_pcm_close() // on drop without draining first, which causes an audible click on - // PipeWire when the stream is still "warm". The OS reclaims file + // PipeWire when the stream is still "warm". The OS reclaims file // descriptors on process exit. if let Some(stream) = self.stream.take() { let _ = stream.pause(); @@ -182,169 +174,11 @@ impl AudioRecorder { buffer[start + i] *= gain; } - preprocess_audio(&mut buffer, self.config.sample_rate); + super::dsp::preprocess_audio(&mut buffer, self.config.sample_rate); Ok(buffer) } } -pub fn preprocess_audio(samples: &mut Vec, sample_rate: u32) { - if samples.is_empty() || sample_rate == 0 { - return; - } - - let before_len = samples.len(); - let before = audio_stats(samples); - - remove_dc_offset(samples); - apply_highpass(samples, sample_rate, HIGHPASS_CUTOFF_HZ); - trim_silence(samples, sample_rate); - let gain = normalize_peak(samples); - - let after = audio_stats(samples); - tracing::debug!( - "audio preprocessing: len {} -> {}, rms {:.4} -> {:.4}, peak {:.4} -> {:.4}, gain {:.2}x", - before_len, - samples.len(), - before.rms, - after.rms, - before.peak, - after.peak, - gain - ); -} - -#[derive(Clone, Copy)] -struct AudioStats { - rms: f32, - peak: f32, -} - -fn audio_stats(samples: &[f32]) -> AudioStats { - if samples.is_empty() { - return AudioStats { - rms: 0.0, - peak: 0.0, - }; - } - - let mut peak = 0.0f32; - let mut energy = 0.0f32; - for sample in samples { - peak = peak.max(sample.abs()); - energy += sample * sample; - } - - AudioStats { - rms: (energy / samples.len() as f32).sqrt(), - peak, - } -} - -fn remove_dc_offset(samples: &mut [f32]) { - if samples.is_empty() { - return; - } - - let mean = samples.iter().copied().sum::() / samples.len() as f32; - if mean.abs() < 1e-6 { - return; - } - - for sample in samples { - *sample -= mean; - } -} - -fn apply_highpass(samples: &mut [f32], sample_rate: u32, cutoff_hz: f32) { - if samples.len() < 2 || sample_rate == 0 || cutoff_hz <= 0.0 { - return; - } - - let dt = 1.0 / sample_rate as f32; - let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff_hz); - let alpha = rc / (rc + dt); - - let mut previous_input = samples[0]; - let mut previous_output = 0.0f32; - samples[0] = 0.0; - - for sample in samples.iter_mut().skip(1) { - let input = *sample; - let output = alpha * (previous_output + input - previous_input); - *sample = output; - previous_input = input; - previous_output = output; - } -} - -fn trim_silence(samples: &mut Vec, sample_rate: u32) { - if samples.is_empty() || sample_rate == 0 { - return; - } - - let frame_len = ((sample_rate as usize * TRIM_FRAME_MS) / 1000).max(1); - if samples.len() <= frame_len * 2 { - return; - } - - let frame_rms: Vec = samples.chunks(frame_len).map(frame_rms).collect(); - let peak_rms = frame_rms.iter().copied().fold(0.0f32, f32::max); - if peak_rms <= 0.0 { - return; - } - - let threshold = (peak_rms * TRIM_RELATIVE_RMS).max(TRIM_MIN_RMS); - let Some(first_active) = frame_rms.iter().position(|rms| *rms >= threshold) else { - return; - }; - let Some(last_active) = frame_rms.iter().rposition(|rms| *rms >= threshold) else { - return; - }; - - let padding_samples = (sample_rate as usize * TRIM_PADDING_MS) / 1000; - let padding_frames = padding_samples.div_ceil(frame_len); - let start_frame = first_active.saturating_sub(padding_frames); - let end_frame = (last_active + 1 + padding_frames).min(frame_rms.len()); - - let start = start_frame.saturating_mul(frame_len); - let end = (end_frame.saturating_mul(frame_len)).min(samples.len()); - if start == 0 && end == samples.len() { - return; - } - if start >= end { - return; - } - - *samples = samples[start..end].to_vec(); -} - -fn frame_rms(frame: &[f32]) -> f32 { - if frame.is_empty() { - return 0.0; - } - - let energy = frame.iter().map(|sample| sample * sample).sum::(); - (energy / frame.len() as f32).sqrt() -} - -fn normalize_peak(samples: &mut [f32]) -> f32 { - let peak = samples.iter().copied().map(f32::abs).fold(0.0f32, f32::max); - if !(NORMALIZE_MIN_PEAK..NORMALIZE_TARGET_PEAK).contains(&peak) { - return 1.0; - } - - let gain = (NORMALIZE_TARGET_PEAK / peak).min(NORMALIZE_MAX_GAIN); - if gain <= 1.0 { - return 1.0; - } - - for sample in samples { - *sample = (*sample * gain).clamp(-1.0, 1.0); - } - - gain -} - fn choose_input_config( device: &cpal::Device, sample_rate: u32, @@ -368,7 +202,6 @@ fn choose_input_config( if format_score == 0 { continue; } - // Prefer mono (20), then fewer channels over more (penalty scales with count) let channel_score: u8 = if cfg.channels() == 1 { 20 } else { @@ -400,7 +233,7 @@ fn choose_input_config( }) } -fn append_mono_f32(data: &[f32], channels: usize, out: &mut Vec) { +pub(super) fn append_mono_f32(data: &[f32], channels: usize, out: &mut Vec) { if channels <= 1 { out.extend_from_slice(data); return; @@ -412,7 +245,7 @@ fn append_mono_f32(data: &[f32], channels: usize, out: &mut Vec) { } } -fn append_mono_i16(data: &[i16], channels: usize, out: &mut Vec) { +pub(super) fn append_mono_i16(data: &[i16], channels: usize, out: &mut Vec) { const I16_SCALE: f32 = 32768.0; if channels <= 1 { out.extend(data.iter().map(|s| *s as f32 / I16_SCALE)); @@ -425,7 +258,7 @@ fn append_mono_i16(data: &[i16], channels: usize, out: &mut Vec) { } } -fn append_mono_u16(data: &[u16], channels: usize, out: &mut Vec) { +pub(super) fn append_mono_u16(data: &[u16], channels: usize, out: &mut Vec) { if channels <= 1 { out.extend( data.iter() @@ -442,88 +275,3 @@ fn append_mono_u16(data: &[u16], channels: usize, out: &mut Vec) { out.push(sum / frame.len() as f32); } } - -#[cfg(test)] -mod tests { - use super::*; - - fn approx_eq(a: f32, b: f32, eps: f32) -> bool { - (a - b).abs() <= eps - } - - #[test] - fn append_mono_f32_passthrough_for_single_channel() { - let mut out = Vec::new(); - append_mono_f32(&[0.1, -0.2, 0.3], 1, &mut out); - assert_eq!(out, vec![0.1, -0.2, 0.3]); - } - - #[test] - fn append_mono_f32_downmixes_stereo() { - let mut out = Vec::new(); - append_mono_f32(&[1.0, -1.0, 0.5, 0.5], 2, &mut out); - assert!(approx_eq(out[0], 0.0, 1e-6)); - assert!(approx_eq(out[1], 0.5, 1e-6)); - } - - #[test] - fn append_mono_i16_converts_to_f32() { - let mut out = Vec::new(); - append_mono_i16(&[i16::MAX, i16::MIN], 1, &mut out); - assert!(approx_eq(out[0], 1.0, 1e-4)); - assert!(out[1] < -0.99); - } - - #[test] - fn append_mono_u16_downmixes_and_converts() { - let mut out = Vec::new(); - append_mono_u16(&[0, u16::MAX], 2, &mut out); - assert!(approx_eq(out[0], 0.0, 0.01)); - } - - #[test] - fn remove_dc_offset_centers_signal() { - let mut samples = vec![0.3, 0.5, 0.7]; - remove_dc_offset(&mut samples); - let mean = samples.iter().copied().sum::() / samples.len() as f32; - assert!(mean.abs() < 1e-6); - } - - #[test] - fn trim_silence_removes_outer_quiet_sections() { - let sample_rate = 1000; - let mut samples = vec![0.0; 120]; - samples.extend(std::iter::repeat_n(0.2, 200)); - samples.extend(vec![0.0; 120]); - - trim_silence(&mut samples, sample_rate); - - assert!(samples.len() < 440); - assert!(samples.len() >= 200); - assert!(samples.iter().any(|sample| sample.abs() >= 0.19)); - } - - #[test] - fn normalize_peak_boosts_quiet_audio_without_clipping() { - let mut samples = vec![0.2, -0.3, 0.4]; - let gain = normalize_peak(&mut samples); - let peak = samples.iter().copied().map(f32::abs).fold(0.0f32, f32::max); - - assert!(gain > 1.0); - assert!(approx_eq(peak, NORMALIZE_TARGET_PEAK, 1e-4)); - assert!(samples.iter().all(|sample| sample.abs() <= 1.0)); - } - - #[test] - fn preprocess_audio_reduces_leading_and_trailing_silence() { - let sample_rate = 16000; - let mut samples = vec![0.0; 1600]; - samples.extend((0..3200).map(|idx| if idx % 2 == 0 { 0.08 } else { -0.08 })); - samples.extend(vec![0.0; 1600]); - - preprocess_audio(&mut samples, sample_rate); - - assert!(samples.len() < 6400); - assert!(samples.iter().any(|sample| sample.abs() > 0.1)); - } -} diff --git a/src/audio/tests.rs b/src/audio/tests.rs new file mode 100644 index 0000000..58b296d --- /dev/null +++ b/src/audio/tests.rs @@ -0,0 +1,81 @@ +use super::{dsp, preprocess_audio, recorder}; + +fn approx_eq(a: f32, b: f32, eps: f32) -> bool { + (a - b).abs() <= eps +} + +#[test] +fn append_mono_f32_passthrough_for_single_channel() { + let mut out = Vec::new(); + recorder::append_mono_f32(&[0.1, -0.2, 0.3], 1, &mut out); + assert_eq!(out, vec![0.1, -0.2, 0.3]); +} + +#[test] +fn append_mono_f32_downmixes_stereo() { + let mut out = Vec::new(); + recorder::append_mono_f32(&[1.0, -1.0, 0.5, 0.5], 2, &mut out); + assert!(approx_eq(out[0], 0.0, 1e-6)); + assert!(approx_eq(out[1], 0.5, 1e-6)); +} + +#[test] +fn append_mono_i16_converts_to_f32() { + let mut out = Vec::new(); + recorder::append_mono_i16(&[i16::MAX, i16::MIN], 1, &mut out); + assert!(approx_eq(out[0], 1.0, 1e-4)); + assert!(out[1] < -0.99); +} + +#[test] +fn append_mono_u16_downmixes_and_converts() { + let mut out = Vec::new(); + recorder::append_mono_u16(&[0, u16::MAX], 2, &mut out); + assert!(approx_eq(out[0], 0.0, 0.01)); +} + +#[test] +fn remove_dc_offset_centers_signal() { + let mut samples = vec![0.3, 0.5, 0.7]; + dsp::remove_dc_offset(&mut samples); + let mean = samples.iter().copied().sum::() / samples.len() as f32; + assert!(mean.abs() < 1e-6); +} + +#[test] +fn trim_silence_removes_outer_quiet_sections() { + let sample_rate = 1000; + let mut samples = vec![0.0; 120]; + samples.extend(std::iter::repeat_n(0.2, 200)); + samples.extend(vec![0.0; 120]); + + dsp::trim_silence(&mut samples, sample_rate); + + assert!(samples.len() < 440); + assert!(samples.len() >= 200); + assert!(samples.iter().any(|sample| sample.abs() >= 0.19)); +} + +#[test] +fn normalize_peak_boosts_quiet_audio_without_clipping() { + let mut samples = vec![0.2, -0.3, 0.4]; + let gain = dsp::normalize_peak(&mut samples); + let peak = samples.iter().copied().map(f32::abs).fold(0.0f32, f32::max); + + assert!(gain > 1.0); + assert!(approx_eq(peak, dsp::NORMALIZE_TARGET_PEAK, 1e-4)); + assert!(samples.iter().all(|sample| sample.abs() <= 1.0)); +} + +#[test] +fn preprocess_audio_reduces_leading_and_trailing_silence() { + let sample_rate = 16000; + let mut samples = vec![0.0; 1600]; + samples.extend((0..3200).map(|idx| if idx % 2 == 0 { 0.08 } else { -0.08 })); + samples.extend(vec![0.0; 1600]); + + preprocess_audio(&mut samples, sample_rate); + + assert!(samples.len() < 6400); + assert!(samples.iter().any(|sample| sample.abs() > 0.1)); +} diff --git a/src/bin/whispers-osd.rs b/src/bin/whispers-osd.rs index ed47233..9fe2d92 100644 --- a/src/bin/whispers-osd.rs +++ b/src/bin/whispers-osd.rs @@ -1,63 +1,93 @@ #[path = "../branding.rs"] mod branding; +#[path = "../osd_protocol.rs"] +mod osd_protocol; use std::ffi::CString; +use std::io::{BufRead, BufReader}; use std::os::fd::AsRawFd; use std::os::unix::io::{AsFd, FromRawFd}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::sync::OnceLock; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::time::Instant; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use font8x8::{BASIC_FONTS, UnicodeFonts}; +use fontdue::Font; +use osd_protocol::{OsdEvent, VoiceOsdStatus, VoiceOsdUpdate}; use wayland_client::protocol::{ wl_buffer, wl_compositor, wl_registry, wl_shm, wl_shm_pool, wl_surface, }; use wayland_client::{Connection, Dispatch, QueueHandle, delegate_noop}; use wayland_protocols_wlr::layer_shell::v1::client::{zwlr_layer_shell_v1, zwlr_layer_surface_v1}; -// --- Layout --- -const NUM_BARS: usize = 28; +const NUM_BARS: usize = 12; const BAR_WIDTH: u32 = 3; const BAR_GAP: u32 = 2; -const PAD_X: u32 = 10; -const PAD_Y: u32 = 8; const BAR_MIN_HEIGHT: f32 = 2.0; -const BAR_MAX_HEIGHT: f32 = 30.0; -const OSD_WIDTH: u32 = PAD_X * 2 + NUM_BARS as u32 * BAR_WIDTH + (NUM_BARS as u32 - 1) * BAR_GAP; -const OSD_HEIGHT: u32 = BAR_MAX_HEIGHT as u32 + PAD_Y * 2; +const BAR_MAX_HEIGHT: f32 = 20.0; +const METER_WIDTH: u32 = 128; +const METER_HEIGHT: u32 = 72; +const VOICE_WIDTH: u32 = 760; +const VOICE_HEIGHT: u32 = 248; const MARGIN_BOTTOM: i32 = 40; -const CORNER_RADIUS: u32 = 12; const BORDER_WIDTH: u32 = 1; -const RISE_RATE: f32 = 0.55; -const DECAY_RATE: f32 = 0.88; - -// --- Animation --- +const RISE_RATE: f32 = 0.40; +const DECAY_RATE: f32 = 0.92; const FPS: i32 = 30; const FRAME_MS: i32 = 1000 / FPS; -// --- Colors --- -const BG_R: u8 = 18; -const BG_G: u8 = 18; -const BG_B: u8 = 30; -const BG_A: u8 = 185; - -const BORDER_R: u8 = 140; -const BORDER_G: u8 = 180; +const BG_R: u8 = 15; +const BG_G: u8 = 16; +const BG_B: u8 = 24; +const BG_A: u8 = 212; +const BORDER_R: u8 = 200; +const BORDER_G: u8 = 215; const BORDER_B: u8 = 255; -const BORDER_A: u8 = 40; - -// Bar gradient: teal → violet -const BAR_LEFT_R: f32 = 0.0; -const BAR_LEFT_G: f32 = 0.82; -const BAR_LEFT_B: f32 = 0.75; -const BAR_RIGHT_R: f32 = 0.65; -const BAR_RIGHT_G: f32 = 0.35; -const BAR_RIGHT_B: f32 = 1.0; +const BORDER_A: u8 = 28; + +const PANEL_R: u8 = 22; +const PANEL_G: u8 = 28; +const PANEL_B: u8 = 39; +const PANEL_A: u8 = 192; +const PANEL_BORDER_A: u8 = 18; +const TRACK_R: u8 = 28; +const TRACK_G: u8 = 36; +const TRACK_B: u8 = 50; +const TRACK_A: u8 = 218; +const HIGHLIGHT_A: u8 = 14; +const SHADOW_R: u8 = 3; +const SHADOW_G: u8 = 6; +const SHADOW_B: u8 = 10; + +const BAR_REST_R: f32 = 0.863; +const BAR_REST_G: f32 = 0.882; +const BAR_REST_B: f32 = 0.922; +const BAR_REST_A: f32 = 0.706; +const BAR_PEAK_R: f32 = 0.392; +const BAR_PEAK_G: f32 = 0.608; +const BAR_PEAK_B: f32 = 1.0; +const BAR_PEAK_A: f32 = 0.941; + +const TEXT_PRIMARY: (u8, u8, u8, u8) = (242, 246, 255, 230); +const TEXT_MUTED: (u8, u8, u8, u8) = (185, 196, 220, 200); +const TEXT_UNSTABLE: (u8, u8, u8, u8) = (115, 235, 226, 240); +const TEXT_REWRITE: (u8, u8, u8, u8) = (255, 208, 126, 220); +const TEXT_REWRITE_PRIMARY: (u8, u8, u8, u8) = (255, 236, 205, 240); +const TEXT_WARNING: (u8, u8, u8, u8) = (255, 153, 134, 235); +const ACCENT_STATUS: (u8, u8, u8) = (109, 236, 196); static SHOULD_EXIT: AtomicBool = AtomicBool::new(false); +static OSD_FONT: OnceLock> = OnceLock::new(); + +#[derive(Clone, Copy, PartialEq, Eq)] +enum OverlayMode { + Meter, + Voice, +} -// --- Audio state (shared with capture thread) --- struct AudioLevel { rms_bits: AtomicU32, } @@ -68,41 +98,43 @@ impl AudioLevel { rms_bits: AtomicU32::new(0), } } + fn set(&self, val: f32) { self.rms_bits.store(val.to_bits(), Ordering::Relaxed); } + fn get(&self) -> f32 { f32::from_bits(self.rms_bits.load(Ordering::Relaxed)) } } -// --- Bar animation state --- struct BarState { heights: [f32; NUM_BARS], + smooth_rms: f32, } impl BarState { fn new() -> Self { Self { heights: [BAR_MIN_HEIGHT; NUM_BARS], + smooth_rms: 0.0, } } fn update(&mut self, rms: f32, time: f32) { - // Amplify RMS for visual impact let level = (rms * 5.0).min(1.0); + let rms_target = (rms * 4.0).min(1.0); + self.smooth_rms = self.smooth_rms * 0.85 + rms_target * 0.15; for i in 0..NUM_BARS { let t = i as f32 / NUM_BARS as f32; - // Create wave pattern across bars, driven by audio level - let wave1 = (t * std::f32::consts::PI * 2.5 + time * 3.0).sin() * 0.5 + 0.5; - let wave2 = (t * std::f32::consts::PI * 1.3 - time * 1.8).sin() * 0.3 + 0.5; - let wave3 = (t * std::f32::consts::PI * 4.0 + time * 5.5).sin() * 0.2 + 0.5; + let wave1 = (t * std::f32::consts::PI * 2.0 + time * 2.0).sin() * 0.5 + 0.5; + let wave2 = (t * std::f32::consts::PI * 1.5 - time * 1.2).sin() * 0.3 + 0.5; + let wave3 = (t * std::f32::consts::PI * 3.5 + time * 4.0).sin() * 0.15 + 0.5; - let combined = (wave1 * 0.5 + wave2 * 0.3 + wave3 * 0.2) * level; + let combined = (wave1 * 0.55 + wave2 * 0.30 + wave3 * 0.15) * level; let target = BAR_MIN_HEIGHT + combined * (BAR_MAX_HEIGHT - BAR_MIN_HEIGHT); - // Smooth: fast rise, slow decay if target > self.heights[i] { self.heights[i] += (target - self.heights[i]) * RISE_RATE; } else { @@ -113,7 +145,42 @@ impl BarState { } } -// --- Wayland state --- +#[derive(Debug, Clone)] +struct VoiceOverlayState { + status: VoiceOsdStatus, + stable_text: String, + unstable_text: String, + rewrite_preview: Option, + live_inject: bool, + frozen: bool, +} + +impl Default for VoiceOverlayState { + fn default() -> Self { + Self { + status: VoiceOsdStatus::Listening, + stable_text: String::new(), + unstable_text: String::new(), + rewrite_preview: None, + live_inject: false, + frozen: false, + } + } +} + +impl From for VoiceOverlayState { + fn from(update: VoiceOsdUpdate) -> Self { + Self { + status: update.status, + stable_text: update.stable_text, + unstable_text: update.unstable_text, + rewrite_preview: update.rewrite_preview, + live_inject: update.live_inject, + frozen: update.frozen, + } + } +} + struct OsdState { running: bool, width: u32, @@ -144,23 +211,36 @@ fn main() -> Result<(), Box> { ); } + let mode = if std::env::args().any(|arg| arg == "--voice") { + OverlayMode::Voice + } else { + OverlayMode::Meter + }; + let (width, height) = match mode { + OverlayMode::Meter => (METER_WIDTH, METER_HEIGHT), + OverlayMode::Voice => (VOICE_WIDTH, VOICE_HEIGHT), + }; + let _ = std::fs::write(pid_file_path(), std::process::id().to_string()); - // Start audio capture for visualization let audio_level = Arc::new(AudioLevel::new()); let _audio_stream = start_audio_capture(Arc::clone(&audio_level)); - // Wayland setup + let voice_state = matches!(mode, OverlayMode::Voice) + .then(|| Arc::new(std::sync::Mutex::new(VoiceOverlayState::default()))); + let _voice_reader = voice_state + .as_ref() + .map(|state| start_voice_event_reader(Arc::clone(state))); + let conn = Connection::connect_to_env()?; let mut event_queue = conn.new_event_queue(); let qh = event_queue.handle(); - conn.display().get_registry(&qh, ()); let mut state = OsdState { running: true, - width: OSD_WIDTH, - height: OSD_HEIGHT, + width, + height, compositor: None, shm: None, layer_shell: None, @@ -172,7 +252,6 @@ fn main() -> Result<(), Box> { event_queue.roundtrip(&mut state)?; - // Create layer surface let compositor = state .compositor .as_ref() @@ -192,7 +271,7 @@ fn main() -> Result<(), Box> { (), ); - layer_surface.set_size(OSD_WIDTH, OSD_HEIGHT); + layer_surface.set_size(width, height); layer_surface.set_anchor(zwlr_layer_surface_v1::Anchor::Bottom); layer_surface.set_margin(0, 0, MARGIN_BOTTOM, 0); layer_surface.set_exclusive_zone(-1); @@ -201,19 +280,14 @@ fn main() -> Result<(), Box> { state.surface = Some(surface); state.layer_surface = Some(layer_surface); - event_queue.roundtrip(&mut state)?; - // Animation state let mut bars = BarState::new(); let start_time = Instant::now(); + let mut pixels = vec![0u8; (width * height * 4) as usize]; - // Reusable pixel buffer (avoids alloc/dealloc per frame) - let mut pixels = vec![0u8; (OSD_WIDTH * OSD_HEIGHT * 4) as usize]; - - // Persistent shm pool: create memfd + pool once, reuse each frame - let stride = OSD_WIDTH * 4; - let shm_size = (stride * OSD_HEIGHT) as i32; + let stride = width * 4; + let shm_size = (stride * height) as i32; let shm_name = CString::new(branding::OSD_BINARY).expect("valid OSD memfd name"); let shm_fd = unsafe { libc::memfd_create(shm_name.as_ptr(), libc::MFD_CLOEXEC) }; if shm_fd < 0 { @@ -227,7 +301,6 @@ fn main() -> Result<(), Box> { .ok_or("wl_shm not advertised by wayland server")?; let pool = shm.create_pool(shm_file.as_fd(), shm_size, &qh, ()); - // Main animation loop while state.running && !SHOULD_EXIT.load(Ordering::Relaxed) { conn.flush()?; @@ -253,24 +326,29 @@ fn main() -> Result<(), Box> { continue; } - // Update animation let time = start_time.elapsed().as_secs_f32(); let rms = audio_level.get(); bars.update(rms, time); - // Render frame into reusable buffer - let w = state.width; - let h = state.height; + let voice_snapshot = voice_state + .as_ref() + .and_then(|shared| shared.lock().ok().map(|state| state.clone())); + pixels.fill(0); - render_frame(&mut pixels, w, h, &bars, time); + render_frame( + &mut pixels, + state.width, + state.height, + &bars, + mode, + voice_snapshot.as_ref(), + ); - // Present frame using persistent shm pool - if let Err(e) = present_frame(&mut state, &qh, &pool, &shm_file, &pixels, w, h) { - eprintln!("frame dropped: {e}"); + if let Err(err) = present_frame(&mut state, &qh, &pool, &shm_file, &pixels, width, height) { + eprintln!("frame dropped: {err}"); } } - // Cleanup pool.destroy(); if let Some(ls) = state.layer_surface.take() { ls.destroy(); @@ -289,13 +367,33 @@ unsafe extern "C" fn handle_signal(_sig: libc::c_int) { SHOULD_EXIT.store(true, Ordering::Relaxed); } -// --- Audio capture --- +fn start_voice_event_reader( + state: Arc>, +) -> std::thread::JoinHandle<()> { + std::thread::spawn(move || { + let stdin = std::io::stdin(); + let reader = BufReader::new(stdin.lock()); + for line in reader.lines() { + let Ok(line) = line else { + break; + }; + if line.trim().is_empty() { + continue; + } + let Ok(event) = serde_json::from_str::(&line) else { + continue; + }; + let OsdEvent::VoiceUpdate(update) = event; + if let Ok(mut guard) = state.lock() { + *guard = update.into(); + } + } + }) +} fn start_audio_capture(level: Arc) -> Option { let host = cpal::default_host(); let device = host.default_input_device()?; - - // Try to find a supported config at 16kHz, preferring mono then fewer channels let config = device .supported_input_configs() .ok() @@ -324,7 +422,6 @@ fn start_audio_capture(level: Arc) -> Option { if data.is_empty() { return; } - // Downmix to mono if needed, then compute RMS let sample_count = data.len() / channels.max(1); if sample_count == 0 { return; @@ -351,65 +448,396 @@ fn start_audio_capture(level: Arc) -> Option { Some(stream) } -// --- Rendering --- +fn render_frame( + pixels: &mut [u8], + w: u32, + h: u32, + bars: &BarState, + mode: OverlayMode, + voice_state: Option<&VoiceOverlayState>, +) { + match mode { + OverlayMode::Meter => render_meter_overlay(pixels, w, h, bars), + OverlayMode::Voice => render_voice_overlay( + pixels, + w, + h, + bars, + voice_state.unwrap_or(&VoiceOverlayState::default()), + ), + } +} + +fn render_meter_overlay(pixels: &mut [u8], w: u32, h: u32, bars: &BarState) { + let shell_x = 8; + let shell_y = 10; + let shell_w = w.saturating_sub(shell_x * 2); + let shell_h = h.saturating_sub(shell_y * 2 + 2); + let shell_radius = shell_h / 2; -fn render_frame(pixels: &mut [u8], w: u32, h: u32, bars: &BarState, _time: f32) { - // Glassmorphic background - draw_rounded_rect( + draw_surface_shell( pixels, w, h, - 0, - 0, + shell_x, + shell_y, + shell_w, + shell_h, + shell_radius, + ); + + let track_x = shell_x + 14; + let track_y = shell_y + 10; + let track_w = shell_w.saturating_sub(28); + let track_h = shell_h.saturating_sub(20); + draw_track_shell( + pixels, w, h, - CORNER_RADIUS, - BG_R, - BG_G, - BG_B, - BG_A, + track_x, + track_y, + track_w, + track_h, + track_h / 2, ); - draw_rounded_border( + + let total_width = NUM_BARS as u32 * BAR_WIDTH + (NUM_BARS as u32 - 1) * BAR_GAP; + let start_x = track_x + track_w.saturating_sub(total_width) / 2; + render_meter_bars(pixels, w, h, bars, track_y + track_h / 2, start_x); +} + +fn render_voice_overlay( + pixels: &mut [u8], + w: u32, + h: u32, + bars: &BarState, + voice: &VoiceOverlayState, +) { + let shell_x = 12u32; + let shell_y = 10u32; + let shell_w = w.saturating_sub(shell_x * 2); + let shell_h = h.saturating_sub(shell_y * 2 + 2); + let shell_radius = 24u32; + draw_surface_shell( pixels, w, h, - CORNER_RADIUS, - BORDER_WIDTH, - BORDER_R, - BORDER_G, - BORDER_B, - BORDER_A, + shell_x, + shell_y, + shell_w, + shell_h, + shell_radius, + ); + + let pad = (shell_x + 18) as i32; + let header_y = shell_y as i32 + 16; + let transcript_panel_y = shell_y + 42; + let transcript_panel_h = 110u32; + let panel_radius = 18u32; + let transcript_y = transcript_panel_y as i32 + 18; + let status_font = 16.0; + let badge_font = 13.0; + let transcript_font = 20.0; + let footer_font = 14.0; + let transcript_line_height = line_height(transcript_font) + 4; + let footer_line_height = line_height(footer_font) + 2; + let transcript_width = shell_w.saturating_sub(72); + let raw_live_text = combined_voice_text(&voice.stable_text, &voice.unstable_text); + let rewrite_available = voice + .rewrite_preview + .as_deref() + .map(str::trim) + .filter(|text| !text.is_empty()); + + let status_label = status_label(voice.status); + draw_text( + pixels, + w, + h, + pad, + header_y, + status_font, + status_label, + TEXT_PRIMARY, + ); + + let badge_text = if voice.live_inject { + "LIVE INJECT" + } else { + "PREVIEW ONLY" + }; + let badge_width = text_width(badge_text, badge_font) + 22; + let badge_height = 24u32; + let badge_x = shell_x + shell_w.saturating_sub(badge_width + 18); + let badge_y = shell_y + 13; + let badge_rgb = if voice.live_inject { + (TEXT_UNSTABLE.0, TEXT_UNSTABLE.1, TEXT_UNSTABLE.2) + } else { + (TEXT_MUTED.0, TEXT_MUTED.1, TEXT_MUTED.2) + }; + draw_chip( + pixels, + w, + h, + badge_x, + badge_y, + badge_width, + badge_height, + badge_rgb.0, + badge_rgb.1, + badge_rgb.2, + ); + draw_text( + pixels, + w, + h, + badge_x as i32 + 11, + badge_y as i32 + 5, + badge_font, + badge_text, + if voice.live_inject { + TEXT_UNSTABLE + } else { + TEXT_MUTED + }, + ); + + draw_panel_shell( + pixels, + w, + h, + shell_x + 18, + transcript_panel_y, + shell_w.saturating_sub(36), + transcript_panel_h, + panel_radius, + ); + + if let Some(rewrite_text) = rewrite_available { + let preview_label = "Live rewrite preview"; + draw_text( + pixels, + w, + h, + pad, + transcript_y - footer_line_height - 4, + footer_font, + preview_label, + TEXT_REWRITE, + ); + + let rewrite_lines = wrap_text(rewrite_text, transcript_width, 4, transcript_font); + for (idx, line) in rewrite_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + pad, + transcript_y + idx as i32 * transcript_line_height, + transcript_font, + line, + TEXT_REWRITE_PRIMARY, + ); + } + } else { + let stable_lines = wrap_text(&voice.stable_text, transcript_width, 3, transcript_font); + let unstable_lines = wrap_text(&voice.unstable_text, transcript_width, 2, transcript_font); + + if stable_lines.is_empty() && unstable_lines.is_empty() { + draw_text( + pixels, + w, + h, + pad, + transcript_y, + transcript_font, + "Listening for speech...", + TEXT_MUTED, + ); + } else { + for (idx, line) in stable_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + pad, + transcript_y + idx as i32 * transcript_line_height, + transcript_font, + line, + TEXT_PRIMARY, + ); + } + + let unstable_y = transcript_y + stable_lines.len() as i32 * transcript_line_height; + for (idx, line) in unstable_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + pad, + unstable_y + idx as i32 * transcript_line_height, + transcript_font, + line, + TEXT_UNSTABLE, + ); + } + } + } + + let footer_panel_x = shell_x + 18; + let footer_panel_y = shell_y + shell_h.saturating_sub(74); + let footer_panel_w = shell_w.saturating_sub(36); + let footer_panel_h = 26u32; + draw_panel_shell( + pixels, + w, + h, + footer_panel_x, + footer_panel_y, + footer_panel_w, + footer_panel_h, + 13, ); - // Top highlight (glass reflection) - for x in (CORNER_RADIUS + 2)..(w.saturating_sub(CORNER_RADIUS + 2)) { - set_pixel_blend(pixels, w, h, x, 1, 255, 255, 255, 18); + if voice.frozen { + let warning_lines = wrap_text( + "Focus changed. Live injection is frozen for this take.", + footer_panel_w.saturating_sub(24), + 1, + footer_font, + ); + for (idx, line) in warning_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + footer_panel_x as i32 + 12, + footer_panel_y as i32 + 6 + idx as i32 * footer_line_height, + footer_font, + line, + TEXT_WARNING, + ); + } + } else if rewrite_available.is_some() { + draw_text( + pixels, + w, + h, + footer_panel_x as i32 + 12, + footer_panel_y as i32 + 6, + footer_font, + "Raw live hypothesis", + TEXT_MUTED, + ); + let raw_lines = wrap_text( + &raw_live_text, + footer_panel_w.saturating_sub(150), + 1, + footer_font, + ); + for (idx, line) in raw_lines.iter().enumerate() { + draw_text( + pixels, + w, + h, + footer_panel_x as i32 + 132, + footer_panel_y as i32 + 6 + idx as i32 * footer_line_height, + footer_font, + line, + TEXT_MUTED, + ); + } } - // Visualizer bars - let center_y = h / 2; + let track_x = shell_x + 18; + let track_y = shell_y + shell_h.saturating_sub(38); + let track_w = shell_w.saturating_sub(36); + let track_h = 14u32; + draw_track_shell( + pixels, + w, + h, + track_x, + track_y, + track_w, + track_h, + track_h / 2, + ); + let total_width = NUM_BARS as u32 * BAR_WIDTH + (NUM_BARS as u32 - 1) * BAR_GAP; + let start_x = track_x + track_w.saturating_sub(total_width) / 2; + render_voice_bars(pixels, w, h, bars, track_y + track_h / 2, start_x); +} + +fn render_meter_bars( + pixels: &mut [u8], + w: u32, + h: u32, + bars: &BarState, + center_y: u32, + start_x: u32, +) { + let color_t = bars.smooth_rms.clamp(0.0, 1.0); + let cr = (lerp(BAR_REST_R, BAR_PEAK_R, color_t) * 255.0) as u8; + let cg = (lerp(BAR_REST_G, BAR_PEAK_G, color_t) * 255.0) as u8; + let cb = (lerp(BAR_REST_B, BAR_PEAK_B, color_t) * 255.0) as u8; + let base_alpha = lerp(BAR_REST_A, BAR_PEAK_A, color_t); + + let glow_expand = 1 + (color_t * 2.0) as u32; + let glow_alpha = (15.0 + color_t * 25.0) as u8; + for i in 0..NUM_BARS { - let bx = PAD_X + i as u32 * (BAR_WIDTH + BAR_GAP); + let bx = start_x + i as u32 * (BAR_WIDTH + BAR_GAP); + let bar_h = bars.heights[i] as u32; + let half_h = bar_h / 2; + let top_y = center_y.saturating_sub(half_h); + + for gy in top_y.saturating_sub(glow_expand) + ..=(top_y + bar_h + glow_expand).min(h.saturating_sub(1)) + { + for gx in bx.saturating_sub(glow_expand) + ..=(bx + BAR_WIDTH + glow_expand).min(w.saturating_sub(1)) + { + set_pixel_blend(pixels, w, h, gx, gy, cr, cg, cb, glow_alpha); + } + } + + for y in top_y..(top_y + bar_h).min(h) { + let vy = (y as f32 - top_y as f32) / bar_h.max(1) as f32; + let brightness = 1.0 - (vy - 0.5).abs() * 0.5; + let a = (brightness * base_alpha * 255.0) as u8; + for x in bx..(bx + BAR_WIDTH).min(w) { + set_pixel_blend(pixels, w, h, x, y, cr, cg, cb, a); + } + } + } +} + +fn render_voice_bars( + pixels: &mut [u8], + w: u32, + h: u32, + bars: &BarState, + center_y: u32, + start_x: u32, +) { + for i in 0..NUM_BARS { + let bx = start_x + i as u32 * (BAR_WIDTH + BAR_GAP); let bar_h = bars.heights[i] as u32; let half_h = bar_h / 2; let top_y = center_y.saturating_sub(half_h); let t = i as f32 / (NUM_BARS - 1) as f32; - let r = lerp(BAR_LEFT_R, BAR_RIGHT_R, t); - let g = lerp(BAR_LEFT_G, BAR_RIGHT_G, t); - let b = lerp(BAR_LEFT_B, BAR_RIGHT_B, t); - let cr = (r * 255.0) as u8; - let cg = (g * 255.0) as u8; - let cb = (b * 255.0) as u8; - - // Glow - for gy in top_y.saturating_sub(2)..=(top_y + bar_h + 2).min(h - 1) { - for gx in bx.saturating_sub(1)..=(bx + BAR_WIDTH).min(w - 1) { - set_pixel_blend(pixels, w, h, gx, gy, cr, cg, cb, 25); + let focus = 1.0 - ((t - 0.5).abs() * 1.6).clamp(0.0, 1.0); + let cr = lerp(130.0, ACCENT_STATUS.0 as f32, focus) as u8; + let cg = lerp(176.0, ACCENT_STATUS.1 as f32, focus) as u8; + let cb = lerp(224.0, 255.0, focus) as u8; + + for gy in top_y.saturating_sub(2)..=(top_y + bar_h + 2).min(h.saturating_sub(1)) { + for gx in bx.saturating_sub(1)..=(bx + BAR_WIDTH).min(w.saturating_sub(1)) { + set_pixel_blend(pixels, w, h, gx, gy, cr, cg, cb, 22); } } - // Bar body with vertical brightness gradient for y in top_y..(top_y + bar_h).min(h) { let vy = (y as f32 - top_y as f32) / bar_h.max(1) as f32; let brightness = 1.0 - (vy - 0.5).abs() * 0.6; @@ -421,6 +849,470 @@ fn render_frame(pixels: &mut [u8], w: u32, h: u32, bars: &BarState, _time: f32) } } +#[allow(clippy::too_many_arguments)] +fn draw_surface_shell( + pixels: &mut [u8], + w: u32, + h: u32, + x0: u32, + y0: u32, + rect_w: u32, + rect_h: u32, + radius: u32, +) { + draw_shadow(pixels, w, h, x0, y0, rect_w, rect_h, radius); + draw_rounded_rect( + pixels, w, h, x0, y0, rect_w, rect_h, radius, BG_R, BG_G, BG_B, BG_A, + ); + draw_rounded_border_rect( + pixels, + w, + h, + x0, + y0, + rect_w, + rect_h, + radius, + BORDER_WIDTH, + BORDER_R, + BORDER_G, + BORDER_B, + BORDER_A, + ); + let highlight_y = y0 + 1; + for x in (x0 + radius / 2)..(x0 + rect_w).saturating_sub(radius / 2) { + set_pixel_blend(pixels, w, h, x, highlight_y, 255, 255, 255, HIGHLIGHT_A); + } +} + +#[allow(clippy::too_many_arguments)] +fn draw_panel_shell( + pixels: &mut [u8], + w: u32, + h: u32, + x0: u32, + y0: u32, + rect_w: u32, + rect_h: u32, + radius: u32, +) { + draw_rounded_rect( + pixels, w, h, x0, y0, rect_w, rect_h, radius, PANEL_R, PANEL_G, PANEL_B, PANEL_A, + ); + draw_rounded_border_rect( + pixels, + w, + h, + x0, + y0, + rect_w, + rect_h, + radius, + BORDER_WIDTH, + BORDER_R, + BORDER_G, + BORDER_B, + PANEL_BORDER_A, + ); +} + +#[allow(clippy::too_many_arguments)] +fn draw_track_shell( + pixels: &mut [u8], + w: u32, + h: u32, + x0: u32, + y0: u32, + rect_w: u32, + rect_h: u32, + radius: u32, +) { + draw_rounded_rect( + pixels, w, h, x0, y0, rect_w, rect_h, radius, TRACK_R, TRACK_G, TRACK_B, TRACK_A, + ); + draw_rounded_border_rect( + pixels, + w, + h, + x0, + y0, + rect_w, + rect_h, + radius, + BORDER_WIDTH, + BORDER_R, + BORDER_G, + BORDER_B, + 14, + ); +} + +#[allow(clippy::too_many_arguments)] +fn draw_chip( + pixels: &mut [u8], + w: u32, + h: u32, + x0: u32, + y0: u32, + rect_w: u32, + rect_h: u32, + r: u8, + g: u8, + b: u8, +) { + draw_rounded_rect( + pixels, + w, + h, + x0, + y0, + rect_w, + rect_h, + rect_h / 2, + r, + g, + b, + 26, + ); + draw_rounded_border_rect( + pixels, + w, + h, + x0, + y0, + rect_w, + rect_h, + rect_h / 2, + BORDER_WIDTH, + r, + g, + b, + 44, + ); +} + +#[allow(clippy::too_many_arguments)] +fn draw_shadow( + pixels: &mut [u8], + w: u32, + h: u32, + x0: u32, + y0: u32, + rect_w: u32, + rect_h: u32, + radius: u32, +) { + for spread in (1..=7).rev() { + let alpha = 4 + (7 - spread) as u8 * 3; + draw_rounded_rect( + pixels, + w, + h, + x0.saturating_sub(spread), + y0 + spread / 2, + rect_w + spread * 2, + rect_h + spread, + radius + spread, + SHADOW_R, + SHADOW_G, + SHADOW_B, + alpha, + ); + } +} + +fn status_label(status: VoiceOsdStatus) -> &'static str { + match status { + VoiceOsdStatus::Listening => "Listening", + VoiceOsdStatus::Transcribing => "Transcribing", + VoiceOsdStatus::Rewriting => "Rewriting", + VoiceOsdStatus::Finalizing => "Finalizing", + VoiceOsdStatus::Frozen => "Frozen", + } +} + +fn load_osd_font() -> Option { + const FONT_CANDIDATES: &[&str] = &[ + "/usr/share/fonts/noto/NotoSans-Regular.ttf", + "/usr/share/fonts/noto/NotoSans-Medium.ttf", + "/usr/share/fonts/TTF/NotoSansMNerdFont-Regular.ttf", + "/usr/share/fonts/liberation/LiberationSans-Regular.ttf", + "/usr/share/fonts/TTF/ArimoNerdFont-Regular.ttf", + "/usr/share/fonts/TTF/DejaVuSansMNerdFont-Regular.ttf", + ]; + + FONT_CANDIDATES + .iter() + .find_map(|path| load_font_from_path(path)) +} + +fn load_font_from_path(path: &str) -> Option { + let bytes = std::fs::read(Path::new(path)).ok()?; + Font::from_bytes(bytes, fontdue::FontSettings::default()).ok() +} + +fn osd_font() -> Option<&'static Font> { + OSD_FONT.get_or_init(load_osd_font).as_ref() +} + +fn line_height(px_size: f32) -> i32 { + if let Some(font) = osd_font() { + if let Some(metrics) = font.horizontal_line_metrics(px_size) { + return metrics.new_line_size.ceil().max(px_size.ceil()) as i32; + } + } + let scale = ((px_size / 8.0).round() as i32).max(1); + 8 * scale + scale + 2 +} + +fn wrap_text(text: &str, max_width_px: u32, max_lines: usize, px_size: f32) -> Vec { + if max_width_px == 0 || max_lines == 0 { + return Vec::new(); + } + + let text = sanitize_text(text); + if text.is_empty() { + return Vec::new(); + } + + let mut lines = Vec::new(); + let mut current = String::new(); + + for word in text.split_whitespace() { + if word.is_empty() { + continue; + } + let candidate = if current.is_empty() { + word.to_string() + } else { + format!("{current} {word}") + }; + + if text_width(&candidate, px_size) <= max_width_px { + current = candidate; + continue; + } + + if !current.is_empty() { + lines.push(current); + } + + current = truncate_word_to_width(word, max_width_px, px_size); + } + + if !current.is_empty() { + lines.push(current); + } + + if lines.len() > max_lines { + let mut tail = lines.split_off(lines.len() - max_lines); + if let Some(first) = tail.first_mut() { + *first = fit_text_to_width(&format!("…{first}"), max_width_px, px_size, true); + } + return tail; + } + + lines +} + +fn combined_voice_text(stable_text: &str, unstable_text: &str) -> String { + match (stable_text.trim(), unstable_text.trim()) { + ("", "") => String::new(), + ("", unstable) => unstable.to_string(), + (stable, "") => stable.to_string(), + (stable, unstable) => format!("{stable} {unstable}"), + } +} + +fn sanitize_text(text: &str) -> String { + let mut normalized = String::new(); + let mut pending_space = false; + + for ch in text.chars() { + if ch.is_control() && !ch.is_whitespace() { + continue; + } + if ch.is_whitespace() { + if !normalized.is_empty() { + pending_space = true; + } + continue; + } + if pending_space { + normalized.push(' '); + pending_space = false; + } + normalized.push(ch); + } + + normalized +} + +fn truncate_word_to_width(word: &str, max_width_px: u32, px_size: f32) -> String { + fit_text_to_width(word, max_width_px, px_size, false) +} + +fn fit_text_to_width(text: &str, max_width_px: u32, px_size: f32, keep_tail: bool) -> String { + let sanitized = sanitize_text(text); + if sanitized.is_empty() { + return String::new(); + } + if text_width(&sanitized, px_size) <= max_width_px { + return sanitized; + } + + let ellipsis = "…"; + let chars: Vec = sanitized.chars().collect(); + if keep_tail { + for start in 0..chars.len() { + let candidate: String = std::iter::once('…') + .chain(chars[start..].iter().copied()) + .collect(); + if text_width(&candidate, px_size) <= max_width_px { + return candidate; + } + } + } else { + let mut candidate = String::new(); + for ch in chars { + let next = if candidate.is_empty() { + format!("{ch}{ellipsis}") + } else { + format!("{candidate}{ch}{ellipsis}") + }; + if text_width(&next, px_size) > max_width_px { + break; + } + candidate.push(ch); + } + if !candidate.is_empty() { + candidate.push('…'); + return candidate; + } + } + + ellipsis.to_string() +} + +fn text_width(text: &str, px_size: f32) -> u32 { + let sanitized = sanitize_text(text); + if sanitized.is_empty() { + return 0; + } + if let Some(font) = osd_font() { + let width: f32 = sanitized + .chars() + .map(|ch| font.metrics(ch, px_size).advance_width) + .sum(); + return width.ceil().max(0.0) as u32; + } + + let scale = ((px_size / 8.0).round() as i32).max(1); + let glyph_w = (8 * scale + scale).max(1) as u32; + sanitized.chars().count() as u32 * glyph_w +} + +#[allow(clippy::too_many_arguments)] +fn draw_text( + pixels: &mut [u8], + w: u32, + h: u32, + x: i32, + y: i32, + px_size: f32, + text: &str, + color: (u8, u8, u8, u8), +) { + let sanitized = sanitize_text(text); + if sanitized.is_empty() { + return; + } + + if let Some(font) = osd_font() { + let baseline = y as f32 + + font + .horizontal_line_metrics(px_size) + .map(|metrics| metrics.ascent) + .unwrap_or(px_size * 0.92); + let mut pen_x = x as f32; + for ch in sanitized.chars() { + let (metrics, bitmap) = font.rasterize(ch, px_size); + let glyph_x = pen_x.round() as i32 + metrics.xmin; + let glyph_y = (baseline - metrics.height as f32 - metrics.ymin as f32).round() as i32; + for row in 0..metrics.height { + for col in 0..metrics.width { + let alpha = bitmap[row * metrics.width + col]; + if alpha == 0 { + continue; + } + let px = glyph_x + col as i32; + let py = glyph_y + row as i32; + if px >= 0 && py >= 0 { + let blended_alpha = ((alpha as u16 * color.3 as u16) / 255) as u8; + set_pixel_blend( + pixels, + w, + h, + px as u32, + py as u32, + color.0, + color.1, + color.2, + blended_alpha, + ); + } + } + } + pen_x += metrics.advance_width; + } + return; + } + + let scale = ((px_size / 8.0).round() as i32).max(1); + let mut cursor_x = x; + let glyph_advance = (8 * scale + scale).max(1); + for ch in sanitized.chars() { + draw_bitmap_char(pixels, w, h, cursor_x, y, scale, ch, color); + cursor_x += glyph_advance; + } +} + +#[allow(clippy::too_many_arguments)] +fn draw_bitmap_char( + pixels: &mut [u8], + w: u32, + h: u32, + x: i32, + y: i32, + scale: i32, + ch: char, + color: (u8, u8, u8, u8), +) { + let glyph = BASIC_FONTS.get(ch).or_else(|| BASIC_FONTS.get('?')); + let Some(glyph) = glyph else { + return; + }; + + for (row_idx, row_bits) in glyph.iter().enumerate() { + for col_idx in 0..8 { + if (row_bits >> col_idx) & 1 == 0 { + continue; + } + for sy in 0..scale.max(1) { + for sx in 0..scale.max(1) { + let px = x + (col_idx * scale as usize + sx as usize) as i32; + let py = y + (row_idx * scale as usize + sy as usize) as i32; + if px >= 0 && py >= 0 { + set_pixel_blend( + pixels, w, h, px as u32, py as u32, color.0, color.1, color.2, color.3, + ); + } + } + } + } + } +} + fn present_frame( state: &mut OsdState, qh: &QueueHandle, @@ -437,7 +1329,6 @@ fn present_frame( writer.seek(std::io::SeekFrom::Start(0))?; writer.write_all(pixels)?; - // Destroy previous buffer if let Some(old) = state.buffer.take() { old.destroy(); } @@ -464,8 +1355,6 @@ fn present_frame( Ok(()) } -// --- Drawing primitives --- - #[allow(clippy::too_many_arguments)] #[inline] fn set_pixel_blend(pixels: &mut [u8], w: u32, h: u32, x: u32, y: u32, r: u8, g: u8, b: u8, a: u8) { @@ -474,7 +1363,6 @@ fn set_pixel_blend(pixels: &mut [u8], w: u32, h: u32, x: u32, y: u32, r: u8, g: } let idx = ((y * w + x) * 4) as usize; if a == 255 { - // Premultiplied: BGRA pixels[idx] = b; pixels[idx + 1] = g; pixels[idx + 2] = r; @@ -483,7 +1371,6 @@ fn set_pixel_blend(pixels: &mut [u8], w: u32, h: u32, x: u32, y: u32, r: u8, g: } let sa = a as u32; let inv = 255 - sa; - // Premultiply source, blend with existing premultiplied dest pixels[idx] = ((sa * b as u32 + inv * pixels[idx] as u32) / 255) as u8; pixels[idx + 1] = ((sa * g as u32 + inv * pixels[idx + 1] as u32) / 255) as u8; pixels[idx + 2] = ((sa * r as u32 + inv * pixels[idx + 2] as u32) / 255) as u8; @@ -517,10 +1404,14 @@ fn draw_rounded_rect( } #[allow(clippy::too_many_arguments)] -fn draw_rounded_border( +fn draw_rounded_border_rect( pixels: &mut [u8], w: u32, h: u32, + x0: u32, + y0: u32, + rect_w: u32, + rect_h: u32, radius: u32, thickness: u32, r: u8, @@ -528,18 +1419,18 @@ fn draw_rounded_border( b: u8, a: u8, ) { - for y in 0..h { - for x in 0..w { - let inside_outer = is_inside_rounded_rect(x, y, w, h, radius); - let inside_inner = x >= thickness - && y >= thickness - && x < w - thickness - && y < h - thickness + for y in y0..y0 + rect_h { + for x in x0..x0 + rect_w { + let inside_outer = is_inside_rounded_rect(x - x0, y - y0, rect_w, rect_h, radius); + let inside_inner = x >= x0 + thickness + && y >= y0 + thickness + && x < x0 + rect_w - thickness + && y < y0 + rect_h - thickness && is_inside_rounded_rect( - x - thickness, - y - thickness, - w - 2 * thickness, - h - 2 * thickness, + x - x0 - thickness, + y - y0 - thickness, + rect_w - 2 * thickness, + rect_h - 2 * thickness, radius.saturating_sub(thickness), ); if inside_outer && !inside_inner { @@ -553,7 +1444,6 @@ fn is_inside_rounded_rect(x: u32, y: u32, w: u32, h: u32, r: u32) -> bool { if r == 0 || w == 0 || h == 0 { return x < w && y < h; } - // Check only corner regions let in_left = x < r; let in_right = x >= w - r; let in_top = y < r; @@ -587,8 +1477,6 @@ fn lerp(a: f32, b: f32, t: f32) -> f32 { a + (b - a) * t } -// --- Dispatch implementations --- - impl Dispatch for OsdState { fn event( state: &mut Self, diff --git a/src/bin/whispers-rewrite-worker/branding.rs b/src/bin/whispers-rewrite-worker/branding.rs new file mode 100644 index 0000000..056524e --- /dev/null +++ b/src/bin/whispers-rewrite-worker/branding.rs @@ -0,0 +1,52 @@ +#![allow(dead_code)] + +use std::path::PathBuf; + +pub const APP_NAME: &str = "whispers"; +pub const MAIN_BINARY: &str = APP_NAME; + +#[cfg(feature = "osd")] +pub const OSD_BINARY: &str = "whispers-osd"; + +pub const REWRITE_WORKER_BINARY: &str = "whispers-rewrite-worker"; + +pub const REWRITE_WORKER_ENV: &str = "WHISPERS_REWRITE_WORKER"; + +pub const MAIN_PID_FILE: &str = "whispers.pid"; +#[cfg(feature = "osd")] +pub const OSD_PID_FILE: &str = "whispers-osd.pid"; + +pub const LOG_TARGET: &str = "whispers"; +pub const UINPUT_KEYBOARD_NAME: &str = "whispers-keyboard"; + +pub fn resolve_sidecar_executable(candidates: &[&str]) -> PathBuf { + if let Ok(current_exe) = std::env::current_exe() { + if let Some(dir) = current_exe.parent() { + for candidate in candidates { + let path = dir.join(candidate); + if path.exists() { + return path; + } + } + } + } + + for candidate in candidates { + if let Some(path) = executable_in_path(candidate) { + return path; + } + } + + PathBuf::from(candidates[0]) +} + +fn executable_in_path(name: &str) -> Option { + let path = std::env::var_os("PATH")?; + for dir in std::env::split_paths(&path) { + let candidate = dir.join(name); + if candidate.is_file() { + return Some(candidate); + } + } + None +} diff --git a/src/bin/whispers-rewrite-worker/local.rs b/src/bin/whispers-rewrite-worker/local.rs new file mode 100644 index 0000000..28db697 --- /dev/null +++ b/src/bin/whispers-rewrite-worker/local.rs @@ -0,0 +1,286 @@ +use std::num::NonZeroU32; +use std::path::Path; +use std::sync::OnceLock; + +use encoding_rs::UTF_8; +use llama_cpp_2::context::params::LlamaContextParams; +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::params::LlamaModelParams; +use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaChatTemplate, LlamaModel}; +use llama_cpp_2::openai::OpenAIChatTemplateParams; +use llama_cpp_2::sampling::LlamaSampler; + +use crate::output::sanitize_rewrite_output; +use crate::prompt::{ + RewritePrompt, build_oaicompat_messages_json, build_prompt, effective_max_tokens, +}; +use crate::rewrite_profile::ResolvedRewriteProfile; +use crate::rewrite_protocol::RewriteTranscript; + +pub(crate) struct LocalRewriter { + model: LlamaModel, + chat_template: LlamaChatTemplate, + profile: ResolvedRewriteProfile, + max_tokens: usize, + max_output_chars: usize, +} + +static LLAMA_BACKEND: OnceLock<&'static LlamaBackend> = OnceLock::new(); +static EXTERNAL_LLAMA_BACKEND: LlamaBackend = LlamaBackend {}; + +impl LocalRewriter { + pub fn new( + model_path: &Path, + profile: ResolvedRewriteProfile, + max_tokens: usize, + max_output_chars: usize, + ) -> std::result::Result { + if !model_path.exists() { + return Err(format!( + "rewrite model file not found: {}", + model_path.display() + )); + } + + let backend = llama_backend()?; + + let mut model_params = LlamaModelParams::default(); + if cfg!(feature = "cuda") { + model_params = model_params.with_n_gpu_layers(1000); + } + + let model = LlamaModel::load_from_file(backend, model_path, &model_params) + .map_err(|e| format!("failed to load rewrite model: {e}"))?; + let chat_template = model + .chat_template(None) + .map_err(|e| format!("rewrite model does not expose a usable chat template: {e}"))?; + + Ok(Self { + model, + chat_template, + profile, + max_tokens, + max_output_chars, + }) + } + + pub fn rewrite_with_instructions( + &self, + transcript: &RewriteTranscript, + custom_instructions: Option<&str>, + ) -> std::result::Result { + if transcript.raw_text.trim().is_empty() { + return Ok(String::new()); + } + + let prompt = build_rewrite_prompt( + &self.model, + &self.chat_template, + transcript, + self.profile, + custom_instructions, + )?; + let effective_max_tokens = effective_max_tokens(self.max_tokens, transcript); + let prompt_tokens = self + .model + .str_to_token(&prompt, AddBos::Never) + .map_err(|e| format!("failed to tokenize rewrite prompt: {e}"))?; + let behavior = rewrite_behavior(self.profile); + + let n_ctx_tokens = prompt_tokens + .len() + .saturating_add(effective_max_tokens) + .saturating_add(64) + .max(2048) + .min(u32::MAX as usize) as u32; + let n_batch = prompt_tokens + .len() + .max(512) + .min(n_ctx_tokens as usize) + .min(u32::MAX as usize) as u32; + let threads = std::thread::available_parallelism() + .map(|threads| threads.get()) + .unwrap_or(4) + .clamp(1, i32::MAX as usize) as i32; + + let ctx_params = LlamaContextParams::default() + .with_n_ctx(NonZeroU32::new(n_ctx_tokens)) + .with_n_batch(n_batch) + .with_n_ubatch(n_batch) + .with_n_threads(threads) + .with_n_threads_batch(threads); + let backend = llama_backend()?; + let mut ctx = self + .model + .new_context(backend, ctx_params) + .map_err(|e| format!("failed to create rewrite context: {e}"))?; + + let mut prompt_batch = LlamaBatch::new(prompt_tokens.len(), 1); + prompt_batch + .add_sequence(&prompt_tokens, 0, false) + .map_err(|e| format!("failed to enqueue rewrite prompt: {e}"))?; + ctx.decode(&mut prompt_batch) + .map_err(|e| format!("failed to decode rewrite prompt: {e}"))?; + + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::top_k(behavior.top_k), + LlamaSampler::top_p(behavior.top_p, 1), + LlamaSampler::temp(behavior.temperature), + LlamaSampler::greedy(), + ]); + sampler.accept_many(prompt_tokens.iter()); + + let mut decoder = UTF_8.new_decoder_without_bom_handling(); + let mut output = String::new(); + let start_pos = i32::try_from(prompt_tokens.len()).unwrap_or(i32::MAX); + + for i in 0..effective_max_tokens { + let mut candidates = ctx.token_data_array(); + candidates.apply_sampler(&sampler); + let token = candidates + .selected_token() + .ok_or_else(|| "rewrite sampler did not select a token".to_string())?; + + if token == self.model.token_eos() { + break; + } + + sampler.accept(token); + + let piece = self + .model + .token_to_piece(token, &mut decoder, true, None) + .map_err(|e| format!("failed to decode rewrite token: {e}"))?; + output.push_str(&piece); + + if output.contains("") || output.len() >= self.max_output_chars { + break; + } + + let mut batch = LlamaBatch::new(1, 1); + batch + .add( + token, + start_pos.saturating_add(i32::try_from(i).unwrap_or(i32::MAX)), + &[0], + true, + ) + .map_err(|e| format!("failed to enqueue rewrite token: {e}"))?; + ctx.decode(&mut batch) + .map_err(|e| format!("failed to decode rewrite token: {e}"))?; + } + + let rewritten = sanitize_rewrite_output(&output); + if rewritten.is_empty() { + return Err("rewrite model returned empty output".into()); + } + + Ok(rewritten) + } +} + +fn llama_backend() -> std::result::Result<&'static LlamaBackend, String> { + if let Some(backend) = LLAMA_BACKEND.get().copied() { + return Ok(backend); + } + + match LlamaBackend::init() { + Ok(backend) => { + let backend = Box::leak(Box::new(backend)); + let _ = LLAMA_BACKEND.set(backend); + Ok(LLAMA_BACKEND + .get() + .copied() + .expect("llama backend initialized")) + } + Err(llama_cpp_2::LlamaCppError::BackendAlreadyInitialized) => { + let _ = LLAMA_BACKEND.set(&EXTERNAL_LLAMA_BACKEND); + Ok(LLAMA_BACKEND + .get() + .copied() + .expect("external llama backend cached")) + } + Err(err) => Err(format!("failed to initialize llama backend: {err}")), + } +} + +fn build_rewrite_prompt( + model: &LlamaModel, + chat_template: &LlamaChatTemplate, + transcript: &RewriteTranscript, + profile: ResolvedRewriteProfile, + custom_instructions: Option<&str>, +) -> std::result::Result { + let prompt = build_prompt(transcript, profile, custom_instructions)?; + if matches!(profile, ResolvedRewriteProfile::Qwen) { + return build_qwen_rewrite_prompt(model, chat_template, &prompt); + } + + let messages = vec![ + LlamaChatMessage::new("system".into(), prompt.system) + .map_err(|e| format!("failed to build rewrite system message: {e}"))?, + LlamaChatMessage::new("user".into(), prompt.user) + .map_err(|e| format!("failed to build rewrite user message: {e}"))?, + ]; + + model + .apply_chat_template(chat_template, &messages, true) + .map_err(|e| format!("failed to apply rewrite chat template: {e}")) +} + +fn build_qwen_rewrite_prompt( + model: &LlamaModel, + chat_template: &LlamaChatTemplate, + prompt: &RewritePrompt, +) -> std::result::Result { + let messages_json = build_oaicompat_messages_json(prompt)?; + let result = model + .apply_chat_template_oaicompat( + chat_template, + &OpenAIChatTemplateParams { + messages_json: &messages_json, + tools_json: None, + tool_choice: None, + json_schema: None, + grammar: None, + reasoning_format: None, + chat_template_kwargs: None, + add_generation_prompt: true, + use_jinja: true, + parallel_tool_calls: false, + enable_thinking: false, + add_bos: false, + add_eos: false, + parse_tool_calls: false, + }, + ) + .map_err(|e| format!("failed to apply Qwen rewrite chat template: {e}"))?; + Ok(result.prompt) +} + +struct RewriteBehavior { + top_k: i32, + top_p: f32, + temperature: f32, +} + +fn rewrite_behavior(profile: ResolvedRewriteProfile) -> RewriteBehavior { + match profile { + ResolvedRewriteProfile::Qwen => RewriteBehavior { + top_k: 24, + top_p: 0.9, + temperature: 0.1, + }, + ResolvedRewriteProfile::Generic => RewriteBehavior { + top_k: 32, + top_p: 0.92, + temperature: 0.12, + }, + ResolvedRewriteProfile::LlamaCompat => RewriteBehavior { + top_k: 40, + top_p: 0.95, + temperature: 0.15, + }, + } +} diff --git a/src/bin/whispers-rewrite-worker.rs b/src/bin/whispers-rewrite-worker/main.rs similarity index 95% rename from src/bin/whispers-rewrite-worker.rs rename to src/bin/whispers-rewrite-worker/main.rs index 8a32989..ab1e928 100644 --- a/src/bin/whispers-rewrite-worker.rs +++ b/src/bin/whispers-rewrite-worker/main.rs @@ -1,11 +1,10 @@ -#[path = "../branding.rs"] mod branding; -#[path = "../rewrite.rs"] -mod rewrite; -#[path = "../rewrite_profile.rs"] +mod local; +mod output; +mod prompt; mod rewrite_profile; -#[path = "../rewrite_protocol.rs"] mod rewrite_protocol; +mod routing; use std::path::PathBuf; use std::time::Duration; @@ -14,7 +13,7 @@ use clap::Parser; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::net::{UnixListener, UnixStream}; -use crate::rewrite::LocalRewriter; +use crate::local::LocalRewriter; use crate::rewrite_profile::ResolvedRewriteProfile; use crate::rewrite_protocol::{WorkerRequest, WorkerResponse}; diff --git a/src/bin/whispers-rewrite-worker/output.rs b/src/bin/whispers-rewrite-worker/output.rs new file mode 100644 index 0000000..708ae9f --- /dev/null +++ b/src/bin/whispers-rewrite-worker/output.rs @@ -0,0 +1,78 @@ +pub fn sanitize_rewrite_output(raw: &str) -> String { + let mut text = raw.replace("\r\n", "\n"); + + for stop in ["<|eot_id|>", "<|end_of_text|>", ""] { + if let Some(index) = text.find(stop) { + text.truncate(index); + } + } + + if let Some(index) = text.find("") { + text.truncate(index); + } + + text = strip_tagged_section(&text, "", ""); + + let mut text = text.trim().to_string(); + + if let Some(stripped) = text.strip_prefix("") { + text = stripped.trim().to_string(); + } + + for prefix in ["Final text:", "Output:", "Rewritten text:"] { + if text + .get(..prefix.len()) + .map(|candidate| candidate.eq_ignore_ascii_case(prefix)) + .unwrap_or(false) + { + text = text[prefix.len()..].trim().to_string(); + break; + } + } + + if text.starts_with('"') && text.ends_with('"') && text.len() >= 2 { + text = text[1..text.len() - 1].trim().to_string(); + } + + text +} + +fn strip_tagged_section(input: &str, open: &str, close: &str) -> String { + let mut output = input.to_string(); + + while let Some(start) = output.find(open) { + let close_start = match output[start + open.len()..].find(close) { + Some(offset) => start + open.len() + offset, + None => { + output.truncate(start); + break; + } + }; + output.replace_range(start..close_start + close.len(), ""); + } + + output +} + +#[cfg(test)] +mod tests { + use super::sanitize_rewrite_output; + + #[test] + fn sanitize_rewrite_output_strips_wrapper_and_label() { + let cleaned = sanitize_rewrite_output("\nFinal text: Hi there.\n"); + assert_eq!(cleaned, "Hi there."); + } + + #[test] + fn sanitize_rewrite_output_strips_llama_stop_tokens() { + let cleaned = sanitize_rewrite_output("Hi there.<|eot_id|>ignored"); + assert_eq!(cleaned, "Hi there."); + } + + #[test] + fn sanitize_rewrite_output_strips_think_blocks() { + let cleaned = sanitize_rewrite_output("reasoning\nHi there."); + assert_eq!(cleaned, "Hi there."); + } +} diff --git a/src/bin/whispers-rewrite-worker/prompt.rs b/src/bin/whispers-rewrite-worker/prompt.rs new file mode 100644 index 0000000..3ef6d54 --- /dev/null +++ b/src/bin/whispers-rewrite-worker/prompt.rs @@ -0,0 +1,756 @@ +use super::routing::{ + RewriteRoute, has_policy_context, has_strong_explicit_edit_cue, + requires_candidate_adjudication, rewrite_route, +}; +use crate::rewrite_profile::ResolvedRewriteProfile; +use crate::rewrite_protocol::RewriteTranscript; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RewritePrompt { + pub system: String, + pub user: String, +} + +pub fn build_prompt( + transcript: &RewriteTranscript, + profile: ResolvedRewriteProfile, + custom_instructions: Option<&str>, +) -> std::result::Result { + Ok(RewritePrompt { + system: build_system_instructions(transcript, profile, custom_instructions), + user: build_user_message(transcript), + }) +} + +pub fn build_oaicompat_messages_json( + prompt: &RewritePrompt, +) -> std::result::Result { + serde_json::to_string(&[ + serde_json::json!({ + "role": "system", + "content": prompt.system, + }), + serde_json::json!({ + "role": "user", + "content": prompt.user, + }), + ]) + .map_err(|e| format!("failed to encode rewrite chat messages: {e}")) +} + +pub fn effective_max_tokens(max_tokens: usize, transcript: &RewriteTranscript) -> usize { + let word_count = transcript + .correction_aware_text + .split_whitespace() + .filter(|word| !word.is_empty()) + .count(); + let extra_budget = if requires_candidate_adjudication(transcript) { + 24 + } else { + 0 + }; + let minimum = if requires_candidate_adjudication(transcript) { + 64 + } else { + 48 + }; + let derived = word_count + .saturating_mul(2) + .saturating_add(24) + .saturating_add(extra_budget); + derived.clamp(minimum, max_tokens) +} + +pub(crate) fn build_system_instructions( + transcript: &RewriteTranscript, + profile: ResolvedRewriteProfile, + custom_instructions: Option<&str>, +) -> String { + let mut instructions = rewrite_instructions(profile).to_string(); + if has_policy_context(transcript) { + let policy_context = &transcript.policy_context; + instructions.push_str("\n\nCorrection policy contract:\n"); + instructions.push_str(correction_policy_contract(policy_context.correction_policy)); + instructions.push_str("\n\nAgentic latitude contract:\n"); + instructions.push_str(agentic_latitude_contract(policy_context.correction_policy)); + if !policy_context.effective_rule_instructions.is_empty() { + instructions.push_str("\n\nMatched app rewrite policy instructions:\n"); + for instruction in &policy_context.effective_rule_instructions { + instructions.push_str("- "); + instructions.push_str(instruction.trim()); + instructions.push('\n'); + } + } + } + if let Some(custom) = custom_instructions + .map(str::trim) + .filter(|text| !text.is_empty()) + { + instructions.push_str("\n\nAdditional user rewrite instructions:\n"); + instructions.push_str(custom); + } + instructions +} + +fn correction_policy_contract( + policy: crate::rewrite_protocol::RewriteCorrectionPolicy, +) -> &'static str { + match policy { + crate::rewrite_protocol::RewriteCorrectionPolicy::Conservative => { + "Conservative: stay close to explicit rewrite candidates and glossary evidence. If uncertain, prefer candidate-preserving output over freer rewriting." + } + crate::rewrite_protocol::RewriteCorrectionPolicy::Balanced => { + "Balanced: allow stronger technical correction when the glossary, app context, or utterance semantics support it. Prefer candidate-backed output when it is competitive, but do not keep an obviously wrong technical spelling just because it appears in the candidate list." + } + crate::rewrite_protocol::RewriteCorrectionPolicy::Aggressive => { + "Aggressive: allow freer technical correction and contextual cleanup when the utterance strongly points to a technical term or proper name. Candidates are useful evidence, not hard limits, as long as you still return only final text within the provided bounds." + } + } +} + +fn agentic_latitude_contract( + policy: crate::rewrite_protocol::RewriteCorrectionPolicy, +) -> &'static str { + match policy { + crate::rewrite_protocol::RewriteCorrectionPolicy::Conservative => { + "In conservative mode, treat the candidate list and glossary as the main evidence. Only make a freer technical normalization when the utterance itself makes the intended term unusually clear." + } + crate::rewrite_protocol::RewriteCorrectionPolicy::Balanced => { + "In balanced mode, you may normalize likely technical terms, product names, commands, libraries, languages, editors, or Linux components even when the literal transcript spelling is noisy or the exact canonical form is not already present in the candidate list, as long as the utterance strongly supports that normalization." + } + crate::rewrite_protocol::RewriteCorrectionPolicy::Aggressive => { + "In aggressive mode, you may confidently rewrite phonetically similar words into the most plausible technical term or proper name when the utterance semantics, app context, or nearby category cues make that interpretation clearly better than the literal transcript." + } + } +} + +pub(crate) fn rewrite_instructions(profile: ResolvedRewriteProfile) -> &'static str { + let base = "You clean up dictated speech into the final text the user meant to type. \ +Return only the finished text. Do not explain anything. Remove obvious disfluencies when natural. \ +Use the correction-aware transcript as the primary source of truth unless structured edit signals say the \ +utterance may still be ambiguous. The raw transcript may still contain spoken editing phrases or canceled wording. \ +Never reintroduce text that was removed by an explicit spoken correction cue. Respect any structured edit intents \ +provided alongside the transcript. If structured edit signals or edit hypotheses are present, use the candidate \ +interpretations as bounded options, choose the best interpretation, and lightly refine it only when needed for natural \ +final text. Prefer transcript spellings for names, brands, and uncommon proper nouns unless a user dictionary or \ +explicit correction says otherwise. Do not normalize names into more common spellings just because they look familiar. \ +When the utterance clearly refers to software, tools, APIs, libraries, Linux components, product names, or other \ +technical concepts, prefer the most plausible intended technical term or proper name over a phonetically similar common \ +word. Use nearby category words like window manager, editor, language, library, package manager, shell, or terminal \ +tool to disambiguate technical names. If the utterance remains genuinely ambiguous, stay close to the transcript rather \ +than inventing a niche term. \ +If an edit intent says to replace or cancel previous wording, preserve that edit and do not keep the spoken correction \ +phrase itself unless the transcript clearly still intends it. Examples:\n\ +- raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ +- raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ +- raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ +- raw: Never mind. Hi, how are you today?\n correction-aware: Hi, how are you today?\n final: Hi, how are you today?\n\ +- raw: Wait, no, it actually works.\n correction-aware: Wait, no, it actually works.\n final: Wait, no, it actually works.\n\ +- raw: Let's meet tomorrow, or rather Friday.\n correction-aware: Let's meet Friday.\n final: Let's meet Friday.\n\ +- raw: I'm currently using the window manager Hyperland.\n correction-aware: I'm currently using the window manager Hyperland.\n final: I'm currently using the window manager Hyprland.\n\ +- raw: I'm switching from Sui to Hyperland.\n correction-aware: I'm switching from Sui to Hyperland.\n final: I'm switching from Sway to Hyprland.\n\ +- raw: I use type script for backend tooling.\n correction-aware: I use type script for backend tooling.\n final: I use TypeScript for backend tooling.\n\ +- raw: I edit the config in neo vim.\n correction-aware: I edit the config in neo vim.\n final: I edit the config in Neovim."; + + match profile { + ResolvedRewriteProfile::Qwen => { + "You clean up dictated speech into the final text the user meant to type. \ +Return only the finished text. Do not explain anything. Do not emit reasoning, think tags, or XML wrappers. \ +Remove obvious disfluencies when natural. Use the correction-aware transcript as the primary source of truth unless \ +structured edit signals say the utterance may still be ambiguous. The raw transcript may still contain spoken editing \ +phrases or canceled wording. Never reintroduce text that was removed by an explicit spoken correction cue. Respect \ +any structured edit intents provided alongside the transcript. If structured edit signals or edit hypotheses are \ +present, use the candidate interpretations as bounded options, choose the best interpretation, and lightly refine it \ +only when needed for natural final text. Prefer transcript spellings for names, brands, and uncommon proper nouns \ +unless a user dictionary or explicit correction says otherwise. Do not normalize names into more common spellings just \ +because they look familiar. When the utterance clearly refers to software, tools, APIs, libraries, Linux components, \ +product names, or other technical concepts, prefer the most plausible intended technical term or proper name over a \ +phonetically similar common word. Use nearby category words like window manager, editor, language, library, package \ +manager, shell, or terminal tool to disambiguate technical names. If the utterance remains genuinely ambiguous, stay \ +close to the transcript rather than inventing a niche term. If an edit intent says to replace or cancel previous wording, preserve that edit and do \ +not keep the spoken correction phrase itself unless the transcript clearly still intends it. Examples:\n\ +- raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ +- raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ +- raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ +- raw: Never mind. Hi, how are you today?\n correction-aware: Hi, how are you today?\n final: Hi, how are you today?\n\ +- raw: Wait, no, it actually works.\n correction-aware: Wait, no, it actually works.\n final: Wait, no, it actually works.\n\ +- raw: Let's meet tomorrow, or rather Friday.\n correction-aware: Let's meet Friday.\n final: Let's meet Friday.\n\ +- raw: I'm currently using the window manager Hyperland.\n correction-aware: I'm currently using the window manager Hyperland.\n final: I'm currently using the window manager Hyprland.\n\ +- raw: I'm switching from Sui to Hyperland.\n correction-aware: I'm switching from Sui to Hyperland.\n final: I'm switching from Sway to Hyprland.\n\ +- raw: I use type script for backend tooling.\n correction-aware: I use type script for backend tooling.\n final: I use TypeScript for backend tooling.\n\ +- raw: I edit the config in neo vim.\n correction-aware: I edit the config in neo vim.\n final: I edit the config in Neovim." + } + ResolvedRewriteProfile::Generic | ResolvedRewriteProfile::LlamaCompat => base, + } +} + +pub(crate) fn build_user_message(transcript: &RewriteTranscript) -> String { + let language = transcript.detected_language.as_deref().unwrap_or("unknown"); + let correction_aware = transcript.correction_aware_text.trim(); + let raw = transcript.raw_text.trim(); + let edit_intents = render_edit_intents(transcript); + let edit_signals = render_edit_signals(transcript); + let agentic_context = render_agentic_context(transcript); + let route = rewrite_route(transcript); + tracing::debug!( + route = ?route, + edit_signals = transcript.edit_signals.len(), + edit_hypotheses = transcript.edit_hypotheses.len(), + rewrite_candidates = transcript.rewrite_candidates.len(), + "rewrite prompt route selected" + ); + + match route { + RewriteRoute::SessionCandidateAdjudication => { + let typing_context = render_typing_context(transcript); + let recent_session_entries = render_recent_session_entries(transcript); + let agentic_policy_context = render_agentic_policy_context(transcript); + let session_candidates = render_session_backtrack_candidates(transcript); + let recommended_session_candidate = render_recommended_session_candidate(transcript); + let rewrite_candidates = render_rewrite_candidates(transcript); + let surface_guidance = transcript + .typing_context + .as_ref() + .filter(|context| { + matches!( + context.surface_kind, + crate::rewrite_protocol::RewriteSurfaceKind::Terminal + ) + }) + .map(|_| { + "The active surface looks like a terminal. Stay conservative unless an explicit correction cue clearly indicates replacing the most recent prior dictation.\n" + }) + .unwrap_or(""); + format!( + "Language: {language}\n\ +Active typing context:\n\ +{typing_context}\ +Recent dictation session:\n\ +{recent_session_entries}\ +{agentic_policy_context}\ +Session backtrack candidates:\n\ +{session_candidates}\ +{recommended_session_candidate}\ +The user may be correcting the most recent prior dictation entry rather than appending new text.\n\ +If the recommended session candidate says replace_last_entry, treat your final text as the replacement text for that previous dictation entry, not as newly appended text.\n\ +Prefer the recommended session candidate unless another listed session candidate is clearly better.\n\ +{surface_guidance}\ +Current utterance correction candidate:\n\ +{correction_aware}\n\ +Raw current utterance:\n\ +{raw}\n\ +Current utterance bounded candidates:\n\ +{rewrite_candidates}\ +Final text:" + ) + } + RewriteRoute::CandidateAdjudication => { + let edit_hypotheses = render_edit_hypotheses(transcript); + let rewrite_candidates = render_rewrite_candidates(transcript); + let recommended_candidate = render_recommended_candidate(transcript); + let recent_segments = render_recent_segments(transcript, 4); + let aggressive_candidate = render_aggressive_candidate(transcript); + let exact_cue_guidance = if has_strong_explicit_edit_cue(transcript) { + "A strong explicit spoken edit cue was detected. The literal raw transcript probably contains canceled wording. \ +Prefer a candidate interpretation that removes the cue and canceled wording unless doing so would clearly lose intended meaning. \ +If the cue is an exact strong match for phrases like scratch that, never mind, or wait no, do not keep the literal cue text in the final output.\n" + } else { + "" + }; + tracing::trace!("rewrite hypotheses:\n{edit_hypotheses}"); + tracing::trace!("rewrite candidates:\n{rewrite_candidates}"); + format!( + "Language: {language}\n\ +{agentic_context}\ +Structured edit hypotheses:\n\ +{edit_hypotheses}\ +Structured edit signals:\n\ +{edit_signals}\ +Structured edit intents:\n\ +{edit_intents}\ +This utterance likely contains spoken self-corrections or restatements.\n\ +Choose the best candidate interpretation and lightly refine it only when needed.\n\ +{exact_cue_guidance}\ +When an exact strong edit cue is present, treat the non-literal candidates as more trustworthy than the literal transcript.\n\ +The candidate list is ordered from most likely to least likely by heuristics.\n\ +For exact strong edit cues, the first candidate is the heuristic best guess and should usually win unless another candidate is clearly better.\n\ +Prefer the smallest replacement scope that yields a coherent result.\n\ +Use span-level replacements when only a key phrase was corrected, clause-level replacements when the correction replaces the surrounding thought, and sentence-level replacements only when the whole sentence was canceled.\n\ +Preserve literal wording when the cue is not actually an edit.\n\ +Do not over-normalize names or brands.\n\ +Do not keep spoken edit cues in the final text when they act as edits.\n\ +{recommended_candidate}\ +Candidate interpretations:\n\ +{rewrite_candidates}\ +Correction candidate:\n\ +{correction_aware}\n\ +{aggressive_candidate}\ +Raw transcript:\n\ +{raw}\n\ +Recent segments:\n\ +{recent_segments}\n\ +Final text:" + ) + } + RewriteRoute::ResolvedCorrection => format!( + "Language: {language}\n\ +{agentic_context}\ +Structured edit signals:\n\ +{edit_signals}\ +Structured edit intents:\n\ +{edit_intents}\ +Self-corrections were already resolved before rewriting.\n\ +Use this correction-aware transcript as the main source text. In agentic mode, you may still normalize likely \ +technical terms or proper names when the utterance strongly supports them, even if the exact canonical spelling is not \ +already present in the candidate list:\n\ +{correction_aware}\n\ +{agentic_candidates}\ +Do not restore any canceled wording from earlier in the utterance.\n\ +Final text:", + agentic_candidates = render_agentic_candidates(transcript), + ), + RewriteRoute::Fast => { + let recent_segments = render_recent_segments(transcript, 4); + format!( + "Language: {language}\n\ +{agentic_context}\ +Structured edit signals:\n\ +{edit_signals}\ +Structured edit intents:\n\ +{edit_intents}\ +Correction-aware transcript:\n\ +{correction_aware}\n\ +Treat the correction-aware transcript as authoritative for explicit spoken edits and overall meaning, but in agentic \ +mode you may normalize likely technical terms or proper names when category cues in the utterance make the intended \ +technical meaning clearly better than the literal transcript.\n\ +{agentic_candidates}\ +\ +Recent segments:\n\ +{recent_segments}\n\ +Final text:", + agentic_candidates = render_agentic_candidates(transcript), + ) + } + } +} + +fn render_agentic_context(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return String::new(); + } + format!( + "{}{}", + render_agentic_runtime_context(transcript), + render_agentic_policy_context(transcript) + ) +} + +fn render_agentic_policy_context(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return String::new(); + } + let policy_context = &transcript.policy_context; + + format!( + "Agentic correction policy:\n\ +- mode: {}\n\ +Matched app rewrite rules:\n\ +{matched_rules}\ +Matched app policy instructions:\n\ +{effective_instructions}\ +Active glossary terms:\n\ +{glossary_terms}\ +", + policy_context.correction_policy.as_str(), + matched_rules = render_matched_rule_names(transcript), + effective_instructions = render_effective_rule_instructions(transcript), + glossary_terms = render_active_glossary_terms(transcript), + ) +} + +fn render_agentic_runtime_context(transcript: &RewriteTranscript) -> String { + has_policy_context(transcript) + .then(|| { + format!( + "Active typing context:\n\ +{}\ +Recent dictation session:\n\ +{}", + render_typing_context(transcript), + render_recent_session_entries(transcript), + ) + }) + .unwrap_or_default() +} + +fn render_agentic_candidates(transcript: &RewriteTranscript) -> String { + has_policy_context(transcript) + .then(|| { + format!( + "Available rewrite candidates (advisory, not exhaustive in agentic mode):\n\ +{}\ +Glossary-backed candidates:\n\ +{}", + render_rewrite_candidates(transcript), + render_glossary_candidates(transcript) + ) + }) + .unwrap_or_default() +} + +fn render_edit_intents(transcript: &RewriteTranscript) -> String { + if transcript.edit_intents.is_empty() { + return "- none detected\n".to_string(); + } + + let mut rendered = String::new(); + for intent in &transcript.edit_intents { + let action = match intent.action { + crate::rewrite_protocol::RewriteEditAction::ReplacePreviousPhrase => { + "replace_previous_phrase" + } + crate::rewrite_protocol::RewriteEditAction::ReplacePreviousClause => { + "replace_previous_clause" + } + crate::rewrite_protocol::RewriteEditAction::ReplacePreviousSentence => { + "replace_previous_sentence" + } + crate::rewrite_protocol::RewriteEditAction::DropEditCue => "drop_edit_cue", + }; + let confidence = match intent.confidence { + crate::rewrite_protocol::RewriteIntentConfidence::High => "high", + }; + rendered.push_str(&format!( + "- action: {action}, trigger: \"{}\", confidence: {confidence}\n", + intent.trigger + )); + } + + rendered +} + +fn render_edit_signals(transcript: &RewriteTranscript) -> String { + if transcript.edit_signals.is_empty() { + return "- none detected\n".to_string(); + } + + let mut rendered = String::new(); + for signal in &transcript.edit_signals { + let kind = match signal.kind { + crate::rewrite_protocol::RewriteEditSignalKind::Cancel => "cancel", + crate::rewrite_protocol::RewriteEditSignalKind::Replace => "replace", + crate::rewrite_protocol::RewriteEditSignalKind::Restatement => "restatement", + }; + let scope_hint = match signal.scope_hint { + crate::rewrite_protocol::RewriteEditSignalScope::Phrase => "phrase", + crate::rewrite_protocol::RewriteEditSignalScope::Clause => "clause", + crate::rewrite_protocol::RewriteEditSignalScope::Sentence => "sentence", + crate::rewrite_protocol::RewriteEditSignalScope::Unknown => "unknown", + }; + let strength = match signal.strength { + crate::rewrite_protocol::RewriteEditSignalStrength::Possible => "possible", + crate::rewrite_protocol::RewriteEditSignalStrength::Strong => "strong", + }; + rendered.push_str(&format!( + "- trigger: \"{}\", kind: {kind}, scope_hint: {scope_hint}, strength: {strength}\n", + signal.trigger + )); + } + + rendered +} + +fn render_edit_hypotheses(transcript: &RewriteTranscript) -> String { + if transcript.edit_hypotheses.is_empty() { + return "- none detected\n".to_string(); + } + + let mut rendered = String::new(); + for hypothesis in &transcript.edit_hypotheses { + let match_source = match hypothesis.match_source { + crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact => "exact", + crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Alias => "alias", + crate::rewrite_protocol::RewriteEditHypothesisMatchSource::NearMiss => "near_miss", + }; + let kind = match hypothesis.kind { + crate::rewrite_protocol::RewriteEditSignalKind::Cancel => "cancel", + crate::rewrite_protocol::RewriteEditSignalKind::Replace => "replace", + crate::rewrite_protocol::RewriteEditSignalKind::Restatement => "restatement", + }; + let scope_hint = match hypothesis.scope_hint { + crate::rewrite_protocol::RewriteEditSignalScope::Phrase => "phrase", + crate::rewrite_protocol::RewriteEditSignalScope::Clause => "clause", + crate::rewrite_protocol::RewriteEditSignalScope::Sentence => "sentence", + crate::rewrite_protocol::RewriteEditSignalScope::Unknown => "unknown", + }; + let strength = match hypothesis.strength { + crate::rewrite_protocol::RewriteEditSignalStrength::Possible => "possible", + crate::rewrite_protocol::RewriteEditSignalStrength::Strong => "strong", + }; + let replacement_scope = match hypothesis.replacement_scope { + crate::rewrite_protocol::RewriteReplacementScope::Span => "span", + crate::rewrite_protocol::RewriteReplacementScope::Clause => "clause", + crate::rewrite_protocol::RewriteReplacementScope::Sentence => "sentence", + }; + let tail_shape = match hypothesis.tail_shape { + crate::rewrite_protocol::RewriteTailShape::Empty => "empty", + crate::rewrite_protocol::RewriteTailShape::Phrase => "phrase", + crate::rewrite_protocol::RewriteTailShape::Clause => "clause", + }; + rendered.push_str(&format!( + "- cue_family: {}, matched_text: \"{}\", match_source: {match_source}, kind: {kind}, scope_hint: {scope_hint}, replacement_scope: {replacement_scope}, tail_shape: {tail_shape}, strength: {strength}\n", + hypothesis.cue_family, hypothesis.matched_text + )); + } + + rendered +} + +fn render_rewrite_candidates(transcript: &RewriteTranscript) -> String { + if transcript.rewrite_candidates.is_empty() { + return "- no candidates available\n".to_string(); + } + + let mut rendered = String::new(); + let highlight_first = has_strong_explicit_edit_cue(transcript); + for (index, candidate) in transcript.rewrite_candidates.iter().enumerate() { + let prefix = if highlight_first && index == 0 { + "- preferred_candidate" + } else { + "-" + }; + let kind = match candidate.kind { + crate::rewrite_protocol::RewriteCandidateKind::Literal => { + "literal (keep only if the cue was not actually an edit)" + } + crate::rewrite_protocol::RewriteCandidateKind::ConservativeCorrection => { + "conservative_correction (balanced cleanup)" + } + crate::rewrite_protocol::RewriteCandidateKind::AggressiveCorrection => { + "aggressive_correction (use when canceled wording should be removed more fully)" + } + crate::rewrite_protocol::RewriteCandidateKind::GlossaryCorrection => { + "glossary_correction (supported by active glossary aliases)" + } + crate::rewrite_protocol::RewriteCandidateKind::SpanReplacement => { + "span_replacement (replace only the corrected phrase)" + } + crate::rewrite_protocol::RewriteCandidateKind::ClauseReplacement => { + "clause_replacement (replace the corrected clause while keeping surrounding context)" + } + crate::rewrite_protocol::RewriteCandidateKind::SentenceReplacement => { + "sentence_replacement (replace the whole corrected sentence)" + } + crate::rewrite_protocol::RewriteCandidateKind::ContextualReplacement => { + "contextual_replacement (replace the corrected span while keeping earlier context)" + } + crate::rewrite_protocol::RewriteCandidateKind::DropCueOnly => { + "drop_cue_only (remove just the spoken edit cue)" + } + crate::rewrite_protocol::RewriteCandidateKind::FollowingReplacement => { + "following_replacement (keep only the wording after the cue)" + } + crate::rewrite_protocol::RewriteCandidateKind::CancelPreviousClause => { + "cancel_previous_clause (treat the cue as canceling the prior clause)" + } + crate::rewrite_protocol::RewriteCandidateKind::CancelPreviousSentence => { + "cancel_previous_sentence (treat the cue as canceling the prior sentence)" + } + }; + rendered.push_str(&format!("{prefix} {kind}: {}\n", candidate.text)); + } + + rendered +} + +fn render_recommended_candidate(transcript: &RewriteTranscript) -> String { + transcript + .recommended_candidate + .as_ref() + .map(|candidate| { + format!( + "Recommended interpretation:\n{}\nUse this as the default final text unless another candidate is clearly better.\n", + candidate.text + ) + }) + .unwrap_or_default() +} + +fn render_typing_context(transcript: &RewriteTranscript) -> String { + transcript + .typing_context + .as_ref() + .map(|context| { + format!( + "- focus_fingerprint: {}\n- app_id: {}\n- window_title: {}\n- surface_kind: {}\n- browser_domain: {}\n", + context.focus_fingerprint, + context.app_id.as_deref().unwrap_or("unknown"), + context.window_title.as_deref().unwrap_or("unknown"), + context.surface_kind.as_str(), + context.browser_domain.as_deref().unwrap_or("unknown"), + ) + }) + .unwrap_or_else(|| "- none available\n".to_string()) +} + +fn render_recent_session_entries(transcript: &RewriteTranscript) -> String { + if transcript.recent_session_entries.is_empty() { + return "- none available\n".to_string(); + } + + let mut rendered = String::new(); + for entry in &transcript.recent_session_entries { + rendered.push_str(&format!( + "- id: {}, text: {}, grapheme_len: {}, surface_kind: {}\n", + entry.id, + entry.final_text, + entry.grapheme_len, + entry.surface_kind.as_str() + )); + } + rendered +} + +fn render_session_backtrack_candidates(transcript: &RewriteTranscript) -> String { + if transcript.session_backtrack_candidates.is_empty() { + return "- no session backtrack candidates\n".to_string(); + } + + let mut rendered = String::new(); + for candidate in &transcript.session_backtrack_candidates { + let kind = match candidate.kind { + crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::AppendCurrent => { + "append_current" + } + crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::ReplaceLastEntry => { + "replace_last_entry" + } + }; + rendered.push_str(&format!( + "- kind: {kind}, entry_id: {}, delete_graphemes: {}, text: {}\n", + candidate + .entry_id + .map(|entry_id| entry_id.to_string()) + .unwrap_or_else(|| "none".to_string()), + candidate.delete_graphemes, + candidate.text + )); + } + rendered +} + +fn render_recommended_session_candidate(transcript: &RewriteTranscript) -> String { + transcript + .recommended_session_candidate + .as_ref() + .map(|candidate| { + let mode = match candidate.kind { + crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::AppendCurrent => { + "append_current" + } + crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::ReplaceLastEntry => { + "replace_last_entry" + } + }; + format!( + "Recommended session action:\nmode: {mode}\nentry_id: {}\ndelete_graphemes: {}\ntext: {}\n", + candidate + .entry_id + .map(|entry_id| entry_id.to_string()) + .unwrap_or_else(|| "none".to_string()), + candidate.delete_graphemes, + candidate.text + ) + }) + .unwrap_or_default() +} + +fn render_recent_segments(transcript: &RewriteTranscript, limit: usize) -> String { + let total_segments = transcript.segments.len(); + let start = total_segments.saturating_sub(limit); + let mut rendered = String::new(); + + for segment in &transcript.segments[start..] { + let line = format!( + "- {}-{} ms: {}\n", + segment.start_ms, segment.end_ms, segment.text + ); + rendered.push_str(&line); + } + + if rendered.is_empty() { + rendered.push_str("- no segments available\n"); + } + + rendered +} + +fn render_aggressive_candidate(transcript: &RewriteTranscript) -> String { + transcript + .aggressive_correction_text + .as_deref() + .map(str::trim) + .filter(|text| !text.is_empty()) + .map(|text| format!("Aggressive correction candidate:\n{text}\n")) + .unwrap_or_default() +} + +fn render_matched_rule_names(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return "- none\n".to_string(); + } + let policy_context = &transcript.policy_context; + if policy_context.matched_rule_names.is_empty() { + return "- none\n".to_string(); + } + policy_context + .matched_rule_names + .iter() + .map(|name| format!("- {name}\n")) + .collect() +} + +fn render_effective_rule_instructions(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return "- none\n".to_string(); + } + let policy_context = &transcript.policy_context; + if policy_context.effective_rule_instructions.is_empty() { + return "- none\n".to_string(); + } + policy_context + .effective_rule_instructions + .iter() + .map(|instruction| format!("- {}\n", instruction.trim())) + .collect() +} + +fn render_active_glossary_terms(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return "- none\n".to_string(); + } + let policy_context = &transcript.policy_context; + if policy_context.active_glossary_terms.is_empty() { + return "- none\n".to_string(); + } + policy_context + .active_glossary_terms + .iter() + .map(|entry| format!("- {} <- [{}]\n", entry.term, entry.aliases.join(", "))) + .collect() +} + +fn render_glossary_candidates(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return "- none\n".to_string(); + } + let policy_context = &transcript.policy_context; + if policy_context.glossary_candidates.is_empty() { + return "- none\n".to_string(); + } + policy_context + .glossary_candidates + .iter() + .map(|candidate| format!("- {}\n", candidate.text)) + .collect() +} diff --git a/src/bin/whispers-rewrite-worker/rewrite_profile.rs b/src/bin/whispers-rewrite-worker/rewrite_profile.rs new file mode 100644 index 0000000..be14e3a --- /dev/null +++ b/src/bin/whispers-rewrite-worker/rewrite_profile.rs @@ -0,0 +1,104 @@ +use std::path::Path; + +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, PartialEq, Eq, ValueEnum)] +#[serde(rename_all = "snake_case")] +pub enum RewriteProfile { + #[default] + Auto, + Generic, + Qwen, + LlamaCompat, +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, ValueEnum)] +#[serde(rename_all = "snake_case")] +pub enum ResolvedRewriteProfile { + Generic, + Qwen, + LlamaCompat, +} + +impl RewriteProfile { + #[allow(dead_code)] + pub fn as_str(self) -> &'static str { + match self { + Self::Auto => "auto", + Self::Generic => "generic", + Self::Qwen => "qwen", + Self::LlamaCompat => "llama_compat", + } + } + + #[allow(dead_code)] + pub fn resolve(self, managed_model: Option<&str>, model_path: &Path) -> ResolvedRewriteProfile { + match self { + Self::Auto => managed_model + .and_then(resolve_identifier) + .or_else(|| { + model_path + .file_name() + .and_then(|name| resolve_identifier(&name.to_string_lossy())) + }) + .unwrap_or(ResolvedRewriteProfile::Generic), + Self::Generic => ResolvedRewriteProfile::Generic, + Self::Qwen => ResolvedRewriteProfile::Qwen, + Self::LlamaCompat => ResolvedRewriteProfile::LlamaCompat, + } + } +} + +impl ResolvedRewriteProfile { + #[allow(dead_code)] + pub fn as_str(self) -> &'static str { + match self { + Self::Generic => "generic", + Self::Qwen => "qwen", + Self::LlamaCompat => "llama_compat", + } + } +} + +#[allow(dead_code)] +fn resolve_identifier(identifier: &str) -> Option { + let normalized = identifier.to_ascii_lowercase(); + if normalized.contains("qwen") { + return Some(ResolvedRewriteProfile::Qwen); + } + + if normalized.contains("llama") { + return Some(ResolvedRewriteProfile::LlamaCompat); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn auto_profile_prefers_managed_model_name() { + let resolved = RewriteProfile::Auto.resolve( + Some("qwen-3.5-4b-q4_k_m"), + Path::new("/tmp/Llama-3.2-3B-Instruct-Q4_K_M.gguf"), + ); + assert_eq!(resolved, ResolvedRewriteProfile::Qwen); + } + + #[test] + fn auto_profile_falls_back_to_model_filename() { + let resolved = + RewriteProfile::Auto.resolve(None, Path::new("/tmp/Llama-3.2-3B-Instruct-Q4_K_M.gguf")); + assert_eq!(resolved, ResolvedRewriteProfile::LlamaCompat); + } + + #[test] + fn auto_profile_uses_generic_for_unknown_models() { + let resolved = + RewriteProfile::Auto.resolve(None, Path::new("/tmp/CustomDictationModel.gguf")); + assert_eq!(resolved, ResolvedRewriteProfile::Generic); + } +} diff --git a/src/bin/whispers-rewrite-worker/rewrite_protocol.rs b/src/bin/whispers-rewrite-worker/rewrite_protocol.rs new file mode 100644 index 0000000..edacdc4 --- /dev/null +++ b/src/bin/whispers-rewrite-worker/rewrite_protocol.rs @@ -0,0 +1,259 @@ +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteTranscript { + pub raw_text: String, + pub correction_aware_text: String, + pub aggressive_correction_text: Option, + pub detected_language: Option, + pub typing_context: Option, + pub recent_session_entries: Vec, + pub session_backtrack_candidates: Vec, + pub recommended_session_candidate: Option, + pub segments: Vec, + pub edit_intents: Vec, + pub edit_signals: Vec, + pub edit_hypotheses: Vec, + pub rewrite_candidates: Vec, + pub recommended_candidate: Option, + #[serde(default)] + pub policy_context: RewritePolicyContext, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteTranscriptSegment { + pub text: String, + pub start_ms: u32, + pub end_ms: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteTypingContext { + pub focus_fingerprint: String, + pub app_id: Option, + pub window_title: Option, + pub surface_kind: RewriteSurfaceKind, + pub browser_domain: Option, + pub captured_at_ms: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteSessionEntry { + pub id: u64, + pub final_text: String, + pub grapheme_len: usize, + pub focus_fingerprint: String, + pub surface_kind: RewriteSurfaceKind, + pub app_id: Option, + pub window_title: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteSessionBacktrackCandidate { + pub kind: RewriteSessionBacktrackCandidateKind, + pub entry_id: Option, + pub delete_graphemes: usize, + pub text: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteEditIntent { + pub action: RewriteEditAction, + pub trigger: String, + pub confidence: RewriteIntentConfidence, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteEditSignal { + pub trigger: String, + pub kind: RewriteEditSignalKind, + pub scope_hint: RewriteEditSignalScope, + pub strength: RewriteEditSignalStrength, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteEditHypothesis { + pub cue_family: String, + pub matched_text: String, + pub match_source: RewriteEditHypothesisMatchSource, + pub kind: RewriteEditSignalKind, + pub scope_hint: RewriteEditSignalScope, + pub replacement_scope: RewriteReplacementScope, + pub tail_shape: RewriteTailShape, + pub strength: RewriteEditSignalStrength, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewriteCandidate { + pub kind: RewriteCandidateKind, + pub text: String, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewritePolicyContext { + pub correction_policy: RewriteCorrectionPolicy, + pub matched_rule_names: Vec, + pub effective_rule_instructions: Vec, + pub active_glossary_terms: Vec, + pub glossary_candidates: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewritePolicyGlossaryTerm { + pub term: String, + pub aliases: Vec, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteEditAction { + ReplacePreviousPhrase, + ReplacePreviousClause, + ReplacePreviousSentence, + DropEditCue, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteIntentConfidence { + High, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteEditSignalKind { + Cancel, + Replace, + Restatement, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteEditSignalScope { + Phrase, + Clause, + Sentence, + Unknown, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteEditSignalStrength { + Possible, + Strong, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteEditHypothesisMatchSource { + Exact, + Alias, + NearMiss, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteReplacementScope { + Span, + Clause, + Sentence, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteTailShape { + Empty, + Phrase, + Clause, +} + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, ValueEnum)] +#[serde(rename_all = "snake_case")] +pub enum RewriteCorrectionPolicy { + Conservative, + #[default] + Balanced, + Aggressive, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteCandidateKind { + Literal, + ConservativeCorrection, + AggressiveCorrection, + GlossaryCorrection, + SpanReplacement, + ClauseReplacement, + SentenceReplacement, + ContextualReplacement, + DropCueOnly, + FollowingReplacement, + CancelPreviousClause, + CancelPreviousSentence, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, ValueEnum)] +#[serde(rename_all = "snake_case")] +pub enum RewriteSurfaceKind { + Browser, + Terminal, + Editor, + GenericText, + Unknown, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteSessionBacktrackCandidateKind { + AppendCurrent, + ReplaceLastEntry, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum WorkerRequest { + Rewrite { + transcript: RewriteTranscript, + custom_instructions: Option, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum WorkerResponse { + Result { text: String }, + Error { message: String }, +} + +impl RewriteCorrectionPolicy { + pub fn as_str(self) -> &'static str { + match self { + Self::Conservative => "conservative", + Self::Balanced => "balanced", + Self::Aggressive => "aggressive", + } + } +} + +impl RewriteSurfaceKind { + pub fn as_str(self) -> &'static str { + match self { + Self::Browser => "browser", + Self::Terminal => "terminal", + Self::Editor => "editor", + Self::GenericText => "generic_text", + Self::Unknown => "unknown", + } + } +} + +impl RewritePolicyContext { + pub fn is_active(&self) -> bool { + !self.matched_rule_names.is_empty() + || !self.effective_rule_instructions.is_empty() + || !self.active_glossary_terms.is_empty() + || !self.glossary_candidates.is_empty() + } +} diff --git a/src/bin/whispers-rewrite-worker/routing.rs b/src/bin/whispers-rewrite-worker/routing.rs new file mode 100644 index 0000000..13b8c45 --- /dev/null +++ b/src/bin/whispers-rewrite-worker/routing.rs @@ -0,0 +1,46 @@ +use crate::rewrite_protocol::{ + RewriteEditHypothesisMatchSource, RewriteEditSignalStrength, RewriteTranscript, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum RewriteRoute { + Fast, + ResolvedCorrection, + SessionCandidateAdjudication, + CandidateAdjudication, +} + +pub(super) fn rewrite_route(transcript: &RewriteTranscript) -> RewriteRoute { + if has_session_backtrack_candidate(transcript) { + RewriteRoute::SessionCandidateAdjudication + } else if requires_candidate_adjudication(transcript) { + RewriteRoute::CandidateAdjudication + } else if transcript.correction_aware_text.trim() != transcript.raw_text.trim() { + RewriteRoute::ResolvedCorrection + } else { + RewriteRoute::Fast + } +} + +pub(super) fn requires_candidate_adjudication(transcript: &RewriteTranscript) -> bool { + !transcript.edit_signals.is_empty() || !transcript.edit_hypotheses.is_empty() +} + +pub(super) fn has_strong_explicit_edit_cue(transcript: &RewriteTranscript) -> bool { + transcript.edit_hypotheses.iter().any(|hypothesis| { + hypothesis.strength == RewriteEditSignalStrength::Strong + && matches!( + hypothesis.match_source, + RewriteEditHypothesisMatchSource::Exact | RewriteEditHypothesisMatchSource::Alias + ) + }) +} + +pub(super) fn has_session_backtrack_candidate(transcript: &RewriteTranscript) -> bool { + transcript.recommended_session_candidate.is_some() + || !transcript.session_backtrack_candidates.is_empty() +} + +pub(super) fn has_policy_context(transcript: &RewriteTranscript) -> bool { + transcript.policy_context.is_active() +} diff --git a/src/cleanup.rs b/src/cleanup.rs index 6275614..1bf4992 100644 --- a/src/cleanup.rs +++ b/src/cleanup.rs @@ -1,6 +1,16 @@ use crate::config::{CleanupConfig, CleanupProfile}; use crate::transcribe::Transcript; +mod analysis; +mod lexicon; +mod render; +#[cfg(test)] +mod tests; + +pub use analysis::{ + clean_transcript, correction_analysis, correction_aware_text, explicit_followup_replacement, +}; + #[derive(Debug, Clone, PartialEq, Eq)] enum Piece { Word(String), @@ -147,491 +157,6 @@ struct CueFamilySpec { session_followup_aliases: &'static [&'static [&'static str]], } -pub fn clean_transcript(transcript: &Transcript, config: &CleanupConfig) -> String { - let raw = transcript.raw_text.trim(); - if raw.is_empty() - || !config.enabled - || !supports_cleanup_language(transcript.detected_language.as_deref()) - { - return raw.to_string(); - } - - let mut pieces = tokenize(raw); - if config.spoken_formatting { - pieces = apply_spoken_formatting(pieces); - } - if config.backtrack { - pieces = apply_backtrack(pieces, config.profile).pieces; - } - if config.remove_fillers { - pieces = remove_fillers(pieces); - } - - render_pieces(&pieces) -} - -pub fn correction_analysis(transcript: &Transcript) -> CorrectionAnalysis { - let raw = transcript.raw_text.trim(); - if raw.is_empty() || !supports_cleanup_language(transcript.detected_language.as_deref()) { - return CorrectionAnalysis { - text: raw.to_string(), - aggressive_text: None, - edit_intents: Vec::new(), - edit_signals: Vec::new(), - edit_hypotheses: Vec::new(), - }; - } - - let mut pieces = tokenize(raw); - pieces = apply_spoken_formatting(pieces); - let aggressive_outcome = apply_backtrack(pieces.clone(), CleanupProfile::Aggressive); - let outcome = apply_backtrack(pieces, CleanupProfile::Basic); - let text = render_pieces(&outcome.pieces); - let aggressive_text = render_pieces(&aggressive_outcome.pieces); - let edit_hypotheses = collect_edit_hypotheses(raw, &outcome.edit_signals); - - CorrectionAnalysis { - aggressive_text: (!outcome.edit_signals.is_empty() - && aggressive_text != text - && !aggressive_text.is_empty()) - .then_some(aggressive_text), - text, - edit_intents: outcome.edit_intents, - edit_signals: outcome.edit_signals, - edit_hypotheses, - } -} - -pub fn correction_aware_text(transcript: &Transcript) -> String { - correction_analysis(transcript).text -} - -pub fn explicit_followup_replacement(raw: &str) -> Option { - let raw = raw.trim(); - if raw.is_empty() { - return None; - } - - let pieces = apply_spoken_formatting(tokenize(raw)); - let lookahead = if let Some(trigger) = match_correction_trigger(&pieces, 0) { - if !matches!(trigger.kind, CorrectionKind::Sentence) - || !matches!(trigger.signal_kind, EditSignalKind::Cancel) - || !trigger.drop_cue_without_context - { - return None; - } - skip_correction_gap(&pieces, trigger.trigger_end) - } else { - explicit_followup_cue_lookahead(&pieces)? - }; - if count_upcoming_words(&pieces, lookahead) == 0 { - return None; - } - - let rendered = render_pieces(&pieces[lookahead..]); - (!rendered.is_empty()).then_some(rendered) -} - -fn explicit_followup_cue_lookahead(pieces: &[Piece]) -> Option { - for spec in cue_family_specs() { - if !matches!(spec.kind, EditSignalKind::Cancel) { - continue; - } - - if let Some(end) = match_words_with_soft_punctuation(pieces, 0, spec.canonical) { - return Some(skip_correction_gap(pieces, end)); - } - - for alias in spec.session_followup_aliases { - if let Some(end) = match_words_with_soft_punctuation(pieces, 0, alias) { - return Some(skip_correction_gap(pieces, end)); - } - } - } - - None -} - -fn cue_family_specs() -> &'static [CueFamilySpec] { - const SCRATCH_THAT_ALIASES: &[&[&str]] = &[&["scratchthat"]]; - const SCRATCH_THAT_SESSION_ALIASES: &[&[&str]] = &[ - &["scratchthat"], - &["scratchvat"], - &["scratchfat"], - &["scratchfart"], - &["scratchfarts"], - &["scratchfot"], - &["scratchbot"], - &["scratchvatnotes"], - &["rajvat"], - &["srajvat"], - &["srajvatnotes"], - ]; - const NEVER_MIND_ALIASES: &[&[&str]] = &[&["nevermind"]]; - const NEVER_MIND_SESSION_ALIASES: &[&[&str]] = &[&["nevermind"], &["nevarmind"]]; - const WAIT_NO_ALIASES: &[&[&str]] = &[&["waitno"]]; - const I_MEANT_ALIASES: &[&[&str]] = &[&["imeant"]]; - const I_MEAN_ALIASES: &[&[&str]] = &[&["imean"]]; - const OR_RATHER_ALIASES: &[&[&str]] = &[&["orrather"]]; - const SPECS: &[CueFamilySpec] = &[ - CueFamilySpec { - cue_family: "scratch_that", - kind: EditSignalKind::Cancel, - scope_hint: EditSignalScope::Sentence, - canonical: &["scratch", "that"], - aliases: SCRATCH_THAT_ALIASES, - session_followup_aliases: SCRATCH_THAT_SESSION_ALIASES, - }, - CueFamilySpec { - cue_family: "never_mind", - kind: EditSignalKind::Cancel, - scope_hint: EditSignalScope::Sentence, - canonical: &["never", "mind"], - aliases: NEVER_MIND_ALIASES, - session_followup_aliases: NEVER_MIND_SESSION_ALIASES, - }, - CueFamilySpec { - cue_family: "wait_no", - kind: EditSignalKind::Replace, - scope_hint: EditSignalScope::Sentence, - canonical: &["wait", "no"], - aliases: WAIT_NO_ALIASES, - session_followup_aliases: &[], - }, - CueFamilySpec { - cue_family: "i_meant", - kind: EditSignalKind::Replace, - scope_hint: EditSignalScope::Phrase, - canonical: &["i", "meant"], - aliases: I_MEANT_ALIASES, - session_followup_aliases: &[], - }, - CueFamilySpec { - cue_family: "i_mean", - kind: EditSignalKind::Replace, - scope_hint: EditSignalScope::Phrase, - canonical: &["i", "mean"], - aliases: I_MEAN_ALIASES, - session_followup_aliases: &[], - }, - CueFamilySpec { - cue_family: "or_rather", - kind: EditSignalKind::Restatement, - scope_hint: EditSignalScope::Phrase, - canonical: &["or", "rather"], - aliases: OR_RATHER_ALIASES, - session_followup_aliases: &[], - }, - ]; - - SPECS -} - -fn collect_edit_hypotheses(raw: &str, edit_signals: &[EditSignal]) -> Vec { - let observed_words = collect_observed_words(raw); - if observed_words.is_empty() { - return Vec::new(); - } - - let mut hypotheses = Vec::new(); - - for spec in cue_family_specs() { - let mut index = 0usize; - while index < observed_words.len() { - let Some((match_source, matched_len)) = - match_cue_family_at(&observed_words, index, spec) - else { - index += 1; - continue; - }; - - let signal = signal_for_cue_family(edit_signals, spec.cue_family); - let strength = signal - .map(|signal| signal.strength) - .unwrap_or(match match_source { - EditHypothesisMatchSource::NearMiss => EditSignalStrength::Possible, - EditHypothesisMatchSource::Exact | EditHypothesisMatchSource::Alias => { - EditSignalStrength::Strong - } - }); - let scope_hint = signal - .map(|signal| signal.scope_hint) - .unwrap_or(spec.scope_hint); - let kind = signal.map(|signal| signal.kind).unwrap_or(spec.kind); - let (replacement_scope, tail_shape) = - classify_replacement_scope(&observed_words, index + matched_len, scope_hint); - - hypotheses.push(EditHypothesis { - cue_family: spec.cue_family, - matched_text: observed_words[index..index + matched_len] - .iter() - .map(|word| word.text.as_str()) - .collect::>() - .join(" "), - match_source, - kind, - scope_hint, - replacement_scope, - tail_shape, - strength, - word_start: index, - word_end: index + matched_len, - }); - - index += matched_len.max(1); - } - } - - hypotheses.sort_by_key(|hypothesis| (hypothesis.word_start, hypothesis.word_end)); - hypotheses.dedup_by(|right, left| { - right.cue_family == left.cue_family - && right.word_start == left.word_start - && right.word_end == left.word_end - }); - hypotheses -} - -fn collect_observed_words(raw: &str) -> Vec { - tokenize(raw) - .into_iter() - .filter_map(|piece| match piece { - Piece::Word(text) => { - let normalized = normalized_word_str(&text); - (!normalized.is_empty()).then_some(ObservedWord { text, normalized }) - } - Piece::Punctuation(_) | Piece::Break(_) => None, - }) - .collect() -} - -fn match_cue_family_at( - observed_words: &[ObservedWord], - start: usize, - spec: &CueFamilySpec, -) -> Option<(EditHypothesisMatchSource, usize)> { - if matches_observed_words(observed_words, start, spec.canonical) { - return Some((EditHypothesisMatchSource::Exact, spec.canonical.len())); - } - - for alias in spec.aliases { - if matches_observed_words(observed_words, start, alias) { - return Some((EditHypothesisMatchSource::Alias, alias.len())); - } - } - - let mut candidate_lengths = vec![1usize, spec.canonical.len()]; - candidate_lengths.extend(spec.aliases.iter().map(|alias| alias.len())); - candidate_lengths.sort_unstable(); - candidate_lengths.dedup(); - - for len in candidate_lengths { - if start + len > observed_words.len() { - continue; - } - - let observed = compact_observed_words(&observed_words[start..start + len]); - if observed.is_empty() { - continue; - } - - if is_limited_near_miss(&observed, &compact_phrase(spec.canonical)) - || spec - .aliases - .iter() - .any(|alias| is_limited_near_miss(&observed, &compact_phrase(alias))) - { - return Some((EditHypothesisMatchSource::NearMiss, len)); - } - } - - None -} - -fn matches_observed_words( - observed_words: &[ObservedWord], - start: usize, - expected: &[&str], -) -> bool { - expected.iter().enumerate().all(|(offset, expected_word)| { - observed_words - .get(start + offset) - .map(|word| word.normalized.as_str()) - == Some(*expected_word) - }) -} - -fn compact_observed_words(observed_words: &[ObservedWord]) -> String { - observed_words - .iter() - .map(|word| word.normalized.as_str()) - .collect::>() - .join("") -} - -fn compact_phrase(words: &[&str]) -> String { - words.join("") -} - -fn signal_for_cue_family<'a>( - edit_signals: &'a [EditSignal], - cue_family: &str, -) -> Option<&'a EditSignal> { - edit_signals - .iter() - .find(|signal| cue_family_for_phrase(signal.trigger) == Some(cue_family)) -} - -fn cue_family_for_phrase(phrase: &str) -> Option<&'static str> { - match phrase { - "scratch that" | "actually scratch that" => Some("scratch_that"), - "never mind" - | "nevermind" - | "actually never mind" - | "actually nevermind" - | "oh wait never mind" - | "oh wait nevermind" - | "forget that" => Some("never_mind"), - "wait no" | "actually wait no" => Some("wait_no"), - "i meant" | "actually i meant" => Some("i_meant"), - "i mean" | "actually i mean" => Some("i_mean"), - "or rather" => Some("or_rather"), - _ => None, - } -} - -fn classify_replacement_scope( - observed_words: &[ObservedWord], - tail_start: usize, - scope_hint: EditSignalScope, -) -> (ReplacementScope, TailShape) { - let tail_words = observed_words - .get(tail_start..) - .unwrap_or(&[]) - .iter() - .map(|word| word.normalized.as_str()) - .take(8) - .collect::>(); - - if tail_words.is_empty() { - let replacement_scope = match scope_hint { - EditSignalScope::Sentence => ReplacementScope::Sentence, - EditSignalScope::Clause => ReplacementScope::Clause, - EditSignalScope::Phrase | EditSignalScope::Unknown => ReplacementScope::Span, - }; - return (replacement_scope, TailShape::Empty); - } - - let tail_shape = if looks_like_clause_tail(&tail_words) { - TailShape::Clause - } else { - TailShape::Phrase - }; - - let replacement_scope = match scope_hint { - EditSignalScope::Sentence => { - if matches!(tail_shape, TailShape::Phrase) && tail_words.len() > 3 { - ReplacementScope::Clause - } else { - ReplacementScope::Sentence - } - } - EditSignalScope::Clause => { - if matches!(tail_shape, TailShape::Phrase) { - ReplacementScope::Span - } else { - ReplacementScope::Clause - } - } - EditSignalScope::Phrase => ReplacementScope::Span, - EditSignalScope::Unknown => { - if matches!(tail_shape, TailShape::Clause) { - ReplacementScope::Clause - } else { - ReplacementScope::Span - } - } - }; - - (replacement_scope, tail_shape) -} - -fn looks_like_clause_tail(tail_words: &[&str]) -> bool { - const CLAUSE_WORDS: &[&str] = &[ - "am", "are", "be", "been", "being", "can", "could", "did", "do", "does", "had", "has", - "have", "is", "must", "need", "needs", "required", "requires", "should", "was", "were", - "will", "would", - ]; - const SUBJECT_WORDS: &[&str] = &[ - "i", "it", "he", "she", "they", "we", "you", "this", "that", "there", "my", "our", "their", - "your", - ]; - - tail_words.iter().any(|word| CLAUSE_WORDS.contains(word)) - || (tail_words.len() >= 2 - && SUBJECT_WORDS.contains(&tail_words[0]) - && !matches!(tail_words[1], "and" | "or" | "but")) -} - -fn is_limited_near_miss(observed: &str, target: &str) -> bool { - if observed.is_empty() || target.is_empty() || observed == target { - return false; - } - - let common_prefix = observed - .chars() - .zip(target.chars()) - .take_while(|(left, right)| left == right) - .count(); - if common_prefix < observed.len().min(target.len()).min(4) { - return false; - } - - let observed_prefix = if observed.chars().count() > target.chars().count() { - observed - .chars() - .take(target.chars().count()) - .collect::() - } else { - observed.to_string() - }; - let distance = bounded_levenshtein(&observed_prefix, target, 3); - distance <= 3 -} - -fn bounded_levenshtein(left: &str, right: &str, max_distance: usize) -> usize { - let left_chars: Vec = left.chars().collect(); - let right_chars: Vec = right.chars().collect(); - - if left_chars.is_empty() { - return right_chars.len(); - } - if right_chars.is_empty() { - return left_chars.len(); - } - if left_chars.len().abs_diff(right_chars.len()) > max_distance { - return max_distance + 1; - } - - let mut prev: Vec = (0..=right_chars.len()).collect(); - let mut curr = vec![0usize; right_chars.len() + 1]; - - for (i, left_char) in left_chars.iter().enumerate() { - curr[0] = i + 1; - let mut row_min = curr[0]; - for (j, right_char) in right_chars.iter().enumerate() { - let cost = usize::from(left_char != right_char); - curr[j + 1] = (prev[j + 1] + 1).min(curr[j] + 1).min(prev[j] + cost); - row_min = row_min.min(curr[j + 1]); - } - if row_min > max_distance { - return max_distance + 1; - } - std::mem::swap(&mut prev, &mut curr); - } - - prev[right_chars.len()] -} - fn supports_cleanup_language(language: Option<&str>) -> bool { matches!(language, Some("en")) } @@ -696,506 +221,10 @@ fn flush_newlines(pieces: &mut Vec, streak: &mut u8) { *streak = 0; } -fn apply_spoken_formatting(pieces: Vec) -> Vec { - let mut out = Vec::with_capacity(pieces.len()); - let mut i = 0; - - while i < pieces.len() { - if matches_words(&pieces, i, &["new", "paragraph"]) { - out.push(Piece::Break(BreakKind::Paragraph)); - i += 2; - continue; - } - if matches_words(&pieces, i, &["new", "line"]) { - out.push(Piece::Break(BreakKind::Line)); - i += 2; - continue; - } - if matches_words(&pieces, i, &["question", "mark"]) { - out.push(Piece::Punctuation('?')); - i += 2; - continue; - } - if matches_words(&pieces, i, &["exclamation", "point"]) - || matches_words(&pieces, i, &["exclamation", "mark"]) - { - out.push(Piece::Punctuation('!')); - i += 2; - continue; - } - if matches_words(&pieces, i, &["full", "stop"]) { - out.push(Piece::Punctuation('.')); - i += 2; - continue; - } - - match normalized_word(pieces.get(i)).as_deref() { - Some("comma") => { - out.push(Piece::Punctuation(',')); - i += 1; - } - Some("period") => { - out.push(Piece::Punctuation('.')); - i += 1; - } - Some("colon") => { - out.push(Piece::Punctuation(':')); - i += 1; - } - Some("semicolon") => { - out.push(Piece::Punctuation(';')); - i += 1; - } - _ => { - out.push(pieces[i].clone()); - i += 1; - } - } - } - - out -} - -fn apply_backtrack(pieces: Vec, profile: CleanupProfile) -> BacktrackOutcome { - let mut out = Vec::with_capacity(pieces.len()); - let mut edit_intents = Vec::new(); - let mut edit_signals = Vec::new(); - let mut i = 0; - - while i < pieces.len() { - let Some(trigger) = match_correction_trigger(&pieces, i) else { - out.push(pieces[i].clone()); - i += 1; - continue; - }; - - let lookahead = skip_correction_gap(&pieces, trigger.trigger_end); - let replacement_words = count_upcoming_words(&pieces, lookahead); - let prior_context_words = output_word_count(&out); - let signal_strength = if prior_context_words >= trigger.min_context_words - || trigger.drop_cue_without_context - || (replacement_words == 0 && trigger.allow_terminal_cancel) - { - EditSignalStrength::Strong - } else { - EditSignalStrength::Possible - }; - let default_scope = match trigger.kind { - CorrectionKind::Phrase => EditSignalScope::Phrase, - CorrectionKind::Sentence => EditSignalScope::Sentence, - }; - - if replacement_words == 0 { - if prior_context_words >= trigger.min_context_words && trigger.allow_terminal_cancel { - let action = trim_terminal_cancel_scope(&mut out, trigger.kind); - edit_signals.push(EditSignal { - trigger: trigger.phrase, - kind: trigger.signal_kind, - scope_hint: scope_hint_for_action(action), - strength: signal_strength, - }); - edit_intents.push(EditIntent { - action, - trigger: trigger.phrase, - confidence: EditIntentConfidence::High, - }); - i = lookahead; - continue; - } - if trigger.drop_cue_without_context { - edit_signals.push(EditSignal { - trigger: trigger.phrase, - kind: trigger.signal_kind, - scope_hint: EditSignalScope::Unknown, - strength: EditSignalStrength::Strong, - }); - edit_intents.push(EditIntent { - action: EditIntentAction::DropEditCue, - trigger: trigger.phrase, - confidence: EditIntentConfidence::High, - }); - i = lookahead; - continue; - } - edit_signals.push(EditSignal { - trigger: trigger.phrase, - kind: trigger.signal_kind, - scope_hint: default_scope, - strength: signal_strength, - }); - out.push(pieces[i].clone()); - i += 1; - continue; - } - - if prior_context_words < trigger.min_context_words { - if trigger.drop_cue_without_context { - edit_signals.push(EditSignal { - trigger: trigger.phrase, - kind: trigger.signal_kind, - scope_hint: EditSignalScope::Unknown, - strength: EditSignalStrength::Strong, - }); - edit_intents.push(EditIntent { - action: EditIntentAction::DropEditCue, - trigger: trigger.phrase, - confidence: EditIntentConfidence::High, - }); - i = lookahead; - continue; - } - edit_signals.push(EditSignal { - trigger: trigger.phrase, - kind: trigger.signal_kind, - scope_hint: default_scope, - strength: signal_strength, - }); - out.push(pieces[i].clone()); - i += 1; - continue; - } - - let action = match trigger.kind { - CorrectionKind::Phrase => { - trim_recent_phrase(&mut out, profile, replacement_words); - EditIntentAction::ReplacePreviousPhrase - } - CorrectionKind::Sentence => { - if ends_with_sentence_boundary(&out) { - trim_last_sentence(&mut out); - EditIntentAction::ReplacePreviousSentence - } else { - trim_recent_phrase(&mut out, profile, replacement_words); - EditIntentAction::ReplacePreviousClause - } - } - }; - edit_signals.push(EditSignal { - trigger: trigger.phrase, - kind: trigger.signal_kind, - scope_hint: scope_hint_for_action(action), - strength: signal_strength, - }); - edit_intents.push(EditIntent { - action, - trigger: trigger.phrase, - confidence: EditIntentConfidence::High, - }); - i = lookahead; - } - - BacktrackOutcome { - pieces: out, - edit_intents, - edit_signals, - } -} - -fn match_correction_trigger(pieces: &[Piece], i: usize) -> Option { - if let Some(end) = - match_words_with_soft_punctuation(pieces, i, &["oh", "wait", "never", "mind"]) - { - return Some(cancel_sentence_trigger( - "oh wait never mind", - end, - true, - true, - )); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["oh", "wait", "nevermind"]) { - return Some(cancel_sentence_trigger( - "oh wait nevermind", - end, - true, - true, - )); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["forget", "that"]) { - return Some(cancel_sentence_trigger("forget that", end, true, true)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["wait", "no"]) { - return Some(replace_sentence_trigger("wait no", end, true, false)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["scratch", "that"]) { - return Some(cancel_sentence_trigger("scratch that", end, true, true)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["never", "mind"]) { - return Some(cancel_sentence_trigger("never mind", end, true, true)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["nevermind"]) { - return Some(cancel_sentence_trigger("nevermind", end, true, true)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "wait", "no"]) { - return Some(replace_sentence_trigger( - "actually wait no", - end, - true, - false, - )); - } - if let Some(end) = - match_words_with_soft_punctuation(pieces, i, &["actually", "scratch", "that"]) - { - return Some(cancel_sentence_trigger( - "actually scratch that", - end, - true, - true, - )); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "never", "mind"]) - { - return Some(cancel_sentence_trigger( - "actually never mind", - end, - true, - true, - )); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "nevermind"]) { - return Some(cancel_sentence_trigger( - "actually nevermind", - end, - true, - true, - )); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "no"]) { - return Some(replace_phrase_trigger("actually no", end)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "i", "meant"]) { - return Some(replace_phrase_trigger("actually i meant", end)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "i", "mean"]) { - return Some(restatement_phrase_trigger("actually i mean", end)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["i", "meant"]) { - return Some(replace_phrase_trigger("i meant", end)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["i", "mean"]) { - return Some(restatement_phrase_trigger("i mean", end)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["or", "rather"]) { - return Some(restatement_phrase_trigger("or rather", end)); - } - if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["no"]) { - let previous_word = previous_word_before(pieces, i); - if matches!(pieces.get(end), Some(Piece::Punctuation(','))) - && !matches!(previous_word.as_deref(), Some("wait" | "actually")) - { - return Some(CorrectionTrigger { - kind: CorrectionKind::Phrase, - signal_kind: EditSignalKind::Replace, - trigger_end: end + 1, - min_context_words: 2, - phrase: "no", - allow_terminal_cancel: false, - drop_cue_without_context: false, - }); - } - } - None -} - -fn cancel_sentence_trigger( - phrase: &'static str, - trigger_end: usize, - allow_terminal_cancel: bool, - drop_cue_without_context: bool, -) -> CorrectionTrigger { - CorrectionTrigger { - kind: CorrectionKind::Sentence, - signal_kind: EditSignalKind::Cancel, - trigger_end, - min_context_words: 1, - phrase, - allow_terminal_cancel, - drop_cue_without_context, - } -} - -fn replace_sentence_trigger( - phrase: &'static str, - trigger_end: usize, - allow_terminal_cancel: bool, - drop_cue_without_context: bool, -) -> CorrectionTrigger { - CorrectionTrigger { - kind: CorrectionKind::Sentence, - signal_kind: EditSignalKind::Replace, - trigger_end, - min_context_words: 1, - phrase, - allow_terminal_cancel, - drop_cue_without_context, - } -} - -fn replace_phrase_trigger(phrase: &'static str, trigger_end: usize) -> CorrectionTrigger { - CorrectionTrigger { - kind: CorrectionKind::Phrase, - signal_kind: EditSignalKind::Replace, - trigger_end, - min_context_words: 1, - phrase, - allow_terminal_cancel: false, - drop_cue_without_context: false, - } -} - -fn restatement_phrase_trigger(phrase: &'static str, trigger_end: usize) -> CorrectionTrigger { - CorrectionTrigger { - kind: CorrectionKind::Phrase, - signal_kind: EditSignalKind::Restatement, - trigger_end, - min_context_words: 1, - phrase, - allow_terminal_cancel: false, - drop_cue_without_context: false, - } -} - -fn scope_hint_for_action(action: EditIntentAction) -> EditSignalScope { - match action { - EditIntentAction::ReplacePreviousPhrase => EditSignalScope::Phrase, - EditIntentAction::ReplacePreviousClause => EditSignalScope::Clause, - EditIntentAction::ReplacePreviousSentence => EditSignalScope::Sentence, - EditIntentAction::DropEditCue => EditSignalScope::Unknown, - } -} - -fn trim_terminal_cancel_scope(out: &mut Vec, kind: CorrectionKind) -> EditIntentAction { - match kind { - CorrectionKind::Phrase => { - trim_last_clause(out); - EditIntentAction::ReplacePreviousClause - } - CorrectionKind::Sentence => { - if ends_with_sentence_boundary(out) { - trim_last_sentence(out); - EditIntentAction::ReplacePreviousSentence - } else { - trim_last_clause(out); - EditIntentAction::ReplacePreviousClause - } - } - } -} - -fn count_upcoming_words(pieces: &[Piece], start: usize) -> usize { - let mut count = 0; - for piece in pieces.iter().skip(start) { - match piece { - Piece::Word(_) => count += 1, - Piece::Break(_) => break, - Piece::Punctuation(ch) if is_strong_boundary(*ch) => break, - _ => {} - } - } - count -} - -fn trim_recent_phrase(out: &mut Vec, profile: CleanupProfile, replacement_words: usize) { - trim_soft_suffix(out); - - let max_words = match profile { - CleanupProfile::Basic => replacement_words.clamp(1, 3), - CleanupProfile::Aggressive => replacement_words.clamp(2, 6), - }; - - let mut removed_words = 0usize; - while removed_words < max_words { - match out.pop() { - Some(Piece::Word(_)) => removed_words += 1, - Some(Piece::Punctuation(ch)) if !is_strong_boundary(ch) => {} - Some(piece @ Piece::Punctuation(_)) | Some(piece @ Piece::Break(_)) => { - out.push(piece); - break; - } - None => break, - } - } - - if profile == CleanupProfile::Aggressive { - while let Some(piece) = out.pop() { - match piece { - Piece::Word(_) => continue, - Piece::Punctuation(ch) if !is_strong_boundary(ch) => continue, - other => { - out.push(other); - break; - } - } - } - } - - trim_soft_suffix(out); -} - -fn trim_last_sentence(out: &mut Vec) { - trim_trailing_boundaries(out); - - let mut removed_word = false; - while let Some(piece) = out.pop() { - match piece { - Piece::Word(_) => removed_word = true, - Piece::Punctuation(ch) if is_strong_boundary(ch) => { - if removed_word { - break; - } - } - Piece::Break(_) => { - if removed_word { - break; - } - } - Piece::Punctuation(_) => {} - } - } - - trim_soft_suffix(out); -} - -fn trim_last_clause(out: &mut Vec) { - trim_trailing_boundaries(out); - - let mut removed_word = false; - while let Some(piece) = out.pop() { - match piece { - Piece::Word(_) => removed_word = true, - Piece::Punctuation(ch) if is_clause_boundary(ch) => { - if removed_word { - break; - } - } - Piece::Break(_) => { - if removed_word { - break; - } - } - Piece::Punctuation(_) => {} - } - } - - trim_soft_suffix(out); -} - -fn trim_soft_suffix(out: &mut Vec) { - while let Some(last) = out.last() { - match last { - Piece::Punctuation(ch) if !is_strong_boundary(*ch) => { - out.pop(); - } - _ => break, - } - } -} - -fn output_word_count(pieces: &[Piece]) -> usize { - pieces - .iter() - .filter(|piece| matches!(piece, Piece::Word(_))) - .count() +fn matches_words(pieces: &[Piece], start: usize, words: &[&str]) -> bool { + words.iter().enumerate().all(|(offset, expected)| { + normalized_word(pieces.get(start + offset)).as_deref() == Some(*expected) + }) } fn skip_soft_punctuation(pieces: &[Piece], mut index: usize) -> usize { @@ -1214,116 +243,6 @@ fn skip_correction_gap(pieces: &[Piece], mut index: usize) -> usize { } } -fn trim_trailing_boundaries(out: &mut Vec) { - while let Some(last) = out.last() { - match last { - Piece::Punctuation(_) | Piece::Break(_) => { - out.pop(); - } - Piece::Word(_) => break, - } - } -} - -fn ends_with_sentence_boundary(out: &[Piece]) -> bool { - out.iter().rev().find_map(|piece| match piece { - Piece::Punctuation(ch) => Some(is_strong_boundary(*ch)), - Piece::Break(_) => Some(true), - Piece::Word(_) => None, - }) == Some(true) -} - -fn remove_fillers(pieces: Vec) -> Vec { - let mut out = Vec::with_capacity(pieces.len()); - for piece in pieces { - match piece { - Piece::Word(word) if is_filler(&word) => continue, - other => out.push(other), - } - } - out -} - -fn is_filler(word: &str) -> bool { - matches!( - normalized_word_str(word).as_str(), - "um" | "umm" | "uh" | "uhh" | "er" | "erm" | "ah" - ) -} - -fn render_pieces(pieces: &[Piece]) -> String { - let mut rendered = String::new(); - let mut capitalize_next = true; - - for piece in pieces { - match piece { - Piece::Word(word) => { - if !rendered.is_empty() && !rendered.ends_with([' ', '\n']) { - rendered.push(' '); - } - if capitalize_next { - rendered.push_str(&capitalize_first(word)); - } else { - rendered.push_str(word); - } - capitalize_next = false; - } - Piece::Punctuation(ch) => { - trim_trailing_spaces(&mut rendered); - rendered.push(*ch); - capitalize_next = matches!(ch, '.' | '?' | '!'); - } - Piece::Break(BreakKind::Line) => { - trim_trailing_spaces(&mut rendered); - if !rendered.is_empty() { - if !rendered.ends_with('\n') { - rendered.push('\n'); - } - capitalize_next = true; - } - } - Piece::Break(BreakKind::Paragraph) => { - trim_trailing_spaces(&mut rendered); - if !rendered.is_empty() { - while rendered.ends_with('\n') { - rendered.pop(); - } - rendered.push('\n'); - rendered.push('\n'); - capitalize_next = true; - } - } - } - } - - trim_trailing_spaces(&mut rendered); - rendered -} - -fn trim_trailing_spaces(text: &mut String) { - while text.ends_with(' ') { - text.pop(); - } -} - -fn capitalize_first(word: &str) -> String { - let mut chars = word.chars(); - let Some(first) = chars.next() else { - return String::new(); - }; - - let mut result = String::new(); - result.extend(first.to_uppercase()); - result.extend(chars); - result -} - -fn matches_words(pieces: &[Piece], start: usize, words: &[&str]) -> bool { - words.iter().enumerate().all(|(offset, expected)| { - normalized_word(pieces.get(start + offset)).as_deref() == Some(*expected) - }) -} - fn match_words_with_soft_punctuation( pieces: &[Piece], start: usize, @@ -1376,253 +295,3 @@ fn is_strong_boundary(ch: char) -> bool { fn is_clause_boundary(ch: char) -> bool { is_strong_boundary(ch) || matches!(ch, ',' | ':' | ';') } - -#[cfg(test)] -mod tests { - use super::*; - use crate::config::{CleanupConfig, CleanupProfile}; - use crate::transcribe::Transcript; - - fn transcript(text: &str) -> Transcript { - Transcript { - raw_text: text.to_string(), - detected_language: Some("en".to_string()), - segments: Vec::new(), - } - } - - #[test] - fn removes_common_fillers() { - let cleaned = clean_transcript( - &transcript("um i think we should go"), - &CleanupConfig::default(), - ); - assert_eq!(cleaned, "I think we should go"); - } - - #[test] - fn preserves_non_filler_words() { - let cleaned = clean_transcript(&transcript("i like apples"), &CleanupConfig::default()); - assert_eq!(cleaned, "I like apples"); - } - - #[test] - fn converts_spoken_punctuation_commands() { - let cleaned = clean_transcript( - &transcript("hello comma world question mark"), - &CleanupConfig::default(), - ); - assert_eq!(cleaned, "Hello, world?"); - } - - #[test] - fn converts_spoken_line_and_paragraph_commands() { - let cleaned = clean_transcript( - &transcript("first line new line second line new paragraph third line"), - &CleanupConfig::default(), - ); - assert_eq!(cleaned, "First line\nSecond line\n\nThird line"); - } - - #[test] - fn basic_backtrack_replaces_recent_phrase() { - let cleaned = clean_transcript( - &transcript("let's meet at 4 actually no 3"), - &CleanupConfig::default(), - ); - assert_eq!(cleaned, "Let's meet at 3"); - } - - #[test] - fn standalone_actually_is_preserved() { - let cleaned = correction_aware_text(&transcript("it actually works")); - assert_eq!(cleaned, "It actually works"); - } - - #[test] - fn rather_is_preserved_in_normal_phrasing() { - let cleaned = correction_aware_text(&transcript("i would rather stay home")); - assert_eq!(cleaned, "I would rather stay home"); - } - - #[test] - fn actually_rather_is_preserved_in_normal_phrasing() { - let cleaned = correction_aware_text(&transcript("i would actually rather stay home")); - assert_eq!(cleaned, "I would actually rather stay home"); - } - - #[test] - fn punctuated_wait_no_replaces_last_sentence() { - let cleaned = clean_transcript( - &transcript("hi there, this is a test of whisper osd. wait, no. hi there."), - &CleanupConfig::default(), - ); - assert_eq!(cleaned, "Hi there."); - } - - #[test] - fn punctuated_wait_no_still_replaces_inline_phrase() { - let cleaned = clean_transcript( - &transcript("let's meet at 4 wait, no, 3"), - &CleanupConfig::default(), - ); - assert_eq!(cleaned, "Let's meet at 3"); - } - - #[test] - fn scratch_that_replaces_recent_word() { - let cleaned = clean_transcript( - &transcript("i'll bring cookies scratch that brownies"), - &CleanupConfig::default(), - ); - assert_eq!(cleaned, "I'll bring brownies"); - } - - #[test] - fn correction_aware_text_drops_previous_sentence_for_scratch_that() { - let cleaned = correction_aware_text(&transcript( - "hello there, this is a test of whisper rs. scratch that. hi.", - )); - assert_eq!(cleaned, "Hi."); - } - - #[test] - fn terminal_never_mind_cancels_previous_clause() { - let cleaned = correction_aware_text(&transcript("hello there oh wait never mind")); - assert_eq!(cleaned, ""); - } - - #[test] - fn utterance_initial_never_mind_is_dropped_when_content_follows() { - let cleaned = correction_aware_text(&transcript("never mind. hi, how are you today?")); - assert_eq!(cleaned, "Hi, how are you today?"); - } - - #[test] - fn explicit_followup_replacement_handles_session_aliases() { - assert_eq!( - explicit_followup_replacement("srajvat, hi").as_deref(), - Some("Hi") - ); - assert_eq!( - explicit_followup_replacement("scratchfarts, hi").as_deref(), - Some("Hi") - ); - } - - #[test] - fn correction_analysis_reports_terminal_cancel_intent() { - let analysis = correction_analysis(&transcript("hello there never mind")); - assert_eq!(analysis.text, ""); - assert_eq!(analysis.edit_intents.len(), 1); - assert_eq!( - analysis.edit_intents[0].action, - EditIntentAction::ReplacePreviousClause - ); - assert_eq!(analysis.edit_intents[0].trigger, "never mind"); - assert_eq!(analysis.edit_signals.len(), 1); - assert_eq!(analysis.edit_signals[0].kind, EditSignalKind::Cancel); - assert_eq!(analysis.edit_signals[0].scope_hint, EditSignalScope::Clause); - assert_eq!( - analysis.edit_signals[0].strength, - EditSignalStrength::Strong - ); - assert_eq!(analysis.edit_hypotheses.len(), 1); - assert_eq!(analysis.edit_hypotheses[0].cue_family, "never_mind"); - assert_eq!( - analysis.edit_hypotheses[0].match_source, - EditHypothesisMatchSource::Exact - ); - } - - #[test] - fn utterance_initial_wait_no_is_not_treated_as_backtrack() { - let analysis = correction_analysis(&transcript("wait, no, it actually works")); - assert_eq!(analysis.text, "Wait, no, it actually works"); - assert_eq!(analysis.edit_intents.len(), 0); - assert_eq!(analysis.edit_signals.len(), 1); - assert_eq!(analysis.edit_signals[0].kind, EditSignalKind::Replace); - assert_eq!( - analysis.edit_signals[0].scope_hint, - EditSignalScope::Sentence - ); - assert_eq!( - analysis.edit_signals[0].strength, - EditSignalStrength::Possible - ); - assert!(analysis.edit_hypotheses.iter().any(|hypothesis| { - hypothesis.cue_family == "wait_no" - && hypothesis.match_source == EditHypothesisMatchSource::Exact - })); - } - - #[test] - fn correction_analysis_reports_restatement_signal_for_or_rather() { - let analysis = correction_analysis(&transcript("let's meet tomorrow or rather friday")); - assert_eq!(analysis.text, "Let's meet friday"); - assert_eq!(analysis.edit_signals.len(), 1); - assert_eq!(analysis.edit_signals[0].kind, EditSignalKind::Restatement); - assert_eq!(analysis.edit_signals[0].scope_hint, EditSignalScope::Phrase); - } - - #[test] - fn correction_analysis_exposes_aggressive_candidate_for_ambiguous_replacement() { - let analysis = correction_analysis(&transcript( - "my name is notes, scratch that my name is jonatan", - )); - assert_eq!(analysis.text, "My my name is jonatan"); - assert_eq!( - analysis.aggressive_text.as_deref(), - Some("My name is jonatan") - ); - assert_eq!( - analysis.edit_hypotheses[0].replacement_scope, - ReplacementScope::Clause - ); - assert_eq!(analysis.edit_hypotheses[0].tail_shape, TailShape::Clause); - } - - #[test] - fn correction_analysis_collects_near_miss_hypothesis_for_scratch_that_family() { - let analysis = correction_analysis(&transcript("hello there scratch vat hi")); - assert!(analysis.edit_hypotheses.iter().any(|hypothesis| { - hypothesis.cue_family == "scratch_that" - && hypothesis.match_source == EditHypothesisMatchSource::NearMiss - && hypothesis.matched_text == "scratch vat" - })); - } - - #[test] - fn correction_analysis_marks_phrase_tail_as_span_scope() { - let analysis = correction_analysis(&transcript( - "mobile apps or sms codes scratch that just sms codes", - )); - assert!(analysis.edit_hypotheses.iter().any(|hypothesis| { - hypothesis.cue_family == "scratch_that" - && hypothesis.replacement_scope == ReplacementScope::Span - && hypothesis.tail_shape == TailShape::Phrase - })); - } - - #[test] - fn aggressive_profile_trims_more_context() { - let config = CleanupConfig { - profile: CleanupProfile::Aggressive, - ..CleanupConfig::default() - }; - let cleaned = - clean_transcript(&transcript("alpha beta gamma delta wait no omega"), &config); - assert_eq!(cleaned, "Omega"); - } - - #[test] - fn skips_advanced_cleanup_for_non_english_transcripts() { - let transcript = Transcript { - raw_text: "um hola comma mundo".to_string(), - detected_language: Some("es".to_string()), - segments: Vec::new(), - }; - let cleaned = clean_transcript(&transcript, &CleanupConfig::default()); - assert_eq!(cleaned, "um hola comma mundo"); - } -} diff --git a/src/cleanup/analysis.rs b/src/cleanup/analysis.rs new file mode 100644 index 0000000..41135c0 --- /dev/null +++ b/src/cleanup/analysis.rs @@ -0,0 +1,633 @@ +use super::lexicon::{collect_edit_hypotheses, explicit_followup_cue_lookahead}; +use super::render::render_pieces; +use super::{ + BacktrackOutcome, CleanupConfig, CleanupProfile, CorrectionAnalysis, CorrectionKind, + CorrectionTrigger, EditIntent, EditIntentAction, EditIntentConfidence, EditSignal, + EditSignalKind, EditSignalScope, EditSignalStrength, Piece, Transcript, is_clause_boundary, + is_strong_boundary, match_words_with_soft_punctuation, matches_words, previous_word_before, + skip_correction_gap, supports_cleanup_language, +}; + +pub fn clean_transcript(transcript: &Transcript, config: &CleanupConfig) -> String { + let raw = transcript.raw_text.trim(); + if raw.is_empty() + || !config.enabled + || !supports_cleanup_language(transcript.detected_language.as_deref()) + { + return raw.to_string(); + } + + let mut pieces = super::tokenize(raw); + if config.spoken_formatting { + pieces = apply_spoken_formatting(pieces); + } + if config.backtrack { + pieces = apply_backtrack(pieces, config.profile).pieces; + } + if config.remove_fillers { + pieces = remove_fillers(pieces); + } + + render_pieces(&pieces) +} + +pub fn correction_analysis(transcript: &Transcript) -> CorrectionAnalysis { + let raw = transcript.raw_text.trim(); + if raw.is_empty() || !supports_cleanup_language(transcript.detected_language.as_deref()) { + return CorrectionAnalysis { + text: raw.to_string(), + aggressive_text: None, + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + }; + } + + let mut pieces = super::tokenize(raw); + pieces = apply_spoken_formatting(pieces); + let aggressive_outcome = apply_backtrack(pieces.clone(), CleanupProfile::Aggressive); + let outcome = apply_backtrack(pieces, CleanupProfile::Basic); + let text = render_pieces(&outcome.pieces); + let aggressive_text = render_pieces(&aggressive_outcome.pieces); + let edit_hypotheses = collect_edit_hypotheses(raw, &outcome.edit_signals); + + CorrectionAnalysis { + aggressive_text: (!outcome.edit_signals.is_empty() + && aggressive_text != text + && !aggressive_text.is_empty()) + .then_some(aggressive_text), + text, + edit_intents: outcome.edit_intents, + edit_signals: outcome.edit_signals, + edit_hypotheses, + } +} + +pub fn correction_aware_text(transcript: &Transcript) -> String { + correction_analysis(transcript).text +} + +pub fn explicit_followup_replacement(raw: &str) -> Option { + let raw = raw.trim(); + if raw.is_empty() { + return None; + } + + let pieces = apply_spoken_formatting(super::tokenize(raw)); + let lookahead = if let Some(trigger) = match_correction_trigger(&pieces, 0) { + if !matches!(trigger.kind, CorrectionKind::Sentence) + || !matches!(trigger.signal_kind, EditSignalKind::Cancel) + || !trigger.drop_cue_without_context + { + return None; + } + skip_correction_gap(&pieces, trigger.trigger_end) + } else { + explicit_followup_cue_lookahead(&pieces)? + }; + if count_upcoming_words(&pieces, lookahead) == 0 { + return None; + } + + let rendered = render_pieces(&pieces[lookahead..]); + (!rendered.is_empty()).then_some(rendered) +} + +fn apply_spoken_formatting(pieces: Vec) -> Vec { + let mut out = Vec::with_capacity(pieces.len()); + let mut i = 0; + + while i < pieces.len() { + if matches_words(&pieces, i, &["new", "paragraph"]) { + out.push(Piece::Break(super::BreakKind::Paragraph)); + i += 2; + continue; + } + if matches_words(&pieces, i, &["new", "line"]) { + out.push(Piece::Break(super::BreakKind::Line)); + i += 2; + continue; + } + if matches_words(&pieces, i, &["question", "mark"]) { + out.push(Piece::Punctuation('?')); + i += 2; + continue; + } + if matches_words(&pieces, i, &["exclamation", "point"]) + || matches_words(&pieces, i, &["exclamation", "mark"]) + { + out.push(Piece::Punctuation('!')); + i += 2; + continue; + } + if matches_words(&pieces, i, &["full", "stop"]) { + out.push(Piece::Punctuation('.')); + i += 2; + continue; + } + + match super::normalized_word(pieces.get(i)).as_deref() { + Some("comma") => { + out.push(Piece::Punctuation(',')); + i += 1; + } + Some("period") => { + out.push(Piece::Punctuation('.')); + i += 1; + } + Some("colon") => { + out.push(Piece::Punctuation(':')); + i += 1; + } + Some("semicolon") => { + out.push(Piece::Punctuation(';')); + i += 1; + } + _ => { + out.push(pieces[i].clone()); + i += 1; + } + } + } + + out +} + +fn apply_backtrack(pieces: Vec, profile: CleanupProfile) -> BacktrackOutcome { + let mut out = Vec::with_capacity(pieces.len()); + let mut edit_intents = Vec::new(); + let mut edit_signals = Vec::new(); + let mut i = 0; + + while i < pieces.len() { + let Some(trigger) = match_correction_trigger(&pieces, i) else { + out.push(pieces[i].clone()); + i += 1; + continue; + }; + + let lookahead = skip_correction_gap(&pieces, trigger.trigger_end); + let replacement_words = count_upcoming_words(&pieces, lookahead); + let prior_context_words = output_word_count(&out); + let signal_strength = if prior_context_words >= trigger.min_context_words + || trigger.drop_cue_without_context + || (replacement_words == 0 && trigger.allow_terminal_cancel) + { + EditSignalStrength::Strong + } else { + EditSignalStrength::Possible + }; + let default_scope = match trigger.kind { + CorrectionKind::Phrase => EditSignalScope::Phrase, + CorrectionKind::Sentence => EditSignalScope::Sentence, + }; + + if replacement_words == 0 { + if prior_context_words >= trigger.min_context_words && trigger.allow_terminal_cancel { + let action = trim_terminal_cancel_scope(&mut out, trigger.kind); + edit_signals.push(EditSignal { + trigger: trigger.phrase, + kind: trigger.signal_kind, + scope_hint: scope_hint_for_action(action), + strength: signal_strength, + }); + edit_intents.push(EditIntent { + action, + trigger: trigger.phrase, + confidence: EditIntentConfidence::High, + }); + i = lookahead; + continue; + } + if trigger.drop_cue_without_context { + edit_signals.push(EditSignal { + trigger: trigger.phrase, + kind: trigger.signal_kind, + scope_hint: EditSignalScope::Unknown, + strength: EditSignalStrength::Strong, + }); + edit_intents.push(EditIntent { + action: EditIntentAction::DropEditCue, + trigger: trigger.phrase, + confidence: EditIntentConfidence::High, + }); + i = lookahead; + continue; + } + edit_signals.push(EditSignal { + trigger: trigger.phrase, + kind: trigger.signal_kind, + scope_hint: default_scope, + strength: signal_strength, + }); + out.push(pieces[i].clone()); + i += 1; + continue; + } + + if prior_context_words < trigger.min_context_words { + if trigger.drop_cue_without_context { + edit_signals.push(EditSignal { + trigger: trigger.phrase, + kind: trigger.signal_kind, + scope_hint: EditSignalScope::Unknown, + strength: EditSignalStrength::Strong, + }); + edit_intents.push(EditIntent { + action: EditIntentAction::DropEditCue, + trigger: trigger.phrase, + confidence: EditIntentConfidence::High, + }); + i = lookahead; + continue; + } + edit_signals.push(EditSignal { + trigger: trigger.phrase, + kind: trigger.signal_kind, + scope_hint: default_scope, + strength: signal_strength, + }); + out.push(pieces[i].clone()); + i += 1; + continue; + } + + let action = match trigger.kind { + CorrectionKind::Phrase => { + trim_recent_phrase(&mut out, profile, replacement_words); + EditIntentAction::ReplacePreviousPhrase + } + CorrectionKind::Sentence => { + if ends_with_sentence_boundary(&out) { + trim_last_sentence(&mut out); + EditIntentAction::ReplacePreviousSentence + } else { + trim_recent_phrase(&mut out, profile, replacement_words); + EditIntentAction::ReplacePreviousClause + } + } + }; + edit_signals.push(EditSignal { + trigger: trigger.phrase, + kind: trigger.signal_kind, + scope_hint: scope_hint_for_action(action), + strength: signal_strength, + }); + edit_intents.push(EditIntent { + action, + trigger: trigger.phrase, + confidence: EditIntentConfidence::High, + }); + i = lookahead; + } + + BacktrackOutcome { + pieces: out, + edit_intents, + edit_signals, + } +} + +fn match_correction_trigger(pieces: &[Piece], i: usize) -> Option { + if let Some(end) = + match_words_with_soft_punctuation(pieces, i, &["oh", "wait", "never", "mind"]) + { + return Some(cancel_sentence_trigger( + "oh wait never mind", + end, + true, + true, + )); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["oh", "wait", "nevermind"]) { + return Some(cancel_sentence_trigger( + "oh wait nevermind", + end, + true, + true, + )); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["forget", "that"]) { + return Some(cancel_sentence_trigger("forget that", end, true, true)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["wait", "no"]) { + return Some(replace_sentence_trigger("wait no", end, true, false)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["scratch", "that"]) { + return Some(cancel_sentence_trigger("scratch that", end, true, true)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["never", "mind"]) { + return Some(cancel_sentence_trigger("never mind", end, true, true)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["nevermind"]) { + return Some(cancel_sentence_trigger("nevermind", end, true, true)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "wait", "no"]) { + return Some(replace_sentence_trigger( + "actually wait no", + end, + true, + false, + )); + } + if let Some(end) = + match_words_with_soft_punctuation(pieces, i, &["actually", "scratch", "that"]) + { + return Some(cancel_sentence_trigger( + "actually scratch that", + end, + true, + true, + )); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "never", "mind"]) + { + return Some(cancel_sentence_trigger( + "actually never mind", + end, + true, + true, + )); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "nevermind"]) { + return Some(cancel_sentence_trigger( + "actually nevermind", + end, + true, + true, + )); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "no"]) { + return Some(replace_phrase_trigger("actually no", end)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "i", "meant"]) { + return Some(replace_phrase_trigger("actually i meant", end)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["actually", "i", "mean"]) { + return Some(restatement_phrase_trigger("actually i mean", end)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["i", "meant"]) { + return Some(replace_phrase_trigger("i meant", end)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["i", "mean"]) { + return Some(restatement_phrase_trigger("i mean", end)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["or", "rather"]) { + return Some(restatement_phrase_trigger("or rather", end)); + } + if let Some(end) = match_words_with_soft_punctuation(pieces, i, &["no"]) { + let previous_word = previous_word_before(pieces, i); + if matches!(pieces.get(end), Some(Piece::Punctuation(','))) + && !matches!(previous_word.as_deref(), Some("wait" | "actually")) + { + return Some(CorrectionTrigger { + kind: CorrectionKind::Phrase, + signal_kind: EditSignalKind::Replace, + trigger_end: end + 1, + min_context_words: 2, + phrase: "no", + allow_terminal_cancel: false, + drop_cue_without_context: false, + }); + } + } + None +} + +fn cancel_sentence_trigger( + phrase: &'static str, + trigger_end: usize, + allow_terminal_cancel: bool, + drop_cue_without_context: bool, +) -> CorrectionTrigger { + CorrectionTrigger { + kind: CorrectionKind::Sentence, + signal_kind: EditSignalKind::Cancel, + trigger_end, + min_context_words: 1, + phrase, + allow_terminal_cancel, + drop_cue_without_context, + } +} + +fn replace_sentence_trigger( + phrase: &'static str, + trigger_end: usize, + allow_terminal_cancel: bool, + drop_cue_without_context: bool, +) -> CorrectionTrigger { + CorrectionTrigger { + kind: CorrectionKind::Sentence, + signal_kind: EditSignalKind::Replace, + trigger_end, + min_context_words: 1, + phrase, + allow_terminal_cancel, + drop_cue_without_context, + } +} + +fn replace_phrase_trigger(phrase: &'static str, trigger_end: usize) -> CorrectionTrigger { + CorrectionTrigger { + kind: CorrectionKind::Phrase, + signal_kind: EditSignalKind::Replace, + trigger_end, + min_context_words: 1, + phrase, + allow_terminal_cancel: false, + drop_cue_without_context: false, + } +} + +fn restatement_phrase_trigger(phrase: &'static str, trigger_end: usize) -> CorrectionTrigger { + CorrectionTrigger { + kind: CorrectionKind::Phrase, + signal_kind: EditSignalKind::Restatement, + trigger_end, + min_context_words: 1, + phrase, + allow_terminal_cancel: false, + drop_cue_without_context: false, + } +} + +fn scope_hint_for_action(action: EditIntentAction) -> EditSignalScope { + match action { + EditIntentAction::ReplacePreviousPhrase => EditSignalScope::Phrase, + EditIntentAction::ReplacePreviousClause => EditSignalScope::Clause, + EditIntentAction::ReplacePreviousSentence => EditSignalScope::Sentence, + EditIntentAction::DropEditCue => EditSignalScope::Unknown, + } +} + +fn trim_terminal_cancel_scope(out: &mut Vec, kind: CorrectionKind) -> EditIntentAction { + match kind { + CorrectionKind::Phrase => { + trim_last_clause(out); + EditIntentAction::ReplacePreviousClause + } + CorrectionKind::Sentence => { + if ends_with_sentence_boundary(out) { + trim_last_sentence(out); + EditIntentAction::ReplacePreviousSentence + } else { + trim_last_clause(out); + EditIntentAction::ReplacePreviousClause + } + } + } +} + +fn count_upcoming_words(pieces: &[Piece], start: usize) -> usize { + let mut count = 0; + for piece in pieces.iter().skip(start) { + match piece { + Piece::Word(_) => count += 1, + Piece::Break(_) => break, + Piece::Punctuation(ch) if is_strong_boundary(*ch) => break, + _ => {} + } + } + count +} + +fn trim_recent_phrase(out: &mut Vec, profile: CleanupProfile, replacement_words: usize) { + trim_soft_suffix(out); + + let max_words = match profile { + CleanupProfile::Basic => replacement_words.clamp(1, 3), + CleanupProfile::Aggressive => replacement_words.clamp(2, 6), + }; + + let mut removed_words = 0usize; + while removed_words < max_words { + match out.pop() { + Some(Piece::Word(_)) => removed_words += 1, + Some(Piece::Punctuation(ch)) if !is_strong_boundary(ch) => {} + Some(piece @ Piece::Punctuation(_)) | Some(piece @ Piece::Break(_)) => { + out.push(piece); + break; + } + None => break, + } + } + + if profile == CleanupProfile::Aggressive { + while let Some(piece) = out.pop() { + match piece { + Piece::Word(_) => continue, + Piece::Punctuation(ch) if !is_strong_boundary(ch) => continue, + other => { + out.push(other); + break; + } + } + } + } + + trim_soft_suffix(out); +} + +fn trim_last_sentence(out: &mut Vec) { + trim_trailing_boundaries(out); + + let mut removed_word = false; + while let Some(piece) = out.pop() { + match piece { + Piece::Word(_) => removed_word = true, + Piece::Punctuation(ch) if is_strong_boundary(ch) => { + if removed_word { + break; + } + } + Piece::Break(_) => { + if removed_word { + break; + } + } + Piece::Punctuation(_) => {} + } + } + + trim_soft_suffix(out); +} + +fn trim_last_clause(out: &mut Vec) { + trim_trailing_boundaries(out); + + let mut removed_word = false; + while let Some(piece) = out.pop() { + match piece { + Piece::Word(_) => removed_word = true, + Piece::Punctuation(ch) if is_clause_boundary(ch) => { + if removed_word { + break; + } + } + Piece::Break(_) => { + if removed_word { + break; + } + } + Piece::Punctuation(_) => {} + } + } + + trim_soft_suffix(out); +} + +fn trim_soft_suffix(out: &mut Vec) { + while let Some(last) = out.last() { + match last { + Piece::Punctuation(ch) if !is_strong_boundary(*ch) => { + out.pop(); + } + _ => break, + } + } +} + +fn output_word_count(pieces: &[Piece]) -> usize { + pieces + .iter() + .filter(|piece| matches!(piece, Piece::Word(_))) + .count() +} + +fn trim_trailing_boundaries(out: &mut Vec) { + while let Some(last) = out.last() { + match last { + Piece::Punctuation(_) | Piece::Break(_) => { + out.pop(); + } + Piece::Word(_) => break, + } + } +} + +fn ends_with_sentence_boundary(out: &[Piece]) -> bool { + out.iter().rev().find_map(|piece| match piece { + Piece::Punctuation(ch) => Some(is_strong_boundary(*ch)), + Piece::Break(_) => Some(true), + Piece::Word(_) => None, + }) == Some(true) +} + +fn remove_fillers(pieces: Vec) -> Vec { + let mut out = Vec::with_capacity(pieces.len()); + for piece in pieces { + match piece { + Piece::Word(word) if is_filler(&word) => continue, + other => out.push(other), + } + } + out +} + +fn is_filler(word: &str) -> bool { + matches!( + super::normalized_word_str(word).as_str(), + "um" | "umm" | "uh" | "uhh" | "er" | "erm" | "ah" + ) +} diff --git a/src/cleanup/lexicon.rs b/src/cleanup/lexicon.rs new file mode 100644 index 0000000..47f4761 --- /dev/null +++ b/src/cleanup/lexicon.rs @@ -0,0 +1,408 @@ +use super::{ + CueFamilySpec, EditHypothesis, EditHypothesisMatchSource, EditSignal, EditSignalKind, + EditSignalScope, EditSignalStrength, ObservedWord, Piece, ReplacementScope, TailShape, + match_words_with_soft_punctuation, normalized_word_str, skip_correction_gap, tokenize, +}; + +pub(super) fn explicit_followup_cue_lookahead(pieces: &[Piece]) -> Option { + for spec in cue_family_specs() { + if !matches!(spec.kind, EditSignalKind::Cancel) { + continue; + } + + if let Some(end) = match_words_with_soft_punctuation(pieces, 0, spec.canonical) { + return Some(skip_correction_gap(pieces, end)); + } + + for alias in spec.session_followup_aliases { + if let Some(end) = match_words_with_soft_punctuation(pieces, 0, alias) { + return Some(skip_correction_gap(pieces, end)); + } + } + } + + None +} + +pub(super) fn cue_family_specs() -> &'static [CueFamilySpec] { + const SCRATCH_THAT_ALIASES: &[&[&str]] = &[&["scratchthat"]]; + const SCRATCH_THAT_SESSION_ALIASES: &[&[&str]] = &[ + &["scratchthat"], + &["scratchvat"], + &["scratchfat"], + &["scratchfart"], + &["scratchfarts"], + &["scratchfot"], + &["scratchbot"], + &["scratchvatnotes"], + &["rajvat"], + &["srajvat"], + &["srajvatnotes"], + ]; + const NEVER_MIND_ALIASES: &[&[&str]] = &[&["nevermind"]]; + const NEVER_MIND_SESSION_ALIASES: &[&[&str]] = &[&["nevermind"], &["nevarmind"]]; + const WAIT_NO_ALIASES: &[&[&str]] = &[&["waitno"]]; + const I_MEANT_ALIASES: &[&[&str]] = &[&["imeant"]]; + const I_MEAN_ALIASES: &[&[&str]] = &[&["imean"]]; + const OR_RATHER_ALIASES: &[&[&str]] = &[&["orrather"]]; + const SPECS: &[CueFamilySpec] = &[ + CueFamilySpec { + cue_family: "scratch_that", + kind: EditSignalKind::Cancel, + scope_hint: EditSignalScope::Sentence, + canonical: &["scratch", "that"], + aliases: SCRATCH_THAT_ALIASES, + session_followup_aliases: SCRATCH_THAT_SESSION_ALIASES, + }, + CueFamilySpec { + cue_family: "never_mind", + kind: EditSignalKind::Cancel, + scope_hint: EditSignalScope::Sentence, + canonical: &["never", "mind"], + aliases: NEVER_MIND_ALIASES, + session_followup_aliases: NEVER_MIND_SESSION_ALIASES, + }, + CueFamilySpec { + cue_family: "wait_no", + kind: EditSignalKind::Replace, + scope_hint: EditSignalScope::Sentence, + canonical: &["wait", "no"], + aliases: WAIT_NO_ALIASES, + session_followup_aliases: &[], + }, + CueFamilySpec { + cue_family: "i_meant", + kind: EditSignalKind::Replace, + scope_hint: EditSignalScope::Phrase, + canonical: &["i", "meant"], + aliases: I_MEANT_ALIASES, + session_followup_aliases: &[], + }, + CueFamilySpec { + cue_family: "i_mean", + kind: EditSignalKind::Replace, + scope_hint: EditSignalScope::Phrase, + canonical: &["i", "mean"], + aliases: I_MEAN_ALIASES, + session_followup_aliases: &[], + }, + CueFamilySpec { + cue_family: "or_rather", + kind: EditSignalKind::Restatement, + scope_hint: EditSignalScope::Phrase, + canonical: &["or", "rather"], + aliases: OR_RATHER_ALIASES, + session_followup_aliases: &[], + }, + ]; + + SPECS +} + +pub(super) fn collect_edit_hypotheses( + raw: &str, + edit_signals: &[EditSignal], +) -> Vec { + let observed_words = collect_observed_words(raw); + if observed_words.is_empty() { + return Vec::new(); + } + + let mut hypotheses = Vec::new(); + + for spec in cue_family_specs() { + let mut index = 0usize; + while index < observed_words.len() { + let Some((match_source, matched_len)) = + match_cue_family_at(&observed_words, index, spec) + else { + index += 1; + continue; + }; + + let signal = signal_for_cue_family(edit_signals, spec.cue_family); + let strength = signal + .map(|signal| signal.strength) + .unwrap_or(match match_source { + EditHypothesisMatchSource::NearMiss => EditSignalStrength::Possible, + EditHypothesisMatchSource::Exact | EditHypothesisMatchSource::Alias => { + EditSignalStrength::Strong + } + }); + let scope_hint = signal + .map(|signal| signal.scope_hint) + .unwrap_or(spec.scope_hint); + let kind = signal.map(|signal| signal.kind).unwrap_or(spec.kind); + let (replacement_scope, tail_shape) = + classify_replacement_scope(&observed_words, index + matched_len, scope_hint); + + hypotheses.push(EditHypothesis { + cue_family: spec.cue_family, + matched_text: observed_words[index..index + matched_len] + .iter() + .map(|word| word.text.as_str()) + .collect::>() + .join(" "), + match_source, + kind, + scope_hint, + replacement_scope, + tail_shape, + strength, + word_start: index, + word_end: index + matched_len, + }); + + index += matched_len.max(1); + } + } + + hypotheses.sort_by_key(|hypothesis| (hypothesis.word_start, hypothesis.word_end)); + hypotheses.dedup_by(|right, left| { + right.cue_family == left.cue_family + && right.word_start == left.word_start + && right.word_end == left.word_end + }); + hypotheses +} + +fn collect_observed_words(raw: &str) -> Vec { + tokenize(raw) + .into_iter() + .filter_map(|piece| match piece { + Piece::Word(text) => { + let normalized = normalized_word_str(&text); + (!normalized.is_empty()).then_some(ObservedWord { text, normalized }) + } + Piece::Punctuation(_) | Piece::Break(_) => None, + }) + .collect() +} + +fn match_cue_family_at( + observed_words: &[ObservedWord], + start: usize, + spec: &CueFamilySpec, +) -> Option<(EditHypothesisMatchSource, usize)> { + if matches_observed_words(observed_words, start, spec.canonical) { + return Some((EditHypothesisMatchSource::Exact, spec.canonical.len())); + } + + for alias in spec.aliases { + if matches_observed_words(observed_words, start, alias) { + return Some((EditHypothesisMatchSource::Alias, alias.len())); + } + } + + let mut candidate_lengths = vec![1usize, spec.canonical.len()]; + candidate_lengths.extend(spec.aliases.iter().map(|alias| alias.len())); + candidate_lengths.sort_unstable(); + candidate_lengths.dedup(); + + for len in candidate_lengths { + if start + len > observed_words.len() { + continue; + } + + let observed = compact_observed_words(&observed_words[start..start + len]); + if observed.is_empty() { + continue; + } + + if is_limited_near_miss(&observed, &compact_phrase(spec.canonical)) + || spec + .aliases + .iter() + .any(|alias| is_limited_near_miss(&observed, &compact_phrase(alias))) + { + return Some((EditHypothesisMatchSource::NearMiss, len)); + } + } + + None +} + +fn matches_observed_words( + observed_words: &[ObservedWord], + start: usize, + expected: &[&str], +) -> bool { + expected.iter().enumerate().all(|(offset, expected_word)| { + observed_words + .get(start + offset) + .map(|word| word.normalized.as_str()) + == Some(*expected_word) + }) +} + +fn compact_observed_words(observed_words: &[ObservedWord]) -> String { + observed_words + .iter() + .map(|word| word.normalized.as_str()) + .collect::>() + .join("") +} + +fn compact_phrase(words: &[&str]) -> String { + words.join("") +} + +fn signal_for_cue_family<'a>( + edit_signals: &'a [EditSignal], + cue_family: &str, +) -> Option<&'a EditSignal> { + edit_signals + .iter() + .find(|signal| cue_family_for_phrase(signal.trigger) == Some(cue_family)) +} + +fn cue_family_for_phrase(phrase: &str) -> Option<&'static str> { + match phrase { + "scratch that" | "actually scratch that" => Some("scratch_that"), + "never mind" + | "nevermind" + | "actually never mind" + | "actually nevermind" + | "oh wait never mind" + | "oh wait nevermind" + | "forget that" => Some("never_mind"), + "wait no" | "actually wait no" => Some("wait_no"), + "i meant" | "actually i meant" => Some("i_meant"), + "i mean" | "actually i mean" => Some("i_mean"), + "or rather" => Some("or_rather"), + _ => None, + } +} + +fn classify_replacement_scope( + observed_words: &[ObservedWord], + tail_start: usize, + scope_hint: EditSignalScope, +) -> (ReplacementScope, TailShape) { + let tail_words = observed_words + .get(tail_start..) + .unwrap_or(&[]) + .iter() + .map(|word| word.normalized.as_str()) + .take(8) + .collect::>(); + + if tail_words.is_empty() { + let replacement_scope = match scope_hint { + EditSignalScope::Sentence => ReplacementScope::Sentence, + EditSignalScope::Clause => ReplacementScope::Clause, + EditSignalScope::Phrase | EditSignalScope::Unknown => ReplacementScope::Span, + }; + return (replacement_scope, TailShape::Empty); + } + + let tail_shape = if looks_like_clause_tail(&tail_words) { + TailShape::Clause + } else { + TailShape::Phrase + }; + + let replacement_scope = match scope_hint { + EditSignalScope::Sentence => { + if matches!(tail_shape, TailShape::Phrase) && tail_words.len() > 3 { + ReplacementScope::Clause + } else { + ReplacementScope::Sentence + } + } + EditSignalScope::Clause => { + if matches!(tail_shape, TailShape::Phrase) { + ReplacementScope::Span + } else { + ReplacementScope::Clause + } + } + EditSignalScope::Phrase => ReplacementScope::Span, + EditSignalScope::Unknown => { + if matches!(tail_shape, TailShape::Clause) { + ReplacementScope::Clause + } else { + ReplacementScope::Span + } + } + }; + + (replacement_scope, tail_shape) +} + +fn looks_like_clause_tail(tail_words: &[&str]) -> bool { + const CLAUSE_WORDS: &[&str] = &[ + "am", "are", "be", "been", "being", "can", "could", "did", "do", "does", "had", "has", + "have", "is", "must", "need", "needs", "required", "requires", "should", "was", "were", + "will", "would", + ]; + const SUBJECT_WORDS: &[&str] = &[ + "i", "it", "he", "she", "they", "we", "you", "this", "that", "there", "my", "our", "their", + "your", + ]; + + tail_words.iter().any(|word| CLAUSE_WORDS.contains(word)) + || (tail_words.len() >= 2 + && SUBJECT_WORDS.contains(&tail_words[0]) + && !matches!(tail_words[1], "and" | "or" | "but")) +} + +fn is_limited_near_miss(observed: &str, target: &str) -> bool { + if observed.is_empty() || target.is_empty() || observed == target { + return false; + } + + let common_prefix = observed + .chars() + .zip(target.chars()) + .take_while(|(left, right)| left == right) + .count(); + if common_prefix < observed.len().min(target.len()).min(4) { + return false; + } + + let observed_prefix = if observed.chars().count() > target.chars().count() { + observed + .chars() + .take(target.chars().count()) + .collect::() + } else { + observed.to_string() + }; + let distance = bounded_levenshtein(&observed_prefix, target, 3); + distance <= 3 +} + +fn bounded_levenshtein(left: &str, right: &str, max_distance: usize) -> usize { + let left_chars: Vec = left.chars().collect(); + let right_chars: Vec = right.chars().collect(); + + if left_chars.is_empty() { + return right_chars.len(); + } + if right_chars.is_empty() { + return left_chars.len(); + } + if left_chars.len().abs_diff(right_chars.len()) > max_distance { + return max_distance + 1; + } + + let mut prev: Vec = (0..=right_chars.len()).collect(); + let mut curr = vec![0usize; right_chars.len() + 1]; + + for (i, left_char) in left_chars.iter().enumerate() { + curr[0] = i + 1; + let mut row_min = curr[0]; + for (j, right_char) in right_chars.iter().enumerate() { + let cost = usize::from(left_char != right_char); + curr[j + 1] = (prev[j + 1] + 1).min(curr[j] + 1).min(prev[j] + cost); + row_min = row_min.min(curr[j + 1]); + } + if row_min > max_distance { + return max_distance + 1; + } + std::mem::swap(&mut prev, &mut curr); + } + + prev[right_chars.len()] +} diff --git a/src/cleanup/render.rs b/src/cleanup/render.rs new file mode 100644 index 0000000..c3c3d2d --- /dev/null +++ b/src/cleanup/render.rs @@ -0,0 +1,68 @@ +use super::{BreakKind, Piece}; + +pub(super) fn render_pieces(pieces: &[Piece]) -> String { + let mut rendered = String::new(); + let mut capitalize_next = true; + + for piece in pieces { + match piece { + Piece::Word(word) => { + if !rendered.is_empty() && !rendered.ends_with([' ', '\n']) { + rendered.push(' '); + } + if capitalize_next { + rendered.push_str(&capitalize_first(word)); + } else { + rendered.push_str(word); + } + capitalize_next = false; + } + Piece::Punctuation(ch) => { + trim_trailing_spaces(&mut rendered); + rendered.push(*ch); + capitalize_next = matches!(ch, '.' | '?' | '!'); + } + Piece::Break(BreakKind::Line) => { + trim_trailing_spaces(&mut rendered); + if !rendered.is_empty() { + if !rendered.ends_with('\n') { + rendered.push('\n'); + } + capitalize_next = true; + } + } + Piece::Break(BreakKind::Paragraph) => { + trim_trailing_spaces(&mut rendered); + if !rendered.is_empty() { + while rendered.ends_with('\n') { + rendered.pop(); + } + rendered.push('\n'); + rendered.push('\n'); + capitalize_next = true; + } + } + } + } + + trim_trailing_spaces(&mut rendered); + rendered +} + +fn trim_trailing_spaces(text: &mut String) { + while text.ends_with(' ') { + text.pop(); + } +} + +fn capitalize_first(word: &str) -> String { + let mut chars = word.chars(); + let Some(first) = chars.next() else { + return String::new(); + }; + + let mut result = String::new(); + result.extend(first.to_uppercase()); + result.extend(chars); + result +} diff --git a/src/cleanup/tests.rs b/src/cleanup/tests.rs new file mode 100644 index 0000000..fd237cc --- /dev/null +++ b/src/cleanup/tests.rs @@ -0,0 +1,245 @@ +use super::*; +use crate::config::{CleanupConfig, CleanupProfile}; +use crate::transcribe::Transcript; + +fn transcript(text: &str) -> Transcript { + Transcript { + raw_text: text.to_string(), + detected_language: Some("en".to_string()), + segments: Vec::new(), + } +} + +#[test] +fn removes_common_fillers() { + let cleaned = clean_transcript( + &transcript("um i think we should go"), + &CleanupConfig::default(), + ); + assert_eq!(cleaned, "I think we should go"); +} + +#[test] +fn preserves_non_filler_words() { + let cleaned = clean_transcript(&transcript("i like apples"), &CleanupConfig::default()); + assert_eq!(cleaned, "I like apples"); +} + +#[test] +fn converts_spoken_punctuation_commands() { + let cleaned = clean_transcript( + &transcript("hello comma world question mark"), + &CleanupConfig::default(), + ); + assert_eq!(cleaned, "Hello, world?"); +} + +#[test] +fn converts_spoken_line_and_paragraph_commands() { + let cleaned = clean_transcript( + &transcript("first line new line second line new paragraph third line"), + &CleanupConfig::default(), + ); + assert_eq!(cleaned, "First line\nSecond line\n\nThird line"); +} + +#[test] +fn basic_backtrack_replaces_recent_phrase() { + let cleaned = clean_transcript( + &transcript("let's meet at 4 actually no 3"), + &CleanupConfig::default(), + ); + assert_eq!(cleaned, "Let's meet at 3"); +} + +#[test] +fn standalone_actually_is_preserved() { + let cleaned = correction_aware_text(&transcript("it actually works")); + assert_eq!(cleaned, "It actually works"); +} + +#[test] +fn rather_is_preserved_in_normal_phrasing() { + let cleaned = correction_aware_text(&transcript("i would rather stay home")); + assert_eq!(cleaned, "I would rather stay home"); +} + +#[test] +fn actually_rather_is_preserved_in_normal_phrasing() { + let cleaned = correction_aware_text(&transcript("i would actually rather stay home")); + assert_eq!(cleaned, "I would actually rather stay home"); +} + +#[test] +fn punctuated_wait_no_replaces_last_sentence() { + let cleaned = clean_transcript( + &transcript("hi there, this is a test of whisper osd. wait, no. hi there."), + &CleanupConfig::default(), + ); + assert_eq!(cleaned, "Hi there."); +} + +#[test] +fn punctuated_wait_no_still_replaces_inline_phrase() { + let cleaned = clean_transcript( + &transcript("let's meet at 4 wait, no, 3"), + &CleanupConfig::default(), + ); + assert_eq!(cleaned, "Let's meet at 3"); +} + +#[test] +fn scratch_that_replaces_recent_word() { + let cleaned = clean_transcript( + &transcript("i'll bring cookies scratch that brownies"), + &CleanupConfig::default(), + ); + assert_eq!(cleaned, "I'll bring brownies"); +} + +#[test] +fn correction_aware_text_drops_previous_sentence_for_scratch_that() { + let cleaned = correction_aware_text(&transcript( + "hello there, this is a test of whisper rs. scratch that. hi.", + )); + assert_eq!(cleaned, "Hi."); +} + +#[test] +fn terminal_never_mind_cancels_previous_clause() { + let cleaned = correction_aware_text(&transcript("hello there oh wait never mind")); + assert_eq!(cleaned, ""); +} + +#[test] +fn utterance_initial_never_mind_is_dropped_when_content_follows() { + let cleaned = correction_aware_text(&transcript("never mind. hi, how are you today?")); + assert_eq!(cleaned, "Hi, how are you today?"); +} + +#[test] +fn explicit_followup_replacement_handles_session_aliases() { + assert_eq!( + explicit_followup_replacement("srajvat, hi").as_deref(), + Some("Hi") + ); + assert_eq!( + explicit_followup_replacement("scratchfarts, hi").as_deref(), + Some("Hi") + ); +} + +#[test] +fn correction_analysis_reports_terminal_cancel_intent() { + let analysis = correction_analysis(&transcript("hello there never mind")); + assert_eq!(analysis.text, ""); + assert_eq!(analysis.edit_intents.len(), 1); + assert_eq!( + analysis.edit_intents[0].action, + EditIntentAction::ReplacePreviousClause + ); + assert_eq!(analysis.edit_intents[0].trigger, "never mind"); + assert_eq!(analysis.edit_signals.len(), 1); + assert_eq!(analysis.edit_signals[0].kind, EditSignalKind::Cancel); + assert_eq!(analysis.edit_signals[0].scope_hint, EditSignalScope::Clause); + assert_eq!( + analysis.edit_signals[0].strength, + EditSignalStrength::Strong + ); + assert_eq!(analysis.edit_hypotheses.len(), 1); + assert_eq!(analysis.edit_hypotheses[0].cue_family, "never_mind"); + assert_eq!( + analysis.edit_hypotheses[0].match_source, + EditHypothesisMatchSource::Exact + ); +} + +#[test] +fn utterance_initial_wait_no_is_not_treated_as_backtrack() { + let analysis = correction_analysis(&transcript("wait, no, it actually works")); + assert_eq!(analysis.text, "Wait, no, it actually works"); + assert_eq!(analysis.edit_intents.len(), 0); + assert_eq!(analysis.edit_signals.len(), 1); + assert_eq!(analysis.edit_signals[0].kind, EditSignalKind::Replace); + assert_eq!( + analysis.edit_signals[0].scope_hint, + EditSignalScope::Sentence + ); + assert_eq!( + analysis.edit_signals[0].strength, + EditSignalStrength::Possible + ); + assert!(analysis.edit_hypotheses.iter().any(|hypothesis| { + hypothesis.cue_family == "wait_no" + && hypothesis.match_source == EditHypothesisMatchSource::Exact + })); +} + +#[test] +fn correction_analysis_reports_restatement_signal_for_or_rather() { + let analysis = correction_analysis(&transcript("let's meet tomorrow or rather friday")); + assert_eq!(analysis.text, "Let's meet friday"); + assert_eq!(analysis.edit_signals.len(), 1); + assert_eq!(analysis.edit_signals[0].kind, EditSignalKind::Restatement); + assert_eq!(analysis.edit_signals[0].scope_hint, EditSignalScope::Phrase); +} + +#[test] +fn correction_analysis_exposes_aggressive_candidate_for_ambiguous_replacement() { + let analysis = correction_analysis(&transcript( + "my name is notes, scratch that my name is jonatan", + )); + assert_eq!(analysis.text, "My my name is jonatan"); + assert_eq!( + analysis.aggressive_text.as_deref(), + Some("My name is jonatan") + ); + assert_eq!( + analysis.edit_hypotheses[0].replacement_scope, + ReplacementScope::Clause + ); + assert_eq!(analysis.edit_hypotheses[0].tail_shape, TailShape::Clause); +} + +#[test] +fn correction_analysis_collects_near_miss_hypothesis_for_scratch_that_family() { + let analysis = correction_analysis(&transcript("hello there scratch vat hi")); + assert!(analysis.edit_hypotheses.iter().any(|hypothesis| { + hypothesis.cue_family == "scratch_that" + && hypothesis.match_source == EditHypothesisMatchSource::NearMiss + && hypothesis.matched_text == "scratch vat" + })); +} + +#[test] +fn correction_analysis_marks_phrase_tail_as_span_scope() { + let analysis = correction_analysis(&transcript( + "mobile apps or sms codes scratch that just sms codes", + )); + assert!(analysis.edit_hypotheses.iter().any(|hypothesis| { + hypothesis.cue_family == "scratch_that" + && hypothesis.replacement_scope == ReplacementScope::Span + && hypothesis.tail_shape == TailShape::Phrase + })); +} + +#[test] +fn aggressive_profile_trims_more_context() { + let config = CleanupConfig { + profile: CleanupProfile::Aggressive, + ..CleanupConfig::default() + }; + let cleaned = clean_transcript(&transcript("alpha beta gamma delta wait no omega"), &config); + assert_eq!(cleaned, "Omega"); +} + +#[test] +fn skips_advanced_cleanup_for_non_english_transcripts() { + let transcript = Transcript { + raw_text: "um hola comma mundo".to_string(), + detected_language: Some("es".to_string()), + segments: Vec::new(), + }; + let cleaned = clean_transcript(&transcript, &CleanupConfig::default()); + assert_eq!(cleaned, "um hola comma mundo"); +} diff --git a/src/cli.rs b/src/cli.rs index 4030a87..d431781 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -2,6 +2,8 @@ use std::path::PathBuf; use clap::{Parser, Subcommand, ValueEnum}; +use crate::rewrite_protocol::{RewriteCorrectionPolicy, RewriteSurfaceKind}; + #[derive(Parser, Debug)] #[command( name = "whispers", @@ -72,6 +74,18 @@ pub enum Command { action: DictionaryAction, }, + /// Manage app-aware rewrite policy rules + AppRule { + #[command(subcommand)] + action: AppRuleAction, + }, + + /// Manage technical glossary entries for agentic rewrite + Glossary { + #[command(subcommand)] + action: GlossaryAction, + }, + /// Check optional cloud ASR/rewrite configuration and connectivity Cloud { #[command(subcommand)] @@ -169,6 +183,91 @@ pub enum DictionaryAction { }, } +#[derive(Subcommand, Debug)] +pub enum AppRuleAction { + /// Print the configured app rule file path + Path, + + /// List configured app rules + List, + + /// Add or update an app rule + Add { + /// Stable rule name used for updates and removals + name: String, + + /// Instructions appended to the effective rewrite prompt + instructions: String, + + /// Match on the active surface kind + #[arg(long)] + surface_kind: Option, + + /// Match on the exact app ID + #[arg(long)] + app_id: Option, + + /// Case-insensitive substring match on the window title + #[arg(long)] + window_title_contains: Option, + + /// Case-insensitive substring match on the browser domain + #[arg(long)] + browser_domain_contains: Option, + + /// Override the effective correction policy + #[arg(long)] + correction_policy: Option, + }, + + /// Remove an app rule by name + Remove { + /// Rule name to remove + name: String, + }, +} + +#[derive(Subcommand, Debug)] +pub enum GlossaryAction { + /// Print the configured glossary file path + Path, + + /// List configured glossary entries + List, + + /// Add or update a glossary entry + Add { + /// Canonical output term + term: String, + + /// Alias that should map to the canonical term + #[arg(long = "alias", required = true)] + aliases: Vec, + + /// Match on the active surface kind + #[arg(long)] + surface_kind: Option, + + /// Match on the exact app ID + #[arg(long)] + app_id: Option, + + /// Case-insensitive substring match on the window title + #[arg(long)] + window_title_contains: Option, + + /// Case-insensitive substring match on the browser domain + #[arg(long)] + browser_domain_contains: Option, + }, + + /// Remove a glossary entry by canonical term + Remove { + /// Canonical term to remove + term: String, + }, +} + #[derive(Subcommand, Debug)] pub enum SnippetAction { /// List snippets @@ -268,6 +367,49 @@ mod tests { )); } + #[test] + fn parses_app_rule_add_subcommand() { + let cli = Cli::try_parse_from([ + "whispers", + "app-rule", + "add", + "zed", + "Preserve identifiers.", + "--app-id", + "dev.zed.Zed", + "--correction-policy", + "balanced", + ]) + .unwrap(); + assert!(matches!( + cli.command, + Some(Command::AppRule { + action: AppRuleAction::Add { .. } + }) + )); + } + + #[test] + fn parses_glossary_add_subcommand() { + let cli = Cli::try_parse_from([ + "whispers", + "glossary", + "add", + "TypeScript", + "--alias", + "type script", + "--surface-kind", + "editor", + ]) + .unwrap(); + assert!(matches!( + cli.command, + Some(Command::Glossary { + action: GlossaryAction::Add { .. } + }) + )); + } + #[test] fn parses_snippet_add_subcommand() { let cli = Cli::try_parse_from(["whispers", "snippets", "add", "sig", "Best"]).unwrap(); diff --git a/src/cloud.rs b/src/cloud.rs index 9e64e0a..fa04ea8 100644 --- a/src/cloud.rs +++ b/src/cloud.rs @@ -13,7 +13,10 @@ use crate::config::{ }; use crate::error::{Result, WhsprError}; use crate::personalization; -use crate::rewrite_profile::{ResolvedRewriteProfile, RewriteProfile}; +use crate::rewrite::{ + RewritePrompt, build_prompt as build_rewrite_prompt, resolved_profile_for_cloud, + sanitize_rewrite_output, +}; use crate::rewrite_protocol::RewriteTranscript; use crate::transcribe::{Transcript, TranscriptSegment}; @@ -160,7 +163,9 @@ impl CloudService { custom_instructions: Option<&str>, ) -> Result { let started = Instant::now(); - let prompt = build_cloud_rewrite_prompt(config, transcript, custom_instructions); + let profile = resolved_profile_for_cloud(config.rewrite.profile); + let prompt = build_rewrite_prompt(transcript, profile, custom_instructions) + .map_err(|e| WhsprError::Rewrite(format!("failed to build rewrite prompt: {e}")))?; let url = format!("{}/chat/completions", self.base_url); let request = ChatCompletionsRequest::from_prompt(config, &prompt); let request_started = Instant::now(); @@ -211,9 +216,8 @@ impl CloudService { pub fn validate_config(config: &Config) -> Result<()> { let uses_cloud_asr = config.transcription.backend == TranscriptionBackend::Cloud; - let uses_cloud_rewrite = config.postprocess.mode - == crate::config::PostprocessMode::AdvancedLocal - && config.rewrite.backend == RewriteBackend::Cloud; + let uses_cloud_rewrite = + config.postprocess.mode.uses_rewrite() && config.rewrite.backend == RewriteBackend::Cloud; if !uses_cloud_asr && !uses_cloud_rewrite { return Ok(()); } @@ -453,12 +457,6 @@ fn samples_to_ms(sample_count: usize, sample_rate: u32) -> u32 { ((sample_count as f64 / sample_rate as f64) * 1000.0).round() as u32 } -#[derive(Debug, Clone, PartialEq, Eq)] -struct RewritePrompt { - system: String, - user: String, -} - #[derive(serde::Serialize)] struct ChatCompletionsRequest { model: String, @@ -467,132 +465,6 @@ struct ChatCompletionsRequest { max_tokens: usize, } -fn build_cloud_rewrite_prompt( - config: &Config, - transcript: &RewriteTranscript, - custom_instructions: Option<&str>, -) -> RewritePrompt { - let profile = resolved_profile_for_cloud(config.rewrite.profile); - let system = build_system_instructions(profile, custom_instructions); - let language = transcript.detected_language.as_deref().unwrap_or("unknown"); - let correction_aware = transcript.correction_aware_text.trim(); - let raw = transcript.raw_text.trim(); - let recommended = transcript - .recommended_session_candidate - .as_ref() - .map(|candidate| candidate.text.as_str()) - .or_else(|| { - transcript - .recommended_candidate - .as_ref() - .map(|candidate| candidate.text.as_str()) - }) - .unwrap_or(""); - let candidates = if transcript.rewrite_candidates.is_empty() { - "- none\n".to_string() - } else { - transcript - .rewrite_candidates - .iter() - .map(|candidate| format!("- {}\n", candidate.text)) - .collect::() - }; - let user = format!( - "Language: {language}\n\ -Correction-aware transcript:\n{correction_aware}\n\ -Raw transcript:\n{raw}\n\ -Recommended interpretation:\n{recommended}\n\ -Candidate interpretations:\n{candidates}\ -Final text:" - ); - RewritePrompt { system, user } -} - -fn build_system_instructions( - profile: ResolvedRewriteProfile, - custom_instructions: Option<&str>, -) -> String { - let mut instructions = rewrite_instructions(profile).to_string(); - if let Some(custom) = custom_instructions - .map(str::trim) - .filter(|text| !text.is_empty()) - { - instructions.push_str("\n\nAdditional user rewrite instructions:\n"); - instructions.push_str(custom); - } - instructions -} - -fn rewrite_instructions(profile: ResolvedRewriteProfile) -> &'static str { - let base = "You clean up dictated speech into the final text the user meant to type. Return only the finished text. Do not explain anything. Remove obvious disfluencies when natural. Use the correction-aware transcript as the primary source of truth unless structured edit cues still make the utterance ambiguous. Never reintroduce text that was removed by an explicit spoken correction cue. Respect any recommended interpretation and candidate interpretations provided alongside the transcript. Prefer transcript spellings for names, brands, and uncommon proper nouns unless a user dictionary or explicit correction says otherwise. Do not normalize names into more common spellings just because they look familiar."; - match profile { - ResolvedRewriteProfile::Qwen => { - "You clean up dictated speech into the final text the user meant to type. Return only the finished text. Do not explain anything. Do not emit reasoning, think tags, or XML wrappers. Remove obvious disfluencies when natural. Use the correction-aware transcript as the primary source of truth unless structured edit cues still make the utterance ambiguous. Never reintroduce text that was removed by an explicit spoken correction cue. Respect any recommended interpretation and candidate interpretations provided alongside the transcript. Prefer transcript spellings for names, brands, and uncommon proper nouns unless a user dictionary or explicit correction says otherwise. Do not normalize names into more common spellings just because they look familiar." - } - ResolvedRewriteProfile::Generic | ResolvedRewriteProfile::LlamaCompat => base, - } -} - -fn resolved_profile_for_cloud(profile: RewriteProfile) -> ResolvedRewriteProfile { - match profile { - RewriteProfile::Auto => ResolvedRewriteProfile::Generic, - RewriteProfile::Generic => ResolvedRewriteProfile::Generic, - RewriteProfile::Qwen => ResolvedRewriteProfile::Qwen, - RewriteProfile::LlamaCompat => ResolvedRewriteProfile::LlamaCompat, - } -} - -fn sanitize_rewrite_output(raw: &str) -> String { - let mut text = raw.replace("\r\n", "\n"); - - for stop in ["<|eot_id|>", "<|end_of_text|>", ""] { - if let Some(index) = text.find(stop) { - text.truncate(index); - } - } - if let Some(index) = text.find("") { - text.truncate(index); - } - - text = strip_tagged_section(&text, "", ""); - let mut text = text.trim().to_string(); - - if let Some(stripped) = text.strip_prefix("") { - text = stripped.trim().to_string(); - } - for prefix in ["Final text:", "Output:", "Rewritten text:"] { - if text - .get(..prefix.len()) - .map(|candidate| candidate.eq_ignore_ascii_case(prefix)) - .unwrap_or(false) - { - text = text[prefix.len()..].trim().to_string(); - break; - } - } - if text.starts_with('"') && text.ends_with('"') && text.len() >= 2 { - text = text[1..text.len() - 1].trim().to_string(); - } - text -} - -fn strip_tagged_section(input: &str, open: &str, close: &str) -> String { - let mut result = String::with_capacity(input.len()); - let mut remainder = input; - while let Some(start) = remainder.find(open) { - result.push_str(&remainder[..start]); - let after_open = &remainder[start + open.len()..]; - if let Some(end) = after_open.find(close) { - remainder = &after_open[end + close.len()..]; - } else { - remainder = ""; - break; - } - } - result.push_str(remainder); - result -} - impl ChatCompletionsRequest { fn from_prompt(config: &Config, prompt: &RewritePrompt) -> Self { Self { @@ -829,6 +701,7 @@ mod tests { edit_hypotheses: Vec::new(), rewrite_candidates: Vec::new(), recommended_candidate: None, + policy_context: crate::rewrite_protocol::RewritePolicyContext::default(), }, None, ) diff --git a/src/completions.rs b/src/completions.rs index b632c20..4d03a77 100644 --- a/src/completions.rs +++ b/src/completions.rs @@ -1,168 +1,22 @@ -use std::io::Write; -use std::path::Path; +mod detect; +mod render; -use clap::CommandFactory; -use clap_complete::{generate, shells}; -use clap_complete_nushell::Nushell; +#[cfg(test)] +mod tests; -use crate::cli::{Cli, CompletionShell}; +use crate::cli::CompletionShell; use crate::error::{Result, WhsprError}; const SUPPORTED_SHELLS: &str = "bash|zsh|fish|nushell"; pub fn run_completions(shell_arg: Option) -> Result<()> { - let shell = shell_arg.or_else(detect_shell).ok_or_else(|| { + let shell = shell_arg.or_else(detect::detect_shell).ok_or_else(|| { WhsprError::Config(format!( "could not detect shell automatically. Specify one manually: whispers completions <{SUPPORTED_SHELLS}>" )) })?; let mut stdout = std::io::stdout(); - write_completions(shell, &mut stdout); + render::write_completions(shell, &mut stdout); Ok(()) } - -fn detect_shell() -> Option { - detect_shell_from_env().or_else(detect_shell_from_parent_process) -} - -fn detect_shell_from_env() -> Option { - std::env::var("SHELL") - .ok() - .and_then(|value| shell_from_path_like(&value)) -} - -fn detect_shell_from_parent_process() -> Option { - let ppid = unsafe { libc::getppid() }; - if ppid <= 0 { - return None; - } - - detect_shell_from_pid(ppid) -} - -fn detect_shell_from_pid(pid: libc::pid_t) -> Option { - let cmdline_path = format!("/proc/{pid}/cmdline"); - if let Ok(cmdline) = std::fs::read(cmdline_path) { - if let Some(shell) = shell_from_cmdline_bytes(&cmdline) { - return Some(shell); - } - } - - let comm_path = format!("/proc/{pid}/comm"); - std::fs::read_to_string(comm_path) - .ok() - .and_then(|comm| shell_from_token(comm.trim())) -} - -fn shell_from_cmdline_bytes(bytes: &[u8]) -> Option { - let first = bytes.split(|b| *b == 0).next()?; - if first.is_empty() { - return None; - } - - let token = String::from_utf8_lossy(first); - shell_from_path_like(&token) -} - -fn shell_from_path_like(value: &str) -> Option { - let name = Path::new(value).file_name()?.to_string_lossy(); - shell_from_token(&name) -} - -fn shell_from_token(value: &str) -> Option { - let normalized = value.trim().trim_start_matches('-').to_ascii_lowercase(); - - match normalized.as_str() { - "bash" => Some(CompletionShell::Bash), - "zsh" => Some(CompletionShell::Zsh), - "fish" => Some(CompletionShell::Fish), - "nu" | "nushell" => Some(CompletionShell::Nushell), - _ => None, - } -} - -fn write_completions(shell: CompletionShell, out: &mut dyn Write) { - let mut cmd = Cli::command(); - let bin_name = cmd.get_name().to_string(); - - match shell { - CompletionShell::Bash => generate(shells::Bash, &mut cmd, &bin_name, out), - CompletionShell::Zsh => generate(shells::Zsh, &mut cmd, &bin_name, out), - CompletionShell::Fish => generate(shells::Fish, &mut cmd, &bin_name, out), - CompletionShell::Nushell => generate(Nushell, &mut cmd, &bin_name, out), - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn shell_detection_from_env_value_supports_paths() { - assert_eq!( - shell_from_path_like("/usr/bin/zsh"), - Some(CompletionShell::Zsh) - ); - assert_eq!( - shell_from_path_like("/bin/bash"), - Some(CompletionShell::Bash) - ); - } - - #[test] - fn shell_detection_accepts_login_shell_prefix_and_nu_alias() { - assert_eq!(shell_from_token("-fish"), Some(CompletionShell::Fish)); - assert_eq!(shell_from_token("nu"), Some(CompletionShell::Nushell)); - } - - #[test] - fn shell_detection_from_cmdline_uses_first_argv_entry() { - let cmdline = b"/usr/bin/fish\0-l\0"; - assert_eq!( - shell_from_cmdline_bytes(cmdline), - Some(CompletionShell::Fish) - ); - } - - #[test] - fn shell_detection_returns_none_for_unknown_values() { - assert_eq!(shell_from_path_like("/bin/tcsh"), None); - assert_eq!(shell_from_token("xonsh"), None); - assert_eq!(shell_from_cmdline_bytes(b""), None); - } - - fn generate_to_string(shell: CompletionShell) -> String { - let mut output = Vec::new(); - write_completions(shell, &mut output); - String::from_utf8(output).unwrap() - } - - #[test] - fn generates_bash_completion_script() { - let script = generate_to_string(CompletionShell::Bash); - assert!(script.contains("whispers")); - assert!(script.contains("complete")); - } - - #[test] - fn generates_zsh_completion_script() { - let script = generate_to_string(CompletionShell::Zsh); - assert!(script.contains("whispers")); - assert!(script.contains("compdef")); - } - - #[test] - fn generates_fish_completion_script() { - let script = generate_to_string(CompletionShell::Fish); - assert!(script.contains("whispers")); - assert!(script.contains("complete -c")); - } - - #[test] - fn generates_nushell_completion_script() { - let script = generate_to_string(CompletionShell::Nushell); - assert!(script.contains("whispers")); - assert!(script.contains("export extern")); - } -} diff --git a/src/completions/detect.rs b/src/completions/detect.rs new file mode 100644 index 0000000..9b5c8c6 --- /dev/null +++ b/src/completions/detect.rs @@ -0,0 +1,63 @@ +use std::path::Path; + +use crate::cli::CompletionShell; + +pub(super) fn detect_shell() -> Option { + detect_shell_from_env().or_else(detect_shell_from_parent_process) +} + +fn detect_shell_from_env() -> Option { + std::env::var("SHELL") + .ok() + .and_then(|value| shell_from_path_like(&value)) +} + +fn detect_shell_from_parent_process() -> Option { + let ppid = unsafe { libc::getppid() }; + if ppid <= 0 { + return None; + } + + detect_shell_from_pid(ppid) +} + +fn detect_shell_from_pid(pid: libc::pid_t) -> Option { + let cmdline_path = format!("/proc/{pid}/cmdline"); + if let Ok(cmdline) = std::fs::read(cmdline_path) { + if let Some(shell) = shell_from_cmdline_bytes(&cmdline) { + return Some(shell); + } + } + + let comm_path = format!("/proc/{pid}/comm"); + std::fs::read_to_string(comm_path) + .ok() + .and_then(|comm| shell_from_token(comm.trim())) +} + +pub(super) fn shell_from_cmdline_bytes(bytes: &[u8]) -> Option { + let first = bytes.split(|b| *b == 0).next()?; + if first.is_empty() { + return None; + } + + let token = String::from_utf8_lossy(first); + shell_from_path_like(&token) +} + +pub(super) fn shell_from_path_like(value: &str) -> Option { + let name = Path::new(value).file_name()?.to_string_lossy(); + shell_from_token(&name) +} + +pub(super) fn shell_from_token(value: &str) -> Option { + let normalized = value.trim().trim_start_matches('-').to_ascii_lowercase(); + + match normalized.as_str() { + "bash" => Some(CompletionShell::Bash), + "zsh" => Some(CompletionShell::Zsh), + "fish" => Some(CompletionShell::Fish), + "nu" | "nushell" => Some(CompletionShell::Nushell), + _ => None, + } +} diff --git a/src/completions/render.rs b/src/completions/render.rs new file mode 100644 index 0000000..3e59c7c --- /dev/null +++ b/src/completions/render.rs @@ -0,0 +1,19 @@ +use std::io::Write; + +use clap::CommandFactory; +use clap_complete::{generate, shells}; +use clap_complete_nushell::Nushell; + +use crate::cli::{Cli, CompletionShell}; + +pub(super) fn write_completions(shell: CompletionShell, out: &mut dyn Write) { + let mut cmd = Cli::command(); + let bin_name = cmd.get_name().to_string(); + + match shell { + CompletionShell::Bash => generate(shells::Bash, &mut cmd, &bin_name, out), + CompletionShell::Zsh => generate(shells::Zsh, &mut cmd, &bin_name, out), + CompletionShell::Fish => generate(shells::Fish, &mut cmd, &bin_name, out), + CompletionShell::Nushell => generate(Nushell, &mut cmd, &bin_name, out), + } +} diff --git a/src/completions/tests.rs b/src/completions/tests.rs new file mode 100644 index 0000000..255f862 --- /dev/null +++ b/src/completions/tests.rs @@ -0,0 +1,77 @@ +use crate::cli::CompletionShell; + +use super::{detect, render}; + +#[test] +fn shell_detection_from_env_value_supports_paths() { + assert_eq!( + detect::shell_from_path_like("/usr/bin/zsh"), + Some(CompletionShell::Zsh) + ); + assert_eq!( + detect::shell_from_path_like("/bin/bash"), + Some(CompletionShell::Bash) + ); +} + +#[test] +fn shell_detection_accepts_login_shell_prefix_and_nu_alias() { + assert_eq!( + detect::shell_from_token("-fish"), + Some(CompletionShell::Fish) + ); + assert_eq!( + detect::shell_from_token("nu"), + Some(CompletionShell::Nushell) + ); +} + +#[test] +fn shell_detection_from_cmdline_uses_first_argv_entry() { + let cmdline = b"/usr/bin/fish\0-l\0"; + assert_eq!( + detect::shell_from_cmdline_bytes(cmdline), + Some(CompletionShell::Fish) + ); +} + +#[test] +fn shell_detection_returns_none_for_unknown_values() { + assert_eq!(detect::shell_from_path_like("/bin/tcsh"), None); + assert_eq!(detect::shell_from_token("xonsh"), None); + assert_eq!(detect::shell_from_cmdline_bytes(b""), None); +} + +fn generate_to_string(shell: CompletionShell) -> String { + let mut output = Vec::new(); + render::write_completions(shell, &mut output); + String::from_utf8(output).unwrap() +} + +#[test] +fn generates_bash_completion_script() { + let script = generate_to_string(CompletionShell::Bash); + assert!(script.contains("whispers")); + assert!(script.contains("complete")); +} + +#[test] +fn generates_zsh_completion_script() { + let script = generate_to_string(CompletionShell::Zsh); + assert!(script.contains("whispers")); + assert!(script.contains("compdef")); +} + +#[test] +fn generates_fish_completion_script() { + let script = generate_to_string(CompletionShell::Fish); + assert!(script.contains("whispers")); + assert!(script.contains("complete -c")); +} + +#[test] +fn generates_nushell_completion_script() { + let script = generate_to_string(CompletionShell::Nushell); + assert!(script.contains("whispers")); + assert!(script.contains("export extern")); +} diff --git a/src/config.rs b/src/config.rs index 81b0994..1b8354a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,1206 +1,21 @@ -use serde::Deserialize; -use std::path::{Path, PathBuf}; - -use crate::error::{Result, WhsprError}; -use crate::rewrite_profile::RewriteProfile; - -#[derive(Debug, Clone, Default, Deserialize)] -#[serde(default)] -pub struct Config { - pub audio: AudioConfig, - pub transcription: TranscriptionConfig, - #[serde(default, rename = "whisper")] - legacy_whisper: LegacyWhisperConfig, - pub postprocess: PostprocessConfig, - pub session: SessionConfig, - pub personalization: PersonalizationConfig, - pub rewrite: RewriteConfig, - pub cloud: CloudConfig, - pub cleanup: CleanupConfig, - pub inject: InjectConfig, - pub feedback: FeedbackConfig, -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(default)] -pub struct AudioConfig { - pub device: String, - pub sample_rate: u32, -} - -#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum TranscriptionBackend { - #[default] - WhisperCpp, - FasterWhisper, - Nemo, - Cloud, -} - -#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum TranscriptionFallback { - None, - #[default] - ConfiguredLocal, -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(default)] -pub struct TranscriptionConfig { - pub backend: TranscriptionBackend, - pub fallback: TranscriptionFallback, - pub local_backend: TranscriptionBackend, - pub selected_model: String, - pub model_path: String, - pub language: String, - pub use_gpu: bool, - pub flash_attn: bool, - pub idle_timeout_ms: u64, -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(default)] -struct LegacyWhisperConfig { - model_path: String, - language: String, - use_gpu: bool, - flash_attn: bool, -} - -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] -#[serde(default)] -pub struct PostprocessConfig { - pub mode: PostprocessMode, -} - -#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum PostprocessMode { - #[default] - Raw, - AdvancedLocal, - LegacyBasic, -} - -#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum RewriteBackend { - #[default] - Local, - Cloud, -} - -#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum RewriteFallback { - None, - #[default] - Local, -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(default)] -pub struct RewriteConfig { - pub backend: RewriteBackend, - pub fallback: RewriteFallback, - pub selected_model: String, - pub model_path: String, - pub instructions_path: String, - pub profile: RewriteProfile, - pub timeout_ms: u64, - pub idle_timeout_ms: u64, - pub max_output_chars: usize, - pub max_tokens: usize, -} - -#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum CloudProvider { - #[default] - #[serde(rename = "openai")] - OpenAi, - #[serde(rename = "openai_compatible")] - OpenAiCompatible, -} - -#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum CloudLanguageMode { - #[default] - InheritLocal, - Force, -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(default)] -pub struct CloudConfig { - pub provider: CloudProvider, - pub base_url: String, - pub api_key: String, - pub api_key_env: String, - pub connect_timeout_ms: u64, - pub request_timeout_ms: u64, - pub transcription: CloudTranscriptionConfig, - pub rewrite: CloudRewriteConfig, -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(default)] -pub struct CloudTranscriptionConfig { - pub model: String, - pub language_mode: CloudLanguageMode, - pub language: String, -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(default)] -pub struct CloudRewriteConfig { - pub model: String, - pub temperature: f32, - pub max_output_tokens: usize, -} - -#[derive(Debug, Clone)] -pub struct CloudSettingsUpdate<'a> { - pub provider: CloudProvider, - pub base_url: &'a str, - pub api_key: &'a str, - pub api_key_env: &'a str, - pub connect_timeout_ms: u64, - pub request_timeout_ms: u64, - pub transcription_model: &'a str, - pub transcription_language_mode: CloudLanguageMode, - pub transcription_language: &'a str, - pub rewrite_model: &'a str, - pub rewrite_temperature: f32, - pub rewrite_max_output_tokens: usize, -} - -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] -#[serde(default)] -pub struct SessionConfig { - pub enabled: bool, - pub max_entries: usize, - pub max_age_ms: u64, - pub max_replace_graphemes: usize, -} - -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] -#[serde(default)] -pub struct PersonalizationConfig { - pub dictionary_path: String, - pub snippets_path: String, - pub snippet_trigger: String, -} - -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] -#[serde(default)] -pub struct CleanupConfig { - pub enabled: bool, - pub profile: CleanupProfile, - pub spoken_formatting: bool, - pub backtrack: bool, - pub remove_fillers: bool, -} - -#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] -pub enum CleanupProfile { - #[default] - Basic, - Aggressive, -} - -#[derive(Debug, Clone, Default, Deserialize)] -#[serde(default)] -pub struct InjectConfig {} - -#[derive(Debug, Clone, Deserialize)] -#[serde(default)] -pub struct FeedbackConfig { - pub enabled: bool, - pub start_sound: String, - pub stop_sound: String, -} - -impl Default for AudioConfig { - fn default() -> Self { - Self { - device: String::new(), - sample_rate: 16000, - } - } -} - -impl Default for TranscriptionConfig { - fn default() -> Self { - Self { - backend: TranscriptionBackend::WhisperCpp, - fallback: TranscriptionFallback::ConfiguredLocal, - local_backend: TranscriptionBackend::WhisperCpp, - selected_model: "large-v3-turbo".into(), - model_path: "~/.local/share/whispers/ggml-large-v3-turbo.bin".into(), - language: "auto".into(), - use_gpu: true, - flash_attn: true, - idle_timeout_ms: 120000, - } - } -} - -impl Default for LegacyWhisperConfig { - fn default() -> Self { - let default = TranscriptionConfig::default(); - Self { - model_path: default.model_path, - language: default.language, - use_gpu: default.use_gpu, - flash_attn: default.flash_attn, - } - } -} - -impl Default for PostprocessConfig { - fn default() -> Self { - Self { - mode: PostprocessMode::Raw, - } - } -} - -impl PostprocessMode { - pub fn as_str(self) -> &'static str { - match self { - Self::Raw => "raw", - Self::AdvancedLocal => "advanced_local", - Self::LegacyBasic => "legacy_basic", - } - } -} - -impl TranscriptionBackend { - pub fn as_str(self) -> &'static str { - match self { - Self::WhisperCpp => "whisper_cpp", - Self::FasterWhisper => "faster_whisper", - Self::Nemo => "nemo", - Self::Cloud => "cloud", - } - } -} - -impl TranscriptionFallback { - pub fn as_str(self) -> &'static str { - match self { - Self::None => "none", - Self::ConfiguredLocal => "configured_local", - } - } -} - -impl RewriteBackend { - pub fn as_str(self) -> &'static str { - match self { - Self::Local => "local", - Self::Cloud => "cloud", - } - } -} - -impl RewriteFallback { - pub fn as_str(self) -> &'static str { - match self { - Self::None => "none", - Self::Local => "local", - } - } -} - -impl CloudProvider { - pub fn as_str(self) -> &'static str { - match self { - Self::OpenAi => "openai", - Self::OpenAiCompatible => "openai_compatible", - } - } -} - -impl CloudLanguageMode { - pub fn as_str(self) -> &'static str { - match self { - Self::InheritLocal => "inherit_local", - Self::Force => "force", - } - } -} - -impl Default for RewriteConfig { - fn default() -> Self { - Self { - backend: RewriteBackend::Local, - fallback: RewriteFallback::Local, - selected_model: "qwen-3.5-4b-q4_k_m".into(), - model_path: String::new(), - instructions_path: "~/.local/share/whispers/rewrite-instructions.txt".into(), - profile: RewriteProfile::Auto, - timeout_ms: 30000, - idle_timeout_ms: 120000, - max_output_chars: 1200, - max_tokens: 256, - } - } -} - -impl Default for CloudConfig { - fn default() -> Self { - Self { - provider: CloudProvider::OpenAi, - base_url: String::new(), - api_key: String::new(), - api_key_env: "OPENAI_API_KEY".into(), - connect_timeout_ms: 3000, - request_timeout_ms: 15000, - transcription: CloudTranscriptionConfig::default(), - rewrite: CloudRewriteConfig::default(), - } - } -} - -impl Default for CloudTranscriptionConfig { - fn default() -> Self { - Self { - model: "gpt-4o-mini-transcribe".into(), - language_mode: CloudLanguageMode::InheritLocal, - language: String::new(), - } - } -} - -impl Default for CloudRewriteConfig { - fn default() -> Self { - Self { - model: "gpt-4.1-mini".into(), - temperature: 0.1, - max_output_tokens: 256, - } - } -} - -impl Default for PersonalizationConfig { - fn default() -> Self { - Self { - dictionary_path: "~/.local/share/whispers/dictionary.toml".into(), - snippets_path: "~/.local/share/whispers/snippets.toml".into(), - snippet_trigger: "insert".into(), - } - } -} - -impl Default for SessionConfig { - fn default() -> Self { - Self { - enabled: true, - max_entries: 3, - max_age_ms: 8000, - max_replace_graphemes: 400, - } - } -} - -impl Default for CleanupConfig { - fn default() -> Self { - Self { - enabled: true, - profile: CleanupProfile::Basic, - spoken_formatting: true, - backtrack: true, - remove_fillers: true, - } - } -} - -impl Default for FeedbackConfig { - fn default() -> Self { - Self { - enabled: true, - start_sound: String::new(), - stop_sound: String::new(), - } - } -} - -impl TranscriptionConfig { - pub fn resolved_local_backend(&self) -> TranscriptionBackend { - match self.local_backend { - TranscriptionBackend::WhisperCpp - | TranscriptionBackend::FasterWhisper - | TranscriptionBackend::Nemo => self.local_backend, - TranscriptionBackend::Cloud => TranscriptionBackend::WhisperCpp, - } - } -} - -impl Config { - pub fn load(path: Option<&Path>) -> Result { - let config_path = resolve_config_path(path); - - if !config_path.exists() { - tracing::info!( - "no config file found at {}, using defaults", - config_path.display() - ); - return Ok(Config::default()); - } - - let contents = std::fs::read_to_string(&config_path).map_err(|e| { - WhsprError::Config(format!("failed to read {}: {e}", config_path.display())) - })?; - - let mut config: Config = toml::from_str(&contents).map_err(|e| { - WhsprError::Config(format!("failed to parse {}: {e}", config_path.display())) - })?; - - config.apply_legacy_transcription_migration(&contents, &config_path); - config.apply_legacy_cleanup_migration(&contents, &config_path); - config.apply_cloud_sanitization(); - Ok(config) - } - - pub fn resolved_model_path(&self) -> PathBuf { - PathBuf::from(expand_tilde(&self.transcription.model_path)) - } - - pub fn resolved_rewrite_model_path(&self) -> Option { - (!self.rewrite.model_path.trim().is_empty()) - .then(|| PathBuf::from(expand_tilde(&self.rewrite.model_path))) - } - - pub fn resolved_rewrite_instructions_path(&self) -> Option { - (!self.rewrite.instructions_path.trim().is_empty()) - .then(|| PathBuf::from(expand_tilde(&self.rewrite.instructions_path))) - } - - pub fn resolved_dictionary_path(&self) -> PathBuf { - PathBuf::from(expand_tilde(&self.personalization.dictionary_path)) - } - - pub fn resolved_snippets_path(&self) -> PathBuf { - PathBuf::from(expand_tilde(&self.personalization.snippets_path)) - } - - fn apply_legacy_transcription_migration(&mut self, contents: &str, config_path: &Path) { - let transcription_present = section_present(contents, "transcription"); - let whisper_present = section_present(contents, "whisper"); - - if !transcription_present && whisper_present { - tracing::warn!( - "config {} uses deprecated [whisper]; mapping to [transcription]", - config_path.display() - ); - self.transcription.backend = TranscriptionBackend::WhisperCpp; - self.transcription.model_path = self.legacy_whisper.model_path.clone(); - self.transcription.language = self.legacy_whisper.language.clone(); - self.transcription.use_gpu = self.legacy_whisper.use_gpu; - self.transcription.flash_attn = self.legacy_whisper.flash_attn; - } else if whisper_present { - tracing::warn!( - "config {} contains deprecated [whisper]; [transcription] takes precedence", - config_path.display() - ); - } - } - - fn apply_legacy_cleanup_migration(&mut self, contents: &str, config_path: &Path) { - let cleanup_present = cleanup_section_present(contents); - - if cleanup_present && self.postprocess.mode == PostprocessMode::Raw { - if self.cleanup.enabled { - tracing::warn!( - "config {} uses deprecated [cleanup]; mapping to postprocess.mode = \"legacy_basic\"", - config_path.display() - ); - self.postprocess.mode = PostprocessMode::LegacyBasic; - } else { - tracing::warn!( - "config {} disables deprecated [cleanup]; keeping postprocess.mode = \"raw\"", - config_path.display() - ); - } - } else if cleanup_present && self.postprocess.mode != PostprocessMode::LegacyBasic { - tracing::warn!( - "config {} contains deprecated [cleanup]; [postprocess] takes precedence", - config_path.display() - ); - } - } - - fn apply_cloud_sanitization(&mut self) { - if self.transcription.local_backend == TranscriptionBackend::Cloud { - tracing::warn!( - "transcription.local_backend cannot be cloud; falling back to whisper_cpp" - ); - self.transcription.local_backend = TranscriptionBackend::WhisperCpp; - } - if self.cloud.api_key.trim().is_empty() && looks_like_cloud_api_key(&self.cloud.api_key_env) - { - tracing::warn!( - "cloud.api_key_env looks like a literal API key; treating it as cloud.api_key" - ); - self.cloud.api_key = self.cloud.api_key_env.trim().to_string(); - self.cloud.api_key_env = "OPENAI_API_KEY".into(); - } - } -} - -pub fn default_config_path() -> PathBuf { - xdg_dir("config").join("whispers").join("config.toml") -} - -pub fn resolve_config_path(path: Option<&Path>) -> PathBuf { - match path { - Some(p) => p.to_path_buf(), - None => default_config_path(), - } -} - -pub fn data_dir() -> PathBuf { - xdg_dir("data").join("whispers") -} - -fn xdg_dir(kind: &str) -> PathBuf { - match kind { - "config" => { - if let Ok(dir) = std::env::var("XDG_CONFIG_HOME") { - PathBuf::from(dir) - } else if let Ok(home) = std::env::var("HOME") { - PathBuf::from(home).join(".config") - } else { - tracing::warn!("neither XDG_CONFIG_HOME nor HOME is set, falling back to /tmp"); - PathBuf::from("/tmp") - } - } - "data" => { - if let Ok(dir) = std::env::var("XDG_DATA_HOME") { - PathBuf::from(dir) - } else if let Ok(home) = std::env::var("HOME") { - PathBuf::from(home).join(".local").join("share") - } else { - tracing::warn!("neither XDG_DATA_HOME nor HOME is set, falling back to /tmp"); - PathBuf::from("/tmp") - } - } - _ => { - tracing::warn!("unknown XDG directory kind '{kind}', falling back to /tmp"); - PathBuf::from("/tmp") - } - } -} - -pub fn expand_tilde(path: &str) -> String { - match path.strip_prefix("~/") { - Some(rest) => { - if let Ok(home) = std::env::var("HOME") { - return format!("{home}/{rest}"); - } - tracing::warn!("HOME is not set, cannot expand tilde in path: {path}"); - } - None if path == "~" => { - if let Ok(home) = std::env::var("HOME") { - return home; - } - tracing::warn!("HOME is not set, cannot expand tilde in path: {path}"); - } - _ => {} - } - path.to_string() -} - -pub fn write_default_config(path: &Path, model_path: &str) -> Result<()> { - let contents = format!( - r#"# whispers configuration -# -# Keybinding is handled by your compositor. Example for Hyprland: -# bind = SUPER ALT, D, exec, whispers -# -# First invocation starts recording, second invocation stops + transcribes + pastes. - -[audio] -# Input device name (empty = system default) -device = "" -# Sample rate in Hz (ASR requires 16000) -sample_rate = 16000 - -[transcription] -# Active transcription backend ("whisper_cpp", "faster_whisper", "nemo", or "cloud") -backend = "whisper_cpp" -# Cloud fallback behavior ("configured_local" or "none") -fallback = "configured_local" -# Local backend used directly in local mode and as the cloud fallback backend -local_backend = "whisper_cpp" -# Managed ASR model name for the selected backend -selected_model = "large-v3-turbo" -# Path to the local backend-specific model or empty to use the selected managed model -# Manage models with: whispers asr-model list / download / select -model_path = "{model_path}" -# Language code ("en", "fr", "de", etc.) or "auto" for auto-detect -language = "auto" -# Enable GPU acceleration (set false to force CPU) -use_gpu = true -# Enable flash attention when GPU is enabled -flash_attn = true -# How long the hidden ASR worker stays warm without requests (0 = never expire) -idle_timeout_ms = 120000 - -[postprocess] -# "raw" (default), "advanced_local", or "legacy_basic" for deprecated cleanup configs -mode = "raw" - -[session] -# Enable short-lived session backtracking in advanced_local mode -enabled = true -# How many recent dictation entries to keep in the runtime session ledger -max_entries = 3 -# How long a recent dictation entry stays eligible for revision -max_age_ms = 8000 -# Maximum graphemes that may be deleted when revising the latest entry -max_replace_graphemes = 400 - -[personalization] -# Dictionary replacements applied in all modes -dictionary_path = "~/.local/share/whispers/dictionary.toml" -# Snippets expanded via an explicit spoken trigger -snippets_path = "~/.local/share/whispers/snippets.toml" -# Spoken trigger phrase used before snippet names -snippet_trigger = "insert" - -[rewrite] -# Rewrite backend ("local" or "cloud") -backend = "local" -# Cloud fallback behavior ("local" or "none") -fallback = "local" -# Managed rewrite model name for advanced_local mode -selected_model = "qwen-3.5-4b-q4_k_m" -# Manual GGUF path override (empty = use selected managed model) -# Custom rewrite models should be chat-capable GGUFs with an embedded -# chat template that llama.cpp can apply at runtime. -model_path = "" -# Append-only custom rewrite instructions file (empty = disabled) -instructions_path = "~/.local/share/whispers/rewrite-instructions.txt" -# Rewrite profile selection ("auto", "qwen", "generic", or "llama_compat") -profile = "auto" -# Timeout for local rewrite inference in milliseconds -timeout_ms = 30000 -# How long the hidden rewrite worker stays warm without requests -idle_timeout_ms = 120000 -# Maximum characters accepted from the rewrite model -max_output_chars = 1200 -# Maximum tokens to generate for rewritten output -max_tokens = 256 - -[cloud] -# Cloud provider ("openai" or "openai_compatible") -provider = "openai" -# Custom base URL for openai_compatible providers (empty uses the OpenAI default) -base_url = "" -# Optional API key stored directly in the config (empty = use api_key_env instead) -api_key = "" -# Environment variable holding the API key -api_key_env = "OPENAI_API_KEY" -# Network connect timeout in milliseconds -connect_timeout_ms = 3000 -# End-to-end request timeout in milliseconds -request_timeout_ms = 15000 - -[cloud.transcription] -# Cloud transcription model -model = "gpt-4o-mini-transcribe" -# "inherit_local" uses [transcription].language when it is not "auto"; "force" uses the value below -language_mode = "inherit_local" -# Language code used when language_mode = "force" -language = "" - -[cloud.rewrite] -# Cloud rewrite model -model = "gpt-4.1-mini" -# Sampling temperature for cloud rewrite -temperature = 0.1 -# Maximum tokens requested from the cloud rewrite model -max_output_tokens = 256 - -[feedback] -# Play sound feedback on start/stop -enabled = true -# Custom sound file paths (empty = use bundled sounds) -start_sound = "" -stop_sound = "" -"# - ); - - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent) - .map_err(|e| WhsprError::Config(format!("failed to create config directory: {e}")))?; - } - - std::fs::write(path, contents) - .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; - - Ok(()) -} - -pub fn update_config_transcription_selection( - config_path: &Path, - backend: TranscriptionBackend, - selected_model: &str, - model_path: &str, - set_active_backend: bool, -) -> Result<()> { - let contents = std::fs::read_to_string(config_path) - .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; - - let mut doc = contents - .parse::() - .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; - - ensure_standard_postprocess_tables(&mut doc); - doc.as_table_mut().remove("whisper"); - if set_active_backend { - doc["transcription"]["backend"] = toml_edit::value(backend.as_str()); - } - doc["transcription"]["local_backend"] = toml_edit::value(backend.as_str()); - doc["transcription"]["fallback"] = - toml_edit::value(TranscriptionFallback::ConfiguredLocal.as_str()); - doc["transcription"]["selected_model"] = toml_edit::value(selected_model); - doc["transcription"]["model_path"] = toml_edit::value(model_path); - let idle_timeout_ms = if backend == TranscriptionBackend::Nemo { - 0 - } else { - 120000 - }; - doc["transcription"]["idle_timeout_ms"] = toml_edit::value(idle_timeout_ms); - - std::fs::write(config_path, doc.to_string()) - .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; - - Ok(()) -} - -pub fn update_config_postprocess_mode(config_path: &Path, mode: PostprocessMode) -> Result<()> { - let contents = std::fs::read_to_string(config_path) - .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; - - let mut doc = contents - .parse::() - .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; - - ensure_standard_postprocess_tables(&mut doc); - doc["postprocess"]["mode"] = toml_edit::value(mode.as_str()); - - std::fs::write(config_path, doc.to_string()) - .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; - - Ok(()) -} - -pub fn update_config_rewrite_selection(config_path: &Path, selected_model: &str) -> Result<()> { - let contents = std::fs::read_to_string(config_path) - .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; - - let mut doc = contents - .parse::() - .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; - - ensure_standard_postprocess_tables(&mut doc); - doc["postprocess"]["mode"] = toml_edit::value(PostprocessMode::AdvancedLocal.as_str()); - let rewrite_backend = doc["rewrite"] - .as_table_like() - .and_then(|table| table.get("backend")) - .and_then(|item| item.as_str()); - if !matches!(rewrite_backend, Some("cloud")) { - doc["rewrite"]["backend"] = toml_edit::value(RewriteBackend::Local.as_str()); - } - doc["rewrite"]["fallback"] = toml_edit::value(RewriteFallback::Local.as_str()); - doc["rewrite"]["selected_model"] = toml_edit::value(selected_model); - doc["rewrite"]["model_path"] = toml_edit::value(""); - doc["rewrite"]["instructions_path"] = - toml_edit::value("~/.local/share/whispers/rewrite-instructions.txt"); - doc["rewrite"]["profile"] = toml_edit::value(RewriteProfile::Auto.as_str()); - doc["rewrite"]["timeout_ms"] = toml_edit::value(30000); - doc["rewrite"]["idle_timeout_ms"] = toml_edit::value(120000); - doc["rewrite"]["max_output_chars"] = toml_edit::value(1200); - doc["rewrite"]["max_tokens"] = toml_edit::value(256); - - std::fs::write(config_path, doc.to_string()) - .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; - - Ok(()) -} - -pub fn update_config_transcription_runtime( - config_path: &Path, - backend: TranscriptionBackend, - fallback: TranscriptionFallback, -) -> Result<()> { - let contents = std::fs::read_to_string(config_path) - .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; - let mut doc = contents - .parse::() - .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; - - ensure_standard_postprocess_tables(&mut doc); - doc["transcription"]["backend"] = toml_edit::value(backend.as_str()); - doc["transcription"]["fallback"] = toml_edit::value(fallback.as_str()); - - std::fs::write(config_path, doc.to_string()) - .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; - Ok(()) -} - -pub fn update_config_rewrite_runtime( - config_path: &Path, - backend: RewriteBackend, - fallback: RewriteFallback, -) -> Result<()> { - let contents = std::fs::read_to_string(config_path) - .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; - let mut doc = contents - .parse::() - .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; - - ensure_standard_postprocess_tables(&mut doc); - doc["postprocess"]["mode"] = toml_edit::value(PostprocessMode::AdvancedLocal.as_str()); - doc["rewrite"]["backend"] = toml_edit::value(backend.as_str()); - doc["rewrite"]["fallback"] = toml_edit::value(fallback.as_str()); - - std::fs::write(config_path, doc.to_string()) - .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; - Ok(()) -} - -pub fn update_config_cloud_settings( - config_path: &Path, - settings: &CloudSettingsUpdate<'_>, -) -> Result<()> { - let contents = std::fs::read_to_string(config_path) - .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; - let mut doc = contents - .parse::() - .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; - - ensure_standard_postprocess_tables(&mut doc); - ensure_root_table(&mut doc, "cloud"); - ensure_nested_table(&mut doc, "cloud", "transcription"); - ensure_nested_table(&mut doc, "cloud", "rewrite"); - doc["cloud"]["provider"] = toml_edit::value(settings.provider.as_str()); - doc["cloud"]["base_url"] = toml_edit::value(settings.base_url); - doc["cloud"]["api_key"] = toml_edit::value(settings.api_key); - doc["cloud"]["api_key_env"] = toml_edit::value(settings.api_key_env); - doc["cloud"]["connect_timeout_ms"] = toml_edit::value(settings.connect_timeout_ms as i64); - doc["cloud"]["request_timeout_ms"] = toml_edit::value(settings.request_timeout_ms as i64); - doc["cloud"]["transcription"]["model"] = toml_edit::value(settings.transcription_model); - doc["cloud"]["transcription"]["language_mode"] = - toml_edit::value(settings.transcription_language_mode.as_str()); - doc["cloud"]["transcription"]["language"] = toml_edit::value(settings.transcription_language); - doc["cloud"]["rewrite"]["model"] = toml_edit::value(settings.rewrite_model); - doc["cloud"]["rewrite"]["temperature"] = toml_edit::value(settings.rewrite_temperature as f64); - doc["cloud"]["rewrite"]["max_output_tokens"] = - toml_edit::value(settings.rewrite_max_output_tokens as i64); - - std::fs::write(config_path, doc.to_string()) - .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; - Ok(()) -} - -fn ensure_standard_postprocess_tables(doc: &mut toml_edit::DocumentMut) { - ensure_root_table(doc, "transcription"); - ensure_root_table(doc, "postprocess"); - ensure_root_table(doc, "session"); - ensure_root_table(doc, "rewrite"); - ensure_root_table(doc, "cloud"); - ensure_nested_table(doc, "cloud", "transcription"); - ensure_nested_table(doc, "cloud", "rewrite"); - ensure_root_table(doc, "personalization"); -} - -fn ensure_root_table(doc: &mut toml_edit::DocumentMut, key: &str) { - let root = doc.as_table_mut(); - let needs_insert = !root.contains_key(key) || !root[key].is_table(); - if needs_insert { - root.insert(key, toml_edit::Item::Table(toml_edit::Table::new())); - } -} - -fn ensure_nested_table(doc: &mut toml_edit::DocumentMut, parent: &str, child: &str) { - ensure_root_table(doc, parent); - let root = doc.as_table_mut(); - let Some(parent_item) = root.get_mut(parent) else { - return; - }; - let Some(parent_table) = parent_item.as_table_like_mut() else { - return; - }; - let needs_insert = !parent_table.contains_key(child) - || parent_table - .get(child) - .map(|item| !item.is_table()) - .unwrap_or(true); - if needs_insert { - parent_table.insert(child, toml_edit::Item::Table(toml_edit::Table::new())); - } -} - -fn cleanup_section_present(contents: &str) -> bool { - section_present(contents, "cleanup") -} - -fn section_present(contents: &str, name: &str) -> bool { - toml::from_str::(contents) - .ok() - .and_then(|value| value.get(name).cloned()) - .is_some() -} - -fn looks_like_cloud_api_key(value: &str) -> bool { - let trimmed = value.trim(); - trimmed.starts_with("sk-") -} +mod edit; +mod load; +mod paths; +mod schema; #[cfg(test)] -mod tests { - use super::*; - use crate::error::WhsprError; - - #[test] - fn load_missing_file_uses_defaults() { - let path = crate::test_support::unique_temp_path("config-missing", "toml"); - let config = Config::load(Some(&path)).expect("missing config should load defaults"); - assert_eq!(config.audio.sample_rate, 16000); - assert_eq!(config.transcription.language, "auto"); - assert_eq!( - config.transcription.backend, - TranscriptionBackend::WhisperCpp - ); - assert_eq!(config.postprocess.mode, PostprocessMode::Raw); - assert_eq!(config.personalization.snippet_trigger, "insert"); - assert_eq!(config.rewrite.selected_model, "qwen-3.5-4b-q4_k_m"); - } - - #[test] - fn load_invalid_toml_returns_parse_error() { - let path = crate::test_support::unique_temp_path("config-invalid", "toml"); - std::fs::write(&path, "not = [valid = toml").expect("write invalid config"); - let err = Config::load(Some(&path)).expect_err("invalid config should fail"); - match err { - WhsprError::Config(msg) => { - assert!(msg.contains("failed to parse"), "unexpected message: {msg}"); - } - other => panic!("unexpected error variant: {other:?}"), - } - } - - #[test] - fn expand_tilde_uses_home_when_present() { - let _env_lock = crate::test_support::env_lock(); - let _guard = crate::test_support::EnvVarGuard::capture(&["HOME"]); - crate::test_support::set_env("HOME", "/tmp/whispers-home"); - assert_eq!( - expand_tilde("~/models/ggml.bin"), - "/tmp/whispers-home/models/ggml.bin" - ); - assert_eq!(expand_tilde("~"), "/tmp/whispers-home"); - } - - #[test] - fn expand_tilde_without_home_returns_original_path() { - let _env_lock = crate::test_support::env_lock(); - let _guard = crate::test_support::EnvVarGuard::capture(&["HOME"]); - crate::test_support::remove_env("HOME"); - assert_eq!(expand_tilde("~/models/ggml.bin"), "~/models/ggml.bin"); - assert_eq!(expand_tilde("~"), "~"); - } - - #[test] - fn write_default_and_update_model_path_roundtrip() { - let dir = crate::test_support::unique_temp_dir("config-roundtrip"); - let config_path = dir.join("nested").join("config.toml"); - - write_default_config(&config_path, "~/old-model.bin").expect("write config"); - assert!(config_path.exists(), "config file should exist"); - - update_config_transcription_selection( - &config_path, - TranscriptionBackend::WhisperCpp, - "large-v3-turbo", - "~/new-model.bin", - true, - ) - .expect("update config"); - let loaded = Config::load(Some(&config_path)).expect("load config"); - assert_eq!(loaded.transcription.model_path, "~/new-model.bin"); - assert_eq!( - loaded.transcription.backend, - TranscriptionBackend::WhisperCpp - ); - assert_eq!(loaded.audio.sample_rate, 16000); - assert_eq!(loaded.postprocess.mode, PostprocessMode::Raw); - assert_eq!( - loaded.personalization.dictionary_path, - "~/.local/share/whispers/dictionary.toml" - ); - assert!(loaded.session.enabled); - assert_eq!(loaded.session.max_entries, 3); - assert_eq!(loaded.rewrite.timeout_ms, 30000); - assert!(loaded.feedback.enabled); - - let raw = std::fs::read_to_string(&config_path).expect("read config"); - assert!(raw.contains("[audio]")); - assert!(raw.contains("[transcription]")); - assert!(raw.contains("[postprocess]")); - assert!(raw.contains("[session]")); - assert!(raw.contains("[rewrite]")); - assert!(!raw.contains("[whisper]")); - } - - #[test] - fn selecting_nemo_model_sets_non_expiring_asr_worker_timeout() { - let config_path = crate::test_support::unique_temp_path("config-nemo-timeout", "toml"); - write_default_config(&config_path, "~/old-model.bin").expect("write config"); - - update_config_transcription_selection( - &config_path, - TranscriptionBackend::Nemo, - "parakeet-tdt_ctc-1.1b", - "~/.local/share/whispers/nemo/models/parakeet-tdt_ctc-1.1b", - true, - ) - .expect("select nemo model"); - - let loaded = Config::load(Some(&config_path)).expect("load config"); - assert_eq!(loaded.transcription.backend, TranscriptionBackend::Nemo); - assert_eq!(loaded.transcription.idle_timeout_ms, 0); - } - - #[test] - fn load_legacy_whisper_section_maps_to_transcription() { - let path = crate::test_support::unique_temp_path("config-whisper-legacy", "toml"); - std::fs::write( - &path, - r#"[whisper] -model_path = "~/legacy-model.bin" -language = "en" -use_gpu = false -flash_attn = false -"#, - ) - .expect("write config"); - - let loaded = Config::load(Some(&path)).expect("load config"); - assert_eq!( - loaded.transcription.backend, - TranscriptionBackend::WhisperCpp - ); - assert_eq!(loaded.transcription.model_path, "~/legacy-model.bin"); - assert_eq!(loaded.transcription.language, "en"); - assert!(!loaded.transcription.use_gpu); - assert!(!loaded.transcription.flash_attn); - } - - #[test] - fn load_legacy_cleanup_section_maps_to_legacy_basic() { - let path = crate::test_support::unique_temp_path("config-cleanup", "toml"); - std::fs::write( - &path, - r#"[cleanup] -profile = "aggressive" -spoken_formatting = false -remove_fillers = false -"#, - ) - .expect("write config"); - - let config = Config::load(Some(&path)).expect("load config"); - assert_eq!(config.postprocess.mode, PostprocessMode::LegacyBasic); - assert_eq!(config.cleanup.profile, CleanupProfile::Aggressive); - assert!(!config.cleanup.spoken_formatting); - assert!(config.cleanup.backtrack); - assert!(!config.cleanup.remove_fillers); - } - - #[test] - fn update_rewrite_selection_enables_advanced_mode() { - let dir = crate::test_support::unique_temp_dir("config-rewrite-select"); - let config_path = dir.join("config.toml"); - write_default_config(&config_path, "~/model.bin").expect("write config"); - - update_config_rewrite_selection(&config_path, "qwen-3.5-2b-q4_k_m") - .expect("select rewrite model"); - - let loaded = Config::load(Some(&config_path)).expect("load config"); - assert_eq!(loaded.postprocess.mode, PostprocessMode::AdvancedLocal); - assert_eq!(loaded.rewrite.selected_model, "qwen-3.5-2b-q4_k_m"); - assert!(loaded.rewrite.model_path.is_empty()); - assert_eq!( - loaded.rewrite.instructions_path, - "~/.local/share/whispers/rewrite-instructions.txt" - ); - assert_eq!(loaded.rewrite.profile, RewriteProfile::Auto); - assert_eq!(loaded.rewrite.timeout_ms, 30000); - assert_eq!(loaded.rewrite.idle_timeout_ms, 120000); - } - - #[test] - fn update_helpers_upgrade_legacy_configs_without_panicking() { - let config_path = crate::test_support::unique_temp_path("config-legacy-upgrade", "toml"); - std::fs::write( - &config_path, - r#"[audio] -sample_rate = 16000 - -[whisper] -model_path = "~/.local/share/whispers/ggml-large-v3-turbo.bin" -language = "auto" -"#, - ) - .expect("write legacy config"); - - update_config_transcription_selection( - &config_path, - TranscriptionBackend::WhisperCpp, - "large-v3-turbo", - "~/.local/share/whispers/ggml-large-v3-turbo.bin", - true, - ) - .expect("update transcription selection"); - update_config_rewrite_selection(&config_path, "qwen-3.5-4b-q4_k_m") - .expect("update rewrite selection"); - - let loaded = Config::load(Some(&config_path)).expect("load upgraded config"); - assert_eq!( - loaded.transcription.backend, - TranscriptionBackend::WhisperCpp - ); - assert_eq!(loaded.transcription.selected_model, "large-v3-turbo"); - assert_eq!(loaded.postprocess.mode, PostprocessMode::AdvancedLocal); - assert_eq!(loaded.rewrite.selected_model, "qwen-3.5-4b-q4_k_m"); - - let raw = std::fs::read_to_string(&config_path).expect("read upgraded config"); - assert!(!raw.contains("[whisper]")); - } - - #[test] - fn load_cloud_literal_key_from_legacy_api_key_env() { - let path = crate::test_support::unique_temp_path("config-cloud-literal-key", "toml"); - std::fs::write( - &path, - r#"[cloud] -api_key_env = "sk-test-inline" -"#, - ) - .expect("write config"); - - let loaded = Config::load(Some(&path)).expect("load config"); - assert_eq!(loaded.cloud.api_key, "sk-test-inline"); - assert_eq!(loaded.cloud.api_key_env, "OPENAI_API_KEY"); - } -} +mod tests; + +pub use edit::{ + update_config_cloud_settings, update_config_postprocess_mode, update_config_rewrite_runtime, + update_config_rewrite_selection, update_config_transcription_runtime, + update_config_transcription_selection, write_default_config, +}; +pub use paths::{data_dir, default_config_path, expand_tilde, resolve_config_path}; +pub use schema::{ + AgenticRewriteConfig, AudioConfig, CleanupConfig, CleanupProfile, CloudConfig, + CloudLanguageMode, CloudProvider, CloudRewriteConfig, CloudSettingsUpdate, + CloudTranscriptionConfig, Config, FeedbackConfig, InjectConfig, PersonalizationConfig, + PostprocessConfig, PostprocessMode, RewriteBackend, RewriteConfig, RewriteFallback, + SessionConfig, TranscriptionBackend, TranscriptionConfig, TranscriptionFallback, +}; diff --git a/src/config/edit.rs b/src/config/edit.rs new file mode 100644 index 0000000..5ac4f03 --- /dev/null +++ b/src/config/edit.rs @@ -0,0 +1,380 @@ +use std::path::Path; + +use crate::error::{Result, WhsprError}; +use crate::rewrite_profile::RewriteProfile; +use crate::rewrite_protocol::RewriteCorrectionPolicy; + +use super::{ + CloudSettingsUpdate, PostprocessMode, RewriteBackend, RewriteFallback, TranscriptionBackend, + TranscriptionFallback, +}; + +pub(crate) fn default_config_template(model_path: &str) -> String { + format!( + r#"# whispers configuration +# +# Keybinding is handled by your compositor. Example for Hyprland: +# bind = SUPER ALT, D, exec, whispers +# +# First invocation starts recording, second invocation stops + transcribes + pastes. + +[audio] +# Input device name (empty = system default) +device = "" +# Sample rate in Hz (ASR requires 16000) +sample_rate = 16000 + +[transcription] +# Active transcription backend ("whisper_cpp", "faster_whisper", "nemo", or "cloud") +backend = "whisper_cpp" +# Cloud fallback behavior ("configured_local" or "none") +fallback = "configured_local" +# Local backend used directly in local mode and as the cloud fallback backend +local_backend = "whisper_cpp" +# Managed ASR model name for the selected backend +selected_model = "large-v3-turbo" +# Path to the local backend-specific model or empty to use the selected managed model +# Manage models with: whispers asr-model list / download / select +model_path = "{model_path}" +# Language code ("en", "fr", "de", etc.) or "auto" for auto-detect +language = "auto" +# Enable GPU acceleration (set false to force CPU) +use_gpu = true +# Enable flash attention when GPU is enabled +flash_attn = true +# How long the hidden ASR worker stays warm without requests (0 = never expire) +idle_timeout_ms = 120000 + +[postprocess] +# "raw" (default), "advanced_local", "agentic_rewrite", or "legacy_basic" for deprecated cleanup configs +mode = "raw" + +[session] +# Enable short-lived session backtracking in rewrite modes +enabled = true +# How many recent dictation entries to keep in the runtime session ledger +max_entries = 3 +# How long a recent dictation entry stays eligible for revision +max_age_ms = 8000 +# Maximum graphemes that may be deleted when revising the latest entry +max_replace_graphemes = 400 + +[personalization] +# Dictionary replacements applied in all modes +dictionary_path = "~/.local/share/whispers/dictionary.toml" +# Snippets expanded via an explicit spoken trigger +snippets_path = "~/.local/share/whispers/snippets.toml" +# Spoken trigger phrase used before snippet names +snippet_trigger = "insert" + +[rewrite] +# Rewrite backend ("local" or "cloud") +backend = "local" +# Cloud fallback behavior ("local" or "none") +fallback = "local" +# Managed rewrite model name for advanced_local mode +selected_model = "qwen-3.5-4b-q4_k_m" +# Manual GGUF path override (empty = use selected managed model) +# Custom rewrite models should be chat-capable GGUFs with an embedded +# chat template that llama.cpp can apply at runtime. +model_path = "" +# Append-only custom rewrite instructions file (empty = disabled) +instructions_path = "~/.local/share/whispers/rewrite-instructions.txt" +# Rewrite profile selection ("auto", "qwen", "generic", or "llama_compat") +profile = "auto" +# Timeout for local rewrite inference in milliseconds +timeout_ms = 30000 +# How long the hidden rewrite worker stays warm without requests +idle_timeout_ms = 120000 +# Maximum characters accepted from the rewrite model +max_output_chars = 1200 +# Maximum tokens to generate for rewritten output +max_tokens = 256 + +[agentic_rewrite] +# App-aware rewrite policy rules used by postprocess.mode = "agentic_rewrite" +policy_path = "~/.local/share/whispers/app-rewrite-policy.toml" +# Technical glossary used by postprocess.mode = "agentic_rewrite" +glossary_path = "~/.local/share/whispers/technical-glossary.toml" +# Default correction policy ("conservative", "balanced", or "aggressive") +default_correction_policy = "balanced" + +[cloud] +# Cloud provider ("openai" or "openai_compatible") +provider = "openai" +# Custom base URL for openai_compatible providers (empty uses the OpenAI default) +base_url = "" +# Optional API key stored directly in the config (empty = use api_key_env instead) +api_key = "" +# Environment variable holding the API key +api_key_env = "OPENAI_API_KEY" +# Network connect timeout in milliseconds +connect_timeout_ms = 3000 +# End-to-end request timeout in milliseconds +request_timeout_ms = 15000 + +[cloud.transcription] +# Cloud transcription model +model = "gpt-4o-mini-transcribe" +# "inherit_local" uses [transcription].language when it is not "auto"; "force" uses the value below +language_mode = "inherit_local" +# Language code used when language_mode = "force" +language = "" + +[cloud.rewrite] +# Cloud rewrite model +model = "gpt-4.1-mini" +# Sampling temperature for cloud rewrite +temperature = 0.1 +# Maximum tokens requested from the cloud rewrite model +max_output_tokens = 256 + +[feedback] +# Play sound feedback on start/stop +enabled = true +# Custom sound file paths (empty = use bundled sounds) +start_sound = "" +stop_sound = "" +"# + ) +} + +pub fn write_default_config(path: &Path, model_path: &str) -> Result<()> { + let contents = default_config_template(model_path); + + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent) + .map_err(|e| WhsprError::Config(format!("failed to create config directory: {e}")))?; + } + + std::fs::write(path, contents) + .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; + + Ok(()) +} + +pub fn update_config_transcription_selection( + config_path: &Path, + backend: TranscriptionBackend, + selected_model: &str, + model_path: &str, + set_active_backend: bool, +) -> Result<()> { + let contents = std::fs::read_to_string(config_path) + .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; + + let mut doc = contents + .parse::() + .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; + + ensure_standard_postprocess_tables(&mut doc); + doc.as_table_mut().remove("whisper"); + if set_active_backend { + doc["transcription"]["backend"] = toml_edit::value(backend.as_str()); + } + doc["transcription"]["local_backend"] = toml_edit::value(backend.as_str()); + doc["transcription"]["fallback"] = + toml_edit::value(TranscriptionFallback::ConfiguredLocal.as_str()); + doc["transcription"]["selected_model"] = toml_edit::value(selected_model); + doc["transcription"]["model_path"] = toml_edit::value(model_path); + let idle_timeout_ms = if backend == TranscriptionBackend::Nemo { + 0 + } else { + 120000 + }; + doc["transcription"]["idle_timeout_ms"] = toml_edit::value(idle_timeout_ms); + + std::fs::write(config_path, doc.to_string()) + .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; + + Ok(()) +} + +pub fn update_config_postprocess_mode(config_path: &Path, mode: PostprocessMode) -> Result<()> { + let contents = std::fs::read_to_string(config_path) + .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; + + let mut doc = contents + .parse::() + .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; + + ensure_standard_postprocess_tables(&mut doc); + doc["postprocess"]["mode"] = toml_edit::value(mode.as_str()); + + std::fs::write(config_path, doc.to_string()) + .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; + + Ok(()) +} + +pub fn update_config_rewrite_selection(config_path: &Path, selected_model: &str) -> Result<()> { + let contents = std::fs::read_to_string(config_path) + .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; + + let mut doc = contents + .parse::() + .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; + + ensure_standard_postprocess_tables(&mut doc); + let mode = match doc["postprocess"] + .as_table_like() + .and_then(|table| table.get("mode")) + .and_then(|item| item.as_str()) + { + Some("agentic_rewrite") => PostprocessMode::AgenticRewrite, + _ => PostprocessMode::AdvancedLocal, + }; + doc["postprocess"]["mode"] = toml_edit::value(mode.as_str()); + let rewrite_backend = doc["rewrite"] + .as_table_like() + .and_then(|table| table.get("backend")) + .and_then(|item| item.as_str()); + if !matches!(rewrite_backend, Some("cloud")) { + doc["rewrite"]["backend"] = toml_edit::value(RewriteBackend::Local.as_str()); + } + doc["rewrite"]["fallback"] = toml_edit::value(RewriteFallback::Local.as_str()); + doc["rewrite"]["selected_model"] = toml_edit::value(selected_model); + doc["rewrite"]["model_path"] = toml_edit::value(""); + doc["rewrite"]["instructions_path"] = + toml_edit::value("~/.local/share/whispers/rewrite-instructions.txt"); + doc["rewrite"]["profile"] = toml_edit::value(RewriteProfile::Auto.as_str()); + doc["rewrite"]["timeout_ms"] = toml_edit::value(30000); + doc["rewrite"]["idle_timeout_ms"] = toml_edit::value(120000); + doc["rewrite"]["max_output_chars"] = toml_edit::value(1200); + doc["rewrite"]["max_tokens"] = toml_edit::value(256); + doc["agentic_rewrite"]["policy_path"] = + toml_edit::value(crate::agentic_rewrite::default_policy_path()); + doc["agentic_rewrite"]["glossary_path"] = + toml_edit::value(crate::agentic_rewrite::default_glossary_path()); + doc["agentic_rewrite"]["default_correction_policy"] = + toml_edit::value(RewriteCorrectionPolicy::Balanced.as_str()); + + std::fs::write(config_path, doc.to_string()) + .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; + + Ok(()) +} + +pub fn update_config_transcription_runtime( + config_path: &Path, + backend: TranscriptionBackend, + fallback: TranscriptionFallback, +) -> Result<()> { + let contents = std::fs::read_to_string(config_path) + .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; + let mut doc = contents + .parse::() + .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; + + ensure_standard_postprocess_tables(&mut doc); + doc["transcription"]["backend"] = toml_edit::value(backend.as_str()); + doc["transcription"]["fallback"] = toml_edit::value(fallback.as_str()); + + std::fs::write(config_path, doc.to_string()) + .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; + Ok(()) +} + +pub fn update_config_rewrite_runtime( + config_path: &Path, + backend: RewriteBackend, + fallback: RewriteFallback, +) -> Result<()> { + let contents = std::fs::read_to_string(config_path) + .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; + let mut doc = contents + .parse::() + .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; + + ensure_standard_postprocess_tables(&mut doc); + normalize_postprocess_mode(&mut doc); + doc["rewrite"]["backend"] = toml_edit::value(backend.as_str()); + doc["rewrite"]["fallback"] = toml_edit::value(fallback.as_str()); + + std::fs::write(config_path, doc.to_string()) + .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; + Ok(()) +} + +pub fn update_config_cloud_settings( + config_path: &Path, + settings: &CloudSettingsUpdate<'_>, +) -> Result<()> { + let contents = std::fs::read_to_string(config_path) + .map_err(|e| WhsprError::Config(format!("failed to read config: {e}")))?; + let mut doc = contents + .parse::() + .map_err(|e| WhsprError::Config(format!("failed to parse config: {e}")))?; + + ensure_standard_postprocess_tables(&mut doc); + ensure_root_table(&mut doc, "cloud"); + ensure_nested_table(&mut doc, "cloud", "transcription"); + ensure_nested_table(&mut doc, "cloud", "rewrite"); + doc["cloud"]["provider"] = toml_edit::value(settings.provider.as_str()); + doc["cloud"]["base_url"] = toml_edit::value(settings.base_url); + doc["cloud"]["api_key"] = toml_edit::value(settings.api_key); + doc["cloud"]["api_key_env"] = toml_edit::value(settings.api_key_env); + doc["cloud"]["connect_timeout_ms"] = toml_edit::value(settings.connect_timeout_ms as i64); + doc["cloud"]["request_timeout_ms"] = toml_edit::value(settings.request_timeout_ms as i64); + doc["cloud"]["transcription"]["model"] = toml_edit::value(settings.transcription_model); + doc["cloud"]["transcription"]["language_mode"] = + toml_edit::value(settings.transcription_language_mode.as_str()); + doc["cloud"]["transcription"]["language"] = toml_edit::value(settings.transcription_language); + doc["cloud"]["rewrite"]["model"] = toml_edit::value(settings.rewrite_model); + doc["cloud"]["rewrite"]["temperature"] = toml_edit::value(settings.rewrite_temperature as f64); + doc["cloud"]["rewrite"]["max_output_tokens"] = + toml_edit::value(settings.rewrite_max_output_tokens as i64); + + std::fs::write(config_path, doc.to_string()) + .map_err(|e| WhsprError::Config(format!("failed to write config: {e}")))?; + Ok(()) +} + +fn ensure_standard_postprocess_tables(doc: &mut toml_edit::DocumentMut) { + ensure_root_table(doc, "transcription"); + ensure_root_table(doc, "postprocess"); + ensure_root_table(doc, "session"); + ensure_root_table(doc, "rewrite"); + ensure_root_table(doc, "agentic_rewrite"); + ensure_root_table(doc, "cloud"); + ensure_nested_table(doc, "cloud", "transcription"); + ensure_nested_table(doc, "cloud", "rewrite"); + ensure_root_table(doc, "personalization"); +} + +fn normalize_postprocess_mode(doc: &mut toml_edit::DocumentMut) { + let current = doc["postprocess"]["mode"].as_str().unwrap_or_default(); + if !matches!( + current, + "raw" | "advanced_local" | "agentic_rewrite" | "legacy_basic" + ) { + doc["postprocess"]["mode"] = toml_edit::value(PostprocessMode::AdvancedLocal.as_str()); + } +} + +fn ensure_root_table(doc: &mut toml_edit::DocumentMut, key: &str) { + let root = doc.as_table_mut(); + let needs_insert = !root.contains_key(key) || !root[key].is_table(); + if needs_insert { + root.insert(key, toml_edit::Item::Table(toml_edit::Table::new())); + } +} + +fn ensure_nested_table(doc: &mut toml_edit::DocumentMut, parent: &str, child: &str) { + ensure_root_table(doc, parent); + let root = doc.as_table_mut(); + let Some(parent_item) = root.get_mut(parent) else { + return; + }; + let Some(parent_table) = parent_item.as_table_like_mut() else { + return; + }; + let needs_insert = !parent_table.contains_key(child) + || parent_table + .get(child) + .map(|item| !item.is_table()) + .unwrap_or(true); + if needs_insert { + parent_table.insert(child, toml_edit::Item::Table(toml_edit::Table::new())); + } +} diff --git a/src/config/load.rs b/src/config/load.rs new file mode 100644 index 0000000..abc5b79 --- /dev/null +++ b/src/config/load.rs @@ -0,0 +1,111 @@ +use std::path::Path; + +use crate::error::{Result, WhsprError}; + +use super::{Config, PostprocessMode, TranscriptionBackend, resolve_config_path}; + +impl Config { + pub fn load(path: Option<&Path>) -> Result { + let config_path = resolve_config_path(path); + + if !config_path.exists() { + tracing::info!( + "no config file found at {}, using defaults", + config_path.display() + ); + return Ok(Config::default()); + } + + let contents = std::fs::read_to_string(&config_path).map_err(|e| { + WhsprError::Config(format!("failed to read {}: {e}", config_path.display())) + })?; + + let mut config: Config = toml::from_str(&contents).map_err(|e| { + WhsprError::Config(format!("failed to parse {}: {e}", config_path.display())) + })?; + + config.apply_legacy_transcription_migration(&contents, &config_path); + config.apply_legacy_cleanup_migration(&contents, &config_path); + config.apply_cloud_sanitization(); + Ok(config) + } + + fn apply_legacy_transcription_migration(&mut self, contents: &str, config_path: &Path) { + let transcription_present = section_present(contents, "transcription"); + let whisper_present = section_present(contents, "whisper"); + + if !transcription_present && whisper_present { + tracing::warn!( + "config {} uses deprecated [whisper]; mapping to [transcription]", + config_path.display() + ); + self.transcription.backend = TranscriptionBackend::WhisperCpp; + self.transcription.model_path = self.legacy_whisper.model_path.clone(); + self.transcription.language = self.legacy_whisper.language.clone(); + self.transcription.use_gpu = self.legacy_whisper.use_gpu; + self.transcription.flash_attn = self.legacy_whisper.flash_attn; + } else if whisper_present { + tracing::warn!( + "config {} contains deprecated [whisper]; [transcription] takes precedence", + config_path.display() + ); + } + } + + fn apply_legacy_cleanup_migration(&mut self, contents: &str, config_path: &Path) { + let cleanup_present = cleanup_section_present(contents); + + if cleanup_present && self.postprocess.mode == PostprocessMode::Raw { + if self.cleanup.enabled { + tracing::warn!( + "config {} uses deprecated [cleanup]; mapping to postprocess.mode = \"legacy_basic\"", + config_path.display() + ); + self.postprocess.mode = PostprocessMode::LegacyBasic; + } else { + tracing::warn!( + "config {} disables deprecated [cleanup]; keeping postprocess.mode = \"raw\"", + config_path.display() + ); + } + } else if cleanup_present && self.postprocess.mode != PostprocessMode::LegacyBasic { + tracing::warn!( + "config {} contains deprecated [cleanup]; [postprocess] takes precedence", + config_path.display() + ); + } + } + + fn apply_cloud_sanitization(&mut self) { + if self.transcription.local_backend == TranscriptionBackend::Cloud { + tracing::warn!( + "transcription.local_backend cannot be cloud; falling back to whisper_cpp" + ); + self.transcription.local_backend = TranscriptionBackend::WhisperCpp; + } + if self.cloud.api_key.trim().is_empty() && looks_like_cloud_api_key(&self.cloud.api_key_env) + { + tracing::warn!( + "cloud.api_key_env looks like a literal API key; treating it as cloud.api_key" + ); + self.cloud.api_key = self.cloud.api_key_env.trim().to_string(); + self.cloud.api_key_env = "OPENAI_API_KEY".into(); + } + } +} + +fn cleanup_section_present(contents: &str) -> bool { + section_present(contents, "cleanup") +} + +fn section_present(contents: &str, name: &str) -> bool { + toml::from_str::(contents) + .ok() + .and_then(|value| value.get(name).cloned()) + .is_some() +} + +fn looks_like_cloud_api_key(value: &str) -> bool { + let trimmed = value.trim(); + trimmed.starts_with("sk-") +} diff --git a/src/config/paths.rs b/src/config/paths.rs new file mode 100644 index 0000000..0b07ec4 --- /dev/null +++ b/src/config/paths.rs @@ -0,0 +1,98 @@ +use std::path::{Path, PathBuf}; + +use super::Config; + +pub fn default_config_path() -> PathBuf { + xdg_dir("config").join("whispers").join("config.toml") +} + +pub fn resolve_config_path(path: Option<&Path>) -> PathBuf { + match path { + Some(path) => path.to_path_buf(), + None => default_config_path(), + } +} + +pub fn data_dir() -> PathBuf { + xdg_dir("data").join("whispers") +} + +fn xdg_dir(kind: &str) -> PathBuf { + match kind { + "config" => { + if let Ok(dir) = std::env::var("XDG_CONFIG_HOME") { + PathBuf::from(dir) + } else if let Ok(home) = std::env::var("HOME") { + PathBuf::from(home).join(".config") + } else { + tracing::warn!("neither XDG_CONFIG_HOME nor HOME is set, falling back to /tmp"); + PathBuf::from("/tmp") + } + } + "data" => { + if let Ok(dir) = std::env::var("XDG_DATA_HOME") { + PathBuf::from(dir) + } else if let Ok(home) = std::env::var("HOME") { + PathBuf::from(home).join(".local").join("share") + } else { + tracing::warn!("neither XDG_DATA_HOME nor HOME is set, falling back to /tmp"); + PathBuf::from("/tmp") + } + } + _ => { + tracing::warn!("unknown XDG directory kind '{kind}', falling back to /tmp"); + PathBuf::from("/tmp") + } + } +} + +pub fn expand_tilde(path: &str) -> String { + match path.strip_prefix("~/") { + Some(rest) => { + if let Ok(home) = std::env::var("HOME") { + return format!("{home}/{rest}"); + } + tracing::warn!("HOME is not set, cannot expand tilde in path: {path}"); + } + None if path == "~" => { + if let Ok(home) = std::env::var("HOME") { + return home; + } + tracing::warn!("HOME is not set, cannot expand tilde in path: {path}"); + } + _ => {} + } + path.to_string() +} + +impl Config { + pub fn resolved_model_path(&self) -> PathBuf { + PathBuf::from(expand_tilde(&self.transcription.model_path)) + } + + pub fn resolved_rewrite_model_path(&self) -> Option { + (!self.rewrite.model_path.trim().is_empty()) + .then(|| PathBuf::from(expand_tilde(&self.rewrite.model_path))) + } + + pub fn resolved_rewrite_instructions_path(&self) -> Option { + (!self.rewrite.instructions_path.trim().is_empty()) + .then(|| PathBuf::from(expand_tilde(&self.rewrite.instructions_path))) + } + + pub fn resolved_dictionary_path(&self) -> PathBuf { + PathBuf::from(expand_tilde(&self.personalization.dictionary_path)) + } + + pub fn resolved_snippets_path(&self) -> PathBuf { + PathBuf::from(expand_tilde(&self.personalization.snippets_path)) + } + + pub fn resolved_agentic_policy_path(&self) -> PathBuf { + PathBuf::from(expand_tilde(&self.agentic_rewrite.policy_path)) + } + + pub fn resolved_agentic_glossary_path(&self) -> PathBuf { + PathBuf::from(expand_tilde(&self.agentic_rewrite.glossary_path)) + } +} diff --git a/src/config/schema.rs b/src/config/schema.rs new file mode 100644 index 0000000..9ddff17 --- /dev/null +++ b/src/config/schema.rs @@ -0,0 +1,467 @@ +use serde::Deserialize; + +use crate::rewrite_profile::RewriteProfile; +use crate::rewrite_protocol::RewriteCorrectionPolicy; + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +pub struct Config { + pub audio: AudioConfig, + pub transcription: TranscriptionConfig, + #[serde(default, rename = "whisper")] + pub(crate) legacy_whisper: LegacyWhisperConfig, + pub postprocess: PostprocessConfig, + pub session: SessionConfig, + pub personalization: PersonalizationConfig, + pub rewrite: RewriteConfig, + pub agentic_rewrite: AgenticRewriteConfig, + pub cloud: CloudConfig, + pub cleanup: CleanupConfig, + pub inject: InjectConfig, + pub feedback: FeedbackConfig, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct AudioConfig { + pub device: String, + pub sample_rate: u32, +} + +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum TranscriptionBackend { + #[default] + WhisperCpp, + FasterWhisper, + Nemo, + Cloud, +} + +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum TranscriptionFallback { + None, + #[default] + ConfiguredLocal, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct TranscriptionConfig { + pub backend: TranscriptionBackend, + pub fallback: TranscriptionFallback, + pub local_backend: TranscriptionBackend, + pub selected_model: String, + pub model_path: String, + pub language: String, + pub use_gpu: bool, + pub flash_attn: bool, + pub idle_timeout_ms: u64, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub(crate) struct LegacyWhisperConfig { + pub(crate) model_path: String, + pub(crate) language: String, + pub(crate) use_gpu: bool, + pub(crate) flash_attn: bool, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct PostprocessConfig { + pub mode: PostprocessMode, +} + +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum PostprocessMode { + #[default] + Raw, + AdvancedLocal, + AgenticRewrite, + LegacyBasic, +} + +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteBackend { + #[default] + Local, + Cloud, +} + +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RewriteFallback { + None, + #[default] + Local, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct RewriteConfig { + pub backend: RewriteBackend, + pub fallback: RewriteFallback, + pub selected_model: String, + pub model_path: String, + pub instructions_path: String, + pub profile: RewriteProfile, + pub timeout_ms: u64, + pub idle_timeout_ms: u64, + pub max_output_chars: usize, + pub max_tokens: usize, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct AgenticRewriteConfig { + pub policy_path: String, + pub glossary_path: String, + pub default_correction_policy: RewriteCorrectionPolicy, +} + +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum CloudProvider { + #[default] + #[serde(rename = "openai")] + OpenAi, + #[serde(rename = "openai_compatible")] + OpenAiCompatible, +} + +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum CloudLanguageMode { + #[default] + InheritLocal, + Force, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct CloudConfig { + pub provider: CloudProvider, + pub base_url: String, + pub api_key: String, + pub api_key_env: String, + pub connect_timeout_ms: u64, + pub request_timeout_ms: u64, + pub transcription: CloudTranscriptionConfig, + pub rewrite: CloudRewriteConfig, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct CloudTranscriptionConfig { + pub model: String, + pub language_mode: CloudLanguageMode, + pub language: String, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct CloudRewriteConfig { + pub model: String, + pub temperature: f32, + pub max_output_tokens: usize, +} + +#[derive(Debug, Clone)] +pub struct CloudSettingsUpdate<'a> { + pub provider: CloudProvider, + pub base_url: &'a str, + pub api_key: &'a str, + pub api_key_env: &'a str, + pub connect_timeout_ms: u64, + pub request_timeout_ms: u64, + pub transcription_model: &'a str, + pub transcription_language_mode: CloudLanguageMode, + pub transcription_language: &'a str, + pub rewrite_model: &'a str, + pub rewrite_temperature: f32, + pub rewrite_max_output_tokens: usize, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct SessionConfig { + pub enabled: bool, + pub max_entries: usize, + pub max_age_ms: u64, + pub max_replace_graphemes: usize, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct PersonalizationConfig { + pub dictionary_path: String, + pub snippets_path: String, + pub snippet_trigger: String, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[serde(default)] +pub struct CleanupConfig { + pub enabled: bool, + pub profile: CleanupProfile, + pub spoken_formatting: bool, + pub backtrack: bool, + pub remove_fillers: bool, +} + +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum CleanupProfile { + #[default] + Basic, + Aggressive, +} + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +pub struct InjectConfig {} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct FeedbackConfig { + pub enabled: bool, + pub start_sound: String, + pub stop_sound: String, +} + +impl Default for AudioConfig { + fn default() -> Self { + Self { + device: String::new(), + sample_rate: 16000, + } + } +} + +impl Default for TranscriptionConfig { + fn default() -> Self { + Self { + backend: TranscriptionBackend::WhisperCpp, + fallback: TranscriptionFallback::ConfiguredLocal, + local_backend: TranscriptionBackend::WhisperCpp, + selected_model: "large-v3-turbo".into(), + model_path: "~/.local/share/whispers/ggml-large-v3-turbo.bin".into(), + language: "auto".into(), + use_gpu: true, + flash_attn: true, + idle_timeout_ms: 120000, + } + } +} + +impl Default for LegacyWhisperConfig { + fn default() -> Self { + let default = TranscriptionConfig::default(); + Self { + model_path: default.model_path, + language: default.language, + use_gpu: default.use_gpu, + flash_attn: default.flash_attn, + } + } +} + +impl Default for PostprocessConfig { + fn default() -> Self { + Self { + mode: PostprocessMode::Raw, + } + } +} + +impl PostprocessMode { + pub fn as_str(self) -> &'static str { + match self { + Self::Raw => "raw", + Self::AdvancedLocal => "advanced_local", + Self::AgenticRewrite => "agentic_rewrite", + Self::LegacyBasic => "legacy_basic", + } + } + + pub fn uses_rewrite(self) -> bool { + matches!(self, Self::AdvancedLocal | Self::AgenticRewrite) + } +} + +impl TranscriptionBackend { + pub fn as_str(self) -> &'static str { + match self { + Self::WhisperCpp => "whisper_cpp", + Self::FasterWhisper => "faster_whisper", + Self::Nemo => "nemo", + Self::Cloud => "cloud", + } + } +} + +impl TranscriptionFallback { + pub fn as_str(self) -> &'static str { + match self { + Self::None => "none", + Self::ConfiguredLocal => "configured_local", + } + } +} + +impl RewriteBackend { + pub fn as_str(self) -> &'static str { + match self { + Self::Local => "local", + Self::Cloud => "cloud", + } + } +} + +impl RewriteFallback { + pub fn as_str(self) -> &'static str { + match self { + Self::None => "none", + Self::Local => "local", + } + } +} + +impl CloudProvider { + pub fn as_str(self) -> &'static str { + match self { + Self::OpenAi => "openai", + Self::OpenAiCompatible => "openai_compatible", + } + } +} + +impl CloudLanguageMode { + pub fn as_str(self) -> &'static str { + match self { + Self::InheritLocal => "inherit_local", + Self::Force => "force", + } + } +} + +impl Default for RewriteConfig { + fn default() -> Self { + Self { + backend: RewriteBackend::Local, + fallback: RewriteFallback::Local, + selected_model: "qwen-3.5-4b-q4_k_m".into(), + model_path: String::new(), + instructions_path: "~/.local/share/whispers/rewrite-instructions.txt".into(), + profile: RewriteProfile::Auto, + timeout_ms: 30000, + idle_timeout_ms: 120000, + max_output_chars: 1200, + max_tokens: 256, + } + } +} + +impl Default for AgenticRewriteConfig { + fn default() -> Self { + Self { + policy_path: crate::agentic_rewrite::default_policy_path().into(), + glossary_path: crate::agentic_rewrite::default_glossary_path().into(), + default_correction_policy: RewriteCorrectionPolicy::Balanced, + } + } +} + +impl Default for CloudConfig { + fn default() -> Self { + Self { + provider: CloudProvider::OpenAi, + base_url: String::new(), + api_key: String::new(), + api_key_env: "OPENAI_API_KEY".into(), + connect_timeout_ms: 3000, + request_timeout_ms: 15000, + transcription: CloudTranscriptionConfig::default(), + rewrite: CloudRewriteConfig::default(), + } + } +} + +impl Default for CloudTranscriptionConfig { + fn default() -> Self { + Self { + model: "gpt-4o-mini-transcribe".into(), + language_mode: CloudLanguageMode::InheritLocal, + language: String::new(), + } + } +} + +impl Default for CloudRewriteConfig { + fn default() -> Self { + Self { + model: "gpt-4.1-mini".into(), + temperature: 0.1, + max_output_tokens: 256, + } + } +} + +impl Default for PersonalizationConfig { + fn default() -> Self { + Self { + dictionary_path: "~/.local/share/whispers/dictionary.toml".into(), + snippets_path: "~/.local/share/whispers/snippets.toml".into(), + snippet_trigger: "insert".into(), + } + } +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + enabled: true, + max_entries: 3, + max_age_ms: 8000, + max_replace_graphemes: 400, + } + } +} + +impl Default for CleanupConfig { + fn default() -> Self { + Self { + enabled: true, + profile: CleanupProfile::Basic, + spoken_formatting: true, + backtrack: true, + remove_fillers: true, + } + } +} + +impl Default for FeedbackConfig { + fn default() -> Self { + Self { + enabled: true, + start_sound: String::new(), + stop_sound: String::new(), + } + } +} + +impl TranscriptionConfig { + pub fn resolved_local_backend(&self) -> TranscriptionBackend { + match self.local_backend { + TranscriptionBackend::WhisperCpp + | TranscriptionBackend::FasterWhisper + | TranscriptionBackend::Nemo => self.local_backend, + TranscriptionBackend::Cloud => TranscriptionBackend::WhisperCpp, + } + } +} diff --git a/src/config/tests.rs b/src/config/tests.rs new file mode 100644 index 0000000..0b2a26f --- /dev/null +++ b/src/config/tests.rs @@ -0,0 +1,248 @@ +use crate::error::WhsprError; + +use super::*; + +#[test] +fn load_missing_file_uses_defaults() { + let path = crate::test_support::unique_temp_path("config-missing", "toml"); + let config = Config::load(Some(&path)).expect("missing config should load defaults"); + assert_eq!(config.audio.sample_rate, 16000); + assert_eq!(config.transcription.language, "auto"); + assert_eq!( + config.transcription.backend, + TranscriptionBackend::WhisperCpp + ); + assert_eq!(config.postprocess.mode, PostprocessMode::Raw); + assert_eq!(config.personalization.snippet_trigger, "insert"); + assert_eq!(config.rewrite.selected_model, "qwen-3.5-4b-q4_k_m"); +} + +#[test] +fn load_invalid_toml_returns_parse_error() { + let path = crate::test_support::unique_temp_path("config-invalid", "toml"); + std::fs::write(&path, "not = [valid = toml").expect("write invalid config"); + let err = Config::load(Some(&path)).expect_err("invalid config should fail"); + match err { + WhsprError::Config(msg) => { + assert!(msg.contains("failed to parse"), "unexpected message: {msg}"); + } + other => panic!("unexpected error variant: {other:?}"), + } +} + +#[test] +fn expand_tilde_uses_home_when_present() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&["HOME"]); + crate::test_support::set_env("HOME", "/tmp/whispers-home"); + assert_eq!( + expand_tilde("~/models/ggml.bin"), + "/tmp/whispers-home/models/ggml.bin" + ); + assert_eq!(expand_tilde("~"), "/tmp/whispers-home"); +} + +#[test] +fn expand_tilde_without_home_returns_original_path() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&["HOME"]); + crate::test_support::remove_env("HOME"); + assert_eq!(expand_tilde("~/models/ggml.bin"), "~/models/ggml.bin"); + assert_eq!(expand_tilde("~"), "~"); +} + +#[test] +fn write_default_and_update_model_path_roundtrip() { + let dir = crate::test_support::unique_temp_dir("config-roundtrip"); + let config_path = dir.join("nested").join("config.toml"); + + write_default_config(&config_path, "~/old-model.bin").expect("write config"); + assert!(config_path.exists(), "config file should exist"); + + update_config_transcription_selection( + &config_path, + TranscriptionBackend::WhisperCpp, + "large-v3-turbo", + "~/new-model.bin", + true, + ) + .expect("update config"); + let loaded = Config::load(Some(&config_path)).expect("load config"); + assert_eq!(loaded.transcription.model_path, "~/new-model.bin"); + assert_eq!( + loaded.transcription.backend, + TranscriptionBackend::WhisperCpp + ); + assert_eq!(loaded.audio.sample_rate, 16000); + assert_eq!(loaded.postprocess.mode, PostprocessMode::Raw); + assert_eq!( + loaded.personalization.dictionary_path, + "~/.local/share/whispers/dictionary.toml" + ); + assert!(loaded.session.enabled); + assert_eq!(loaded.session.max_entries, 3); + assert_eq!(loaded.rewrite.timeout_ms, 30000); + assert!(loaded.feedback.enabled); + + let raw = std::fs::read_to_string(&config_path).expect("read config"); + assert!(raw.contains("[audio]")); + assert!(raw.contains("[transcription]")); + assert!(raw.contains("[postprocess]")); + assert!(raw.contains("[session]")); + assert!(raw.contains("[rewrite]")); + assert!(!raw.contains("[whisper]")); +} + +#[test] +fn selecting_nemo_model_sets_non_expiring_asr_worker_timeout() { + let config_path = crate::test_support::unique_temp_path("config-nemo-timeout", "toml"); + write_default_config(&config_path, "~/old-model.bin").expect("write config"); + + update_config_transcription_selection( + &config_path, + TranscriptionBackend::Nemo, + "parakeet-tdt_ctc-1.1b", + "~/.local/share/whispers/nemo/models/parakeet-tdt_ctc-1.1b", + true, + ) + .expect("select nemo model"); + + let loaded = Config::load(Some(&config_path)).expect("load config"); + assert_eq!(loaded.transcription.backend, TranscriptionBackend::Nemo); + assert_eq!(loaded.transcription.idle_timeout_ms, 0); +} + +#[test] +fn load_legacy_whisper_section_maps_to_transcription() { + let path = crate::test_support::unique_temp_path("config-whisper-legacy", "toml"); + std::fs::write( + &path, + r#"[whisper] +model_path = "~/legacy-model.bin" +language = "en" +use_gpu = false +flash_attn = false +"#, + ) + .expect("write config"); + + let loaded = Config::load(Some(&path)).expect("load config"); + assert_eq!( + loaded.transcription.backend, + TranscriptionBackend::WhisperCpp + ); + assert_eq!(loaded.transcription.model_path, "~/legacy-model.bin"); + assert_eq!(loaded.transcription.language, "en"); + assert!(!loaded.transcription.use_gpu); + assert!(!loaded.transcription.flash_attn); +} + +#[test] +fn load_legacy_cleanup_section_maps_to_legacy_basic() { + let path = crate::test_support::unique_temp_path("config-cleanup", "toml"); + std::fs::write( + &path, + r#"[cleanup] +profile = "aggressive" +spoken_formatting = false +remove_fillers = false +"#, + ) + .expect("write config"); + + let config = Config::load(Some(&path)).expect("load config"); + assert_eq!(config.postprocess.mode, PostprocessMode::LegacyBasic); + assert_eq!(config.cleanup.profile, CleanupProfile::Aggressive); + assert!(!config.cleanup.spoken_formatting); + assert!(config.cleanup.backtrack); + assert!(!config.cleanup.remove_fillers); +} + +#[test] +fn update_rewrite_selection_enables_advanced_mode() { + let dir = crate::test_support::unique_temp_dir("config-rewrite-select"); + let config_path = dir.join("config.toml"); + write_default_config(&config_path, "~/model.bin").expect("write config"); + + update_config_rewrite_selection(&config_path, "qwen-3.5-2b-q4_k_m") + .expect("select rewrite model"); + + let loaded = Config::load(Some(&config_path)).expect("load config"); + assert_eq!(loaded.postprocess.mode, PostprocessMode::AdvancedLocal); + assert_eq!(loaded.rewrite.selected_model, "qwen-3.5-2b-q4_k_m"); + assert!(loaded.rewrite.model_path.is_empty()); + assert_eq!( + loaded.rewrite.instructions_path, + "~/.local/share/whispers/rewrite-instructions.txt" + ); + assert_eq!( + loaded.rewrite.profile, + crate::rewrite_profile::RewriteProfile::Auto + ); + assert_eq!(loaded.rewrite.timeout_ms, 30000); + assert_eq!(loaded.rewrite.idle_timeout_ms, 120000); +} + +#[test] +fn update_helpers_upgrade_legacy_configs_without_panicking() { + let config_path = crate::test_support::unique_temp_path("config-legacy-upgrade", "toml"); + std::fs::write( + &config_path, + r#"[audio] +sample_rate = 16000 + +[whisper] +model_path = "~/.local/share/whispers/ggml-large-v3-turbo.bin" +language = "auto" +"#, + ) + .expect("write legacy config"); + + update_config_transcription_selection( + &config_path, + TranscriptionBackend::WhisperCpp, + "large-v3-turbo", + "~/.local/share/whispers/ggml-large-v3-turbo.bin", + true, + ) + .expect("update transcription selection"); + update_config_rewrite_selection(&config_path, "qwen-3.5-4b-q4_k_m") + .expect("update rewrite selection"); + + let loaded = Config::load(Some(&config_path)).expect("load upgraded config"); + assert_eq!( + loaded.transcription.backend, + TranscriptionBackend::WhisperCpp + ); + assert_eq!(loaded.transcription.selected_model, "large-v3-turbo"); + assert_eq!(loaded.postprocess.mode, PostprocessMode::AdvancedLocal); + assert_eq!(loaded.rewrite.selected_model, "qwen-3.5-4b-q4_k_m"); + + let raw = std::fs::read_to_string(&config_path).expect("read upgraded config"); + assert!(!raw.contains("[whisper]")); +} + +#[test] +fn load_cloud_literal_key_from_legacy_api_key_env() { + let path = crate::test_support::unique_temp_path("config-cloud-literal-key", "toml"); + std::fs::write( + &path, + r#"[cloud] +api_key_env = "sk-test-inline" +"#, + ) + .expect("write config"); + + let loaded = Config::load(Some(&path)).expect("load config"); + assert_eq!(loaded.cloud.api_key, "sk-test-inline"); + assert_eq!(loaded.cloud.api_key_env, "OPENAI_API_KEY"); +} + +#[test] +fn default_config_template_matches_example_file() { + let example_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("config.example.toml"); + let example = std::fs::read_to_string(&example_path).expect("read config example"); + let expected = + super::edit::default_config_template("~/.local/share/whispers/ggml-large-v3-turbo.bin"); + assert_eq!(example, expected); +} diff --git a/src/context.rs b/src/context.rs index d34172c..5299fe9 100644 --- a/src/context.rs +++ b/src/context.rs @@ -119,13 +119,14 @@ fn parse_niri_focused_window_json(raw: &str, captured_at_ms: u64) -> Option Option Option { value.get(key)?.as_str().map(str::to_string) } +fn infer_browser_domain(surface_kind: SurfaceKind, window_title: Option<&str>) -> Option { + if surface_kind != SurfaceKind::Browser { + return None; + } + + let title = window_title?; + extract_browser_domain_from_title(title) +} + +fn extract_browser_domain_from_title(title: &str) -> Option { + let normalized = title + .replace(" — ", "\n") + .replace(" – ", "\n") + .replace(" - ", "\n") + .replace(" | ", "\n") + .replace(" · ", "\n"); + + for segment in normalized.lines() { + let segment = segment.trim(); + if segment.is_empty() { + continue; + } + + if let Some(domain) = extract_domain_candidate(segment) { + return Some(domain); + } + + for token in segment.split_whitespace() { + if let Some(domain) = extract_domain_candidate(token) { + return Some(domain); + } + } + } + + None +} + +fn extract_domain_candidate(candidate: &str) -> Option { + let trimmed = candidate.trim_matches(|ch: char| { + ch.is_whitespace() || matches!(ch, '"' | '\'' | '(' | ')' | '[' | ']' | '{' | '}' | ',') + }); + if trimmed.is_empty() { + return None; + } + + let without_scheme = trimmed.split("://").nth(1).unwrap_or(trimmed); + let host = without_scheme + .split('/') + .next() + .unwrap_or(without_scheme) + .split(':') + .next() + .unwrap_or(without_scheme) + .trim_end_matches('.'); + + looks_like_domain(host).then(|| host.to_ascii_lowercase()) +} + +fn looks_like_domain(host: &str) -> bool { + let mut labels = host.split('.'); + let Some(first) = labels.next() else { + return false; + }; + if first.is_empty() || !is_domain_label(first) { + return false; + } + + let rest = labels.collect::>(); + if rest.is_empty() { + return false; + } + + rest.iter().all(|label| is_domain_label(label)) + && rest + .last() + .is_some_and(|label| label.chars().any(|ch| ch.is_ascii_alphabetic())) +} + +fn is_domain_label(label: &str) -> bool { + !label.is_empty() + && label + .chars() + .all(|ch| ch.is_ascii_alphanumeric() || ch == '-') +} + fn now_ms() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) @@ -288,4 +375,33 @@ mod tests { let raw = r#"{"mapped": true}"#; assert!(parse_hyprland_activewindow_json(raw, 1).is_none()); } + + #[test] + fn parse_niri_focused_window_json_extracts_browser_domain_from_title() { + let raw = r#"{ + "id": 11, + "title": "docs.rs - serde_json", + "app_id": "firefox" + }"#; + + let context = parse_niri_focused_window_json(raw, 42).expect("context"); + assert_eq!(context.surface_kind, SurfaceKind::Browser); + assert_eq!(context.browser_domain.as_deref(), Some("docs.rs")); + } + + #[test] + fn parse_hyprland_activewindow_json_extracts_browser_domain_from_title() { + let raw = r#"{ + "address": "0x5678", + "class": "firefox", + "title": "https://news.ycombinator.com/item?id=1" + }"#; + + let context = parse_hyprland_activewindow_json(raw, 42).expect("context"); + assert_eq!(context.surface_kind, SurfaceKind::Browser); + assert_eq!( + context.browser_domain.as_deref(), + Some("news.ycombinator.com") + ); + } } diff --git a/src/inject.rs b/src/inject.rs deleted file mode 100644 index 4fd91d2..0000000 --- a/src/inject.rs +++ /dev/null @@ -1,278 +0,0 @@ -use std::process::{Command, Stdio}; -use std::time::Duration; - -use evdev::uinput::VirtualDevice; -use evdev::{AttributeSet, EventType, InputEvent, KeyCode}; - -use crate::error::{Result, WhsprError}; - -const DEVICE_READY_DELAY: Duration = Duration::from_millis(120); -const CLIPBOARD_READY_DELAY: Duration = Duration::from_millis(180); -const POST_DELETE_SETTLE_DELAY: Duration = Duration::from_millis(30); - -pub struct TextInjector { - wl_copy_bin: String, - wl_copy_args: Vec, -} - -impl TextInjector { - pub fn new() -> Self { - Self { - wl_copy_bin: "wl-copy".to_string(), - wl_copy_args: Vec::new(), - } - } - - #[cfg(test)] - fn with_wl_copy_command(bin: &str, args: &[&str]) -> Self { - Self { - wl_copy_bin: bin.to_string(), - wl_copy_args: args.iter().map(|arg| (*arg).to_string()).collect(), - } - } - - pub async fn inject(&self, text: &str) -> Result<()> { - if text.is_empty() { - tracing::warn!("empty text, nothing to inject"); - return Ok(()); - } - - let text = text.to_string(); - let text_len = text.len(); - let wl_copy_bin = self.wl_copy_bin.clone(); - let wl_copy_args = self.wl_copy_args.clone(); - tokio::task::spawn_blocking(move || inject_sync(&wl_copy_bin, &wl_copy_args, &text)) - .await - .map_err(|e| WhsprError::Injection(format!("injection task panicked: {e}")))??; - - tracing::info!("injected {} chars via wl-copy + Ctrl+Shift+V", text_len); - Ok(()) - } - - pub async fn replace_recent_text(&self, delete_graphemes: usize, text: &str) -> Result<()> { - if delete_graphemes == 0 { - return self.inject(text).await; - } - - let text = text.to_string(); - let wl_copy_bin = self.wl_copy_bin.clone(); - let wl_copy_args = self.wl_copy_args.clone(); - tokio::task::spawn_blocking(move || { - replace_recent_text_sync(&wl_copy_bin, &wl_copy_args, delete_graphemes, &text) - }) - .await - .map_err(|e| WhsprError::Injection(format!("replace task panicked: {e}")))??; - - tracing::info!( - "replaced {} graphemes via backspace + wl-copy paste", - delete_graphemes - ); - Ok(()) - } -} - -fn inject_sync(wl_copy_bin: &str, wl_copy_args: &[String], text: &str) -> Result<()> { - let mut device = build_virtual_device()?; - - run_wl_copy(wl_copy_bin, wl_copy_args, text)?; - - // Wait for compositor to process the clipboard offer. - // The uinput device was created above, so it has already been - // registering during the wl-copy write. - std::thread::sleep(CLIPBOARD_READY_DELAY); - emit_paste_combo(&mut device)?; - - Ok(()) -} - -fn replace_recent_text_sync( - wl_copy_bin: &str, - wl_copy_args: &[String], - delete_graphemes: usize, - text: &str, -) -> Result<()> { - let mut device = build_virtual_device()?; - // Unlike plain injection, replacement can try to backspace immediately - // after creating the uinput device. Give the compositor a moment to - // register it first so the initial backspaces are not dropped. - std::thread::sleep(DEVICE_READY_DELAY); - emit_backspaces(&mut device, delete_graphemes)?; - - if !text.is_empty() { - std::thread::sleep(POST_DELETE_SETTLE_DELAY); - run_wl_copy(wl_copy_bin, wl_copy_args, text)?; - std::thread::sleep(CLIPBOARD_READY_DELAY); - emit_paste_combo(&mut device)?; - } - - Ok(()) -} - -fn build_virtual_device() -> Result { - let mut keys = AttributeSet::::new(); - keys.insert(KeyCode::KEY_LEFTCTRL); - keys.insert(KeyCode::KEY_LEFTSHIFT); - keys.insert(KeyCode::KEY_V); - keys.insert(KeyCode::KEY_BACKSPACE); - - VirtualDevice::builder() - .map_err(|e| WhsprError::Injection(format!("uinput: {e}")))? - .name("whispers-keyboard") - .with_keys(&keys) - .map_err(|e| WhsprError::Injection(format!("uinput keys: {e}")))? - .build() - .map_err(|e| WhsprError::Injection(format!("uinput build: {e}"))) -} - -fn run_wl_copy(wl_copy_bin: &str, wl_copy_args: &[String], text: &str) -> Result<()> { - run_wl_copy_with_timeout(wl_copy_bin, wl_copy_args, text, Duration::from_secs(2)) -} - -fn run_wl_copy_with_timeout( - wl_copy_bin: &str, - wl_copy_args: &[String], - text: &str, - timeout: Duration, -) -> Result<()> { - let mut wl_copy = Command::new(wl_copy_bin) - .args(wl_copy_args) - .stdin(Stdio::piped()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .map_err(|e| WhsprError::Injection(format!("failed to spawn wl-copy: {e}")))?; - - { - use std::io::Write; - let mut stdin = wl_copy - .stdin - .take() - .ok_or_else(|| WhsprError::Injection("wl-copy stdin unavailable".into()))?; - stdin - .write_all(text.as_bytes()) - .map_err(|e| WhsprError::Injection(format!("wl-copy stdin write: {e}")))?; - } - - let deadline = std::time::Instant::now() + timeout; - let status = loop { - if let Some(status) = wl_copy - .try_wait() - .map_err(|e| WhsprError::Injection(format!("wl-copy wait: {e}")))? - { - break status; - } - if std::time::Instant::now() >= deadline { - let _ = wl_copy.kill(); - let _ = wl_copy.wait(); - return Err(WhsprError::Injection(format!( - "wl-copy timed out after {}ms", - timeout.as_millis() - ))); - } - std::thread::sleep(Duration::from_millis(10)); - }; - if !status.success() { - return Err(WhsprError::Injection(format!( - "wl-copy exited with {status}" - ))); - } - Ok(()) -} - -fn emit_paste_combo(device: &mut VirtualDevice) -> Result<()> { - device - .emit(&[ - InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTCTRL.0, 1), - InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTSHIFT.0, 1), - ]) - .map_err(|e| WhsprError::Injection(format!("paste modifier press: {e}")))?; - std::thread::sleep(Duration::from_millis(12)); - - device - .emit(&[ - InputEvent::new(EventType::KEY.0, KeyCode::KEY_V.0, 1), - InputEvent::new(EventType::KEY.0, KeyCode::KEY_V.0, 0), - ]) - .map_err(|e| WhsprError::Injection(format!("paste key press: {e}")))?; - std::thread::sleep(Duration::from_millis(12)); - - device - .emit(&[ - InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTSHIFT.0, 0), - InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTCTRL.0, 0), - ]) - .map_err(|e| WhsprError::Injection(format!("paste modifier release: {e}")))?; - - Ok(()) -} - -fn emit_backspaces(device: &mut VirtualDevice, count: usize) -> Result<()> { - for _ in 0..count { - device - .emit(&[ - InputEvent::new(EventType::KEY.0, KeyCode::KEY_BACKSPACE.0, 1), - InputEvent::new(EventType::KEY.0, KeyCode::KEY_BACKSPACE.0, 0), - ]) - .map_err(|e| WhsprError::Injection(format!("backspace key press: {e}")))?; - std::thread::sleep(Duration::from_millis(6)); - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::error::WhsprError; - - #[test] - fn run_wl_copy_reports_spawn_failure() { - let err = run_wl_copy("/definitely/missing/wl-copy", &[], "hello") - .expect_err("missing binary should fail"); - match err { - WhsprError::Injection(msg) => { - assert!(msg.contains("failed to spawn wl-copy"), "unexpected: {msg}"); - } - other => panic!("unexpected error variant: {other:?}"), - } - } - - #[test] - fn run_wl_copy_reports_non_zero_exit() { - let err = run_wl_copy( - "/bin/sh", - &[String::from("-c"), String::from("exit 7")], - "hello", - ) - .expect_err("non-zero exit should fail"); - match err { - WhsprError::Injection(msg) => { - assert!(msg.contains("wl-copy exited"), "unexpected: {msg}"); - } - other => panic!("unexpected error variant: {other:?}"), - } - } - - #[test] - fn run_wl_copy_reports_timeout() { - let err = run_wl_copy_with_timeout( - "/bin/sh", - &[String::from("-c"), String::from("sleep 1")], - "hello", - Duration::from_millis(80), - ) - .expect_err("sleep should time out"); - match err { - WhsprError::Injection(msg) => { - assert!(msg.contains("timed out"), "unexpected: {msg}"); - } - other => panic!("unexpected error variant: {other:?}"), - } - } - - #[tokio::test] - async fn inject_empty_text_is_noop() { - let injector = TextInjector::with_wl_copy_command("/bin/true", &[]); - injector.inject("").await.expect("empty text should no-op"); - } -} diff --git a/src/inject/clipboard.rs b/src/inject/clipboard.rs new file mode 100644 index 0000000..dc9b0e8 --- /dev/null +++ b/src/inject/clipboard.rs @@ -0,0 +1,83 @@ +use std::process::{Command, Stdio}; +use std::time::Duration; + +use crate::error::{Result, WhsprError}; + +pub(super) struct ClipboardAdapter<'a> { + wl_copy_bin: &'a str, + wl_copy_args: &'a [String], +} + +impl<'a> ClipboardAdapter<'a> { + pub(super) fn new(wl_copy_bin: &'a str, wl_copy_args: &'a [String]) -> Self { + Self { + wl_copy_bin, + wl_copy_args, + } + } + + pub(super) fn copy(&self, text: &str) -> Result<()> { + run_wl_copy_with_timeout( + self.wl_copy_bin, + self.wl_copy_args, + text, + Duration::from_secs(2), + ) + } +} + +#[cfg(test)] +pub(super) fn run_wl_copy(wl_copy_bin: &str, wl_copy_args: &[String], text: &str) -> Result<()> { + run_wl_copy_with_timeout(wl_copy_bin, wl_copy_args, text, Duration::from_secs(2)) +} + +pub(super) fn run_wl_copy_with_timeout( + wl_copy_bin: &str, + wl_copy_args: &[String], + text: &str, + timeout: Duration, +) -> Result<()> { + let mut wl_copy = Command::new(wl_copy_bin) + .args(wl_copy_args) + .stdin(Stdio::piped()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .map_err(|e| WhsprError::Injection(format!("failed to spawn wl-copy: {e}")))?; + + { + use std::io::Write; + let mut stdin = wl_copy + .stdin + .take() + .ok_or_else(|| WhsprError::Injection("wl-copy stdin unavailable".into()))?; + stdin + .write_all(text.as_bytes()) + .map_err(|e| WhsprError::Injection(format!("wl-copy stdin write: {e}")))?; + } + + let deadline = std::time::Instant::now() + timeout; + let status = loop { + if let Some(status) = wl_copy + .try_wait() + .map_err(|e| WhsprError::Injection(format!("wl-copy wait: {e}")))? + { + break status; + } + if std::time::Instant::now() >= deadline { + let _ = wl_copy.kill(); + let _ = wl_copy.wait(); + return Err(WhsprError::Injection(format!( + "wl-copy timed out after {}ms", + timeout.as_millis() + ))); + } + std::thread::sleep(Duration::from_millis(10)); + }; + if !status.success() { + return Err(WhsprError::Injection(format!( + "wl-copy exited with {status}" + ))); + } + Ok(()) +} diff --git a/src/inject/keyboard.rs b/src/inject/keyboard.rs new file mode 100644 index 0000000..e366dd6 --- /dev/null +++ b/src/inject/keyboard.rs @@ -0,0 +1,73 @@ +use evdev::uinput::VirtualDevice; +use evdev::{AttributeSet, EventType, InputEvent, KeyCode}; + +use crate::error::{Result, WhsprError}; + +pub(super) struct VirtualKeyboardAdapter { + device: VirtualDevice, +} + +impl VirtualKeyboardAdapter { + pub(super) fn new() -> Result { + Ok(Self { + device: build_virtual_device()?, + }) + } + + pub(super) fn emit_paste_combo(&mut self) -> Result<()> { + self.device + .emit(&[ + InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTCTRL.0, 1), + InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTSHIFT.0, 1), + ]) + .map_err(|e| WhsprError::Injection(format!("paste modifier press: {e}")))?; + std::thread::sleep(std::time::Duration::from_millis(12)); + + self.device + .emit(&[ + InputEvent::new(EventType::KEY.0, KeyCode::KEY_V.0, 1), + InputEvent::new(EventType::KEY.0, KeyCode::KEY_V.0, 0), + ]) + .map_err(|e| WhsprError::Injection(format!("paste key press: {e}")))?; + std::thread::sleep(std::time::Duration::from_millis(12)); + + self.device + .emit(&[ + InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTSHIFT.0, 0), + InputEvent::new(EventType::KEY.0, KeyCode::KEY_LEFTCTRL.0, 0), + ]) + .map_err(|e| WhsprError::Injection(format!("paste modifier release: {e}")))?; + + Ok(()) + } + + pub(super) fn emit_backspaces(&mut self, count: usize) -> Result<()> { + for _ in 0..count { + self.device + .emit(&[ + InputEvent::new(EventType::KEY.0, KeyCode::KEY_BACKSPACE.0, 1), + InputEvent::new(EventType::KEY.0, KeyCode::KEY_BACKSPACE.0, 0), + ]) + .map_err(|e| WhsprError::Injection(format!("backspace key press: {e}")))?; + std::thread::sleep(std::time::Duration::from_millis(6)); + } + + Ok(()) + } +} + +fn build_virtual_device() -> Result { + let mut keys = AttributeSet::::new(); + keys.insert(KeyCode::KEY_LEFTCTRL); + keys.insert(KeyCode::KEY_LEFTSHIFT); + keys.insert(KeyCode::KEY_V); + keys.insert(KeyCode::KEY_BACKSPACE); + + VirtualDevice::builder() + .map_err(|e| WhsprError::Injection(format!("uinput: {e}")))? + .name("whispers-keyboard") + .with_keys(&keys) + .map_err(|e| WhsprError::Injection(format!("uinput keys: {e}")))? + .build() + .map_err(|e| WhsprError::Injection(format!("uinput build: {e}"))) +} diff --git a/src/inject/mod.rs b/src/inject/mod.rs new file mode 100644 index 0000000..0df955b --- /dev/null +++ b/src/inject/mod.rs @@ -0,0 +1,118 @@ +mod clipboard; +mod keyboard; + +#[cfg(test)] +mod tests; + +use crate::error::{Result, WhsprError}; + +const DEVICE_READY_DELAY: std::time::Duration = std::time::Duration::from_millis(120); +const CLIPBOARD_READY_DELAY: std::time::Duration = std::time::Duration::from_millis(180); +const POST_DELETE_SETTLE_DELAY: std::time::Duration = std::time::Duration::from_millis(30); + +pub struct TextInjector { + wl_copy_bin: String, + wl_copy_args: Vec, +} + +impl Default for TextInjector { + fn default() -> Self { + Self::new() + } +} + +impl TextInjector { + pub fn new() -> Self { + Self { + wl_copy_bin: "wl-copy".to_string(), + wl_copy_args: Vec::new(), + } + } + + #[cfg(test)] + fn with_wl_copy_command(bin: &str, args: &[&str]) -> Self { + Self { + wl_copy_bin: bin.to_string(), + wl_copy_args: args.iter().map(|arg| (*arg).to_string()).collect(), + } + } + + pub async fn inject(&self, text: &str) -> Result<()> { + if text.is_empty() { + tracing::warn!("empty text, nothing to inject"); + return Ok(()); + } + + let text = text.to_string(); + let text_len = text.len(); + let wl_copy_bin = self.wl_copy_bin.clone(); + let wl_copy_args = self.wl_copy_args.clone(); + tokio::task::spawn_blocking(move || inject_sync(&wl_copy_bin, &wl_copy_args, &text)) + .await + .map_err(|e| WhsprError::Injection(format!("injection task panicked: {e}")))??; + + tracing::info!("injected {} chars via wl-copy + Ctrl+Shift+V", text_len); + Ok(()) + } + + pub async fn replace_recent_text(&self, delete_graphemes: usize, text: &str) -> Result<()> { + if delete_graphemes == 0 { + return self.inject(text).await; + } + + let text = text.to_string(); + let wl_copy_bin = self.wl_copy_bin.clone(); + let wl_copy_args = self.wl_copy_args.clone(); + tokio::task::spawn_blocking(move || { + replace_recent_text_sync(&wl_copy_bin, &wl_copy_args, delete_graphemes, &text) + }) + .await + .map_err(|e| WhsprError::Injection(format!("replace task panicked: {e}")))??; + + tracing::info!( + "replaced {} graphemes via backspace + wl-copy paste", + delete_graphemes + ); + Ok(()) + } +} + +fn inject_sync(wl_copy_bin: &str, wl_copy_args: &[String], text: &str) -> Result<()> { + let mut keyboard = keyboard::VirtualKeyboardAdapter::new()?; + let clipboard = clipboard::ClipboardAdapter::new(wl_copy_bin, wl_copy_args); + + clipboard.copy(text)?; + + // Wait for compositor to process the clipboard offer. + // The uinput device was created above, so it has already been + // registering during the wl-copy write. + std::thread::sleep(CLIPBOARD_READY_DELAY); + keyboard.emit_paste_combo()?; + + Ok(()) +} + +fn replace_recent_text_sync( + wl_copy_bin: &str, + wl_copy_args: &[String], + delete_graphemes: usize, + text: &str, +) -> Result<()> { + let mut keyboard = keyboard::VirtualKeyboardAdapter::new()?; + let clipboard = clipboard::ClipboardAdapter::new(wl_copy_bin, wl_copy_args); + + // Unlike plain injection, replacement can try to backspace immediately + // after creating the uinput device. Give the compositor a moment to + // register it first so the initial backspaces are not dropped. + std::thread::sleep(DEVICE_READY_DELAY); + keyboard.emit_backspaces(delete_graphemes)?; + + if !text.is_empty() { + std::thread::sleep(POST_DELETE_SETTLE_DELAY); + clipboard.copy(text)?; + std::thread::sleep(CLIPBOARD_READY_DELAY); + keyboard.emit_paste_combo()?; + } + + Ok(()) +} diff --git a/src/inject/tests.rs b/src/inject/tests.rs new file mode 100644 index 0000000..45db51f --- /dev/null +++ b/src/inject/tests.rs @@ -0,0 +1,54 @@ +use crate::error::WhsprError; + +use super::*; + +#[test] +fn run_wl_copy_reports_spawn_failure() { + let err = clipboard::run_wl_copy("/definitely/missing/wl-copy", &[], "hello") + .expect_err("missing binary should fail"); + match err { + WhsprError::Injection(msg) => { + assert!(msg.contains("failed to spawn wl-copy"), "unexpected: {msg}"); + } + other => panic!("unexpected error variant: {other:?}"), + } +} + +#[test] +fn run_wl_copy_reports_non_zero_exit() { + let err = clipboard::run_wl_copy( + "/bin/sh", + &[String::from("-c"), String::from("exit 7")], + "hello", + ) + .expect_err("non-zero exit should fail"); + match err { + WhsprError::Injection(msg) => { + assert!(msg.contains("wl-copy exited"), "unexpected: {msg}"); + } + other => panic!("unexpected error variant: {other:?}"), + } +} + +#[test] +fn run_wl_copy_reports_timeout() { + let err = clipboard::run_wl_copy_with_timeout( + "/bin/sh", + &[String::from("-c"), String::from("sleep 1")], + "hello", + std::time::Duration::from_millis(80), + ) + .expect_err("sleep should time out"); + match err { + WhsprError::Injection(msg) => { + assert!(msg.contains("timed out"), "unexpected: {msg}"); + } + other => panic!("unexpected error variant: {other:?}"), + } +} + +#[tokio::test] +async fn inject_empty_text_is_noop() { + let injector = TextInjector::with_wl_copy_command("/bin/true", &[]); + injector.inject("").await.expect("empty text should no-op"); +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..64658ec --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,35 @@ +pub mod agentic_rewrite; +pub mod app; +pub mod asr; +pub mod asr_model; +pub mod asr_protocol; +pub mod audio; +pub mod branding; +pub mod cleanup; +pub mod cli; +pub mod cloud; +pub mod completions; +pub mod config; +pub mod context; +pub mod error; +pub mod faster_whisper; +pub mod feedback; +pub mod file_audio; +pub mod inject; +pub mod model; +pub(crate) mod model_support; +pub mod nemo_asr; +pub mod personalization; +pub mod postprocess; +pub mod rewrite; +pub mod rewrite_model; +pub mod rewrite_profile; +pub mod rewrite_protocol; +pub mod rewrite_worker; +pub mod runtime_support; +pub mod session; +pub mod setup; +#[cfg(test)] +pub mod test_support; +pub mod transcribe; +pub mod ui; diff --git a/src/main.rs b/src/main.rs index 5b71b62..90cee45 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,223 +1,64 @@ -mod app; -mod asr; -mod asr_model; -mod asr_protocol; -mod audio; -mod cleanup; -mod cli; -mod cloud; -mod completions; -mod config; -mod context; -mod error; -mod faster_whisper; -mod feedback; -mod file_audio; -mod inject; -mod model; -mod nemo_asr; -mod personalization; -mod postprocess; -mod rewrite_model; -mod rewrite_profile; -mod rewrite_protocol; -mod rewrite_worker; -mod session; -mod setup; -#[cfg(test)] -mod test_support; -mod transcribe; -mod ui; - -use std::path::{Path, PathBuf}; +use std::path::Path; use clap::Parser; -use tracing_subscriber::EnvFilter; - -use crate::cli::{ - AsrModelAction, Cli, CloudAction, Command, DictionaryAction, ModelAction, RewriteModelAction, - SnippetAction, +use whispers::cli::{ + AppRuleAction, AsrModelAction, Cli, CloudAction, Command, DictionaryAction, GlossaryAction, + ModelAction, RewriteModelAction, SnippetAction, +}; +use whispers::config::Config; +use whispers::error::Result; +use whispers::rewrite_protocol::RewriteSurfaceKind; +use whispers::{ + agentic_rewrite, app, asr, asr_model, audio, cloud, completions, file_audio, model, + personalization, postprocess, rewrite_model, runtime_support, setup, }; -use crate::config::Config; - -struct PidLock { - path: PathBuf, - _file: std::fs::File, -} - -impl Drop for PidLock { - fn drop(&mut self) { - let _ = std::fs::remove_file(&self.path); - } -} - -fn pid_file_path() -> PathBuf { - let runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".into()); - PathBuf::from(runtime_dir).join("whispers.pid") -} - -fn read_pid_from_lock(path: &Path) -> Option { - let contents = std::fs::read_to_string(path).ok()?; - contents.trim().parse().ok() -} - -fn process_exists(pid: libc::pid_t) -> bool { - Path::new(&format!("/proc/{pid}")).exists() -} - -fn pid_belongs_to_whspr(pid: libc::pid_t) -> bool { - if !process_exists(pid) { - return false; - } - - let current_exe = std::env::current_exe() - .ok() - .and_then(|p| std::fs::canonicalize(p).ok()); - let target_exe = std::fs::canonicalize(format!("/proc/{pid}/exe")).ok(); - - if let (Some(current), Some(target)) = (current_exe.as_ref(), target_exe.as_ref()) { - if current == target { - return true; - } - } - - let current_name = std::env::current_exe() - .ok() - .and_then(|p| p.file_name().map(|n| n.to_string_lossy().into_owned())) - .unwrap_or_else(|| "whispers".into()); - let cmdline = match std::fs::read(format!("/proc/{pid}/cmdline")) { - Ok(bytes) => bytes, - Err(_) => return false, - }; - let Some(first_arg) = cmdline.split(|b| *b == 0).next() else { - return false; - }; - if first_arg.is_empty() { - return false; - } - let first_arg = String::from_utf8_lossy(first_arg); - Path::new(first_arg.as_ref()) - .file_name() - .map(|name| name.to_string_lossy() == current_name) - .unwrap_or(false) -} - -fn try_acquire_pid_lock(path: &Path) -> std::io::Result { - use std::io::Write; - - let mut file = std::fs::OpenOptions::new() - .write(true) - .create_new(true) - .open(path)?; - writeln!(file, "{}", std::process::id())?; - - Ok(PidLock { - path: path.to_path_buf(), - _file: file, - }) -} - -fn signal_existing_instance(path: &Path) -> crate::error::Result { - let Some(pid) = read_pid_from_lock(path) else { - tracing::warn!("stale pid lock at {}, removing", path.display()); - let _ = std::fs::remove_file(path); - return Ok(false); - }; - - if !pid_belongs_to_whspr(pid) { - tracing::warn!( - "pid lock at {} points to non-whspr process ({pid}), removing", - path.display() - ); - let _ = std::fs::remove_file(path); - return Ok(false); - } - - tracing::info!("sending toggle signal to running instance (pid {pid})"); - let ret = unsafe { libc::kill(pid, libc::SIGUSR1) }; - if ret == 0 { - return Ok(true); - } - let err = std::io::Error::last_os_error(); - tracing::warn!("failed to signal pid {pid}: {err}"); - if err.raw_os_error() == Some(libc::ESRCH) { - let _ = std::fs::remove_file(path); - return Ok(false); +fn build_context_matcher( + surface_kind: Option, + app_id: Option<&String>, + window_title_contains: Option<&String>, + browser_domain_contains: Option<&String>, +) -> agentic_rewrite::ContextMatcher { + agentic_rewrite::ContextMatcher { + surface_kind, + app_id: app_id.cloned(), + window_title_contains: window_title_contains.cloned(), + browser_domain_contains: browser_domain_contains.cloned(), } - - Err(err.into()) } -fn acquire_or_signal_lock() -> crate::error::Result> { - let path = pid_file_path(); - - for _ in 0..2 { - match try_acquire_pid_lock(&path) { - Ok(lock) => return Ok(Some(lock)), - Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { - if signal_existing_instance(&path)? { - return Ok(None); - } - } - Err(e) => return Err(e.into()), - } - } - - Err(crate::error::WhsprError::Config(format!( - "failed to acquire pid lock at {}", - path.display() - ))) -} - -fn init_tracing(verbose: u8) { - crate::ui::configure_terminal_colors(); - crate::ui::set_verbosity(verbose); - let filter = match verbose { - 0 => "whispers=warn", - 1 => "whispers=info", - 2 => "whispers=debug", - _ => "whispers=trace", - }; - - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(filter)), - ) - .compact() - .init(); -} - -async fn transcribe_file( - cli: &Cli, - file: &Path, - output: Option<&Path>, - raw: bool, -) -> crate::error::Result<()> { +async fn transcribe_file(cli: &Cli, file: &Path, output: Option<&Path>, raw: bool) -> Result<()> { let config = Config::load(cli.config.as_deref())?; - asr::validate_transcription_config(&config)?; + asr::validation::validate_transcription_config(&config)?; tracing::info!("decoding audio file: {}", file.display()); let mut samples = file_audio::decode_audio_file(file)?; audio::preprocess_audio(&mut samples, file_audio::TARGET_SAMPLE_RATE); let rewrite_service = if raw { None } else { - postprocess::prepare_rewrite_service(&config) + postprocess::execution::prepare_rewrite_service(&config) }; if let Some(service) = rewrite_service.as_ref() { - postprocess::prewarm_rewrite_service(service, "file transcription"); + postprocess::execution::prewarm_rewrite_service(service, "file transcription"); } - let prepared = asr::prepare_transcriber(&config)?; - asr::prewarm_transcriber(&prepared, "file transcription"); + let prepared = asr::prepare::prepare_transcriber(&config)?; + asr::prepare::prewarm_transcriber(&prepared, "file transcription"); let transcript = - asr::transcribe_audio(&config, prepared, samples, file_audio::TARGET_SAMPLE_RATE).await?; + asr::execute::transcribe_audio(&config, prepared, samples, file_audio::TARGET_SAMPLE_RATE) + .await?; let text = if raw { - postprocess::raw_text(&transcript) + postprocess::planning::raw_text(&transcript) } else { - postprocess::finalize_transcript(&config, transcript, rewrite_service.as_ref(), None, None) - .await - .text + postprocess::finalize::finalize_transcript( + &config, + transcript, + rewrite_service.as_ref(), + None, + None, + ) + .await + .text }; if let Some(out_path) = output { @@ -230,8 +71,8 @@ async fn transcribe_file( Ok(()) } -async fn run_default(cli: &Cli) -> crate::error::Result<()> { - let Some(_pid_lock) = acquire_or_signal_lock()? else { +async fn run_default(cli: &Cli) -> Result<()> { + let Some(_pid_lock) = runtime_support::acquire_or_signal_lock()? else { return Ok(()); }; @@ -239,17 +80,17 @@ async fn run_default(cli: &Cli) -> crate::error::Result<()> { // Load config let config = Config::load(cli.config.as_deref())?; - asr::validate_transcription_config(&config)?; + asr::validation::validate_transcription_config(&config)?; tracing::debug!("config loaded: {config:?}"); app::run(config).await } #[tokio::main] -async fn main() -> crate::error::Result<()> { +async fn main() -> Result<()> { let cli = Cli::parse(); - init_tracing(cli.verbose); + runtime_support::init_tracing(cli.verbose); match &cli.command { None => run_default(&cli).await, @@ -302,6 +143,58 @@ async fn main() -> crate::error::Result<()> { personalization::remove_dictionary(cli.config.as_deref(), phrase) } }, + Some(Command::AppRule { action }) => match action { + AppRuleAction::Path => agentic_rewrite::print_app_rule_path(cli.config.as_deref()), + AppRuleAction::List => agentic_rewrite::list_app_rules(cli.config.as_deref()), + AppRuleAction::Add { + name, + instructions, + surface_kind, + app_id, + window_title_contains, + browser_domain_contains, + correction_policy, + } => agentic_rewrite::add_app_rule( + cli.config.as_deref(), + name, + instructions, + build_context_matcher( + *surface_kind, + app_id.as_ref(), + window_title_contains.as_ref(), + browser_domain_contains.as_ref(), + ), + *correction_policy, + ), + AppRuleAction::Remove { name } => { + agentic_rewrite::remove_app_rule(cli.config.as_deref(), name) + } + }, + Some(Command::Glossary { action }) => match action { + GlossaryAction::Path => agentic_rewrite::print_glossary_path(cli.config.as_deref()), + GlossaryAction::List => agentic_rewrite::list_glossary(cli.config.as_deref()), + GlossaryAction::Add { + term, + aliases, + surface_kind, + app_id, + window_title_contains, + browser_domain_contains, + } => agentic_rewrite::add_glossary_entry( + cli.config.as_deref(), + term, + aliases, + build_context_matcher( + *surface_kind, + app_id.as_ref(), + window_title_contains.as_ref(), + browser_domain_contains.as_ref(), + ), + ), + GlossaryAction::Remove { term } => { + agentic_rewrite::remove_glossary_entry(cli.config.as_deref(), term) + } + }, Some(Command::Cloud { action }) => match action { CloudAction::Check => { let config = Config::load(cli.config.as_deref())?; @@ -329,55 +222,3 @@ async fn main() -> crate::error::Result<()> { } } } - -#[cfg(test)] -mod tests { - use super::*; - - fn temp_lock_path(suffix: &str) -> PathBuf { - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_nanos(); - std::env::temp_dir().join(format!( - "whispers-test-{suffix}-{}-{now}.pid", - std::process::id() - )) - } - - #[test] - fn signal_existing_instance_cleans_invalid_pid_file() { - let path = temp_lock_path("invalid"); - std::fs::write(&path, "not-a-pid").unwrap(); - assert!(!signal_existing_instance(&path).unwrap()); - assert!(!path.exists()); - } - - #[test] - fn signal_existing_instance_cleans_missing_process_pid_file() { - let path = temp_lock_path("missing-process"); - std::fs::write(&path, "99999999").unwrap(); - assert!(!signal_existing_instance(&path).unwrap()); - assert!(!path.exists()); - } - - #[test] - fn try_acquire_pid_lock_uses_create_new_semantics() { - let path = temp_lock_path("acquire"); - let lock = try_acquire_pid_lock(&path).unwrap(); - assert!(path.exists()); - - let err = match try_acquire_pid_lock(&path) { - Ok(_) => panic!("lock acquisition should fail when file already exists"), - Err(e) => e, - }; - assert_eq!(err.kind(), std::io::ErrorKind::AlreadyExists); - - drop(lock); - assert!(!path.exists()); - - let lock2 = try_acquire_pid_lock(&path).unwrap(); - drop(lock2); - assert!(!path.exists()); - } -} diff --git a/src/model.rs b/src/model.rs index 1d13d2f..555f1b4 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,13 +1,8 @@ use std::path::{Path, PathBuf}; -use futures_util::StreamExt; -use tokio::io::AsyncWriteExt; - -use crate::config::{ - self, TranscriptionBackend, data_dir, resolve_config_path, - update_config_transcription_selection, -}; +use crate::config::{self, TranscriptionBackend, data_dir, update_config_transcription_selection}; use crate::error::{Result, WhsprError}; +use crate::model_support; const MODEL_BASE_URL: &str = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main"; @@ -105,41 +100,21 @@ pub fn selected_model_local_path(name: &str) -> Option { find_model(name).map(|info| model_path(info.filename)) } -fn path_for_config(path: &std::path::Path, home: Option<&std::path::Path>) -> String { - if let Some(home_path) = home { - if let Ok(stripped) = path.strip_prefix(home_path) { - return format!("~/{}", stripped.display()); - } - } - path.display().to_string() -} - pub fn model_path_for_config(filename: &str) -> String { let path = model_path(filename); - let home = std::env::var("HOME").ok().map(PathBuf::from); - path_for_config(&path, home.as_deref()) + model_support::path_for_current_home(&path) } fn active_model_path(config_path_override: Option<&Path>) -> Option { - let config_path = resolve_config_path(config_path_override); - if !config_path.exists() { - return None; - } - let contents = std::fs::read_to_string(&config_path).ok()?; - let config: config::Config = toml::from_str(&contents).ok()?; - Some(config.transcription.model_path) + model_support::load_config_if_exists(config_path_override) + .map(|config| config.transcription.model_path) } fn model_status(info: &ModelInfo, active_resolved: Option<&std::path::Path>) -> &'static str { let path = model_path(info.filename); let is_active = active_resolved == Some(path.as_path()); let is_local = path.exists(); - - match (is_active, is_local) { - (true, _) => "active", - (_, true) => "local", - _ => "remote", - } + model_support::configured_file_status(is_active, is_local) } pub fn list_models(config_path_override: Option<&Path>) { @@ -164,26 +139,6 @@ pub fn list_models(config_path_override: Option<&Path>) { } } -fn validated_existing_len(existing_len: u64, status: reqwest::StatusCode) -> Result { - if existing_len > 0 { - match status { - reqwest::StatusCode::PARTIAL_CONTENT => Ok(existing_len), - reqwest::StatusCode::OK => Ok(0), - _ => Err(WhsprError::Download(format!( - "download failed with HTTP {}", - status - ))), - } - } else if status.is_success() { - Ok(0) - } else { - Err(WhsprError::Download(format!( - "download failed with HTTP {}", - status - ))) - } -} - pub async fn download_model(name: &str) -> Result { download_model_from_base(name, MODEL_BASE_URL).await } @@ -201,109 +156,19 @@ pub(crate) async fn download_model_from_base(name: &str, base_url: &str) -> Resu let dest = model_path(info.filename); let part_path = dest.with_extension("bin.part"); - if dest.exists() { - tracing::info!("model '{name}' already downloaded at {}", dest.display()); - println!("{}", crate::ui::ready_message("ASR", name)); - return Ok(dest); - } - - // Ensure data directory exists - if let Some(parent) = dest.parent() { - std::fs::create_dir_all(parent) - .map_err(|e| WhsprError::Download(format!("failed to create data directory: {e}")))?; - } - let url = format!("{}/{}", base_url.trim_end_matches('/'), info.filename); - tracing::info!("downloading model '{}' from {}", info.name, url); - - println!( - "{} Downloading ASR model {} ({})...", - crate::ui::info_label(), - crate::ui::value(info.name), - info.size - ); - - let client = reqwest::Client::new(); - - // Check for partial download to support resume - let mut existing_len = if part_path.exists() { - std::fs::metadata(&part_path).map(|m| m.len()).unwrap_or(0) - } else { - 0 - }; - - let mut request = client.get(&url); - if existing_len > 0 { - tracing::info!("resuming model download from {existing_len} bytes"); - if crate::ui::is_verbose() { - println!("Resuming from {} bytes...", existing_len); - } - request = request.header("Range", format!("bytes={}-", existing_len)); - } - - let response = request - .send() - .await - .map_err(|e| WhsprError::Download(format!("failed to start download: {e}")))?; - - let original_len = existing_len; - existing_len = validated_existing_len(existing_len, response.status())?; - if original_len > 0 && existing_len == 0 { - tracing::warn!("server ignored range request, restarting model download from zero"); - if crate::ui::is_verbose() { - println!("Server ignored range request, restarting download from zero"); - } - } - - let total_size = if existing_len > 0 { - // For range requests, content-length is remaining bytes - response - .content_length() - .map(|cl| cl + existing_len) - .unwrap_or(0) - } else { - response.content_length().unwrap_or(0) - }; - - let pb = crate::ui::progress_bar(total_size); - pb.set_position(existing_len); - - let mut open_opts = tokio::fs::OpenOptions::new(); - open_opts.create(true); - if existing_len > 0 { - open_opts.append(true); - } else { - open_opts.write(true).truncate(true); - } - let mut file = open_opts - .open(&part_path) - .await - .map_err(|e| WhsprError::Download(format!("failed to open file: {e}")))?; - - let mut stream = response.bytes_stream(); - while let Some(chunk) = stream.next().await { - let chunk = - chunk.map_err(|e| WhsprError::Download(format!("download interrupted: {e}")))?; - file.write_all(&chunk) - .await - .map_err(|e| WhsprError::Download(format!("failed to write: {e}")))?; - pb.inc(chunk.len() as u64); - } - - file.flush() - .await - .map_err(|e| WhsprError::Download(format!("failed to flush: {e}")))?; - drop(file); - - pb.finish_with_message("done"); - - // Atomic rename - std::fs::rename(&part_path, &dest) - .map_err(|e| WhsprError::Download(format!("failed to finalize download: {e}")))?; - - tracing::info!("model '{}' saved to {}", info.name, dest.display()); - println!("{}", crate::ui::ready_message("ASR", info.name)); - Ok(dest) + model_support::download_to_path(model_support::DownloadSpec { + tracing_label: "model", + user_label: "ASR model", + ready_kind: "ASR", + item_name: info.name, + size: info.size, + url: &url, + dest, + part_path, + resume_partial: true, + }) + .await } pub fn select_model(name: &str, config_path_override: Option<&Path>) -> Result<()> { @@ -318,10 +183,11 @@ pub fn select_model(name: &str, config_path_override: Option<&Path>) -> Result<( ))); } - let config_path = resolve_config_path(config_path_override); let model_path_str = model_path_for_config(info.filename); + let (config_path, created) = + model_support::ensure_default_config(config_path_override, &model_path_str)?; - if config_path.exists() { + if !created { tracing::info!( "updating model selection in config {} to {}", config_path.display(), @@ -340,7 +206,6 @@ pub fn select_model(name: &str, config_path_override: Option<&Path>) -> Result<( config_path.display(), model_path_str ); - config::write_default_config(&config_path, &model_path_str)?; update_config_transcription_selection( &config_path, TranscriptionBackend::WhisperCpp, @@ -370,7 +235,7 @@ mod tests { let home = PathBuf::from("/home/alice"); let path = PathBuf::from("/home/alice/.local/share/whispers/ggml.bin"); assert_eq!( - path_for_config(&path, Some(&home)), + crate::model_support::path_for_config(&path, Some(&home)), "~/.local/share/whispers/ggml.bin" ); } @@ -380,29 +245,11 @@ mod tests { let home = PathBuf::from("/home/alice"); let path = PathBuf::from("/var/lib/whispers/ggml.bin"); assert_eq!( - path_for_config(&path, Some(&home)), + crate::model_support::path_for_config(&path, Some(&home)), "/var/lib/whispers/ggml.bin" ); } - #[test] - fn validated_existing_len_accepts_partial_content_resume() { - let len = validated_existing_len(100, reqwest::StatusCode::PARTIAL_CONTENT).unwrap(); - assert_eq!(len, 100); - } - - #[test] - fn validated_existing_len_restarts_on_ok_resume_response() { - let len = validated_existing_len(100, reqwest::StatusCode::OK).unwrap(); - assert_eq!(len, 0); - } - - #[test] - fn validated_existing_len_rejects_resume_on_error_status() { - let err = validated_existing_len(100, reqwest::StatusCode::RANGE_NOT_SATISFIABLE); - assert!(err.is_err()); - } - #[test] fn active_model_path_uses_override_config() { let config_path = crate::test_support::unique_temp_path("active-model-config", "toml"); diff --git a/src/model_support.rs b/src/model_support.rs new file mode 100644 index 0000000..8bfe218 --- /dev/null +++ b/src/model_support.rs @@ -0,0 +1,286 @@ +use std::path::{Path, PathBuf}; + +use futures_util::StreamExt; +use tokio::io::AsyncWriteExt; + +use crate::config::{self, Config}; +use crate::error::{Result, WhsprError}; + +pub(crate) struct DownloadSpec<'a> { + pub tracing_label: &'a str, + pub user_label: &'a str, + pub ready_kind: &'a str, + pub item_name: &'a str, + pub size: &'a str, + pub url: &'a str, + pub dest: PathBuf, + pub part_path: PathBuf, + pub resume_partial: bool, +} + +pub(crate) fn path_for_config(path: &Path, home: Option<&Path>) -> String { + if let Some(home_path) = home { + if let Ok(stripped) = path.strip_prefix(home_path) { + return format!("~/{}", stripped.display()); + } + } + path.display().to_string() +} + +pub(crate) fn path_for_current_home(path: &Path) -> String { + let home = std::env::var("HOME").ok().map(PathBuf::from); + path_for_config(path, home.as_deref()) +} + +pub(crate) fn load_config_if_exists(config_path_override: Option<&Path>) -> Option { + let config_path = config::resolve_config_path(config_path_override); + load_config_at_if_exists(&config_path) +} + +pub(crate) fn load_config_at_if_exists(config_path: &Path) -> Option { + config_path + .exists() + .then(|| Config::load(Some(config_path)).ok())? +} + +pub(crate) fn ensure_default_config( + config_path_override: Option<&Path>, + default_model_path: &str, +) -> Result<(PathBuf, bool)> { + let config_path = config::resolve_config_path(config_path_override); + let created = !config_path.exists(); + if created { + config::write_default_config(&config_path, default_model_path)?; + } + Ok((config_path, created)) +} + +pub(crate) fn configured_file_status(is_active: bool, is_local: bool) -> &'static str { + match (is_active, is_local) { + (true, _) => "active", + (_, true) => "local", + _ => "remote", + } +} + +pub(crate) fn managed_download_status(is_active: bool, is_local: bool) -> &'static str { + match (is_active, is_local) { + (true, true) => "active", + (true, false) => "active (missing)", + (_, true) => "local", + _ => "remote", + } +} + +pub(crate) async fn download_to_path(spec: DownloadSpec<'_>) -> Result { + if spec.dest.exists() { + tracing::info!( + "{} '{}' already downloaded at {}", + spec.tracing_label, + spec.item_name, + spec.dest.display() + ); + println!( + "{}", + crate::ui::ready_message(spec.ready_kind, spec.item_name) + ); + return Ok(spec.dest); + } + + if let Some(parent) = spec.dest.parent() { + std::fs::create_dir_all(parent) + .map_err(|e| WhsprError::Download(format!("failed to create data directory: {e}")))?; + } + + tracing::info!( + "downloading {} '{}' from {}", + spec.tracing_label, + spec.item_name, + spec.url + ); + println!( + "{} Downloading {} {} ({})...", + crate::ui::info_label(), + spec.user_label, + crate::ui::value(spec.item_name), + spec.size + ); + + let client = reqwest::Client::new(); + let mut existing_len = if spec.resume_partial && spec.part_path.exists() { + std::fs::metadata(&spec.part_path) + .map(|m| m.len()) + .unwrap_or(0) + } else { + 0 + }; + + let mut request = client.get(spec.url); + if existing_len > 0 { + tracing::info!( + "resuming {} download from {existing_len} bytes", + spec.tracing_label + ); + if crate::ui::is_verbose() { + println!("Resuming from {} bytes...", existing_len); + } + request = request.header("Range", format!("bytes={}-", existing_len)); + } + + let response = request + .send() + .await + .map_err(|e| WhsprError::Download(format!("failed to start download: {e}")))?; + + let original_len = existing_len; + existing_len = validated_existing_len(existing_len, response.status())?; + if original_len > 0 && existing_len == 0 { + tracing::warn!( + "server ignored range request, restarting {} download from zero", + spec.tracing_label + ); + if crate::ui::is_verbose() { + println!("Server ignored range request, restarting download from zero"); + } + } + + let total_size = if existing_len > 0 { + response + .content_length() + .map(|content_length| content_length + existing_len) + .unwrap_or(0) + } else { + response.content_length().unwrap_or(0) + }; + + let pb = crate::ui::progress_bar(total_size); + if existing_len > 0 { + pb.set_position(existing_len); + } + + let mut open_opts = tokio::fs::OpenOptions::new(); + open_opts.create(true); + if existing_len > 0 { + open_opts.append(true); + } else { + open_opts.write(true).truncate(true); + } + + let mut file = open_opts + .open(&spec.part_path) + .await + .map_err(|e| WhsprError::Download(format!("failed to open file: {e}")))?; + + let mut stream = response.bytes_stream(); + while let Some(chunk) = stream.next().await { + let chunk = + chunk.map_err(|e| WhsprError::Download(format!("download interrupted: {e}")))?; + file.write_all(&chunk) + .await + .map_err(|e| WhsprError::Download(format!("failed to write: {e}")))?; + pb.inc(chunk.len() as u64); + } + + file.flush() + .await + .map_err(|e| WhsprError::Download(format!("failed to flush: {e}")))?; + drop(file); + + pb.finish_with_message("done"); + + std::fs::rename(&spec.part_path, &spec.dest) + .map_err(|e| WhsprError::Download(format!("failed to finalize download: {e}")))?; + + tracing::info!( + "{} '{}' saved to {}", + spec.tracing_label, + spec.item_name, + spec.dest.display() + ); + println!( + "{}", + crate::ui::ready_message(spec.ready_kind, spec.item_name) + ); + Ok(spec.dest) +} + +pub(crate) fn validated_existing_len( + existing_len: u64, + status: reqwest::StatusCode, +) -> Result { + if existing_len > 0 { + match status { + reqwest::StatusCode::PARTIAL_CONTENT => Ok(existing_len), + reqwest::StatusCode::OK => Ok(0), + _ => Err(WhsprError::Download(format!( + "download failed with HTTP {}", + status + ))), + } + } else if status.is_success() { + Ok(0) + } else { + Err(WhsprError::Download(format!( + "download failed with HTTP {}", + status + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn path_for_config_uses_tilde_when_under_home() { + let home = PathBuf::from("/home/alice"); + let path = PathBuf::from("/home/alice/.local/share/whispers/ggml.bin"); + assert_eq!( + path_for_config(&path, Some(&home)), + "~/.local/share/whispers/ggml.bin" + ); + } + + #[test] + fn path_for_config_keeps_absolute_when_outside_home() { + let home = PathBuf::from("/home/alice"); + let path = PathBuf::from("/var/lib/whispers/ggml.bin"); + assert_eq!( + path_for_config(&path, Some(&home)), + "/var/lib/whispers/ggml.bin" + ); + } + + #[test] + fn configured_file_status_prioritizes_active() { + assert_eq!(configured_file_status(true, false), "active"); + assert_eq!(configured_file_status(false, true), "local"); + assert_eq!(configured_file_status(false, false), "remote"); + } + + #[test] + fn managed_download_status_reports_missing_active_assets() { + assert_eq!(managed_download_status(true, true), "active"); + assert_eq!(managed_download_status(true, false), "active (missing)"); + assert_eq!(managed_download_status(false, true), "local"); + assert_eq!(managed_download_status(false, false), "remote"); + } + + #[test] + fn validated_existing_len_accepts_partial_content_resume() { + let len = validated_existing_len(100, reqwest::StatusCode::PARTIAL_CONTENT).unwrap(); + assert_eq!(len, 100); + } + + #[test] + fn validated_existing_len_restarts_on_ok_resume_response() { + let len = validated_existing_len(100, reqwest::StatusCode::OK).unwrap(); + assert_eq!(len, 0); + } + + #[test] + fn validated_existing_len_rejects_resume_on_error_status() { + let err = validated_existing_len(100, reqwest::StatusCode::RANGE_NOT_SATISFIABLE); + assert!(err.is_err()); + } +} diff --git a/src/nemo_asr.rs b/src/nemo_asr.rs index aa164ab..0ab2827 100644 --- a/src/nemo_asr.rs +++ b/src/nemo_asr.rs @@ -178,7 +178,7 @@ pub fn prepare_service(config: &TranscriptionConfig) -> Option { Some(NemoAsrService::new(config, &resolved)) } -pub fn resolve_model_ref(config: &TranscriptionConfig) -> Option { +fn resolve_model_ref(config: &TranscriptionConfig) -> Option { if let Some(model) = find_managed_model(&config.selected_model) { let model_dir = managed_model_local_path(model.name); if let Some(metadata) = load_model_ready_metadata(&model_dir) { diff --git a/src/osd_protocol.rs b/src/osd_protocol.rs new file mode 100644 index 0000000..140732d --- /dev/null +++ b/src/osd_protocol.rs @@ -0,0 +1,49 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum VoiceOsdStatus { + #[default] + Listening, + Transcribing, + Rewriting, + Finalizing, + Frozen, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct VoiceOsdUpdate { + pub status: VoiceOsdStatus, + pub stable_text: String, + pub unstable_text: String, + pub rewrite_preview: Option, + pub live_inject: bool, + pub frozen: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum OsdEvent { + VoiceUpdate(VoiceOsdUpdate), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn voice_update_roundtrips_as_json() { + let event = OsdEvent::VoiceUpdate(VoiceOsdUpdate { + status: VoiceOsdStatus::Rewriting, + stable_text: "hello".into(), + unstable_text: "world".into(), + rewrite_preview: Some("Hello world.".into()), + live_inject: true, + frozen: false, + }); + + let json = serde_json::to_string(&event).expect("serialize event"); + let parsed: OsdEvent = serde_json::from_str(&json).expect("deserialize event"); + assert_eq!(parsed, event); + } +} diff --git a/src/personalization/mod.rs b/src/personalization/mod.rs new file mode 100644 index 0000000..d6105e6 --- /dev/null +++ b/src/personalization/mod.rs @@ -0,0 +1,471 @@ +use crate::config::Config; +use crate::error::Result; + +mod rewrite; +mod store; + +pub use rewrite::build_rewrite_transcript; +pub use store::{ + DictionaryEntry, SnippetEntry, add_dictionary, add_snippet, list_dictionary, list_snippets, + print_rewrite_instructions_path, remove_dictionary, remove_snippet, +}; + +#[derive(Debug, Clone, Default)] +pub struct PersonalizationRules { + dictionary: Vec, + snippets: Vec, + snippet_trigger_words: Vec, + custom_instructions: String, +} + +#[derive(Debug, Clone)] +pub(super) struct PreparedDictionaryEntry { + replace: String, + words: Vec, +} + +#[derive(Debug, Clone)] +pub(super) struct PreparedSnippet { + text: String, + words: Vec, +} + +#[derive(Debug, Clone)] +pub(super) struct WordSpan { + pub(super) start: usize, + pub(super) end: usize, + pub(super) normalized: String, +} + +pub fn load_rules(config: &Config) -> Result { + let dictionary_entries = store::read_dictionary_file(&config.resolved_dictionary_path())?; + let snippet_entries = store::read_snippet_file(&config.resolved_snippets_path())?; + let custom_instructions = store::load_custom_instructions(config)?; + + Ok(PersonalizationRules { + dictionary: dictionary_entries + .into_iter() + .filter_map(|entry| PreparedDictionaryEntry::new(entry).ok()) + .collect(), + snippets: snippet_entries + .into_iter() + .filter_map(|entry| PreparedSnippet::new(entry).ok()) + .collect(), + snippet_trigger_words: normalized_words(&config.personalization.snippet_trigger), + custom_instructions, + }) +} + +pub fn finalize_text(text: &str, rules: &PersonalizationRules) -> String { + let corrected = apply_dictionary(text, rules); + let expanded = expand_snippets(&corrected, rules); + normalize_numeric_dot_runs(&expanded) +} + +pub fn custom_instructions(rules: &PersonalizationRules) -> Option<&str> { + (!rules.custom_instructions.trim().is_empty()).then_some(rules.custom_instructions.as_str()) +} + +pub fn transcription_prompt(rules: &PersonalizationRules) -> Option { + const MAX_TERMS: usize = 24; + const MAX_PROMPT_LEN: usize = 480; + + let mut terms = Vec::new(); + for entry in &rules.dictionary { + let replace = entry.replace.trim(); + if replace.is_empty() { + continue; + } + if terms.iter().any(|existing: &String| existing == replace) { + continue; + } + let projected_len = if terms.is_empty() { + replace.len() + } else { + terms.iter().map(String::len).sum::() + (terms.len() * 2) + replace.len() + }; + if terms.len() >= MAX_TERMS || projected_len > MAX_PROMPT_LEN { + break; + } + terms.push(replace.to_string()); + } + + if terms.is_empty() { + return None; + } + + Some(format!( + "This is direct dictation. Prefer these exact spellings when heard: {}.", + terms.join(", ") + )) +} + +pub(super) fn apply_dictionary(text: &str, rules: &PersonalizationRules) -> String { + apply_replacements(text, &rules.dictionary) +} + +fn expand_snippets(text: &str, rules: &PersonalizationRules) -> String { + if rules.snippets.is_empty() || rules.snippet_trigger_words.is_empty() { + return text.trim().to_string(); + } + + let spans = collect_word_spans(text); + if spans.is_empty() { + return text.trim().to_string(); + } + + let mut output = String::new(); + let mut cursor = 0usize; + let mut index = 0usize; + + while index < spans.len() { + let Some(best) = + best_snippet_match(&spans, index, &rules.snippet_trigger_words, &rules.snippets) + else { + index += 1; + continue; + }; + + output.push_str(&text[cursor..spans[index].start]); + output.push_str(best.text); + cursor = spans[index + best.total_words - 1].end; + index += best.total_words; + } + + output.push_str(&text[cursor..]); + output.trim().to_string() +} + +fn apply_replacements(text: &str, entries: &[PreparedDictionaryEntry]) -> String { + if entries.is_empty() { + return text.trim().to_string(); + } + + let spans = collect_word_spans(text); + if spans.is_empty() { + return text.trim().to_string(); + } + + let mut output = String::new(); + let mut cursor = 0usize; + let mut index = 0usize; + + while index < spans.len() { + let Some(best) = best_dictionary_match(&spans, index, entries) else { + index += 1; + continue; + }; + + output.push_str(&text[cursor..spans[index].start]); + output.push_str(&best.replace); + cursor = spans[index + best.words.len() - 1].end; + index += best.words.len(); + } + + output.push_str(&text[cursor..]); + output.trim().to_string() +} + +fn best_dictionary_match<'a>( + spans: &[WordSpan], + index: usize, + entries: &'a [PreparedDictionaryEntry], +) -> Option<&'a PreparedDictionaryEntry> { + entries + .iter() + .filter(|entry| entry.matches(spans, index)) + .max_by_key(|entry| entry.words.len()) +} + +fn best_snippet_match<'a>( + spans: &[WordSpan], + index: usize, + trigger_words: &[String], + snippets: &'a [PreparedSnippet], +) -> Option> { + if !matches_words(spans, index, trigger_words) { + return None; + } + + let snippet_index = index + trigger_words.len(); + snippets + .iter() + .filter(|snippet| snippet.matches(spans, snippet_index)) + .max_by_key(|snippet| snippet.words.len()) + .map(|snippet| SnippetMatch { + text: snippet.text.as_str(), + total_words: trigger_words.len() + snippet.words.len(), + }) +} + +fn matches_words(spans: &[WordSpan], index: usize, words: &[String]) -> bool { + if words.is_empty() || index + words.len() > spans.len() { + return false; + } + + spans[index..index + words.len()] + .iter() + .zip(words) + .all(|(span, word)| span.normalized == *word) +} + +pub(super) fn collect_word_spans(text: &str) -> Vec { + let mut spans = Vec::new(); + let mut current_start = None; + + for (idx, ch) in text.char_indices() { + if is_word_char(ch) { + current_start.get_or_insert(idx); + continue; + } + + if let Some(start) = current_start.take() { + spans.push(WordSpan { + start, + end: idx, + normalized: normalize_word(&text[start..idx]), + }); + } + } + + if let Some(start) = current_start { + spans.push(WordSpan { + start, + end: text.len(), + normalized: normalize_word(&text[start..]), + }); + } + + spans +} + +fn normalize_word(word: &str) -> String { + word.chars() + .filter(|ch| is_word_char(*ch)) + .flat_map(|ch| ch.to_lowercase()) + .collect() +} + +pub(super) fn normalized_words(text: &str) -> Vec { + collect_word_spans(text) + .into_iter() + .map(|span| span.normalized) + .collect() +} + +fn is_word_char(ch: char) -> bool { + ch.is_alphanumeric() || matches!(ch, '\'' | '-') +} + +impl PreparedDictionaryEntry { + pub(super) fn new(entry: DictionaryEntry) -> std::result::Result { + let words = normalized_words(&entry.phrase); + if words.is_empty() { + return Err(entry); + } + + Ok(Self { + replace: entry.replace, + words, + }) + } + + fn matches(&self, spans: &[WordSpan], index: usize) -> bool { + matches_words(spans, index, &self.words) + } +} + +impl PreparedSnippet { + pub(super) fn new(entry: SnippetEntry) -> std::result::Result { + let words = normalized_words(&entry.name); + if words.is_empty() { + return Err(entry); + } + + Ok(Self { + text: entry.text, + words, + }) + } + + fn matches(&self, spans: &[WordSpan], index: usize) -> bool { + matches_words(spans, index, &self.words) + } +} + +struct SnippetMatch<'a> { + text: &'a str, + total_words: usize, +} + +fn normalize_numeric_dot_runs(text: &str) -> String { + let chars: Vec = text.chars().collect(); + let mut output = String::with_capacity(text.len()); + let mut index = 0usize; + + while index < chars.len() { + let ch = chars[index]; + + if ch == ' ' + && previous_non_space_char(&output).is_some_and(|previous| previous.is_ascii_digit()) + { + let mut lookahead = index; + while lookahead < chars.len() && chars[lookahead] == ' ' { + lookahead += 1; + } + + if lookahead < chars.len() + && chars[lookahead] == '.' + && dot_has_numeric_suffix(&chars, lookahead) + { + index = lookahead; + continue; + } + } + + output.push(ch); + + if ch == '.' + && previous_non_space_char(&output[..output.len().saturating_sub(1)]) + .is_some_and(|previous| previous.is_ascii_digit()) + { + let mut lookahead = index + 1; + while lookahead < chars.len() && chars[lookahead] == ' ' { + lookahead += 1; + } + + if lookahead > index + 1 && lookahead < chars.len() && chars[lookahead].is_ascii_digit() + { + index = lookahead; + continue; + } + } + + index += 1; + } + + output +} + +fn previous_non_space_char(text: &str) -> Option { + text.chars().rev().find(|ch| !ch.is_whitespace()) +} + +fn dot_has_numeric_suffix(chars: &[char], dot_index: usize) -> bool { + let mut lookahead = dot_index + 1; + while lookahead < chars.len() && chars[lookahead] == ' ' { + lookahead += 1; + } + + lookahead < chars.len() && chars[lookahead].is_ascii_digit() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{Config, PostprocessMode}; + use crate::rewrite_profile::RewriteProfile; + + fn rules() -> PersonalizationRules { + PersonalizationRules { + dictionary: vec![ + PreparedDictionaryEntry::new(DictionaryEntry { + phrase: "wisper flow".into(), + replace: "Wispr Flow".into(), + }) + .expect("dictionary"), + PreparedDictionaryEntry::new(DictionaryEntry { + phrase: "open ai".into(), + replace: "OpenAI".into(), + }) + .expect("dictionary"), + ], + snippets: vec![ + PreparedSnippet::new(SnippetEntry { + name: "signature".into(), + text: "Best regards,\nNotes".into(), + }) + .expect("snippet"), + PreparedSnippet::new(SnippetEntry { + name: "meeting follow up".into(), + text: "Thanks for the meeting.".into(), + }) + .expect("snippet"), + ], + snippet_trigger_words: normalized_words("insert"), + custom_instructions: "Keep brand names exact.".into(), + } + } + + #[test] + fn dictionary_applies_exact_normalized_replacements() { + let applied = apply_dictionary("I use wisper flow with open, ai.", &rules()); + assert_eq!(applied, "I use Wispr Flow with OpenAI."); + } + + #[test] + fn dictionary_prefers_longest_match() { + let rules = PersonalizationRules { + dictionary: vec![ + PreparedDictionaryEntry::new(DictionaryEntry { + phrase: "open".into(), + replace: "X".into(), + }) + .expect("dictionary"), + PreparedDictionaryEntry::new(DictionaryEntry { + phrase: "open ai".into(), + replace: "OpenAI".into(), + }) + .expect("dictionary"), + ], + ..PersonalizationRules::default() + }; + let applied = apply_dictionary("open ai works", &rules); + assert_eq!(applied, "OpenAI works"); + } + + #[test] + fn snippets_expand_after_trigger() { + let expanded = expand_snippets("please insert signature now", &rules()); + assert_eq!(expanded, "please Best regards,\nNotes now"); + } + + #[test] + fn unmatched_snippet_leaves_text_unchanged() { + let expanded = expand_snippets("please insert unknown now", &rules()); + assert_eq!(expanded, "please insert unknown now"); + } + + #[test] + fn finalize_text_applies_dictionary_then_snippets() { + let finalized = finalize_text("insert meeting follow up about wisper flow", &rules()); + assert_eq!(finalized, "Thanks for the meeting. about Wispr Flow"); + } + + #[test] + fn finalize_text_collapses_spaced_numeric_dot_runs() { + let finalized = finalize_text("MPL 2. 0 and TLS 1 . 3 are common references", &rules()); + assert_eq!(finalized, "MPL 2.0 and TLS 1.3 are common references"); + } + + #[test] + fn finalize_text_preserves_sentence_period_before_words() { + let finalized = finalize_text("Section 2. Next step", &rules()); + assert_eq!(finalized, "Section 2. Next step"); + } + + #[test] + fn transcription_prompt_includes_dictionary_targets() { + let prompt = transcription_prompt(&rules()).expect("prompt"); + assert!(prompt.contains("Wispr Flow")); + assert!(prompt.contains("OpenAI")); + } + + #[test] + fn default_config_paths_support_personalization_files() { + let config = Config::default(); + assert_eq!(config.postprocess.mode, PostprocessMode::Raw); + assert_eq!(config.rewrite.profile, RewriteProfile::Auto); + assert_eq!(config.personalization.snippet_trigger, "insert"); + } +} diff --git a/src/personalization.rs b/src/personalization/rewrite.rs similarity index 55% rename from src/personalization.rs rename to src/personalization/rewrite.rs index 71f4ae2..c6c5d75 100644 --- a/src/personalization.rs +++ b/src/personalization/rewrite.rs @@ -1,68 +1,14 @@ -use std::path::Path; - -use serde::{Deserialize, Serialize}; - use crate::cleanup; -use crate::config::Config; -use crate::error::{Result, WhsprError}; use crate::rewrite_protocol::{ RewriteCandidate, RewriteCandidateKind, RewriteEditAction, RewriteEditHypothesis, RewriteEditHypothesisMatchSource, RewriteEditIntent, RewriteEditSignal, RewriteEditSignalKind, RewriteEditSignalScope, RewriteEditSignalStrength, RewriteIntentConfidence, - RewriteReplacementScope, RewriteTailShape, RewriteTranscript, RewriteTranscriptSegment, + RewritePolicyContext, RewriteReplacementScope, RewriteTailShape, RewriteTranscript, + RewriteTranscriptSegment, }; use crate::transcribe::Transcript; -#[derive(Debug, Clone, Default)] -pub struct PersonalizationRules { - dictionary: Vec, - snippets: Vec, - snippet_trigger_words: Vec, - custom_instructions: String, -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] -pub struct DictionaryEntry { - pub phrase: String, - pub replace: String, -} - -#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] -#[serde(default)] -struct DictionaryFile { - entries: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] -pub struct SnippetEntry { - pub name: String, - pub text: String, -} - -#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] -#[serde(default)] -struct SnippetFile { - snippets: Vec, -} - -#[derive(Debug, Clone)] -struct PreparedDictionaryEntry { - replace: String, - words: Vec, -} - -#[derive(Debug, Clone)] -struct PreparedSnippet { - text: String, - words: Vec, -} - -#[derive(Debug, Clone)] -struct WordSpan { - start: usize, - end: usize, - normalized: String, -} +use super::{PersonalizationRules, WordSpan, apply_dictionary, collect_word_spans}; #[derive(Debug, Clone, Default)] struct TailParts { @@ -70,25 +16,6 @@ struct TailParts { span_replacement_tail: Option, } -pub fn load_rules(config: &Config) -> Result { - let dictionary_entries = read_dictionary_file(&config.resolved_dictionary_path())?; - let snippet_entries = read_snippet_file(&config.resolved_snippets_path())?; - let custom_instructions = load_custom_instructions(config)?; - - Ok(PersonalizationRules { - dictionary: dictionary_entries - .into_iter() - .filter_map(|entry| PreparedDictionaryEntry::new(entry).ok()) - .collect(), - snippets: snippet_entries - .into_iter() - .filter_map(|entry| PreparedSnippet::new(entry).ok()) - .collect(), - snippet_trigger_words: normalized_words(&config.personalization.snippet_trigger), - custom_instructions, - }) -} - pub fn build_rewrite_transcript( transcript: &Transcript, rules: &PersonalizationRules, @@ -220,6 +147,7 @@ pub fn build_rewrite_transcript( edit_hypotheses, rewrite_candidates, recommended_candidate, + policy_context: RewritePolicyContext::default(), } } @@ -695,551 +623,26 @@ fn candidate_priority(kind: RewriteCandidateKind) -> u8 { RewriteCandidateKind::CancelPreviousSentence => 5, RewriteCandidateKind::CancelPreviousClause => 6, RewriteCandidateKind::FollowingReplacement => 7, - RewriteCandidateKind::ConservativeCorrection => 8, - RewriteCandidateKind::DropCueOnly => 9, - RewriteCandidateKind::Literal => 10, - } -} - -fn normalize_numeric_dot_runs(text: &str) -> String { - let chars: Vec = text.chars().collect(); - let mut output = String::with_capacity(text.len()); - let mut index = 0usize; - - while index < chars.len() { - let ch = chars[index]; - - if ch == ' ' - && previous_non_space_char(&output).is_some_and(|previous| previous.is_ascii_digit()) - { - let mut lookahead = index; - while lookahead < chars.len() && chars[lookahead] == ' ' { - lookahead += 1; - } - - if lookahead < chars.len() - && chars[lookahead] == '.' - && dot_has_numeric_suffix(&chars, lookahead) - { - index = lookahead; - continue; - } - } - - output.push(ch); - - if ch == '.' - && previous_non_space_char(&output[..output.len().saturating_sub(1)]) - .is_some_and(|previous| previous.is_ascii_digit()) - { - let mut lookahead = index + 1; - while lookahead < chars.len() && chars[lookahead] == ' ' { - lookahead += 1; - } - - if lookahead > index + 1 && lookahead < chars.len() && chars[lookahead].is_ascii_digit() - { - index = lookahead; - continue; - } - } - - index += 1; - } - - output -} - -fn previous_non_space_char(text: &str) -> Option { - text.chars().rev().find(|ch| !ch.is_whitespace()) -} - -fn dot_has_numeric_suffix(chars: &[char], dot_index: usize) -> bool { - let mut lookahead = dot_index + 1; - while lookahead < chars.len() && chars[lookahead] == ' ' { - lookahead += 1; - } - - lookahead < chars.len() && chars[lookahead].is_ascii_digit() -} - -pub fn finalize_text(text: &str, rules: &PersonalizationRules) -> String { - let corrected = apply_dictionary(text, rules); - let expanded = expand_snippets(&corrected, rules); - normalize_numeric_dot_runs(&expanded) -} - -pub fn custom_instructions(rules: &PersonalizationRules) -> Option<&str> { - (!rules.custom_instructions.trim().is_empty()).then_some(rules.custom_instructions.as_str()) -} - -pub fn transcription_prompt(rules: &PersonalizationRules) -> Option { - const MAX_TERMS: usize = 24; - const MAX_PROMPT_LEN: usize = 480; - - let mut terms = Vec::new(); - for entry in &rules.dictionary { - let replace = entry.replace.trim(); - if replace.is_empty() { - continue; - } - if terms.iter().any(|existing: &String| existing == replace) { - continue; - } - let projected_len = if terms.is_empty() { - replace.len() - } else { - terms.iter().map(String::len).sum::() + (terms.len() * 2) + replace.len() - }; - if terms.len() >= MAX_TERMS || projected_len > MAX_PROMPT_LEN { - break; - } - terms.push(replace.to_string()); - } - - if terms.is_empty() { - return None; + RewriteCandidateKind::GlossaryCorrection => 8, + RewriteCandidateKind::ConservativeCorrection => 9, + RewriteCandidateKind::DropCueOnly => 10, + RewriteCandidateKind::Literal => 11, } - - Some(format!( - "This is direct dictation. Prefer these exact spellings when heard: {}.", - terms.join(", ") - )) -} - -pub fn list_dictionary(config_override: Option<&Path>) -> Result<()> { - let config = Config::load(config_override)?; - let entries = read_dictionary_file(&config.resolved_dictionary_path())?; - if entries.is_empty() { - println!("No dictionary entries configured."); - return Ok(()); - } - - for entry in entries { - println!("{} -> {}", entry.phrase, entry.replace); - } - - Ok(()) -} - -pub fn add_dictionary(config_override: Option<&Path>, phrase: &str, replace: &str) -> Result<()> { - let config = Config::load(config_override)?; - let path = config.resolved_dictionary_path(); - let mut entries = read_dictionary_file(&path)?; - upsert_dictionary_entry(&mut entries, phrase, replace); - write_dictionary_file(&path, &entries)?; - println!("Added dictionary entry: {} -> {}", phrase, replace); - println!("Dictionary updated: {}", path.display()); - Ok(()) -} - -pub fn remove_dictionary(config_override: Option<&Path>, phrase: &str) -> Result<()> { - let config = Config::load(config_override)?; - let path = config.resolved_dictionary_path(); - let mut entries = read_dictionary_file(&path)?; - let removed = remove_dictionary_entry(&mut entries, phrase); - write_dictionary_file(&path, &entries)?; - if removed { - println!("Removed dictionary entry: {}", phrase); - } else { - println!("No dictionary entry matched: {}", phrase); - } - println!("Dictionary updated: {}", path.display()); - Ok(()) -} - -pub fn list_snippets(config_override: Option<&Path>) -> Result<()> { - let config = Config::load(config_override)?; - let snippets = read_snippet_file(&config.resolved_snippets_path())?; - if snippets.is_empty() { - println!("No snippets configured."); - return Ok(()); - } - - for snippet in snippets { - println!("{} -> {}", snippet.name, snippet.text.replace('\n', "\\n")); - } - - Ok(()) -} - -pub fn add_snippet(config_override: Option<&Path>, name: &str, text: &str) -> Result<()> { - let config = Config::load(config_override)?; - let path = config.resolved_snippets_path(); - let mut snippets = read_snippet_file(&path)?; - upsert_snippet(&mut snippets, name, text); - write_snippet_file(&path, &snippets)?; - println!("Added snippet: {}", name); - println!("Snippets updated: {}", path.display()); - Ok(()) -} - -pub fn remove_snippet(config_override: Option<&Path>, name: &str) -> Result<()> { - let config = Config::load(config_override)?; - let path = config.resolved_snippets_path(); - let mut snippets = read_snippet_file(&path)?; - let removed = remove_snippet_entry(&mut snippets, name); - write_snippet_file(&path, &snippets)?; - if removed { - println!("Removed snippet: {}", name); - } else { - println!("No snippet matched: {}", name); - } - println!("Snippets updated: {}", path.display()); - Ok(()) -} - -pub fn print_rewrite_instructions_path(config_override: Option<&Path>) -> Result<()> { - let config = Config::load(config_override)?; - match config.resolved_rewrite_instructions_path() { - Some(path) => println!("{}", path.display()), - None => println!("No rewrite instructions path configured."), - } - Ok(()) -} - -fn load_custom_instructions(config: &Config) -> Result { - let Some(path) = config.resolved_rewrite_instructions_path() else { - return Ok(String::new()); - }; - - match std::fs::read_to_string(&path) { - Ok(contents) => Ok(contents.trim().to_string()), - Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(String::new()), - Err(err) => Err(WhsprError::Config(format!( - "failed to read rewrite instructions {}: {err}", - path.display() - ))), - } -} - -fn apply_dictionary(text: &str, rules: &PersonalizationRules) -> String { - apply_replacements(text, &rules.dictionary) -} - -fn expand_snippets(text: &str, rules: &PersonalizationRules) -> String { - if rules.snippets.is_empty() || rules.snippet_trigger_words.is_empty() { - return text.trim().to_string(); - } - - let spans = collect_word_spans(text); - if spans.is_empty() { - return text.trim().to_string(); - } - - let mut output = String::new(); - let mut cursor = 0usize; - let mut index = 0usize; - - while index < spans.len() { - let Some(best) = - best_snippet_match(&spans, index, &rules.snippet_trigger_words, &rules.snippets) - else { - index += 1; - continue; - }; - - output.push_str(&text[cursor..spans[index].start]); - output.push_str(best.text); - cursor = spans[index + best.total_words - 1].end; - index += best.total_words; - } - - output.push_str(&text[cursor..]); - output.trim().to_string() -} - -fn apply_replacements(text: &str, entries: &[PreparedDictionaryEntry]) -> String { - if entries.is_empty() { - return text.trim().to_string(); - } - - let spans = collect_word_spans(text); - if spans.is_empty() { - return text.trim().to_string(); - } - - let mut output = String::new(); - let mut cursor = 0usize; - let mut index = 0usize; - - while index < spans.len() { - let Some(best) = best_dictionary_match(&spans, index, entries) else { - index += 1; - continue; - }; - - output.push_str(&text[cursor..spans[index].start]); - output.push_str(&best.replace); - cursor = spans[index + best.words.len() - 1].end; - index += best.words.len(); - } - - output.push_str(&text[cursor..]); - output.trim().to_string() -} - -fn best_dictionary_match<'a>( - spans: &[WordSpan], - index: usize, - entries: &'a [PreparedDictionaryEntry], -) -> Option<&'a PreparedDictionaryEntry> { - entries - .iter() - .filter(|entry| entry.matches(spans, index)) - .max_by_key(|entry| entry.words.len()) -} - -fn best_snippet_match<'a>( - spans: &[WordSpan], - index: usize, - trigger_words: &[String], - snippets: &'a [PreparedSnippet], -) -> Option> { - if !matches_words(spans, index, trigger_words) { - return None; - } - - let snippet_index = index + trigger_words.len(); - snippets - .iter() - .filter(|snippet| snippet.matches(spans, snippet_index)) - .max_by_key(|snippet| snippet.words.len()) - .map(|snippet| SnippetMatch { - text: snippet.text.as_str(), - total_words: trigger_words.len() + snippet.words.len(), - }) -} - -fn matches_words(spans: &[WordSpan], index: usize, words: &[String]) -> bool { - if words.is_empty() || index + words.len() > spans.len() { - return false; - } - - spans[index..index + words.len()] - .iter() - .zip(words) - .all(|(span, word)| span.normalized == *word) -} - -fn collect_word_spans(text: &str) -> Vec { - let mut spans = Vec::new(); - let mut current_start = None; - - for (idx, ch) in text.char_indices() { - if is_word_char(ch) { - current_start.get_or_insert(idx); - continue; - } - - if let Some(start) = current_start.take() { - spans.push(WordSpan { - start, - end: idx, - normalized: normalize_word(&text[start..idx]), - }); - } - } - - if let Some(start) = current_start { - spans.push(WordSpan { - start, - end: text.len(), - normalized: normalize_word(&text[start..]), - }); - } - - spans -} - -fn normalize_word(word: &str) -> String { - word.chars() - .filter(|ch| is_word_char(*ch)) - .flat_map(|ch| ch.to_lowercase()) - .collect() -} - -fn normalized_words(text: &str) -> Vec { - collect_word_spans(text) - .into_iter() - .map(|span| span.normalized) - .collect() -} - -fn is_word_char(ch: char) -> bool { - ch.is_alphanumeric() || matches!(ch, '\'' | '-') -} - -fn read_dictionary_file(path: &Path) -> Result> { - if !path.exists() { - return Ok(Vec::new()); - } - - let contents = std::fs::read_to_string(path).map_err(|e| { - WhsprError::Config(format!("failed to read dictionary {}: {e}", path.display())) - })?; - let file: DictionaryFile = toml::from_str(&contents).map_err(|e| { - WhsprError::Config(format!( - "failed to parse dictionary {}: {e}", - path.display() - )) - })?; - Ok(file.entries) -} - -fn write_dictionary_file(path: &Path, entries: &[DictionaryEntry]) -> Result<()> { - write_parent(path)?; - let file = DictionaryFile { - entries: entries.to_vec(), - }; - let contents = toml::to_string_pretty(&file) - .map_err(|e| WhsprError::Config(format!("failed to encode dictionary: {e}")))?; - std::fs::write(path, contents).map_err(|e| { - WhsprError::Config(format!( - "failed to write dictionary {}: {e}", - path.display() - )) - })?; - Ok(()) -} - -fn read_snippet_file(path: &Path) -> Result> { - if !path.exists() { - return Ok(Vec::new()); - } - - let contents = std::fs::read_to_string(path).map_err(|e| { - WhsprError::Config(format!("failed to read snippets {}: {e}", path.display())) - })?; - let file: SnippetFile = toml::from_str(&contents).map_err(|e| { - WhsprError::Config(format!("failed to parse snippets {}: {e}", path.display())) - })?; - Ok(file.snippets) -} - -fn write_snippet_file(path: &Path, snippets: &[SnippetEntry]) -> Result<()> { - write_parent(path)?; - let file = SnippetFile { - snippets: snippets.to_vec(), - }; - let contents = toml::to_string_pretty(&file) - .map_err(|e| WhsprError::Config(format!("failed to encode snippets: {e}")))?; - std::fs::write(path, contents).map_err(|e| { - WhsprError::Config(format!("failed to write snippets {}: {e}", path.display())) - })?; - Ok(()) -} - -fn write_parent(path: &Path) -> Result<()> { - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - WhsprError::Config(format!( - "failed to create directory {}: {e}", - parent.display() - )) - })?; - } - Ok(()) -} - -fn upsert_dictionary_entry(entries: &mut Vec, phrase: &str, replace: &str) { - let target = normalized_words(phrase); - if let Some(existing) = entries - .iter_mut() - .find(|entry| normalized_words(&entry.phrase) == target) - { - existing.phrase = phrase.to_string(); - existing.replace = replace.to_string(); - return; - } - - entries.push(DictionaryEntry { - phrase: phrase.to_string(), - replace: replace.to_string(), - }); -} - -fn remove_dictionary_entry(entries: &mut Vec, phrase: &str) -> bool { - let target = normalized_words(phrase); - let before = entries.len(); - entries.retain(|entry| normalized_words(&entry.phrase) != target); - before != entries.len() -} - -fn upsert_snippet(snippets: &mut Vec, name: &str, text: &str) { - let target = normalized_words(name); - if let Some(existing) = snippets - .iter_mut() - .find(|entry| normalized_words(&entry.name) == target) - { - existing.name = name.to_string(); - existing.text = text.to_string(); - return; - } - - snippets.push(SnippetEntry { - name: name.to_string(), - text: text.to_string(), - }); -} - -fn remove_snippet_entry(snippets: &mut Vec, name: &str) -> bool { - let target = normalized_words(name); - let before = snippets.len(); - snippets.retain(|entry| normalized_words(&entry.name) != target); - before != snippets.len() -} - -impl PreparedDictionaryEntry { - fn new(entry: DictionaryEntry) -> std::result::Result { - let words = normalized_words(&entry.phrase); - if words.is_empty() { - return Err(entry); - } - - Ok(Self { - replace: entry.replace, - words, - }) - } - - fn matches(&self, spans: &[WordSpan], index: usize) -> bool { - matches_words(spans, index, &self.words) - } -} - -impl PreparedSnippet { - fn new(entry: SnippetEntry) -> std::result::Result { - let words = normalized_words(&entry.name); - if words.is_empty() { - return Err(entry); - } - - Ok(Self { - text: entry.text, - words, - }) - } - - fn matches(&self, spans: &[WordSpan], index: usize) -> bool { - matches_words(spans, index, &self.words) - } -} - -struct SnippetMatch<'a> { - text: &'a str, - total_words: usize, } #[cfg(test)] mod tests { - use super::*; - use crate::config::{Config, PostprocessMode}; - use crate::rewrite_profile::RewriteProfile; + use super::build_rewrite_transcript; use crate::rewrite_protocol::{ RewriteCandidateKind, RewriteEditHypothesisMatchSource, RewriteEditSignalKind, RewriteEditSignalScope, RewriteEditSignalStrength, }; + use crate::transcribe::Transcript; + + use super::super::store::{DictionaryEntry, SnippetEntry}; + use super::super::{ + PersonalizationRules, PreparedDictionaryEntry, PreparedSnippet, normalized_words, + }; fn rules() -> PersonalizationRules { PersonalizationRules { @@ -1272,63 +675,6 @@ mod tests { } } - #[test] - fn dictionary_applies_exact_normalized_replacements() { - let applied = apply_dictionary("I use wisper flow with open, ai.", &rules()); - assert_eq!(applied, "I use Wispr Flow with OpenAI."); - } - - #[test] - fn dictionary_prefers_longest_match() { - let rules = PersonalizationRules { - dictionary: vec![ - PreparedDictionaryEntry::new(DictionaryEntry { - phrase: "open".into(), - replace: "X".into(), - }) - .expect("dictionary"), - PreparedDictionaryEntry::new(DictionaryEntry { - phrase: "open ai".into(), - replace: "OpenAI".into(), - }) - .expect("dictionary"), - ], - ..PersonalizationRules::default() - }; - let applied = apply_dictionary("open ai works", &rules); - assert_eq!(applied, "OpenAI works"); - } - - #[test] - fn snippets_expand_after_trigger() { - let expanded = expand_snippets("please insert signature now", &rules()); - assert_eq!(expanded, "please Best regards,\nNotes now"); - } - - #[test] - fn unmatched_snippet_leaves_text_unchanged() { - let expanded = expand_snippets("please insert unknown now", &rules()); - assert_eq!(expanded, "please insert unknown now"); - } - - #[test] - fn finalize_text_applies_dictionary_then_snippets() { - let finalized = finalize_text("insert meeting follow up about wisper flow", &rules()); - assert_eq!(finalized, "Thanks for the meeting. about Wispr Flow"); - } - - #[test] - fn finalize_text_collapses_spaced_numeric_dot_runs() { - let finalized = finalize_text("MPL 2. 0 and TLS 1 . 3 are common references", &rules()); - assert_eq!(finalized, "MPL 2.0 and TLS 1.3 are common references"); - } - - #[test] - fn finalize_text_preserves_sentence_period_before_words() { - let finalized = finalize_text("Section 2. Next step", &rules()); - assert_eq!(finalized, "Section 2. Next step"); - } - #[test] fn build_rewrite_transcript_applies_dictionary_before_rewrite() { let transcript = Transcript { @@ -1364,10 +710,9 @@ mod tests { let rewrite = build_rewrite_transcript(&transcript, &rules()); assert_eq!(rewrite.correction_aware_text, ""); - assert_eq!(rewrite.edit_intents.len(), 1); assert_eq!( rewrite.edit_intents[0].action, - RewriteEditAction::ReplacePreviousClause + crate::rewrite_protocol::RewriteEditAction::ReplacePreviousClause ); assert_eq!(rewrite.edit_intents[0].trigger, "never mind"); assert_eq!(rewrite.edit_signals.len(), 1); @@ -1514,73 +859,4 @@ mod tests { Some(RewriteCandidateKind::SpanReplacement) ); } - - #[test] - fn load_custom_instructions_tolerates_missing_file() { - let mut config = Config::default(); - config.rewrite.instructions_path = "/tmp/whispers-missing-instructions.txt".into(); - let loaded = load_custom_instructions(&config).expect("load"); - assert!(loaded.is_empty()); - } - - #[test] - fn transcription_prompt_includes_dictionary_targets() { - let prompt = transcription_prompt(&rules()).expect("prompt"); - assert!(prompt.contains("Wispr Flow")); - assert!(prompt.contains("OpenAI")); - } - - #[test] - fn add_and_remove_dictionary_entries_roundtrip() { - let _env_lock = crate::test_support::env_lock(); - let _guard = crate::test_support::EnvVarGuard::capture(&[ - "HOME", - "XDG_CONFIG_HOME", - "XDG_DATA_HOME", - ]); - let home = crate::test_support::unique_temp_dir("personalization-dict-home"); - crate::test_support::set_env("HOME", &home.to_string_lossy()); - crate::test_support::remove_env("XDG_CONFIG_HOME"); - crate::test_support::remove_env("XDG_DATA_HOME"); - - add_dictionary(None, "wisper flow", "Wispr Flow").expect("add dictionary"); - let config = Config::load(None).expect("config"); - let entries = read_dictionary_file(&config.resolved_dictionary_path()).expect("read"); - assert_eq!(entries.len(), 1); - - remove_dictionary(None, "wisper flow").expect("remove dictionary"); - let entries = read_dictionary_file(&config.resolved_dictionary_path()).expect("read"); - assert!(entries.is_empty()); - } - - #[test] - fn add_and_remove_snippets_roundtrip() { - let _env_lock = crate::test_support::env_lock(); - let _guard = crate::test_support::EnvVarGuard::capture(&[ - "HOME", - "XDG_CONFIG_HOME", - "XDG_DATA_HOME", - ]); - let home = crate::test_support::unique_temp_dir("personalization-snippet-home"); - crate::test_support::set_env("HOME", &home.to_string_lossy()); - crate::test_support::remove_env("XDG_CONFIG_HOME"); - crate::test_support::remove_env("XDG_DATA_HOME"); - - add_snippet(None, "signature", "Best regards,\nNotes").expect("add snippet"); - let config = Config::load(None).expect("config"); - let entries = read_snippet_file(&config.resolved_snippets_path()).expect("read"); - assert_eq!(entries.len(), 1); - - remove_snippet(None, "signature").expect("remove snippet"); - let entries = read_snippet_file(&config.resolved_snippets_path()).expect("read"); - assert!(entries.is_empty()); - } - - #[test] - fn default_config_paths_support_personalization_files() { - let config = Config::default(); - assert_eq!(config.postprocess.mode, PostprocessMode::Raw); - assert_eq!(config.rewrite.profile, RewriteProfile::Auto); - assert_eq!(config.personalization.snippet_trigger, "insert"); - } } diff --git a/src/personalization/store.rs b/src/personalization/store.rs new file mode 100644 index 0000000..ffe0af9 --- /dev/null +++ b/src/personalization/store.rs @@ -0,0 +1,321 @@ +use std::path::Path; + +use serde::{Deserialize, Serialize}; + +use crate::config::Config; +use crate::error::{Result, WhsprError}; + +use super::normalized_words; + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct DictionaryEntry { + pub phrase: String, + pub replace: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct SnippetEntry { + pub name: String, + pub text: String, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] +#[serde(default)] +struct DictionaryFile { + entries: Vec, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)] +#[serde(default)] +struct SnippetFile { + snippets: Vec, +} + +pub fn list_dictionary(config_override: Option<&Path>) -> Result<()> { + let config = Config::load(config_override)?; + let entries = read_dictionary_file(&config.resolved_dictionary_path())?; + if entries.is_empty() { + println!("No dictionary entries configured."); + return Ok(()); + } + + for entry in entries { + println!("{} -> {}", entry.phrase, entry.replace); + } + + Ok(()) +} + +pub fn add_dictionary(config_override: Option<&Path>, phrase: &str, replace: &str) -> Result<()> { + let config = Config::load(config_override)?; + let path = config.resolved_dictionary_path(); + let mut entries = read_dictionary_file(&path)?; + upsert_dictionary_entry(&mut entries, phrase, replace); + write_dictionary_file(&path, &entries)?; + println!("Added dictionary entry: {} -> {}", phrase, replace); + println!("Dictionary updated: {}", path.display()); + Ok(()) +} + +pub fn remove_dictionary(config_override: Option<&Path>, phrase: &str) -> Result<()> { + let config = Config::load(config_override)?; + let path = config.resolved_dictionary_path(); + let mut entries = read_dictionary_file(&path)?; + let removed = remove_dictionary_entry(&mut entries, phrase); + write_dictionary_file(&path, &entries)?; + if removed { + println!("Removed dictionary entry: {}", phrase); + } else { + println!("No dictionary entry matched: {}", phrase); + } + println!("Dictionary updated: {}", path.display()); + Ok(()) +} + +pub fn list_snippets(config_override: Option<&Path>) -> Result<()> { + let config = Config::load(config_override)?; + let snippets = read_snippet_file(&config.resolved_snippets_path())?; + if snippets.is_empty() { + println!("No snippets configured."); + return Ok(()); + } + + for snippet in snippets { + println!("{} -> {}", snippet.name, snippet.text.replace('\n', "\\n")); + } + + Ok(()) +} + +pub fn add_snippet(config_override: Option<&Path>, name: &str, text: &str) -> Result<()> { + let config = Config::load(config_override)?; + let path = config.resolved_snippets_path(); + let mut snippets = read_snippet_file(&path)?; + upsert_snippet(&mut snippets, name, text); + write_snippet_file(&path, &snippets)?; + println!("Added snippet: {}", name); + println!("Snippets updated: {}", path.display()); + Ok(()) +} + +pub fn remove_snippet(config_override: Option<&Path>, name: &str) -> Result<()> { + let config = Config::load(config_override)?; + let path = config.resolved_snippets_path(); + let mut snippets = read_snippet_file(&path)?; + let removed = remove_snippet_entry(&mut snippets, name); + write_snippet_file(&path, &snippets)?; + if removed { + println!("Removed snippet: {}", name); + } else { + println!("No snippet matched: {}", name); + } + println!("Snippets updated: {}", path.display()); + Ok(()) +} + +pub fn print_rewrite_instructions_path(config_override: Option<&Path>) -> Result<()> { + let config = Config::load(config_override)?; + match config.resolved_rewrite_instructions_path() { + Some(path) => println!("{}", path.display()), + None => println!("No rewrite instructions path configured."), + } + Ok(()) +} + +pub(super) fn load_custom_instructions(config: &Config) -> Result { + let Some(path) = config.resolved_rewrite_instructions_path() else { + return Ok(String::new()); + }; + + match std::fs::read_to_string(&path) { + Ok(contents) => Ok(contents.trim().to_string()), + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(String::new()), + Err(err) => Err(WhsprError::Config(format!( + "failed to read rewrite instructions {}: {err}", + path.display() + ))), + } +} + +pub(super) fn read_dictionary_file(path: &Path) -> Result> { + if !path.exists() { + return Ok(Vec::new()); + } + + let contents = std::fs::read_to_string(path).map_err(|e| { + WhsprError::Config(format!("failed to read dictionary {}: {e}", path.display())) + })?; + let file: DictionaryFile = toml::from_str(&contents).map_err(|e| { + WhsprError::Config(format!( + "failed to parse dictionary {}: {e}", + path.display() + )) + })?; + Ok(file.entries) +} + +pub(super) fn write_dictionary_file(path: &Path, entries: &[DictionaryEntry]) -> Result<()> { + write_parent(path)?; + let file = DictionaryFile { + entries: entries.to_vec(), + }; + let contents = toml::to_string_pretty(&file) + .map_err(|e| WhsprError::Config(format!("failed to encode dictionary: {e}")))?; + std::fs::write(path, contents).map_err(|e| { + WhsprError::Config(format!( + "failed to write dictionary {}: {e}", + path.display() + )) + })?; + Ok(()) +} + +pub(super) fn read_snippet_file(path: &Path) -> Result> { + if !path.exists() { + return Ok(Vec::new()); + } + + let contents = std::fs::read_to_string(path).map_err(|e| { + WhsprError::Config(format!("failed to read snippets {}: {e}", path.display())) + })?; + let file: SnippetFile = toml::from_str(&contents).map_err(|e| { + WhsprError::Config(format!("failed to parse snippets {}: {e}", path.display())) + })?; + Ok(file.snippets) +} + +pub(super) fn write_snippet_file(path: &Path, snippets: &[SnippetEntry]) -> Result<()> { + write_parent(path)?; + let file = SnippetFile { + snippets: snippets.to_vec(), + }; + let contents = toml::to_string_pretty(&file) + .map_err(|e| WhsprError::Config(format!("failed to encode snippets: {e}")))?; + std::fs::write(path, contents).map_err(|e| { + WhsprError::Config(format!("failed to write snippets {}: {e}", path.display())) + })?; + Ok(()) +} + +fn write_parent(path: &Path) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|e| { + WhsprError::Config(format!( + "failed to create directory {}: {e}", + parent.display() + )) + })?; + } + Ok(()) +} + +fn upsert_dictionary_entry(entries: &mut Vec, phrase: &str, replace: &str) { + let target = normalized_words(phrase); + if let Some(existing) = entries + .iter_mut() + .find(|entry| normalized_words(&entry.phrase) == target) + { + existing.phrase = phrase.to_string(); + existing.replace = replace.to_string(); + return; + } + + entries.push(DictionaryEntry { + phrase: phrase.to_string(), + replace: replace.to_string(), + }); +} + +fn remove_dictionary_entry(entries: &mut Vec, phrase: &str) -> bool { + let target = normalized_words(phrase); + let before = entries.len(); + entries.retain(|entry| normalized_words(&entry.phrase) != target); + before != entries.len() +} + +fn upsert_snippet(snippets: &mut Vec, name: &str, text: &str) { + let target = normalized_words(name); + if let Some(existing) = snippets + .iter_mut() + .find(|entry| normalized_words(&entry.name) == target) + { + existing.name = name.to_string(); + existing.text = text.to_string(); + return; + } + + snippets.push(SnippetEntry { + name: name.to_string(), + text: text.to_string(), + }); +} + +fn remove_snippet_entry(snippets: &mut Vec, name: &str) -> bool { + let target = normalized_words(name); + let before = snippets.len(); + snippets.retain(|entry| normalized_words(&entry.name) != target); + before != snippets.len() +} + +#[cfg(test)] +mod tests { + use super::{ + add_dictionary, add_snippet, load_custom_instructions, read_dictionary_file, + read_snippet_file, remove_dictionary, remove_snippet, + }; + use crate::config::Config; + + #[test] + fn add_and_remove_dictionary_entries_roundtrip() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&[ + "HOME", + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + ]); + let home = crate::test_support::unique_temp_dir("personalization-dict-home"); + crate::test_support::set_env("HOME", &home.to_string_lossy()); + crate::test_support::remove_env("XDG_CONFIG_HOME"); + crate::test_support::remove_env("XDG_DATA_HOME"); + + add_dictionary(None, "wisper flow", "Wispr Flow").expect("add dictionary"); + let config = Config::load(None).expect("config"); + let entries = read_dictionary_file(&config.resolved_dictionary_path()).expect("read"); + assert_eq!(entries.len(), 1); + + remove_dictionary(None, "wisper flow").expect("remove dictionary"); + let entries = read_dictionary_file(&config.resolved_dictionary_path()).expect("read"); + assert!(entries.is_empty()); + } + + #[test] + fn add_and_remove_snippets_roundtrip() { + let _env_lock = crate::test_support::env_lock(); + let _guard = crate::test_support::EnvVarGuard::capture(&[ + "HOME", + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + ]); + let home = crate::test_support::unique_temp_dir("personalization-snippet-home"); + crate::test_support::set_env("HOME", &home.to_string_lossy()); + crate::test_support::remove_env("XDG_CONFIG_HOME"); + crate::test_support::remove_env("XDG_DATA_HOME"); + + add_snippet(None, "signature", "Best regards,\nNotes").expect("add snippet"); + let config = Config::load(None).expect("config"); + let entries = read_snippet_file(&config.resolved_snippets_path()).expect("read"); + assert_eq!(entries.len(), 1); + + remove_snippet(None, "signature").expect("remove snippet"); + let entries = read_snippet_file(&config.resolved_snippets_path()).expect("read"); + assert!(entries.is_empty()); + } + + #[test] + fn load_custom_instructions_tolerates_missing_file() { + let mut config = Config::default(); + config.rewrite.instructions_path = "/tmp/whispers-missing-instructions.txt".into(); + let loaded = load_custom_instructions(&config).expect("load"); + assert!(loaded.is_empty()); + } +} diff --git a/src/postprocess.rs b/src/postprocess.rs deleted file mode 100644 index 2cafbef..0000000 --- a/src/postprocess.rs +++ /dev/null @@ -1,373 +0,0 @@ -use std::path::{Path, PathBuf}; -use std::time::Duration; -use std::time::Instant; - -use crate::cleanup; -use crate::cloud; -use crate::config::{Config, PostprocessMode, RewriteBackend, RewriteFallback}; -use crate::context::TypingContext; -use crate::personalization::{self, PersonalizationRules}; -use crate::rewrite_model; -use crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind; -use crate::rewrite_worker::{self, RewriteService}; -use crate::session::{self, EligibleSessionEntry, SessionRewriteSummary}; -use crate::transcribe::Transcript; - -const FEEDBACK_DRAIN_DELAY: Duration = Duration::from_millis(150); - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum FinalizedOperation { - Append, - ReplaceLastEntry { - entry_id: u64, - delete_graphemes: usize, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct FinalizedTranscript { - pub text: String, - pub operation: FinalizedOperation, - pub rewrite_summary: SessionRewriteSummary, -} - -pub fn raw_text(transcript: &Transcript) -> String { - transcript.raw_text.trim().to_string() -} - -fn base_text(config: &Config, transcript: &Transcript) -> String { - match config.postprocess.mode { - PostprocessMode::LegacyBasic => cleanup::clean_transcript(transcript, &config.cleanup), - PostprocessMode::AdvancedLocal => cleanup::correction_aware_text(transcript), - PostprocessMode::Raw => raw_text(transcript), - } -} - -pub fn resolve_rewrite_model_path(config: &Config) -> Option { - if let Some(path) = config.resolved_rewrite_model_path() { - return Some(path); - } - - rewrite_model::selected_model_path(&config.rewrite.selected_model) -} - -pub async fn finalize_transcript( - config: &Config, - transcript: Transcript, - rewrite_service: Option<&RewriteService>, - typing_context: Option<&TypingContext>, - recent_session: Option<&EligibleSessionEntry>, -) -> FinalizedTranscript { - let started = Instant::now(); - let rules = load_runtime_rules(config); - let finalized = match config.postprocess.mode { - PostprocessMode::Raw => finalize_plain_text( - raw_text(&transcript), - SessionRewriteSummary { - had_edit_cues: false, - rewrite_used: false, - recommended_candidate: None, - }, - &rules, - ), - PostprocessMode::LegacyBasic => finalize_plain_text( - cleanup::clean_transcript(&transcript, &config.cleanup), - SessionRewriteSummary { - had_edit_cues: false, - rewrite_used: false, - recommended_candidate: None, - }, - &rules, - ), - PostprocessMode::AdvancedLocal => { - rewrite_transcript_or_fallback( - config, - &transcript, - rewrite_service, - &rules, - typing_context, - recent_session, - ) - .await - } - }; - tracing::info!( - elapsed_ms = started.elapsed().as_millis(), - mode = config.postprocess.mode.as_str(), - rewrite_used = finalized.rewrite_summary.rewrite_used, - output_chars = finalized.text.len(), - "finalize_transcript finished" - ); - finalized -} - -pub async fn wait_for_feedback_drain() { - tokio::time::sleep(FEEDBACK_DRAIN_DELAY).await; -} - -pub fn prepare_rewrite_service(config: &Config) -> Option { - if config.postprocess.mode != PostprocessMode::AdvancedLocal { - return None; - } - - if config.rewrite.backend != RewriteBackend::Local - && config.rewrite.fallback != RewriteFallback::Local - { - return None; - } - - let model_path = resolve_rewrite_model_path(config)?; - Some(rewrite_worker::RewriteService::new( - &config.rewrite, - &model_path, - )) -} - -pub fn prewarm_rewrite_service(service: &RewriteService, phase: &str) { - match service.prewarm() { - Ok(()) => tracing::info!("prewarming rewrite worker via {}", phase,), - Err(err) => tracing::warn!("failed to prewarm rewrite worker: {err}"), - } -} - -async fn rewrite_transcript_or_fallback( - config: &Config, - transcript: &Transcript, - rewrite_service: Option<&RewriteService>, - rules: &PersonalizationRules, - typing_context: Option<&TypingContext>, - recent_session: Option<&EligibleSessionEntry>, -) -> FinalizedTranscript { - let fallback = base_text(config, transcript); - let local_model_path = resolve_rewrite_model_path(config); - let local_rewrite_required = config.rewrite.backend == RewriteBackend::Local - || config.rewrite.fallback == RewriteFallback::Local; - if local_rewrite_required && local_model_path.is_none() { - tracing::warn!( - "rewrite backend requires a local model but none is configured; using fallback" - ); - return finalize_plain_text( - fallback, - SessionRewriteSummary { - had_edit_cues: false, - rewrite_used: false, - recommended_candidate: None, - }, - rules, - ); - } - let mut rewrite_transcript = personalization::build_rewrite_transcript(transcript, rules); - rewrite_transcript.typing_context = typing_context.and_then(session::to_rewrite_typing_context); - let session_plan = session::build_backtrack_plan(&rewrite_transcript, recent_session); - rewrite_transcript.recent_session_entries = session_plan.recent_entries.clone(); - rewrite_transcript.session_backtrack_candidates = session_plan.candidates.clone(); - rewrite_transcript.recommended_session_candidate = session_plan.recommended.clone(); - tracing::debug!( - edit_hypotheses = rewrite_transcript.edit_hypotheses.len(), - rewrite_candidates = rewrite_transcript.rewrite_candidates.len(), - session_backtrack_candidates = rewrite_transcript.session_backtrack_candidates.len(), - recommended_candidate = rewrite_transcript - .recommended_candidate - .as_ref() - .map(|candidate| candidate.text.as_str()) - .unwrap_or(""), - "advanced_local prepared rewrite request" - ); - let custom_instructions = personalization::custom_instructions(rules); - let deterministic_session_replacement = session_plan.deterministic_replacement_text.clone(); - - if let Some(text) = deterministic_session_replacement { - tracing::debug!( - output_len = text.len(), - "advanced_local using deterministic session replacement" - ); - let operation = rewrite_transcript - .recommended_session_candidate - .as_ref() - .and_then(|candidate| { - matches!( - candidate.kind, - RewriteSessionBacktrackCandidateKind::ReplaceLastEntry - ) - .then_some(FinalizedOperation::ReplaceLastEntry { - entry_id: candidate.entry_id?, - delete_graphemes: candidate.delete_graphemes, - }) - }) - .unwrap_or(FinalizedOperation::Append); - return finalize_plain_text( - text, - SessionRewriteSummary { - had_edit_cues: !rewrite_transcript.edit_signals.is_empty() - || !rewrite_transcript.edit_hypotheses.is_empty(), - rewrite_used: false, - recommended_candidate: rewrite_transcript - .recommended_session_candidate - .as_ref() - .map(|candidate| candidate.text.clone()) - .or_else(|| { - rewrite_transcript - .recommended_candidate - .as_ref() - .map(|candidate| candidate.text.clone()) - }), - }, - rules, - ) - .with_operation(operation); - } - - let rewrite_result = match config.rewrite.backend { - RewriteBackend::Local => { - local_rewrite_result( - config, - rewrite_service, - local_model_path - .as_ref() - .expect("local rewrite requires resolved model path"), - &rewrite_transcript, - custom_instructions, - ) - .await - } - RewriteBackend::Cloud => { - let cloud_service = cloud::CloudService::new(config); - let cloud_result = match cloud_service { - Ok(service) => { - service - .rewrite_transcript(config, &rewrite_transcript, custom_instructions) - .await - } - Err(err) => Err(err), - }; - match cloud_result { - Ok(text) => Ok(text), - Err(err) if config.rewrite.fallback == RewriteFallback::Local => { - tracing::warn!("cloud rewrite failed: {err}; falling back to local rewrite"); - local_rewrite_result( - config, - rewrite_service, - local_model_path - .as_ref() - .expect("local rewrite fallback requires resolved model path"), - &rewrite_transcript, - custom_instructions, - ) - .await - } - Err(err) => Err(err), - } - } - }; - - let had_edit_cues = !rewrite_transcript.edit_signals.is_empty() - || !rewrite_transcript.edit_hypotheses.is_empty(); - let recommended_candidate = rewrite_transcript - .recommended_session_candidate - .as_ref() - .map(|candidate| candidate.text.clone()) - .or_else(|| { - rewrite_transcript - .recommended_candidate - .as_ref() - .map(|candidate| candidate.text.clone()) - }); - - let (base, rewrite_used) = match rewrite_result { - Ok(text) if !text.trim().is_empty() => { - tracing::debug!( - output_len = text.len(), - "advanced_local rewrite applied successfully" - ); - (text, true) - } - Ok(_) => { - tracing::warn!("rewrite model returned empty text; using fallback"); - (fallback, false) - } - Err(err) => { - tracing::warn!("rewrite failed: {err}; using fallback"); - (fallback, false) - } - }; - let operation = rewrite_transcript - .recommended_session_candidate - .as_ref() - .and_then(|candidate| { - matches!( - candidate.kind, - RewriteSessionBacktrackCandidateKind::ReplaceLastEntry - ) - .then_some(FinalizedOperation::ReplaceLastEntry { - entry_id: candidate.entry_id?, - delete_graphemes: candidate.delete_graphemes, - }) - }) - .unwrap_or(FinalizedOperation::Append); - - finalize_plain_text( - base, - SessionRewriteSummary { - had_edit_cues, - rewrite_used, - recommended_candidate, - }, - rules, - ) - .with_operation(operation) -} - -async fn local_rewrite_result( - config: &Config, - rewrite_service: Option<&RewriteService>, - model_path: &Path, - rewrite_transcript: &crate::rewrite_protocol::RewriteTranscript, - custom_instructions: Option<&str>, -) -> crate::error::Result { - if let Some(service) = rewrite_service { - rewrite_worker::rewrite_with_service( - service, - &config.rewrite, - rewrite_transcript, - custom_instructions, - ) - .await - } else { - rewrite_worker::rewrite_transcript( - &config.rewrite, - model_path, - rewrite_transcript, - custom_instructions, - ) - .await - } -} - -fn finalize_plain_text( - text: String, - rewrite_summary: SessionRewriteSummary, - rules: &PersonalizationRules, -) -> FinalizedTranscript { - FinalizedTranscript { - text: personalization::finalize_text(&text, rules), - operation: FinalizedOperation::Append, - rewrite_summary, - } -} - -fn load_runtime_rules(config: &Config) -> PersonalizationRules { - match personalization::load_rules(config) { - Ok(rules) => rules, - Err(err) => { - tracing::warn!("failed to load personalization rules: {err}"); - PersonalizationRules::default() - } - } -} - -impl FinalizedTranscript { - fn with_operation(mut self, operation: FinalizedOperation) -> Self { - self.operation = operation; - self - } -} diff --git a/src/postprocess/execution.rs b/src/postprocess/execution.rs new file mode 100644 index 0000000..e4f8ae0 --- /dev/null +++ b/src/postprocess/execution.rs @@ -0,0 +1,125 @@ +use std::path::Path; + +use crate::cloud; +use crate::config::{Config, RewriteBackend, RewriteFallback}; +use crate::rewrite_worker::{self, RewriteService}; + +use super::planning::{self, RewritePlan}; + +pub fn prepare_rewrite_service(config: &Config) -> Option { + if !config.postprocess.mode.uses_rewrite() { + return None; + } + + if config.rewrite.backend != RewriteBackend::Local + && config.rewrite.fallback != RewriteFallback::Local + { + return None; + } + + if !crate::rewrite::local_rewrite_available() { + return None; + } + + let model_path = planning::resolve_rewrite_model_path(config)?; + Some(rewrite_worker::RewriteService::new( + &config.rewrite, + &model_path, + )) +} + +pub fn prewarm_rewrite_service(service: &RewriteService, phase: &str) { + match service.prewarm() { + Ok(()) => tracing::info!("prewarming rewrite worker via {}", phase,), + Err(err) => tracing::warn!("failed to prewarm rewrite worker: {err}"), + } +} + +pub(crate) async fn execute_rewrite( + config: &Config, + rewrite_service: Option<&RewriteService>, + plan: &RewritePlan, +) -> crate::error::Result { + match config.rewrite.backend { + RewriteBackend::Local => { + local_rewrite_result( + config, + rewrite_service, + plan.local_model_path + .as_ref() + .expect("local rewrite requires resolved model path"), + &plan.rewrite_transcript, + plan.custom_instructions.as_deref(), + ) + .await + } + RewriteBackend::Cloud => { + let cloud_service = cloud::CloudService::new(config); + let cloud_result = match cloud_service { + Ok(service) => { + service + .rewrite_transcript( + config, + &plan.rewrite_transcript, + plan.custom_instructions.as_deref(), + ) + .await + } + Err(err) => Err(err), + }; + match cloud_result { + Ok(text) => Ok(text), + Err(err) + if config.rewrite.fallback == RewriteFallback::Local + && crate::rewrite::local_rewrite_available() => + { + tracing::warn!("cloud rewrite failed: {err}; falling back to local rewrite"); + local_rewrite_result( + config, + rewrite_service, + plan.local_model_path + .as_ref() + .expect("local rewrite fallback requires resolved model path"), + &plan.rewrite_transcript, + plan.custom_instructions.as_deref(), + ) + .await + } + Err(err) => Err(err), + } + } + } +} + +async fn local_rewrite_result( + config: &Config, + rewrite_service: Option<&RewriteService>, + model_path: &Path, + rewrite_transcript: &crate::rewrite_protocol::RewriteTranscript, + custom_instructions: Option<&str>, +) -> crate::error::Result { + if !crate::rewrite::local_rewrite_available() { + return Err(crate::error::WhsprError::Rewrite( + "local rewrite is unavailable in this build; rebuild with --features local-rewrite" + .into(), + )); + } + + if let Some(service) = rewrite_service { + rewrite_worker::rewrite_with_service( + service, + &config.rewrite, + rewrite_transcript, + custom_instructions, + ) + .await + } else { + rewrite_worker::rewrite_transcript( + &config.rewrite, + model_path, + rewrite_transcript, + custom_instructions, + ) + .await + } +} diff --git a/src/postprocess/finalize.rs b/src/postprocess/finalize.rs new file mode 100644 index 0000000..c0a7747 --- /dev/null +++ b/src/postprocess/finalize.rs @@ -0,0 +1,354 @@ +use std::time::Duration; +use std::time::Instant; + +use crate::agentic_rewrite; +use crate::config::{Config, PostprocessMode, RewriteBackend, RewriteFallback}; +use crate::context::TypingContext; +use crate::personalization::{self, PersonalizationRules}; +use crate::rewrite_protocol::{RewriteCorrectionPolicy, RewriteTranscript}; +use crate::rewrite_worker::RewriteService; +use crate::session::{EligibleSessionEntry, SessionRewriteSummary}; +use crate::transcribe::Transcript; + +use super::{execution, planning}; + +const FEEDBACK_DRAIN_DELAY: Duration = Duration::from_millis(150); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FinalizedOperation { + Append, + ReplaceLastEntry { + entry_id: u64, + delete_graphemes: usize, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FinalizedTranscript { + pub text: String, + pub operation: FinalizedOperation, + pub rewrite_summary: SessionRewriteSummary, +} + +pub async fn finalize_transcript( + config: &Config, + transcript: Transcript, + rewrite_service: Option<&RewriteService>, + typing_context: Option<&TypingContext>, + recent_session: Option<&EligibleSessionEntry>, +) -> FinalizedTranscript { + let started = Instant::now(); + let finalized = match config.postprocess.mode { + PostprocessMode::Raw => { + let rules = planning::load_runtime_rules(config); + finalize_plain_text( + planning::raw_text(&transcript), + SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: false, + recommended_candidate: None, + }, + &rules, + ) + } + PostprocessMode::LegacyBasic => { + let rules = planning::load_runtime_rules(config); + finalize_plain_text( + crate::cleanup::clean_transcript(&transcript, &config.cleanup), + SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: false, + recommended_candidate: None, + }, + &rules, + ) + } + PostprocessMode::AdvancedLocal | PostprocessMode::AgenticRewrite => { + finalize_rewrite_plan_or_fallback( + config, + rewrite_service, + planning::build_rewrite_plan(config, &transcript, typing_context, recent_session), + ) + .await + } + }; + tracing::info!( + elapsed_ms = started.elapsed().as_millis(), + mode = config.postprocess.mode.as_str(), + rewrite_used = finalized.rewrite_summary.rewrite_used, + output_chars = finalized.text.len(), + "finalize_transcript finished" + ); + finalized +} + +pub async fn wait_for_feedback_drain() { + tokio::time::sleep(FEEDBACK_DRAIN_DELAY).await; +} + +async fn finalize_rewrite_plan_or_fallback( + config: &Config, + rewrite_service: Option<&RewriteService>, + plan: planning::RewritePlan, +) -> FinalizedTranscript { + if let Some(text) = plan.deterministic_replacement_text.clone() { + tracing::debug!( + output_len = text.len(), + mode = config.postprocess.mode.as_str(), + "using deterministic session replacement" + ); + return finalize_plain_text( + text, + SessionRewriteSummary { + had_edit_cues: plan.had_edit_cues, + rewrite_used: false, + recommended_candidate: plan.recommended_candidate.clone(), + }, + &plan.rules, + ) + .with_operation(plan.operation.clone()); + } + + let local_rewrite_available = crate::rewrite::local_rewrite_available(); + let local_backend_requested = config.rewrite.backend == RewriteBackend::Local; + + if local_backend_requested && !local_rewrite_available { + tracing::warn!( + "local rewrite backend requested, but this build does not include local rewrite support; using fallback" + ); + return finalize_unavailable_rewrite_fallback(plan); + } + + let local_rewrite_required = local_backend_requested + || (config.rewrite.fallback == RewriteFallback::Local && local_rewrite_available); + if local_rewrite_required && plan.local_model_path.is_none() { + tracing::warn!( + "rewrite backend requires a local model but none is configured; using fallback" + ); + return finalize_unavailable_rewrite_fallback(plan); + } + + let rewrite_result = execution::execute_rewrite(config, rewrite_service, &plan).await; + finalize_rewrite_attempt(config, plan, rewrite_result) +} + +fn finalize_rewrite_attempt( + config: &Config, + plan: planning::RewritePlan, + rewrite_result: crate::error::Result, +) -> FinalizedTranscript { + let (base, rewrite_used) = match rewrite_result { + Ok(text) if rewrite_output_accepted(config, &plan.rewrite_transcript, &text) => { + tracing::debug!( + output_len = text.len(), + mode = config.postprocess.mode.as_str(), + "rewrite applied successfully" + ); + (text, true) + } + Ok(text) if text.trim().is_empty() => { + tracing::warn!("rewrite model returned empty text; using fallback"); + (plan.fallback_text, false) + } + Ok(text) => { + tracing::warn!( + mode = config.postprocess.mode.as_str(), + output_len = text.len(), + "rewrite output failed acceptance guard; using fallback" + ); + (plan.fallback_text, false) + } + Err(err) => { + tracing::warn!("rewrite failed: {err}; using fallback"); + (plan.fallback_text, false) + } + }; + + finalize_plain_text( + base, + SessionRewriteSummary { + had_edit_cues: plan.had_edit_cues, + rewrite_used, + recommended_candidate: plan.recommended_candidate, + }, + &plan.rules, + ) + .with_operation(plan.operation) +} + +fn finalize_plain_text( + text: String, + rewrite_summary: SessionRewriteSummary, + rules: &PersonalizationRules, +) -> FinalizedTranscript { + FinalizedTranscript { + text: personalization::finalize_text(&text, rules), + operation: FinalizedOperation::Append, + rewrite_summary, + } +} + +fn finalize_unavailable_rewrite_fallback(plan: planning::RewritePlan) -> FinalizedTranscript { + finalize_plain_text( + plan.fallback_text, + SessionRewriteSummary { + had_edit_cues: plan.had_edit_cues, + rewrite_used: false, + recommended_candidate: plan.recommended_candidate, + }, + &plan.rules, + ) + .with_operation(plan.operation) +} + +impl FinalizedTranscript { + fn with_operation(mut self, operation: FinalizedOperation) -> Self { + self.operation = operation; + self + } +} + +fn rewrite_output_accepted( + config: &Config, + rewrite_transcript: &RewriteTranscript, + text: &str, +) -> bool { + if text.trim().is_empty() { + return false; + } + + if config.postprocess.mode != PostprocessMode::AgenticRewrite { + return true; + } + + match rewrite_transcript.policy_context.correction_policy { + RewriteCorrectionPolicy::Conservative => { + agentic_rewrite::conservative_output_allowed(rewrite_transcript, text) + } + RewriteCorrectionPolicy::Balanced | RewriteCorrectionPolicy::Aggressive => true, + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::*; + use crate::rewrite_protocol::{ + RewriteCandidate, RewriteCandidateKind, RewritePolicyContext, RewriteTranscript, + }; + + fn plan_config(mode: PostprocessMode, backend: RewriteBackend) -> Config { + let mut config = Config::default(); + config.postprocess.mode = mode; + config.rewrite.backend = backend; + config + } + + fn rewrite_plan() -> planning::RewritePlan { + planning::RewritePlan { + rules: PersonalizationRules::default(), + fallback_text: "fallback text".into(), + rewrite_transcript: RewriteTranscript { + raw_text: "raw text".into(), + correction_aware_text: "fallback text".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "allowed rewrite".into(), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + }, + custom_instructions: None, + local_model_path: Some(PathBuf::from("/tmp/model.gguf")), + operation: FinalizedOperation::Append, + had_edit_cues: false, + recommended_candidate: Some("allowed rewrite".into()), + deterministic_replacement_text: None, + } + } + + #[tokio::test] + async fn deterministic_session_replacement_bypasses_rewrite_and_preserves_replace_operation() { + let config = plan_config(PostprocessMode::AdvancedLocal, RewriteBackend::Local); + let mut plan = rewrite_plan(); + plan.operation = FinalizedOperation::ReplaceLastEntry { + entry_id: 7, + delete_graphemes: 4, + }; + plan.deterministic_replacement_text = Some("deterministic replacement".into()); + plan.local_model_path = None; + + let finalized = finalize_rewrite_plan_or_fallback(&config, None, plan).await; + + assert_eq!(finalized.text, "deterministic replacement"); + assert_eq!( + finalized.operation, + FinalizedOperation::ReplaceLastEntry { + entry_id: 7, + delete_graphemes: 4, + } + ); + assert!(!finalized.rewrite_summary.rewrite_used); + } + + #[tokio::test] + #[cfg(not(feature = "local-rewrite"))] + async fn local_rewrite_unavailable_build_falls_back_to_plain_text() { + let config = plan_config(PostprocessMode::AdvancedLocal, RewriteBackend::Local); + let mut plan = rewrite_plan(); + plan.operation = FinalizedOperation::ReplaceLastEntry { + entry_id: 11, + delete_graphemes: 6, + }; + + let finalized = finalize_rewrite_plan_or_fallback(&config, None, plan).await; + + assert_eq!(finalized.text, "fallback text"); + assert!(!finalized.rewrite_summary.rewrite_used); + assert_eq!( + finalized.operation, + FinalizedOperation::ReplaceLastEntry { + entry_id: 11, + delete_graphemes: 6, + } + ); + } + + #[tokio::test] + #[cfg(feature = "local-rewrite")] + async fn missing_local_model_falls_back_to_plain_text() { + let config = plan_config(PostprocessMode::AdvancedLocal, RewriteBackend::Local); + let mut plan = rewrite_plan(); + plan.local_model_path = None; + + let finalized = finalize_rewrite_plan_or_fallback(&config, None, plan).await; + + assert_eq!(finalized.text, "fallback text"); + assert!(!finalized.rewrite_summary.rewrite_used); + } + + #[test] + fn conservative_agentic_rejection_falls_back_to_precomputed_text() { + let mut config = plan_config(PostprocessMode::AgenticRewrite, RewriteBackend::Cloud); + config.rewrite.fallback = RewriteFallback::None; + let mut plan = rewrite_plan(); + plan.rewrite_transcript.policy_context.correction_policy = + RewriteCorrectionPolicy::Conservative; + + let finalized = finalize_rewrite_attempt(&config, plan, Ok("rejected rewrite".into())); + + assert_eq!(finalized.text, "fallback text"); + assert!(!finalized.rewrite_summary.rewrite_used); + } +} diff --git a/src/postprocess/mod.rs b/src/postprocess/mod.rs new file mode 100644 index 0000000..d10b269 --- /dev/null +++ b/src/postprocess/mod.rs @@ -0,0 +1,3 @@ +pub mod execution; +pub mod finalize; +pub mod planning; diff --git a/src/postprocess/planning.rs b/src/postprocess/planning.rs new file mode 100644 index 0000000..e802352 --- /dev/null +++ b/src/postprocess/planning.rs @@ -0,0 +1,128 @@ +use std::path::PathBuf; + +use crate::agentic_rewrite; +use crate::cleanup; +use crate::config::{Config, PostprocessMode}; +use crate::context::TypingContext; +use crate::personalization::{self, PersonalizationRules}; +use crate::rewrite_model; +use crate::rewrite_protocol::{RewriteSessionBacktrackCandidateKind, RewriteTranscript}; +use crate::session::{self, EligibleSessionEntry}; +use crate::transcribe::Transcript; + +use super::finalize::FinalizedOperation; + +pub(crate) struct RewritePlan { + pub rules: PersonalizationRules, + pub fallback_text: String, + pub rewrite_transcript: RewriteTranscript, + pub custom_instructions: Option, + pub local_model_path: Option, + pub operation: FinalizedOperation, + pub had_edit_cues: bool, + pub recommended_candidate: Option, + pub deterministic_replacement_text: Option, +} + +pub fn raw_text(transcript: &Transcript) -> String { + transcript.raw_text.trim().to_string() +} + +pub(crate) fn resolve_rewrite_model_path(config: &Config) -> Option { + if let Some(path) = config.resolved_rewrite_model_path() { + return Some(path); + } + + rewrite_model::selected_model_path(&config.rewrite.selected_model) +} + +pub(crate) fn load_runtime_rules(config: &Config) -> PersonalizationRules { + match personalization::load_rules(config) { + Ok(rules) => rules, + Err(err) => { + tracing::warn!("failed to load personalization rules: {err}"); + PersonalizationRules::default() + } + } +} + +pub(crate) fn build_rewrite_plan( + config: &Config, + transcript: &Transcript, + typing_context: Option<&TypingContext>, + recent_session: Option<&EligibleSessionEntry>, +) -> RewritePlan { + let rules = load_runtime_rules(config); + let fallback_text = base_text(config, transcript); + let local_model_path = resolve_rewrite_model_path(config); + let mut rewrite_transcript = personalization::build_rewrite_transcript(transcript, &rules); + rewrite_transcript.typing_context = typing_context.and_then(session::to_rewrite_typing_context); + if config.postprocess.mode == PostprocessMode::AgenticRewrite { + agentic_rewrite::apply_runtime_policy(config, &mut rewrite_transcript); + } + let session_plan = session::build_backtrack_plan(&rewrite_transcript, recent_session); + rewrite_transcript.recent_session_entries = session_plan.recent_entries.clone(); + rewrite_transcript.session_backtrack_candidates = session_plan.candidates.clone(); + rewrite_transcript.recommended_session_candidate = session_plan.recommended.clone(); + tracing::debug!( + mode = config.postprocess.mode.as_str(), + edit_hypotheses = rewrite_transcript.edit_hypotheses.len(), + rewrite_candidates = rewrite_transcript.rewrite_candidates.len(), + session_backtrack_candidates = rewrite_transcript.session_backtrack_candidates.len(), + recommended_candidate = rewrite_transcript + .recommended_candidate + .as_ref() + .map(|candidate| candidate.text.as_str()) + .unwrap_or(""), + "prepared rewrite request" + ); + + RewritePlan { + custom_instructions: personalization::custom_instructions(&rules).map(str::to_string), + local_model_path, + operation: recommended_operation(&rewrite_transcript), + had_edit_cues: !rewrite_transcript.edit_signals.is_empty() + || !rewrite_transcript.edit_hypotheses.is_empty(), + recommended_candidate: rewrite_transcript + .recommended_session_candidate + .as_ref() + .map(|candidate| candidate.text.clone()) + .or_else(|| { + rewrite_transcript + .recommended_candidate + .as_ref() + .map(|candidate| candidate.text.clone()) + }), + deterministic_replacement_text: session_plan.deterministic_replacement_text.clone(), + rules, + fallback_text, + rewrite_transcript, + } +} + +fn base_text(config: &Config, transcript: &Transcript) -> String { + match config.postprocess.mode { + PostprocessMode::LegacyBasic => cleanup::clean_transcript(transcript, &config.cleanup), + PostprocessMode::AdvancedLocal | PostprocessMode::AgenticRewrite => { + cleanup::correction_aware_text(transcript) + } + PostprocessMode::Raw => raw_text(transcript), + } +} + +fn recommended_operation(rewrite_transcript: &RewriteTranscript) -> FinalizedOperation { + rewrite_transcript + .recommended_session_candidate + .as_ref() + .and_then(|candidate| { + matches!( + candidate.kind, + RewriteSessionBacktrackCandidateKind::ReplaceLastEntry + ) + .then_some(FinalizedOperation::ReplaceLastEntry { + entry_id: candidate.entry_id?, + delete_graphemes: candidate.delete_graphemes, + }) + }) + .unwrap_or(FinalizedOperation::Append) +} diff --git a/src/rewrite.rs b/src/rewrite.rs deleted file mode 100644 index 29c3910..0000000 --- a/src/rewrite.rs +++ /dev/null @@ -1,1271 +0,0 @@ -use std::num::NonZeroU32; -use std::path::Path; -use std::sync::OnceLock; - -use encoding_rs::UTF_8; -use llama_cpp_2::context::params::LlamaContextParams; -use llama_cpp_2::llama_backend::LlamaBackend; -use llama_cpp_2::llama_batch::LlamaBatch; -use llama_cpp_2::model::params::LlamaModelParams; -use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaChatTemplate, LlamaModel}; -use llama_cpp_2::sampling::LlamaSampler; - -use crate::rewrite_profile::ResolvedRewriteProfile; -use crate::rewrite_profile::RewriteProfile; -use crate::rewrite_protocol::RewriteTranscript; - -#[allow(dead_code)] -pub struct LocalRewriter { - model: LlamaModel, - chat_template: LlamaChatTemplate, - profile: ResolvedRewriteProfile, - max_tokens: usize, - max_output_chars: usize, -} - -#[allow(dead_code)] -static LLAMA_BACKEND: OnceLock<&'static LlamaBackend> = OnceLock::new(); -#[allow(dead_code)] -static EXTERNAL_LLAMA_BACKEND: LlamaBackend = LlamaBackend {}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RewritePrompt { - pub system: String, - pub user: String, -} - -impl LocalRewriter { - #[allow(dead_code)] - pub fn new( - model_path: &Path, - profile: ResolvedRewriteProfile, - max_tokens: usize, - max_output_chars: usize, - ) -> std::result::Result { - if !model_path.exists() { - return Err(format!( - "rewrite model file not found: {}", - model_path.display() - )); - } - - let backend = llama_backend()?; - - let mut model_params = LlamaModelParams::default(); - if cfg!(feature = "cuda") { - model_params = model_params.with_n_gpu_layers(1000); - } - - let model = LlamaModel::load_from_file(backend, model_path, &model_params) - .map_err(|e| format!("failed to load rewrite model: {e}"))?; - let chat_template = model - .chat_template(None) - .map_err(|e| format!("rewrite model does not expose a usable chat template: {e}"))?; - - Ok(Self { - model, - chat_template, - profile, - max_tokens, - max_output_chars, - }) - } - - #[allow(dead_code)] - pub fn rewrite_with_instructions( - &self, - transcript: &RewriteTranscript, - custom_instructions: Option<&str>, - ) -> std::result::Result { - if transcript.raw_text.trim().is_empty() { - return Ok(String::new()); - } - - let prompt = build_rewrite_prompt( - &self.model, - &self.chat_template, - transcript, - self.profile, - custom_instructions, - )?; - let effective_max_tokens = effective_max_tokens(self.max_tokens, transcript); - let prompt_tokens = self - .model - .str_to_token(&prompt, AddBos::Never) - .map_err(|e| format!("failed to tokenize rewrite prompt: {e}"))?; - let behavior = rewrite_behavior(self.profile); - - let n_ctx_tokens = prompt_tokens - .len() - .saturating_add(effective_max_tokens) - .saturating_add(64) - .max(2048) - .min(u32::MAX as usize) as u32; - let n_batch = prompt_tokens - .len() - .max(512) - .min(n_ctx_tokens as usize) - .min(u32::MAX as usize) as u32; - let threads = std::thread::available_parallelism() - .map(|threads| threads.get()) - .unwrap_or(4) - .clamp(1, i32::MAX as usize) as i32; - - let ctx_params = LlamaContextParams::default() - .with_n_ctx(NonZeroU32::new(n_ctx_tokens)) - .with_n_batch(n_batch) - .with_n_ubatch(n_batch) - .with_n_threads(threads) - .with_n_threads_batch(threads); - let backend = llama_backend()?; - let mut ctx = self - .model - .new_context(backend, ctx_params) - .map_err(|e| format!("failed to create rewrite context: {e}"))?; - - let mut prompt_batch = LlamaBatch::new(prompt_tokens.len(), 1); - prompt_batch - .add_sequence(&prompt_tokens, 0, false) - .map_err(|e| format!("failed to enqueue rewrite prompt: {e}"))?; - ctx.decode(&mut prompt_batch) - .map_err(|e| format!("failed to decode rewrite prompt: {e}"))?; - - let mut sampler = LlamaSampler::chain_simple([ - LlamaSampler::top_k(behavior.top_k), - LlamaSampler::top_p(behavior.top_p, 1), - LlamaSampler::temp(behavior.temperature), - LlamaSampler::greedy(), - ]); - sampler.accept_many(prompt_tokens.iter()); - - let mut decoder = UTF_8.new_decoder_without_bom_handling(); - let mut output = String::new(); - let start_pos = i32::try_from(prompt_tokens.len()).unwrap_or(i32::MAX); - - for i in 0..effective_max_tokens { - let mut candidates = ctx.token_data_array(); - candidates.apply_sampler(&sampler); - let token = candidates - .selected_token() - .ok_or_else(|| "rewrite sampler did not select a token".to_string())?; - - if token == self.model.token_eos() { - break; - } - - sampler.accept(token); - - let piece = self - .model - .token_to_piece(token, &mut decoder, true, None) - .map_err(|e| format!("failed to decode rewrite token: {e}"))?; - output.push_str(&piece); - - if output.contains("") || output.len() >= self.max_output_chars { - break; - } - - let mut batch = LlamaBatch::new(1, 1); - batch - .add( - token, - start_pos.saturating_add(i32::try_from(i).unwrap_or(i32::MAX)), - &[0], - true, - ) - .map_err(|e| format!("failed to enqueue rewrite token: {e}"))?; - ctx.decode(&mut batch) - .map_err(|e| format!("failed to decode rewrite token: {e}"))?; - } - - let rewritten = sanitize_rewrite_output(&output); - if rewritten.is_empty() { - return Err("rewrite model returned empty output".into()); - } - - Ok(rewritten) - } -} - -#[allow(dead_code)] -fn llama_backend() -> std::result::Result<&'static LlamaBackend, String> { - if let Some(backend) = LLAMA_BACKEND.get().copied() { - return Ok(backend); - } - - match LlamaBackend::init() { - Ok(backend) => { - let backend = Box::leak(Box::new(backend)); - let _ = LLAMA_BACKEND.set(backend); - Ok(LLAMA_BACKEND - .get() - .copied() - .expect("llama backend initialized")) - } - // Use a static non-dropping token when another part of the process already - // owns llama.cpp global initialization. The worker never drops this token. - Err(llama_cpp_2::LlamaCppError::BackendAlreadyInitialized) => { - let _ = LLAMA_BACKEND.set(&EXTERNAL_LLAMA_BACKEND); - Ok(LLAMA_BACKEND - .get() - .copied() - .expect("external llama backend cached")) - } - Err(err) => Err(format!("failed to initialize llama backend: {err}")), - } -} - -#[allow(dead_code)] -fn build_rewrite_prompt( - model: &LlamaModel, - chat_template: &LlamaChatTemplate, - transcript: &RewriteTranscript, - profile: ResolvedRewriteProfile, - custom_instructions: Option<&str>, -) -> std::result::Result { - let prompt = build_prompt(transcript, profile, custom_instructions)?; - let messages = vec![ - LlamaChatMessage::new("system".into(), prompt.system) - .map_err(|e| format!("failed to build rewrite system message: {e}"))?, - LlamaChatMessage::new("user".into(), prompt.user) - .map_err(|e| format!("failed to build rewrite user message: {e}"))?, - ]; - - model - .apply_chat_template(chat_template, &messages, true) - .map_err(|e| format!("failed to apply rewrite chat template: {e}")) -} - -pub fn build_prompt( - transcript: &RewriteTranscript, - profile: ResolvedRewriteProfile, - custom_instructions: Option<&str>, -) -> std::result::Result { - Ok(RewritePrompt { - system: build_system_instructions(profile, custom_instructions), - user: build_user_message(transcript), - }) -} - -#[allow(dead_code)] -pub fn resolved_profile_for_cloud(profile: RewriteProfile) -> ResolvedRewriteProfile { - match profile { - RewriteProfile::Auto => ResolvedRewriteProfile::Generic, - RewriteProfile::Generic => ResolvedRewriteProfile::Generic, - RewriteProfile::Qwen => ResolvedRewriteProfile::Qwen, - RewriteProfile::LlamaCompat => ResolvedRewriteProfile::LlamaCompat, - } -} - -fn build_system_instructions( - profile: ResolvedRewriteProfile, - custom_instructions: Option<&str>, -) -> String { - let mut instructions = rewrite_instructions(profile).to_string(); - if let Some(custom) = custom_instructions - .map(str::trim) - .filter(|text| !text.is_empty()) - { - instructions.push_str("\n\nAdditional user rewrite instructions:\n"); - instructions.push_str(custom); - } - instructions -} - -#[allow(dead_code)] -struct RewriteBehavior { - top_k: i32, - top_p: f32, - temperature: f32, -} - -#[allow(dead_code)] -fn rewrite_behavior(profile: ResolvedRewriteProfile) -> RewriteBehavior { - match profile { - ResolvedRewriteProfile::Qwen => RewriteBehavior { - top_k: 24, - top_p: 0.9, - temperature: 0.1, - }, - ResolvedRewriteProfile::Generic => RewriteBehavior { - top_k: 32, - top_p: 0.92, - temperature: 0.12, - }, - ResolvedRewriteProfile::LlamaCompat => RewriteBehavior { - top_k: 40, - top_p: 0.95, - temperature: 0.15, - }, - } -} - -fn rewrite_instructions(profile: ResolvedRewriteProfile) -> &'static str { - let base = "You clean up dictated speech into the final text the user meant to type. \ -Return only the finished text. Do not explain anything. Remove obvious disfluencies when natural. \ -Use the correction-aware transcript as the primary source of truth unless structured edit signals say the \ -utterance may still be ambiguous. The raw transcript may still contain spoken editing phrases or canceled wording. \ -Never reintroduce text that was removed by an explicit spoken correction cue. Respect any structured edit intents \ -provided alongside the transcript. If structured edit signals or edit hypotheses are present, use the candidate \ -interpretations as bounded options, choose the best interpretation, and lightly refine it only when needed for natural \ -final text. Prefer transcript spellings for names, brands, and uncommon proper nouns unless a user dictionary or \ -explicit correction says otherwise. Do not normalize names into more common spellings just because they look familiar. \ -If an edit intent says to replace or cancel previous wording, preserve that edit and do not keep the spoken correction \ -phrase itself unless the transcript clearly still intends it. Examples:\n\ -- raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ -- raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ -- raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ -- raw: Never mind. Hi, how are you today?\n correction-aware: Hi, how are you today?\n final: Hi, how are you today?\n\ -- raw: Wait, no, it actually works.\n correction-aware: Wait, no, it actually works.\n final: Wait, no, it actually works.\n\ -- raw: Let's meet tomorrow, or rather Friday.\n correction-aware: Let's meet Friday.\n final: Let's meet Friday."; - - match profile { - ResolvedRewriteProfile::Qwen => { - "You clean up dictated speech into the final text the user meant to type. \ -Return only the finished text. Do not explain anything. Do not emit reasoning, think tags, or XML wrappers. \ -Remove obvious disfluencies when natural. Use the correction-aware transcript as the primary source of truth unless \ -structured edit signals say the utterance may still be ambiguous. The raw transcript may still contain spoken editing \ -phrases or canceled wording. Never reintroduce text that was removed by an explicit spoken correction cue. Respect \ -any structured edit intents provided alongside the transcript. If structured edit signals or edit hypotheses are \ -present, use the candidate interpretations as bounded options, choose the best interpretation, and lightly refine it \ -only when needed for natural final text. Prefer transcript spellings for names, brands, and uncommon proper nouns \ -unless a user dictionary or explicit correction says otherwise. Do not normalize names into more common spellings just \ -because they look familiar. If an edit intent says to replace or cancel previous wording, preserve that edit and do \ -not keep the spoken correction phrase itself unless the transcript clearly still intends it. Examples:\n\ -- raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ -- raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ -- raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ -- raw: Never mind. Hi, how are you today?\n correction-aware: Hi, how are you today?\n final: Hi, how are you today?\n\ -- raw: Wait, no, it actually works.\n correction-aware: Wait, no, it actually works.\n final: Wait, no, it actually works.\n\ -- raw: Let's meet tomorrow, or rather Friday.\n correction-aware: Let's meet Friday.\n final: Let's meet Friday." - } - ResolvedRewriteProfile::Generic | ResolvedRewriteProfile::LlamaCompat => base, - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum RewriteRoute { - Fast, - ResolvedCorrection, - SessionCandidateAdjudication, - CandidateAdjudication, -} - -fn build_user_message(transcript: &RewriteTranscript) -> String { - let language = transcript.detected_language.as_deref().unwrap_or("unknown"); - let correction_aware = transcript.correction_aware_text.trim(); - let raw = transcript.raw_text.trim(); - let edit_intents = render_edit_intents(transcript); - let edit_signals = render_edit_signals(transcript); - let route = rewrite_route(transcript); - tracing::debug!( - route = ?route, - edit_signals = transcript.edit_signals.len(), - edit_hypotheses = transcript.edit_hypotheses.len(), - rewrite_candidates = transcript.rewrite_candidates.len(), - "rewrite prompt route selected" - ); - - match route { - RewriteRoute::SessionCandidateAdjudication => { - let typing_context = render_typing_context(transcript); - let recent_session_entries = render_recent_session_entries(transcript); - let session_candidates = render_session_backtrack_candidates(transcript); - let recommended_session_candidate = render_recommended_session_candidate(transcript); - let rewrite_candidates = render_rewrite_candidates(transcript); - let surface_guidance = transcript - .typing_context - .as_ref() - .filter(|context| { - matches!( - context.surface_kind, - crate::rewrite_protocol::RewriteSurfaceKind::Terminal - ) - }) - .map(|_| { - "The active surface looks like a terminal. Stay conservative unless an explicit correction cue clearly indicates replacing the most recent prior dictation.\n" - }) - .unwrap_or(""); - format!( - "Language: {language}\n\ -Active typing context:\n\ -{typing_context}\ -Recent dictation session:\n\ -{recent_session_entries}\ -Session backtrack candidates:\n\ -{session_candidates}\ -{recommended_session_candidate}\ -The user may be correcting the most recent prior dictation entry rather than appending new text.\n\ -If the recommended session candidate says replace_last_entry, treat your final text as the replacement text for that previous dictation entry, not as newly appended text.\n\ -Prefer the recommended session candidate unless another listed session candidate is clearly better.\n\ -{surface_guidance}\ -Current utterance correction candidate:\n\ -{correction_aware}\n\ -Raw current utterance:\n\ -{raw}\n\ -Current utterance bounded candidates:\n\ -{rewrite_candidates}\ -Final text:" - ) - } - RewriteRoute::CandidateAdjudication => { - let edit_hypotheses = render_edit_hypotheses(transcript); - let rewrite_candidates = render_rewrite_candidates(transcript); - let recommended_candidate = render_recommended_candidate(transcript); - let recent_segments = render_recent_segments(transcript, 4); - let aggressive_candidate = render_aggressive_candidate(transcript); - let exact_cue_guidance = if has_strong_explicit_edit_cue(transcript) { - "A strong explicit spoken edit cue was detected. The literal raw transcript probably contains canceled wording. \ -Prefer a candidate interpretation that removes the cue and canceled wording unless doing so would clearly lose intended meaning. \ -If the cue is an exact strong match for phrases like scratch that, never mind, or wait no, do not keep the literal cue text in the final output.\n" - } else { - "" - }; - tracing::trace!("rewrite hypotheses:\n{edit_hypotheses}"); - tracing::trace!("rewrite candidates:\n{rewrite_candidates}"); - format!( - "Language: {language}\n\ -Structured edit hypotheses:\n\ -{edit_hypotheses}\ -Structured edit signals:\n\ -{edit_signals}\ -Structured edit intents:\n\ -{edit_intents}\ -This utterance likely contains spoken self-corrections or restatements.\n\ -Choose the best candidate interpretation and lightly refine it only when needed.\n\ -{exact_cue_guidance}\ -When an exact strong edit cue is present, treat the non-literal candidates as more trustworthy than the literal transcript.\n\ -The candidate list is ordered from most likely to least likely by heuristics.\n\ -For exact strong edit cues, the first candidate is the heuristic best guess and should usually win unless another candidate is clearly better.\n\ -Prefer the smallest replacement scope that yields a coherent result.\n\ -Use span-level replacements when only a key phrase was corrected, clause-level replacements when the correction replaces the surrounding thought, and sentence-level replacements only when the whole sentence was canceled.\n\ -Preserve literal wording when the cue is not actually an edit.\n\ -Do not over-normalize names or brands.\n\ -Do not keep spoken edit cues in the final text when they act as edits.\n\ -{recommended_candidate}\ -Candidate interpretations:\n\ -{rewrite_candidates}\ -Correction candidate:\n\ -{correction_aware}\n\ -{aggressive_candidate}\ -Raw transcript:\n\ -{raw}\n\ -Recent segments:\n\ -{recent_segments}\n\ -Final text:" - ) - } - RewriteRoute::ResolvedCorrection => format!( - "Language: {language}\n\ -Structured edit signals:\n\ -{edit_signals}\ -Structured edit intents:\n\ -{edit_intents}\ -Self-corrections were already resolved before rewriting.\n\ -Use only this correction-aware transcript as the source text:\n\ -{correction_aware}\n\ -Do not restore any canceled wording from earlier in the utterance.\n\ -Final text:" - ), - RewriteRoute::Fast => { - let recent_segments = render_recent_segments(transcript, 4); - format!( - "Language: {language}\n\ -Structured edit signals:\n\ -{edit_signals}\ -Structured edit intents:\n\ -{edit_intents}\ -Correction-aware transcript:\n\ -{correction_aware}\n\ -Treat the correction-aware transcript as authoritative for explicit spoken edits.\n\ -\ -Recent segments:\n\ -{recent_segments}\n\ -Final text:", - ) - } - } -} - -fn rewrite_route(transcript: &RewriteTranscript) -> RewriteRoute { - if has_session_backtrack_candidate(transcript) { - RewriteRoute::SessionCandidateAdjudication - } else if requires_candidate_adjudication(transcript) { - RewriteRoute::CandidateAdjudication - } else if transcript.correction_aware_text.trim() != transcript.raw_text.trim() { - RewriteRoute::ResolvedCorrection - } else { - RewriteRoute::Fast - } -} - -fn requires_candidate_adjudication(transcript: &RewriteTranscript) -> bool { - !transcript.edit_signals.is_empty() || !transcript.edit_hypotheses.is_empty() -} - -fn has_strong_explicit_edit_cue(transcript: &RewriteTranscript) -> bool { - transcript.edit_hypotheses.iter().any(|hypothesis| { - hypothesis.strength == crate::rewrite_protocol::RewriteEditSignalStrength::Strong - && matches!( - hypothesis.match_source, - crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact - | crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Alias - ) - }) -} - -fn has_session_backtrack_candidate(transcript: &RewriteTranscript) -> bool { - transcript.recommended_session_candidate.is_some() - || !transcript.session_backtrack_candidates.is_empty() -} - -fn render_edit_intents(transcript: &RewriteTranscript) -> String { - if transcript.edit_intents.is_empty() { - return "- none detected\n".to_string(); - } - - let mut rendered = String::new(); - for intent in &transcript.edit_intents { - let action = match intent.action { - crate::rewrite_protocol::RewriteEditAction::ReplacePreviousPhrase => { - "replace_previous_phrase" - } - crate::rewrite_protocol::RewriteEditAction::ReplacePreviousClause => { - "replace_previous_clause" - } - crate::rewrite_protocol::RewriteEditAction::ReplacePreviousSentence => { - "replace_previous_sentence" - } - crate::rewrite_protocol::RewriteEditAction::DropEditCue => "drop_edit_cue", - }; - let confidence = match intent.confidence { - crate::rewrite_protocol::RewriteIntentConfidence::High => "high", - }; - rendered.push_str(&format!( - "- action: {action}, trigger: \"{}\", confidence: {confidence}\n", - intent.trigger - )); - } - - rendered -} - -fn render_edit_signals(transcript: &RewriteTranscript) -> String { - if transcript.edit_signals.is_empty() { - return "- none detected\n".to_string(); - } - - let mut rendered = String::new(); - for signal in &transcript.edit_signals { - let kind = match signal.kind { - crate::rewrite_protocol::RewriteEditSignalKind::Cancel => "cancel", - crate::rewrite_protocol::RewriteEditSignalKind::Replace => "replace", - crate::rewrite_protocol::RewriteEditSignalKind::Restatement => "restatement", - }; - let scope_hint = match signal.scope_hint { - crate::rewrite_protocol::RewriteEditSignalScope::Phrase => "phrase", - crate::rewrite_protocol::RewriteEditSignalScope::Clause => "clause", - crate::rewrite_protocol::RewriteEditSignalScope::Sentence => "sentence", - crate::rewrite_protocol::RewriteEditSignalScope::Unknown => "unknown", - }; - let strength = match signal.strength { - crate::rewrite_protocol::RewriteEditSignalStrength::Possible => "possible", - crate::rewrite_protocol::RewriteEditSignalStrength::Strong => "strong", - }; - rendered.push_str(&format!( - "- trigger: \"{}\", kind: {kind}, scope_hint: {scope_hint}, strength: {strength}\n", - signal.trigger - )); - } - - rendered -} - -fn render_edit_hypotheses(transcript: &RewriteTranscript) -> String { - if transcript.edit_hypotheses.is_empty() { - return "- none detected\n".to_string(); - } - - let mut rendered = String::new(); - for hypothesis in &transcript.edit_hypotheses { - let match_source = match hypothesis.match_source { - crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact => "exact", - crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Alias => "alias", - crate::rewrite_protocol::RewriteEditHypothesisMatchSource::NearMiss => "near_miss", - }; - let kind = match hypothesis.kind { - crate::rewrite_protocol::RewriteEditSignalKind::Cancel => "cancel", - crate::rewrite_protocol::RewriteEditSignalKind::Replace => "replace", - crate::rewrite_protocol::RewriteEditSignalKind::Restatement => "restatement", - }; - let scope_hint = match hypothesis.scope_hint { - crate::rewrite_protocol::RewriteEditSignalScope::Phrase => "phrase", - crate::rewrite_protocol::RewriteEditSignalScope::Clause => "clause", - crate::rewrite_protocol::RewriteEditSignalScope::Sentence => "sentence", - crate::rewrite_protocol::RewriteEditSignalScope::Unknown => "unknown", - }; - let strength = match hypothesis.strength { - crate::rewrite_protocol::RewriteEditSignalStrength::Possible => "possible", - crate::rewrite_protocol::RewriteEditSignalStrength::Strong => "strong", - }; - let replacement_scope = match hypothesis.replacement_scope { - crate::rewrite_protocol::RewriteReplacementScope::Span => "span", - crate::rewrite_protocol::RewriteReplacementScope::Clause => "clause", - crate::rewrite_protocol::RewriteReplacementScope::Sentence => "sentence", - }; - let tail_shape = match hypothesis.tail_shape { - crate::rewrite_protocol::RewriteTailShape::Empty => "empty", - crate::rewrite_protocol::RewriteTailShape::Phrase => "phrase", - crate::rewrite_protocol::RewriteTailShape::Clause => "clause", - }; - rendered.push_str(&format!( - "- cue_family: {}, matched_text: \"{}\", match_source: {match_source}, kind: {kind}, scope_hint: {scope_hint}, replacement_scope: {replacement_scope}, tail_shape: {tail_shape}, strength: {strength}\n", - hypothesis.cue_family, hypothesis.matched_text - )); - } - - rendered -} - -fn render_rewrite_candidates(transcript: &RewriteTranscript) -> String { - if transcript.rewrite_candidates.is_empty() { - return "- no candidates available\n".to_string(); - } - - let mut rendered = String::new(); - let highlight_first = has_strong_explicit_edit_cue(transcript); - for (index, candidate) in transcript.rewrite_candidates.iter().enumerate() { - let prefix = if highlight_first && index == 0 { - "- preferred_candidate" - } else { - "-" - }; - let kind = match candidate.kind { - crate::rewrite_protocol::RewriteCandidateKind::Literal => { - "literal (keep only if the cue was not actually an edit)" - } - crate::rewrite_protocol::RewriteCandidateKind::ConservativeCorrection => { - "conservative_correction (balanced cleanup)" - } - crate::rewrite_protocol::RewriteCandidateKind::AggressiveCorrection => { - "aggressive_correction (use when canceled wording should be removed more fully)" - } - crate::rewrite_protocol::RewriteCandidateKind::SpanReplacement => { - "span_replacement (replace only the corrected phrase)" - } - crate::rewrite_protocol::RewriteCandidateKind::ClauseReplacement => { - "clause_replacement (replace the corrected clause while keeping surrounding context)" - } - crate::rewrite_protocol::RewriteCandidateKind::SentenceReplacement => { - "sentence_replacement (replace the whole corrected sentence)" - } - crate::rewrite_protocol::RewriteCandidateKind::ContextualReplacement => { - "contextual_replacement (replace the corrected span while keeping earlier context)" - } - crate::rewrite_protocol::RewriteCandidateKind::DropCueOnly => { - "drop_cue_only (remove just the spoken edit cue)" - } - crate::rewrite_protocol::RewriteCandidateKind::FollowingReplacement => { - "following_replacement (keep only the wording after the cue)" - } - crate::rewrite_protocol::RewriteCandidateKind::CancelPreviousClause => { - "cancel_previous_clause (treat the cue as canceling the prior clause)" - } - crate::rewrite_protocol::RewriteCandidateKind::CancelPreviousSentence => { - "cancel_previous_sentence (treat the cue as canceling the prior sentence)" - } - }; - rendered.push_str(&format!("{prefix} {kind}: {}\n", candidate.text)); - } - - rendered -} - -fn render_recommended_candidate(transcript: &RewriteTranscript) -> String { - transcript - .recommended_candidate - .as_ref() - .map(|candidate| { - format!( - "Recommended interpretation:\n{}\nUse this as the default final text unless another candidate is clearly better.\n", - candidate.text - ) - }) - .unwrap_or_default() -} - -fn render_typing_context(transcript: &RewriteTranscript) -> String { - transcript - .typing_context - .as_ref() - .map(|context| { - format!( - "- focus_fingerprint: {}\n- app_id: {}\n- window_title: {}\n- surface_kind: {}\n", - context.focus_fingerprint, - context.app_id.as_deref().unwrap_or("unknown"), - context.window_title.as_deref().unwrap_or("unknown"), - match context.surface_kind { - crate::rewrite_protocol::RewriteSurfaceKind::Browser => "browser", - crate::rewrite_protocol::RewriteSurfaceKind::Terminal => "terminal", - crate::rewrite_protocol::RewriteSurfaceKind::Editor => "editor", - crate::rewrite_protocol::RewriteSurfaceKind::GenericText => "generic_text", - crate::rewrite_protocol::RewriteSurfaceKind::Unknown => "unknown", - } - ) - }) - .unwrap_or_else(|| "- none available\n".to_string()) -} - -fn render_recent_session_entries(transcript: &RewriteTranscript) -> String { - if transcript.recent_session_entries.is_empty() { - return "- none available\n".to_string(); - } - - let mut rendered = String::new(); - for entry in &transcript.recent_session_entries { - rendered.push_str(&format!( - "- id: {}, text: {}, grapheme_len: {}, surface_kind: {}\n", - entry.id, - entry.final_text, - entry.grapheme_len, - match entry.surface_kind { - crate::rewrite_protocol::RewriteSurfaceKind::Browser => "browser", - crate::rewrite_protocol::RewriteSurfaceKind::Terminal => "terminal", - crate::rewrite_protocol::RewriteSurfaceKind::Editor => "editor", - crate::rewrite_protocol::RewriteSurfaceKind::GenericText => "generic_text", - crate::rewrite_protocol::RewriteSurfaceKind::Unknown => "unknown", - } - )); - } - rendered -} - -fn render_session_backtrack_candidates(transcript: &RewriteTranscript) -> String { - if transcript.session_backtrack_candidates.is_empty() { - return "- no session backtrack candidates\n".to_string(); - } - - let mut rendered = String::new(); - for candidate in &transcript.session_backtrack_candidates { - let kind = match candidate.kind { - crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::AppendCurrent => { - "append_current" - } - crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::ReplaceLastEntry => { - "replace_last_entry" - } - }; - rendered.push_str(&format!( - "- kind: {kind}, entry_id: {}, delete_graphemes: {}, text: {}\n", - candidate - .entry_id - .map(|entry_id| entry_id.to_string()) - .unwrap_or_else(|| "none".to_string()), - candidate.delete_graphemes, - candidate.text - )); - } - rendered -} - -fn render_recommended_session_candidate(transcript: &RewriteTranscript) -> String { - transcript - .recommended_session_candidate - .as_ref() - .map(|candidate| { - let mode = match candidate.kind { - crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::AppendCurrent => { - "append_current" - } - crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::ReplaceLastEntry => { - "replace_last_entry" - } - }; - format!( - "Recommended session action:\nmode: {mode}\nentry_id: {}\ndelete_graphemes: {}\ntext: {}\n", - candidate - .entry_id - .map(|entry_id| entry_id.to_string()) - .unwrap_or_else(|| "none".to_string()), - candidate.delete_graphemes, - candidate.text - ) - }) - .unwrap_or_default() -} - -fn render_recent_segments(transcript: &RewriteTranscript, limit: usize) -> String { - let total_segments = transcript.segments.len(); - let start = total_segments.saturating_sub(limit); - let mut rendered = String::new(); - - for segment in &transcript.segments[start..] { - let line = format!( - "- {}-{} ms: {}\n", - segment.start_ms, segment.end_ms, segment.text - ); - rendered.push_str(&line); - } - - if rendered.is_empty() { - rendered.push_str("- no segments available\n"); - } - - rendered -} - -fn render_aggressive_candidate(transcript: &RewriteTranscript) -> String { - transcript - .aggressive_correction_text - .as_deref() - .map(str::trim) - .filter(|text| !text.is_empty()) - .map(|text| format!("Aggressive correction candidate:\n{text}\n")) - .unwrap_or_default() -} - -#[allow(dead_code)] -fn effective_max_tokens(max_tokens: usize, transcript: &RewriteTranscript) -> usize { - let word_count = transcript - .correction_aware_text - .split_whitespace() - .filter(|word| !word.is_empty()) - .count(); - let extra_budget = if requires_candidate_adjudication(transcript) { - 24 - } else { - 0 - }; - let minimum = if requires_candidate_adjudication(transcript) { - 64 - } else { - 48 - }; - let derived = word_count - .saturating_mul(2) - .saturating_add(24) - .saturating_add(extra_budget); - derived.clamp(minimum, max_tokens) -} - -pub(crate) fn sanitize_rewrite_output(raw: &str) -> String { - let mut text = raw.replace("\r\n", "\n"); - - for stop in ["<|eot_id|>", "<|end_of_text|>", ""] { - if let Some(index) = text.find(stop) { - text.truncate(index); - } - } - - if let Some(index) = text.find("") { - text.truncate(index); - } - - text = strip_tagged_section(&text, "", ""); - - let mut text = text.trim().to_string(); - - if let Some(stripped) = text.strip_prefix("") { - text = stripped.trim().to_string(); - } - - for prefix in ["Final text:", "Output:", "Rewritten text:"] { - if text - .get(..prefix.len()) - .map(|candidate| candidate.eq_ignore_ascii_case(prefix)) - .unwrap_or(false) - { - text = text[prefix.len()..].trim().to_string(); - break; - } - } - - if text.starts_with('"') && text.ends_with('"') && text.len() >= 2 { - text = text[1..text.len() - 1].trim().to_string(); - } - - text -} - -fn strip_tagged_section(input: &str, open: &str, close: &str) -> String { - let mut output = input.to_string(); - - while let Some(start) = output.find(open) { - let close_start = match output[start + open.len()..].find(close) { - Some(offset) => start + open.len() + offset, - None => { - output.truncate(start); - break; - } - }; - output.replace_range(start..close_start + close.len(), ""); - } - - output -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::rewrite_protocol::{ - RewriteCandidate, RewriteCandidateKind, RewriteEditAction, RewriteEditHypothesis, - RewriteEditHypothesisMatchSource, RewriteEditIntent, RewriteEditSignal, - RewriteEditSignalKind, RewriteEditSignalScope, RewriteEditSignalStrength, - RewriteIntentConfidence, RewriteReplacementScope, RewriteSessionBacktrackCandidate, - RewriteSessionBacktrackCandidateKind, RewriteSessionEntry, RewriteSurfaceKind, - RewriteTailShape, RewriteTranscript, RewriteTranscriptSegment, RewriteTypingContext, - }; - - fn correction_transcript() -> RewriteTranscript { - RewriteTranscript { - raw_text: "Hi there, this is a test. Wait, no. Hi there.".into(), - correction_aware_text: "Hi there.".into(), - aggressive_correction_text: None, - detected_language: Some("en".into()), - typing_context: None, - recent_session_entries: Vec::new(), - session_backtrack_candidates: Vec::new(), - recommended_session_candidate: None, - segments: vec![ - RewriteTranscriptSegment { - text: "Hi there, this is a test.".into(), - start_ms: 0, - end_ms: 1200, - }, - RewriteTranscriptSegment { - text: "Wait, no. Hi there.".into(), - start_ms: 1200, - end_ms: 2200, - }, - ], - edit_intents: vec![RewriteEditIntent { - action: RewriteEditAction::ReplacePreviousSentence, - trigger: "wait no".into(), - confidence: RewriteIntentConfidence::High, - }], - edit_signals: vec![RewriteEditSignal { - trigger: "wait no".into(), - kind: RewriteEditSignalKind::Replace, - scope_hint: RewriteEditSignalScope::Sentence, - strength: RewriteEditSignalStrength::Strong, - }], - edit_hypotheses: vec![RewriteEditHypothesis { - cue_family: "wait_no".into(), - matched_text: "wait no".into(), - match_source: RewriteEditHypothesisMatchSource::Exact, - kind: RewriteEditSignalKind::Replace, - scope_hint: RewriteEditSignalScope::Sentence, - replacement_scope: RewriteReplacementScope::Sentence, - tail_shape: RewriteTailShape::Phrase, - strength: RewriteEditSignalStrength::Strong, - }], - rewrite_candidates: vec![ - RewriteCandidate { - kind: RewriteCandidateKind::Literal, - text: "Hi there, this is a test. Wait, no. Hi there.".into(), - }, - RewriteCandidate { - kind: RewriteCandidateKind::ConservativeCorrection, - text: "Hi there.".into(), - }, - ], - recommended_candidate: Some(RewriteCandidate { - kind: RewriteCandidateKind::Literal, - text: "Hi there, this is a test. Wait, no. Hi there.".into(), - }), - } - } - - fn candidate_only_transcript() -> RewriteTranscript { - RewriteTranscript { - raw_text: "Hi there, this is a test. Scratch that. Hi there.".into(), - correction_aware_text: "Hi there.".into(), - aggressive_correction_text: None, - detected_language: Some("en".into()), - typing_context: None, - recent_session_entries: Vec::new(), - session_backtrack_candidates: Vec::new(), - recommended_session_candidate: None, - segments: Vec::new(), - edit_intents: vec![RewriteEditIntent { - action: RewriteEditAction::ReplacePreviousSentence, - trigger: "scratch that".into(), - confidence: RewriteIntentConfidence::High, - }], - edit_signals: Vec::new(), - edit_hypotheses: Vec::new(), - rewrite_candidates: vec![ - RewriteCandidate { - kind: RewriteCandidateKind::Literal, - text: "Hi there, this is a test. Scratch that. Hi there.".into(), - }, - RewriteCandidate { - kind: RewriteCandidateKind::ConservativeCorrection, - text: "Hi there.".into(), - }, - ], - recommended_candidate: None, - } - } - - #[test] - fn instructions_cover_self_correction_examples() { - let instructions = rewrite_instructions(ResolvedRewriteProfile::LlamaCompat); - assert!(instructions.contains("Return only the finished text")); - assert!(instructions.contains("Never reintroduce text")); - assert!(instructions.contains("scratch that, brownies")); - } - - #[test] - fn qwen_instructions_forbid_reasoning_tags() { - let instructions = rewrite_instructions(ResolvedRewriteProfile::Qwen); - assert!(instructions.contains("Do not emit reasoning")); - } - - #[test] - fn custom_instructions_append_to_system_prompt() { - let instructions = build_system_instructions( - ResolvedRewriteProfile::Qwen, - Some("Keep product names exact."), - ); - assert!(instructions.contains("Return only the finished text")); - assert!(instructions.contains("Keep product names exact.")); - } - - #[test] - fn cue_prompt_includes_raw_candidate_and_signals() { - let prompt = build_user_message(&correction_transcript()); - assert!(matches!( - rewrite_route(&correction_transcript()), - RewriteRoute::CandidateAdjudication - )); - assert!(prompt.contains("Structured edit hypotheses")); - assert!(prompt.contains("cue_family: wait_no")); - assert!(prompt.contains("replacement_scope: sentence")); - assert!(prompt.contains("tail_shape: phrase")); - assert!(prompt.contains("Candidate interpretations")); - assert!(prompt.contains("A strong explicit spoken edit cue was detected")); - assert!(prompt.contains( - "The candidate list is ordered from most likely to least likely by heuristics." - )); - assert!(prompt.contains("the first candidate is the heuristic best guess")); - assert!(prompt.contains("Recommended interpretation:")); - assert!(prompt.contains( - "Use this as the default final text unless another candidate is clearly better." - )); - assert!( - prompt.contains("Prefer the smallest replacement scope that yields a coherent result.") - ); - assert!(prompt.contains("- preferred_candidate")); - assert!(prompt.contains( - "- preferred_candidate literal (keep only if the cue was not actually an edit): Hi there, this is a test. Wait, no. Hi there." - )); - assert!(prompt.contains("Structured edit signals")); - assert!(prompt.contains("trigger: \"wait no\"")); - assert!(prompt.contains("Structured edit intents")); - assert!(prompt.contains("replace_previous_sentence")); - assert!(prompt.contains("Choose the best candidate interpretation")); - assert!(prompt.contains("Candidate interpretations:\n")); - assert!(prompt.contains("Correction candidate:\nHi there.")); - assert!(prompt.contains("Raw transcript:\nHi there, this is a test. Wait, no. Hi there.")); - assert!(prompt.contains("Recent segments")); - } - - #[test] - fn cue_prompt_includes_aggressive_candidate_when_available() { - let mut transcript = correction_transcript(); - transcript.aggressive_correction_text = Some("Hi there.".into()); - - let prompt = build_user_message(&transcript); - assert!(prompt.contains("Aggressive correction candidate")); - } - - #[test] - fn user_message_prefers_correction_candidate_without_signals() { - let prompt = build_user_message(&candidate_only_transcript()); - assert!(matches!( - rewrite_route(&candidate_only_transcript()), - RewriteRoute::ResolvedCorrection - )); - assert!(!prompt.contains("Recommended interpretation:")); - assert!(prompt.contains("Structured edit signals")); - assert!(prompt.contains("Structured edit intents")); - assert!(prompt.contains("Self-corrections were already resolved")); - assert!(prompt.contains("Do not restore any canceled wording")); - assert!(!prompt.contains("Recent segments")); - assert!(!prompt.contains("Raw transcript")); - } - - #[test] - fn user_message_includes_recent_segments_when_correction_matches_raw() { - let transcript = RewriteTranscript { - raw_text: "Hi there.".into(), - correction_aware_text: "Hi there.".into(), - aggressive_correction_text: None, - detected_language: Some("en".into()), - typing_context: None, - recent_session_entries: Vec::new(), - session_backtrack_candidates: Vec::new(), - recommended_session_candidate: None, - segments: vec![RewriteTranscriptSegment { - text: "Hi there.".into(), - start_ms: 0, - end_ms: 1200, - }], - edit_intents: Vec::new(), - edit_signals: Vec::new(), - edit_hypotheses: Vec::new(), - rewrite_candidates: vec![RewriteCandidate { - kind: RewriteCandidateKind::Literal, - text: "Hi there.".into(), - }], - recommended_candidate: None, - }; - - let prompt = build_user_message(&transcript); - assert!(matches!(rewrite_route(&transcript), RewriteRoute::Fast)); - assert!(prompt.contains("Correction-aware transcript")); - assert!(prompt.contains("Structured edit signals")); - assert!(prompt.contains("Recent segments")); - assert!(prompt.contains("0-1200 ms")); - assert!(prompt.contains("Hi there.")); - } - - #[test] - fn effective_max_tokens_scales_with_transcript_length() { - let short = RewriteTranscript { - raw_text: "hi there".into(), - correction_aware_text: "hi there".into(), - aggressive_correction_text: None, - detected_language: Some("en".into()), - typing_context: None, - recent_session_entries: Vec::new(), - session_backtrack_candidates: Vec::new(), - recommended_session_candidate: None, - segments: Vec::new(), - edit_intents: Vec::new(), - edit_signals: Vec::new(), - edit_hypotheses: Vec::new(), - rewrite_candidates: vec![RewriteCandidate { - kind: RewriteCandidateKind::Literal, - text: "hi there".into(), - }], - recommended_candidate: None, - }; - assert_eq!(effective_max_tokens(256, &short), 48); - - let long = RewriteTranscript { - raw_text: "word ".repeat(80), - correction_aware_text: "word ".repeat(80), - aggressive_correction_text: None, - detected_language: Some("en".into()), - typing_context: None, - recent_session_entries: Vec::new(), - session_backtrack_candidates: Vec::new(), - recommended_session_candidate: None, - segments: Vec::new(), - edit_intents: Vec::new(), - edit_signals: Vec::new(), - edit_hypotheses: Vec::new(), - rewrite_candidates: vec![RewriteCandidate { - kind: RewriteCandidateKind::Literal, - text: "word ".repeat(80), - }], - recommended_candidate: None, - }; - assert_eq!(effective_max_tokens(256, &long), 184); - } - - #[test] - fn effective_max_tokens_gives_edit_heavy_prompts_more_budget() { - let transcript = correction_transcript(); - assert_eq!(effective_max_tokens(256, &transcript), 64); - } - - #[test] - fn session_prompt_includes_recent_entry_and_context() { - let transcript = RewriteTranscript { - raw_text: "scratch that hi".into(), - correction_aware_text: "Hi".into(), - aggressive_correction_text: None, - detected_language: Some("en".into()), - typing_context: Some(RewriteTypingContext { - focus_fingerprint: "hyprland:0x123".into(), - app_id: Some("firefox".into()), - window_title: Some("Example".into()), - surface_kind: RewriteSurfaceKind::Browser, - browser_domain: None, - captured_at_ms: 10, - }), - recent_session_entries: vec![RewriteSessionEntry { - id: 7, - final_text: "Hello there".into(), - grapheme_len: 11, - focus_fingerprint: "hyprland:0x123".into(), - surface_kind: RewriteSurfaceKind::Browser, - app_id: Some("firefox".into()), - window_title: Some("Example".into()), - }], - session_backtrack_candidates: vec![ - RewriteSessionBacktrackCandidate { - kind: RewriteSessionBacktrackCandidateKind::ReplaceLastEntry, - entry_id: Some(7), - delete_graphemes: 11, - text: "Hi".into(), - }, - RewriteSessionBacktrackCandidate { - kind: RewriteSessionBacktrackCandidateKind::AppendCurrent, - entry_id: None, - delete_graphemes: 0, - text: "Hi".into(), - }, - ], - recommended_session_candidate: Some(RewriteSessionBacktrackCandidate { - kind: RewriteSessionBacktrackCandidateKind::ReplaceLastEntry, - entry_id: Some(7), - delete_graphemes: 11, - text: "Hi".into(), - }), - segments: Vec::new(), - edit_intents: Vec::new(), - edit_signals: Vec::new(), - edit_hypotheses: Vec::new(), - rewrite_candidates: vec![RewriteCandidate { - kind: RewriteCandidateKind::SentenceReplacement, - text: "Hi".into(), - }], - recommended_candidate: Some(RewriteCandidate { - kind: RewriteCandidateKind::SentenceReplacement, - text: "Hi".into(), - }), - }; - - let prompt = build_user_message(&transcript); - assert!(matches!( - rewrite_route(&transcript), - RewriteRoute::SessionCandidateAdjudication - )); - assert!(prompt.contains("Active typing context")); - assert!(prompt.contains("Recent dictation session")); - assert!(prompt.contains("replace_last_entry")); - assert!(prompt.contains("treat your final text as the replacement text")); - } - - #[test] - fn sanitize_rewrite_output_strips_wrapper_and_label() { - let cleaned = sanitize_rewrite_output("\nFinal text: Hi there.\n"); - assert_eq!(cleaned, "Hi there."); - } - - #[test] - fn sanitize_rewrite_output_strips_llama_stop_tokens() { - let cleaned = sanitize_rewrite_output("Hi there.<|eot_id|>ignored"); - assert_eq!(cleaned, "Hi there."); - } - - #[test] - fn sanitize_rewrite_output_strips_think_blocks() { - let cleaned = sanitize_rewrite_output("reasoning\nHi there."); - assert_eq!(cleaned, "Hi there."); - } -} diff --git a/src/rewrite/mod.rs b/src/rewrite/mod.rs new file mode 100644 index 0000000..ad8d63b --- /dev/null +++ b/src/rewrite/mod.rs @@ -0,0 +1,27 @@ +mod output; +mod prompt; +mod routing; + +#[cfg(test)] +mod tests; + +pub use output::sanitize_rewrite_output; +pub use prompt::{ + build_oaicompat_messages_json, build_prompt, effective_max_tokens, resolved_profile_for_cloud, +}; + +#[cfg(feature = "local-rewrite")] +pub const fn local_rewrite_available() -> bool { + true +} + +#[cfg(not(feature = "local-rewrite"))] +pub const fn local_rewrite_available() -> bool { + false +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RewritePrompt { + pub system: String, + pub user: String, +} diff --git a/src/rewrite/output.rs b/src/rewrite/output.rs new file mode 100644 index 0000000..708ae9f --- /dev/null +++ b/src/rewrite/output.rs @@ -0,0 +1,78 @@ +pub fn sanitize_rewrite_output(raw: &str) -> String { + let mut text = raw.replace("\r\n", "\n"); + + for stop in ["<|eot_id|>", "<|end_of_text|>", ""] { + if let Some(index) = text.find(stop) { + text.truncate(index); + } + } + + if let Some(index) = text.find("") { + text.truncate(index); + } + + text = strip_tagged_section(&text, "", ""); + + let mut text = text.trim().to_string(); + + if let Some(stripped) = text.strip_prefix("") { + text = stripped.trim().to_string(); + } + + for prefix in ["Final text:", "Output:", "Rewritten text:"] { + if text + .get(..prefix.len()) + .map(|candidate| candidate.eq_ignore_ascii_case(prefix)) + .unwrap_or(false) + { + text = text[prefix.len()..].trim().to_string(); + break; + } + } + + if text.starts_with('"') && text.ends_with('"') && text.len() >= 2 { + text = text[1..text.len() - 1].trim().to_string(); + } + + text +} + +fn strip_tagged_section(input: &str, open: &str, close: &str) -> String { + let mut output = input.to_string(); + + while let Some(start) = output.find(open) { + let close_start = match output[start + open.len()..].find(close) { + Some(offset) => start + open.len() + offset, + None => { + output.truncate(start); + break; + } + }; + output.replace_range(start..close_start + close.len(), ""); + } + + output +} + +#[cfg(test)] +mod tests { + use super::sanitize_rewrite_output; + + #[test] + fn sanitize_rewrite_output_strips_wrapper_and_label() { + let cleaned = sanitize_rewrite_output("\nFinal text: Hi there.\n"); + assert_eq!(cleaned, "Hi there."); + } + + #[test] + fn sanitize_rewrite_output_strips_llama_stop_tokens() { + let cleaned = sanitize_rewrite_output("Hi there.<|eot_id|>ignored"); + assert_eq!(cleaned, "Hi there."); + } + + #[test] + fn sanitize_rewrite_output_strips_think_blocks() { + let cleaned = sanitize_rewrite_output("reasoning\nHi there."); + assert_eq!(cleaned, "Hi there."); + } +} diff --git a/src/rewrite/prompt.rs b/src/rewrite/prompt.rs new file mode 100644 index 0000000..bd9ca2c --- /dev/null +++ b/src/rewrite/prompt.rs @@ -0,0 +1,760 @@ +use super::RewritePrompt; +use super::routing::{ + RewriteRoute, has_policy_context, has_strong_explicit_edit_cue, + requires_candidate_adjudication, rewrite_route, +}; +use crate::rewrite_profile::{ResolvedRewriteProfile, RewriteProfile}; +use crate::rewrite_protocol::RewriteTranscript; + +pub fn build_prompt( + transcript: &RewriteTranscript, + profile: ResolvedRewriteProfile, + custom_instructions: Option<&str>, +) -> std::result::Result { + Ok(RewritePrompt { + system: build_system_instructions(transcript, profile, custom_instructions), + user: build_user_message(transcript), + }) +} + +pub fn resolved_profile_for_cloud(profile: RewriteProfile) -> ResolvedRewriteProfile { + match profile { + RewriteProfile::Auto => ResolvedRewriteProfile::Generic, + RewriteProfile::Generic => ResolvedRewriteProfile::Generic, + RewriteProfile::Qwen => ResolvedRewriteProfile::Qwen, + RewriteProfile::LlamaCompat => ResolvedRewriteProfile::LlamaCompat, + } +} + +pub fn build_oaicompat_messages_json( + prompt: &RewritePrompt, +) -> std::result::Result { + serde_json::to_string(&[ + serde_json::json!({ + "role": "system", + "content": prompt.system, + }), + serde_json::json!({ + "role": "user", + "content": prompt.user, + }), + ]) + .map_err(|e| format!("failed to encode rewrite chat messages: {e}")) +} + +pub fn effective_max_tokens(max_tokens: usize, transcript: &RewriteTranscript) -> usize { + let word_count = transcript + .correction_aware_text + .split_whitespace() + .filter(|word| !word.is_empty()) + .count(); + let extra_budget = if requires_candidate_adjudication(transcript) { + 24 + } else { + 0 + }; + let minimum = if requires_candidate_adjudication(transcript) { + 64 + } else { + 48 + }; + let derived = word_count + .saturating_mul(2) + .saturating_add(24) + .saturating_add(extra_budget); + derived.clamp(minimum, max_tokens) +} + +pub(crate) fn build_system_instructions( + transcript: &RewriteTranscript, + profile: ResolvedRewriteProfile, + custom_instructions: Option<&str>, +) -> String { + let mut instructions = rewrite_instructions(profile).to_string(); + if has_policy_context(transcript) { + let policy_context = &transcript.policy_context; + instructions.push_str("\n\nCorrection policy contract:\n"); + instructions.push_str(correction_policy_contract(policy_context.correction_policy)); + instructions.push_str("\n\nAgentic latitude contract:\n"); + instructions.push_str(agentic_latitude_contract(policy_context.correction_policy)); + if !policy_context.effective_rule_instructions.is_empty() { + instructions.push_str("\n\nMatched app rewrite policy instructions:\n"); + for instruction in &policy_context.effective_rule_instructions { + instructions.push_str("- "); + instructions.push_str(instruction.trim()); + instructions.push('\n'); + } + } + } + if let Some(custom) = custom_instructions + .map(str::trim) + .filter(|text| !text.is_empty()) + { + instructions.push_str("\n\nAdditional user rewrite instructions:\n"); + instructions.push_str(custom); + } + instructions +} + +fn correction_policy_contract( + policy: crate::rewrite_protocol::RewriteCorrectionPolicy, +) -> &'static str { + match policy { + crate::rewrite_protocol::RewriteCorrectionPolicy::Conservative => { + "Conservative: stay close to explicit rewrite candidates and glossary evidence. If uncertain, prefer candidate-preserving output over freer rewriting." + } + crate::rewrite_protocol::RewriteCorrectionPolicy::Balanced => { + "Balanced: allow stronger technical correction when the glossary, app context, or utterance semantics support it. Prefer candidate-backed output when it is competitive, but do not keep an obviously wrong technical spelling just because it appears in the candidate list." + } + crate::rewrite_protocol::RewriteCorrectionPolicy::Aggressive => { + "Aggressive: allow freer technical correction and contextual cleanup when the utterance strongly points to a technical term or proper name. Candidates are useful evidence, not hard limits, as long as you still return only final text within the provided bounds." + } + } +} + +fn agentic_latitude_contract( + policy: crate::rewrite_protocol::RewriteCorrectionPolicy, +) -> &'static str { + match policy { + crate::rewrite_protocol::RewriteCorrectionPolicy::Conservative => { + "In conservative mode, treat the candidate list and glossary as the main evidence. Only make a freer technical normalization when the utterance itself makes the intended term unusually clear." + } + crate::rewrite_protocol::RewriteCorrectionPolicy::Balanced => { + "In balanced mode, you may normalize likely technical terms, product names, commands, libraries, languages, editors, or Linux components even when the literal transcript spelling is noisy or the exact canonical form is not already present in the candidate list, as long as the utterance strongly supports that normalization." + } + crate::rewrite_protocol::RewriteCorrectionPolicy::Aggressive => { + "In aggressive mode, you may confidently rewrite phonetically similar words into the most plausible technical term or proper name when the utterance semantics, app context, or nearby category cues make that interpretation clearly better than the literal transcript." + } + } +} + +pub(crate) fn rewrite_instructions(profile: ResolvedRewriteProfile) -> &'static str { + let base = "You clean up dictated speech into the final text the user meant to type. \ +Return only the finished text. Do not explain anything. Remove obvious disfluencies when natural. \ +Use the correction-aware transcript as the primary source of truth unless structured edit signals say the \ +utterance may still be ambiguous. The raw transcript may still contain spoken editing phrases or canceled wording. \ +Never reintroduce text that was removed by an explicit spoken correction cue. Respect any structured edit intents \ +provided alongside the transcript. If structured edit signals or edit hypotheses are present, use the candidate \ +interpretations as bounded options, choose the best interpretation, and lightly refine it only when needed for natural \ +final text. Prefer transcript spellings for names, brands, and uncommon proper nouns unless a user dictionary or \ +explicit correction says otherwise. Do not normalize names into more common spellings just because they look familiar. \ +When the utterance clearly refers to software, tools, APIs, libraries, Linux components, product names, or other \ +technical concepts, prefer the most plausible intended technical term or proper name over a phonetically similar common \ +word. Use nearby category words like window manager, editor, language, library, package manager, shell, or terminal \ +tool to disambiguate technical names. If the utterance remains genuinely ambiguous, stay close to the transcript rather \ +than inventing a niche term. \ +If an edit intent says to replace or cancel previous wording, preserve that edit and do not keep the spoken correction \ +phrase itself unless the transcript clearly still intends it. Examples:\n\ +- raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ +- raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ +- raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ +- raw: Never mind. Hi, how are you today?\n correction-aware: Hi, how are you today?\n final: Hi, how are you today?\n\ +- raw: Wait, no, it actually works.\n correction-aware: Wait, no, it actually works.\n final: Wait, no, it actually works.\n\ +- raw: Let's meet tomorrow, or rather Friday.\n correction-aware: Let's meet Friday.\n final: Let's meet Friday.\n\ +- raw: I'm currently using the window manager Hyperland.\n correction-aware: I'm currently using the window manager Hyperland.\n final: I'm currently using the window manager Hyprland.\n\ +- raw: I'm switching from Sui to Hyperland.\n correction-aware: I'm switching from Sui to Hyperland.\n final: I'm switching from Sway to Hyprland.\n\ +- raw: I use type script for backend tooling.\n correction-aware: I use type script for backend tooling.\n final: I use TypeScript for backend tooling.\n\ +- raw: I edit the config in neo vim.\n correction-aware: I edit the config in neo vim.\n final: I edit the config in Neovim."; + + match profile { + ResolvedRewriteProfile::Qwen => { + "You clean up dictated speech into the final text the user meant to type. \ +Return only the finished text. Do not explain anything. Do not emit reasoning, think tags, or XML wrappers. \ +Remove obvious disfluencies when natural. Use the correction-aware transcript as the primary source of truth unless \ +structured edit signals say the utterance may still be ambiguous. The raw transcript may still contain spoken editing \ +phrases or canceled wording. Never reintroduce text that was removed by an explicit spoken correction cue. Respect \ +any structured edit intents provided alongside the transcript. If structured edit signals or edit hypotheses are \ +present, use the candidate interpretations as bounded options, choose the best interpretation, and lightly refine it \ +only when needed for natural final text. Prefer transcript spellings for names, brands, and uncommon proper nouns \ +unless a user dictionary or explicit correction says otherwise. Do not normalize names into more common spellings just \ +because they look familiar. When the utterance clearly refers to software, tools, APIs, libraries, Linux components, \ +product names, or other technical concepts, prefer the most plausible intended technical term or proper name over a \ +phonetically similar common word. Use nearby category words like window manager, editor, language, library, package \ +manager, shell, or terminal tool to disambiguate technical names. If the utterance remains genuinely ambiguous, stay \ +close to the transcript rather than inventing a niche term. If an edit intent says to replace or cancel previous wording, preserve that edit and do \ +not keep the spoken correction phrase itself unless the transcript clearly still intends it. Examples:\n\ +- raw: Hello there. Scratch that. Hi.\n correction-aware: Hi.\n final: Hi.\n\ +- raw: I'll bring cookies, scratch that, brownies.\n correction-aware: I'll bring brownies.\n final: I'll bring brownies.\n\ +- raw: My name is Notes, scratch that my name is Jonatan.\n correction-aware: My my name is Jonatan.\n aggressive correction-aware: My name is Jonatan.\n final: My name is Jonatan.\n\ +- raw: Never mind. Hi, how are you today?\n correction-aware: Hi, how are you today?\n final: Hi, how are you today?\n\ +- raw: Wait, no, it actually works.\n correction-aware: Wait, no, it actually works.\n final: Wait, no, it actually works.\n\ +- raw: Let's meet tomorrow, or rather Friday.\n correction-aware: Let's meet Friday.\n final: Let's meet Friday.\n\ +- raw: I'm currently using the window manager Hyperland.\n correction-aware: I'm currently using the window manager Hyperland.\n final: I'm currently using the window manager Hyprland.\n\ +- raw: I'm switching from Sui to Hyperland.\n correction-aware: I'm switching from Sui to Hyperland.\n final: I'm switching from Sway to Hyprland.\n\ +- raw: I use type script for backend tooling.\n correction-aware: I use type script for backend tooling.\n final: I use TypeScript for backend tooling.\n\ +- raw: I edit the config in neo vim.\n correction-aware: I edit the config in neo vim.\n final: I edit the config in Neovim." + } + ResolvedRewriteProfile::Generic | ResolvedRewriteProfile::LlamaCompat => base, + } +} + +pub(crate) fn build_user_message(transcript: &RewriteTranscript) -> String { + let language = transcript.detected_language.as_deref().unwrap_or("unknown"); + let correction_aware = transcript.correction_aware_text.trim(); + let raw = transcript.raw_text.trim(); + let edit_intents = render_edit_intents(transcript); + let edit_signals = render_edit_signals(transcript); + let agentic_context = render_agentic_context(transcript); + let route = rewrite_route(transcript); + tracing::debug!( + route = ?route, + edit_signals = transcript.edit_signals.len(), + edit_hypotheses = transcript.edit_hypotheses.len(), + rewrite_candidates = transcript.rewrite_candidates.len(), + "rewrite prompt route selected" + ); + + match route { + RewriteRoute::SessionCandidateAdjudication => { + let typing_context = render_typing_context(transcript); + let recent_session_entries = render_recent_session_entries(transcript); + let agentic_policy_context = render_agentic_policy_context(transcript); + let session_candidates = render_session_backtrack_candidates(transcript); + let recommended_session_candidate = render_recommended_session_candidate(transcript); + let rewrite_candidates = render_rewrite_candidates(transcript); + let surface_guidance = transcript + .typing_context + .as_ref() + .filter(|context| { + matches!( + context.surface_kind, + crate::rewrite_protocol::RewriteSurfaceKind::Terminal + ) + }) + .map(|_| { + "The active surface looks like a terminal. Stay conservative unless an explicit correction cue clearly indicates replacing the most recent prior dictation.\n" + }) + .unwrap_or(""); + format!( + "Language: {language}\n\ +Active typing context:\n\ +{typing_context}\ +Recent dictation session:\n\ +{recent_session_entries}\ +{agentic_policy_context}\ +Session backtrack candidates:\n\ +{session_candidates}\ +{recommended_session_candidate}\ +The user may be correcting the most recent prior dictation entry rather than appending new text.\n\ +If the recommended session candidate says replace_last_entry, treat your final text as the replacement text for that previous dictation entry, not as newly appended text.\n\ +Prefer the recommended session candidate unless another listed session candidate is clearly better.\n\ +{surface_guidance}\ +Current utterance correction candidate:\n\ +{correction_aware}\n\ +Raw current utterance:\n\ +{raw}\n\ +Current utterance bounded candidates:\n\ +{rewrite_candidates}\ +Final text:" + ) + } + RewriteRoute::CandidateAdjudication => { + let edit_hypotheses = render_edit_hypotheses(transcript); + let rewrite_candidates = render_rewrite_candidates(transcript); + let recommended_candidate = render_recommended_candidate(transcript); + let recent_segments = render_recent_segments(transcript, 4); + let aggressive_candidate = render_aggressive_candidate(transcript); + let exact_cue_guidance = if has_strong_explicit_edit_cue(transcript) { + "A strong explicit spoken edit cue was detected. The literal raw transcript probably contains canceled wording. \ +Prefer a candidate interpretation that removes the cue and canceled wording unless doing so would clearly lose intended meaning. \ +If the cue is an exact strong match for phrases like scratch that, never mind, or wait no, do not keep the literal cue text in the final output.\n" + } else { + "" + }; + tracing::trace!("rewrite hypotheses:\n{edit_hypotheses}"); + tracing::trace!("rewrite candidates:\n{rewrite_candidates}"); + format!( + "Language: {language}\n\ +{agentic_context}\ +Structured edit hypotheses:\n\ +{edit_hypotheses}\ +Structured edit signals:\n\ +{edit_signals}\ +Structured edit intents:\n\ +{edit_intents}\ +This utterance likely contains spoken self-corrections or restatements.\n\ +Choose the best candidate interpretation and lightly refine it only when needed.\n\ +{exact_cue_guidance}\ +When an exact strong edit cue is present, treat the non-literal candidates as more trustworthy than the literal transcript.\n\ +The candidate list is ordered from most likely to least likely by heuristics.\n\ +For exact strong edit cues, the first candidate is the heuristic best guess and should usually win unless another candidate is clearly better.\n\ +Prefer the smallest replacement scope that yields a coherent result.\n\ +Use span-level replacements when only a key phrase was corrected, clause-level replacements when the correction replaces the surrounding thought, and sentence-level replacements only when the whole sentence was canceled.\n\ +Preserve literal wording when the cue is not actually an edit.\n\ +Do not over-normalize names or brands.\n\ +Do not keep spoken edit cues in the final text when they act as edits.\n\ +{recommended_candidate}\ +Candidate interpretations:\n\ +{rewrite_candidates}\ +Correction candidate:\n\ +{correction_aware}\n\ +{aggressive_candidate}\ +Raw transcript:\n\ +{raw}\n\ +Recent segments:\n\ +{recent_segments}\n\ +Final text:" + ) + } + RewriteRoute::ResolvedCorrection => format!( + "Language: {language}\n\ +{agentic_context}\ +Structured edit signals:\n\ +{edit_signals}\ +Structured edit intents:\n\ +{edit_intents}\ +Self-corrections were already resolved before rewriting.\n\ +Use this correction-aware transcript as the main source text. In agentic mode, you may still normalize likely \ +technical terms or proper names when the utterance strongly supports them, even if the exact canonical spelling is not \ +already present in the candidate list:\n\ +{correction_aware}\n\ +{agentic_candidates}\ +Do not restore any canceled wording from earlier in the utterance.\n\ +Final text:", + agentic_candidates = render_agentic_candidates(transcript), + ), + RewriteRoute::Fast => { + let recent_segments = render_recent_segments(transcript, 4); + format!( + "Language: {language}\n\ +{agentic_context}\ +Structured edit signals:\n\ +{edit_signals}\ +Structured edit intents:\n\ +{edit_intents}\ +Correction-aware transcript:\n\ +{correction_aware}\n\ +Treat the correction-aware transcript as authoritative for explicit spoken edits and overall meaning, but in agentic \ +mode you may normalize likely technical terms or proper names when category cues in the utterance make the intended \ +technical meaning clearly better than the literal transcript.\n\ +{agentic_candidates}\ +\ +Recent segments:\n\ +{recent_segments}\n\ +Final text:", + agentic_candidates = render_agentic_candidates(transcript), + ) + } + } +} + +fn render_agentic_context(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return String::new(); + } + format!( + "{}{}", + render_agentic_runtime_context(transcript), + render_agentic_policy_context(transcript) + ) +} + +fn render_agentic_policy_context(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return String::new(); + } + let policy_context = &transcript.policy_context; + + format!( + "Agentic correction policy:\n\ +- mode: {}\n\ +Matched app rewrite rules:\n\ +{matched_rules}\ +Matched app policy instructions:\n\ +{effective_instructions}\ +Active glossary terms:\n\ +{glossary_terms}\ +", + policy_context.correction_policy.as_str(), + matched_rules = render_matched_rule_names(transcript), + effective_instructions = render_effective_rule_instructions(transcript), + glossary_terms = render_active_glossary_terms(transcript), + ) +} + +fn render_agentic_runtime_context(transcript: &RewriteTranscript) -> String { + if has_policy_context(transcript) { + format!( + "Active typing context:\n\ +{}\ +Recent dictation session:\n\ +{}", + render_typing_context(transcript), + render_recent_session_entries(transcript), + ) + } else { + String::new() + } +} + +fn render_agentic_candidates(transcript: &RewriteTranscript) -> String { + if has_policy_context(transcript) { + format!( + "Available rewrite candidates (advisory, not exhaustive in agentic mode):\n\ +{}\ +Glossary-backed candidates:\n\ +{}", + render_rewrite_candidates(transcript), + render_glossary_candidates(transcript) + ) + } else { + String::new() + } +} + +fn render_edit_intents(transcript: &RewriteTranscript) -> String { + if transcript.edit_intents.is_empty() { + return "- none detected\n".to_string(); + } + + let mut rendered = String::new(); + for intent in &transcript.edit_intents { + let action = match intent.action { + crate::rewrite_protocol::RewriteEditAction::ReplacePreviousPhrase => { + "replace_previous_phrase" + } + crate::rewrite_protocol::RewriteEditAction::ReplacePreviousClause => { + "replace_previous_clause" + } + crate::rewrite_protocol::RewriteEditAction::ReplacePreviousSentence => { + "replace_previous_sentence" + } + crate::rewrite_protocol::RewriteEditAction::DropEditCue => "drop_edit_cue", + }; + let confidence = match intent.confidence { + crate::rewrite_protocol::RewriteIntentConfidence::High => "high", + }; + rendered.push_str(&format!( + "- action: {action}, trigger: \"{}\", confidence: {confidence}\n", + intent.trigger + )); + } + + rendered +} + +fn render_edit_signals(transcript: &RewriteTranscript) -> String { + if transcript.edit_signals.is_empty() { + return "- none detected\n".to_string(); + } + + let mut rendered = String::new(); + for signal in &transcript.edit_signals { + let kind = match signal.kind { + crate::rewrite_protocol::RewriteEditSignalKind::Cancel => "cancel", + crate::rewrite_protocol::RewriteEditSignalKind::Replace => "replace", + crate::rewrite_protocol::RewriteEditSignalKind::Restatement => "restatement", + }; + let scope_hint = match signal.scope_hint { + crate::rewrite_protocol::RewriteEditSignalScope::Phrase => "phrase", + crate::rewrite_protocol::RewriteEditSignalScope::Clause => "clause", + crate::rewrite_protocol::RewriteEditSignalScope::Sentence => "sentence", + crate::rewrite_protocol::RewriteEditSignalScope::Unknown => "unknown", + }; + let strength = match signal.strength { + crate::rewrite_protocol::RewriteEditSignalStrength::Possible => "possible", + crate::rewrite_protocol::RewriteEditSignalStrength::Strong => "strong", + }; + rendered.push_str(&format!( + "- trigger: \"{}\", kind: {kind}, scope_hint: {scope_hint}, strength: {strength}\n", + signal.trigger + )); + } + + rendered +} + +fn render_edit_hypotheses(transcript: &RewriteTranscript) -> String { + if transcript.edit_hypotheses.is_empty() { + return "- none detected\n".to_string(); + } + + let mut rendered = String::new(); + for hypothesis in &transcript.edit_hypotheses { + let match_source = match hypothesis.match_source { + crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact => "exact", + crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Alias => "alias", + crate::rewrite_protocol::RewriteEditHypothesisMatchSource::NearMiss => "near_miss", + }; + let kind = match hypothesis.kind { + crate::rewrite_protocol::RewriteEditSignalKind::Cancel => "cancel", + crate::rewrite_protocol::RewriteEditSignalKind::Replace => "replace", + crate::rewrite_protocol::RewriteEditSignalKind::Restatement => "restatement", + }; + let scope_hint = match hypothesis.scope_hint { + crate::rewrite_protocol::RewriteEditSignalScope::Phrase => "phrase", + crate::rewrite_protocol::RewriteEditSignalScope::Clause => "clause", + crate::rewrite_protocol::RewriteEditSignalScope::Sentence => "sentence", + crate::rewrite_protocol::RewriteEditSignalScope::Unknown => "unknown", + }; + let strength = match hypothesis.strength { + crate::rewrite_protocol::RewriteEditSignalStrength::Possible => "possible", + crate::rewrite_protocol::RewriteEditSignalStrength::Strong => "strong", + }; + let replacement_scope = match hypothesis.replacement_scope { + crate::rewrite_protocol::RewriteReplacementScope::Span => "span", + crate::rewrite_protocol::RewriteReplacementScope::Clause => "clause", + crate::rewrite_protocol::RewriteReplacementScope::Sentence => "sentence", + }; + let tail_shape = match hypothesis.tail_shape { + crate::rewrite_protocol::RewriteTailShape::Empty => "empty", + crate::rewrite_protocol::RewriteTailShape::Phrase => "phrase", + crate::rewrite_protocol::RewriteTailShape::Clause => "clause", + }; + rendered.push_str(&format!( + "- cue_family: {}, matched_text: \"{}\", match_source: {match_source}, kind: {kind}, scope_hint: {scope_hint}, replacement_scope: {replacement_scope}, tail_shape: {tail_shape}, strength: {strength}\n", + hypothesis.cue_family, hypothesis.matched_text + )); + } + + rendered +} + +fn render_rewrite_candidates(transcript: &RewriteTranscript) -> String { + if transcript.rewrite_candidates.is_empty() { + return "- no candidates available\n".to_string(); + } + + let mut rendered = String::new(); + let highlight_first = has_strong_explicit_edit_cue(transcript); + for (index, candidate) in transcript.rewrite_candidates.iter().enumerate() { + let prefix = if highlight_first && index == 0 { + "- preferred_candidate" + } else { + "-" + }; + let kind = match candidate.kind { + crate::rewrite_protocol::RewriteCandidateKind::Literal => { + "literal (keep only if the cue was not actually an edit)" + } + crate::rewrite_protocol::RewriteCandidateKind::ConservativeCorrection => { + "conservative_correction (balanced cleanup)" + } + crate::rewrite_protocol::RewriteCandidateKind::AggressiveCorrection => { + "aggressive_correction (use when canceled wording should be removed more fully)" + } + crate::rewrite_protocol::RewriteCandidateKind::GlossaryCorrection => { + "glossary_correction (supported by active glossary aliases)" + } + crate::rewrite_protocol::RewriteCandidateKind::SpanReplacement => { + "span_replacement (replace only the corrected phrase)" + } + crate::rewrite_protocol::RewriteCandidateKind::ClauseReplacement => { + "clause_replacement (replace the corrected clause while keeping surrounding context)" + } + crate::rewrite_protocol::RewriteCandidateKind::SentenceReplacement => { + "sentence_replacement (replace the whole corrected sentence)" + } + crate::rewrite_protocol::RewriteCandidateKind::ContextualReplacement => { + "contextual_replacement (replace the corrected span while keeping earlier context)" + } + crate::rewrite_protocol::RewriteCandidateKind::DropCueOnly => { + "drop_cue_only (remove just the spoken edit cue)" + } + crate::rewrite_protocol::RewriteCandidateKind::FollowingReplacement => { + "following_replacement (keep only the wording after the cue)" + } + crate::rewrite_protocol::RewriteCandidateKind::CancelPreviousClause => { + "cancel_previous_clause (treat the cue as canceling the prior clause)" + } + crate::rewrite_protocol::RewriteCandidateKind::CancelPreviousSentence => { + "cancel_previous_sentence (treat the cue as canceling the prior sentence)" + } + }; + rendered.push_str(&format!("{prefix} {kind}: {}\n", candidate.text)); + } + + rendered +} + +fn render_recommended_candidate(transcript: &RewriteTranscript) -> String { + transcript + .recommended_candidate + .as_ref() + .map(|candidate| { + format!( + "Recommended interpretation:\n{}\nUse this as the default final text unless another candidate is clearly better.\n", + candidate.text + ) + }) + .unwrap_or_default() +} + +fn render_typing_context(transcript: &RewriteTranscript) -> String { + transcript + .typing_context + .as_ref() + .map(|context| { + format!( + "- focus_fingerprint: {}\n- app_id: {}\n- window_title: {}\n- surface_kind: {}\n- browser_domain: {}\n", + context.focus_fingerprint, + context.app_id.as_deref().unwrap_or("unknown"), + context.window_title.as_deref().unwrap_or("unknown"), + context.surface_kind.as_str(), + context.browser_domain.as_deref().unwrap_or("unknown"), + ) + }) + .unwrap_or_else(|| "- none available\n".to_string()) +} + +fn render_recent_session_entries(transcript: &RewriteTranscript) -> String { + if transcript.recent_session_entries.is_empty() { + return "- none available\n".to_string(); + } + + let mut rendered = String::new(); + for entry in &transcript.recent_session_entries { + rendered.push_str(&format!( + "- id: {}, text: {}, grapheme_len: {}, surface_kind: {}\n", + entry.id, + entry.final_text, + entry.grapheme_len, + entry.surface_kind.as_str() + )); + } + rendered +} + +fn render_session_backtrack_candidates(transcript: &RewriteTranscript) -> String { + if transcript.session_backtrack_candidates.is_empty() { + return "- no session backtrack candidates\n".to_string(); + } + + let mut rendered = String::new(); + for candidate in &transcript.session_backtrack_candidates { + let kind = match candidate.kind { + crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::AppendCurrent => { + "append_current" + } + crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::ReplaceLastEntry => { + "replace_last_entry" + } + }; + rendered.push_str(&format!( + "- kind: {kind}, entry_id: {}, delete_graphemes: {}, text: {}\n", + candidate + .entry_id + .map(|entry_id| entry_id.to_string()) + .unwrap_or_else(|| "none".to_string()), + candidate.delete_graphemes, + candidate.text + )); + } + rendered +} + +fn render_recommended_session_candidate(transcript: &RewriteTranscript) -> String { + transcript + .recommended_session_candidate + .as_ref() + .map(|candidate| { + let mode = match candidate.kind { + crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::AppendCurrent => { + "append_current" + } + crate::rewrite_protocol::RewriteSessionBacktrackCandidateKind::ReplaceLastEntry => { + "replace_last_entry" + } + }; + format!( + "Recommended session action:\nmode: {mode}\nentry_id: {}\ndelete_graphemes: {}\ntext: {}\n", + candidate + .entry_id + .map(|entry_id| entry_id.to_string()) + .unwrap_or_else(|| "none".to_string()), + candidate.delete_graphemes, + candidate.text + ) + }) + .unwrap_or_default() +} + +fn render_recent_segments(transcript: &RewriteTranscript, limit: usize) -> String { + let total_segments = transcript.segments.len(); + let start = total_segments.saturating_sub(limit); + let mut rendered = String::new(); + + for segment in &transcript.segments[start..] { + let line = format!( + "- {}-{} ms: {}\n", + segment.start_ms, segment.end_ms, segment.text + ); + rendered.push_str(&line); + } + + if rendered.is_empty() { + rendered.push_str("- no segments available\n"); + } + + rendered +} + +fn render_aggressive_candidate(transcript: &RewriteTranscript) -> String { + transcript + .aggressive_correction_text + .as_deref() + .map(str::trim) + .filter(|text| !text.is_empty()) + .map(|text| format!("Aggressive correction candidate:\n{text}\n")) + .unwrap_or_default() +} + +fn render_matched_rule_names(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return "- none\n".to_string(); + } + let policy_context = &transcript.policy_context; + if policy_context.matched_rule_names.is_empty() { + return "- none\n".to_string(); + } + policy_context + .matched_rule_names + .iter() + .map(|name| format!("- {name}\n")) + .collect() +} + +fn render_effective_rule_instructions(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return "- none\n".to_string(); + } + let policy_context = &transcript.policy_context; + if policy_context.effective_rule_instructions.is_empty() { + return "- none\n".to_string(); + } + policy_context + .effective_rule_instructions + .iter() + .map(|instruction| format!("- {}\n", instruction.trim())) + .collect() +} + +fn render_active_glossary_terms(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return "- none\n".to_string(); + } + let policy_context = &transcript.policy_context; + if policy_context.active_glossary_terms.is_empty() { + return "- none\n".to_string(); + } + policy_context + .active_glossary_terms + .iter() + .map(|entry| format!("- {} <- [{}]\n", entry.term, entry.aliases.join(", "))) + .collect() +} + +fn render_glossary_candidates(transcript: &RewriteTranscript) -> String { + if !has_policy_context(transcript) { + return "- none\n".to_string(); + } + let policy_context = &transcript.policy_context; + if policy_context.glossary_candidates.is_empty() { + return "- none\n".to_string(); + } + policy_context + .glossary_candidates + .iter() + .map(|candidate| format!("- {}\n", candidate.text)) + .collect() +} diff --git a/src/rewrite/routing.rs b/src/rewrite/routing.rs new file mode 100644 index 0000000..13b8c45 --- /dev/null +++ b/src/rewrite/routing.rs @@ -0,0 +1,46 @@ +use crate::rewrite_protocol::{ + RewriteEditHypothesisMatchSource, RewriteEditSignalStrength, RewriteTranscript, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum RewriteRoute { + Fast, + ResolvedCorrection, + SessionCandidateAdjudication, + CandidateAdjudication, +} + +pub(super) fn rewrite_route(transcript: &RewriteTranscript) -> RewriteRoute { + if has_session_backtrack_candidate(transcript) { + RewriteRoute::SessionCandidateAdjudication + } else if requires_candidate_adjudication(transcript) { + RewriteRoute::CandidateAdjudication + } else if transcript.correction_aware_text.trim() != transcript.raw_text.trim() { + RewriteRoute::ResolvedCorrection + } else { + RewriteRoute::Fast + } +} + +pub(super) fn requires_candidate_adjudication(transcript: &RewriteTranscript) -> bool { + !transcript.edit_signals.is_empty() || !transcript.edit_hypotheses.is_empty() +} + +pub(super) fn has_strong_explicit_edit_cue(transcript: &RewriteTranscript) -> bool { + transcript.edit_hypotheses.iter().any(|hypothesis| { + hypothesis.strength == RewriteEditSignalStrength::Strong + && matches!( + hypothesis.match_source, + RewriteEditHypothesisMatchSource::Exact | RewriteEditHypothesisMatchSource::Alias + ) + }) +} + +pub(super) fn has_session_backtrack_candidate(transcript: &RewriteTranscript) -> bool { + transcript.recommended_session_candidate.is_some() + || !transcript.session_backtrack_candidates.is_empty() +} + +pub(super) fn has_policy_context(transcript: &RewriteTranscript) -> bool { + transcript.policy_context.is_active() +} diff --git a/src/rewrite/tests.rs b/src/rewrite/tests.rs new file mode 100644 index 0000000..c401459 --- /dev/null +++ b/src/rewrite/tests.rs @@ -0,0 +1,463 @@ +use super::RewritePrompt; +use super::prompt::{ + build_oaicompat_messages_json, build_system_instructions, build_user_message, + effective_max_tokens, rewrite_instructions, +}; +use super::routing::{RewriteRoute, rewrite_route}; +use crate::rewrite_profile::ResolvedRewriteProfile; +use crate::rewrite_protocol::{ + RewriteCandidate, RewriteCandidateKind, RewriteCorrectionPolicy, RewriteEditAction, + RewriteEditHypothesis, RewriteEditHypothesisMatchSource, RewriteEditIntent, RewriteEditSignal, + RewriteEditSignalKind, RewriteEditSignalScope, RewriteEditSignalStrength, + RewriteIntentConfidence, RewritePolicyContext, RewriteReplacementScope, + RewriteSessionBacktrackCandidate, RewriteSessionBacktrackCandidateKind, RewriteSessionEntry, + RewriteSurfaceKind, RewriteTailShape, RewriteTranscript, RewriteTranscriptSegment, + RewriteTypingContext, +}; + +fn correction_transcript() -> RewriteTranscript { + RewriteTranscript { + raw_text: "Hi there, this is a test. Wait, no. Hi there.".into(), + correction_aware_text: "Hi there.".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: vec![ + RewriteTranscriptSegment { + text: "Hi there, this is a test.".into(), + start_ms: 0, + end_ms: 1200, + }, + RewriteTranscriptSegment { + text: "Wait, no. Hi there.".into(), + start_ms: 1200, + end_ms: 2200, + }, + ], + edit_intents: vec![RewriteEditIntent { + action: RewriteEditAction::ReplacePreviousSentence, + trigger: "wait no".into(), + confidence: RewriteIntentConfidence::High, + }], + edit_signals: vec![RewriteEditSignal { + trigger: "wait no".into(), + kind: RewriteEditSignalKind::Replace, + scope_hint: RewriteEditSignalScope::Sentence, + strength: RewriteEditSignalStrength::Strong, + }], + edit_hypotheses: vec![RewriteEditHypothesis { + cue_family: "wait_no".into(), + matched_text: "wait no".into(), + match_source: RewriteEditHypothesisMatchSource::Exact, + kind: RewriteEditSignalKind::Replace, + scope_hint: RewriteEditSignalScope::Sentence, + replacement_scope: RewriteReplacementScope::Sentence, + tail_shape: RewriteTailShape::Phrase, + strength: RewriteEditSignalStrength::Strong, + }], + rewrite_candidates: vec![ + RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "Hi there, this is a test. Wait, no. Hi there.".into(), + }, + RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "Hi there.".into(), + }, + ], + recommended_candidate: Some(RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "Hi there, this is a test. Wait, no. Hi there.".into(), + }), + policy_context: RewritePolicyContext::default(), + } +} + +fn candidate_only_transcript() -> RewriteTranscript { + RewriteTranscript { + raw_text: "Hi there, this is a test. Scratch that. Hi there.".into(), + correction_aware_text: "Hi there.".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: vec![RewriteEditIntent { + action: RewriteEditAction::ReplacePreviousSentence, + trigger: "scratch that".into(), + confidence: RewriteIntentConfidence::High, + }], + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![ + RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "Hi there, this is a test. Scratch that. Hi there.".into(), + }, + RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "Hi there.".into(), + }, + ], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + } +} + +fn fast_agentic_transcript() -> RewriteTranscript { + RewriteTranscript { + raw_text: "I'm currently using the window manager hyperland.".into(), + correction_aware_text: "I'm currently using the window manager hyperland.".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: Some(RewriteTypingContext { + focus_fingerprint: "focus".into(), + app_id: Some("browser".into()), + window_title: Some("Matrix".into()), + surface_kind: RewriteSurfaceKind::GenericText, + browser_domain: None, + captured_at_ms: 42, + }), + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![ + RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "I'm currently using the window manager hyperland.".into(), + }, + RewriteCandidate { + kind: RewriteCandidateKind::ConservativeCorrection, + text: "I'm currently using the window manager hyperland.".into(), + }, + ], + recommended_candidate: None, + policy_context: RewritePolicyContext { + correction_policy: RewriteCorrectionPolicy::Balanced, + matched_rule_names: vec!["baseline/global-default".into()], + effective_rule_instructions: vec![ + "Use category cues like window manager to disambiguate nearby technical names." + .into(), + ], + active_glossary_terms: Vec::new(), + glossary_candidates: Vec::new(), + }, + } +} + +#[test] +fn instructions_cover_self_correction_examples() { + let instructions = rewrite_instructions(ResolvedRewriteProfile::LlamaCompat); + assert!(instructions.contains("Return only the finished text")); + assert!(instructions.contains("Never reintroduce text")); + assert!(instructions.contains("scratch that, brownies")); + assert!(instructions.contains("window manager Hyperland")); + assert!(instructions.contains("switching from Sui to Hyperland")); +} + +#[test] +fn qwen_instructions_forbid_reasoning_tags() { + let instructions = rewrite_instructions(ResolvedRewriteProfile::Qwen); + assert!(instructions.contains("Do not emit reasoning")); + assert!(instructions.contains("phonetically similar common word")); +} + +#[test] +fn base_instructions_allow_technical_term_inference() { + let instructions = rewrite_instructions(ResolvedRewriteProfile::LlamaCompat); + assert!(instructions.contains("technical concepts")); + assert!(instructions.contains("phonetically similar common word")); +} + +#[test] +fn custom_instructions_append_to_system_prompt() { + let instructions = build_system_instructions( + &correction_transcript(), + ResolvedRewriteProfile::Qwen, + Some("Keep product names exact."), + ); + assert!(instructions.contains("Return only the finished text")); + assert!(instructions.contains("Keep product names exact.")); +} + +#[test] +fn oaicompat_messages_json_contains_system_and_user_roles() { + let prompt = RewritePrompt { + system: "system instructions".into(), + user: "user input".into(), + }; + + let messages_json = build_oaicompat_messages_json(&prompt).expect("encode oaicompat messages"); + let messages: serde_json::Value = + serde_json::from_str(&messages_json).expect("parse oaicompat messages"); + let messages = messages.as_array().expect("messages array"); + + assert_eq!(messages.len(), 2); + assert_eq!(messages[0]["role"], "system"); + assert_eq!(messages[0]["content"], "system instructions"); + assert_eq!(messages[1]["role"], "user"); + assert_eq!(messages[1]["content"], "user input"); +} + +#[test] +fn agentic_system_prompt_relaxes_candidate_restrictions() { + let instructions = build_system_instructions( + &fast_agentic_transcript(), + ResolvedRewriteProfile::Qwen, + None, + ); + assert!(instructions.contains("Agentic latitude contract")); + assert!(instructions.contains( + "do not keep an obviously wrong technical spelling just because it appears in the candidate list" + )); + assert!(instructions.contains( + "even when the literal transcript spelling is noisy or the exact canonical form is not already present in the candidate list" + )); +} + +#[test] +fn fast_route_prompt_allows_agentic_technical_normalization() { + let transcript = fast_agentic_transcript(); + assert!(matches!(rewrite_route(&transcript), RewriteRoute::Fast)); + let prompt = build_user_message(&transcript); + assert!(prompt.contains( + "you may normalize likely technical terms or proper names when category cues in the utterance make the intended technical meaning clearly better than the literal transcript" + )); + assert!( + prompt.contains("Available rewrite candidates (advisory, not exhaustive in agentic mode)") + ); +} + +#[test] +fn cue_prompt_includes_raw_candidate_and_signals() { + let prompt = build_user_message(&correction_transcript()); + assert!(matches!( + rewrite_route(&correction_transcript()), + RewriteRoute::CandidateAdjudication + )); + assert!(prompt.contains("Structured edit hypotheses")); + assert!(prompt.contains("cue_family: wait_no")); + assert!(prompt.contains("replacement_scope: sentence")); + assert!(prompt.contains("tail_shape: phrase")); + assert!(prompt.contains("Candidate interpretations")); + assert!(prompt.contains("A strong explicit spoken edit cue was detected")); + assert!( + prompt.contains( + "The candidate list is ordered from most likely to least likely by heuristics." + ) + ); + assert!(prompt.contains("the first candidate is the heuristic best guess")); + assert!(prompt.contains("Recommended interpretation:")); + assert!(prompt.contains( + "Use this as the default final text unless another candidate is clearly better." + )); + assert!( + prompt.contains("Prefer the smallest replacement scope that yields a coherent result.") + ); + assert!(prompt.contains("- preferred_candidate")); + assert!(prompt.contains( + "- preferred_candidate literal (keep only if the cue was not actually an edit): Hi there, this is a test. Wait, no. Hi there." + )); + assert!(prompt.contains("Structured edit signals")); + assert!(prompt.contains("trigger: \"wait no\"")); + assert!(prompt.contains("Structured edit intents")); + assert!(prompt.contains("replace_previous_sentence")); + assert!(prompt.contains("Choose the best candidate interpretation")); + assert!(prompt.contains("Candidate interpretations:\n")); + assert!(prompt.contains("Correction candidate:\nHi there.")); + assert!(prompt.contains("Raw transcript:\nHi there, this is a test. Wait, no. Hi there.")); + assert!(prompt.contains("Recent segments")); +} + +#[test] +fn cue_prompt_includes_aggressive_candidate_when_available() { + let mut transcript = correction_transcript(); + transcript.aggressive_correction_text = Some("Hi there.".into()); + + let prompt = build_user_message(&transcript); + assert!(prompt.contains("Aggressive correction candidate")); +} + +#[test] +fn user_message_prefers_correction_candidate_without_signals() { + let prompt = build_user_message(&candidate_only_transcript()); + assert!(matches!( + rewrite_route(&candidate_only_transcript()), + RewriteRoute::ResolvedCorrection + )); + assert!(!prompt.contains("Recommended interpretation:")); + assert!(prompt.contains("Structured edit signals")); + assert!(prompt.contains("Structured edit intents")); + assert!(prompt.contains("Self-corrections were already resolved")); + assert!(prompt.contains("Do not restore any canceled wording")); + assert!(!prompt.contains("Recent segments")); + assert!(!prompt.contains("Raw transcript")); +} + +#[test] +fn user_message_includes_recent_segments_when_correction_matches_raw() { + let transcript = RewriteTranscript { + raw_text: "Hi there.".into(), + correction_aware_text: "Hi there.".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: vec![RewriteTranscriptSegment { + text: "Hi there.".into(), + start_ms: 0, + end_ms: 1200, + }], + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "Hi there.".into(), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + }; + + let prompt = build_user_message(&transcript); + assert!(matches!(rewrite_route(&transcript), RewriteRoute::Fast)); + assert!(prompt.contains("Correction-aware transcript")); + assert!(prompt.contains("Structured edit signals")); + assert!(prompt.contains("Recent segments")); + assert!(prompt.contains("0-1200 ms")); + assert!(prompt.contains("Hi there.")); +} + +#[test] +fn effective_max_tokens_scales_with_transcript_length() { + let short = RewriteTranscript { + raw_text: "hi there".into(), + correction_aware_text: "hi there".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "hi there".into(), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + }; + assert_eq!(effective_max_tokens(256, &short), 48); + + let long = RewriteTranscript { + raw_text: "word ".repeat(80), + correction_aware_text: "word ".repeat(80), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::Literal, + text: "word ".repeat(80), + }], + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + }; + assert_eq!(effective_max_tokens(256, &long), 184); +} + +#[test] +fn effective_max_tokens_gives_edit_heavy_prompts_more_budget() { + let transcript = correction_transcript(); + assert_eq!(effective_max_tokens(256, &transcript), 64); +} + +#[test] +fn session_prompt_includes_recent_entry_and_context() { + let transcript = RewriteTranscript { + raw_text: "scratch that hi".into(), + correction_aware_text: "Hi".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: Some(RewriteTypingContext { + focus_fingerprint: "hyprland:0x123".into(), + app_id: Some("firefox".into()), + window_title: Some("Example".into()), + surface_kind: RewriteSurfaceKind::Browser, + browser_domain: None, + captured_at_ms: 10, + }), + recent_session_entries: vec![RewriteSessionEntry { + id: 7, + final_text: "Hello there".into(), + grapheme_len: 11, + focus_fingerprint: "hyprland:0x123".into(), + surface_kind: RewriteSurfaceKind::Browser, + app_id: Some("firefox".into()), + window_title: Some("Example".into()), + }], + session_backtrack_candidates: vec![ + RewriteSessionBacktrackCandidate { + kind: RewriteSessionBacktrackCandidateKind::ReplaceLastEntry, + entry_id: Some(7), + delete_graphemes: 11, + text: "Hi".into(), + }, + RewriteSessionBacktrackCandidate { + kind: RewriteSessionBacktrackCandidateKind::AppendCurrent, + entry_id: None, + delete_graphemes: 0, + text: "Hi".into(), + }, + ], + recommended_session_candidate: Some(RewriteSessionBacktrackCandidate { + kind: RewriteSessionBacktrackCandidateKind::ReplaceLastEntry, + entry_id: Some(7), + delete_graphemes: 11, + text: "Hi".into(), + }), + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: vec![RewriteCandidate { + kind: RewriteCandidateKind::SentenceReplacement, + text: "Hi".into(), + }], + recommended_candidate: Some(RewriteCandidate { + kind: RewriteCandidateKind::SentenceReplacement, + text: "Hi".into(), + }), + policy_context: RewritePolicyContext::default(), + }; + + let prompt = build_user_message(&transcript); + assert!(matches!( + rewrite_route(&transcript), + RewriteRoute::SessionCandidateAdjudication + )); + assert!(prompt.contains("Active typing context")); + assert!(prompt.contains("Recent dictation session")); + assert!(prompt.contains("replace_last_entry")); + assert!(prompt.contains("treat your final text as the replacement text")); +} diff --git a/src/rewrite_model.rs b/src/rewrite_model.rs index c7cbb1d..5315f63 100644 --- a/src/rewrite_model.rs +++ b/src/rewrite_model.rs @@ -1,12 +1,8 @@ use std::path::{Path, PathBuf}; -use futures_util::StreamExt; -use tokio::io::AsyncWriteExt; - -use crate::config::{ - self, RewriteBackend, data_dir, resolve_config_path, update_config_rewrite_selection, -}; +use crate::config::{self, RewriteBackend, data_dir, update_config_rewrite_selection}; use crate::error::{Result, WhsprError}; +use crate::model_support; use crate::rewrite_profile::RewriteProfile; pub struct RewriteModelInfo { @@ -71,25 +67,15 @@ pub fn managed_profile(name: &str) -> Option { } fn active_model_name(config_path_override: Option<&Path>) -> Option { - let config_path = resolve_config_path(config_path_override); - if !config_path.exists() { - return None; - } - let config = config::Config::load(Some(&config_path)).ok()?; - Some(config.rewrite.selected_model) + model_support::load_config_if_exists(config_path_override) + .map(|config| config.rewrite.selected_model) } fn model_status(info: &RewriteModelInfo, active_name: Option<&str>) -> &'static str { let path = model_path(info.filename); let is_active = active_name == Some(info.name); let is_local = path.exists(); - - match (is_active, is_local) { - (true, true) => "active", - (true, false) => "active (missing)", - (_, true) => "local", - _ => "remote", - } + model_support::managed_download_status(is_active, is_local) } pub fn list_models(config_path_override: Option<&Path>) { @@ -134,78 +120,18 @@ pub async fn download_model(name: &str) -> Result { pub async fn download_model_with_url(info: &RewriteModelInfo, url: &str) -> Result { let dest = model_path(info.filename); let part_path = dest.with_extension("gguf.part"); - - if dest.exists() { - tracing::info!( - "rewrite model '{}' already downloaded at {}", - info.name, - dest.display() - ); - println!("{}", crate::ui::ready_message("Rewrite", info.name)); - return Ok(dest); - } - - if let Some(parent) = dest.parent() { - std::fs::create_dir_all(parent) - .map_err(|e| WhsprError::Download(format!("failed to create data directory: {e}")))?; - } - - tracing::info!("downloading rewrite model '{}' from {}", info.name, url); - println!( - "{} Downloading rewrite model {} ({})...", - crate::ui::info_label(), - crate::ui::value(info.name), - info.size - ); - - let client = reqwest::Client::new(); - let response = client - .get(url) - .send() - .await - .map_err(|e| WhsprError::Download(format!("failed to start download: {e}")))?; - - if !response.status().is_success() { - return Err(WhsprError::Download(format!( - "download failed with HTTP {}", - response.status() - ))); - } - - let total_size = response.content_length().unwrap_or(0); - let pb = crate::ui::progress_bar(total_size); - - let mut file = tokio::fs::OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open(&part_path) - .await - .map_err(|e| WhsprError::Download(format!("failed to open file: {e}")))?; - - let mut stream = response.bytes_stream(); - while let Some(chunk) = stream.next().await { - let chunk = - chunk.map_err(|e| WhsprError::Download(format!("download interrupted: {e}")))?; - file.write_all(&chunk) - .await - .map_err(|e| WhsprError::Download(format!("failed to write: {e}")))?; - pb.inc(chunk.len() as u64); - } - - file.flush() - .await - .map_err(|e| WhsprError::Download(format!("failed to flush: {e}")))?; - drop(file); - - pb.finish_with_message("done"); - - std::fs::rename(&part_path, &dest) - .map_err(|e| WhsprError::Download(format!("failed to finalize download: {e}")))?; - - tracing::info!("rewrite model '{}' saved to {}", info.name, dest.display()); - println!("{}", crate::ui::ready_message("Rewrite", info.name)); - Ok(dest) + model_support::download_to_path(model_support::DownloadSpec { + tracing_label: "rewrite model", + user_label: "rewrite model", + ready_kind: "Rewrite", + item_name: info.name, + size: info.size, + url, + dest, + part_path, + resume_partial: false, + }) + .await } pub fn select_model(name: &str, config_path_override: Option<&Path>) -> Result<()> { @@ -220,14 +146,12 @@ pub fn select_model(name: &str, config_path_override: Option<&Path>) -> Result<( ))); } - let config_path = resolve_config_path(config_path_override); - if !config_path.exists() { - let whisper_model = config::Config::default().transcription.model_path; - config::write_default_config(&config_path, &whisper_model)?; - } + let whisper_model = config::Config::default().transcription.model_path; + let (config_path, _) = + model_support::ensure_default_config(config_path_override, &whisper_model)?; update_config_rewrite_selection(&config_path, info.name)?; - if config::Config::load(Some(&config_path)) + if model_support::load_config_at_if_exists(&config_path) .map(|config| config.rewrite.backend == RewriteBackend::Cloud) .unwrap_or(false) { diff --git a/src/rewrite_protocol.rs b/src/rewrite_protocol.rs index c4abe8f..edacdc4 100644 --- a/src/rewrite_protocol.rs +++ b/src/rewrite_protocol.rs @@ -1,3 +1,4 @@ +use clap::ValueEnum; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -16,6 +17,8 @@ pub struct RewriteTranscript { pub edit_hypotheses: Vec, pub rewrite_candidates: Vec, pub recommended_candidate: Option, + #[serde(default)] + pub policy_context: RewritePolicyContext, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -87,6 +90,21 @@ pub struct RewriteCandidate { pub text: String, } +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewritePolicyContext { + pub correction_policy: RewriteCorrectionPolicy, + pub matched_rule_names: Vec, + pub effective_rule_instructions: Vec, + pub active_glossary_terms: Vec, + pub glossary_candidates: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RewritePolicyGlossaryTerm { + pub term: String, + pub aliases: Vec, +} + #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum RewriteEditAction { @@ -150,12 +168,22 @@ pub enum RewriteTailShape { Clause, } +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, ValueEnum)] +#[serde(rename_all = "snake_case")] +pub enum RewriteCorrectionPolicy { + Conservative, + #[default] + Balanced, + Aggressive, +} + #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum RewriteCandidateKind { Literal, ConservativeCorrection, AggressiveCorrection, + GlossaryCorrection, SpanReplacement, ClauseReplacement, SentenceReplacement, @@ -166,7 +194,7 @@ pub enum RewriteCandidateKind { CancelPreviousSentence, } -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, ValueEnum)] #[serde(rename_all = "snake_case")] pub enum RewriteSurfaceKind { Browser, @@ -198,3 +226,34 @@ pub enum WorkerResponse { Result { text: String }, Error { message: String }, } + +impl RewriteCorrectionPolicy { + pub fn as_str(self) -> &'static str { + match self { + Self::Conservative => "conservative", + Self::Balanced => "balanced", + Self::Aggressive => "aggressive", + } + } +} + +impl RewriteSurfaceKind { + pub fn as_str(self) -> &'static str { + match self { + Self::Browser => "browser", + Self::Terminal => "terminal", + Self::Editor => "editor", + Self::GenericText => "generic_text", + Self::Unknown => "unknown", + } + } +} + +impl RewritePolicyContext { + pub fn is_active(&self) -> bool { + !self.matched_rule_names.is_empty() + || !self.effective_rule_instructions.is_empty() + || !self.active_glossary_terms.is_empty() + || !self.glossary_candidates.is_empty() + } +} diff --git a/src/runtime_support.rs b/src/runtime_support.rs new file mode 100644 index 0000000..da52f5a --- /dev/null +++ b/src/runtime_support.rs @@ -0,0 +1,207 @@ +use std::path::{Path, PathBuf}; + +use tracing_subscriber::EnvFilter; + +use crate::branding; +use crate::error::{Result, WhsprError}; +use crate::ui; + +pub struct PidLock { + path: PathBuf, + _file: std::fs::File, +} + +impl Drop for PidLock { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.path); + } +} + +fn pid_file_path() -> PathBuf { + let runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".into()); + PathBuf::from(runtime_dir).join(branding::MAIN_PID_FILE) +} + +fn read_pid_from_lock(path: &Path) -> Option { + let contents = std::fs::read_to_string(path).ok()?; + contents.trim().parse().ok() +} + +fn process_exists(pid: libc::pid_t) -> bool { + Path::new(&format!("/proc/{pid}")).exists() +} + +fn pid_belongs_to_whspr(pid: libc::pid_t) -> bool { + if !process_exists(pid) { + return false; + } + + let current_exe = std::env::current_exe() + .ok() + .and_then(|p| std::fs::canonicalize(p).ok()); + let target_exe = std::fs::canonicalize(format!("/proc/{pid}/exe")).ok(); + + if let (Some(current), Some(target)) = (current_exe.as_ref(), target_exe.as_ref()) { + if current == target { + return true; + } + } + + let current_name = std::env::current_exe() + .ok() + .and_then(|p| p.file_name().map(|n| n.to_string_lossy().into_owned())) + .unwrap_or_else(|| branding::MAIN_BINARY.into()); + let cmdline = match std::fs::read(format!("/proc/{pid}/cmdline")) { + Ok(bytes) => bytes, + Err(_) => return false, + }; + let Some(first_arg) = cmdline.split(|b| *b == 0).next() else { + return false; + }; + if first_arg.is_empty() { + return false; + } + let first_arg = String::from_utf8_lossy(first_arg); + Path::new(first_arg.as_ref()) + .file_name() + .map(|name| name.to_string_lossy() == current_name) + .unwrap_or(false) +} + +fn try_acquire_pid_lock(path: &Path) -> std::io::Result { + use std::io::Write; + + let mut file = std::fs::OpenOptions::new() + .write(true) + .create_new(true) + .open(path)?; + writeln!(file, "{}", std::process::id())?; + + Ok(PidLock { + path: path.to_path_buf(), + _file: file, + }) +} + +fn signal_existing_instance(path: &Path) -> Result { + let Some(pid) = read_pid_from_lock(path) else { + tracing::warn!("stale pid lock at {}, removing", path.display()); + let _ = std::fs::remove_file(path); + return Ok(false); + }; + + if !pid_belongs_to_whspr(pid) { + tracing::warn!( + "pid lock at {} points to non-whspr process ({pid}), removing", + path.display() + ); + let _ = std::fs::remove_file(path); + return Ok(false); + } + + tracing::info!("sending toggle signal to running instance (pid {pid})"); + let ret = unsafe { libc::kill(pid, libc::SIGUSR1) }; + if ret == 0 { + return Ok(true); + } + + let err = std::io::Error::last_os_error(); + tracing::warn!("failed to signal pid {pid}: {err}"); + if err.raw_os_error() == Some(libc::ESRCH) { + let _ = std::fs::remove_file(path); + return Ok(false); + } + + Err(err.into()) +} + +pub fn acquire_or_signal_lock() -> Result> { + let path = pid_file_path(); + + for _ in 0..2 { + match try_acquire_pid_lock(&path) { + Ok(lock) => return Ok(Some(lock)), + Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { + if signal_existing_instance(&path)? { + return Ok(None); + } + } + Err(e) => return Err(e.into()), + } + } + + Err(WhsprError::Config(format!( + "failed to acquire pid lock at {}", + path.display() + ))) +} + +pub fn init_tracing(verbose: u8) { + ui::configure_terminal_colors(); + ui::set_verbosity(verbose); + let filter = match verbose { + 0 => "whispers=warn", + 1 => "whispers=info", + 2 => "whispers=debug", + _ => "whispers=trace", + }; + + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(filter)), + ) + .compact() + .init(); +} + +#[cfg(test)] +mod tests { + use super::*; + + fn temp_lock_path(suffix: &str) -> PathBuf { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + std::env::temp_dir().join(format!( + "whispers-test-{suffix}-{}-{now}.pid", + std::process::id() + )) + } + + #[test] + fn signal_existing_instance_cleans_invalid_pid_file() { + let path = temp_lock_path("invalid"); + std::fs::write(&path, "not-a-pid").unwrap(); + assert!(!signal_existing_instance(&path).unwrap()); + assert!(!path.exists()); + } + + #[test] + fn signal_existing_instance_cleans_missing_process_pid_file() { + let path = temp_lock_path("missing-process"); + std::fs::write(&path, "99999999").unwrap(); + assert!(!signal_existing_instance(&path).unwrap()); + assert!(!path.exists()); + } + + #[test] + fn try_acquire_pid_lock_uses_create_new_semantics() { + let path = temp_lock_path("acquire"); + let lock = try_acquire_pid_lock(&path).unwrap(); + assert!(path.exists()); + + let err = match try_acquire_pid_lock(&path) { + Ok(_) => panic!("lock acquisition should fail when file already exists"), + Err(e) => e, + }; + assert_eq!(err.kind(), std::io::ErrorKind::AlreadyExists); + + drop(lock); + assert!(!path.exists()); + + let lock2 = try_acquire_pid_lock(&path).unwrap(); + drop(lock2); + assert!(!path.exists()); + } +} diff --git a/src/session.rs b/src/session.rs deleted file mode 100644 index 4732028..0000000 --- a/src/session.rs +++ /dev/null @@ -1,634 +0,0 @@ -use std::path::PathBuf; -use std::time::{SystemTime, UNIX_EPOCH}; - -use serde::{Deserialize, Serialize}; -use unicode_segmentation::UnicodeSegmentation; - -use crate::cleanup; -use crate::config::SessionConfig; -use crate::context::{SurfaceKind, TypingContext}; -use crate::error::{Result, WhsprError}; -use crate::rewrite_protocol::{ - RewriteSessionBacktrackCandidate, RewriteSessionBacktrackCandidateKind, RewriteSessionEntry, - RewriteSurfaceKind, RewriteTranscript, RewriteTypingContext, -}; - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -struct SessionFile { - next_id: u64, - entries: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SessionEntry { - pub id: u64, - pub final_text: String, - pub grapheme_len: usize, - pub injected_at_ms: u64, - pub focus_fingerprint: String, - pub surface_kind: SurfaceKind, - pub app_id: Option, - pub window_title: Option, - pub rewrite_summary: SessionRewriteSummary, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SessionRewriteSummary { - pub had_edit_cues: bool, - pub rewrite_used: bool, - pub recommended_candidate: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct EligibleSessionEntry { - pub entry: SessionEntry, - pub delete_graphemes: usize, -} - -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct SessionBacktrackPlan { - pub recent_entries: Vec, - pub candidates: Vec, - pub recommended: Option, - pub deterministic_replacement_text: Option, -} - -pub fn load_recent_entry( - config: &SessionConfig, - context: &TypingContext, -) -> Result> { - if !config.enabled || !context.is_known_focus() { - return Ok(None); - } - - let mut state = load_session_file()?; - prune_state(&mut state, config); - persist_session_file(&state)?; - - let Some(entry) = state.entries.last().cloned() else { - return Ok(None); - }; - - if entry.focus_fingerprint != context.focus_fingerprint { - return Ok(None); - } - - if entry.grapheme_len == 0 || entry.grapheme_len > config.max_replace_graphemes { - return Ok(None); - } - - Ok(Some(EligibleSessionEntry { - delete_graphemes: entry.grapheme_len, - entry, - })) -} - -pub fn record_append( - config: &SessionConfig, - context: &TypingContext, - final_text: &str, - rewrite_summary: SessionRewriteSummary, -) -> Result<()> { - if !config.enabled || final_text.trim().is_empty() || !context.is_known_focus() { - return Ok(()); - } - - let mut state = load_session_file()?; - prune_state(&mut state, config); - let entry = SessionEntry { - id: state.next_id, - final_text: final_text.to_string(), - grapheme_len: grapheme_count(final_text), - injected_at_ms: now_ms(), - focus_fingerprint: context.focus_fingerprint.clone(), - surface_kind: context.surface_kind, - app_id: context.app_id.clone(), - window_title: context.window_title.clone(), - rewrite_summary, - }; - state.next_id = state.next_id.saturating_add(1); - state.entries.push(entry); - trim_state(&mut state, config); - persist_session_file(&state) -} - -pub fn record_replace( - config: &SessionConfig, - context: &TypingContext, - replaced_entry_id: u64, - final_text: &str, - rewrite_summary: SessionRewriteSummary, -) -> Result<()> { - if !config.enabled || final_text.trim().is_empty() || !context.is_known_focus() { - return Ok(()); - } - - let mut state = load_session_file()?; - prune_state(&mut state, config); - if let Some(entry) = state - .entries - .iter_mut() - .find(|entry| entry.id == replaced_entry_id) - { - entry.final_text = final_text.to_string(); - entry.grapheme_len = grapheme_count(final_text); - entry.injected_at_ms = now_ms(); - entry.focus_fingerprint = context.focus_fingerprint.clone(); - entry.surface_kind = context.surface_kind; - entry.app_id = context.app_id.clone(); - entry.window_title = context.window_title.clone(); - entry.rewrite_summary = rewrite_summary; - } else { - return record_append(config, context, final_text, rewrite_summary); - } - trim_state(&mut state, config); - persist_session_file(&state) -} - -pub fn build_backtrack_plan( - transcript: &RewriteTranscript, - recent_entry: Option<&EligibleSessionEntry>, -) -> SessionBacktrackPlan { - let Some(recent_entry) = recent_entry else { - return SessionBacktrackPlan::default(); - }; - if !should_offer_session_backtrack(transcript) { - return SessionBacktrackPlan::default(); - } - - let append_text = preferred_current_text(transcript); - if append_text.is_empty() { - return SessionBacktrackPlan::default(); - } - - let append_candidate = RewriteSessionBacktrackCandidate { - kind: RewriteSessionBacktrackCandidateKind::AppendCurrent, - entry_id: None, - delete_graphemes: 0, - text: append_text.clone(), - }; - let replace_candidate = RewriteSessionBacktrackCandidate { - kind: RewriteSessionBacktrackCandidateKind::ReplaceLastEntry, - entry_id: Some(recent_entry.entry.id), - delete_graphemes: recent_entry.delete_graphemes, - text: append_text, - }; - - SessionBacktrackPlan { - recent_entries: vec![to_rewrite_session_entry(&recent_entry.entry)], - candidates: vec![replace_candidate.clone(), append_candidate], - recommended: Some(replace_candidate), - deterministic_replacement_text: preferred_current_text_for_exact_followup(transcript), - } -} - -pub fn to_rewrite_typing_context(context: &TypingContext) -> Option { - context.is_known_focus().then(|| RewriteTypingContext { - focus_fingerprint: context.focus_fingerprint.clone(), - app_id: context.app_id.clone(), - window_title: context.window_title.clone(), - surface_kind: map_surface_kind(context.surface_kind), - browser_domain: context.browser_domain.clone(), - captured_at_ms: context.captured_at_ms, - }) -} - -fn load_session_file() -> Result { - let path = session_file_path(); - if !path.exists() { - return Ok(SessionFile::default()); - } - - let contents = std::fs::read_to_string(&path).map_err(|e| { - WhsprError::Config(format!( - "failed to read session state {}: {e}", - path.display() - )) - })?; - serde_json::from_str(&contents).map_err(|e| { - WhsprError::Config(format!( - "failed to parse session state {}: {e}", - path.display() - )) - }) -} - -fn persist_session_file(state: &SessionFile) -> Result<()> { - let path = session_file_path(); - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - WhsprError::Config(format!( - "failed to create session runtime dir {}: {e}", - parent.display() - )) - })?; - } - let encoded = serde_json::to_vec(state) - .map_err(|e| WhsprError::Config(format!("failed to encode session state: {e}")))?; - std::fs::write(&path, encoded).map_err(|e| { - WhsprError::Config(format!( - "failed to write session state {}: {e}", - path.display() - )) - }) -} - -fn prune_state(state: &mut SessionFile, config: &SessionConfig) { - let cutoff = now_ms().saturating_sub(config.max_age_ms); - state.entries.retain(|entry| entry.injected_at_ms >= cutoff); - trim_state(state, config); - if state.next_id == 0 { - state.next_id = 1; - } -} - -fn trim_state(state: &mut SessionFile, config: &SessionConfig) { - if state.entries.len() > config.max_entries { - let remove_count = state.entries.len() - config.max_entries; - state.entries.drain(0..remove_count); - } -} - -fn session_file_path() -> PathBuf { - let runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".into()); - PathBuf::from(runtime_dir) - .join("whispers") - .join("session.json") -} - -fn grapheme_count(text: &str) -> usize { - UnicodeSegmentation::graphemes(text, true).count() -} - -fn should_offer_session_backtrack(transcript: &RewriteTranscript) -> bool { - if cleanup::explicit_followup_replacement(&transcript.raw_text).is_some() { - return true; - } - - if transcript.correction_aware_text.trim() == transcript.raw_text.trim() { - return false; - } - - let raw_prefix = normalize_prefix(&transcript.raw_text); - if ![ - "scratch that", - "actually scratch that", - "never mind", - "nevermind", - "actually never mind", - "actually nevermind", - "oh wait never mind", - "oh wait nevermind", - "forget that", - "wait no", - "actually wait no", - "i meant", - "actually i meant", - "i mean", - "actually i mean", - ] - .iter() - .any(|cue| raw_prefix.starts_with(cue)) - { - return false; - } - - transcript.edit_hypotheses.iter().any(|hypothesis| { - hypothesis.strength == crate::rewrite_protocol::RewriteEditSignalStrength::Strong - && matches!( - hypothesis.match_source, - crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact - | crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Alias - ) - }) -} - -fn preferred_current_text(transcript: &RewriteTranscript) -> String { - transcript - .recommended_candidate - .as_ref() - .map(|candidate| candidate.text.trim()) - .filter(|text| !text.is_empty()) - .or_else(|| Some(transcript.correction_aware_text.trim()).filter(|text| !text.is_empty())) - .or_else(|| Some(transcript.raw_text.trim()).filter(|text| !text.is_empty())) - .unwrap_or_default() - .to_string() -} - -fn preferred_current_text_for_exact_followup(transcript: &RewriteTranscript) -> Option { - if let Some(text) = cleanup::explicit_followup_replacement(&transcript.raw_text) { - return Some(text); - } - - if !has_strong_explicit_followup_cue(transcript) { - return None; - } - - let raw_prefix = normalize_prefix(&transcript.raw_text); - if ![ - "scratch that", - "actually scratch that", - "never mind", - "nevermind", - "actually never mind", - "actually nevermind", - "oh wait never mind", - "oh wait nevermind", - "forget that", - ] - .iter() - .any(|cue| raw_prefix.starts_with(cue)) - { - return None; - } - - let preferred = preferred_current_text(transcript); - (!preferred.is_empty()).then_some(preferred) -} - -fn has_strong_explicit_followup_cue(transcript: &RewriteTranscript) -> bool { - transcript.edit_hypotheses.iter().any(|hypothesis| { - hypothesis.strength == crate::rewrite_protocol::RewriteEditSignalStrength::Strong - && matches!( - hypothesis.match_source, - crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact - | crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Alias - ) - && matches!( - hypothesis.cue_family.as_str(), - "scratch_that" | "never_mind" - ) - }) -} - -fn normalize_prefix(text: &str) -> String { - text.chars() - .map(|ch| { - if ch.is_ascii_alphanumeric() || ch.is_ascii_whitespace() { - ch.to_ascii_lowercase() - } else { - ' ' - } - }) - .collect::() - .split_whitespace() - .take(4) - .collect::>() - .join(" ") -} - -fn to_rewrite_session_entry(entry: &SessionEntry) -> RewriteSessionEntry { - RewriteSessionEntry { - id: entry.id, - final_text: entry.final_text.clone(), - grapheme_len: entry.grapheme_len, - focus_fingerprint: entry.focus_fingerprint.clone(), - surface_kind: map_surface_kind(entry.surface_kind), - app_id: entry.app_id.clone(), - window_title: entry.window_title.clone(), - } -} - -fn map_surface_kind(kind: SurfaceKind) -> RewriteSurfaceKind { - match kind { - SurfaceKind::Browser => RewriteSurfaceKind::Browser, - SurfaceKind::Terminal => RewriteSurfaceKind::Terminal, - SurfaceKind::Editor => RewriteSurfaceKind::Editor, - SurfaceKind::GenericText => RewriteSurfaceKind::GenericText, - SurfaceKind::Unknown => RewriteSurfaceKind::Unknown, - } -} - -fn now_ms() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_millis() as u64) - .unwrap_or(0) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_support::{EnvVarGuard, env_lock, set_env, unique_temp_dir}; - - fn config() -> SessionConfig { - SessionConfig { - enabled: true, - max_entries: 3, - max_age_ms: 8_000, - max_replace_graphemes: 400, - } - } - - fn context() -> TypingContext { - TypingContext { - focus_fingerprint: "hyprland:0x123".into(), - app_id: Some("kitty".into()), - window_title: Some("shell".into()), - surface_kind: SurfaceKind::Terminal, - browser_domain: None, - captured_at_ms: 10, - } - } - - fn with_runtime_dir(f: impl FnOnce() -> T) -> T { - let _env_lock = env_lock(); - let _guard = EnvVarGuard::capture(&["XDG_RUNTIME_DIR"]); - let runtime_dir = unique_temp_dir("session-runtime"); - let runtime_dir = runtime_dir - .to_str() - .expect("temp runtime dir should be valid UTF-8"); - set_env("XDG_RUNTIME_DIR", runtime_dir); - f() - } - - #[test] - fn record_append_then_load_recent_entry() { - with_runtime_dir(|| { - record_append( - &config(), - &context(), - "Hello there", - SessionRewriteSummary { - had_edit_cues: false, - rewrite_used: false, - recommended_candidate: None, - }, - ) - .expect("record"); - - let entry = load_recent_entry(&config(), &context()) - .expect("load") - .expect("entry"); - assert_eq!(entry.entry.final_text, "Hello there"); - assert_eq!(entry.delete_graphemes, 11); - }); - } - - #[test] - fn load_recent_entry_requires_matching_focus() { - with_runtime_dir(|| { - record_append( - &config(), - &context(), - "Hello there", - SessionRewriteSummary { - had_edit_cues: false, - rewrite_used: false, - recommended_candidate: None, - }, - ) - .expect("record"); - - let mut other = context(); - other.focus_fingerprint = "hyprland:0x999".into(); - assert!( - load_recent_entry(&config(), &other) - .expect("load") - .is_none() - ); - }); - } - - #[test] - fn record_replace_updates_existing_entry() { - with_runtime_dir(|| { - let rewrite_summary = SessionRewriteSummary { - had_edit_cues: false, - rewrite_used: false, - recommended_candidate: None, - }; - record_append( - &config(), - &context(), - "Hello there", - rewrite_summary.clone(), - ) - .expect("record"); - let entry = load_recent_entry(&config(), &context()) - .expect("load") - .expect("entry"); - - record_replace(&config(), &context(), entry.entry.id, "Hi", rewrite_summary) - .expect("replace"); - - let replaced = load_recent_entry(&config(), &context()) - .expect("load") - .expect("entry"); - assert_eq!(replaced.entry.final_text, "Hi"); - assert_eq!(replaced.delete_graphemes, 2); - }); - } - - #[test] - fn build_backtrack_plan_prefers_replacing_recent_entry_for_follow_up_correction() { - let transcript = RewriteTranscript { - raw_text: "scratch that hi".into(), - correction_aware_text: "Hi".into(), - aggressive_correction_text: None, - detected_language: Some("en".into()), - typing_context: None, - recent_session_entries: Vec::new(), - session_backtrack_candidates: Vec::new(), - recommended_session_candidate: None, - segments: Vec::new(), - edit_intents: Vec::new(), - edit_signals: Vec::new(), - edit_hypotheses: vec![crate::rewrite_protocol::RewriteEditHypothesis { - cue_family: "scratch_that".into(), - matched_text: "scratch that".into(), - match_source: crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact, - kind: crate::rewrite_protocol::RewriteEditSignalKind::Cancel, - scope_hint: crate::rewrite_protocol::RewriteEditSignalScope::Sentence, - replacement_scope: crate::rewrite_protocol::RewriteReplacementScope::Sentence, - tail_shape: crate::rewrite_protocol::RewriteTailShape::Phrase, - strength: crate::rewrite_protocol::RewriteEditSignalStrength::Strong, - }], - rewrite_candidates: Vec::new(), - recommended_candidate: Some(crate::rewrite_protocol::RewriteCandidate { - kind: crate::rewrite_protocol::RewriteCandidateKind::SentenceReplacement, - text: "Hi".into(), - }), - }; - - let recent = EligibleSessionEntry { - entry: SessionEntry { - id: 7, - final_text: "Hello there".into(), - grapheme_len: 11, - injected_at_ms: 1, - focus_fingerprint: "hyprland:0x123".into(), - surface_kind: SurfaceKind::GenericText, - app_id: Some("firefox".into()), - window_title: Some("Example".into()), - rewrite_summary: SessionRewriteSummary { - had_edit_cues: false, - rewrite_used: true, - recommended_candidate: Some("Hello there".into()), - }, - }, - delete_graphemes: 11, - }; - - let plan = build_backtrack_plan(&transcript, Some(&recent)); - assert_eq!(plan.recent_entries.len(), 1); - assert_eq!(plan.candidates.len(), 2); - assert_eq!( - plan.recommended.as_ref().map(|candidate| candidate.kind), - Some(RewriteSessionBacktrackCandidateKind::ReplaceLastEntry) - ); - assert_eq!( - plan.recommended - .as_ref() - .and_then(|candidate| candidate.entry_id), - Some(7) - ); - assert_eq!(plan.deterministic_replacement_text.as_deref(), Some("Hi")); - } - - #[test] - fn build_backtrack_plan_uses_raw_followup_fallback_without_hypotheses() { - let transcript = RewriteTranscript { - raw_text: "scratch that hi".into(), - correction_aware_text: "scratch that hi".into(), - aggressive_correction_text: None, - detected_language: None, - typing_context: None, - recent_session_entries: Vec::new(), - session_backtrack_candidates: Vec::new(), - recommended_session_candidate: None, - segments: Vec::new(), - edit_intents: Vec::new(), - edit_signals: Vec::new(), - edit_hypotheses: Vec::new(), - rewrite_candidates: Vec::new(), - recommended_candidate: None, - }; - - let recent = EligibleSessionEntry { - entry: SessionEntry { - id: 7, - final_text: "Hello there".into(), - grapheme_len: 11, - injected_at_ms: 1, - focus_fingerprint: "hyprland:0x123".into(), - surface_kind: SurfaceKind::GenericText, - app_id: Some("firefox".into()), - window_title: Some("Example".into()), - rewrite_summary: SessionRewriteSummary { - had_edit_cues: false, - rewrite_used: true, - recommended_candidate: Some("Hello there".into()), - }, - }, - delete_graphemes: 11, - }; - - let plan = build_backtrack_plan(&transcript, Some(&recent)); - assert_eq!( - plan.recommended.as_ref().map(|candidate| candidate.kind), - Some(RewriteSessionBacktrackCandidateKind::ReplaceLastEntry) - ); - assert_eq!(plan.deterministic_replacement_text.as_deref(), Some("Hi")); - } -} diff --git a/src/session/mod.rs b/src/session/mod.rs new file mode 100644 index 0000000..031de20 --- /dev/null +++ b/src/session/mod.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; + +use crate::context::SurfaceKind; +use crate::rewrite_protocol::{RewriteSessionBacktrackCandidate, RewriteSessionEntry}; + +mod persistence; +mod planning; + +pub use persistence::{load_recent_entry, record_append, record_replace}; +pub use planning::{build_backtrack_plan, to_rewrite_typing_context}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionEntry { + pub id: u64, + pub final_text: String, + pub grapheme_len: usize, + pub injected_at_ms: u64, + pub focus_fingerprint: String, + pub surface_kind: SurfaceKind, + pub app_id: Option, + pub window_title: Option, + pub rewrite_summary: SessionRewriteSummary, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionRewriteSummary { + pub had_edit_cues: bool, + pub rewrite_used: bool, + pub recommended_candidate: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EligibleSessionEntry { + pub entry: SessionEntry, + pub delete_graphemes: usize, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct SessionBacktrackPlan { + pub recent_entries: Vec, + pub candidates: Vec, + pub recommended: Option, + pub deterministic_replacement_text: Option, +} diff --git a/src/session/persistence.rs b/src/session/persistence.rs new file mode 100644 index 0000000..51d111f --- /dev/null +++ b/src/session/persistence.rs @@ -0,0 +1,301 @@ +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::config::SessionConfig; +use crate::context::TypingContext; +use crate::error::{Result, WhsprError}; + +use super::{EligibleSessionEntry, SessionEntry, SessionRewriteSummary}; + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +struct SessionFile { + next_id: u64, + entries: Vec, +} + +pub fn load_recent_entry( + config: &SessionConfig, + context: &TypingContext, +) -> Result> { + if !config.enabled || !context.is_known_focus() { + return Ok(None); + } + + let mut state = load_session_file()?; + prune_state(&mut state, config); + persist_session_file(&state)?; + + let Some(entry) = state.entries.last().cloned() else { + return Ok(None); + }; + + if entry.focus_fingerprint != context.focus_fingerprint { + return Ok(None); + } + + if entry.grapheme_len == 0 || entry.grapheme_len > config.max_replace_graphemes { + return Ok(None); + } + + Ok(Some(EligibleSessionEntry { + delete_graphemes: entry.grapheme_len, + entry, + })) +} + +pub fn record_append( + config: &SessionConfig, + context: &TypingContext, + final_text: &str, + rewrite_summary: SessionRewriteSummary, +) -> Result<()> { + if !config.enabled || final_text.trim().is_empty() || !context.is_known_focus() { + return Ok(()); + } + + let mut state = load_session_file()?; + prune_state(&mut state, config); + let entry = SessionEntry { + id: state.next_id, + final_text: final_text.to_string(), + grapheme_len: grapheme_count(final_text), + injected_at_ms: now_ms(), + focus_fingerprint: context.focus_fingerprint.clone(), + surface_kind: context.surface_kind, + app_id: context.app_id.clone(), + window_title: context.window_title.clone(), + rewrite_summary, + }; + state.next_id = state.next_id.saturating_add(1); + state.entries.push(entry); + trim_state(&mut state, config); + persist_session_file(&state) +} + +pub fn record_replace( + config: &SessionConfig, + context: &TypingContext, + replaced_entry_id: u64, + final_text: &str, + rewrite_summary: SessionRewriteSummary, +) -> Result<()> { + if !config.enabled || final_text.trim().is_empty() || !context.is_known_focus() { + return Ok(()); + } + + let mut state = load_session_file()?; + prune_state(&mut state, config); + if let Some(entry) = state + .entries + .iter_mut() + .find(|entry| entry.id == replaced_entry_id) + { + entry.final_text = final_text.to_string(); + entry.grapheme_len = grapheme_count(final_text); + entry.injected_at_ms = now_ms(); + entry.focus_fingerprint = context.focus_fingerprint.clone(); + entry.surface_kind = context.surface_kind; + entry.app_id = context.app_id.clone(); + entry.window_title = context.window_title.clone(); + entry.rewrite_summary = rewrite_summary; + } else { + return record_append(config, context, final_text, rewrite_summary); + } + trim_state(&mut state, config); + persist_session_file(&state) +} + +fn load_session_file() -> Result { + let path = session_file_path(); + if !path.exists() { + return Ok(SessionFile::default()); + } + + let contents = std::fs::read_to_string(&path).map_err(|e| { + WhsprError::Config(format!( + "failed to read session state {}: {e}", + path.display() + )) + })?; + serde_json::from_str(&contents).map_err(|e| { + WhsprError::Config(format!( + "failed to parse session state {}: {e}", + path.display() + )) + }) +} + +fn persist_session_file(state: &SessionFile) -> Result<()> { + let path = session_file_path(); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|e| { + WhsprError::Config(format!( + "failed to create session runtime dir {}: {e}", + parent.display() + )) + })?; + } + let encoded = serde_json::to_vec(state) + .map_err(|e| WhsprError::Config(format!("failed to encode session state: {e}")))?; + std::fs::write(&path, encoded).map_err(|e| { + WhsprError::Config(format!( + "failed to write session state {}: {e}", + path.display() + )) + }) +} + +fn prune_state(state: &mut SessionFile, config: &SessionConfig) { + let cutoff = now_ms().saturating_sub(config.max_age_ms); + state.entries.retain(|entry| entry.injected_at_ms >= cutoff); + trim_state(state, config); + if state.next_id == 0 { + state.next_id = 1; + } +} + +fn trim_state(state: &mut SessionFile, config: &SessionConfig) { + if state.entries.len() > config.max_entries { + let remove_count = state.entries.len() - config.max_entries; + state.entries.drain(0..remove_count); + } +} + +fn session_file_path() -> PathBuf { + let runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".into()); + PathBuf::from(runtime_dir) + .join("whispers") + .join("session.json") +} + +fn grapheme_count(text: &str) -> usize { + UnicodeSegmentation::graphemes(text, true).count() +} + +fn now_ms() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_millis() as u64) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::{load_recent_entry, record_append, record_replace}; + use crate::config::SessionConfig; + use crate::context::{SurfaceKind, TypingContext}; + use crate::session::SessionRewriteSummary; + use crate::test_support::{EnvVarGuard, env_lock, set_env, unique_temp_dir}; + + fn config() -> SessionConfig { + SessionConfig { + enabled: true, + max_entries: 3, + max_age_ms: 8_000, + max_replace_graphemes: 400, + } + } + + fn context() -> TypingContext { + TypingContext { + focus_fingerprint: "hyprland:0x123".into(), + app_id: Some("kitty".into()), + window_title: Some("shell".into()), + surface_kind: SurfaceKind::Terminal, + browser_domain: None, + captured_at_ms: 10, + } + } + + fn with_runtime_dir(f: impl FnOnce() -> T) -> T { + let _env_lock = env_lock(); + let _guard = EnvVarGuard::capture(&["XDG_RUNTIME_DIR"]); + let runtime_dir = unique_temp_dir("session-runtime"); + let runtime_dir = runtime_dir + .to_str() + .expect("temp runtime dir should be valid UTF-8"); + set_env("XDG_RUNTIME_DIR", runtime_dir); + f() + } + + #[test] + fn record_append_then_load_recent_entry() { + with_runtime_dir(|| { + record_append( + &config(), + &context(), + "Hello there", + SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: false, + recommended_candidate: None, + }, + ) + .expect("record"); + + let entry = load_recent_entry(&config(), &context()) + .expect("load") + .expect("entry"); + assert_eq!(entry.entry.final_text, "Hello there"); + assert_eq!(entry.delete_graphemes, 11); + }); + } + + #[test] + fn load_recent_entry_requires_matching_focus() { + with_runtime_dir(|| { + record_append( + &config(), + &context(), + "Hello there", + SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: false, + recommended_candidate: None, + }, + ) + .expect("record"); + + let mut other = context(); + other.focus_fingerprint = "hyprland:0x999".into(); + assert!( + load_recent_entry(&config(), &other) + .expect("load") + .is_none() + ); + }); + } + + #[test] + fn record_replace_updates_existing_entry() { + with_runtime_dir(|| { + let rewrite_summary = SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: false, + recommended_candidate: None, + }; + record_append( + &config(), + &context(), + "Hello there", + rewrite_summary.clone(), + ) + .expect("record"); + let entry = load_recent_entry(&config(), &context()) + .expect("load") + .expect("entry"); + + record_replace(&config(), &context(), entry.entry.id, "Hi", rewrite_summary) + .expect("replace"); + + let replaced = load_recent_entry(&config(), &context()) + .expect("load") + .expect("entry"); + assert_eq!(replaced.entry.final_text, "Hi"); + assert_eq!(replaced.delete_graphemes, 2); + }); + } +} diff --git a/src/session/planning.rs b/src/session/planning.rs new file mode 100644 index 0000000..1eeedbf --- /dev/null +++ b/src/session/planning.rs @@ -0,0 +1,324 @@ +use crate::cleanup; +use crate::context::{SurfaceKind, TypingContext}; +use crate::rewrite_protocol::{ + RewriteSessionBacktrackCandidate, RewriteSessionBacktrackCandidateKind, RewriteSessionEntry, + RewriteSurfaceKind, RewriteTranscript, RewriteTypingContext, +}; + +use super::{EligibleSessionEntry, SessionBacktrackPlan, SessionEntry}; + +pub fn build_backtrack_plan( + transcript: &RewriteTranscript, + recent_entry: Option<&EligibleSessionEntry>, +) -> SessionBacktrackPlan { + let Some(recent_entry) = recent_entry else { + return SessionBacktrackPlan::default(); + }; + if !should_offer_session_backtrack(transcript) { + return SessionBacktrackPlan::default(); + } + + let append_text = preferred_current_text(transcript); + if append_text.is_empty() { + return SessionBacktrackPlan::default(); + } + + let append_candidate = RewriteSessionBacktrackCandidate { + kind: RewriteSessionBacktrackCandidateKind::AppendCurrent, + entry_id: None, + delete_graphemes: 0, + text: append_text.clone(), + }; + let replace_candidate = RewriteSessionBacktrackCandidate { + kind: RewriteSessionBacktrackCandidateKind::ReplaceLastEntry, + entry_id: Some(recent_entry.entry.id), + delete_graphemes: recent_entry.delete_graphemes, + text: append_text, + }; + + SessionBacktrackPlan { + recent_entries: vec![to_rewrite_session_entry(&recent_entry.entry)], + candidates: vec![replace_candidate.clone(), append_candidate], + recommended: Some(replace_candidate), + deterministic_replacement_text: preferred_current_text_for_exact_followup(transcript), + } +} + +pub fn to_rewrite_typing_context(context: &TypingContext) -> Option { + context.is_known_focus().then(|| RewriteTypingContext { + focus_fingerprint: context.focus_fingerprint.clone(), + app_id: context.app_id.clone(), + window_title: context.window_title.clone(), + surface_kind: map_surface_kind(context.surface_kind), + browser_domain: context.browser_domain.clone(), + captured_at_ms: context.captured_at_ms, + }) +} + +fn should_offer_session_backtrack(transcript: &RewriteTranscript) -> bool { + if cleanup::explicit_followup_replacement(&transcript.raw_text).is_some() { + return true; + } + + if transcript.correction_aware_text.trim() == transcript.raw_text.trim() { + return false; + } + + let raw_prefix = normalize_prefix(&transcript.raw_text); + if ![ + "scratch that", + "actually scratch that", + "never mind", + "nevermind", + "actually never mind", + "actually nevermind", + "oh wait never mind", + "oh wait nevermind", + "forget that", + "wait no", + "actually wait no", + "i meant", + "actually i meant", + "i mean", + "actually i mean", + ] + .iter() + .any(|cue| raw_prefix.starts_with(cue)) + { + return false; + } + + transcript.edit_hypotheses.iter().any(|hypothesis| { + hypothesis.strength == crate::rewrite_protocol::RewriteEditSignalStrength::Strong + && matches!( + hypothesis.match_source, + crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact + | crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Alias + ) + }) +} + +fn preferred_current_text(transcript: &RewriteTranscript) -> String { + transcript + .recommended_candidate + .as_ref() + .map(|candidate| candidate.text.trim()) + .filter(|text: &&str| !text.is_empty()) + .or_else(|| { + Some(transcript.correction_aware_text.trim()).filter(|text: &&str| !text.is_empty()) + }) + .or_else(|| Some(transcript.raw_text.trim()).filter(|text: &&str| !text.is_empty())) + .unwrap_or_default() + .to_string() +} + +fn preferred_current_text_for_exact_followup(transcript: &RewriteTranscript) -> Option { + if let Some(text) = cleanup::explicit_followup_replacement(&transcript.raw_text) { + return Some(text); + } + + if !has_strong_explicit_followup_cue(transcript) { + return None; + } + + let raw_prefix = normalize_prefix(&transcript.raw_text); + if ![ + "scratch that", + "actually scratch that", + "never mind", + "nevermind", + "actually never mind", + "actually nevermind", + "oh wait never mind", + "oh wait nevermind", + "forget that", + ] + .iter() + .any(|cue| raw_prefix.starts_with(cue)) + { + return None; + } + + let preferred = preferred_current_text(transcript); + (!preferred.is_empty()).then_some(preferred) +} + +fn has_strong_explicit_followup_cue(transcript: &RewriteTranscript) -> bool { + transcript.edit_hypotheses.iter().any(|hypothesis| { + hypothesis.strength == crate::rewrite_protocol::RewriteEditSignalStrength::Strong + && matches!( + hypothesis.match_source, + crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Exact + | crate::rewrite_protocol::RewriteEditHypothesisMatchSource::Alias + ) + && matches!( + hypothesis.cue_family.as_str(), + "scratch_that" | "never_mind" + ) + }) +} + +fn normalize_prefix(text: &str) -> String { + text.chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || ch.is_ascii_whitespace() { + ch.to_ascii_lowercase() + } else { + ' ' + } + }) + .collect::() + .split_whitespace() + .take(4) + .collect::>() + .join(" ") +} + +fn to_rewrite_session_entry(entry: &SessionEntry) -> RewriteSessionEntry { + RewriteSessionEntry { + id: entry.id, + final_text: entry.final_text.clone(), + grapheme_len: entry.grapheme_len, + focus_fingerprint: entry.focus_fingerprint.clone(), + surface_kind: map_surface_kind(entry.surface_kind), + app_id: entry.app_id.clone(), + window_title: entry.window_title.clone(), + } +} + +fn map_surface_kind(kind: SurfaceKind) -> RewriteSurfaceKind { + match kind { + SurfaceKind::Browser => RewriteSurfaceKind::Browser, + SurfaceKind::Terminal => RewriteSurfaceKind::Terminal, + SurfaceKind::Editor => RewriteSurfaceKind::Editor, + SurfaceKind::GenericText => RewriteSurfaceKind::GenericText, + SurfaceKind::Unknown => RewriteSurfaceKind::Unknown, + } +} + +#[cfg(test)] +mod tests { + use super::build_backtrack_plan; + use crate::context::SurfaceKind; + use crate::rewrite_protocol::{ + RewriteCandidate, RewriteCandidateKind, RewriteEditHypothesis, + RewriteEditHypothesisMatchSource, RewriteEditSignalKind, RewriteEditSignalStrength, + RewritePolicyContext, RewriteReplacementScope, RewriteSessionBacktrackCandidateKind, + RewriteTailShape, RewriteTranscript, + }; + use crate::session::{EligibleSessionEntry, SessionEntry, SessionRewriteSummary}; + + #[test] + fn build_backtrack_plan_prefers_replacing_recent_entry_for_follow_up_correction() { + let transcript = RewriteTranscript { + raw_text: "scratch that hi".into(), + correction_aware_text: "Hi".into(), + aggressive_correction_text: None, + detected_language: Some("en".into()), + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: vec![RewriteEditHypothesis { + cue_family: "scratch_that".into(), + matched_text: "scratch that".into(), + match_source: RewriteEditHypothesisMatchSource::Exact, + kind: RewriteEditSignalKind::Cancel, + scope_hint: crate::rewrite_protocol::RewriteEditSignalScope::Sentence, + replacement_scope: RewriteReplacementScope::Sentence, + tail_shape: RewriteTailShape::Phrase, + strength: RewriteEditSignalStrength::Strong, + }], + rewrite_candidates: Vec::new(), + recommended_candidate: Some(RewriteCandidate { + kind: RewriteCandidateKind::SentenceReplacement, + text: "Hi".into(), + }), + policy_context: RewritePolicyContext::default(), + }; + + let recent = EligibleSessionEntry { + entry: SessionEntry { + id: 7, + final_text: "Hello there".into(), + grapheme_len: 11, + injected_at_ms: 1, + focus_fingerprint: "hyprland:0x123".into(), + surface_kind: SurfaceKind::GenericText, + app_id: Some("firefox".into()), + window_title: Some("Example".into()), + rewrite_summary: SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: true, + recommended_candidate: Some("Hello there".into()), + }, + }, + delete_graphemes: 11, + }; + + let plan = build_backtrack_plan(&transcript, Some(&recent)); + assert_eq!(plan.recent_entries.len(), 1); + assert_eq!(plan.candidates.len(), 2); + assert_eq!( + plan.recommended.as_ref().map(|candidate| candidate.kind), + Some(RewriteSessionBacktrackCandidateKind::ReplaceLastEntry) + ); + assert_eq!( + plan.recommended + .as_ref() + .and_then(|candidate| candidate.entry_id), + Some(7) + ); + assert_eq!(plan.deterministic_replacement_text.as_deref(), Some("Hi")); + } + + #[test] + fn build_backtrack_plan_uses_raw_followup_fallback_without_hypotheses() { + let transcript = RewriteTranscript { + raw_text: "scratch that hi".into(), + correction_aware_text: "scratch that hi".into(), + aggressive_correction_text: None, + detected_language: None, + typing_context: None, + recent_session_entries: Vec::new(), + session_backtrack_candidates: Vec::new(), + recommended_session_candidate: None, + segments: Vec::new(), + edit_intents: Vec::new(), + edit_signals: Vec::new(), + edit_hypotheses: Vec::new(), + rewrite_candidates: Vec::new(), + recommended_candidate: None, + policy_context: RewritePolicyContext::default(), + }; + + let recent = EligibleSessionEntry { + entry: SessionEntry { + id: 7, + final_text: "Hello there".into(), + grapheme_len: 11, + injected_at_ms: 1, + focus_fingerprint: "hyprland:0x123".into(), + surface_kind: SurfaceKind::GenericText, + app_id: Some("firefox".into()), + window_title: Some("Example".into()), + rewrite_summary: SessionRewriteSummary { + had_edit_cues: false, + rewrite_used: true, + recommended_candidate: Some("Hello there".into()), + }, + }, + delete_graphemes: 11, + }; + + let plan = build_backtrack_plan(&transcript, Some(&recent)); + assert_eq!( + plan.recommended.as_ref().map(|candidate| candidate.kind), + Some(RewriteSessionBacktrackCandidateKind::ReplaceLastEntry) + ); + assert_eq!(plan.deterministic_replacement_text.as_deref(), Some("Hi")); + } +} diff --git a/src/setup.rs b/src/setup.rs index 7ab17a7..87aa2f8 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -1,17 +1,24 @@ +mod apply; +mod report; +mod select; +mod side_effects; + +#[cfg(test)] +mod tests; + use std::path::Path; use crate::asr_model::{self, ASR_MODELS, AsrModelInfo}; use crate::config::{ - self, CloudLanguageMode, CloudProvider, CloudSettingsUpdate, PostprocessMode, RewriteBackend, - RewriteFallback, TranscriptionBackend, TranscriptionFallback, resolve_config_path, + CloudProvider, PostprocessMode, RewriteFallback, TranscriptionFallback, resolve_config_path, }; use crate::error::Result; -use crate::rewrite_model::{self, REWRITE_MODELS}; use crate::ui::SetupUi; struct SetupSelections { asr_model: &'static AsrModelInfo, rewrite_model: Option<&'static str>, + postprocess_mode: PostprocessMode, cloud: CloudSetup, } @@ -51,52 +58,59 @@ pub async fn run_setup(config_path_override: Option<&Path>) -> Result<()> { let ui = SetupUi::new(); ui.print_header("whispers setup"); ui.blank(); - print_setup_intro(&ui); + report::print_setup_intro(&ui); let available_asr_models: Vec<_> = ASR_MODELS .iter() .filter(|model| asr_model::is_model_available(model.name)) .collect(); - let asr_model = choose_asr_model(&ui, &available_asr_models)?; - ui.blank(); - tracing::info!("setup selected ASR model: {}", asr_model.name); - asr_model::download_model(asr_model.name).await?; - ui.blank(); + let asr_model = select::choose_asr_model(&ui, &available_asr_models)?; + side_effects::download_asr_model(&ui, asr_model).await?; let mut rewrite_model = None; - if ui.confirm("Enable smarter local rewrite cleanup?", false)? { - rewrite_model = Some(choose_rewrite_model(&ui, "Choose a local rewrite model", 1).await?); + let mut postprocess_mode = PostprocessMode::Raw; + if crate::rewrite::local_rewrite_available() + && ui.confirm("Enable smarter local rewrite cleanup?", false)? + { + let selected_rewrite_model = + select::choose_rewrite_model(&ui, "Choose a local rewrite model", 1)?; + side_effects::download_rewrite_model(&ui, selected_rewrite_model).await?; + rewrite_model = Some(selected_rewrite_model); + postprocess_mode = select::choose_rewrite_mode(&ui)?; + } else if !crate::rewrite::local_rewrite_available() { + ui.print_info( + "This build does not include local rewrite support. You can still enable cloud rewrite.", + ); + ui.blank(); } + let rewrite_model_before_cloud = rewrite_model; let cloud = if ui.confirm("Add optional cloud ASR or rewrite?", false)? { - configure_cloud(&ui, &mut rewrite_model).await? + select::configure_cloud(&ui, &mut rewrite_model)? } else { CloudSetup::default() }; + if rewrite_model_before_cloud.is_none() { + if let Some(selected_rewrite_model) = rewrite_model { + side_effects::download_rewrite_model(&ui, selected_rewrite_model).await?; + } + } + if cloud.rewrite_enabled && postprocess_mode == PostprocessMode::Raw { + postprocess_mode = select::choose_rewrite_mode(&ui)?; + } let selections = SetupSelections { asr_model, rewrite_model, + postprocess_mode, cloud, }; let config_path = resolve_config_path(config_path_override); - ensure_config_exists(&ui, &config_path)?; - asr_model::select_model(selections.asr_model.name, config_path_override)?; - - if let Some(rewrite_model) = selections.rewrite_model { - config::update_config_rewrite_selection(&config_path, rewrite_model)?; - } - - apply_postprocess_selection(&ui, &config_path, &selections)?; - apply_runtime_backend_selection( - &config_path, - selections.asr_model.backend, - &selections.cloud, - )?; - apply_cloud_settings(&ui, &config_path, &selections.cloud)?; - cleanup_stale_asr_workers(&ui, &config_path)?; + apply::apply_setup_config(&ui, &config_path, config_path_override, &selections)?; + side_effects::maybe_create_agentic_starter_files(&ui, &config_path, &selections)?; + side_effects::cleanup_stale_asr_workers(&ui, &config_path)?; if let Some(rewrite_model) = selections.rewrite_model { ui.print_ok(format!( @@ -105,408 +119,13 @@ pub async fn run_setup(config_path_override: Option<&Path>) -> Result<()> { )); } - maybe_prewarm_experimental_nemo(&ui, &config_path, &selections)?; + side_effects::maybe_prewarm_experimental_nemo(&ui, &config_path, &selections)?; ui.print_ok("Config saved."); ui.blank(); - print_setup_summary(&ui, &selections); + report::print_setup_summary(&ui, &selections); ui.blank(); - print_setup_complete(&ui); + report::print_setup_complete(&ui); Ok(()) } - -fn print_setup_intro(ui: &SetupUi) { - ui.print_subtle( - "Recommended models are listed first. Experimental backends are available, but not the default recommendation.", - ); - ui.blank(); -} - -fn ensure_config_exists(ui: &SetupUi, config_path: &Path) -> Result<()> { - let default_model_path = crate::model::model_path_for_config("ggml-large-v3-turbo.bin"); - if !config_path.exists() { - tracing::info!("writing new config at {}", config_path.display()); - config::write_default_config(config_path, &default_model_path)?; - return Ok(()); - } - - tracing::info!("updating existing config at {}", config_path.display()); - ui.print_info("Updating existing config."); - Ok(()) -} - -fn apply_postprocess_selection( - ui: &SetupUi, - config_path: &Path, - selections: &SetupSelections, -) -> Result<()> { - if selections.cloud.rewrite_enabled || selections.rewrite_model.is_some() { - config::update_config_postprocess_mode(config_path, PostprocessMode::AdvancedLocal)?; - } else { - config::update_config_postprocess_mode(config_path, PostprocessMode::Raw)?; - ui.print_info("Rewrite cleanup: disabled (raw mode)."); - } - Ok(()) -} - -fn apply_cloud_settings(ui: &SetupUi, config_path: &Path, cloud: &CloudSetup) -> Result<()> { - if !cloud.enabled() { - return Ok(()); - } - - config::update_config_cloud_settings( - config_path, - &CloudSettingsUpdate { - provider: cloud.provider, - base_url: &cloud.base_url, - api_key: &cloud.api_key, - api_key_env: &cloud.api_key_env, - connect_timeout_ms: 3000, - request_timeout_ms: 15000, - transcription_model: "gpt-4o-mini-transcribe", - transcription_language_mode: CloudLanguageMode::InheritLocal, - transcription_language: "", - rewrite_model: "gpt-4.1-mini", - rewrite_temperature: 0.1, - rewrite_max_output_tokens: 256, - }, - )?; - - if cloud.api_key.is_empty() { - ui.print_ok(format!( - "Cloud provider: {} (using API key env {}).", - crate::ui::provider_token(cloud.provider.as_str()), - crate::ui::value(&cloud.api_key_env) - )); - } else { - ui.print_ok(format!( - "Cloud provider: {} (using a locally stored API key).", - crate::ui::provider_token(cloud.provider.as_str()) - )); - } - - if cloud.api_key.is_empty() && std::env::var(&cloud.api_key_env).is_err() { - ui.print_warn(format!( - "{} is not set in the current environment yet.", - crate::ui::value(&cloud.api_key_env) - )); - } - - Ok(()) -} - -fn maybe_prewarm_experimental_nemo( - ui: &SetupUi, - config_path: &Path, - selections: &SetupSelections, -) -> Result<()> { - if selections.asr_model.backend != TranscriptionBackend::Nemo || selections.cloud.asr_enabled { - return Ok(()); - } - - let spinner = - crate::ui::spinner("Starting background warm-up for the experimental NeMo backend..."); - match config::Config::load(Some(config_path)).and_then(|config| asr_model_prewarm(&config)) { - Ok(()) => { - spinner.finish_and_clear(); - ui.print_info("Background warm-up started for the experimental NeMo backend."); - } - Err(err) => { - spinner.finish_and_clear(); - ui.print_warn(format!( - "Failed to prewarm NeMo ASR backend after setup: {err}" - )); - } - } - - Ok(()) -} - -fn cleanup_stale_asr_workers(ui: &SetupUi, config_path: &Path) -> Result<()> { - match config::Config::load(Some(config_path)) - .and_then(|config| crate::asr::cleanup_stale_transcribers(&config)) - { - Ok(()) => Ok(()), - Err(err) => { - ui.print_warn(format!( - "Failed to retire stale ASR workers after setup: {err}" - )); - Ok(()) - } - } -} - -fn asr_model_prewarm(config: &config::Config) -> Result<()> { - let prepared = crate::asr::prepare_transcriber(config)?; - crate::asr::prewarm_transcriber(&prepared, "setup"); - Ok(()) -} - -fn choose_asr_model( - ui: &SetupUi, - available_asr_models: &[&'static AsrModelInfo], -) -> Result<&'static AsrModelInfo> { - loop { - let items: Vec = available_asr_models - .iter() - .map(|model| asr_model::setup_label(model)) - .collect(); - let selection = ui.select("Choose an ASR model", &items, 0)?; - let chosen = available_asr_models[selection]; - - if let Some(warning) = asr_model::experimental_warning(chosen) { - ui.blank(); - tracing::debug!( - "experimental setup warning for {}: {}", - chosen.name, - warning - ); - ui.print_experimental_notice(chosen.name, asr_model::experimental_notice_facts(chosen)); - if !ui.danger_confirm( - crate::ui::danger_text("Continue with this experimental ASR backend?"), - false, - )? { - ui.blank(); - continue; - } - } - - return Ok(chosen); - } -} - -async fn choose_rewrite_model( - ui: &SetupUi, - prompt: &str, - default_index: usize, -) -> Result<&'static str> { - let items: Vec = REWRITE_MODELS - .iter() - .map(rewrite_model::setup_label) - .collect(); - let selection = ui.select(prompt, &items, default_index)?; - let chosen = &REWRITE_MODELS[selection]; - ui.blank(); - tracing::info!("setup selected rewrite model: {}", chosen.name); - rewrite_model::download_model(chosen.name).await?; - ui.blank(); - Ok(chosen.name) -} - -async fn configure_cloud( - ui: &SetupUi, - rewrite_model: &mut Option<&'static str>, -) -> Result { - let mut cloud = CloudSetup::default(); - - let provider_items = vec![ - format!( - "{} {}", - crate::ui::provider_token("OpenAI"), - crate::ui::description_token("Official OpenAI hosted API") - ), - format!( - "{} {}", - crate::ui::provider_token("OpenAI-compatible endpoint"), - crate::ui::description_token("Third-party endpoint that speaks the OpenAI API") - ), - ]; - let provider_selection = ui.select("Choose a cloud provider", &provider_items, 0)?; - if provider_selection == 1 { - cloud.provider = CloudProvider::OpenAiCompatible; - cloud.base_url = ui.input_string("Base URL for the OpenAI-compatible API", None)?; - } - - let cloud_key_input = ui.input_string( - "Cloud API key or environment variable name", - Some("OPENAI_API_KEY"), - )?; - if looks_like_cloud_api_key(&cloud_key_input) { - cloud.api_key = cloud_key_input.trim().to_string(); - cloud.api_key_env = "OPENAI_API_KEY".into(); - } else { - cloud.api_key_env = cloud_key_input.trim().to_string(); - } - - let cloud_mode_items = [ - "Cloud rewrite only", - "Cloud ASR only", - "Cloud ASR + rewrite", - ]; - let cloud_mode = ui.select("Choose the cloud mode", &cloud_mode_items, 0)?; - cloud.rewrite_enabled = matches!(cloud_mode, 0 | 2); - cloud.asr_enabled = matches!(cloud_mode, 1 | 2); - - if cloud.asr_enabled { - let fallback_items = ["Use configured local fallback", "Fail if cloud ASR fails"]; - let selection = ui.select("If cloud ASR fails", &fallback_items, 0)?; - cloud.asr_fallback = if selection == 0 { - TranscriptionFallback::ConfiguredLocal - } else { - TranscriptionFallback::None - }; - } - - if cloud.rewrite_enabled { - let fallback_items = [ - "Use local rewrite fallback", - "Fail back to deterministic text", - ]; - let selection = ui.select("If cloud rewrite fails", &fallback_items, 0)?; - cloud.rewrite_fallback = if selection == 0 { - RewriteFallback::Local - } else { - RewriteFallback::None - }; - } - - if cloud.rewrite_enabled - && cloud.rewrite_fallback == RewriteFallback::Local - && rewrite_model.is_none() - { - ui.blank(); - ui.print_info("Cloud rewrite fallback uses a local rewrite model."); - *rewrite_model = Some( - choose_rewrite_model(ui, "Choose a local rewrite fallback model to download", 1) - .await?, - ); - } - - Ok(cloud) -} - -fn looks_like_cloud_api_key(value: &str) -> bool { - value.trim().starts_with("sk-") -} - -fn print_setup_summary(ui: &SetupUi, selections: &SetupSelections) { - ui.print_section("Setup summary"); - println!( - " {}: {} ({}, {}, {})", - crate::ui::summary_key("ASR"), - crate::ui::value(selections.asr_model.name), - crate::ui::backend_token(selections.asr_model.backend.as_str()), - crate::ui::scope_token(selections.asr_model.language_scope.as_str()), - crate::ui::tier_token(selections.asr_model.support_tier.as_str()) - ); - - if let Some(note) = selections.asr_model.setup_note { - println!(" {}: {}", crate::ui::summary_key("ASR note"), note); - } - - if selections.cloud.asr_enabled { - println!( - " {}: enabled via {} (fallback: {})", - crate::ui::summary_key("Cloud ASR"), - crate::ui::provider_token(selections.cloud.provider.as_str()), - selections.cloud.asr_fallback.as_str() - ); - } else { - println!(" {}: disabled", crate::ui::summary_key("Cloud ASR")); - } - - match (selections.cloud.rewrite_enabled, selections.rewrite_model) { - (true, Some(model)) => println!( - " {}: cloud with local fallback ({})", - crate::ui::summary_key("Rewrite"), - crate::ui::value(model) - ), - (true, None) => println!( - " {}: cloud only (fallback: {})", - crate::ui::summary_key("Rewrite"), - selections.cloud.rewrite_fallback.as_str() - ), - (false, Some(model)) => println!( - " {}: local ({})", - crate::ui::summary_key("Rewrite"), - crate::ui::value(model) - ), - (false, None) => println!( - " {}: disabled (raw mode)", - crate::ui::summary_key("Rewrite") - ), - } - - if selections.asr_model.backend == TranscriptionBackend::Nemo && !selections.cloud.asr_enabled { - println!( - " {}: first use may be slower than steady-state while the worker warms.", - crate::ui::summary_key("NeMo note") - ); - } -} - -fn print_setup_complete(ui: &SetupUi) { - ui.print_header("Setup complete"); - println!("You can now use whispers."); - ui.print_section("Example keybind"); - ui.print_subtle("Bind it to a key in your compositor, e.g. for Hyprland:"); - println!(" bind = SUPER ALT, D, exec, whispers"); -} - -fn apply_runtime_backend_selection( - config_path: &Path, - selected_asr_backend: TranscriptionBackend, - cloud: &CloudSetup, -) -> Result<()> { - let transcription_backend = if cloud.asr_enabled { - TranscriptionBackend::Cloud - } else { - selected_asr_backend - }; - let transcription_fallback = if cloud.asr_enabled { - cloud.asr_fallback - } else { - TranscriptionFallback::ConfiguredLocal - }; - config::update_config_transcription_runtime( - config_path, - transcription_backend, - transcription_fallback, - )?; - - let rewrite_backend = if cloud.rewrite_enabled { - RewriteBackend::Cloud - } else { - RewriteBackend::Local - }; - let rewrite_fallback = if cloud.rewrite_enabled { - cloud.rewrite_fallback - } else { - RewriteFallback::Local - }; - config::update_config_rewrite_runtime(config_path, rewrite_backend, rewrite_fallback)?; - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::config::Config; - - #[test] - fn runtime_selection_resets_cloud_asr_when_disabled() { - let config_path = crate::test_support::unique_temp_path("setup-runtime-reset", "toml"); - config::write_default_config(&config_path, "~/model.bin").expect("write config"); - config::update_config_transcription_runtime( - &config_path, - TranscriptionBackend::Cloud, - TranscriptionFallback::None, - ) - .expect("set cloud runtime"); - - let cloud = CloudSetup::default(); - apply_runtime_backend_selection(&config_path, TranscriptionBackend::WhisperCpp, &cloud) - .expect("reset runtime"); - - let config = Config::load(Some(&config_path)).expect("load config"); - assert_eq!( - config.transcription.backend, - TranscriptionBackend::WhisperCpp - ); - assert_eq!( - config.transcription.fallback, - TranscriptionFallback::ConfiguredLocal - ); - } -} diff --git a/src/setup/apply.rs b/src/setup/apply.rs new file mode 100644 index 0000000..e257f2a --- /dev/null +++ b/src/setup/apply.rs @@ -0,0 +1,142 @@ +use std::path::Path; + +use crate::asr_model; +use crate::config::{ + self, CloudLanguageMode, CloudSettingsUpdate, PostprocessMode, RewriteBackend, RewriteFallback, + TranscriptionBackend, TranscriptionFallback, +}; +use crate::error::Result; +use crate::ui::SetupUi; + +use super::{CloudSetup, SetupSelections}; + +pub(super) fn apply_setup_config( + ui: &SetupUi, + config_path: &Path, + config_path_override: Option<&Path>, + selections: &SetupSelections, +) -> Result<()> { + ensure_config_exists(ui, config_path)?; + asr_model::select_model(selections.asr_model.name, config_path_override)?; + + if let Some(rewrite_model) = selections.rewrite_model { + config::update_config_rewrite_selection(config_path, rewrite_model)?; + } + + apply_postprocess_selection(ui, config_path, selections)?; + apply_runtime_backend_selection(config_path, selections.asr_model.backend, &selections.cloud)?; + apply_cloud_settings(ui, config_path, &selections.cloud)?; + Ok(()) +} + +fn ensure_config_exists(ui: &SetupUi, config_path: &Path) -> Result<()> { + let default_model_path = crate::model::model_path_for_config("ggml-large-v3-turbo.bin"); + if !config_path.exists() { + tracing::info!("writing new config at {}", config_path.display()); + config::write_default_config(config_path, &default_model_path)?; + return Ok(()); + } + + tracing::info!("updating existing config at {}", config_path.display()); + ui.print_info("Updating existing config."); + Ok(()) +} + +fn apply_postprocess_selection( + ui: &SetupUi, + config_path: &Path, + selections: &SetupSelections, +) -> Result<()> { + if selections.postprocess_mode == PostprocessMode::Raw { + config::update_config_postprocess_mode(config_path, PostprocessMode::Raw)?; + ui.print_info("Rewrite cleanup: disabled (raw mode)."); + } else { + config::update_config_postprocess_mode(config_path, selections.postprocess_mode)?; + } + Ok(()) +} + +fn apply_cloud_settings(ui: &SetupUi, config_path: &Path, cloud: &CloudSetup) -> Result<()> { + if !cloud.enabled() { + return Ok(()); + } + + config::update_config_cloud_settings( + config_path, + &CloudSettingsUpdate { + provider: cloud.provider, + base_url: &cloud.base_url, + api_key: &cloud.api_key, + api_key_env: &cloud.api_key_env, + connect_timeout_ms: 3000, + request_timeout_ms: 15000, + transcription_model: "gpt-4o-mini-transcribe", + transcription_language_mode: CloudLanguageMode::InheritLocal, + transcription_language: "", + rewrite_model: "gpt-4.1-mini", + rewrite_temperature: 0.1, + rewrite_max_output_tokens: 256, + }, + )?; + + if cloud.api_key.is_empty() { + ui.print_ok(format!( + "Cloud provider: {} (using API key env {}).", + crate::ui::provider_token(cloud.provider.as_str()), + crate::ui::value(&cloud.api_key_env) + )); + } else { + ui.print_ok(format!( + "Cloud provider: {} (using a locally stored API key).", + crate::ui::provider_token(cloud.provider.as_str()) + )); + } + + if cloud.api_key.is_empty() && std::env::var(&cloud.api_key_env).is_err() { + ui.print_warn(format!( + "{} is not set in the current environment yet.", + crate::ui::value(&cloud.api_key_env) + )); + } + + Ok(()) +} + +pub(super) fn apply_runtime_backend_selection( + config_path: &Path, + selected_asr_backend: TranscriptionBackend, + cloud: &CloudSetup, +) -> Result<()> { + let transcription_backend = if cloud.asr_enabled { + TranscriptionBackend::Cloud + } else { + selected_asr_backend + }; + let transcription_fallback = if cloud.asr_enabled { + cloud.asr_fallback + } else { + TranscriptionFallback::ConfiguredLocal + }; + config::update_config_transcription_runtime( + config_path, + transcription_backend, + transcription_fallback, + )?; + + let rewrite_backend = if cloud.rewrite_enabled { + RewriteBackend::Cloud + } else { + RewriteBackend::Local + }; + let rewrite_fallback = if cloud.rewrite_enabled { + if crate::rewrite::local_rewrite_available() { + cloud.rewrite_fallback + } else { + RewriteFallback::None + } + } else { + RewriteFallback::Local + }; + config::update_config_rewrite_runtime(config_path, rewrite_backend, rewrite_fallback)?; + Ok(()) +} diff --git a/src/setup/report.rs b/src/setup/report.rs new file mode 100644 index 0000000..8d042d0 --- /dev/null +++ b/src/setup/report.rs @@ -0,0 +1,78 @@ +use crate::config::TranscriptionBackend; +use crate::ui::SetupUi; + +use super::SetupSelections; + +pub(super) fn print_setup_intro(ui: &SetupUi) { + ui.print_subtle( + "Recommended models are listed first. Experimental backends are available, but not the default recommendation.", + ); + ui.blank(); +} + +pub(super) fn print_setup_summary(ui: &SetupUi, selections: &SetupSelections) { + ui.print_section("Setup summary"); + println!( + " {}: {} ({}, {}, {})", + crate::ui::summary_key("ASR"), + crate::ui::value(selections.asr_model.name), + crate::ui::backend_token(selections.asr_model.backend.as_str()), + crate::ui::scope_token(selections.asr_model.language_scope.as_str()), + crate::ui::tier_token(selections.asr_model.support_tier.as_str()) + ); + + if let Some(note) = selections.asr_model.setup_note { + println!(" {}: {}", crate::ui::summary_key("ASR note"), note); + } + + if selections.cloud.asr_enabled { + println!( + " {}: enabled via {} (fallback: {})", + crate::ui::summary_key("Cloud ASR"), + crate::ui::provider_token(selections.cloud.provider.as_str()), + selections.cloud.asr_fallback.as_str() + ); + } else { + println!(" {}: disabled", crate::ui::summary_key("Cloud ASR")); + } + + match (selections.cloud.rewrite_enabled, selections.rewrite_model) { + (true, Some(model)) => println!( + " {}: cloud with local fallback ({}, mode: {})", + crate::ui::summary_key("Rewrite"), + crate::ui::value(model), + selections.postprocess_mode.as_str() + ), + (true, None) => println!( + " {}: cloud only (fallback: {}, mode: {})", + crate::ui::summary_key("Rewrite"), + selections.cloud.rewrite_fallback.as_str(), + selections.postprocess_mode.as_str() + ), + (false, Some(model)) => println!( + " {}: local ({}, mode: {})", + crate::ui::summary_key("Rewrite"), + crate::ui::value(model), + selections.postprocess_mode.as_str() + ), + (false, None) => println!( + " {}: disabled (raw mode)", + crate::ui::summary_key("Rewrite") + ), + } + + if selections.asr_model.backend == TranscriptionBackend::Nemo && !selections.cloud.asr_enabled { + println!( + " {}: first use may be slower than steady-state while the worker warms.", + crate::ui::summary_key("NeMo note") + ); + } +} + +pub(super) fn print_setup_complete(ui: &SetupUi) { + ui.print_header("Setup complete"); + println!("You can now use whispers."); + ui.print_section("Example keybind"); + ui.print_subtle("Bind it to a key in your compositor, e.g. for Hyprland:"); + println!(" bind = SUPER ALT, D, exec, whispers"); +} diff --git a/src/setup/select.rs b/src/setup/select.rs new file mode 100644 index 0000000..0c27bd2 --- /dev/null +++ b/src/setup/select.rs @@ -0,0 +1,161 @@ +use crate::asr_model::{self, AsrModelInfo}; +use crate::config::{CloudProvider, PostprocessMode, RewriteFallback, TranscriptionFallback}; +use crate::error::Result; +use crate::rewrite_model::{self, REWRITE_MODELS}; +use crate::ui::SetupUi; + +use super::CloudSetup; + +pub(super) fn choose_asr_model( + ui: &SetupUi, + available_asr_models: &[&'static AsrModelInfo], +) -> Result<&'static AsrModelInfo> { + loop { + let items: Vec = available_asr_models + .iter() + .map(|model| asr_model::setup_label(model)) + .collect(); + let selection = ui.select("Choose an ASR model", &items, 0)?; + let chosen = available_asr_models[selection]; + + if let Some(warning) = asr_model::experimental_warning(chosen) { + ui.blank(); + tracing::debug!( + "experimental setup warning for {}: {}", + chosen.name, + warning + ); + ui.print_experimental_notice(chosen.name, asr_model::experimental_notice_facts(chosen)); + if !ui.danger_confirm( + crate::ui::danger_text("Continue with this experimental ASR backend?"), + false, + )? { + ui.blank(); + continue; + } + } + + return Ok(chosen); + } +} + +pub(super) fn choose_rewrite_model( + ui: &SetupUi, + prompt: &str, + default_index: usize, +) -> Result<&'static str> { + let items: Vec = REWRITE_MODELS + .iter() + .map(rewrite_model::setup_label) + .collect(); + let selection = ui.select(prompt, &items, default_index)?; + let chosen = &REWRITE_MODELS[selection]; + Ok(chosen.name) +} + +pub(super) fn choose_rewrite_mode(ui: &SetupUi) -> Result { + let items = [ + "advanced_local: smart rewrite cleanup with current bounded-candidate behavior", + "agentic_rewrite: app-aware rewrite with policy and technical glossary support", + ]; + let selection = ui.select("Choose the rewrite mode", &items, 1)?; + Ok(if selection == 0 { + PostprocessMode::AdvancedLocal + } else { + PostprocessMode::AgenticRewrite + }) +} + +pub(super) fn configure_cloud( + ui: &SetupUi, + rewrite_model: &mut Option<&'static str>, +) -> Result { + let mut cloud = CloudSetup::default(); + + let provider_items = vec![ + format!( + "{} {}", + crate::ui::provider_token("OpenAI"), + crate::ui::description_token("Official OpenAI hosted API") + ), + format!( + "{} {}", + crate::ui::provider_token("OpenAI-compatible endpoint"), + crate::ui::description_token("Third-party endpoint that speaks the OpenAI API") + ), + ]; + let provider_selection = ui.select("Choose a cloud provider", &provider_items, 0)?; + if provider_selection == 1 { + cloud.provider = CloudProvider::OpenAiCompatible; + cloud.base_url = ui.input_string("Base URL for the OpenAI-compatible API", None)?; + } + + let cloud_key_input = ui.input_string( + "Cloud API key or environment variable name", + Some("OPENAI_API_KEY"), + )?; + if looks_like_cloud_api_key(&cloud_key_input) { + cloud.api_key = cloud_key_input.trim().to_string(); + cloud.api_key_env = "OPENAI_API_KEY".into(); + } else { + cloud.api_key_env = cloud_key_input.trim().to_string(); + } + + let cloud_mode_items = [ + "Cloud rewrite only", + "Cloud ASR only", + "Cloud ASR + rewrite", + ]; + let cloud_mode = ui.select("Choose the cloud mode", &cloud_mode_items, 0)?; + cloud.rewrite_enabled = matches!(cloud_mode, 0 | 2); + cloud.asr_enabled = matches!(cloud_mode, 1 | 2); + + if cloud.asr_enabled { + let fallback_items = ["Use configured local fallback", "Fail if cloud ASR fails"]; + let selection = ui.select("If cloud ASR fails", &fallback_items, 0)?; + cloud.asr_fallback = if selection == 0 { + TranscriptionFallback::ConfiguredLocal + } else { + TranscriptionFallback::None + }; + } + + if cloud.rewrite_enabled { + if crate::rewrite::local_rewrite_available() { + let fallback_items = [ + "Use local rewrite fallback", + "Fail back to deterministic text", + ]; + let selection = ui.select("If cloud rewrite fails", &fallback_items, 0)?; + cloud.rewrite_fallback = if selection == 0 { + RewriteFallback::Local + } else { + RewriteFallback::None + }; + } else { + cloud.rewrite_fallback = RewriteFallback::None; + ui.print_info( + "Local rewrite fallback is unavailable in this build; cloud rewrite will fall back to deterministic text.", + ); + } + } + + if cloud.rewrite_enabled + && cloud.rewrite_fallback == RewriteFallback::Local + && rewrite_model.is_none() + { + ui.blank(); + ui.print_info("Cloud rewrite fallback uses a local rewrite model."); + *rewrite_model = Some(choose_rewrite_model( + ui, + "Choose a local rewrite fallback model to download", + 1, + )?); + } + + Ok(cloud) +} + +fn looks_like_cloud_api_key(value: &str) -> bool { + value.trim().starts_with("sk-") +} diff --git a/src/setup/side_effects.rs b/src/setup/side_effects.rs new file mode 100644 index 0000000..fbeb369 --- /dev/null +++ b/src/setup/side_effects.rs @@ -0,0 +1,93 @@ +use std::path::Path; + +use crate::config::{self, TranscriptionBackend}; +use crate::error::Result; +use crate::ui::SetupUi; + +use super::SetupSelections; + +pub(super) async fn download_asr_model( + ui: &SetupUi, + asr_model: &'static crate::asr_model::AsrModelInfo, +) -> Result<()> { + ui.blank(); + tracing::info!("setup selected ASR model: {}", asr_model.name); + crate::asr_model::download_model(asr_model.name).await?; + ui.blank(); + Ok(()) +} + +pub(super) async fn download_rewrite_model( + ui: &SetupUi, + rewrite_model: &'static str, +) -> Result<()> { + ui.blank(); + tracing::info!("setup selected rewrite model: {}", rewrite_model); + crate::rewrite_model::download_model(rewrite_model).await?; + ui.blank(); + Ok(()) +} + +pub(super) fn maybe_create_agentic_starter_files( + ui: &SetupUi, + config_path: &Path, + selections: &SetupSelections, +) -> Result<()> { + if selections.postprocess_mode != crate::config::PostprocessMode::AgenticRewrite { + return Ok(()); + } + + let config = config::Config::load(Some(config_path))?; + let created = crate::agentic_rewrite::ensure_starter_files(&config)?; + for path in created { + ui.print_info(format!("Created agentic rewrite starter file: {}", path)); + } + Ok(()) +} + +pub(super) fn cleanup_stale_asr_workers(ui: &SetupUi, config_path: &Path) -> Result<()> { + match config::Config::load(Some(config_path)) + .and_then(|config| crate::asr::cleanup::cleanup_stale_transcribers(&config)) + { + Ok(()) => Ok(()), + Err(err) => { + ui.print_warn(format!( + "Failed to retire stale ASR workers after setup: {err}" + )); + Ok(()) + } + } +} + +pub(super) fn maybe_prewarm_experimental_nemo( + ui: &SetupUi, + config_path: &Path, + selections: &SetupSelections, +) -> Result<()> { + if selections.asr_model.backend != TranscriptionBackend::Nemo || selections.cloud.asr_enabled { + return Ok(()); + } + + let spinner = + crate::ui::spinner("Starting background warm-up for the experimental NeMo backend..."); + match config::Config::load(Some(config_path)).and_then(|config| asr_model_prewarm(&config)) { + Ok(()) => { + spinner.finish_and_clear(); + ui.print_info("Background warm-up started for the experimental NeMo backend."); + } + Err(err) => { + spinner.finish_and_clear(); + ui.print_warn(format!( + "Failed to prewarm NeMo ASR backend after setup: {err}" + )); + } + } + + Ok(()) +} + +fn asr_model_prewarm(config: &config::Config) -> Result<()> { + let prepared = crate::asr::prepare::prepare_transcriber(config)?; + crate::asr::prepare::prewarm_transcriber(&prepared, "setup"); + Ok(()) +} diff --git a/src/setup/tests.rs b/src/setup/tests.rs new file mode 100644 index 0000000..4c27fbd --- /dev/null +++ b/src/setup/tests.rs @@ -0,0 +1,51 @@ +use crate::config::Config; + +use super::{CloudSetup, apply}; +use crate::config::{ + self, RewriteBackend, RewriteFallback, TranscriptionBackend, TranscriptionFallback, +}; + +#[test] +fn runtime_selection_resets_cloud_asr_when_disabled() { + let config_path = crate::test_support::unique_temp_path("setup-runtime-reset", "toml"); + config::write_default_config(&config_path, "~/model.bin").expect("write config"); + config::update_config_transcription_runtime( + &config_path, + TranscriptionBackend::Cloud, + TranscriptionFallback::None, + ) + .expect("set cloud runtime"); + + let cloud = CloudSetup::default(); + apply::apply_runtime_backend_selection(&config_path, TranscriptionBackend::WhisperCpp, &cloud) + .expect("reset runtime"); + + let config = Config::load(Some(&config_path)).expect("load config"); + assert_eq!( + config.transcription.backend, + TranscriptionBackend::WhisperCpp + ); + assert_eq!( + config.transcription.fallback, + TranscriptionFallback::ConfiguredLocal + ); +} + +#[cfg(not(feature = "local-rewrite"))] +#[test] +fn runtime_selection_disables_local_rewrite_fallback_when_build_lacks_local_rewrite() { + let config_path = crate::test_support::unique_temp_path("setup-rewrite-fallback-reset", "toml"); + config::write_default_config(&config_path, "~/model.bin").expect("write config"); + + let cloud = CloudSetup { + rewrite_enabled: true, + rewrite_fallback: RewriteFallback::Local, + ..CloudSetup::default() + }; + apply::apply_runtime_backend_selection(&config_path, TranscriptionBackend::WhisperCpp, &cloud) + .expect("apply runtime"); + + let config = Config::load(Some(&config_path)).expect("load config"); + assert_eq!(config.rewrite.backend, RewriteBackend::Cloud); + assert_eq!(config.rewrite.fallback, RewriteFallback::None); +} diff --git a/src/ui.rs b/src/ui.rs index 2fcf5c8..85df93e 100644 --- a/src/ui.rs +++ b/src/ui.rs @@ -354,6 +354,12 @@ impl SetupUi { } } +impl Default for SetupUi { + fn default() -> Self { + Self::new() + } +} + fn prompt_error(err: dialoguer::Error) -> WhsprError { WhsprError::Config(format!("prompt cancelled: {err}")) } diff --git a/vendor/whisper-rs-sys/Cargo.toml b/vendor/whisper-rs-sys/Cargo.toml new file mode 100644 index 0000000..5b94e37 --- /dev/null +++ b/vendor/whisper-rs-sys/Cargo.toml @@ -0,0 +1,70 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +name = "whisper-rs-sys" +version = "0.14.1" +build = "build.rs" +links = "whisper" +include = [ + "whisper.cpp/bindings/javascript/package-tmpl.json", + "whisper.cpp/bindings/CMakeLists.txt", + "whisper.cpp/CMakeLists.txt", + "whisper.cpp/cmake", + "whisper.cpp/src/**", + "whisper.cpp/include/whisper.h", + "whisper.cpp/ggml/cmake", + "whisper.cpp/ggml/CMakeLists.txt", + "whisper.cpp/ggml/src/**", + "whisper.cpp/ggml/include/*.h", + "whisper.cpp/LICENSE", + "src/*.rs", + "build.rs", + "wrapper.h", +] +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "Rust bindings for whisper.cpp (FFI bindings)" +documentation = "https://docs.rs/whisper-rs-sys" +readme = false +license = "Unlicense" +repository = "https://codeberg.org/tazz4843/whisper-rs" + +[features] +coreml = [] +cuda = [] +force-debug = [] +hipblas = [] +intel-sycl = [] +metal = [] +openblas = [] +openmp = [] +vulkan = [] + +[lib] +name = "whisper_rs_sys" +path = "src/lib.rs" + +[build-dependencies.bindgen] +version = "0.71" + +[build-dependencies.cfg-if] +version = "1" + +[build-dependencies.cmake] +version = "0.1" + +[build-dependencies.fs_extra] +version = "1.3" diff --git a/vendor/whisper-rs-sys/build.rs b/vendor/whisper-rs-sys/build.rs new file mode 100644 index 0000000..b7d5bf7 --- /dev/null +++ b/vendor/whisper-rs-sys/build.rs @@ -0,0 +1,377 @@ +#![allow(clippy::uninlined_format_args)] + +extern crate bindgen; + +use cmake::Config; +use std::env; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; + +fn main() { + let target = env::var("TARGET").unwrap(); + // Link C++ standard library + if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) { + println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib); + } + // Link macOS Accelerate framework for matrix calculations + if target.contains("apple") { + println!("cargo:rustc-link-lib=framework=Accelerate"); + #[cfg(feature = "coreml")] + { + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=framework=CoreML"); + } + #[cfg(feature = "metal")] + { + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=MetalKit"); + } + } + + #[cfg(feature = "coreml")] + println!("cargo:rustc-link-lib=static=whisper.coreml"); + + #[cfg(feature = "openblas")] + { + if let Ok(openblas_path) = env::var("OPENBLAS_PATH") { + println!( + "cargo::rustc-link-search={}", + PathBuf::from(openblas_path).join("lib").display() + ); + } + if cfg!(windows) { + println!("cargo:rustc-link-lib=libopenblas"); + } else { + println!("cargo:rustc-link-lib=openblas"); + } + } + #[cfg(feature = "cuda")] + { + println!("cargo:rustc-link-lib=cublas"); + println!("cargo:rustc-link-lib=cudart"); + println!("cargo:rustc-link-lib=cublasLt"); + println!("cargo:rustc-link-lib=cuda"); + cfg_if::cfg_if! { + if #[cfg(target_os = "windows")] { + let cuda_path = PathBuf::from(env::var("CUDA_PATH").unwrap()).join("lib/x64"); + println!("cargo:rustc-link-search={}", cuda_path.display()); + } else { + println!("cargo:rustc-link-lib=culibos"); + println!("cargo:rustc-link-search=/usr/local/cuda/lib64"); + println!("cargo:rustc-link-search=/usr/local/cuda/lib64/stubs"); + println!("cargo:rustc-link-search=/opt/cuda/lib64"); + println!("cargo:rustc-link-search=/opt/cuda/lib64/stubs"); + } + } + } + #[cfg(feature = "hipblas")] + { + println!("cargo:rustc-link-lib=hipblas"); + println!("cargo:rustc-link-lib=rocblas"); + println!("cargo:rustc-link-lib=amdhip64"); + + cfg_if::cfg_if! { + if #[cfg(target_os = "windows")] { + panic!("Due to a problem with the last revision of the ROCm 5.7 library, it is not possible to compile the library for the windows environment.\nSee https://github.com/ggerganov/whisper.cpp/issues/2202 for more details.") + } else { + println!("cargo:rerun-if-env-changed=HIP_PATH"); + + let hip_path = match env::var("HIP_PATH") { + Ok(path) =>PathBuf::from(path), + Err(_) => PathBuf::from("/opt/rocm"), + }; + let hip_lib_path = hip_path.join("lib"); + + println!("cargo:rustc-link-search={}",hip_lib_path.display()); + } + } + } + + #[cfg(feature = "openmp")] + { + if target.contains("gnu") { + println!("cargo:rustc-link-lib=gomp"); + } else if target.contains("apple") { + println!("cargo:rustc-link-lib=omp"); + println!("cargo:rustc-link-search=/opt/homebrew/opt/libomp/lib"); + } + } + + println!("cargo:rerun-if-changed=wrapper.h"); + println!("cargo:rerun-if-env-changed=WHISPER_FORCE_GENERATE_BINDINGS"); + + let out = PathBuf::from(env::var("OUT_DIR").unwrap()); + let whisper_root = out.join("whisper.cpp/"); + + if !whisper_root.exists() { + std::fs::create_dir_all(&whisper_root).unwrap(); + fs_extra::dir::copy("./whisper.cpp", &out, &Default::default()).unwrap_or_else(|e| { + panic!( + "Failed to copy whisper sources into {}: {}", + whisper_root.display(), + e + ) + }); + } + + if env::var("WHISPER_FORCE_GENERATE_BINDINGS").is_ok() { + let bindings = bindgen::Builder::default().header("wrapper.h"); + + #[cfg(feature = "metal")] + { + bindings = bindings.header("whisper.cpp/ggml/include/ggml-metal.h"); + } + #[cfg(feature = "vulkan")] + { + bindings = bindings + .header("whisper.cpp/ggml/include/ggml-vulkan.h") + .clang_arg("-DGGML_USE_VULKAN=1"); + } + + let bindings = bindings + .clang_arg("-I./whisper.cpp/") + .clang_arg("-I./whisper.cpp/include") + .clang_arg("-I./whisper.cpp/ggml/include") + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .generate(); + + match bindings { + Ok(b) => { + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + b.write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); + } + Err(e) => { + println!("cargo:warning=Unable to generate bindings: {}", e); + println!("cargo:warning=Using bundled bindings.rs, which may be out of date"); + // copy src/bindings.rs to OUT_DIR + std::fs::copy("src/bindings.rs", out.join("bindings.rs")) + .expect("Unable to copy bindings.rs"); + } + } + } else { + let _: u64 = std::fs::copy("src/bindings.rs", out.join("bindings.rs")) + .expect("Failed to copy bindings.rs"); + }; + + // stop if we're on docs.rs + if env::var("DOCS_RS").is_ok() { + return; + } + + let mut config = Config::new(&whisper_root); + + config + .profile("Release") + .define("BUILD_SHARED_LIBS", "OFF") + .define("WHISPER_ALL_WARNINGS", "OFF") + .define("WHISPER_ALL_WARNINGS_3RD_PARTY", "OFF") + .define("WHISPER_BUILD_TESTS", "OFF") + .define("WHISPER_BUILD_EXAMPLES", "OFF") + .very_verbose(true) + .pic(true); + + if cfg!(target_os = "windows") { + config.cxxflag("/utf-8"); + println!("cargo:rustc-link-lib=advapi32"); + } + + if cfg!(feature = "coreml") { + config.define("WHISPER_COREML", "ON"); + config.define("WHISPER_COREML_ALLOW_FALLBACK", "1"); + } + + if cfg!(feature = "cuda") { + config.define("GGML_CUDA", "ON"); + config.define("CMAKE_POSITION_INDEPENDENT_CODE", "ON"); + config.define("CMAKE_CUDA_FLAGS", "-Xcompiler=-fPIC"); + } + + if cfg!(feature = "hipblas") { + config.define("GGML_HIP", "ON"); + config.define("CMAKE_C_COMPILER", "hipcc"); + config.define("CMAKE_CXX_COMPILER", "hipcc"); + println!("cargo:rerun-if-env-changed=AMDGPU_TARGETS"); + if let Ok(gpu_targets) = env::var("AMDGPU_TARGETS") { + config.define("AMDGPU_TARGETS", gpu_targets); + } + } + + if cfg!(feature = "vulkan") { + config.define("GGML_VULKAN", "ON"); + if cfg!(windows) { + println!("cargo:rerun-if-env-changed=VULKAN_SDK"); + println!("cargo:rustc-link-lib=vulkan-1"); + let vulkan_path = match env::var("VULKAN_SDK") { + Ok(path) => PathBuf::from(path), + Err(_) => panic!( + "Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set" + ), + }; + let vulkan_lib_path = vulkan_path.join("Lib"); + println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); + } else if cfg!(target_os = "macos") { + println!("cargo:rerun-if-env-changed=VULKAN_SDK"); + println!("cargo:rustc-link-lib=vulkan"); + let vulkan_path = match env::var("VULKAN_SDK") { + Ok(path) => PathBuf::from(path), + Err(_) => panic!( + "Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set" + ), + }; + let vulkan_lib_path = vulkan_path.join("lib"); + println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); + } else { + println!("cargo:rustc-link-lib=vulkan"); + } + } + + if cfg!(feature = "openblas") { + config.define("GGML_BLAS", "ON"); + config.define("GGML_BLAS_VENDOR", "OpenBLAS"); + if env::var("BLAS_INCLUDE_DIRS").is_err() { + panic!("BLAS_INCLUDE_DIRS environment variable must be set when using OpenBLAS"); + } + config.define("BLAS_INCLUDE_DIRS", env::var("BLAS_INCLUDE_DIRS").unwrap()); + println!("cargo:rerun-if-env-changed=BLAS_INCLUDE_DIRS"); + } + + if cfg!(feature = "metal") { + config.define("GGML_METAL", "ON"); + config.define("GGML_METAL_NDEBUG", "ON"); + config.define("GGML_METAL_EMBED_LIBRARY", "ON"); + } else { + // Metal is enabled by default, so we need to explicitly disable it + config.define("GGML_METAL", "OFF"); + } + + if cfg!(debug_assertions) || cfg!(feature = "force-debug") { + // debug builds are too slow to even remotely be usable, + // so we build with optimizations even in debug mode + config.define("CMAKE_BUILD_TYPE", "RelWithDebInfo"); + config.cxxflag("-DWHISPER_DEBUG"); + } else { + // we're in release mode, explicitly set to release mode + // see also https://codeberg.org/tazz4843/whisper-rs/issues/226 + config.define("CMAKE_BUILD_TYPE", "Release"); + } + + // Allow passing any WHISPER or CMAKE compile flags + for (key, value) in env::vars() { + let is_whisper_flag = + key.starts_with("WHISPER_") && key != "WHISPER_DONT_GENERATE_BINDINGS"; + let is_cmake_flag = key.starts_with("CMAKE_"); + if is_whisper_flag || is_cmake_flag { + config.define(&key, &value); + } + } + + if cfg!(not(feature = "openmp")) { + config.define("GGML_OPENMP", "OFF"); + } + + if cfg!(feature = "intel-sycl") { + config.define("BUILD_SHARED_LIBS", "ON"); + config.define("GGML_SYCL", "ON"); + config.define("GGML_SYCL_TARGET", "INTEL"); + config.define("CMAKE_C_COMPILER", "icx"); + config.define("CMAKE_CXX_COMPILER", "icpx"); + } + + let destination = config.build(); + + add_link_search_path(&out.join("build")).unwrap(); + + println!("cargo:rustc-link-search=native={}", destination.display()); + if cfg!(feature = "intel-sycl") { + println!("cargo:rustc-link-lib=whisper"); + println!("cargo:rustc-link-lib=ggml"); + println!("cargo:rustc-link-lib=ggml-base"); + println!("cargo:rustc-link-lib=ggml-cpu"); + } else { + println!("cargo:rustc-link-lib=static=whisper"); + println!("cargo:rustc-link-lib=static=ggml"); + println!("cargo:rustc-link-lib=static=ggml-base"); + println!("cargo:rustc-link-lib=static=ggml-cpu"); + } + if cfg!(target_os = "macos") || cfg!(feature = "openblas") { + println!("cargo:rustc-link-lib=static=ggml-blas"); + } + if cfg!(feature = "vulkan") { + if cfg!(feature = "intel-sycl") { + println!("cargo:rustc-link-lib=ggml-vulkan"); + } else { + println!("cargo:rustc-link-lib=static=ggml-vulkan"); + } + } + + if cfg!(feature = "hipblas") { + println!("cargo:rustc-link-lib=static=ggml-hip"); + } + + if cfg!(feature = "metal") { + println!("cargo:rustc-link-lib=static=ggml-metal"); + } + + if cfg!(feature = "cuda") { + println!("cargo:rustc-link-lib=static=ggml-cuda"); + } + + if cfg!(feature = "openblas") { + println!("cargo:rustc-link-lib=static=ggml-blas"); + } + + if cfg!(feature = "intel-sycl") { + println!("cargo:rustc-link-lib=ggml-sycl"); + } + + println!( + "cargo:WHISPER_CPP_VERSION={}", + get_whisper_cpp_version(&whisper_root) + .expect("Failed to read whisper.cpp CMake config") + .expect("Could not find whisper.cpp version declaration"), + ); + + // for whatever reason this file is generated during build and triggers cargo complaining + _ = std::fs::remove_file("bindings/javascript/package.json"); +} + +// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462 +fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> { + if target.contains("msvc") { + None + } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { + Some("c++") + } else if target.contains("android") { + Some("c++_shared") + } else { + Some("stdc++") + } +} + +fn add_link_search_path(dir: &std::path::Path) -> std::io::Result<()> { + if dir.is_dir() { + println!("cargo:rustc-link-search={}", dir.display()); + for entry in std::fs::read_dir(dir)? { + add_link_search_path(&entry?.path())?; + } + } + Ok(()) +} + +fn get_whisper_cpp_version(whisper_root: &std::path::Path) -> std::io::Result> { + let cmake_lists = BufReader::new(File::open(whisper_root.join("CMakeLists.txt"))?); + + for line in cmake_lists.lines() { + let line = line?; + + if let Some(suffix) = line.strip_prefix(r#"project("whisper.cpp" VERSION "#) { + let whisper_cpp_version = suffix.trim_end_matches(')'); + return Ok(Some(whisper_cpp_version.into())); + } + } + + Ok(None) +} diff --git a/vendor/whisper-rs-sys/src/bindings.rs b/vendor/whisper-rs-sys/src/bindings.rs new file mode 100644 index 0000000..9f4be14 --- /dev/null +++ b/vendor/whisper-rs-sys/src/bindings.rs @@ -0,0 +1,5698 @@ +/* automatically generated by rust-bindgen 0.71.1 */ + +#[repr(C)] +#[derive(Copy, Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct __BindgenBitfieldUnit { + storage: Storage, +} +impl __BindgenBitfieldUnit { + #[inline] + pub const fn new(storage: Storage) -> Self { + Self { storage } + } +} +impl __BindgenBitfieldUnit +where + Storage: AsRef<[u8]> + AsMut<[u8]>, +{ + #[inline] + fn extract_bit(byte: u8, index: usize) -> bool { + let bit_index = if cfg!(target_endian = "big") { + 7 - (index % 8) + } else { + index % 8 + }; + let mask = 1 << bit_index; + byte & mask == mask + } + #[inline] + pub fn get_bit(&self, index: usize) -> bool { + debug_assert!(index / 8 < self.storage.as_ref().len()); + let byte_index = index / 8; + let byte = self.storage.as_ref()[byte_index]; + Self::extract_bit(byte, index) + } + #[inline] + pub unsafe fn raw_get_bit(this: *const Self, index: usize) -> bool { + debug_assert!(index / 8 < core::mem::size_of::()); + let byte_index = index / 8; + let byte = *(core::ptr::addr_of!((*this).storage) as *const u8).offset(byte_index as isize); + Self::extract_bit(byte, index) + } + #[inline] + fn change_bit(byte: u8, index: usize, val: bool) -> u8 { + let bit_index = if cfg!(target_endian = "big") { + 7 - (index % 8) + } else { + index % 8 + }; + let mask = 1 << bit_index; + if val { + byte | mask + } else { + byte & !mask + } + } + #[inline] + pub fn set_bit(&mut self, index: usize, val: bool) { + debug_assert!(index / 8 < self.storage.as_ref().len()); + let byte_index = index / 8; + let byte = &mut self.storage.as_mut()[byte_index]; + *byte = Self::change_bit(*byte, index, val); + } + #[inline] + pub unsafe fn raw_set_bit(this: *mut Self, index: usize, val: bool) { + debug_assert!(index / 8 < core::mem::size_of::()); + let byte_index = index / 8; + let byte = + (core::ptr::addr_of_mut!((*this).storage) as *mut u8).offset(byte_index as isize); + *byte = Self::change_bit(*byte, index, val); + } + #[inline] + pub fn get(&self, bit_offset: usize, bit_width: u8) -> u64 { + debug_assert!(bit_width <= 64); + debug_assert!(bit_offset / 8 < self.storage.as_ref().len()); + debug_assert!((bit_offset + (bit_width as usize)) / 8 <= self.storage.as_ref().len()); + let mut val = 0; + for i in 0..(bit_width as usize) { + if self.get_bit(i + bit_offset) { + let index = if cfg!(target_endian = "big") { + bit_width as usize - 1 - i + } else { + i + }; + val |= 1 << index; + } + } + val + } + #[inline] + pub unsafe fn raw_get(this: *const Self, bit_offset: usize, bit_width: u8) -> u64 { + debug_assert!(bit_width <= 64); + debug_assert!(bit_offset / 8 < core::mem::size_of::()); + debug_assert!((bit_offset + (bit_width as usize)) / 8 <= core::mem::size_of::()); + let mut val = 0; + for i in 0..(bit_width as usize) { + if Self::raw_get_bit(this, i + bit_offset) { + let index = if cfg!(target_endian = "big") { + bit_width as usize - 1 - i + } else { + i + }; + val |= 1 << index; + } + } + val + } + #[inline] + pub fn set(&mut self, bit_offset: usize, bit_width: u8, val: u64) { + debug_assert!(bit_width <= 64); + debug_assert!(bit_offset / 8 < self.storage.as_ref().len()); + debug_assert!((bit_offset + (bit_width as usize)) / 8 <= self.storage.as_ref().len()); + for i in 0..(bit_width as usize) { + let mask = 1 << i; + let val_bit_is_set = val & mask == mask; + let index = if cfg!(target_endian = "big") { + bit_width as usize - 1 - i + } else { + i + }; + self.set_bit(index + bit_offset, val_bit_is_set); + } + } + #[inline] + pub unsafe fn raw_set(this: *mut Self, bit_offset: usize, bit_width: u8, val: u64) { + debug_assert!(bit_width <= 64); + debug_assert!(bit_offset / 8 < core::mem::size_of::()); + debug_assert!((bit_offset + (bit_width as usize)) / 8 <= core::mem::size_of::()); + for i in 0..(bit_width as usize) { + let mask = 1 << i; + let val_bit_is_set = val & mask == mask; + let index = if cfg!(target_endian = "big") { + bit_width as usize - 1 - i + } else { + i + }; + Self::raw_set_bit(this, index + bit_offset, val_bit_is_set); + } + } +} +#[derive(PartialEq, Copy, Clone, Hash, Debug, Default)] +#[repr(C)] +pub struct __BindgenComplex { + pub re: T, + pub im: T, +} +pub const __bool_true_false_are_defined: u32 = 1; +pub const true_: u32 = 1; +pub const false_: u32 = 0; +pub const _STDINT_H: u32 = 1; +pub const _FEATURES_H: u32 = 1; +pub const _DEFAULT_SOURCE: u32 = 1; +pub const __GLIBC_USE_ISOC2Y: u32 = 0; +pub const __GLIBC_USE_ISOC23: u32 = 0; +pub const __USE_ISOC11: u32 = 1; +pub const __USE_ISOC99: u32 = 1; +pub const __USE_ISOC95: u32 = 1; +pub const __USE_POSIX_IMPLICITLY: u32 = 1; +pub const _POSIX_SOURCE: u32 = 1; +pub const _POSIX_C_SOURCE: u32 = 200809; +pub const __USE_POSIX: u32 = 1; +pub const __USE_POSIX2: u32 = 1; +pub const __USE_POSIX199309: u32 = 1; +pub const __USE_POSIX199506: u32 = 1; +pub const __USE_XOPEN2K: u32 = 1; +pub const __USE_XOPEN2K8: u32 = 1; +pub const _ATFILE_SOURCE: u32 = 1; +pub const __WORDSIZE: u32 = 64; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const __TIMESIZE: u32 = 64; +pub const __USE_TIME_BITS64: u32 = 1; +pub const __USE_MISC: u32 = 1; +pub const __USE_ATFILE: u32 = 1; +pub const __USE_FORTIFY_LEVEL: u32 = 0; +pub const __GLIBC_USE_DEPRECATED_GETS: u32 = 0; +pub const __GLIBC_USE_DEPRECATED_SCANF: u32 = 0; +pub const __GLIBC_USE_C23_STRTOL: u32 = 0; +pub const _STDC_PREDEF_H: u32 = 1; +pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_60559_BFP__: u32 = 201404; +pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_IEC_60559_COMPLEX__: u32 = 201404; +pub const __STDC_ISO_10646__: u32 = 201706; +pub const __GNU_LIBRARY__: u32 = 6; +pub const __GLIBC__: u32 = 2; +pub const __GLIBC_MINOR__: u32 = 41; +pub const _SYS_CDEFS_H: u32 = 1; +pub const __glibc_c99_flexarr_available: u32 = 1; +pub const __LDOUBLE_REDIRECTS_TO_FLOAT128_ABI: u32 = 0; +pub const __HAVE_GENERIC_SELECTION: u32 = 1; +pub const __GLIBC_USE_LIB_EXT2: u32 = 0; +pub const __GLIBC_USE_IEC_60559_BFP_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_BFP_EXT_C23: u32 = 0; +pub const __GLIBC_USE_IEC_60559_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_FUNCS_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_FUNCS_EXT_C23: u32 = 0; +pub const __GLIBC_USE_IEC_60559_TYPES_EXT: u32 = 0; +pub const _BITS_TYPES_H: u32 = 1; +pub const _BITS_TYPESIZES_H: u32 = 1; +pub const __OFF_T_MATCHES_OFF64_T: u32 = 1; +pub const __INO_T_MATCHES_INO64_T: u32 = 1; +pub const __RLIM_T_MATCHES_RLIM64_T: u32 = 1; +pub const __STATFS_MATCHES_STATFS64: u32 = 1; +pub const __KERNEL_OLD_TIMEVAL_MATCHES_TIMEVAL64: u32 = 1; +pub const __FD_SETSIZE: u32 = 1024; +pub const _BITS_TIME64_H: u32 = 1; +pub const _BITS_WCHAR_H: u32 = 1; +pub const _BITS_STDINT_INTN_H: u32 = 1; +pub const _BITS_STDINT_UINTN_H: u32 = 1; +pub const _BITS_STDINT_LEAST_H: u32 = 1; +pub const INT8_MIN: i32 = -128; +pub const INT16_MIN: i32 = -32768; +pub const INT32_MIN: i32 = -2147483648; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; +pub const UINT8_MAX: u32 = 255; +pub const UINT16_MAX: u32 = 65535; +pub const UINT32_MAX: u32 = 4294967295; +pub const INT_LEAST8_MIN: i32 = -128; +pub const INT_LEAST16_MIN: i32 = -32768; +pub const INT_LEAST32_MIN: i32 = -2147483648; +pub const INT_LEAST8_MAX: u32 = 127; +pub const INT_LEAST16_MAX: u32 = 32767; +pub const INT_LEAST32_MAX: u32 = 2147483647; +pub const UINT_LEAST8_MAX: u32 = 255; +pub const UINT_LEAST16_MAX: u32 = 65535; +pub const UINT_LEAST32_MAX: u32 = 4294967295; +pub const INT_FAST8_MIN: i32 = -128; +pub const INT_FAST16_MIN: i64 = -9223372036854775808; +pub const INT_FAST32_MIN: i64 = -9223372036854775808; +pub const INT_FAST8_MAX: u32 = 127; +pub const INT_FAST16_MAX: u64 = 9223372036854775807; +pub const INT_FAST32_MAX: u64 = 9223372036854775807; +pub const UINT_FAST8_MAX: u32 = 255; +pub const UINT_FAST16_MAX: i32 = -1; +pub const UINT_FAST32_MAX: i32 = -1; +pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const UINTPTR_MAX: i32 = -1; +pub const PTRDIFF_MIN: i64 = -9223372036854775808; +pub const PTRDIFF_MAX: u64 = 9223372036854775807; +pub const SIG_ATOMIC_MIN: i32 = -2147483648; +pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const SIZE_MAX: i32 = -1; +pub const WINT_MIN: u32 = 0; +pub const WINT_MAX: u32 = 4294967295; +pub const _STDIO_H: u32 = 1; +pub const _____fpos_t_defined: u32 = 1; +pub const ____mbstate_t_defined: u32 = 1; +pub const _____fpos64_t_defined: u32 = 1; +pub const ____FILE_defined: u32 = 1; +pub const __FILE_defined: u32 = 1; +pub const __struct_FILE_defined: u32 = 1; +pub const _IO_EOF_SEEN: u32 = 16; +pub const _IO_ERR_SEEN: u32 = 32; +pub const _IO_USER_LOCK: u32 = 32768; +pub const __cookie_io_functions_t_defined: u32 = 1; +pub const _IOFBF: u32 = 0; +pub const _IOLBF: u32 = 1; +pub const _IONBF: u32 = 2; +pub const BUFSIZ: u32 = 8192; +pub const EOF: i32 = -1; +pub const SEEK_SET: u32 = 0; +pub const SEEK_CUR: u32 = 1; +pub const SEEK_END: u32 = 2; +pub const P_tmpdir: &[u8; 5] = b"/tmp\0"; +pub const L_tmpnam: u32 = 20; +pub const TMP_MAX: u32 = 238328; +pub const _BITS_STDIO_LIM_H: u32 = 1; +pub const FILENAME_MAX: u32 = 4096; +pub const L_ctermid: u32 = 9; +pub const FOPEN_MAX: u32 = 16; +pub const __HAVE_FLOAT128: u32 = 1; +pub const __HAVE_DISTINCT_FLOAT128: u32 = 1; +pub const __HAVE_FLOAT64X: u32 = 1; +pub const __HAVE_FLOAT64X_LONG_DOUBLE: u32 = 1; +pub const __HAVE_FLOAT16: u32 = 0; +pub const __HAVE_FLOAT32: u32 = 1; +pub const __HAVE_FLOAT64: u32 = 1; +pub const __HAVE_FLOAT32X: u32 = 1; +pub const __HAVE_FLOAT128X: u32 = 0; +pub const __HAVE_DISTINCT_FLOAT16: u32 = 0; +pub const __HAVE_DISTINCT_FLOAT32: u32 = 0; +pub const __HAVE_DISTINCT_FLOAT64: u32 = 0; +pub const __HAVE_DISTINCT_FLOAT32X: u32 = 0; +pub const __HAVE_DISTINCT_FLOAT64X: u32 = 0; +pub const __HAVE_DISTINCT_FLOAT128X: u32 = 0; +pub const __HAVE_FLOATN_NOT_TYPEDEF: u32 = 0; +pub const GGML_FILE_MAGIC: u32 = 1734831468; +pub const GGML_FILE_VERSION: u32 = 2; +pub const GGML_QNT_VERSION: u32 = 2; +pub const GGML_QNT_VERSION_FACTOR: u32 = 1000; +pub const GGML_MAX_DIMS: u32 = 4; +pub const GGML_MAX_PARAMS: u32 = 2048; +pub const GGML_MAX_SRC: u32 = 10; +pub const GGML_MAX_N_THREADS: u32 = 512; +pub const GGML_MAX_OP_PARAMS: u32 = 64; +pub const GGML_MAX_NAME: u32 = 64; +pub const GGML_DEFAULT_N_THREADS: u32 = 4; +pub const GGML_DEFAULT_GRAPH_SIZE: u32 = 2048; +pub const GGML_MEM_ALIGN: u32 = 16; +pub const GGML_EXIT_SUCCESS: u32 = 0; +pub const GGML_EXIT_ABORTED: u32 = 1; +pub const GGML_ROPE_TYPE_NEOX: u32 = 2; +pub const GGML_ROPE_TYPE_MROPE: u32 = 8; +pub const GGML_ROPE_TYPE_VISION: u32 = 24; +pub const GGML_MROPE_SECTIONS: u32 = 4; +pub const GGML_KQ_MASK_PAD: u32 = 64; +pub const GGML_N_TASKS_MAX: i32 = -1; +pub const WHISPER_SAMPLE_RATE: u32 = 16000; +pub const WHISPER_N_FFT: u32 = 400; +pub const WHISPER_HOP_LENGTH: u32 = 160; +pub const WHISPER_CHUNK_SIZE: u32 = 30; +pub type wchar_t = ::std::os::raw::c_int; +#[repr(C)] +#[repr(align(16))] +#[derive(Debug, Copy, Clone)] +pub struct max_align_t { + pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, + pub __bindgen_padding_0: u64, + pub __clang_max_align_nonce2: u128, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of max_align_t"][::std::mem::size_of::() - 32usize]; + ["Alignment of max_align_t"][::std::mem::align_of::() - 16usize]; + ["Offset of field: max_align_t::__clang_max_align_nonce1"] + [::std::mem::offset_of!(max_align_t, __clang_max_align_nonce1) - 0usize]; + ["Offset of field: max_align_t::__clang_max_align_nonce2"] + [::std::mem::offset_of!(max_align_t, __clang_max_align_nonce2) - 16usize]; +}; +pub type __u_char = ::std::os::raw::c_uchar; +pub type __u_short = ::std::os::raw::c_ushort; +pub type __u_int = ::std::os::raw::c_uint; +pub type __u_long = ::std::os::raw::c_ulong; +pub type __int8_t = ::std::os::raw::c_schar; +pub type __uint8_t = ::std::os::raw::c_uchar; +pub type __int16_t = ::std::os::raw::c_short; +pub type __uint16_t = ::std::os::raw::c_ushort; +pub type __int32_t = ::std::os::raw::c_int; +pub type __uint32_t = ::std::os::raw::c_uint; +pub type __int64_t = ::std::os::raw::c_long; +pub type __uint64_t = ::std::os::raw::c_ulong; +pub type __int_least8_t = __int8_t; +pub type __uint_least8_t = __uint8_t; +pub type __int_least16_t = __int16_t; +pub type __uint_least16_t = __uint16_t; +pub type __int_least32_t = __int32_t; +pub type __uint_least32_t = __uint32_t; +pub type __int_least64_t = __int64_t; +pub type __uint_least64_t = __uint64_t; +pub type __quad_t = ::std::os::raw::c_long; +pub type __u_quad_t = ::std::os::raw::c_ulong; +pub type __intmax_t = ::std::os::raw::c_long; +pub type __uintmax_t = ::std::os::raw::c_ulong; +pub type __dev_t = ::std::os::raw::c_ulong; +pub type __uid_t = ::std::os::raw::c_uint; +pub type __gid_t = ::std::os::raw::c_uint; +pub type __ino_t = ::std::os::raw::c_ulong; +pub type __ino64_t = ::std::os::raw::c_ulong; +pub type __mode_t = ::std::os::raw::c_uint; +pub type __nlink_t = ::std::os::raw::c_ulong; +pub type __off_t = ::std::os::raw::c_long; +pub type __off64_t = ::std::os::raw::c_long; +pub type __pid_t = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct __fsid_t { + pub __val: [::std::os::raw::c_int; 2usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of __fsid_t"][::std::mem::size_of::<__fsid_t>() - 8usize]; + ["Alignment of __fsid_t"][::std::mem::align_of::<__fsid_t>() - 4usize]; + ["Offset of field: __fsid_t::__val"][::std::mem::offset_of!(__fsid_t, __val) - 0usize]; +}; +pub type __clock_t = ::std::os::raw::c_long; +pub type __rlim_t = ::std::os::raw::c_ulong; +pub type __rlim64_t = ::std::os::raw::c_ulong; +pub type __id_t = ::std::os::raw::c_uint; +pub type __time_t = ::std::os::raw::c_long; +pub type __useconds_t = ::std::os::raw::c_uint; +pub type __suseconds_t = ::std::os::raw::c_long; +pub type __suseconds64_t = ::std::os::raw::c_long; +pub type __daddr_t = ::std::os::raw::c_int; +pub type __key_t = ::std::os::raw::c_int; +pub type __clockid_t = ::std::os::raw::c_int; +pub type __timer_t = *mut ::std::os::raw::c_void; +pub type __blksize_t = ::std::os::raw::c_long; +pub type __blkcnt_t = ::std::os::raw::c_long; +pub type __blkcnt64_t = ::std::os::raw::c_long; +pub type __fsblkcnt_t = ::std::os::raw::c_ulong; +pub type __fsblkcnt64_t = ::std::os::raw::c_ulong; +pub type __fsfilcnt_t = ::std::os::raw::c_ulong; +pub type __fsfilcnt64_t = ::std::os::raw::c_ulong; +pub type __fsword_t = ::std::os::raw::c_long; +pub type __ssize_t = ::std::os::raw::c_long; +pub type __syscall_slong_t = ::std::os::raw::c_long; +pub type __syscall_ulong_t = ::std::os::raw::c_ulong; +pub type __loff_t = __off64_t; +pub type __caddr_t = *mut ::std::os::raw::c_char; +pub type __intptr_t = ::std::os::raw::c_long; +pub type __socklen_t = ::std::os::raw::c_uint; +pub type __sig_atomic_t = ::std::os::raw::c_int; +pub type int_least8_t = __int_least8_t; +pub type int_least16_t = __int_least16_t; +pub type int_least32_t = __int_least32_t; +pub type int_least64_t = __int_least64_t; +pub type uint_least8_t = __uint_least8_t; +pub type uint_least16_t = __uint_least16_t; +pub type uint_least32_t = __uint_least32_t; +pub type uint_least64_t = __uint_least64_t; +pub type int_fast8_t = ::std::os::raw::c_schar; +pub type int_fast16_t = ::std::os::raw::c_long; +pub type int_fast32_t = ::std::os::raw::c_long; +pub type int_fast64_t = ::std::os::raw::c_long; +pub type uint_fast8_t = ::std::os::raw::c_uchar; +pub type uint_fast16_t = ::std::os::raw::c_ulong; +pub type uint_fast32_t = ::std::os::raw::c_ulong; +pub type uint_fast64_t = ::std::os::raw::c_ulong; +pub type intmax_t = __intmax_t; +pub type uintmax_t = __uintmax_t; +pub type __gnuc_va_list = __builtin_va_list; +#[repr(C)] +#[derive(Copy, Clone)] +pub struct __mbstate_t { + pub __count: ::std::os::raw::c_int, + pub __value: __mbstate_t__bindgen_ty_1, +} +#[repr(C)] +#[derive(Copy, Clone)] +pub union __mbstate_t__bindgen_ty_1 { + pub __wch: ::std::os::raw::c_uint, + pub __wchb: [::std::os::raw::c_char; 4usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of __mbstate_t__bindgen_ty_1"] + [::std::mem::size_of::<__mbstate_t__bindgen_ty_1>() - 4usize]; + ["Alignment of __mbstate_t__bindgen_ty_1"] + [::std::mem::align_of::<__mbstate_t__bindgen_ty_1>() - 4usize]; + ["Offset of field: __mbstate_t__bindgen_ty_1::__wch"] + [::std::mem::offset_of!(__mbstate_t__bindgen_ty_1, __wch) - 0usize]; + ["Offset of field: __mbstate_t__bindgen_ty_1::__wchb"] + [::std::mem::offset_of!(__mbstate_t__bindgen_ty_1, __wchb) - 0usize]; +}; +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of __mbstate_t"][::std::mem::size_of::<__mbstate_t>() - 8usize]; + ["Alignment of __mbstate_t"][::std::mem::align_of::<__mbstate_t>() - 4usize]; + ["Offset of field: __mbstate_t::__count"] + [::std::mem::offset_of!(__mbstate_t, __count) - 0usize]; + ["Offset of field: __mbstate_t::__value"] + [::std::mem::offset_of!(__mbstate_t, __value) - 4usize]; +}; +#[repr(C)] +#[derive(Copy, Clone)] +pub struct _G_fpos_t { + pub __pos: __off_t, + pub __state: __mbstate_t, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _G_fpos_t"][::std::mem::size_of::<_G_fpos_t>() - 16usize]; + ["Alignment of _G_fpos_t"][::std::mem::align_of::<_G_fpos_t>() - 8usize]; + ["Offset of field: _G_fpos_t::__pos"][::std::mem::offset_of!(_G_fpos_t, __pos) - 0usize]; + ["Offset of field: _G_fpos_t::__state"][::std::mem::offset_of!(_G_fpos_t, __state) - 8usize]; +}; +pub type __fpos_t = _G_fpos_t; +#[repr(C)] +#[derive(Copy, Clone)] +pub struct _G_fpos64_t { + pub __pos: __off64_t, + pub __state: __mbstate_t, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _G_fpos64_t"][::std::mem::size_of::<_G_fpos64_t>() - 16usize]; + ["Alignment of _G_fpos64_t"][::std::mem::align_of::<_G_fpos64_t>() - 8usize]; + ["Offset of field: _G_fpos64_t::__pos"][::std::mem::offset_of!(_G_fpos64_t, __pos) - 0usize]; + ["Offset of field: _G_fpos64_t::__state"] + [::std::mem::offset_of!(_G_fpos64_t, __state) - 8usize]; +}; +pub type __fpos64_t = _G_fpos64_t; +pub type __FILE = _IO_FILE; +pub type FILE = _IO_FILE; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _IO_marker { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _IO_codecvt { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _IO_wide_data { + _unused: [u8; 0], +} +pub type _IO_lock_t = ::std::os::raw::c_void; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _IO_FILE { + pub _flags: ::std::os::raw::c_int, + pub _IO_read_ptr: *mut ::std::os::raw::c_char, + pub _IO_read_end: *mut ::std::os::raw::c_char, + pub _IO_read_base: *mut ::std::os::raw::c_char, + pub _IO_write_base: *mut ::std::os::raw::c_char, + pub _IO_write_ptr: *mut ::std::os::raw::c_char, + pub _IO_write_end: *mut ::std::os::raw::c_char, + pub _IO_buf_base: *mut ::std::os::raw::c_char, + pub _IO_buf_end: *mut ::std::os::raw::c_char, + pub _IO_save_base: *mut ::std::os::raw::c_char, + pub _IO_backup_base: *mut ::std::os::raw::c_char, + pub _IO_save_end: *mut ::std::os::raw::c_char, + pub _markers: *mut _IO_marker, + pub _chain: *mut _IO_FILE, + pub _fileno: ::std::os::raw::c_int, + pub _bitfield_align_1: [u32; 0], + pub _bitfield_1: __BindgenBitfieldUnit<[u8; 3usize]>, + pub _short_backupbuf: [::std::os::raw::c_char; 1usize], + pub _old_offset: __off_t, + pub _cur_column: ::std::os::raw::c_ushort, + pub _vtable_offset: ::std::os::raw::c_schar, + pub _shortbuf: [::std::os::raw::c_char; 1usize], + pub _lock: *mut _IO_lock_t, + pub _offset: __off64_t, + pub _codecvt: *mut _IO_codecvt, + pub _wide_data: *mut _IO_wide_data, + pub _freeres_list: *mut _IO_FILE, + pub _freeres_buf: *mut ::std::os::raw::c_void, + pub _prevchain: *mut *mut _IO_FILE, + pub _mode: ::std::os::raw::c_int, + pub _unused2: [::std::os::raw::c_char; 20usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _IO_FILE"][::std::mem::size_of::<_IO_FILE>() - 216usize]; + ["Alignment of _IO_FILE"][::std::mem::align_of::<_IO_FILE>() - 8usize]; + ["Offset of field: _IO_FILE::_flags"][::std::mem::offset_of!(_IO_FILE, _flags) - 0usize]; + ["Offset of field: _IO_FILE::_IO_read_ptr"] + [::std::mem::offset_of!(_IO_FILE, _IO_read_ptr) - 8usize]; + ["Offset of field: _IO_FILE::_IO_read_end"] + [::std::mem::offset_of!(_IO_FILE, _IO_read_end) - 16usize]; + ["Offset of field: _IO_FILE::_IO_read_base"] + [::std::mem::offset_of!(_IO_FILE, _IO_read_base) - 24usize]; + ["Offset of field: _IO_FILE::_IO_write_base"] + [::std::mem::offset_of!(_IO_FILE, _IO_write_base) - 32usize]; + ["Offset of field: _IO_FILE::_IO_write_ptr"] + [::std::mem::offset_of!(_IO_FILE, _IO_write_ptr) - 40usize]; + ["Offset of field: _IO_FILE::_IO_write_end"] + [::std::mem::offset_of!(_IO_FILE, _IO_write_end) - 48usize]; + ["Offset of field: _IO_FILE::_IO_buf_base"] + [::std::mem::offset_of!(_IO_FILE, _IO_buf_base) - 56usize]; + ["Offset of field: _IO_FILE::_IO_buf_end"] + [::std::mem::offset_of!(_IO_FILE, _IO_buf_end) - 64usize]; + ["Offset of field: _IO_FILE::_IO_save_base"] + [::std::mem::offset_of!(_IO_FILE, _IO_save_base) - 72usize]; + ["Offset of field: _IO_FILE::_IO_backup_base"] + [::std::mem::offset_of!(_IO_FILE, _IO_backup_base) - 80usize]; + ["Offset of field: _IO_FILE::_IO_save_end"] + [::std::mem::offset_of!(_IO_FILE, _IO_save_end) - 88usize]; + ["Offset of field: _IO_FILE::_markers"][::std::mem::offset_of!(_IO_FILE, _markers) - 96usize]; + ["Offset of field: _IO_FILE::_chain"][::std::mem::offset_of!(_IO_FILE, _chain) - 104usize]; + ["Offset of field: _IO_FILE::_fileno"][::std::mem::offset_of!(_IO_FILE, _fileno) - 112usize]; + ["Offset of field: _IO_FILE::_short_backupbuf"] + [::std::mem::offset_of!(_IO_FILE, _short_backupbuf) - 119usize]; + ["Offset of field: _IO_FILE::_old_offset"] + [::std::mem::offset_of!(_IO_FILE, _old_offset) - 120usize]; + ["Offset of field: _IO_FILE::_cur_column"] + [::std::mem::offset_of!(_IO_FILE, _cur_column) - 128usize]; + ["Offset of field: _IO_FILE::_vtable_offset"] + [::std::mem::offset_of!(_IO_FILE, _vtable_offset) - 130usize]; + ["Offset of field: _IO_FILE::_shortbuf"] + [::std::mem::offset_of!(_IO_FILE, _shortbuf) - 131usize]; + ["Offset of field: _IO_FILE::_lock"][::std::mem::offset_of!(_IO_FILE, _lock) - 136usize]; + ["Offset of field: _IO_FILE::_offset"][::std::mem::offset_of!(_IO_FILE, _offset) - 144usize]; + ["Offset of field: _IO_FILE::_codecvt"][::std::mem::offset_of!(_IO_FILE, _codecvt) - 152usize]; + ["Offset of field: _IO_FILE::_wide_data"] + [::std::mem::offset_of!(_IO_FILE, _wide_data) - 160usize]; + ["Offset of field: _IO_FILE::_freeres_list"] + [::std::mem::offset_of!(_IO_FILE, _freeres_list) - 168usize]; + ["Offset of field: _IO_FILE::_freeres_buf"] + [::std::mem::offset_of!(_IO_FILE, _freeres_buf) - 176usize]; + ["Offset of field: _IO_FILE::_prevchain"] + [::std::mem::offset_of!(_IO_FILE, _prevchain) - 184usize]; + ["Offset of field: _IO_FILE::_mode"][::std::mem::offset_of!(_IO_FILE, _mode) - 192usize]; + ["Offset of field: _IO_FILE::_unused2"][::std::mem::offset_of!(_IO_FILE, _unused2) - 196usize]; +}; +impl _IO_FILE { + #[inline] + pub fn _flags2(&self) -> ::std::os::raw::c_int { + unsafe { ::std::mem::transmute(self._bitfield_1.get(0usize, 24u8) as u32) } + } + #[inline] + pub fn set__flags2(&mut self, val: ::std::os::raw::c_int) { + unsafe { + let val: u32 = ::std::mem::transmute(val); + self._bitfield_1.set(0usize, 24u8, val as u64) + } + } + #[inline] + pub unsafe fn _flags2_raw(this: *const Self) -> ::std::os::raw::c_int { + unsafe { + ::std::mem::transmute(<__BindgenBitfieldUnit<[u8; 3usize]>>::raw_get( + ::std::ptr::addr_of!((*this)._bitfield_1), + 0usize, + 24u8, + ) as u32) + } + } + #[inline] + pub unsafe fn set__flags2_raw(this: *mut Self, val: ::std::os::raw::c_int) { + unsafe { + let val: u32 = ::std::mem::transmute(val); + <__BindgenBitfieldUnit<[u8; 3usize]>>::raw_set( + ::std::ptr::addr_of_mut!((*this)._bitfield_1), + 0usize, + 24u8, + val as u64, + ) + } + } + #[inline] + pub fn new_bitfield_1(_flags2: ::std::os::raw::c_int) -> __BindgenBitfieldUnit<[u8; 3usize]> { + let mut __bindgen_bitfield_unit: __BindgenBitfieldUnit<[u8; 3usize]> = Default::default(); + __bindgen_bitfield_unit.set(0usize, 24u8, { + let _flags2: u32 = unsafe { ::std::mem::transmute(_flags2) }; + _flags2 as u64 + }); + __bindgen_bitfield_unit + } +} +pub type cookie_read_function_t = ::std::option::Option< + unsafe extern "C" fn( + __cookie: *mut ::std::os::raw::c_void, + __buf: *mut ::std::os::raw::c_char, + __nbytes: usize, + ) -> __ssize_t, +>; +pub type cookie_write_function_t = ::std::option::Option< + unsafe extern "C" fn( + __cookie: *mut ::std::os::raw::c_void, + __buf: *const ::std::os::raw::c_char, + __nbytes: usize, + ) -> __ssize_t, +>; +pub type cookie_seek_function_t = ::std::option::Option< + unsafe extern "C" fn( + __cookie: *mut ::std::os::raw::c_void, + __pos: *mut __off64_t, + __w: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int, +>; +pub type cookie_close_function_t = ::std::option::Option< + unsafe extern "C" fn(__cookie: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int, +>; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _IO_cookie_io_functions_t { + pub read: cookie_read_function_t, + pub write: cookie_write_function_t, + pub seek: cookie_seek_function_t, + pub close: cookie_close_function_t, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _IO_cookie_io_functions_t"] + [::std::mem::size_of::<_IO_cookie_io_functions_t>() - 32usize]; + ["Alignment of _IO_cookie_io_functions_t"] + [::std::mem::align_of::<_IO_cookie_io_functions_t>() - 8usize]; + ["Offset of field: _IO_cookie_io_functions_t::read"] + [::std::mem::offset_of!(_IO_cookie_io_functions_t, read) - 0usize]; + ["Offset of field: _IO_cookie_io_functions_t::write"] + [::std::mem::offset_of!(_IO_cookie_io_functions_t, write) - 8usize]; + ["Offset of field: _IO_cookie_io_functions_t::seek"] + [::std::mem::offset_of!(_IO_cookie_io_functions_t, seek) - 16usize]; + ["Offset of field: _IO_cookie_io_functions_t::close"] + [::std::mem::offset_of!(_IO_cookie_io_functions_t, close) - 24usize]; +}; +pub type cookie_io_functions_t = _IO_cookie_io_functions_t; +pub type va_list = __gnuc_va_list; +pub type off_t = __off_t; +pub type fpos_t = __fpos_t; +unsafe extern "C" { + pub static mut stdin: *mut FILE; +} +unsafe extern "C" { + pub static mut stdout: *mut FILE; +} +unsafe extern "C" { + pub static mut stderr: *mut FILE; +} +unsafe extern "C" { + pub fn remove(__filename: *const ::std::os::raw::c_char) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn rename( + __old: *const ::std::os::raw::c_char, + __new: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn renameat( + __oldfd: ::std::os::raw::c_int, + __old: *const ::std::os::raw::c_char, + __newfd: ::std::os::raw::c_int, + __new: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fclose(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn tmpfile() -> *mut FILE; +} +unsafe extern "C" { + pub fn tmpnam(arg1: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn tmpnam_r(__s: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn tempnam( + __dir: *const ::std::os::raw::c_char, + __pfx: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn fflush(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fflush_unlocked(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fopen( + __filename: *const ::std::os::raw::c_char, + __modes: *const ::std::os::raw::c_char, + ) -> *mut FILE; +} +unsafe extern "C" { + pub fn freopen( + __filename: *const ::std::os::raw::c_char, + __modes: *const ::std::os::raw::c_char, + __stream: *mut FILE, + ) -> *mut FILE; +} +unsafe extern "C" { + pub fn fdopen(__fd: ::std::os::raw::c_int, __modes: *const ::std::os::raw::c_char) + -> *mut FILE; +} +unsafe extern "C" { + pub fn fopencookie( + __magic_cookie: *mut ::std::os::raw::c_void, + __modes: *const ::std::os::raw::c_char, + __io_funcs: cookie_io_functions_t, + ) -> *mut FILE; +} +unsafe extern "C" { + pub fn fmemopen( + __s: *mut ::std::os::raw::c_void, + __len: usize, + __modes: *const ::std::os::raw::c_char, + ) -> *mut FILE; +} +unsafe extern "C" { + pub fn open_memstream( + __bufloc: *mut *mut ::std::os::raw::c_char, + __sizeloc: *mut usize, + ) -> *mut FILE; +} +unsafe extern "C" { + pub fn setbuf(__stream: *mut FILE, __buf: *mut ::std::os::raw::c_char); +} +unsafe extern "C" { + pub fn setvbuf( + __stream: *mut FILE, + __buf: *mut ::std::os::raw::c_char, + __modes: ::std::os::raw::c_int, + __n: usize, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn setbuffer(__stream: *mut FILE, __buf: *mut ::std::os::raw::c_char, __size: usize); +} +unsafe extern "C" { + pub fn setlinebuf(__stream: *mut FILE); +} +unsafe extern "C" { + pub fn fprintf( + __stream: *mut FILE, + __format: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn printf(__format: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn sprintf( + __s: *mut ::std::os::raw::c_char, + __format: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vfprintf( + __s: *mut FILE, + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vprintf( + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vsprintf( + __s: *mut ::std::os::raw::c_char, + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn snprintf( + __s: *mut ::std::os::raw::c_char, + __maxlen: ::std::os::raw::c_ulong, + __format: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vsnprintf( + __s: *mut ::std::os::raw::c_char, + __maxlen: ::std::os::raw::c_ulong, + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vasprintf( + __ptr: *mut *mut ::std::os::raw::c_char, + __f: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn __asprintf( + __ptr: *mut *mut ::std::os::raw::c_char, + __fmt: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn asprintf( + __ptr: *mut *mut ::std::os::raw::c_char, + __fmt: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vdprintf( + __fd: ::std::os::raw::c_int, + __fmt: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn dprintf( + __fd: ::std::os::raw::c_int, + __fmt: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fscanf( + __stream: *mut FILE, + __format: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn scanf(__format: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn sscanf( + __s: *const ::std::os::raw::c_char, + __format: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +pub type __cfloat128 = __BindgenComplex; +pub type _Float128 = u128; +pub type _Float32 = f32; +pub type _Float64 = f64; +pub type _Float32x = f64; +pub type _Float64x = u128; +unsafe extern "C" { + #[link_name = "\u{1}__isoc99_fscanf"] + pub fn fscanf1( + __stream: *mut FILE, + __format: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + #[link_name = "\u{1}__isoc99_scanf"] + pub fn scanf1(__format: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + #[link_name = "\u{1}__isoc99_sscanf"] + pub fn sscanf1( + __s: *const ::std::os::raw::c_char, + __format: *const ::std::os::raw::c_char, + ... + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vfscanf( + __s: *mut FILE, + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vscanf( + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn vsscanf( + __s: *const ::std::os::raw::c_char, + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + #[link_name = "\u{1}__isoc99_vfscanf"] + pub fn vfscanf1( + __s: *mut FILE, + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + #[link_name = "\u{1}__isoc99_vscanf"] + pub fn vscanf1( + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + #[link_name = "\u{1}__isoc99_vsscanf"] + pub fn vsscanf1( + __s: *const ::std::os::raw::c_char, + __format: *const ::std::os::raw::c_char, + __arg: *mut __va_list_tag, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fgetc(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn getc(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn getchar() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn getc_unlocked(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn getchar_unlocked() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fgetc_unlocked(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fputc(__c: ::std::os::raw::c_int, __stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn putc(__c: ::std::os::raw::c_int, __stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn putchar(__c: ::std::os::raw::c_int) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fputc_unlocked(__c: ::std::os::raw::c_int, __stream: *mut FILE) + -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn putc_unlocked(__c: ::std::os::raw::c_int, __stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn putchar_unlocked(__c: ::std::os::raw::c_int) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn getw(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn putw(__w: ::std::os::raw::c_int, __stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fgets( + __s: *mut ::std::os::raw::c_char, + __n: ::std::os::raw::c_int, + __stream: *mut FILE, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn __getdelim( + __lineptr: *mut *mut ::std::os::raw::c_char, + __n: *mut usize, + __delimiter: ::std::os::raw::c_int, + __stream: *mut FILE, + ) -> __ssize_t; +} +unsafe extern "C" { + pub fn getdelim( + __lineptr: *mut *mut ::std::os::raw::c_char, + __n: *mut usize, + __delimiter: ::std::os::raw::c_int, + __stream: *mut FILE, + ) -> __ssize_t; +} +unsafe extern "C" { + pub fn getline( + __lineptr: *mut *mut ::std::os::raw::c_char, + __n: *mut usize, + __stream: *mut FILE, + ) -> __ssize_t; +} +unsafe extern "C" { + pub fn fputs(__s: *const ::std::os::raw::c_char, __stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn puts(__s: *const ::std::os::raw::c_char) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ungetc(__c: ::std::os::raw::c_int, __stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fread( + __ptr: *mut ::std::os::raw::c_void, + __size: ::std::os::raw::c_ulong, + __n: ::std::os::raw::c_ulong, + __stream: *mut FILE, + ) -> ::std::os::raw::c_ulong; +} +unsafe extern "C" { + pub fn fwrite( + __ptr: *const ::std::os::raw::c_void, + __size: ::std::os::raw::c_ulong, + __n: ::std::os::raw::c_ulong, + __s: *mut FILE, + ) -> ::std::os::raw::c_ulong; +} +unsafe extern "C" { + pub fn fread_unlocked( + __ptr: *mut ::std::os::raw::c_void, + __size: usize, + __n: usize, + __stream: *mut FILE, + ) -> usize; +} +unsafe extern "C" { + pub fn fwrite_unlocked( + __ptr: *const ::std::os::raw::c_void, + __size: usize, + __n: usize, + __stream: *mut FILE, + ) -> usize; +} +unsafe extern "C" { + pub fn fseek( + __stream: *mut FILE, + __off: ::std::os::raw::c_long, + __whence: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ftell(__stream: *mut FILE) -> ::std::os::raw::c_long; +} +unsafe extern "C" { + pub fn rewind(__stream: *mut FILE); +} +unsafe extern "C" { + pub fn fseeko( + __stream: *mut FILE, + __off: __off_t, + __whence: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ftello(__stream: *mut FILE) -> __off_t; +} +unsafe extern "C" { + pub fn fgetpos(__stream: *mut FILE, __pos: *mut fpos_t) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fsetpos(__stream: *mut FILE, __pos: *const fpos_t) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn clearerr(__stream: *mut FILE); +} +unsafe extern "C" { + pub fn feof(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ferror(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn clearerr_unlocked(__stream: *mut FILE); +} +unsafe extern "C" { + pub fn feof_unlocked(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ferror_unlocked(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn perror(__s: *const ::std::os::raw::c_char); +} +unsafe extern "C" { + pub fn fileno(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fileno_unlocked(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn pclose(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn popen( + __command: *const ::std::os::raw::c_char, + __modes: *const ::std::os::raw::c_char, + ) -> *mut FILE; +} +unsafe extern "C" { + pub fn ctermid(__s: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn flockfile(__stream: *mut FILE); +} +unsafe extern "C" { + pub fn ftrylockfile(__stream: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn funlockfile(__stream: *mut FILE); +} +unsafe extern "C" { + pub fn __uflow(arg1: *mut FILE) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn __overflow(arg1: *mut FILE, arg2: ::std::os::raw::c_int) -> ::std::os::raw::c_int; +} +pub type ggml_abort_callback_t = + ::std::option::Option; +unsafe extern "C" { + pub fn ggml_set_abort_callback(callback: ggml_abort_callback_t) -> ggml_abort_callback_t; +} +unsafe extern "C" { + pub fn ggml_abort( + file: *const ::std::os::raw::c_char, + line: ::std::os::raw::c_int, + fmt: *const ::std::os::raw::c_char, + ... + ); +} +pub const ggml_status_GGML_STATUS_ALLOC_FAILED: ggml_status = -2; +pub const ggml_status_GGML_STATUS_FAILED: ggml_status = -1; +pub const ggml_status_GGML_STATUS_SUCCESS: ggml_status = 0; +pub const ggml_status_GGML_STATUS_ABORTED: ggml_status = 1; +pub type ggml_status = ::std::os::raw::c_int; +unsafe extern "C" { + pub fn ggml_status_to_string(status: ggml_status) -> *const ::std::os::raw::c_char; +} +pub type ggml_fp16_t = u16; +unsafe extern "C" { + pub fn ggml_fp16_to_fp32(arg1: ggml_fp16_t) -> f32; +} +unsafe extern "C" { + pub fn ggml_fp32_to_fp16(arg1: f32) -> ggml_fp16_t; +} +unsafe extern "C" { + pub fn ggml_fp16_to_fp32_row(arg1: *const ggml_fp16_t, arg2: *mut f32, arg3: i64); +} +unsafe extern "C" { + pub fn ggml_fp32_to_fp16_row(arg1: *const f32, arg2: *mut ggml_fp16_t, arg3: i64); +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_bf16_t { + pub bits: u16, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_bf16_t"][::std::mem::size_of::() - 2usize]; + ["Alignment of ggml_bf16_t"][::std::mem::align_of::() - 2usize]; + ["Offset of field: ggml_bf16_t::bits"][::std::mem::offset_of!(ggml_bf16_t, bits) - 0usize]; +}; +unsafe extern "C" { + pub fn ggml_fp32_to_bf16(arg1: f32) -> ggml_bf16_t; +} +unsafe extern "C" { + pub fn ggml_bf16_to_fp32(arg1: ggml_bf16_t) -> f32; +} +unsafe extern "C" { + pub fn ggml_bf16_to_fp32_row(arg1: *const ggml_bf16_t, arg2: *mut f32, arg3: i64); +} +unsafe extern "C" { + pub fn ggml_fp32_to_bf16_row_ref(arg1: *const f32, arg2: *mut ggml_bf16_t, arg3: i64); +} +unsafe extern "C" { + pub fn ggml_fp32_to_bf16_row(arg1: *const f32, arg2: *mut ggml_bf16_t, arg3: i64); +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_object { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_context { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_cgraph { + _unused: [u8; 0], +} +pub const ggml_type_GGML_TYPE_F32: ggml_type = 0; +pub const ggml_type_GGML_TYPE_F16: ggml_type = 1; +pub const ggml_type_GGML_TYPE_Q4_0: ggml_type = 2; +pub const ggml_type_GGML_TYPE_Q4_1: ggml_type = 3; +pub const ggml_type_GGML_TYPE_Q5_0: ggml_type = 6; +pub const ggml_type_GGML_TYPE_Q5_1: ggml_type = 7; +pub const ggml_type_GGML_TYPE_Q8_0: ggml_type = 8; +pub const ggml_type_GGML_TYPE_Q8_1: ggml_type = 9; +pub const ggml_type_GGML_TYPE_Q2_K: ggml_type = 10; +pub const ggml_type_GGML_TYPE_Q3_K: ggml_type = 11; +pub const ggml_type_GGML_TYPE_Q4_K: ggml_type = 12; +pub const ggml_type_GGML_TYPE_Q5_K: ggml_type = 13; +pub const ggml_type_GGML_TYPE_Q6_K: ggml_type = 14; +pub const ggml_type_GGML_TYPE_Q8_K: ggml_type = 15; +pub const ggml_type_GGML_TYPE_IQ2_XXS: ggml_type = 16; +pub const ggml_type_GGML_TYPE_IQ2_XS: ggml_type = 17; +pub const ggml_type_GGML_TYPE_IQ3_XXS: ggml_type = 18; +pub const ggml_type_GGML_TYPE_IQ1_S: ggml_type = 19; +pub const ggml_type_GGML_TYPE_IQ4_NL: ggml_type = 20; +pub const ggml_type_GGML_TYPE_IQ3_S: ggml_type = 21; +pub const ggml_type_GGML_TYPE_IQ2_S: ggml_type = 22; +pub const ggml_type_GGML_TYPE_IQ4_XS: ggml_type = 23; +pub const ggml_type_GGML_TYPE_I8: ggml_type = 24; +pub const ggml_type_GGML_TYPE_I16: ggml_type = 25; +pub const ggml_type_GGML_TYPE_I32: ggml_type = 26; +pub const ggml_type_GGML_TYPE_I64: ggml_type = 27; +pub const ggml_type_GGML_TYPE_F64: ggml_type = 28; +pub const ggml_type_GGML_TYPE_IQ1_M: ggml_type = 29; +pub const ggml_type_GGML_TYPE_BF16: ggml_type = 30; +pub const ggml_type_GGML_TYPE_TQ1_0: ggml_type = 34; +pub const ggml_type_GGML_TYPE_TQ2_0: ggml_type = 35; +pub const ggml_type_GGML_TYPE_MXFP4: ggml_type = 39; +pub const ggml_type_GGML_TYPE_COUNT: ggml_type = 40; +pub type ggml_type = ::std::os::raw::c_uint; +pub const ggml_prec_GGML_PREC_DEFAULT: ggml_prec = 0; +pub const ggml_prec_GGML_PREC_F32: ggml_prec = 10; +pub type ggml_prec = ::std::os::raw::c_uint; +pub const ggml_ftype_GGML_FTYPE_UNKNOWN: ggml_ftype = -1; +pub const ggml_ftype_GGML_FTYPE_ALL_F32: ggml_ftype = 0; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_F16: ggml_ftype = 1; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q4_0: ggml_ftype = 2; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q4_1: ggml_ftype = 3; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: ggml_ftype = 4; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q8_0: ggml_ftype = 7; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q5_0: ggml_ftype = 8; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q5_1: ggml_ftype = 9; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q2_K: ggml_ftype = 10; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q3_K: ggml_ftype = 11; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q4_K: ggml_ftype = 12; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q5_K: ggml_ftype = 13; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_Q6_K: ggml_ftype = 14; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ2_XXS: ggml_ftype = 15; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ2_XS: ggml_ftype = 16; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ3_XXS: ggml_ftype = 17; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ1_S: ggml_ftype = 18; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ4_NL: ggml_ftype = 19; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ3_S: ggml_ftype = 20; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ2_S: ggml_ftype = 21; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ4_XS: ggml_ftype = 22; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_IQ1_M: ggml_ftype = 23; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_BF16: ggml_ftype = 24; +pub const ggml_ftype_GGML_FTYPE_MOSTLY_MXFP4: ggml_ftype = 25; +pub type ggml_ftype = ::std::os::raw::c_int; +pub const ggml_op_GGML_OP_NONE: ggml_op = 0; +pub const ggml_op_GGML_OP_DUP: ggml_op = 1; +pub const ggml_op_GGML_OP_ADD: ggml_op = 2; +pub const ggml_op_GGML_OP_ADD_ID: ggml_op = 3; +pub const ggml_op_GGML_OP_ADD1: ggml_op = 4; +pub const ggml_op_GGML_OP_ACC: ggml_op = 5; +pub const ggml_op_GGML_OP_SUB: ggml_op = 6; +pub const ggml_op_GGML_OP_MUL: ggml_op = 7; +pub const ggml_op_GGML_OP_DIV: ggml_op = 8; +pub const ggml_op_GGML_OP_SQR: ggml_op = 9; +pub const ggml_op_GGML_OP_SQRT: ggml_op = 10; +pub const ggml_op_GGML_OP_LOG: ggml_op = 11; +pub const ggml_op_GGML_OP_SIN: ggml_op = 12; +pub const ggml_op_GGML_OP_COS: ggml_op = 13; +pub const ggml_op_GGML_OP_SUM: ggml_op = 14; +pub const ggml_op_GGML_OP_SUM_ROWS: ggml_op = 15; +pub const ggml_op_GGML_OP_MEAN: ggml_op = 16; +pub const ggml_op_GGML_OP_ARGMAX: ggml_op = 17; +pub const ggml_op_GGML_OP_COUNT_EQUAL: ggml_op = 18; +pub const ggml_op_GGML_OP_REPEAT: ggml_op = 19; +pub const ggml_op_GGML_OP_REPEAT_BACK: ggml_op = 20; +pub const ggml_op_GGML_OP_CONCAT: ggml_op = 21; +pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 22; +pub const ggml_op_GGML_OP_NORM: ggml_op = 23; +pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 24; +pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 25; +pub const ggml_op_GGML_OP_GROUP_NORM: ggml_op = 26; +pub const ggml_op_GGML_OP_L2_NORM: ggml_op = 27; +pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 28; +pub const ggml_op_GGML_OP_MUL_MAT_ID: ggml_op = 29; +pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 30; +pub const ggml_op_GGML_OP_SCALE: ggml_op = 31; +pub const ggml_op_GGML_OP_SET: ggml_op = 32; +pub const ggml_op_GGML_OP_CPY: ggml_op = 33; +pub const ggml_op_GGML_OP_CONT: ggml_op = 34; +pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 35; +pub const ggml_op_GGML_OP_VIEW: ggml_op = 36; +pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 37; +pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 38; +pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 39; +pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 40; +pub const ggml_op_GGML_OP_SET_ROWS: ggml_op = 41; +pub const ggml_op_GGML_OP_DIAG: ggml_op = 42; +pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 43; +pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 44; +pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 45; +pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 46; +pub const ggml_op_GGML_OP_ROPE: ggml_op = 47; +pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 48; +pub const ggml_op_GGML_OP_CLAMP: ggml_op = 49; +pub const ggml_op_GGML_OP_CONV_TRANSPOSE_1D: ggml_op = 50; +pub const ggml_op_GGML_OP_IM2COL: ggml_op = 51; +pub const ggml_op_GGML_OP_IM2COL_BACK: ggml_op = 52; +pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 53; +pub const ggml_op_GGML_OP_CONV_2D_DW: ggml_op = 54; +pub const ggml_op_GGML_OP_CONV_TRANSPOSE_2D: ggml_op = 55; +pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 56; +pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 57; +pub const ggml_op_GGML_OP_POOL_2D_BACK: ggml_op = 58; +pub const ggml_op_GGML_OP_UPSCALE: ggml_op = 59; +pub const ggml_op_GGML_OP_PAD: ggml_op = 60; +pub const ggml_op_GGML_OP_PAD_REFLECT_1D: ggml_op = 61; +pub const ggml_op_GGML_OP_ROLL: ggml_op = 62; +pub const ggml_op_GGML_OP_ARANGE: ggml_op = 63; +pub const ggml_op_GGML_OP_TIMESTEP_EMBEDDING: ggml_op = 64; +pub const ggml_op_GGML_OP_ARGSORT: ggml_op = 65; +pub const ggml_op_GGML_OP_LEAKY_RELU: ggml_op = 66; +pub const ggml_op_GGML_OP_FLASH_ATTN_EXT: ggml_op = 67; +pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 68; +pub const ggml_op_GGML_OP_SSM_CONV: ggml_op = 69; +pub const ggml_op_GGML_OP_SSM_SCAN: ggml_op = 70; +pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 71; +pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 72; +pub const ggml_op_GGML_OP_GET_REL_POS: ggml_op = 73; +pub const ggml_op_GGML_OP_ADD_REL_POS: ggml_op = 74; +pub const ggml_op_GGML_OP_RWKV_WKV6: ggml_op = 75; +pub const ggml_op_GGML_OP_GATED_LINEAR_ATTN: ggml_op = 76; +pub const ggml_op_GGML_OP_RWKV_WKV7: ggml_op = 77; +pub const ggml_op_GGML_OP_UNARY: ggml_op = 78; +pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 79; +pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 80; +pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 81; +pub const ggml_op_GGML_OP_CUSTOM: ggml_op = 82; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 83; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 84; +pub const ggml_op_GGML_OP_OPT_STEP_ADAMW: ggml_op = 85; +pub const ggml_op_GGML_OP_OPT_STEP_SGD: ggml_op = 86; +pub const ggml_op_GGML_OP_GLU: ggml_op = 87; +pub const ggml_op_GGML_OP_COUNT: ggml_op = 88; +pub type ggml_op = ::std::os::raw::c_uint; +pub const ggml_unary_op_GGML_UNARY_OP_ABS: ggml_unary_op = 0; +pub const ggml_unary_op_GGML_UNARY_OP_SGN: ggml_unary_op = 1; +pub const ggml_unary_op_GGML_UNARY_OP_NEG: ggml_unary_op = 2; +pub const ggml_unary_op_GGML_UNARY_OP_STEP: ggml_unary_op = 3; +pub const ggml_unary_op_GGML_UNARY_OP_TANH: ggml_unary_op = 4; +pub const ggml_unary_op_GGML_UNARY_OP_ELU: ggml_unary_op = 5; +pub const ggml_unary_op_GGML_UNARY_OP_RELU: ggml_unary_op = 6; +pub const ggml_unary_op_GGML_UNARY_OP_SIGMOID: ggml_unary_op = 7; +pub const ggml_unary_op_GGML_UNARY_OP_GELU: ggml_unary_op = 8; +pub const ggml_unary_op_GGML_UNARY_OP_GELU_QUICK: ggml_unary_op = 9; +pub const ggml_unary_op_GGML_UNARY_OP_SILU: ggml_unary_op = 10; +pub const ggml_unary_op_GGML_UNARY_OP_HARDSWISH: ggml_unary_op = 11; +pub const ggml_unary_op_GGML_UNARY_OP_HARDSIGMOID: ggml_unary_op = 12; +pub const ggml_unary_op_GGML_UNARY_OP_EXP: ggml_unary_op = 13; +pub const ggml_unary_op_GGML_UNARY_OP_GELU_ERF: ggml_unary_op = 14; +pub const ggml_unary_op_GGML_UNARY_OP_COUNT: ggml_unary_op = 15; +pub type ggml_unary_op = ::std::os::raw::c_uint; +pub const ggml_glu_op_GGML_GLU_OP_REGLU: ggml_glu_op = 0; +pub const ggml_glu_op_GGML_GLU_OP_GEGLU: ggml_glu_op = 1; +pub const ggml_glu_op_GGML_GLU_OP_SWIGLU: ggml_glu_op = 2; +pub const ggml_glu_op_GGML_GLU_OP_SWIGLU_OAI: ggml_glu_op = 3; +pub const ggml_glu_op_GGML_GLU_OP_GEGLU_ERF: ggml_glu_op = 4; +pub const ggml_glu_op_GGML_GLU_OP_GEGLU_QUICK: ggml_glu_op = 5; +pub const ggml_glu_op_GGML_GLU_OP_COUNT: ggml_glu_op = 6; +pub type ggml_glu_op = ::std::os::raw::c_uint; +pub const ggml_object_type_GGML_OBJECT_TYPE_TENSOR: ggml_object_type = 0; +pub const ggml_object_type_GGML_OBJECT_TYPE_GRAPH: ggml_object_type = 1; +pub const ggml_object_type_GGML_OBJECT_TYPE_WORK_BUFFER: ggml_object_type = 2; +pub type ggml_object_type = ::std::os::raw::c_uint; +pub const ggml_log_level_GGML_LOG_LEVEL_NONE: ggml_log_level = 0; +pub const ggml_log_level_GGML_LOG_LEVEL_DEBUG: ggml_log_level = 1; +pub const ggml_log_level_GGML_LOG_LEVEL_INFO: ggml_log_level = 2; +pub const ggml_log_level_GGML_LOG_LEVEL_WARN: ggml_log_level = 3; +pub const ggml_log_level_GGML_LOG_LEVEL_ERROR: ggml_log_level = 4; +pub const ggml_log_level_GGML_LOG_LEVEL_CONT: ggml_log_level = 5; +pub type ggml_log_level = ::std::os::raw::c_uint; +pub const ggml_tensor_flag_GGML_TENSOR_FLAG_INPUT: ggml_tensor_flag = 1; +pub const ggml_tensor_flag_GGML_TENSOR_FLAG_OUTPUT: ggml_tensor_flag = 2; +pub const ggml_tensor_flag_GGML_TENSOR_FLAG_PARAM: ggml_tensor_flag = 4; +pub const ggml_tensor_flag_GGML_TENSOR_FLAG_LOSS: ggml_tensor_flag = 8; +pub type ggml_tensor_flag = ::std::os::raw::c_uint; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_init_params { + pub mem_size: usize, + pub mem_buffer: *mut ::std::os::raw::c_void, + pub no_alloc: bool, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_init_params"][::std::mem::size_of::() - 24usize]; + ["Alignment of ggml_init_params"][::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_init_params::mem_size"] + [::std::mem::offset_of!(ggml_init_params, mem_size) - 0usize]; + ["Offset of field: ggml_init_params::mem_buffer"] + [::std::mem::offset_of!(ggml_init_params, mem_buffer) - 8usize]; + ["Offset of field: ggml_init_params::no_alloc"] + [::std::mem::offset_of!(ggml_init_params, no_alloc) - 16usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_tensor { + pub type_: ggml_type, + pub buffer: *mut ggml_backend_buffer, + pub ne: [i64; 4usize], + pub nb: [usize; 4usize], + pub op: ggml_op, + pub op_params: [i32; 16usize], + pub flags: i32, + pub src: [*mut ggml_tensor; 10usize], + pub view_src: *mut ggml_tensor, + pub view_offs: usize, + pub data: *mut ::std::os::raw::c_void, + pub name: [::std::os::raw::c_char; 64usize], + pub extra: *mut ::std::os::raw::c_void, + pub padding: [::std::os::raw::c_char; 8usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_tensor"][::std::mem::size_of::() - 336usize]; + ["Alignment of ggml_tensor"][::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_tensor::type_"][::std::mem::offset_of!(ggml_tensor, type_) - 0usize]; + ["Offset of field: ggml_tensor::buffer"][::std::mem::offset_of!(ggml_tensor, buffer) - 8usize]; + ["Offset of field: ggml_tensor::ne"][::std::mem::offset_of!(ggml_tensor, ne) - 16usize]; + ["Offset of field: ggml_tensor::nb"][::std::mem::offset_of!(ggml_tensor, nb) - 48usize]; + ["Offset of field: ggml_tensor::op"][::std::mem::offset_of!(ggml_tensor, op) - 80usize]; + ["Offset of field: ggml_tensor::op_params"] + [::std::mem::offset_of!(ggml_tensor, op_params) - 84usize]; + ["Offset of field: ggml_tensor::flags"][::std::mem::offset_of!(ggml_tensor, flags) - 148usize]; + ["Offset of field: ggml_tensor::src"][::std::mem::offset_of!(ggml_tensor, src) - 152usize]; + ["Offset of field: ggml_tensor::view_src"] + [::std::mem::offset_of!(ggml_tensor, view_src) - 232usize]; + ["Offset of field: ggml_tensor::view_offs"] + [::std::mem::offset_of!(ggml_tensor, view_offs) - 240usize]; + ["Offset of field: ggml_tensor::data"][::std::mem::offset_of!(ggml_tensor, data) - 248usize]; + ["Offset of field: ggml_tensor::name"][::std::mem::offset_of!(ggml_tensor, name) - 256usize]; + ["Offset of field: ggml_tensor::extra"][::std::mem::offset_of!(ggml_tensor, extra) - 320usize]; + ["Offset of field: ggml_tensor::padding"] + [::std::mem::offset_of!(ggml_tensor, padding) - 328usize]; +}; +pub const GGML_TENSOR_SIZE: usize = 336; +pub type ggml_abort_callback = + ::std::option::Option bool>; +pub type ggml_guid = [u8; 16usize]; +pub type ggml_guid_t = *mut ggml_guid; +unsafe extern "C" { + pub fn ggml_guid_matches(guid_a: ggml_guid_t, guid_b: ggml_guid_t) -> bool; +} +unsafe extern "C" { + pub fn ggml_version() -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_commit() -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_time_init(); +} +unsafe extern "C" { + pub fn ggml_time_ms() -> i64; +} +unsafe extern "C" { + pub fn ggml_time_us() -> i64; +} +unsafe extern "C" { + pub fn ggml_cycles() -> i64; +} +unsafe extern "C" { + pub fn ggml_cycles_per_ms() -> i64; +} +unsafe extern "C" { + pub fn ggml_fopen( + fname: *const ::std::os::raw::c_char, + mode: *const ::std::os::raw::c_char, + ) -> *mut FILE; +} +unsafe extern "C" { + pub fn ggml_print_object(obj: *const ggml_object); +} +unsafe extern "C" { + pub fn ggml_print_objects(ctx: *const ggml_context); +} +unsafe extern "C" { + pub fn ggml_nelements(tensor: *const ggml_tensor) -> i64; +} +unsafe extern "C" { + pub fn ggml_nrows(tensor: *const ggml_tensor) -> i64; +} +unsafe extern "C" { + pub fn ggml_nbytes(tensor: *const ggml_tensor) -> usize; +} +unsafe extern "C" { + pub fn ggml_nbytes_pad(tensor: *const ggml_tensor) -> usize; +} +unsafe extern "C" { + pub fn ggml_blck_size(type_: ggml_type) -> i64; +} +unsafe extern "C" { + pub fn ggml_type_size(type_: ggml_type) -> usize; +} +unsafe extern "C" { + pub fn ggml_row_size(type_: ggml_type, ne: i64) -> usize; +} +unsafe extern "C" { + pub fn ggml_type_sizef(type_: ggml_type) -> f64; +} +unsafe extern "C" { + pub fn ggml_type_name(type_: ggml_type) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_op_name(op: ggml_op) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_op_symbol(op: ggml_op) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_unary_op_name(op: ggml_unary_op) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_glu_op_name(op: ggml_glu_op) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_op_desc(t: *const ggml_tensor) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_element_size(tensor: *const ggml_tensor) -> usize; +} +unsafe extern "C" { + pub fn ggml_is_quantized(type_: ggml_type) -> bool; +} +unsafe extern "C" { + pub fn ggml_ftype_to_ggml_type(ftype: ggml_ftype) -> ggml_type; +} +unsafe extern "C" { + pub fn ggml_is_transposed(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_permuted(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_empty(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_scalar(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_vector(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_matrix(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_3d(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_n_dims(tensor: *const ggml_tensor) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_is_contiguous(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_contiguous_0(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_contiguous_1(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_contiguous_2(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_contiguously_allocated(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_contiguous_channels(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_is_contiguous_rows(tensor: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_are_same_shape(t0: *const ggml_tensor, t1: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_are_same_stride(t0: *const ggml_tensor, t1: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_can_repeat(t0: *const ggml_tensor, t1: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_tensor_overhead() -> usize; +} +unsafe extern "C" { + pub fn ggml_validate_row_data( + type_: ggml_type, + data: *const ::std::os::raw::c_void, + nbytes: usize, + ) -> bool; +} +unsafe extern "C" { + pub fn ggml_init(params: ggml_init_params) -> *mut ggml_context; +} +unsafe extern "C" { + pub fn ggml_reset(ctx: *mut ggml_context); +} +unsafe extern "C" { + pub fn ggml_free(ctx: *mut ggml_context); +} +unsafe extern "C" { + pub fn ggml_used_mem(ctx: *const ggml_context) -> usize; +} +unsafe extern "C" { + pub fn ggml_get_no_alloc(ctx: *mut ggml_context) -> bool; +} +unsafe extern "C" { + pub fn ggml_set_no_alloc(ctx: *mut ggml_context, no_alloc: bool); +} +unsafe extern "C" { + pub fn ggml_get_mem_buffer(ctx: *const ggml_context) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn ggml_get_mem_size(ctx: *const ggml_context) -> usize; +} +unsafe extern "C" { + pub fn ggml_get_max_tensor_size(ctx: *const ggml_context) -> usize; +} +unsafe extern "C" { + pub fn ggml_new_tensor( + ctx: *mut ggml_context, + type_: ggml_type, + n_dims: ::std::os::raw::c_int, + ne: *const i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_new_tensor_1d( + ctx: *mut ggml_context, + type_: ggml_type, + ne0: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_new_tensor_2d( + ctx: *mut ggml_context, + type_: ggml_type, + ne0: i64, + ne1: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_new_tensor_3d( + ctx: *mut ggml_context, + type_: ggml_type, + ne0: i64, + ne1: i64, + ne2: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_new_tensor_4d( + ctx: *mut ggml_context, + type_: ggml_type, + ne0: i64, + ne1: i64, + ne2: i64, + ne3: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_new_buffer(ctx: *mut ggml_context, nbytes: usize) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn ggml_dup_tensor(ctx: *mut ggml_context, src: *const ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_view_tensor(ctx: *mut ggml_context, src: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_get_first_tensor(ctx: *const ggml_context) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_get_next_tensor( + ctx: *const ggml_context, + tensor: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_get_tensor( + ctx: *mut ggml_context, + name: *const ::std::os::raw::c_char, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_unravel_index( + tensor: *const ggml_tensor, + i: i64, + i0: *mut i64, + i1: *mut i64, + i2: *mut i64, + i3: *mut i64, + ); +} +unsafe extern "C" { + pub fn ggml_get_unary_op(tensor: *const ggml_tensor) -> ggml_unary_op; +} +unsafe extern "C" { + pub fn ggml_get_glu_op(tensor: *const ggml_tensor) -> ggml_glu_op; +} +unsafe extern "C" { + pub fn ggml_get_data(tensor: *const ggml_tensor) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn ggml_get_data_f32(tensor: *const ggml_tensor) -> *mut f32; +} +unsafe extern "C" { + pub fn ggml_get_name(tensor: *const ggml_tensor) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_set_name( + tensor: *mut ggml_tensor, + name: *const ::std::os::raw::c_char, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_format_name( + tensor: *mut ggml_tensor, + fmt: *const ::std::os::raw::c_char, + ... + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_input(tensor: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_set_output(tensor: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_set_param(tensor: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_set_loss(tensor: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_dup(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_dup_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_add( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_add_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_add_cast( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + type_: ggml_type, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_add_id( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ids: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_add1( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_add1_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_acc( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + nb1: usize, + nb2: usize, + nb3: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_acc_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + nb1: usize, + nb2: usize, + nb3: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sub( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sub_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_mul( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_mul_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_div( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_div_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sqr(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sqr_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sqrt(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sqrt_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_log(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_log_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sin(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sin_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cos(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cos_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sum(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sum_rows(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_mean(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_argmax(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_count_equal( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_repeat( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_repeat_4d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ne2: i64, + ne3: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_repeat_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_concat( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + dim: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_abs(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_abs_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sgn(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sgn_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_neg(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_neg_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_step(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_step_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_tanh(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_tanh_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_elu(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_elu_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_relu(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_leaky_relu( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + negative_slope: f32, + inplace: bool, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_relu_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sigmoid(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_sigmoid_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_gelu(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_gelu_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_gelu_erf(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_gelu_erf_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_gelu_quick(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_gelu_quick_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) + -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_silu(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_silu_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_silu_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_hardswish(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_hardsigmoid(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_exp(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_exp_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_glu( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_glu_op, + swapped: bool, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_reglu(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_reglu_swapped(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu_swapped(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_swiglu(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_swiglu_swapped(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu_erf(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu_erf_swapped(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu_quick(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu_quick_swapped( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_glu_split( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + op: ggml_glu_op, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_reglu_split( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu_split( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_swiglu_split( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu_erf_split( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_geglu_quick_split( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_swiglu_oai( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + alpha: f32, + limit: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_norm(ctx: *mut ggml_context, a: *mut ggml_tensor, eps: f32) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_norm_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + eps: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rms_norm(ctx: *mut ggml_context, a: *mut ggml_tensor, eps: f32) + -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rms_norm_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + eps: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_group_norm( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_groups: ::std::os::raw::c_int, + eps: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_group_norm_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_groups: ::std::os::raw::c_int, + eps: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_l2_norm(ctx: *mut ggml_context, a: *mut ggml_tensor, eps: f32) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_l2_norm_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + eps: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rms_norm_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + eps: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_mul_mat( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_mul_mat_set_prec(a: *mut ggml_tensor, prec: ggml_prec); +} +unsafe extern "C" { + pub fn ggml_mul_mat_id( + ctx: *mut ggml_context, + as_: *mut ggml_tensor, + b: *mut ggml_tensor, + ids: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_out_prod( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_scale(ctx: *mut ggml_context, a: *mut ggml_tensor, s: f32) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_scale_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + s: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_scale_bias( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + s: f32, + b: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_scale_bias_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + s: f32, + b: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + nb1: usize, + nb2: usize, + nb3: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + nb1: usize, + nb2: usize, + nb3: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_1d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_1d_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_2d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + nb1: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_2d_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + nb1: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cpy( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cast( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + type_: ggml_type, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cont(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cont_1d(ctx: *mut ggml_context, a: *mut ggml_tensor, ne0: i64) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cont_2d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cont_3d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ne2: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cont_4d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ne2: i64, + ne3: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_reshape( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_reshape_1d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_reshape_2d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_reshape_3d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ne2: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_reshape_4d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ne2: i64, + ne3: i64, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_view_1d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_view_2d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + nb1: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_view_3d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ne2: i64, + nb1: usize, + nb2: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_view_4d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ne2: i64, + ne3: i64, + nb1: usize, + nb2: usize, + nb3: usize, + offset: usize, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_permute( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + axis0: ::std::os::raw::c_int, + axis1: ::std::os::raw::c_int, + axis2: ::std::os::raw::c_int, + axis3: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_transpose(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_get_rows( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_get_rows_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_rows( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_diag(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_diag_mask_inf( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_past: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_diag_mask_inf_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_past: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_diag_mask_zero( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_past: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_diag_mask_zero_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_past: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_soft_max(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_soft_max_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_soft_max_ext( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + mask: *mut ggml_tensor, + scale: f32, + max_bias: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_soft_max_add_sinks(a: *mut ggml_tensor, sinks: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_soft_max_ext_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + scale: f32, + max_bias: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_soft_max_ext_back_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + scale: f32, + max_bias: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_ext( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + ext_factor: f32, + attn_factor: f32, + beta_fast: f32, + beta_slow: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_multi( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + sections: *mut ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + ext_factor: f32, + attn_factor: f32, + beta_fast: f32, + beta_slow: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_ext_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + ext_factor: f32, + attn_factor: f32, + beta_fast: f32, + beta_slow: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_multi_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + sections: *mut ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + ext_factor: f32, + attn_factor: f32, + beta_fast: f32, + beta_slow: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_custom( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + ext_factor: f32, + attn_factor: f32, + beta_fast: f32, + beta_slow: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_custom_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + ext_factor: f32, + attn_factor: f32, + beta_fast: f32, + beta_slow: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_yarn_corr_dims( + n_dims: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + beta_fast: f32, + beta_slow: f32, + dims: *mut f32, + ); +} +unsafe extern "C" { + pub fn ggml_rope_ext_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + ext_factor: f32, + attn_factor: f32, + beta_fast: f32, + beta_slow: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rope_multi_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + n_dims: ::std::os::raw::c_int, + sections: *mut ::std::os::raw::c_int, + mode: ::std::os::raw::c_int, + n_ctx_orig: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + ext_factor: f32, + attn_factor: f32, + beta_fast: f32, + beta_slow: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_clamp( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + min: f32, + max: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_im2col( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + d1: ::std::os::raw::c_int, + is_2D: bool, + dst_type: ggml_type, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_im2col_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ne: *mut i64, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + d1: ::std::os::raw::c_int, + is_2D: bool, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_1d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_1d_ph( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s: ::std::os::raw::c_int, + d: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_1d_dw( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_1d_dw_ph( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_transpose_1d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_2d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + d1: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_2d_sk_p0( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_2d_s1_ph( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_2d_dw( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + d1: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_2d_dw_direct( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + stride0: ::std::os::raw::c_int, + stride1: ::std::os::raw::c_int, + pad0: ::std::os::raw::c_int, + pad1: ::std::os::raw::c_int, + dilation0: ::std::os::raw::c_int, + dilation1: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_transpose_2d_p0( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + stride: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_2d_direct( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + d1: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +pub const ggml_op_pool_GGML_OP_POOL_MAX: ggml_op_pool = 0; +pub const ggml_op_pool_GGML_OP_POOL_AVG: ggml_op_pool = 1; +pub const ggml_op_pool_GGML_OP_POOL_COUNT: ggml_op_pool = 2; +pub type ggml_op_pool = ::std::os::raw::c_uint; +unsafe extern "C" { + pub fn ggml_pool_1d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_op_pool, + k0: ::std::os::raw::c_int, + s0: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_pool_2d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_op_pool, + k0: ::std::os::raw::c_int, + k1: ::std::os::raw::c_int, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + p0: f32, + p1: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_pool_2d_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + af: *mut ggml_tensor, + op: ggml_op_pool, + k0: ::std::os::raw::c_int, + k1: ::std::os::raw::c_int, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + p0: f32, + p1: f32, + ) -> *mut ggml_tensor; +} +pub const ggml_scale_mode_GGML_SCALE_MODE_NEAREST: ggml_scale_mode = 0; +pub const ggml_scale_mode_GGML_SCALE_MODE_BILINEAR: ggml_scale_mode = 1; +pub const ggml_scale_mode_GGML_SCALE_MODE_COUNT: ggml_scale_mode = 2; +pub type ggml_scale_mode = ::std::os::raw::c_uint; +pub const ggml_scale_flag_GGML_SCALE_FLAG_ALIGN_CORNERS: ggml_scale_flag = 256; +pub type ggml_scale_flag = ::std::os::raw::c_uint; +unsafe extern "C" { + pub fn ggml_upscale( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + scale_factor: ::std::os::raw::c_int, + mode: ggml_scale_mode, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_upscale_ext( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: ::std::os::raw::c_int, + ne1: ::std::os::raw::c_int, + ne2: ::std::os::raw::c_int, + ne3: ::std::os::raw::c_int, + mode: ggml_scale_mode, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_interpolate( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + ne0: i64, + ne1: i64, + ne2: i64, + ne3: i64, + mode: u32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_pad( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + p2: ::std::os::raw::c_int, + p3: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_pad_reflect_1d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_roll( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + shift0: ::std::os::raw::c_int, + shift1: ::std::os::raw::c_int, + shift2: ::std::os::raw::c_int, + shift3: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_timestep_embedding( + ctx: *mut ggml_context, + timesteps: *mut ggml_tensor, + dim: ::std::os::raw::c_int, + max_period: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +pub const ggml_sort_order_GGML_SORT_ORDER_ASC: ggml_sort_order = 0; +pub const ggml_sort_order_GGML_SORT_ORDER_DESC: ggml_sort_order = 1; +pub type ggml_sort_order = ::std::os::raw::c_uint; +unsafe extern "C" { + pub fn ggml_argsort( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + order: ggml_sort_order, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_arange( + ctx: *mut ggml_context, + start: f32, + stop: f32, + step: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_top_k( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + k: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_flash_attn_ext( + ctx: *mut ggml_context, + q: *mut ggml_tensor, + k: *mut ggml_tensor, + v: *mut ggml_tensor, + mask: *mut ggml_tensor, + scale: f32, + max_bias: f32, + logit_softcap: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_flash_attn_ext_set_prec(a: *mut ggml_tensor, prec: ggml_prec); +} +unsafe extern "C" { + pub fn ggml_flash_attn_ext_get_prec(a: *const ggml_tensor) -> ggml_prec; +} +unsafe extern "C" { + pub fn ggml_flash_attn_ext_add_sinks(a: *mut ggml_tensor, sinks: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_flash_attn_back( + ctx: *mut ggml_context, + q: *mut ggml_tensor, + k: *mut ggml_tensor, + v: *mut ggml_tensor, + d: *mut ggml_tensor, + masked: bool, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_ssm_conv( + ctx: *mut ggml_context, + sx: *mut ggml_tensor, + c: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_ssm_scan( + ctx: *mut ggml_context, + s: *mut ggml_tensor, + x: *mut ggml_tensor, + dt: *mut ggml_tensor, + A: *mut ggml_tensor, + B: *mut ggml_tensor, + C: *mut ggml_tensor, + ids: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_win_part( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + w: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_win_unpart( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + w0: ::std::os::raw::c_int, + h0: ::std::os::raw::c_int, + w: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_unary( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_unary_op, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_unary_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_unary_op, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_get_rel_pos( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + qh: ::std::os::raw::c_int, + kh: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_add_rel_pos( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + pw: *mut ggml_tensor, + ph: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_add_rel_pos_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + pw: *mut ggml_tensor, + ph: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rwkv_wkv6( + ctx: *mut ggml_context, + k: *mut ggml_tensor, + v: *mut ggml_tensor, + r: *mut ggml_tensor, + tf: *mut ggml_tensor, + td: *mut ggml_tensor, + state: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_gated_linear_attn( + ctx: *mut ggml_context, + k: *mut ggml_tensor, + v: *mut ggml_tensor, + q: *mut ggml_tensor, + g: *mut ggml_tensor, + state: *mut ggml_tensor, + scale: f32, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_rwkv_wkv7( + ctx: *mut ggml_context, + r: *mut ggml_tensor, + w: *mut ggml_tensor, + k: *mut ggml_tensor, + v: *mut ggml_tensor, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + state: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +pub type ggml_custom1_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +pub type ggml_custom2_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + b: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +pub type ggml_custom3_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + b: *const ggml_tensor, + c: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +unsafe extern "C" { + pub fn ggml_map_custom1( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + fun: ggml_custom1_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_map_custom1_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + fun: ggml_custom1_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_map_custom2( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + fun: ggml_custom2_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_map_custom2_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + fun: ggml_custom2_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_map_custom3( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + fun: ggml_custom3_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_map_custom3_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + fun: ggml_custom3_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +pub type ggml_custom_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +unsafe extern "C" { + pub fn ggml_custom_4d( + ctx: *mut ggml_context, + type_: ggml_type, + ne0: i64, + ne1: i64, + ne2: i64, + ne3: i64, + args: *mut *mut ggml_tensor, + n_args: ::std::os::raw::c_int, + fun: ggml_custom_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_custom_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + args: *mut *mut ggml_tensor, + n_args: ::std::os::raw::c_int, + fun: ggml_custom_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cross_entropy_loss( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_cross_entropy_loss_back( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_opt_step_adamw( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + grad: *mut ggml_tensor, + m: *mut ggml_tensor, + v: *mut ggml_tensor, + adamw_params: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_opt_step_sgd( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + grad: *mut ggml_tensor, + sgd_params: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_build_forward_expand(cgraph: *mut ggml_cgraph, tensor: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_build_backward_expand( + ctx: *mut ggml_context, + cgraph: *mut ggml_cgraph, + grad_accs: *mut *mut ggml_tensor, + ); +} +unsafe extern "C" { + pub fn ggml_new_graph(ctx: *mut ggml_context) -> *mut ggml_cgraph; +} +unsafe extern "C" { + pub fn ggml_new_graph_custom( + ctx: *mut ggml_context, + size: usize, + grads: bool, + ) -> *mut ggml_cgraph; +} +unsafe extern "C" { + pub fn ggml_graph_dup( + ctx: *mut ggml_context, + cgraph: *mut ggml_cgraph, + force_grads: bool, + ) -> *mut ggml_cgraph; +} +unsafe extern "C" { + pub fn ggml_graph_cpy(src: *mut ggml_cgraph, dst: *mut ggml_cgraph); +} +unsafe extern "C" { + pub fn ggml_graph_reset(cgraph: *mut ggml_cgraph); +} +unsafe extern "C" { + pub fn ggml_graph_clear(cgraph: *mut ggml_cgraph); +} +unsafe extern "C" { + pub fn ggml_graph_size(cgraph: *mut ggml_cgraph) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_graph_node(cgraph: *mut ggml_cgraph, i: ::std::os::raw::c_int) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_graph_nodes(cgraph: *mut ggml_cgraph) -> *mut *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_graph_n_nodes(cgraph: *mut ggml_cgraph) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_graph_add_node(cgraph: *mut ggml_cgraph, tensor: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_graph_overhead() -> usize; +} +unsafe extern "C" { + pub fn ggml_graph_overhead_custom(size: usize, grads: bool) -> usize; +} +unsafe extern "C" { + pub fn ggml_graph_get_tensor( + cgraph: *const ggml_cgraph, + name: *const ::std::os::raw::c_char, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_graph_get_grad( + cgraph: *const ggml_cgraph, + node: *const ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_graph_get_grad_acc( + cgraph: *const ggml_cgraph, + node: *const ggml_tensor, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_graph_print(cgraph: *const ggml_cgraph); +} +unsafe extern "C" { + pub fn ggml_graph_dump_dot( + gb: *const ggml_cgraph, + gf: *const ggml_cgraph, + filename: *const ::std::os::raw::c_char, + ); +} +pub type ggml_log_callback = ::std::option::Option< + unsafe extern "C" fn( + level: ggml_log_level, + text: *const ::std::os::raw::c_char, + user_data: *mut ::std::os::raw::c_void, + ), +>; +unsafe extern "C" { + pub fn ggml_log_set(log_callback: ggml_log_callback, user_data: *mut ::std::os::raw::c_void); +} +unsafe extern "C" { + pub fn ggml_set_zero(tensor: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_quantize_init(type_: ggml_type); +} +unsafe extern "C" { + pub fn ggml_quantize_free(); +} +unsafe extern "C" { + pub fn ggml_quantize_requires_imatrix(type_: ggml_type) -> bool; +} +unsafe extern "C" { + pub fn ggml_quantize_chunk( + type_: ggml_type, + src: *const f32, + dst: *mut ::std::os::raw::c_void, + start: i64, + nrows: i64, + n_per_row: i64, + imatrix: *const f32, + ) -> usize; +} +pub type ggml_to_float_t = ::std::option::Option< + unsafe extern "C" fn(x: *const ::std::os::raw::c_void, y: *mut f32, k: i64), +>; +pub type ggml_from_float_t = ::std::option::Option< + unsafe extern "C" fn(x: *const f32, y: *mut ::std::os::raw::c_void, k: i64), +>; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_type_traits { + pub type_name: *const ::std::os::raw::c_char, + pub blck_size: i64, + pub blck_size_interleave: i64, + pub type_size: usize, + pub is_quantized: bool, + pub to_float: ggml_to_float_t, + pub from_float_ref: ggml_from_float_t, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_type_traits"][::std::mem::size_of::() - 56usize]; + ["Alignment of ggml_type_traits"][::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_type_traits::type_name"] + [::std::mem::offset_of!(ggml_type_traits, type_name) - 0usize]; + ["Offset of field: ggml_type_traits::blck_size"] + [::std::mem::offset_of!(ggml_type_traits, blck_size) - 8usize]; + ["Offset of field: ggml_type_traits::blck_size_interleave"] + [::std::mem::offset_of!(ggml_type_traits, blck_size_interleave) - 16usize]; + ["Offset of field: ggml_type_traits::type_size"] + [::std::mem::offset_of!(ggml_type_traits, type_size) - 24usize]; + ["Offset of field: ggml_type_traits::is_quantized"] + [::std::mem::offset_of!(ggml_type_traits, is_quantized) - 32usize]; + ["Offset of field: ggml_type_traits::to_float"] + [::std::mem::offset_of!(ggml_type_traits, to_float) - 40usize]; + ["Offset of field: ggml_type_traits::from_float_ref"] + [::std::mem::offset_of!(ggml_type_traits, from_float_ref) - 48usize]; +}; +unsafe extern "C" { + pub fn ggml_get_type_traits(type_: ggml_type) -> *const ggml_type_traits; +} +pub const ggml_sched_priority_GGML_SCHED_PRIO_LOW: ggml_sched_priority = -1; +pub const ggml_sched_priority_GGML_SCHED_PRIO_NORMAL: ggml_sched_priority = 0; +pub const ggml_sched_priority_GGML_SCHED_PRIO_MEDIUM: ggml_sched_priority = 1; +pub const ggml_sched_priority_GGML_SCHED_PRIO_HIGH: ggml_sched_priority = 2; +pub const ggml_sched_priority_GGML_SCHED_PRIO_REALTIME: ggml_sched_priority = 3; +pub type ggml_sched_priority = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_threadpool_params { + pub cpumask: [bool; 512usize], + pub n_threads: ::std::os::raw::c_int, + pub prio: ggml_sched_priority, + pub poll: u32, + pub strict_cpu: bool, + pub paused: bool, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_threadpool_params"][::std::mem::size_of::() - 528usize]; + ["Alignment of ggml_threadpool_params"] + [::std::mem::align_of::() - 4usize]; + ["Offset of field: ggml_threadpool_params::cpumask"] + [::std::mem::offset_of!(ggml_threadpool_params, cpumask) - 0usize]; + ["Offset of field: ggml_threadpool_params::n_threads"] + [::std::mem::offset_of!(ggml_threadpool_params, n_threads) - 512usize]; + ["Offset of field: ggml_threadpool_params::prio"] + [::std::mem::offset_of!(ggml_threadpool_params, prio) - 516usize]; + ["Offset of field: ggml_threadpool_params::poll"] + [::std::mem::offset_of!(ggml_threadpool_params, poll) - 520usize]; + ["Offset of field: ggml_threadpool_params::strict_cpu"] + [::std::mem::offset_of!(ggml_threadpool_params, strict_cpu) - 524usize]; + ["Offset of field: ggml_threadpool_params::paused"] + [::std::mem::offset_of!(ggml_threadpool_params, paused) - 525usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_threadpool { + _unused: [u8; 0], +} +pub type ggml_threadpool_t = *mut ggml_threadpool; +unsafe extern "C" { + pub fn ggml_threadpool_params_default( + n_threads: ::std::os::raw::c_int, + ) -> ggml_threadpool_params; +} +unsafe extern "C" { + pub fn ggml_threadpool_params_init( + p: *mut ggml_threadpool_params, + n_threads: ::std::os::raw::c_int, + ); +} +unsafe extern "C" { + pub fn ggml_threadpool_params_match( + p0: *const ggml_threadpool_params, + p1: *const ggml_threadpool_params, + ) -> bool; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_buffer_type { + _unused: [u8; 0], +} +pub type ggml_backend_buffer_type_t = *mut ggml_backend_buffer_type; +pub type ggml_backend_buffer_t = *mut ggml_backend_buffer; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend { + _unused: [u8; 0], +} +pub type ggml_backend_t = *mut ggml_backend; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_tallocr { + pub buffer: ggml_backend_buffer_t, + pub base: *mut ::std::os::raw::c_void, + pub alignment: usize, + pub offset: usize, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_tallocr"][::std::mem::size_of::() - 32usize]; + ["Alignment of ggml_tallocr"][::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_tallocr::buffer"] + [::std::mem::offset_of!(ggml_tallocr, buffer) - 0usize]; + ["Offset of field: ggml_tallocr::base"][::std::mem::offset_of!(ggml_tallocr, base) - 8usize]; + ["Offset of field: ggml_tallocr::alignment"] + [::std::mem::offset_of!(ggml_tallocr, alignment) - 16usize]; + ["Offset of field: ggml_tallocr::offset"] + [::std::mem::offset_of!(ggml_tallocr, offset) - 24usize]; +}; +unsafe extern "C" { + pub fn ggml_tallocr_new(buffer: ggml_backend_buffer_t) -> ggml_tallocr; +} +unsafe extern "C" { + pub fn ggml_tallocr_alloc(talloc: *mut ggml_tallocr, tensor: *mut ggml_tensor) -> ggml_status; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_gallocr { + _unused: [u8; 0], +} +pub type ggml_gallocr_t = *mut ggml_gallocr; +unsafe extern "C" { + pub fn ggml_gallocr_new(buft: ggml_backend_buffer_type_t) -> ggml_gallocr_t; +} +unsafe extern "C" { + pub fn ggml_gallocr_new_n( + bufts: *mut ggml_backend_buffer_type_t, + n_bufs: ::std::os::raw::c_int, + ) -> ggml_gallocr_t; +} +unsafe extern "C" { + pub fn ggml_gallocr_free(galloc: ggml_gallocr_t); +} +unsafe extern "C" { + pub fn ggml_gallocr_reserve(galloc: ggml_gallocr_t, graph: *mut ggml_cgraph) -> bool; +} +unsafe extern "C" { + pub fn ggml_gallocr_reserve_n( + galloc: ggml_gallocr_t, + graph: *mut ggml_cgraph, + node_buffer_ids: *const ::std::os::raw::c_int, + leaf_buffer_ids: *const ::std::os::raw::c_int, + ) -> bool; +} +unsafe extern "C" { + pub fn ggml_gallocr_alloc_graph(galloc: ggml_gallocr_t, graph: *mut ggml_cgraph) -> bool; +} +unsafe extern "C" { + pub fn ggml_gallocr_get_buffer_size( + galloc: ggml_gallocr_t, + buffer_id: ::std::os::raw::c_int, + ) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_alloc_ctx_tensors_from_buft( + ctx: *mut ggml_context, + buft: ggml_backend_buffer_type_t, + ) -> *mut ggml_backend_buffer; +} +unsafe extern "C" { + pub fn ggml_backend_alloc_ctx_tensors( + ctx: *mut ggml_context, + backend: ggml_backend_t, + ) -> *mut ggml_backend_buffer; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_event { + _unused: [u8; 0], +} +pub type ggml_backend_event_t = *mut ggml_backend_event; +pub type ggml_backend_graph_plan_t = *mut ::std::os::raw::c_void; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_reg { + _unused: [u8; 0], +} +pub type ggml_backend_reg_t = *mut ggml_backend_reg; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_device { + _unused: [u8; 0], +} +pub type ggml_backend_dev_t = *mut ggml_backend_device; +unsafe extern "C" { + pub fn ggml_backend_buft_name( + buft: ggml_backend_buffer_type_t, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_backend_buft_alloc_buffer( + buft: ggml_backend_buffer_type_t, + size: usize, + ) -> ggml_backend_buffer_t; +} +unsafe extern "C" { + pub fn ggml_backend_buft_get_alignment(buft: ggml_backend_buffer_type_t) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_buft_get_max_size(buft: ggml_backend_buffer_type_t) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_buft_get_alloc_size( + buft: ggml_backend_buffer_type_t, + tensor: *const ggml_tensor, + ) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_buft_is_host(buft: ggml_backend_buffer_type_t) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_buft_get_device(buft: ggml_backend_buffer_type_t) -> ggml_backend_dev_t; +} +pub const ggml_backend_buffer_usage_GGML_BACKEND_BUFFER_USAGE_ANY: ggml_backend_buffer_usage = 0; +pub const ggml_backend_buffer_usage_GGML_BACKEND_BUFFER_USAGE_WEIGHTS: ggml_backend_buffer_usage = + 1; +pub const ggml_backend_buffer_usage_GGML_BACKEND_BUFFER_USAGE_COMPUTE: ggml_backend_buffer_usage = + 2; +pub type ggml_backend_buffer_usage = ::std::os::raw::c_uint; +unsafe extern "C" { + pub fn ggml_backend_buffer_name(buffer: ggml_backend_buffer_t) + -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_free(buffer: ggml_backend_buffer_t); +} +unsafe extern "C" { + pub fn ggml_backend_buffer_get_base( + buffer: ggml_backend_buffer_t, + ) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_get_size(buffer: ggml_backend_buffer_t) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_init_tensor( + buffer: ggml_backend_buffer_t, + tensor: *mut ggml_tensor, + ) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_get_alignment(buffer: ggml_backend_buffer_t) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_get_max_size(buffer: ggml_backend_buffer_t) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_get_alloc_size( + buffer: ggml_backend_buffer_t, + tensor: *const ggml_tensor, + ) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_clear(buffer: ggml_backend_buffer_t, value: u8); +} +unsafe extern "C" { + pub fn ggml_backend_buffer_is_host(buffer: ggml_backend_buffer_t) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_set_usage( + buffer: ggml_backend_buffer_t, + usage: ggml_backend_buffer_usage, + ); +} +unsafe extern "C" { + pub fn ggml_backend_buffer_get_usage( + buffer: ggml_backend_buffer_t, + ) -> ggml_backend_buffer_usage; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_get_type( + buffer: ggml_backend_buffer_t, + ) -> ggml_backend_buffer_type_t; +} +unsafe extern "C" { + pub fn ggml_backend_buffer_reset(buffer: ggml_backend_buffer_t); +} +unsafe extern "C" { + pub fn ggml_backend_tensor_copy(src: *mut ggml_tensor, dst: *mut ggml_tensor); +} +unsafe extern "C" { + pub fn ggml_backend_guid(backend: ggml_backend_t) -> ggml_guid_t; +} +unsafe extern "C" { + pub fn ggml_backend_name(backend: ggml_backend_t) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_backend_free(backend: ggml_backend_t); +} +unsafe extern "C" { + pub fn ggml_backend_get_default_buffer_type( + backend: ggml_backend_t, + ) -> ggml_backend_buffer_type_t; +} +unsafe extern "C" { + pub fn ggml_backend_alloc_buffer(backend: ggml_backend_t, size: usize) + -> ggml_backend_buffer_t; +} +unsafe extern "C" { + pub fn ggml_backend_get_alignment(backend: ggml_backend_t) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_get_max_size(backend: ggml_backend_t) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_tensor_set_async( + backend: ggml_backend_t, + tensor: *mut ggml_tensor, + data: *const ::std::os::raw::c_void, + offset: usize, + size: usize, + ); +} +unsafe extern "C" { + pub fn ggml_backend_tensor_get_async( + backend: ggml_backend_t, + tensor: *const ggml_tensor, + data: *mut ::std::os::raw::c_void, + offset: usize, + size: usize, + ); +} +unsafe extern "C" { + pub fn ggml_backend_tensor_set( + tensor: *mut ggml_tensor, + data: *const ::std::os::raw::c_void, + offset: usize, + size: usize, + ); +} +unsafe extern "C" { + pub fn ggml_backend_tensor_get( + tensor: *const ggml_tensor, + data: *mut ::std::os::raw::c_void, + offset: usize, + size: usize, + ); +} +unsafe extern "C" { + pub fn ggml_backend_tensor_memset( + tensor: *mut ggml_tensor, + value: u8, + offset: usize, + size: usize, + ); +} +unsafe extern "C" { + pub fn ggml_backend_synchronize(backend: ggml_backend_t); +} +unsafe extern "C" { + pub fn ggml_backend_graph_plan_create( + backend: ggml_backend_t, + cgraph: *mut ggml_cgraph, + ) -> ggml_backend_graph_plan_t; +} +unsafe extern "C" { + pub fn ggml_backend_graph_plan_free(backend: ggml_backend_t, plan: ggml_backend_graph_plan_t); +} +unsafe extern "C" { + pub fn ggml_backend_graph_plan_compute( + backend: ggml_backend_t, + plan: ggml_backend_graph_plan_t, + ) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_backend_graph_compute( + backend: ggml_backend_t, + cgraph: *mut ggml_cgraph, + ) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_backend_graph_compute_async( + backend: ggml_backend_t, + cgraph: *mut ggml_cgraph, + ) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_backend_supports_op(backend: ggml_backend_t, op: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_supports_buft( + backend: ggml_backend_t, + buft: ggml_backend_buffer_type_t, + ) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_offload_op(backend: ggml_backend_t, op: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_tensor_copy_async( + backend_src: ggml_backend_t, + backend_dst: ggml_backend_t, + src: *mut ggml_tensor, + dst: *mut ggml_tensor, + ); +} +unsafe extern "C" { + pub fn ggml_backend_get_device(backend: ggml_backend_t) -> ggml_backend_dev_t; +} +unsafe extern "C" { + pub fn ggml_backend_event_new(device: ggml_backend_dev_t) -> ggml_backend_event_t; +} +unsafe extern "C" { + pub fn ggml_backend_event_free(event: ggml_backend_event_t); +} +unsafe extern "C" { + pub fn ggml_backend_event_record(event: ggml_backend_event_t, backend: ggml_backend_t); +} +unsafe extern "C" { + pub fn ggml_backend_event_synchronize(event: ggml_backend_event_t); +} +unsafe extern "C" { + pub fn ggml_backend_event_wait(backend: ggml_backend_t, event: ggml_backend_event_t); +} +pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_CPU: ggml_backend_dev_type = 0; +pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_GPU: ggml_backend_dev_type = 1; +pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_ACCEL: ggml_backend_dev_type = 2; +pub type ggml_backend_dev_type = ::std::os::raw::c_uint; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_dev_caps { + pub async_: bool, + pub host_buffer: bool, + pub buffer_from_host_ptr: bool, + pub events: bool, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_backend_dev_caps"][::std::mem::size_of::() - 4usize]; + ["Alignment of ggml_backend_dev_caps"] + [::std::mem::align_of::() - 1usize]; + ["Offset of field: ggml_backend_dev_caps::async_"] + [::std::mem::offset_of!(ggml_backend_dev_caps, async_) - 0usize]; + ["Offset of field: ggml_backend_dev_caps::host_buffer"] + [::std::mem::offset_of!(ggml_backend_dev_caps, host_buffer) - 1usize]; + ["Offset of field: ggml_backend_dev_caps::buffer_from_host_ptr"] + [::std::mem::offset_of!(ggml_backend_dev_caps, buffer_from_host_ptr) - 2usize]; + ["Offset of field: ggml_backend_dev_caps::events"] + [::std::mem::offset_of!(ggml_backend_dev_caps, events) - 3usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_dev_props { + pub name: *const ::std::os::raw::c_char, + pub description: *const ::std::os::raw::c_char, + pub memory_free: usize, + pub memory_total: usize, + pub type_: ggml_backend_dev_type, + pub caps: ggml_backend_dev_caps, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_backend_dev_props"][::std::mem::size_of::() - 40usize]; + ["Alignment of ggml_backend_dev_props"] + [::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_backend_dev_props::name"] + [::std::mem::offset_of!(ggml_backend_dev_props, name) - 0usize]; + ["Offset of field: ggml_backend_dev_props::description"] + [::std::mem::offset_of!(ggml_backend_dev_props, description) - 8usize]; + ["Offset of field: ggml_backend_dev_props::memory_free"] + [::std::mem::offset_of!(ggml_backend_dev_props, memory_free) - 16usize]; + ["Offset of field: ggml_backend_dev_props::memory_total"] + [::std::mem::offset_of!(ggml_backend_dev_props, memory_total) - 24usize]; + ["Offset of field: ggml_backend_dev_props::type_"] + [::std::mem::offset_of!(ggml_backend_dev_props, type_) - 32usize]; + ["Offset of field: ggml_backend_dev_props::caps"] + [::std::mem::offset_of!(ggml_backend_dev_props, caps) - 36usize]; +}; +unsafe extern "C" { + pub fn ggml_backend_dev_name(device: ggml_backend_dev_t) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_backend_dev_description( + device: ggml_backend_dev_t, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_backend_dev_memory(device: ggml_backend_dev_t, free: *mut usize, total: *mut usize); +} +unsafe extern "C" { + pub fn ggml_backend_dev_type(device: ggml_backend_dev_t) -> ggml_backend_dev_type; +} +unsafe extern "C" { + pub fn ggml_backend_dev_get_props( + device: ggml_backend_dev_t, + props: *mut ggml_backend_dev_props, + ); +} +unsafe extern "C" { + pub fn ggml_backend_dev_backend_reg(device: ggml_backend_dev_t) -> ggml_backend_reg_t; +} +unsafe extern "C" { + pub fn ggml_backend_dev_init( + device: ggml_backend_dev_t, + params: *const ::std::os::raw::c_char, + ) -> ggml_backend_t; +} +unsafe extern "C" { + pub fn ggml_backend_dev_buffer_type(device: ggml_backend_dev_t) -> ggml_backend_buffer_type_t; +} +unsafe extern "C" { + pub fn ggml_backend_dev_host_buffer_type( + device: ggml_backend_dev_t, + ) -> ggml_backend_buffer_type_t; +} +unsafe extern "C" { + pub fn ggml_backend_dev_buffer_from_host_ptr( + device: ggml_backend_dev_t, + ptr: *mut ::std::os::raw::c_void, + size: usize, + max_tensor_size: usize, + ) -> ggml_backend_buffer_t; +} +unsafe extern "C" { + pub fn ggml_backend_dev_supports_op(device: ggml_backend_dev_t, op: *const ggml_tensor) + -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_dev_supports_buft( + device: ggml_backend_dev_t, + buft: ggml_backend_buffer_type_t, + ) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_dev_offload_op(device: ggml_backend_dev_t, op: *const ggml_tensor) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_reg_name(reg: ggml_backend_reg_t) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ggml_backend_reg_dev_count(reg: ggml_backend_reg_t) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_reg_dev_get(reg: ggml_backend_reg_t, index: usize) -> ggml_backend_dev_t; +} +unsafe extern "C" { + pub fn ggml_backend_reg_get_proc_address( + reg: ggml_backend_reg_t, + name: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_void; +} +pub type ggml_backend_split_buffer_type_t = ::std::option::Option< + unsafe extern "C" fn( + main_device: ::std::os::raw::c_int, + tensor_split: *const f32, + ) -> ggml_backend_buffer_type_t, +>; +pub type ggml_backend_set_n_threads_t = ::std::option::Option< + unsafe extern "C" fn(backend: ggml_backend_t, n_threads: ::std::os::raw::c_int), +>; +pub type ggml_backend_dev_get_extra_bufts_t = ::std::option::Option< + unsafe extern "C" fn(device: ggml_backend_dev_t) -> *mut ggml_backend_buffer_type_t, +>; +pub type ggml_backend_set_abort_callback_t = ::std::option::Option< + unsafe extern "C" fn( + backend: ggml_backend_t, + abort_callback: ggml_abort_callback, + abort_callback_data: *mut ::std::os::raw::c_void, + ), +>; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_feature { + pub name: *const ::std::os::raw::c_char, + pub value: *const ::std::os::raw::c_char, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_backend_feature"][::std::mem::size_of::() - 16usize]; + ["Alignment of ggml_backend_feature"][::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_backend_feature::name"] + [::std::mem::offset_of!(ggml_backend_feature, name) - 0usize]; + ["Offset of field: ggml_backend_feature::value"] + [::std::mem::offset_of!(ggml_backend_feature, value) - 8usize]; +}; +pub type ggml_backend_get_features_t = ::std::option::Option< + unsafe extern "C" fn(reg: ggml_backend_reg_t) -> *mut ggml_backend_feature, +>; +unsafe extern "C" { + pub fn ggml_backend_device_register(device: ggml_backend_dev_t); +} +unsafe extern "C" { + pub fn ggml_backend_reg_count() -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_reg_get(index: usize) -> ggml_backend_reg_t; +} +unsafe extern "C" { + pub fn ggml_backend_reg_by_name(name: *const ::std::os::raw::c_char) -> ggml_backend_reg_t; +} +unsafe extern "C" { + pub fn ggml_backend_dev_count() -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_dev_get(index: usize) -> ggml_backend_dev_t; +} +unsafe extern "C" { + pub fn ggml_backend_dev_by_name(name: *const ::std::os::raw::c_char) -> ggml_backend_dev_t; +} +unsafe extern "C" { + pub fn ggml_backend_dev_by_type(type_: ggml_backend_dev_type) -> ggml_backend_dev_t; +} +unsafe extern "C" { + pub fn ggml_backend_init_by_name( + name: *const ::std::os::raw::c_char, + params: *const ::std::os::raw::c_char, + ) -> ggml_backend_t; +} +unsafe extern "C" { + pub fn ggml_backend_init_by_type( + type_: ggml_backend_dev_type, + params: *const ::std::os::raw::c_char, + ) -> ggml_backend_t; +} +unsafe extern "C" { + pub fn ggml_backend_init_best() -> ggml_backend_t; +} +unsafe extern "C" { + pub fn ggml_backend_load(path: *const ::std::os::raw::c_char) -> ggml_backend_reg_t; +} +unsafe extern "C" { + pub fn ggml_backend_unload(reg: ggml_backend_reg_t); +} +unsafe extern "C" { + pub fn ggml_backend_load_all(); +} +unsafe extern "C" { + pub fn ggml_backend_load_all_from_path(dir_path: *const ::std::os::raw::c_char); +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_sched { + _unused: [u8; 0], +} +pub type ggml_backend_sched_t = *mut ggml_backend_sched; +pub type ggml_backend_sched_eval_callback = ::std::option::Option< + unsafe extern "C" fn( + t: *mut ggml_tensor, + ask: bool, + user_data: *mut ::std::os::raw::c_void, + ) -> bool, +>; +unsafe extern "C" { + pub fn ggml_backend_sched_new( + backends: *mut ggml_backend_t, + bufts: *mut ggml_backend_buffer_type_t, + n_backends: ::std::os::raw::c_int, + graph_size: usize, + parallel: bool, + op_offload: bool, + ) -> ggml_backend_sched_t; +} +unsafe extern "C" { + pub fn ggml_backend_sched_free(sched: ggml_backend_sched_t); +} +unsafe extern "C" { + pub fn ggml_backend_sched_reserve( + sched: ggml_backend_sched_t, + measure_graph: *mut ggml_cgraph, + ) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_sched_get_n_backends(sched: ggml_backend_sched_t) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_backend_sched_get_backend( + sched: ggml_backend_sched_t, + i: ::std::os::raw::c_int, + ) -> ggml_backend_t; +} +unsafe extern "C" { + pub fn ggml_backend_sched_get_n_splits(sched: ggml_backend_sched_t) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_backend_sched_get_n_copies(sched: ggml_backend_sched_t) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_backend_sched_get_buffer_size( + sched: ggml_backend_sched_t, + backend: ggml_backend_t, + ) -> usize; +} +unsafe extern "C" { + pub fn ggml_backend_sched_set_tensor_backend( + sched: ggml_backend_sched_t, + node: *mut ggml_tensor, + backend: ggml_backend_t, + ); +} +unsafe extern "C" { + pub fn ggml_backend_sched_get_tensor_backend( + sched: ggml_backend_sched_t, + node: *mut ggml_tensor, + ) -> ggml_backend_t; +} +unsafe extern "C" { + pub fn ggml_backend_sched_alloc_graph( + sched: ggml_backend_sched_t, + graph: *mut ggml_cgraph, + ) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_sched_graph_compute( + sched: ggml_backend_sched_t, + graph: *mut ggml_cgraph, + ) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_backend_sched_graph_compute_async( + sched: ggml_backend_sched_t, + graph: *mut ggml_cgraph, + ) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_backend_sched_synchronize(sched: ggml_backend_sched_t); +} +unsafe extern "C" { + pub fn ggml_backend_sched_reset(sched: ggml_backend_sched_t); +} +unsafe extern "C" { + pub fn ggml_backend_sched_set_eval_callback( + sched: ggml_backend_sched_t, + callback: ggml_backend_sched_eval_callback, + user_data: *mut ::std::os::raw::c_void, + ); +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_graph_copy { + pub buffer: ggml_backend_buffer_t, + pub ctx_allocated: *mut ggml_context, + pub ctx_unallocated: *mut ggml_context, + pub graph: *mut ggml_cgraph, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_backend_graph_copy"][::std::mem::size_of::() - 32usize]; + ["Alignment of ggml_backend_graph_copy"] + [::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_backend_graph_copy::buffer"] + [::std::mem::offset_of!(ggml_backend_graph_copy, buffer) - 0usize]; + ["Offset of field: ggml_backend_graph_copy::ctx_allocated"] + [::std::mem::offset_of!(ggml_backend_graph_copy, ctx_allocated) - 8usize]; + ["Offset of field: ggml_backend_graph_copy::ctx_unallocated"] + [::std::mem::offset_of!(ggml_backend_graph_copy, ctx_unallocated) - 16usize]; + ["Offset of field: ggml_backend_graph_copy::graph"] + [::std::mem::offset_of!(ggml_backend_graph_copy, graph) - 24usize]; +}; +unsafe extern "C" { + pub fn ggml_backend_graph_copy( + backend: ggml_backend_t, + graph: *mut ggml_cgraph, + ) -> ggml_backend_graph_copy; +} +unsafe extern "C" { + pub fn ggml_backend_graph_copy_free(copy: ggml_backend_graph_copy); +} +pub type ggml_backend_eval_callback = ::std::option::Option< + unsafe extern "C" fn( + node_index: ::std::os::raw::c_int, + t1: *mut ggml_tensor, + t2: *mut ggml_tensor, + user_data: *mut ::std::os::raw::c_void, + ) -> bool, +>; +unsafe extern "C" { + pub fn ggml_backend_compare_graph_backend( + backend1: ggml_backend_t, + backend2: ggml_backend_t, + graph: *mut ggml_cgraph, + callback: ggml_backend_eval_callback, + user_data: *mut ::std::os::raw::c_void, + test_node: *mut ggml_tensor, + ) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_tensor_alloc( + buffer: ggml_backend_buffer_t, + tensor: *mut ggml_tensor, + addr: *mut ::std::os::raw::c_void, + ) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_backend_view_init(tensor: *mut ggml_tensor) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_backend_cpu_buffer_from_ptr( + ptr: *mut ::std::os::raw::c_void, + size: usize, + ) -> ggml_backend_buffer_t; +} +unsafe extern "C" { + pub fn ggml_backend_cpu_buffer_type() -> ggml_backend_buffer_type_t; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_cplan { + pub work_size: usize, + pub work_data: *mut u8, + pub n_threads: ::std::os::raw::c_int, + pub threadpool: *mut ggml_threadpool, + pub abort_callback: ggml_abort_callback, + pub abort_callback_data: *mut ::std::os::raw::c_void, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_cplan"][::std::mem::size_of::() - 48usize]; + ["Alignment of ggml_cplan"][::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_cplan::work_size"] + [::std::mem::offset_of!(ggml_cplan, work_size) - 0usize]; + ["Offset of field: ggml_cplan::work_data"] + [::std::mem::offset_of!(ggml_cplan, work_data) - 8usize]; + ["Offset of field: ggml_cplan::n_threads"] + [::std::mem::offset_of!(ggml_cplan, n_threads) - 16usize]; + ["Offset of field: ggml_cplan::threadpool"] + [::std::mem::offset_of!(ggml_cplan, threadpool) - 24usize]; + ["Offset of field: ggml_cplan::abort_callback"] + [::std::mem::offset_of!(ggml_cplan, abort_callback) - 32usize]; + ["Offset of field: ggml_cplan::abort_callback_data"] + [::std::mem::offset_of!(ggml_cplan, abort_callback_data) - 40usize]; +}; +pub const ggml_numa_strategy_GGML_NUMA_STRATEGY_DISABLED: ggml_numa_strategy = 0; +pub const ggml_numa_strategy_GGML_NUMA_STRATEGY_DISTRIBUTE: ggml_numa_strategy = 1; +pub const ggml_numa_strategy_GGML_NUMA_STRATEGY_ISOLATE: ggml_numa_strategy = 2; +pub const ggml_numa_strategy_GGML_NUMA_STRATEGY_NUMACTL: ggml_numa_strategy = 3; +pub const ggml_numa_strategy_GGML_NUMA_STRATEGY_MIRROR: ggml_numa_strategy = 4; +pub const ggml_numa_strategy_GGML_NUMA_STRATEGY_COUNT: ggml_numa_strategy = 5; +pub type ggml_numa_strategy = ::std::os::raw::c_uint; +unsafe extern "C" { + pub fn ggml_numa_init(numa: ggml_numa_strategy); +} +unsafe extern "C" { + pub fn ggml_is_numa() -> bool; +} +unsafe extern "C" { + pub fn ggml_new_i32(ctx: *mut ggml_context, value: i32) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_new_f32(ctx: *mut ggml_context, value: f32) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_i32(tensor: *mut ggml_tensor, value: i32) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_set_f32(tensor: *mut ggml_tensor, value: f32) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_get_i32_1d(tensor: *const ggml_tensor, i: ::std::os::raw::c_int) -> i32; +} +unsafe extern "C" { + pub fn ggml_set_i32_1d(tensor: *const ggml_tensor, i: ::std::os::raw::c_int, value: i32); +} +unsafe extern "C" { + pub fn ggml_get_i32_nd( + tensor: *const ggml_tensor, + i0: ::std::os::raw::c_int, + i1: ::std::os::raw::c_int, + i2: ::std::os::raw::c_int, + i3: ::std::os::raw::c_int, + ) -> i32; +} +unsafe extern "C" { + pub fn ggml_set_i32_nd( + tensor: *const ggml_tensor, + i0: ::std::os::raw::c_int, + i1: ::std::os::raw::c_int, + i2: ::std::os::raw::c_int, + i3: ::std::os::raw::c_int, + value: i32, + ); +} +unsafe extern "C" { + pub fn ggml_get_f32_1d(tensor: *const ggml_tensor, i: ::std::os::raw::c_int) -> f32; +} +unsafe extern "C" { + pub fn ggml_set_f32_1d(tensor: *const ggml_tensor, i: ::std::os::raw::c_int, value: f32); +} +unsafe extern "C" { + pub fn ggml_get_f32_nd( + tensor: *const ggml_tensor, + i0: ::std::os::raw::c_int, + i1: ::std::os::raw::c_int, + i2: ::std::os::raw::c_int, + i3: ::std::os::raw::c_int, + ) -> f32; +} +unsafe extern "C" { + pub fn ggml_set_f32_nd( + tensor: *const ggml_tensor, + i0: ::std::os::raw::c_int, + i1: ::std::os::raw::c_int, + i2: ::std::os::raw::c_int, + i3: ::std::os::raw::c_int, + value: f32, + ); +} +unsafe extern "C" { + pub fn ggml_threadpool_new(params: *mut ggml_threadpool_params) -> *mut ggml_threadpool; +} +unsafe extern "C" { + pub fn ggml_threadpool_free(threadpool: *mut ggml_threadpool); +} +unsafe extern "C" { + pub fn ggml_threadpool_get_n_threads(threadpool: *mut ggml_threadpool) + -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_threadpool_pause(threadpool: *mut ggml_threadpool); +} +unsafe extern "C" { + pub fn ggml_threadpool_resume(threadpool: *mut ggml_threadpool); +} +unsafe extern "C" { + pub fn ggml_graph_plan( + cgraph: *const ggml_cgraph, + n_threads: ::std::os::raw::c_int, + threadpool: *mut ggml_threadpool, + ) -> ggml_cplan; +} +unsafe extern "C" { + pub fn ggml_graph_compute(cgraph: *mut ggml_cgraph, cplan: *mut ggml_cplan) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_graph_compute_with_ctx( + ctx: *mut ggml_context, + cgraph: *mut ggml_cgraph, + n_threads: ::std::os::raw::c_int, + ) -> ggml_status; +} +unsafe extern "C" { + pub fn ggml_cpu_has_sse3() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_ssse3() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_avx() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_avx_vnni() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_avx2() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_bmi2() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_f16c() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_fma() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_avx512() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_avx512_vbmi() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_avx512_vnni() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_avx512_bf16() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_amx_int8() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_neon() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_arm_fma() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_fp16_va() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_dotprod() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_matmul_int8() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_sve() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_get_sve_cnt() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_sme() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_riscv_v() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_vsx() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_vxe() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_nnpa() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_wasm_simd() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ggml_cpu_has_llamafile() -> ::std::os::raw::c_int; +} +pub type ggml_vec_dot_t = ::std::option::Option< + unsafe extern "C" fn( + n: ::std::os::raw::c_int, + s: *mut f32, + bs: usize, + x: *const ::std::os::raw::c_void, + bx: usize, + y: *const ::std::os::raw::c_void, + by: usize, + nrc: ::std::os::raw::c_int, + ), +>; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_type_traits_cpu { + pub from_float: ggml_from_float_t, + pub vec_dot: ggml_vec_dot_t, + pub vec_dot_type: ggml_type, + pub nrows: i64, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of ggml_type_traits_cpu"][::std::mem::size_of::() - 32usize]; + ["Alignment of ggml_type_traits_cpu"][::std::mem::align_of::() - 8usize]; + ["Offset of field: ggml_type_traits_cpu::from_float"] + [::std::mem::offset_of!(ggml_type_traits_cpu, from_float) - 0usize]; + ["Offset of field: ggml_type_traits_cpu::vec_dot"] + [::std::mem::offset_of!(ggml_type_traits_cpu, vec_dot) - 8usize]; + ["Offset of field: ggml_type_traits_cpu::vec_dot_type"] + [::std::mem::offset_of!(ggml_type_traits_cpu, vec_dot_type) - 16usize]; + ["Offset of field: ggml_type_traits_cpu::nrows"] + [::std::mem::offset_of!(ggml_type_traits_cpu, nrows) - 24usize]; +}; +unsafe extern "C" { + pub fn ggml_get_type_traits_cpu(type_: ggml_type) -> *const ggml_type_traits_cpu; +} +unsafe extern "C" { + pub fn ggml_cpu_init(); +} +unsafe extern "C" { + pub fn ggml_backend_cpu_init() -> ggml_backend_t; +} +unsafe extern "C" { + pub fn ggml_backend_is_cpu(backend: ggml_backend_t) -> bool; +} +unsafe extern "C" { + pub fn ggml_backend_cpu_set_n_threads( + backend_cpu: ggml_backend_t, + n_threads: ::std::os::raw::c_int, + ); +} +unsafe extern "C" { + pub fn ggml_backend_cpu_set_threadpool( + backend_cpu: ggml_backend_t, + threadpool: ggml_threadpool_t, + ); +} +unsafe extern "C" { + pub fn ggml_backend_cpu_set_abort_callback( + backend_cpu: ggml_backend_t, + abort_callback: ggml_abort_callback, + abort_callback_data: *mut ::std::os::raw::c_void, + ); +} +unsafe extern "C" { + pub fn ggml_backend_cpu_reg() -> ggml_backend_reg_t; +} +unsafe extern "C" { + pub fn ggml_cpu_fp32_to_fp32(arg1: *const f32, arg2: *mut f32, arg3: i64); +} +unsafe extern "C" { + pub fn ggml_cpu_fp32_to_fp16(arg1: *const f32, arg2: *mut ggml_fp16_t, arg3: i64); +} +unsafe extern "C" { + pub fn ggml_cpu_fp16_to_fp32(arg1: *const ggml_fp16_t, arg2: *mut f32, arg3: i64); +} +unsafe extern "C" { + pub fn ggml_cpu_fp32_to_bf16(arg1: *const f32, arg2: *mut ggml_bf16_t, arg3: i64); +} +unsafe extern "C" { + pub fn ggml_cpu_bf16_to_fp32(arg1: *const ggml_bf16_t, arg2: *mut f32, arg3: i64); +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_context { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_state { + _unused: [u8; 0], +} +pub type whisper_pos = i32; +pub type whisper_token = i32; +pub type whisper_seq_id = i32; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_NONE: whisper_alignment_heads_preset = 0; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST: whisper_alignment_heads_preset = + 1; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM: whisper_alignment_heads_preset = 2; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN: whisper_alignment_heads_preset = 3; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_TINY: whisper_alignment_heads_preset = 4; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN: whisper_alignment_heads_preset = 5; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_BASE: whisper_alignment_heads_preset = 6; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN: whisper_alignment_heads_preset = + 7; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL: whisper_alignment_heads_preset = 8; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN: whisper_alignment_heads_preset = + 9; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM: whisper_alignment_heads_preset = 10; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1: whisper_alignment_heads_preset = + 11; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2: whisper_alignment_heads_preset = + 12; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3: whisper_alignment_heads_preset = + 13; +pub const whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3_TURBO: + whisper_alignment_heads_preset = 14; +pub type whisper_alignment_heads_preset = ::std::os::raw::c_uint; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_ahead { + pub n_text_layer: ::std::os::raw::c_int, + pub n_head: ::std::os::raw::c_int, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_ahead"][::std::mem::size_of::() - 8usize]; + ["Alignment of whisper_ahead"][::std::mem::align_of::() - 4usize]; + ["Offset of field: whisper_ahead::n_text_layer"] + [::std::mem::offset_of!(whisper_ahead, n_text_layer) - 0usize]; + ["Offset of field: whisper_ahead::n_head"] + [::std::mem::offset_of!(whisper_ahead, n_head) - 4usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_aheads { + pub n_heads: usize, + pub heads: *const whisper_ahead, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_aheads"][::std::mem::size_of::() - 16usize]; + ["Alignment of whisper_aheads"][::std::mem::align_of::() - 8usize]; + ["Offset of field: whisper_aheads::n_heads"] + [::std::mem::offset_of!(whisper_aheads, n_heads) - 0usize]; + ["Offset of field: whisper_aheads::heads"] + [::std::mem::offset_of!(whisper_aheads, heads) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_context_params { + pub use_gpu: bool, + pub flash_attn: bool, + pub gpu_device: ::std::os::raw::c_int, + pub dtw_token_timestamps: bool, + pub dtw_aheads_preset: whisper_alignment_heads_preset, + pub dtw_n_top: ::std::os::raw::c_int, + pub dtw_aheads: whisper_aheads, + pub dtw_mem_size: usize, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_context_params"][::std::mem::size_of::() - 48usize]; + ["Alignment of whisper_context_params"] + [::std::mem::align_of::() - 8usize]; + ["Offset of field: whisper_context_params::use_gpu"] + [::std::mem::offset_of!(whisper_context_params, use_gpu) - 0usize]; + ["Offset of field: whisper_context_params::flash_attn"] + [::std::mem::offset_of!(whisper_context_params, flash_attn) - 1usize]; + ["Offset of field: whisper_context_params::gpu_device"] + [::std::mem::offset_of!(whisper_context_params, gpu_device) - 4usize]; + ["Offset of field: whisper_context_params::dtw_token_timestamps"] + [::std::mem::offset_of!(whisper_context_params, dtw_token_timestamps) - 8usize]; + ["Offset of field: whisper_context_params::dtw_aheads_preset"] + [::std::mem::offset_of!(whisper_context_params, dtw_aheads_preset) - 12usize]; + ["Offset of field: whisper_context_params::dtw_n_top"] + [::std::mem::offset_of!(whisper_context_params, dtw_n_top) - 16usize]; + ["Offset of field: whisper_context_params::dtw_aheads"] + [::std::mem::offset_of!(whisper_context_params, dtw_aheads) - 24usize]; + ["Offset of field: whisper_context_params::dtw_mem_size"] + [::std::mem::offset_of!(whisper_context_params, dtw_mem_size) - 40usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_token_data { + pub id: whisper_token, + pub tid: whisper_token, + pub p: f32, + pub plog: f32, + pub pt: f32, + pub ptsum: f32, + pub t0: i64, + pub t1: i64, + pub t_dtw: i64, + pub vlen: f32, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_token_data"][::std::mem::size_of::() - 56usize]; + ["Alignment of whisper_token_data"][::std::mem::align_of::() - 8usize]; + ["Offset of field: whisper_token_data::id"] + [::std::mem::offset_of!(whisper_token_data, id) - 0usize]; + ["Offset of field: whisper_token_data::tid"] + [::std::mem::offset_of!(whisper_token_data, tid) - 4usize]; + ["Offset of field: whisper_token_data::p"] + [::std::mem::offset_of!(whisper_token_data, p) - 8usize]; + ["Offset of field: whisper_token_data::plog"] + [::std::mem::offset_of!(whisper_token_data, plog) - 12usize]; + ["Offset of field: whisper_token_data::pt"] + [::std::mem::offset_of!(whisper_token_data, pt) - 16usize]; + ["Offset of field: whisper_token_data::ptsum"] + [::std::mem::offset_of!(whisper_token_data, ptsum) - 20usize]; + ["Offset of field: whisper_token_data::t0"] + [::std::mem::offset_of!(whisper_token_data, t0) - 24usize]; + ["Offset of field: whisper_token_data::t1"] + [::std::mem::offset_of!(whisper_token_data, t1) - 32usize]; + ["Offset of field: whisper_token_data::t_dtw"] + [::std::mem::offset_of!(whisper_token_data, t_dtw) - 40usize]; + ["Offset of field: whisper_token_data::vlen"] + [::std::mem::offset_of!(whisper_token_data, vlen) - 48usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_model_loader { + pub context: *mut ::std::os::raw::c_void, + pub read: ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut ::std::os::raw::c_void, + output: *mut ::std::os::raw::c_void, + read_size: usize, + ) -> usize, + >, + pub eof: ::std::option::Option bool>, + pub close: ::std::option::Option, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_model_loader"][::std::mem::size_of::() - 32usize]; + ["Alignment of whisper_model_loader"][::std::mem::align_of::() - 8usize]; + ["Offset of field: whisper_model_loader::context"] + [::std::mem::offset_of!(whisper_model_loader, context) - 0usize]; + ["Offset of field: whisper_model_loader::read"] + [::std::mem::offset_of!(whisper_model_loader, read) - 8usize]; + ["Offset of field: whisper_model_loader::eof"] + [::std::mem::offset_of!(whisper_model_loader, eof) - 16usize]; + ["Offset of field: whisper_model_loader::close"] + [::std::mem::offset_of!(whisper_model_loader, close) - 24usize]; +}; +pub const whisper_gretype_WHISPER_GRETYPE_END: whisper_gretype = 0; +pub const whisper_gretype_WHISPER_GRETYPE_ALT: whisper_gretype = 1; +pub const whisper_gretype_WHISPER_GRETYPE_RULE_REF: whisper_gretype = 2; +pub const whisper_gretype_WHISPER_GRETYPE_CHAR: whisper_gretype = 3; +pub const whisper_gretype_WHISPER_GRETYPE_CHAR_NOT: whisper_gretype = 4; +pub const whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER: whisper_gretype = 5; +pub const whisper_gretype_WHISPER_GRETYPE_CHAR_ALT: whisper_gretype = 6; +pub type whisper_gretype = ::std::os::raw::c_uint; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_grammar_element { + pub type_: whisper_gretype, + pub value: u32, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_grammar_element"][::std::mem::size_of::() - 8usize]; + ["Alignment of whisper_grammar_element"] + [::std::mem::align_of::() - 4usize]; + ["Offset of field: whisper_grammar_element::type_"] + [::std::mem::offset_of!(whisper_grammar_element, type_) - 0usize]; + ["Offset of field: whisper_grammar_element::value"] + [::std::mem::offset_of!(whisper_grammar_element, value) - 4usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_vad_params { + pub threshold: f32, + pub min_speech_duration_ms: ::std::os::raw::c_int, + pub min_silence_duration_ms: ::std::os::raw::c_int, + pub max_speech_duration_s: f32, + pub speech_pad_ms: ::std::os::raw::c_int, + pub samples_overlap: f32, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_vad_params"][::std::mem::size_of::() - 24usize]; + ["Alignment of whisper_vad_params"][::std::mem::align_of::() - 4usize]; + ["Offset of field: whisper_vad_params::threshold"] + [::std::mem::offset_of!(whisper_vad_params, threshold) - 0usize]; + ["Offset of field: whisper_vad_params::min_speech_duration_ms"] + [::std::mem::offset_of!(whisper_vad_params, min_speech_duration_ms) - 4usize]; + ["Offset of field: whisper_vad_params::min_silence_duration_ms"] + [::std::mem::offset_of!(whisper_vad_params, min_silence_duration_ms) - 8usize]; + ["Offset of field: whisper_vad_params::max_speech_duration_s"] + [::std::mem::offset_of!(whisper_vad_params, max_speech_duration_s) - 12usize]; + ["Offset of field: whisper_vad_params::speech_pad_ms"] + [::std::mem::offset_of!(whisper_vad_params, speech_pad_ms) - 16usize]; + ["Offset of field: whisper_vad_params::samples_overlap"] + [::std::mem::offset_of!(whisper_vad_params, samples_overlap) - 20usize]; +}; +unsafe extern "C" { + pub fn whisper_version() -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_init_from_file_with_params( + path_model: *const ::std::os::raw::c_char, + params: whisper_context_params, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_from_buffer_with_params( + buffer: *mut ::std::os::raw::c_void, + buffer_size: usize, + params: whisper_context_params, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_with_params( + loader: *mut whisper_model_loader, + params: whisper_context_params, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_from_file_with_params_no_state( + path_model: *const ::std::os::raw::c_char, + params: whisper_context_params, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_from_buffer_with_params_no_state( + buffer: *mut ::std::os::raw::c_void, + buffer_size: usize, + params: whisper_context_params, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_with_params_no_state( + loader: *mut whisper_model_loader, + params: whisper_context_params, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_from_file( + path_model: *const ::std::os::raw::c_char, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_from_buffer( + buffer: *mut ::std::os::raw::c_void, + buffer_size: usize, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init(loader: *mut whisper_model_loader) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_from_file_no_state( + path_model: *const ::std::os::raw::c_char, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_from_buffer_no_state( + buffer: *mut ::std::os::raw::c_void, + buffer_size: usize, + ) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_no_state(loader: *mut whisper_model_loader) -> *mut whisper_context; +} +unsafe extern "C" { + pub fn whisper_init_state(ctx: *mut whisper_context) -> *mut whisper_state; +} +unsafe extern "C" { + pub fn whisper_ctx_init_openvino_encoder_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + model_path: *const ::std::os::raw::c_char, + device: *const ::std::os::raw::c_char, + cache_dir: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_ctx_init_openvino_encoder( + ctx: *mut whisper_context, + model_path: *const ::std::os::raw::c_char, + device: *const ::std::os::raw::c_char, + cache_dir: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_free(ctx: *mut whisper_context); +} +unsafe extern "C" { + pub fn whisper_free_state(state: *mut whisper_state); +} +unsafe extern "C" { + pub fn whisper_free_params(params: *mut whisper_full_params); +} +unsafe extern "C" { + pub fn whisper_free_context_params(params: *mut whisper_context_params); +} +unsafe extern "C" { + pub fn whisper_pcm_to_mel( + ctx: *mut whisper_context, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_pcm_to_mel_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_set_mel( + ctx: *mut whisper_context, + data: *const f32, + n_len: ::std::os::raw::c_int, + n_mel: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_set_mel_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + data: *const f32, + n_len: ::std::os::raw::c_int, + n_mel: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_encode( + ctx: *mut whisper_context, + offset: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_encode_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + offset: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_decode( + ctx: *mut whisper_context, + tokens: *const whisper_token, + n_tokens: ::std::os::raw::c_int, + n_past: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_decode_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + tokens: *const whisper_token, + n_tokens: ::std::os::raw::c_int, + n_past: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_tokenize( + ctx: *mut whisper_context, + text: *const ::std::os::raw::c_char, + tokens: *mut whisper_token, + n_max_tokens: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_token_count( + ctx: *mut whisper_context, + text: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_lang_max_id() -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_lang_id(lang: *const ::std::os::raw::c_char) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_lang_str(id: ::std::os::raw::c_int) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_lang_str_full(id: ::std::os::raw::c_int) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_lang_auto_detect( + ctx: *mut whisper_context, + offset_ms: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + lang_probs: *mut f32, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_lang_auto_detect_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + offset_ms: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + lang_probs: *mut f32, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_n_len(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_n_len_from_state(state: *mut whisper_state) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_n_vocab(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_n_text_ctx(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_n_audio_ctx(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_is_multilingual(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_vocab(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_audio_ctx(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_audio_state(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_audio_head(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_audio_layer(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_text_ctx(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_text_state(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_text_head(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_text_layer(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_n_mels(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_ftype(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_model_type(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_get_logits(ctx: *mut whisper_context) -> *mut f32; +} +unsafe extern "C" { + pub fn whisper_get_logits_from_state(state: *mut whisper_state) -> *mut f32; +} +unsafe extern "C" { + pub fn whisper_token_to_str( + ctx: *mut whisper_context, + token: whisper_token, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_model_type_readable(ctx: *mut whisper_context) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_token_eot(ctx: *mut whisper_context) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_sot(ctx: *mut whisper_context) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_solm(ctx: *mut whisper_context) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_prev(ctx: *mut whisper_context) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_nosp(ctx: *mut whisper_context) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_not(ctx: *mut whisper_context) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_beg(ctx: *mut whisper_context) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_lang( + ctx: *mut whisper_context, + lang_id: ::std::os::raw::c_int, + ) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_translate(ctx: *mut whisper_context) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_token_transcribe(ctx: *mut whisper_context) -> whisper_token; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_timings { + pub sample_ms: f32, + pub encode_ms: f32, + pub decode_ms: f32, + pub batchd_ms: f32, + pub prompt_ms: f32, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_timings"][::std::mem::size_of::() - 20usize]; + ["Alignment of whisper_timings"][::std::mem::align_of::() - 4usize]; + ["Offset of field: whisper_timings::sample_ms"] + [::std::mem::offset_of!(whisper_timings, sample_ms) - 0usize]; + ["Offset of field: whisper_timings::encode_ms"] + [::std::mem::offset_of!(whisper_timings, encode_ms) - 4usize]; + ["Offset of field: whisper_timings::decode_ms"] + [::std::mem::offset_of!(whisper_timings, decode_ms) - 8usize]; + ["Offset of field: whisper_timings::batchd_ms"] + [::std::mem::offset_of!(whisper_timings, batchd_ms) - 12usize]; + ["Offset of field: whisper_timings::prompt_ms"] + [::std::mem::offset_of!(whisper_timings, prompt_ms) - 16usize]; +}; +unsafe extern "C" { + pub fn whisper_get_timings(ctx: *mut whisper_context) -> *mut whisper_timings; +} +unsafe extern "C" { + pub fn whisper_print_timings(ctx: *mut whisper_context); +} +unsafe extern "C" { + pub fn whisper_reset_timings(ctx: *mut whisper_context); +} +unsafe extern "C" { + pub fn whisper_print_system_info() -> *const ::std::os::raw::c_char; +} +pub const whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY: whisper_sampling_strategy = 0; +pub const whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH: whisper_sampling_strategy = 1; +pub type whisper_sampling_strategy = ::std::os::raw::c_uint; +pub type whisper_new_segment_callback = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut whisper_context, + state: *mut whisper_state, + n_new: ::std::os::raw::c_int, + user_data: *mut ::std::os::raw::c_void, + ), +>; +pub type whisper_progress_callback = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut whisper_context, + state: *mut whisper_state, + progress: ::std::os::raw::c_int, + user_data: *mut ::std::os::raw::c_void, + ), +>; +pub type whisper_encoder_begin_callback = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut whisper_context, + state: *mut whisper_state, + user_data: *mut ::std::os::raw::c_void, + ) -> bool, +>; +pub type whisper_logits_filter_callback = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut whisper_context, + state: *mut whisper_state, + tokens: *const whisper_token_data, + n_tokens: ::std::os::raw::c_int, + logits: *mut f32, + user_data: *mut ::std::os::raw::c_void, + ), +>; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_full_params { + pub strategy: whisper_sampling_strategy, + pub n_threads: ::std::os::raw::c_int, + pub n_max_text_ctx: ::std::os::raw::c_int, + pub offset_ms: ::std::os::raw::c_int, + pub duration_ms: ::std::os::raw::c_int, + pub translate: bool, + pub no_context: bool, + pub no_timestamps: bool, + pub single_segment: bool, + pub print_special: bool, + pub print_progress: bool, + pub print_realtime: bool, + pub print_timestamps: bool, + pub token_timestamps: bool, + pub thold_pt: f32, + pub thold_ptsum: f32, + pub max_len: ::std::os::raw::c_int, + pub split_on_word: bool, + pub max_tokens: ::std::os::raw::c_int, + pub debug_mode: bool, + pub audio_ctx: ::std::os::raw::c_int, + pub tdrz_enable: bool, + pub suppress_regex: *const ::std::os::raw::c_char, + pub initial_prompt: *const ::std::os::raw::c_char, + pub prompt_tokens: *const whisper_token, + pub prompt_n_tokens: ::std::os::raw::c_int, + pub language: *const ::std::os::raw::c_char, + pub detect_language: bool, + pub suppress_blank: bool, + pub suppress_nst: bool, + pub temperature: f32, + pub max_initial_ts: f32, + pub length_penalty: f32, + pub temperature_inc: f32, + pub entropy_thold: f32, + pub logprob_thold: f32, + pub no_speech_thold: f32, + pub greedy: whisper_full_params__bindgen_ty_1, + pub beam_search: whisper_full_params__bindgen_ty_2, + pub new_segment_callback: whisper_new_segment_callback, + pub new_segment_callback_user_data: *mut ::std::os::raw::c_void, + pub progress_callback: whisper_progress_callback, + pub progress_callback_user_data: *mut ::std::os::raw::c_void, + pub encoder_begin_callback: whisper_encoder_begin_callback, + pub encoder_begin_callback_user_data: *mut ::std::os::raw::c_void, + pub abort_callback: ggml_abort_callback, + pub abort_callback_user_data: *mut ::std::os::raw::c_void, + pub logits_filter_callback: whisper_logits_filter_callback, + pub logits_filter_callback_user_data: *mut ::std::os::raw::c_void, + pub grammar_rules: *mut *const whisper_grammar_element, + pub n_grammar_rules: usize, + pub i_start_rule: usize, + pub grammar_penalty: f32, + pub vad: bool, + pub vad_model_path: *const ::std::os::raw::c_char, + pub vad_params: whisper_vad_params, +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_full_params__bindgen_ty_1 { + pub best_of: ::std::os::raw::c_int, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_full_params__bindgen_ty_1"] + [::std::mem::size_of::() - 4usize]; + ["Alignment of whisper_full_params__bindgen_ty_1"] + [::std::mem::align_of::() - 4usize]; + ["Offset of field: whisper_full_params__bindgen_ty_1::best_of"] + [::std::mem::offset_of!(whisper_full_params__bindgen_ty_1, best_of) - 0usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_full_params__bindgen_ty_2 { + pub beam_size: ::std::os::raw::c_int, + pub patience: f32, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_full_params__bindgen_ty_2"] + [::std::mem::size_of::() - 8usize]; + ["Alignment of whisper_full_params__bindgen_ty_2"] + [::std::mem::align_of::() - 4usize]; + ["Offset of field: whisper_full_params__bindgen_ty_2::beam_size"] + [::std::mem::offset_of!(whisper_full_params__bindgen_ty_2, beam_size) - 0usize]; + ["Offset of field: whisper_full_params__bindgen_ty_2::patience"] + [::std::mem::offset_of!(whisper_full_params__bindgen_ty_2, patience) - 4usize]; +}; +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_full_params"][::std::mem::size_of::() - 296usize]; + ["Alignment of whisper_full_params"][::std::mem::align_of::() - 8usize]; + ["Offset of field: whisper_full_params::strategy"] + [::std::mem::offset_of!(whisper_full_params, strategy) - 0usize]; + ["Offset of field: whisper_full_params::n_threads"] + [::std::mem::offset_of!(whisper_full_params, n_threads) - 4usize]; + ["Offset of field: whisper_full_params::n_max_text_ctx"] + [::std::mem::offset_of!(whisper_full_params, n_max_text_ctx) - 8usize]; + ["Offset of field: whisper_full_params::offset_ms"] + [::std::mem::offset_of!(whisper_full_params, offset_ms) - 12usize]; + ["Offset of field: whisper_full_params::duration_ms"] + [::std::mem::offset_of!(whisper_full_params, duration_ms) - 16usize]; + ["Offset of field: whisper_full_params::translate"] + [::std::mem::offset_of!(whisper_full_params, translate) - 20usize]; + ["Offset of field: whisper_full_params::no_context"] + [::std::mem::offset_of!(whisper_full_params, no_context) - 21usize]; + ["Offset of field: whisper_full_params::no_timestamps"] + [::std::mem::offset_of!(whisper_full_params, no_timestamps) - 22usize]; + ["Offset of field: whisper_full_params::single_segment"] + [::std::mem::offset_of!(whisper_full_params, single_segment) - 23usize]; + ["Offset of field: whisper_full_params::print_special"] + [::std::mem::offset_of!(whisper_full_params, print_special) - 24usize]; + ["Offset of field: whisper_full_params::print_progress"] + [::std::mem::offset_of!(whisper_full_params, print_progress) - 25usize]; + ["Offset of field: whisper_full_params::print_realtime"] + [::std::mem::offset_of!(whisper_full_params, print_realtime) - 26usize]; + ["Offset of field: whisper_full_params::print_timestamps"] + [::std::mem::offset_of!(whisper_full_params, print_timestamps) - 27usize]; + ["Offset of field: whisper_full_params::token_timestamps"] + [::std::mem::offset_of!(whisper_full_params, token_timestamps) - 28usize]; + ["Offset of field: whisper_full_params::thold_pt"] + [::std::mem::offset_of!(whisper_full_params, thold_pt) - 32usize]; + ["Offset of field: whisper_full_params::thold_ptsum"] + [::std::mem::offset_of!(whisper_full_params, thold_ptsum) - 36usize]; + ["Offset of field: whisper_full_params::max_len"] + [::std::mem::offset_of!(whisper_full_params, max_len) - 40usize]; + ["Offset of field: whisper_full_params::split_on_word"] + [::std::mem::offset_of!(whisper_full_params, split_on_word) - 44usize]; + ["Offset of field: whisper_full_params::max_tokens"] + [::std::mem::offset_of!(whisper_full_params, max_tokens) - 48usize]; + ["Offset of field: whisper_full_params::debug_mode"] + [::std::mem::offset_of!(whisper_full_params, debug_mode) - 52usize]; + ["Offset of field: whisper_full_params::audio_ctx"] + [::std::mem::offset_of!(whisper_full_params, audio_ctx) - 56usize]; + ["Offset of field: whisper_full_params::tdrz_enable"] + [::std::mem::offset_of!(whisper_full_params, tdrz_enable) - 60usize]; + ["Offset of field: whisper_full_params::suppress_regex"] + [::std::mem::offset_of!(whisper_full_params, suppress_regex) - 64usize]; + ["Offset of field: whisper_full_params::initial_prompt"] + [::std::mem::offset_of!(whisper_full_params, initial_prompt) - 72usize]; + ["Offset of field: whisper_full_params::prompt_tokens"] + [::std::mem::offset_of!(whisper_full_params, prompt_tokens) - 80usize]; + ["Offset of field: whisper_full_params::prompt_n_tokens"] + [::std::mem::offset_of!(whisper_full_params, prompt_n_tokens) - 88usize]; + ["Offset of field: whisper_full_params::language"] + [::std::mem::offset_of!(whisper_full_params, language) - 96usize]; + ["Offset of field: whisper_full_params::detect_language"] + [::std::mem::offset_of!(whisper_full_params, detect_language) - 104usize]; + ["Offset of field: whisper_full_params::suppress_blank"] + [::std::mem::offset_of!(whisper_full_params, suppress_blank) - 105usize]; + ["Offset of field: whisper_full_params::suppress_nst"] + [::std::mem::offset_of!(whisper_full_params, suppress_nst) - 106usize]; + ["Offset of field: whisper_full_params::temperature"] + [::std::mem::offset_of!(whisper_full_params, temperature) - 108usize]; + ["Offset of field: whisper_full_params::max_initial_ts"] + [::std::mem::offset_of!(whisper_full_params, max_initial_ts) - 112usize]; + ["Offset of field: whisper_full_params::length_penalty"] + [::std::mem::offset_of!(whisper_full_params, length_penalty) - 116usize]; + ["Offset of field: whisper_full_params::temperature_inc"] + [::std::mem::offset_of!(whisper_full_params, temperature_inc) - 120usize]; + ["Offset of field: whisper_full_params::entropy_thold"] + [::std::mem::offset_of!(whisper_full_params, entropy_thold) - 124usize]; + ["Offset of field: whisper_full_params::logprob_thold"] + [::std::mem::offset_of!(whisper_full_params, logprob_thold) - 128usize]; + ["Offset of field: whisper_full_params::no_speech_thold"] + [::std::mem::offset_of!(whisper_full_params, no_speech_thold) - 132usize]; + ["Offset of field: whisper_full_params::greedy"] + [::std::mem::offset_of!(whisper_full_params, greedy) - 136usize]; + ["Offset of field: whisper_full_params::beam_search"] + [::std::mem::offset_of!(whisper_full_params, beam_search) - 140usize]; + ["Offset of field: whisper_full_params::new_segment_callback"] + [::std::mem::offset_of!(whisper_full_params, new_segment_callback) - 152usize]; + ["Offset of field: whisper_full_params::new_segment_callback_user_data"] + [::std::mem::offset_of!(whisper_full_params, new_segment_callback_user_data) - 160usize]; + ["Offset of field: whisper_full_params::progress_callback"] + [::std::mem::offset_of!(whisper_full_params, progress_callback) - 168usize]; + ["Offset of field: whisper_full_params::progress_callback_user_data"] + [::std::mem::offset_of!(whisper_full_params, progress_callback_user_data) - 176usize]; + ["Offset of field: whisper_full_params::encoder_begin_callback"] + [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback) - 184usize]; + ["Offset of field: whisper_full_params::encoder_begin_callback_user_data"] + [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback_user_data) - 192usize]; + ["Offset of field: whisper_full_params::abort_callback"] + [::std::mem::offset_of!(whisper_full_params, abort_callback) - 200usize]; + ["Offset of field: whisper_full_params::abort_callback_user_data"] + [::std::mem::offset_of!(whisper_full_params, abort_callback_user_data) - 208usize]; + ["Offset of field: whisper_full_params::logits_filter_callback"] + [::std::mem::offset_of!(whisper_full_params, logits_filter_callback) - 216usize]; + ["Offset of field: whisper_full_params::logits_filter_callback_user_data"] + [::std::mem::offset_of!(whisper_full_params, logits_filter_callback_user_data) - 224usize]; + ["Offset of field: whisper_full_params::grammar_rules"] + [::std::mem::offset_of!(whisper_full_params, grammar_rules) - 232usize]; + ["Offset of field: whisper_full_params::n_grammar_rules"] + [::std::mem::offset_of!(whisper_full_params, n_grammar_rules) - 240usize]; + ["Offset of field: whisper_full_params::i_start_rule"] + [::std::mem::offset_of!(whisper_full_params, i_start_rule) - 248usize]; + ["Offset of field: whisper_full_params::grammar_penalty"] + [::std::mem::offset_of!(whisper_full_params, grammar_penalty) - 256usize]; + ["Offset of field: whisper_full_params::vad"] + [::std::mem::offset_of!(whisper_full_params, vad) - 260usize]; + ["Offset of field: whisper_full_params::vad_model_path"] + [::std::mem::offset_of!(whisper_full_params, vad_model_path) - 264usize]; + ["Offset of field: whisper_full_params::vad_params"] + [::std::mem::offset_of!(whisper_full_params, vad_params) - 272usize]; +}; +unsafe extern "C" { + pub fn whisper_context_default_params_by_ref() -> *mut whisper_context_params; +} +unsafe extern "C" { + pub fn whisper_context_default_params() -> whisper_context_params; +} +unsafe extern "C" { + pub fn whisper_full_default_params_by_ref( + strategy: whisper_sampling_strategy, + ) -> *mut whisper_full_params; +} +unsafe extern "C" { + pub fn whisper_full_default_params(strategy: whisper_sampling_strategy) -> whisper_full_params; +} +unsafe extern "C" { + pub fn whisper_full( + ctx: *mut whisper_context, + params: whisper_full_params, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + params: whisper_full_params, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_parallel( + ctx: *mut whisper_context, + params: whisper_full_params, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + n_processors: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_n_segments(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_n_segments_from_state(state: *mut whisper_state) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_lang_id(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_lang_id_from_state(state: *mut whisper_state) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_t0( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + ) -> i64; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_t0_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> i64; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_t1( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + ) -> i64; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_t1_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> i64; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_speaker_turn_next( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + ) -> bool; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_speaker_turn_next_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> bool; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_text( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_text_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_full_n_tokens( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_n_tokens_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_full_get_token_text( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_full_get_token_text_from_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_full_get_token_id( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_full_get_token_id_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> whisper_token; +} +unsafe extern "C" { + pub fn whisper_full_get_token_data( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> whisper_token_data; +} +unsafe extern "C" { + pub fn whisper_full_get_token_data_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> whisper_token_data; +} +unsafe extern "C" { + pub fn whisper_full_get_token_p( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> f32; +} +unsafe extern "C" { + pub fn whisper_full_get_token_p_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> f32; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_vad_context { + _unused: [u8; 0], +} +unsafe extern "C" { + pub fn whisper_vad_default_params() -> whisper_vad_params; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_vad_context_params { + pub n_threads: ::std::os::raw::c_int, + pub use_gpu: bool, + pub gpu_device: ::std::os::raw::c_int, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of whisper_vad_context_params"] + [::std::mem::size_of::() - 12usize]; + ["Alignment of whisper_vad_context_params"] + [::std::mem::align_of::() - 4usize]; + ["Offset of field: whisper_vad_context_params::n_threads"] + [::std::mem::offset_of!(whisper_vad_context_params, n_threads) - 0usize]; + ["Offset of field: whisper_vad_context_params::use_gpu"] + [::std::mem::offset_of!(whisper_vad_context_params, use_gpu) - 4usize]; + ["Offset of field: whisper_vad_context_params::gpu_device"] + [::std::mem::offset_of!(whisper_vad_context_params, gpu_device) - 8usize]; +}; +unsafe extern "C" { + pub fn whisper_vad_default_context_params() -> whisper_vad_context_params; +} +unsafe extern "C" { + pub fn whisper_vad_init_from_file_with_params( + path_model: *const ::std::os::raw::c_char, + params: whisper_vad_context_params, + ) -> *mut whisper_vad_context; +} +unsafe extern "C" { + pub fn whisper_vad_init_with_params( + loader: *mut whisper_model_loader, + params: whisper_vad_context_params, + ) -> *mut whisper_vad_context; +} +unsafe extern "C" { + pub fn whisper_vad_detect_speech( + vctx: *mut whisper_vad_context, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + ) -> bool; +} +unsafe extern "C" { + pub fn whisper_vad_n_probs(vctx: *mut whisper_vad_context) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_vad_probs(vctx: *mut whisper_vad_context) -> *mut f32; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_vad_segments { + _unused: [u8; 0], +} +unsafe extern "C" { + pub fn whisper_vad_segments_from_probs( + vctx: *mut whisper_vad_context, + params: whisper_vad_params, + ) -> *mut whisper_vad_segments; +} +unsafe extern "C" { + pub fn whisper_vad_segments_from_samples( + vctx: *mut whisper_vad_context, + params: whisper_vad_params, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + ) -> *mut whisper_vad_segments; +} +unsafe extern "C" { + pub fn whisper_vad_segments_n_segments( + segments: *mut whisper_vad_segments, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_vad_segments_get_segment_t0( + segments: *mut whisper_vad_segments, + i_segment: ::std::os::raw::c_int, + ) -> f32; +} +unsafe extern "C" { + pub fn whisper_vad_segments_get_segment_t1( + segments: *mut whisper_vad_segments, + i_segment: ::std::os::raw::c_int, + ) -> f32; +} +unsafe extern "C" { + pub fn whisper_vad_free_segments(segments: *mut whisper_vad_segments); +} +unsafe extern "C" { + pub fn whisper_vad_free(ctx: *mut whisper_vad_context); +} +unsafe extern "C" { + pub fn whisper_bench_memcpy(n_threads: ::std::os::raw::c_int) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_bench_memcpy_str( + n_threads: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_bench_ggml_mul_mat(n_threads: ::std::os::raw::c_int) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn whisper_bench_ggml_mul_mat_str( + n_threads: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn whisper_log_set(log_callback: ggml_log_callback, user_data: *mut ::std::os::raw::c_void); +} +unsafe extern "C" { + pub fn whisper_full_get_segment_no_speech_prob( + ctx: *mut whisper_context, + i_segment: ::std::os::raw::c_int, + ) -> f32; +} +unsafe extern "C" { + pub fn whisper_full_get_segment_no_speech_prob_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> f32; +} +pub type __builtin_va_list = [__va_list_tag; 1usize]; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct __va_list_tag { + pub gp_offset: ::std::os::raw::c_uint, + pub fp_offset: ::std::os::raw::c_uint, + pub overflow_arg_area: *mut ::std::os::raw::c_void, + pub reg_save_area: *mut ::std::os::raw::c_void, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of __va_list_tag"][::std::mem::size_of::<__va_list_tag>() - 24usize]; + ["Alignment of __va_list_tag"][::std::mem::align_of::<__va_list_tag>() - 8usize]; + ["Offset of field: __va_list_tag::gp_offset"] + [::std::mem::offset_of!(__va_list_tag, gp_offset) - 0usize]; + ["Offset of field: __va_list_tag::fp_offset"] + [::std::mem::offset_of!(__va_list_tag, fp_offset) - 4usize]; + ["Offset of field: __va_list_tag::overflow_arg_area"] + [::std::mem::offset_of!(__va_list_tag, overflow_arg_area) - 8usize]; + ["Offset of field: __va_list_tag::reg_save_area"] + [::std::mem::offset_of!(__va_list_tag, reg_save_area) - 16usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_backend_buffer { + pub _address: u8, +} diff --git a/vendor/whisper-rs-sys/src/lib.rs b/vendor/whisper-rs-sys/src/lib.rs new file mode 100644 index 0000000..a38a13a --- /dev/null +++ b/vendor/whisper-rs-sys/src/lib.rs @@ -0,0 +1,5 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/vendor/whisper-rs-sys/whisper.cpp/CMakeLists.txt b/vendor/whisper-rs-sys/whisper.cpp/CMakeLists.txt new file mode 100644 index 0000000..989e94b --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/CMakeLists.txt @@ -0,0 +1,255 @@ +cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. +project("whisper.cpp" C CXX) +project("whisper.cpp" VERSION 1.7.6) +include(CheckIncludeFileCXX) + +set(SOVERSION 1) + +#set(CMAKE_WARN_DEPRECATED YES) +set(CMAKE_WARN_UNUSED_CLI YES) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +# Add path to modules +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + set(WHISPER_STANDALONE ON) + + include(git-vars) + + # configure project version + configure_file(${CMAKE_SOURCE_DIR}/bindings/javascript/package-tmpl.json ${CMAKE_SOURCE_DIR}/bindings/javascript/package.json @ONLY) +else() + set(WHISPER_STANDALONE OFF) +endif() + +if (EMSCRIPTEN) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + + option(WHISPER_WASM_SINGLE_FILE "whisper: embed WASM inside the generated whisper.js" ON) + + # TODO: without these, we get the following error: + # wasm-ld: error: --shared-memory is disallowed by whisper.cpp.o because it was not compiled with 'atomics' or 'bulk-memory' features. + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") + + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -s TOTAL_STACK=5242880") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s TOTAL_STACK=5242880") + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated") +else() + if (MINGW) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + else() + set(BUILD_SHARED_LIBS_DEFAULT ON) + endif() +endif() + +option(BUILD_SHARED_LIBS "build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT}) + +# +# option list +# + +# debug +option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON) +option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF) + +# build +option(WHISPER_FATAL_WARNINGS "whisper: enable -Werror flag" OFF) +option(WHISPER_USE_SYSTEM_GGML "whisper: use system-installed GGML library" OFF) + +# sanitizers +option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF) +option(WHISPER_SANITIZE_ADDRESS "whisper: enable address sanitizer" OFF) +option(WHISPER_SANITIZE_UNDEFINED "whisper: enable undefined sanitizer" OFF) + +# extra artifacts +option(WHISPER_BUILD_TESTS "whisper: build tests" ${WHISPER_STANDALONE}) +option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDALONE}) +option(WHISPER_BUILD_SERVER "whisper: build server example" ${WHISPER_STANDALONE}) + +# 3rd party libs +option(WHISPER_CURL "whisper: use libcurl to download model from an URL" OFF) +option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) + +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF) +endif() + +option(WHISPER_COREML "whisper: enable Core ML framework" OFF) +option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF) +option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF) + +# Required for relocatable CMake package +include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) + +# override ggml options +set(GGML_SANITIZE_THREAD ${WHISPER_SANITIZE_THREAD}) +set(GGML_SANITIZE_ADDRESS ${WHISPER_SANITIZE_ADDRESS}) +set(GGML_SANITIZE_UNDEFINED ${WHISPER_SANITIZE_UNDEFINED}) +set(GGML_ALL_WARNINGS ${WHISPER_ALL_WARNINGS}) +set(GGML_FATAL_WARNINGS ${WHISPER_FATAL_WARNINGS}) + +# transition helpers +function (whisper_option_depr TYPE OLD NEW) + if (${OLD}) + message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n") + set(${NEW} ON) + endif() +endfunction() + +whisper_option_depr(FATAL_ERROR WHISPER_CUBLAS GGML_CUDA) +whisper_option_depr(WARNING WHISPER_CUDA GGML_CUDA) +whisper_option_depr(WARNING WHISPER_KOMPUTE GGML_KOMPUTE) +whisper_option_depr(WARNING WHISPER_METAL GGML_METAL) +whisper_option_depr(WARNING WHISPER_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY) +whisper_option_depr(WARNING WHISPER_NATIVE GGML_NATIVE) +whisper_option_depr(WARNING WHISPER_OPENMP GGML_OPENMP) +whisper_option_depr(WARNING WHISPER_RPC GGML_RPC) +whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL) +whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16) +whisper_option_depr(WARNING WHISPER_CCACHE GGML_CCACHE) + +if (GGML_CUDA AND NOT MSVC) + #GGML_CUDA enabled, add the necessary compile options -Wno-deprecated-gpu-targets + add_compile_options(-Wno-deprecated-gpu-targets) +endif() + +# +# build the library +# + +if (NOT TARGET ggml) + if (WHISPER_USE_SYSTEM_GGML) + find_package(ggml REQUIRED) + if (NOT ggml_FOUND) + message(FATAL_ERROR "System-installed GGML library not found.") + endif() + add_library(ggml ALIAS ggml::ggml) + else() + add_subdirectory(ggml) + if(WIN32) + # The following adds a _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR macro and is a workaround for + # the Windows C++ standard library which does not support constexpr mutexes. + # From the release notes://github.com/microsoft/STL/wiki/Changelog + # Disable constexpr mutex constructor on Windows + # Fixed mutex's constructor to be constexpr. #3824 #4000 #4339 + # Note: Programs that aren't following the documented restrictions on binary compatibility may encounter + # null dereferences in mutex machinery. You must follow this rule: + # When you mix binaries built by different supported versions of the toolset, the Redistributable version + # must be at least as new as the latest toolset used by any app component. + # You can define _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR as an escape hatch. + # + # Specifically to whisper.cpp this would cause a crash when using the Java bindings. + # resulting in a Invalid memory access error. + target_compile_definitions(ggml-base PRIVATE _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR) + endif() + endif() + # ... otherwise assume ggml is added by a parent CMakeLists.txt +endif() +add_subdirectory(src) + +# +# install +# + +include(GNUInstallDirs) +include(CMakePackageConfigHelpers) + +set(WHISPER_BUILD_NUMBER ${BUILD_NUMBER}) +set(WHISPER_BUILD_COMMIT ${BUILD_COMMIT}) +set(WHISPER_INSTALL_VERSION ${CMAKE_PROJECT_VERSION}) + +set(WHISPER_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(WHISPER_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS) + +set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h) +install(TARGETS whisper LIBRARY PUBLIC_HEADER) + +target_compile_definitions(whisper PRIVATE + WHISPER_VERSION="${PROJECT_VERSION}" +) + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/whisper + PATH_VARS + WHISPER_INCLUDE_INSTALL_DIR + WHISPER_LIB_INSTALL_DIR + WHISPER_BIN_INSTALL_DIR ) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/whisper-version.cmake + VERSION ${WHISPER_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/whisper-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/whisper) + +configure_file(cmake/whisper.pc.in + "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" + @ONLY) + +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" + DESTINATION lib/pkgconfig) + +# +# programs, examples and tests +# + +if (WHISPER_BUILD_TESTS AND NOT CMAKE_JS_VERSION) + include(CTest) + add_subdirectory(tests) +endif () + +if (WHISPER_BUILD_EXAMPLES) + add_subdirectory(examples) +endif() + +if (MSVC) + set(MSVC_WARNING_FLAGS + /wd4101 # Unreferenced local variable + /wd4005 # Macro redefinition + /wd4065 # switch statement contains 'default' but no 'case' labels + /wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data + /wd4244 # Conversion from one type to another type, possible loss of ata + /wd4805 # Unsafe mix of type + /wd4305 # Truncation from 'type1' to 'type2' (often double to float) + /wd4996 # Function or variable may be unsafe/deprecated + ) + function(disable_msvc_warnings target_name) + if(TARGET ${target_name}) + target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS}) + endif() + endfunction() + + if (WHISPER_BUILD_EXAMPLES) + disable_msvc_warnings(whisper) + disable_msvc_warnings(common) + disable_msvc_warnings(common-sdl) + disable_msvc_warnings(lsp) + disable_msvc_warnings(wchess-core) + disable_msvc_warnings(whisper-command) + disable_msvc_warnings(whisper-cli) + disable_msvc_warnings(whisper-server) + disable_msvc_warnings(whisper-stream) + disable_msvc_warnings(whisper-talk-llama) + disable_msvc_warnings(whisper-bench) + disable_msvc_warnings(quantize) + disable_msvc_warnings(vad-speech-segments) + endif() +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/LICENSE b/vendor/whisper-rs-sys/whisper.cpp/LICENSE new file mode 100644 index 0000000..acb96ce --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023-2024 The ggml authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/whisper-rs-sys/whisper.cpp/bindings/CMakeLists.txt b/vendor/whisper-rs-sys/whisper.cpp/bindings/CMakeLists.txt new file mode 100644 index 0000000..af79c51 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/bindings/CMakeLists.txt @@ -0,0 +1,19 @@ +if (EMSCRIPTEN) + add_subdirectory(javascript) + + add_custom_command( + OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/javascript/publish.log + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/whisper.js + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/libwhisper.worker.js + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/javascript/package.json + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/javascript + COMMAND npm publish + COMMAND touch publish.log + COMMENT "Publishing npm module v${PROJECT_VERSION}" + VERBATIM + ) + + add_custom_target(publish-npm + DEPENDS javascript/publish.log + ) +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/bindings/javascript/package-tmpl.json b/vendor/whisper-rs-sys/whisper.cpp/bindings/javascript/package-tmpl.json new file mode 100644 index 0000000..d8ba210 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/bindings/javascript/package-tmpl.json @@ -0,0 +1,26 @@ +{ + "name": "whisper.cpp", + "version": "@PROJECT_VERSION@", + "description": "Whisper speech recognition", + "main": "whisper.js", + "scripts": { + "test": "echo \"todo: add tests\" && exit 0" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/ggerganov/whisper.cpp" + }, + "keywords": [ + "openai", + "whisper", + "speech-to-text", + "speech-recognition", + "transformer" + ], + "author": "Georgi Gerganov", + "license": "MIT", + "bugs": { + "url": "https://github.com/ggerganov/whisper.cpp/issues" + }, + "homepage": "https://github.com/ggerganov/whisper.cpp#readme" +} diff --git a/vendor/whisper-rs-sys/whisper.cpp/cmake/DefaultTargetOptions.cmake b/vendor/whisper-rs-sys/whisper.cpp/cmake/DefaultTargetOptions.cmake new file mode 100644 index 0000000..0cfbb34 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/cmake/DefaultTargetOptions.cmake @@ -0,0 +1,16 @@ +# Set the default compile features and properties for a target. + +if (NOT TARGET) + message(FATAL_ERROR "TARGET not set before including DefaultTargetOptions") +endif() + +target_compile_features(${TARGET} + PRIVATE + cxx_std_11 + ) + +set_target_properties(${TARGET} + PROPERTIES + EXPORT_COMPILE_COMMANDS ON + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" +) diff --git a/vendor/whisper-rs-sys/whisper.cpp/cmake/FindFFmpeg.cmake b/vendor/whisper-rs-sys/whisper.cpp/cmake/FindFFmpeg.cmake new file mode 100644 index 0000000..bb6aff6 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/cmake/FindFFmpeg.cmake @@ -0,0 +1,163 @@ +# From +# https://github.com/snikulov/cmake-modules/blob/master/FindFFmpeg.cmake +# +# vim: ts=2 sw=2 +# - Try to find the required ffmpeg components(default: AVFORMAT, AVUTIL, AVCODEC) +# +# Once done this will define +# FFMPEG_FOUND - System has the all required components. +# FFMPEG_INCLUDE_DIRS - Include directory necessary for using the required components headers. +# FFMPEG_LIBRARIES - Link these to use the required ffmpeg components. +# FFMPEG_DEFINITIONS - Compiler switches required for using the required ffmpeg components. +# +# For each of the components it will additionally set. +# - AVCODEC +# - AVDEVICE +# - AVFORMAT +# - AVFILTER +# - AVUTIL +# - POSTPROC +# - SWSCALE +# the following variables will be defined +# _FOUND - System has +# _INCLUDE_DIRS - Include directory necessary for using the headers +# _LIBRARIES - Link these to use +# _DEFINITIONS - Compiler switches required for using +# _VERSION - The components version +# +# Copyright (c) 2006, Matthias Kretz, +# Copyright (c) 2008, Alexander Neundorf, +# Copyright (c) 2011, Michael Jansen, +# +# Redistribution and use is allowed according to the terms of the BSD license. +# For details see the accompanying COPYING-CMAKE-SCRIPTS file. + +include(FindPackageHandleStandardArgs) + +# The default components were taken from a survey over other FindFFMPEG.cmake files +if (NOT FFmpeg_FIND_COMPONENTS) + set(FFmpeg_FIND_COMPONENTS AVFORMAT AVCODEC AVUTIL SWRESAMPLE) +endif() + +# +### Macro: set_component_found +# +# Marks the given component as found if both *_LIBRARIES AND *_INCLUDE_DIRS is present. +# +macro(set_component_found _component ) + if (${_component}_LIBRARIES AND ${_component}_INCLUDE_DIRS) + message(DEBUG " - ${_component} found.") + set(${_component}_FOUND TRUE) + else () + message(DEBUG " - ${_component} not found.") + endif () +endmacro() + +# +### Macro: find_component +# +# Checks for the given component by invoking pkgconfig and then looking up the libraries and +# include directories. +# +macro(find_component _component _pkgconfig _library _header) + + if (NOT WIN32) + # use pkg-config to get the directories and then use these values + # in the FIND_PATH() and FIND_LIBRARY() calls + find_package(PkgConfig) + if (PKG_CONFIG_FOUND) + pkg_check_modules(PC_${_component} ${_pkgconfig}) + message(STATUS "Pkgconfig found: ${PC_${_component}_INCLUDEDIR}") + message(STATUS "Pkgconfig found: ${PC_${_component}_INCLUDE_DIRS}") + message(STATUS "${PC_${_component}_CFLAGS}") + endif () + endif (NOT WIN32) + + + find_path(${_component}_INCLUDE_DIRS ${_header} + HINTS + ${PC_${_component}_INCLUDEDIR} + ${PC_${_component}_INCLUDE_DIRS} + PATH_SUFFIXES + ffmpeg + ) + + # CMake's default is to search first for shared libraries and then for static libraries. + # Todo later: add option to prefer static libs over dynamic: + find_library(${_component}_LIBRARIES NAMES ${_library} lib${_library}.a + HINTS + ${PC_${_component}_LIBDIR} + ${PC_${_component}_LIBRARY_DIRS} + ) + + set(${_component}_DEFINITIONS ${PC_${_component}_CFLAGS_OTHER} CACHE STRING "The ${_component} CFLAGS.") + set(${_component}_VERSION ${PC_${_component}_VERSION} CACHE STRING "The ${_component} version number.") + + set_component_found(${_component}) + + mark_as_advanced( + ${_component}_INCLUDE_DIRS + ${_component}_LIBRARIES + ${_component}_DEFINITIONS + ${_component}_VERSION) + +endmacro() + + +# Check for cached results. If there are skip the costly part. +if (NOT FFMPEG_LIBRARIES) + + # Check for all possible component. + find_component(AVCODEC libavcodec avcodec libavcodec/avcodec.h) + find_component(AVFORMAT libavformat avformat libavformat/avformat.h) + find_component(AVDEVICE libavdevice avdevice libavdevice/avdevice.h) + #find_component(AVRESAMPLE libavresample avresample libavresample/avresample.h) # old name for swresample + find_component(AVUTIL libavutil avutil libavutil/avutil.h) + find_component(AVFILTER libavfilter avfilter libavfilter/avfilter.h) + find_component(SWSCALE libswscale swscale libswscale/swscale.h) + find_component(POSTPROC libpostproc postproc libpostproc/postprocess.h) + find_component(SWRESAMPLE libswresample swresample libswresample/swresample.h) + + # Check if the required components were found and add their stuff to the FFMPEG_* vars. + foreach (_component ${FFmpeg_FIND_COMPONENTS}) + if (${_component}_FOUND) + # message(STATUS "Required component ${_component} present.") + set(FFMPEG_LIBRARIES ${FFMPEG_LIBRARIES} ${${_component}_LIBRARIES}) + set(FFMPEG_DEFINITIONS ${FFMPEG_DEFINITIONS} ${${_component}_DEFINITIONS}) + list(APPEND FFMPEG_INCLUDE_DIRS ${${_component}_INCLUDE_DIRS}) + else () + # message(STATUS "Required component ${_component} missing.") + endif () + endforeach () + + # Build the include path with duplicates removed. + if (FFMPEG_INCLUDE_DIRS) + list(REMOVE_DUPLICATES FFMPEG_INCLUDE_DIRS) + endif () + + # cache the vars. + set(FFMPEG_INCLUDE_DIRS ${FFMPEG_INCLUDE_DIRS} CACHE STRING "The FFmpeg include directories." FORCE) + set(FFMPEG_LIBRARIES ${FFMPEG_LIBRARIES} CACHE STRING "The FFmpeg libraries." FORCE) + set(FFMPEG_DEFINITIONS ${FFMPEG_DEFINITIONS} CACHE STRING "The FFmpeg cflags." FORCE) + + mark_as_advanced(FFMPEG_INCLUDE_DIRS + FFMPEG_LIBRARIES + FFMPEG_DEFINITIONS) + +endif () + +# Now set the noncached _FOUND vars for the components. +# whisper.cpp does not need SWSCALE +foreach (_component AVCODEC AVDEVICE AVFORMAT AVRESAMPLE AVUTIL POSTPROCESS) + set_component_found(${_component}) +endforeach () + +# Compile the list of required vars +set(_FFmpeg_REQUIRED_VARS FFMPEG_LIBRARIES FFMPEG_INCLUDE_DIRS) +foreach (_component ${FFmpeg_FIND_COMPONENTS}) + list(APPEND _FFmpeg_REQUIRED_VARS ${_component}_LIBRARIES ${_component}_INCLUDE_DIRS) +endforeach () + +# Give a nice error message if some of the required vars are missing. +find_package_handle_standard_args(FFmpeg DEFAULT_MSG ${_FFmpeg_REQUIRED_VARS}) + diff --git a/vendor/whisper-rs-sys/whisper.cpp/cmake/build-info.cmake b/vendor/whisper-rs-sys/whisper.cpp/cmake/build-info.cmake new file mode 100644 index 0000000..b293c9b --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/cmake/build-info.cmake @@ -0,0 +1,60 @@ +set(BUILD_NUMBER 0) +set(BUILD_COMMIT "unknown") +set(BUILD_COMPILER "unknown") +set(BUILD_TARGET "unknown") + +# Look for git +find_package(Git) +if(NOT Git_FOUND) + find_program(GIT_EXECUTABLE NAMES git git.exe) + if(GIT_EXECUTABLE) + set(Git_FOUND TRUE) + message(STATUS "Found Git: ${GIT_EXECUTABLE}") + else() + message(WARNING "Git not found. Build info will not be accurate.") + endif() +endif() + +# Get the commit count and hash +if(Git_FOUND) + execute_process( + COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE HEAD + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE RES + ) + if (RES EQUAL 0) + set(BUILD_COMMIT ${HEAD}) + endif() + execute_process( + COMMAND ${GIT_EXECUTABLE} rev-list --count HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE COUNT + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE RES + ) + if (RES EQUAL 0) + set(BUILD_NUMBER ${COUNT}) + endif() +endif() + +if(MSVC) + set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") + set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/utf-8>") +else() + execute_process( + COMMAND sh -c "$@ --version | head -1" _ ${CMAKE_C_COMPILER} + OUTPUT_VARIABLE OUT + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + set(BUILD_COMPILER ${OUT}) + execute_process( + COMMAND ${CMAKE_C_COMPILER} -dumpmachine + OUTPUT_VARIABLE OUT + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + set(BUILD_TARGET ${OUT}) +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/cmake/git-vars.cmake b/vendor/whisper-rs-sys/whisper.cpp/cmake/git-vars.cmake new file mode 100644 index 0000000..1a4c24e --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/cmake/git-vars.cmake @@ -0,0 +1,22 @@ +find_package(Git) + +# the commit's SHA1 +execute_process(COMMAND + "${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8 + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_SHA1 + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the date of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%ad --date=local + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_DATE + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the subject of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%s + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_COMMIT_SUBJECT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/vendor/whisper-rs-sys/whisper.cpp/cmake/whisper-config.cmake.in b/vendor/whisper-rs-sys/whisper.cpp/cmake/whisper-config.cmake.in new file mode 100644 index 0000000..6a3fa22 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/cmake/whisper-config.cmake.in @@ -0,0 +1,65 @@ +set(WHISPER_VERSION @WHISPER_INSTALL_VERSION@) +set(WHISPER_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) +set(WHISPER_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) +set(WHISPER_SHARED_LIB @BUILD_SHARED_LIBS@) + +set(GGML_BLAS @GGML_BLAS@) +set(GGML_CUDA @GGML_CUDA@) +set(GGML_METAL @GGML_METAL@) +set(GGML_HIPBLAS @GGML_HIPBLAS@) +set(GGML_ACCELERATE @GGML_ACCELERATE@) + +@PACKAGE_INIT@ + +set_and_check(WHISPER_INCLUDE_DIR "@PACKAGE_WHISPER_INCLUDE_INSTALL_DIR@") +set_and_check(WHISPER_LIB_DIR "@PACKAGE_WHISPER_LIB_INSTALL_DIR@") +set_and_check(WHISPER_BIN_DIR "@PACKAGE_WHISPER_BIN_INSTALL_DIR@") + +# Ensure transient dependencies satisfied + +find_package(Threads REQUIRED) + +if (APPLE AND GGML_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) +endif() + +if (GGML_BLAS) + find_package(BLAS REQUIRED) +endif() + +if (GGML_CUDA) + find_package(CUDAToolkit REQUIRED) +endif() + +if (GGML_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) +endif() + +if (GGML_HIPBLAS) + find_package(hip REQUIRED) + find_package(hipblas REQUIRED) + find_package(rocblas REQUIRED) +endif() + +find_library(whisper_LIBRARY whisper + REQUIRED + HINTS ${WHISPER_LIB_DIR}) + +set(_whisper_link_deps "Threads::Threads" "@WHISPER_EXTRA_LIBS@") +set(_whisper_transient_defines "@WHISPER_TRANSIENT_DEFINES@") + +add_library(whisper UNKNOWN IMPORTED) + +set_target_properties(whisper + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${WHISPER_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES "${_whisper_link_deps}" + INTERFACE_COMPILE_DEFINITIONS "${_whisper_transient_defines}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${whisper_LIBRARY}" + INTERFACE_COMPILE_FEATURES cxx_std_11 + POSITION_INDEPENDENT_CODE ON ) + +check_required_components(whisper) diff --git a/vendor/whisper-rs-sys/whisper.cpp/cmake/whisper.pc.in b/vendor/whisper-rs-sys/whisper.cpp/cmake/whisper.pc.in new file mode 100644 index 0000000..00ec791 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/cmake/whisper.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/lib +includedir=${prefix}/include + +Name: whisper +Description: Port of OpenAI's Whisper model in C/C++ +Version: @PROJECT_VERSION@ +Libs: -L${libdir} -lggml -lggml-base -lwhisper +Cflags: -I${includedir} diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/CMakeLists.txt b/vendor/whisper-rs-sys/whisper.cpp/ggml/CMakeLists.txt new file mode 100644 index 0000000..90e274c --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/CMakeLists.txt @@ -0,0 +1,449 @@ +cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. +project("ggml" C CXX) +include(CheckIncludeFileCXX) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + set(GGML_STANDALONE ON) + + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + + # configure project version + # TODO +else() + set(GGML_STANDALONE OFF) +endif() + +if (EMSCRIPTEN) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + + option(GGML_WASM_SINGLE_FILE "ggml: embed WASM inside the generated ggml.js" ON) +else() + if (MINGW) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + else() + set(BUILD_SHARED_LIBS_DEFAULT ON) + endif() +endif() + +# remove the lib prefix on win32 mingw +if (WIN32) + set(CMAKE_STATIC_LIBRARY_PREFIX "") + set(CMAKE_SHARED_LIBRARY_PREFIX "") + set(CMAKE_SHARED_MODULE_PREFIX "") +endif() + +option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT}) +option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF) +set(GGML_BACKEND_DIR "" CACHE PATH "ggml: directory to load dynamic backends from (requires GGML_BACKEND_DL") + +# +# option list +# + +# TODO: mark all options as advanced when not GGML_STANDALONE + +if (APPLE) + set(GGML_METAL_DEFAULT ON) + set(GGML_BLAS_DEFAULT ON) + set(GGML_BLAS_VENDOR_DEFAULT "Apple") +else() + set(GGML_METAL_DEFAULT OFF) + set(GGML_BLAS_DEFAULT OFF) + set(GGML_BLAS_VENDOR_DEFAULT "Generic") +endif() + +if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH}) + message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF") + set(GGML_NATIVE_DEFAULT OFF) +else() + set(GGML_NATIVE_DEFAULT ON) +endif() + +# defaults +if (NOT GGML_LLAMAFILE_DEFAULT) + set(GGML_LLAMAFILE_DEFAULT OFF) +endif() + +if (NOT GGML_CUDA_GRAPHS_DEFAULT) + set(GGML_CUDA_GRAPHS_DEFAULT OFF) +endif() + +# general +option(GGML_STATIC "ggml: static link libraries" OFF) +option(GGML_NATIVE "ggml: optimize the build for the current system" ${GGML_NATIVE_DEFAULT}) +option(GGML_LTO "ggml: enable link time optimization" OFF) +option(GGML_CCACHE "ggml: use ccache if available" ON) + +# debug +option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON) +option(GGML_ALL_WARNINGS_3RD_PARTY "ggml: enable all compiler warnings in 3rd party libs" OFF) +option(GGML_GPROF "ggml: enable gprof" OFF) + +# build +option(GGML_FATAL_WARNINGS "ggml: enable -Werror flag" OFF) + +# sanitizers +option(GGML_SANITIZE_THREAD "ggml: enable thread sanitizer" OFF) +option(GGML_SANITIZE_ADDRESS "ggml: enable address sanitizer" OFF) +option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF) + +# instruction set specific +if (GGML_NATIVE OR NOT GGML_NATIVE_DEFAULT) + set(INS_ENB OFF) +else() + set(INS_ENB ON) +endif() + +message(DEBUG "GGML_NATIVE : ${GGML_NATIVE}") +message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}") +message(DEBUG "INS_ENB : ${INS_ENB}") + +option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) +option(GGML_CPU_REPACK "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) +option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF) +option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB}) +option(GGML_AVX "ggml: enable AVX" ${INS_ENB}) +option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF) +option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB}) +option(GGML_BMI2 "ggml: enable BMI2" ${INS_ENB}) +option(GGML_AVX512 "ggml: enable AVX512F" OFF) +option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF) +option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF) +option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF) +if (NOT MSVC) + # in MSVC F16C and FMA is implied with AVX2/AVX512 + option(GGML_FMA "ggml: enable FMA" ${INS_ENB}) + option(GGML_F16C "ggml: enable F16C" ${INS_ENB}) + # MSVC does not seem to support AMX + option(GGML_AMX_TILE "ggml: enable AMX-TILE" OFF) + option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF) + option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF) +endif() +option(GGML_LASX "ggml: enable lasx" ON) +option(GGML_LSX "ggml: enable lsx" ON) +option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) +option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) +option(GGML_VXE "ggml: enable vxe" ON) +option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877 + +option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) +set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") +set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC") + + +if (MINGW) + set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows version") +endif() + +# ggml core +set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") +option(GGML_CPU "ggml: enable CPU backend" ON) + +# 3rd party libs / backends +option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON) +option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT}) +set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING + "ggml: BLAS library vendor") +option(GGML_LLAMAFILE "ggml: use LLAMAFILE" ${GGML_LLAMAFILE_DEFAULT}) + +option(GGML_CUDA "ggml: use CUDA" OFF) +option(GGML_MUSA "ggml: use MUSA" OFF) +option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) +option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) +option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) +set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING + "ggml: max. batch size for using peer access") +option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) +option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF) +option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) +option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) +option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) +set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING + "ggml: cuda link binary compression mode; requires cuda 12.8+") +set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") + +option(GGML_HIP "ggml: use HIP" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) +option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) +option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF) +option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) +option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF) +option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF) +option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF) +option(GGML_VULKAN "ggml: use Vulkan" OFF) +option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) +option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF) +option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF) +option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF) +option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) +option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) +option(GGML_WEBGPU "ggml: use WebGPU" OFF) +option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) +option(GGML_ZDNN "ggml: use zDNN" OFF) +option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) +option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF) +option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) +option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF) +option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL}) +set (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING + "ggml: metal minimum macOS version") +set (GGML_METAL_STD "" CACHE STRING "ggml: metal standard version (-std flag)") +option(GGML_OPENMP "ggml: use OpenMP" ON) +option(GGML_RPC "ggml: use RPC" OFF) +option(GGML_SYCL "ggml: use SYCL" OFF) +option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) +option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) +option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) +set (GGML_SYCL_TARGET "INTEL" CACHE STRING + "ggml: sycl target device") +set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING + "ggml: sycl device architecture") + +option(GGML_OPENCL "ggml: use OpenCL" OFF) +option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF) +option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) +option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) +set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING + "gmml: OpenCL API version to target") + +# toolchain for vulkan-shaders-gen +set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") + +# extra artifacts +option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) +option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) + +# +# dependencies +# + +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED true) + +set(THREADS_PREFER_PTHREAD_FLAG ON) + +find_package(Threads REQUIRED) + +include(GNUInstallDirs) + +# +# build the library +# + +add_subdirectory(src) + +# +# tests and examples +# + +if (GGML_BUILD_TESTS) + enable_testing() + add_subdirectory(tests) +endif () + +if (GGML_BUILD_EXAMPLES) + add_subdirectory(examples) +endif () + +# +# install +# + +include(CMakePackageConfigHelpers) + +# all public headers +set(GGML_PUBLIC_HEADERS + include/ggml.h + include/ggml-cpu.h + include/ggml-alloc.h + include/ggml-backend.h + include/ggml-blas.h + include/ggml-cann.h + include/ggml-cpp.h + include/ggml-cuda.h + include/ggml-opt.h + include/ggml-metal.h + include/ggml-rpc.h + include/ggml-sycl.h + include/ggml-vulkan.h + include/ggml-webgpu.h + include/gguf.h) + +set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") +#if (GGML_METAL) +# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal") +#endif() +install(TARGETS ggml LIBRARY PUBLIC_HEADER) +install(TARGETS ggml-base LIBRARY) + +if (GGML_STANDALONE) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in + ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc + @ONLY) + + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc + DESTINATION share/pkgconfig) +endif() + +# +# Create CMake package +# + +# Generate version info based on git commit. + +if(NOT DEFINED GGML_BUILD_NUMBER) + find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) + execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_NUMBER + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + + if(GGML_BUILD_NUMBER EQUAL 1) + message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") + endif() + + execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_COMMIT + OUTPUT_STRIP_TRAILING_WHITESPACE + ) +endif() + + +# Capture variables prefixed with GGML_. + +set(variable_set_statements +" +####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run ####### + +") + +set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS}) + +get_cmake_property(all_variables VARIABLES) +foreach(variable_name IN LISTS all_variables) + if(variable_name MATCHES "^GGML_") + string(REPLACE ";" "\\;" + variable_value "${${variable_name}}") + + set(variable_set_statements + "${variable_set_statements}set(${variable_name} \"${variable_value}\")\n") + endif() +endforeach() + +set(GGML_VARIABLES_EXPANDED ${variable_set_statements}) + +# Create the CMake package and set install location. + +set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER}) +set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml + PATH_VARS GGML_INCLUDE_INSTALL_DIR + GGML_LIB_INSTALL_DIR + GGML_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake + VERSION ${GGML_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +target_compile_definitions(ggml-base PRIVATE + GGML_VERSION="${GGML_INSTALL_VERSION}" + GGML_COMMIT="${GGML_BUILD_COMMIT}" +) +message(STATUS "ggml version: ${GGML_INSTALL_VERSION}") +message(STATUS "ggml commit: ${GGML_BUILD_COMMIT}") + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml) + +if (MSVC) + set(MSVC_WARNING_FLAGS + /wd4005 # Macro redefinition + /wd4244 # Conversion from one type to another type, possible loss of data + /wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data + /wd4305 # Conversion from 'type1' to 'type2', possible loss of data + /wd4566 # Conversion from 'char' to 'wchar_t', possible loss of data + /wd4996 # Disable POSIX deprecation warnings + /wd4702 # Unreachable code warnings + ) + function(disable_msvc_warnings target_name) + if(TARGET ${target_name}) + target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS}) + endif() + endfunction() + + disable_msvc_warnings(ggml-base) + disable_msvc_warnings(ggml) + disable_msvc_warnings(ggml-cpu) + disable_msvc_warnings(ggml-cpu-x64) + disable_msvc_warnings(ggml-cpu-sse42) + disable_msvc_warnings(ggml-cpu-sandybridge) + disable_msvc_warnings(ggml-cpu-haswell) + disable_msvc_warnings(ggml-cpu-skylakex) + disable_msvc_warnings(ggml-cpu-icelake) + disable_msvc_warnings(ggml-cpu-alderlake) + + if (GGML_BUILD_EXAMPLES) + disable_msvc_warnings(common-ggml) + disable_msvc_warnings(common) + + disable_msvc_warnings(mnist-common) + disable_msvc_warnings(mnist-eval) + disable_msvc_warnings(mnist-train) + + disable_msvc_warnings(gpt-2-ctx) + disable_msvc_warnings(gpt-2-alloc) + disable_msvc_warnings(gpt-2-backend) + disable_msvc_warnings(gpt-2-sched) + disable_msvc_warnings(gpt-2-quantize) + disable_msvc_warnings(gpt-2-batched) + + disable_msvc_warnings(gpt-j) + disable_msvc_warnings(gpt-j-quantize) + + disable_msvc_warnings(magika) + disable_msvc_warnings(yolov3-tiny) + disable_msvc_warnings(sam) + + disable_msvc_warnings(simple-ctx) + disable_msvc_warnings(simple-backend) + endif() + + if (GGML_BUILD_TESTS) + disable_msvc_warnings(test-mul-mat) + disable_msvc_warnings(test-arange) + disable_msvc_warnings(test-backend-ops) + disable_msvc_warnings(test-cont) + disable_msvc_warnings(test-conv-transpose) + disable_msvc_warnings(test-conv-transpose-1d) + disable_msvc_warnings(test-conv1d) + disable_msvc_warnings(test-conv2d) + disable_msvc_warnings(test-conv2d-dw) + disable_msvc_warnings(test-customop) + disable_msvc_warnings(test-dup) + disable_msvc_warnings(test-opt) + disable_msvc_warnings(test-pool) + endif () +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/BuildTypes.cmake b/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/BuildTypes.cmake new file mode 100644 index 0000000..a9c7b6c --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/BuildTypes.cmake @@ -0,0 +1,54 @@ +# Add new build types + +# ReleaseGG - Release with enabled asserts + +SET(CMAKE_CXX_FLAGS_RELEASEGG + "-O3" + CACHE STRING "Flags used by the c++ compiler during release builds with enabled asserts." + FORCE ) +SET(CMAKE_C_FLAGS_RELEASEGG + "-O3" + CACHE STRING "Flags used by the compiler during release builds with enabled asserts." + FORCE ) +SET(CMAKE_EXE_LINKER_FLAGS_RELEASEGG + "" + CACHE STRING "Flags used for linking binaries during release builds with enabled asserts." + FORCE ) +SET(CMAKE_SHARED_LINKER_FLAGS_RELEASEGG + "" + CACHE STRING "Flags used by the shared libraries linker during release builds with enabled asserts." + FORCE ) +MARK_AS_ADVANCED( + CMAKE_CXX_FLAGS_RELEASEGG + CMAKE_C_FLAGS_RELEASEGG + CMAKE_EXE_LINKER_FLAGS_RELEASEGG + CMAKE_SHARED_LINKER_FLAGS_RELEASEGG ) + +# RelWithDebInfoGG - RelWithDebInfo with enabled asserts + +SET(CMAKE_CXX_FLAGS_RELWITHDEBINFOGG + "-O2 -g" + CACHE STRING "Flags used by the c++ compiler during release builds with debug symbols and enabled asserts." + FORCE ) +SET(CMAKE_C_FLAGS_RELWITHDEBINFOGG + "-O2 -g" + CACHE STRING "Flags used by the compiler during release builds with debug symbols and enabled asserts." + FORCE ) +SET(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG + "" + CACHE STRING "Flags used for linking binaries during release builds with debug symbols and enabled asserts." + FORCE ) +SET(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG + "" + CACHE STRING "Flags used by the shared libraries linker during release builds with debug symbols and enabled asserts." + FORCE ) +MARK_AS_ADVANCED( + CMAKE_CXX_FLAGS_RELWITHDEBINFOGG + CMAKE_C_FLAGS_RELWITHDEBINFOGG + CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG + CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG ) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ReleaseGG" "RelWithDebInfoGG") +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/GitVars.cmake b/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/GitVars.cmake new file mode 100644 index 0000000..1a4c24e --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/GitVars.cmake @@ -0,0 +1,22 @@ +find_package(Git) + +# the commit's SHA1 +execute_process(COMMAND + "${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8 + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_SHA1 + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the date of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%ad --date=local + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_DATE + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the subject of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%s + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_COMMIT_SUBJECT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/common.cmake b/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/common.cmake new file mode 100644 index 0000000..cb66388 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/common.cmake @@ -0,0 +1,50 @@ +function(ggml_get_flags CCID CCVER) + set(C_FLAGS "") + set(CXX_FLAGS "") + + if (CCID MATCHES "Clang") + set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return) + set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi) + + if ( + (CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR + (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0) + ) + list(APPEND C_FLAGS -Wdouble-promotion) + endif() + elseif (CCID STREQUAL "GNU") + set(C_FLAGS -Wdouble-promotion) + set(CXX_FLAGS -Wno-array-bounds) + + if (CCVER VERSION_GREATER_EQUAL 8.1.0) + list(APPEND CXX_FLAGS -Wextra-semi) + endif() + endif() + + set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE) + set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE) +endfunction() + +function(ggml_get_system_arch) + if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR + CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")) + set(GGML_SYSTEM_ARCH "ARM" PARENT_SCOPE) + elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR + CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$")) + set(GGML_SYSTEM_ARCH "x86" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc|power") + set(GGML_SYSTEM_ARCH "PowerPC" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") + set(GGML_SYSTEM_ARCH "loongarch64" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") + set(GGML_SYSTEM_ARCH "riscv64" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + set(GGML_SYSTEM_ARCH "s390x" PARENT_SCOPE) + else() + set(GGML_SYSTEM_ARCH "UNKNOWN" PARENT_SCOPE) + endif() +endfunction() diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/ggml-config.cmake.in b/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/ggml-config.cmake.in new file mode 100644 index 0000000..91c9d5c --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/cmake/ggml-config.cmake.in @@ -0,0 +1,191 @@ +@PACKAGE_INIT@ + +@GGML_VARIABLES_EXPANDED@ + +# Find all dependencies before creating any target. +include(CMakeFindDependencyMacro) +find_dependency(Threads) +if (NOT GGML_SHARED_LIB) + set(GGML_CPU_INTERFACE_LINK_LIBRARIES "") + set(GGML_CPU_INTERFACE_LINK_OPTIONS "") + + if (APPLE AND GGML_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if(NOT ACCELERATE_FRAMEWORK) + set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0) + return() + endif() + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK}) + endif() + + if (GGML_OPENMP_ENABLED) + find_dependency(OpenMP) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + endif() + + if (GGML_CPU_HBM) + find_library(memkind memkind) + if(NOT memkind) + set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0) + return() + endif() + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind) + endif() + + if (GGML_BLAS) + find_dependency(BLAS) + list(APPEND GGML_BLAS_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES}) + list(APPEND GGML_BLAS_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS}) + endif() + + if (GGML_CUDA) + set(GGML_CUDA_INTERFACE_LINK_LIBRARIES "") + find_dependency(CUDAToolkit) + if (GGML_STATIC) + list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $) + if (WIN32) + list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $ $) + else() + list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $ $) + endif() + endif() + if (NOT GGML_CUDA_NO_VMM) + list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $) + endif() + endif() + + if (GGML_METAL) + find_library(FOUNDATION_LIBRARY Foundation) + find_library(METAL_FRAMEWORK Metal) + find_library(METALKIT_FRAMEWORK MetalKit) + if(NOT FOUNDATION_LIBRARY OR NOT METAL_FRAMEWORK OR NOT METALKIT_FRAMEWORK) + set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0) + return() + endif() + set(GGML_METAL_INTERFACE_LINK_LIBRARIES + ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) + endif() + + if (GGML_OPENCL) + find_dependency(OpenCL) + set(GGML_OPENCL_INTERFACE_LINK_LIBRARIES $) + endif() + + if (GGML_VULKAN) + find_dependency(Vulkan) + set(GGML_VULKAN_INTERFACE_LINK_LIBRARIES $) + endif() + + if (GGML_HIP) + find_dependency(hip) + find_dependency(hipblas) + find_dependency(rocblas) + set(GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas) + endif() + + if (GGML_SYCL) + set(GGML_SYCL_INTERFACE_LINK_LIBRARIES "") + find_package(DNNL) + if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") + list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl) + endif() + if (WIN32) + find_dependency(IntelSYCL) + find_dependency(MKL) + list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) + endif() + endif() +endif() + +set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@") +set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@") +#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") + +if(NOT TARGET ggml::ggml) + find_package(Threads REQUIRED) + + find_library(GGML_LIBRARY ggml + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + + add_library(ggml::ggml UNKNOWN IMPORTED) + set_target_properties(ggml::ggml + PROPERTIES + IMPORTED_LOCATION "${GGML_LIBRARY}") + + find_library(GGML_BASE_LIBRARY ggml-base + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + + add_library(ggml::ggml-base UNKNOWN IMPORTED) + set_target_properties(ggml::ggml-base + PROPERTIES + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + + set(_ggml_all_targets "") + if (NOT GGML_BACKEND_DL) + foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) + string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}") + string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx) + + find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend} + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + + message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}") + + add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}" + INTERFACE_COMPILE_FEATURES c_std_90 + POSITION_INDEPENDENT_CODE ON) + + string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") + if(is_cpu_variant) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") + + if(GGML_CPU_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}") + endif() + + else() + list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") + + if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}") + endif() + endif() + + list(APPEND _ggml_all_targets ggml::${_ggml_backend}) + endforeach() + endif() + + list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}") + set_target_properties(ggml::ggml + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}") + + add_library(ggml::all INTERFACE IMPORTED) + set_target_properties(ggml::all + PROPERTIES + INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}") + +endif() + +check_required_components(ggml) diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-alloc.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-alloc.h new file mode 100644 index 0000000..2cb150f --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-alloc.h @@ -0,0 +1,76 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; +typedef struct ggml_backend_buffer * ggml_backend_buffer_t; +typedef struct ggml_backend * ggml_backend_t; + +// Tensor allocator +struct ggml_tallocr { + ggml_backend_buffer_t buffer; + void * base; + size_t alignment; + size_t offset; +}; + +GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer); +GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor); + +// Graph allocator +/* + Example usage: + ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type()); + + // optional: create a worst-case graph and reserve the buffers to avoid reallocations + ggml_gallocr_reserve(galloc, build_graph(max_batch)); + + // allocate the graph + struct ggml_cgraph * graph = build_graph(batch); + ggml_gallocr_alloc_graph(galloc, graph); + + printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0)); + + // evaluate the graph + ggml_backend_graph_compute(backend, graph); +*/ + +// special tensor flags for use with the graph allocator: +// ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses +// ggml_set_output(): output tensors are never freed and never overwritten + +typedef struct ggml_gallocr * ggml_gallocr_t; + +GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft); +GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs); +GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc); + +// pre-allocate buffers from a measure graph - does not allocate or modify the graph +// call with a worst-case graph to avoid buffer reallocations +// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed +// returns false if the buffer allocation failed +GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph); +GGML_API bool ggml_gallocr_reserve_n( + ggml_gallocr_t galloc, + struct ggml_cgraph * graph, + const int * node_buffer_ids, + const int * leaf_buffer_ids); + +// automatic reallocation if the topology changes when using a single buffer +// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers) +GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph); + +GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); + +// Utils +// Create a buffer and allocate all the tensors in a ggml_context +GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); +GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-backend.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-backend.h new file mode 100644 index 0000000..a2977ea --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-backend.h @@ -0,0 +1,354 @@ +#pragma once + +#include "ggml.h" +#include "ggml-alloc.h" + +#ifdef GGML_BACKEND_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BACKEND_BUILD +# define GGML_BACKEND_API __declspec(dllexport) extern +# else +# define GGML_BACKEND_API __declspec(dllimport) extern +# endif +# else +# define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern +# endif +#else +# define GGML_BACKEND_API extern +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; + typedef struct ggml_backend_buffer * ggml_backend_buffer_t; + typedef struct ggml_backend_event * ggml_backend_event_t; + typedef struct ggml_backend * ggml_backend_t; + typedef void * ggml_backend_graph_plan_t; + typedef struct ggml_backend_reg * ggml_backend_reg_t; + typedef struct ggml_backend_device * ggml_backend_dev_t; + + + // + // Backend buffer type + // + + GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); + GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); + GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); + GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); + GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); + GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); + GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft); + + // + // Backend buffer + // + + enum ggml_backend_buffer_usage { + GGML_BACKEND_BUFFER_USAGE_ANY = 0, + GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, + GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2, + }; + + GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); + GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); + GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor); + GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); + GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer); + GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); + + // tensor copy between different backends + GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + + // + // Backend (stream) + // + + GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend); + GGML_API const char * ggml_backend_name(ggml_backend_t backend); + GGML_API void ggml_backend_free(ggml_backend_t backend); + + GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend); + GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); + GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); + GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); + + GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + + // "offset" refers to the offset in tensor->data for setting/getting data + GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + + GGML_API void ggml_backend_synchronize(ggml_backend_t backend); + + GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + + GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph); + + // NOTE: will be removed, use device version instead + GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op); + GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft); + GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op); + + // asynchronous copy + // the copy is performed after all the currently queued operations in backend_src + // backend_dst will wait for the copy to complete before performing other operations + // automatic fallback to sync copy if async is not supported + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); + + GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend); + + // + // Events + // + + GGML_API ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device); + GGML_API void ggml_backend_event_free(ggml_backend_event_t event); + GGML_API void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend); + GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event); + GGML_API void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event); + + // + // Backend device + // + + enum ggml_backend_dev_type { + // CPU device using system memory + GGML_BACKEND_DEVICE_TYPE_CPU, + // GPU device using dedicated memory + GGML_BACKEND_DEVICE_TYPE_GPU, + // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) + GGML_BACKEND_DEVICE_TYPE_ACCEL + }; + + // functionality supported by the device + struct ggml_backend_dev_caps { + // asynchronous operations + bool async; + // pinned host buffer + bool host_buffer; + // creating buffers from host ptr + bool buffer_from_host_ptr; + // event synchronization + bool events; + }; + + // all the device properties + struct ggml_backend_dev_props { + const char * name; + const char * description; + size_t memory_free; + size_t memory_total; + enum ggml_backend_dev_type type; + struct ggml_backend_dev_caps caps; + }; + + GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); + GGML_API const char * ggml_backend_dev_description(ggml_backend_dev_t device); + GGML_API void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total); + GGML_API enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device); + GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props); + GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device); + GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params); + GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device); + GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); + GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); + + GGML_API bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op); + GGML_API bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft); + GGML_API bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op); + + // + // Backend (reg) + // + + GGML_API const char * ggml_backend_reg_name(ggml_backend_reg_t reg); + GGML_API size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg); + GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index); + GGML_API void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name); + + // Common functions that may be obtained using ggml_backend_reg_get_proc_address + + // Split buffer type for tensor parallelism + typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); + // Set the number of threads for the backend + typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); + // Get additional buffer types provided by the device (returns a NULL-terminated array) + typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device); + // Set the abort callback for the backend + typedef void (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data); + // Get a list of feature flags supported by the backend (returns a NULL-terminated array) + struct ggml_backend_feature { + const char * name; + const char * value; + }; + typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg); + + // + // Backend registry + // + + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); + + // Backend (reg) enumeration + GGML_API size_t ggml_backend_reg_count(void); + GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index); + GGML_API ggml_backend_reg_t ggml_backend_reg_by_name(const char * name); + + // Device enumeration + GGML_API size_t ggml_backend_dev_count(void); + GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index); + GGML_API ggml_backend_dev_t ggml_backend_dev_by_name(const char * name); + GGML_API ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type); + + // Direct backend (stream) initialization + // = ggml_backend_dev_init(ggml_backend_dev_by_name(name), params) + GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params); + // = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params) + GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params); + // = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL) + GGML_API ggml_backend_t ggml_backend_init_best(void); + + // Load a backend from a dynamic library and register it + GGML_API ggml_backend_reg_t ggml_backend_load(const char * path); + // Unload a backend if loaded dynamically and unregister it + GGML_API void ggml_backend_unload(ggml_backend_reg_t reg); + // Load all known backends from dynamic libraries + GGML_API void ggml_backend_load_all(void); + GGML_API void ggml_backend_load_all_from_path(const char * dir_path); + + // + // Backend scheduler + // + + // The backend scheduler allows for multiple backend devices to be used together + // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends + // The backends are selected based on: + // - the backend that supports the operation + // - the location of the pre-allocated tensors (e.g. the weights) + /* + Example usage: + + // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned + // preferrably to run on the same backend as the buffer + ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true); + + // initialize buffers from a max size graph (optional) + reserve_graph = build_graph(sched, max_batch_size); + + // manually assign nodes to a backend (optional, should not be needed in most cases) + struct ggml_tensor * node = ggml_mul_mat(ctx, ...); + ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu); + + ggml_backend_sched_reserve(sched, reserve_graph); + + // compute + graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation + for (int i = 0; i < 10; ++i) { + ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically + } + + // if there are graph inputs: + graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called) + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph + ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it + ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, graph); // execute the graph + + // as an alternative to the above it is also possible to assign the inputs to a dedicated context and + // allocate them statically via ggml_backend_alloc_ctx_tensors + } + */ + + typedef struct ggml_backend_sched * ggml_backend_sched_t; + + // Evaluation callback for each node in the graph (set with ggml_backend_sched_set_eval_callback) + // when ask == true, the scheduler wants to know if the user wants to observe this node + // this allows the scheduler to batch nodes together in order to evaluate them in a single call + // + // when ask == false, the scheduler is passing the node tensor to the user for observation + // if the user returns false, the scheduler will cancel the graph compute + // + typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); + + // Initialize a backend scheduler, backends with low index are given priority over backends with high index + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); + GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); + + // Initialize backend buffers from a measure graph + GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success + + GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched); + GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i); + + // Get the number of splits of the last graph + GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched); + GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched); + + GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + + GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); + GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); + + // Allocate and compute graph on the backend scheduler + GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success + GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched); + + // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph. + // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers. + // The correct way to use this API is to discard the deallocated tensors and create new ones. + GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); + + // Set a callback to be called for each resulting node during graph compute + GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); + + // + // Utils + // + + struct ggml_backend_graph_copy { + ggml_backend_buffer_t buffer; + struct ggml_context * ctx_allocated; + struct ggml_context * ctx_unallocated; + struct ggml_cgraph * graph; + }; + + // Copy a graph to a different backend + GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph); + GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy); + + typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); + + // Compare the output of two backends + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); + + // Tensor initialization + GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); + GGML_API enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor); + + // CPU buffer types are always available + GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-blas.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-blas.h new file mode 100644 index 0000000..87a81b3 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-blas.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + + +#ifdef __cplusplus +extern "C" { +#endif + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void); + +GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend); + +// number of threads used for conversion to float +// for openblas and blis, this will also set the number of threads used for blas operations +GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void); + + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cann.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cann.h new file mode 100644 index 0000000..b469e22 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cann.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#pragma once + +#include "ggml-backend.h" +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Maximum number of CANN devices supported. + */ +#define GGML_CANN_MAX_DEVICES 16 + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void); + +/** + * @brief Initializes the CANN backend for a specified device. + * + * This function initializes the CANN backend for the given device. + * It verifies the device index, allocates a context, and creates a backend + * instance. + * + * @param device The index of the device to initialize. + * @return A pointer to the initialized backend instance, or nullptr on failure. + */ +GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device); + +/** + * @brief Checks if a given backend is a CANN backend. + * + * This function verifies if the provided backend is a CANN backend by comparing + * its GUID with the CANN backend's GUID. + * + * @param backend The backend instance to check. + * @return True if the backend is a CANN backend, false otherwise. + */ +GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend); + +/** + * @brief Retrieves the CANN buffer type for a specified device. + * + * This function initializes and returns the buffer type interface associated + * with the given device. It ensures thread-safe access using a mutex. + * + * @param device The device index for which to retrieve the buffer type. + * @return A pointer to the buffer type interface for the specified device, or + * nullptr if the device index is out of range. + */ +GGML_BACKEND_API ggml_backend_buffer_type_t +ggml_backend_cann_buffer_type(int32_t device); + +/** + * @brief Retrieves the number of CANN devices available. + * + * This function returns the number of CANN devices available based on + * information obtained from `ggml_cann_info()`. + * + * @return The number of CANN devices available. + */ +GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void); + +/** + * @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU. + * + * @return A pointer to the host buffer type interface. + */ +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void); + +/** + * @brief Retrieves the description of a specific CANN device. + * + * This function sets the specified device, retrieves the SoC name, + * and writes it into the provided description buffer. + * + * @param device The device index to retrieve the description for. + * @param description Pointer to a buffer where the description will be written. + * @param description_size Size of the description buffer. + */ +GGML_BACKEND_API void ggml_backend_cann_get_device_description( + int32_t device, char* description, size_t description_size); + +/** + * @brief Retrieves the memory information of a specific CANN device. + * + * This function sets the specified device, retrieves the free and total + * memory information of the specified type (ACL_HBM_MEM), and stores them + * in the provided pointers. + * + * @param device The device index to retrieve memory information for. + * @param free Pointer to a variable where the free memory size will be stored. + * @param total Pointer to a variable where the total memory size will be + * stored. + */ +GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device, + size_t* free, + size_t* total); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cpp.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cpp.h new file mode 100644 index 0000000..48aa796 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cpp.h @@ -0,0 +1,39 @@ +#pragma once + +#ifndef __cplusplus +#error "This header is for C++ only" +#endif + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "gguf.h" +#include + +// Smart pointers for ggml types + +// ggml + +struct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } }; +struct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } }; + +typedef std::unique_ptr ggml_context_ptr; +typedef std::unique_ptr gguf_context_ptr; + +// ggml-alloc + +struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } }; + +typedef std::unique_ptr ggml_gallocr_ptr; + +// ggml-backend + +struct ggml_backend_deleter { void operator()(ggml_backend_t backend) { ggml_backend_free(backend); } }; +struct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } }; +struct ggml_backend_event_deleter { void operator()(ggml_backend_event_t event) { ggml_backend_event_free(event); } }; +struct ggml_backend_sched_deleter { void operator()(ggml_backend_sched_t sched) { ggml_backend_sched_free(sched); } }; + +typedef std::unique_ptr ggml_backend_ptr; +typedef std::unique_ptr ggml_backend_buffer_ptr; +typedef std::unique_ptr ggml_backend_event_ptr; +typedef std::unique_ptr ggml_backend_sched_ptr; diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cpu.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cpu.h new file mode 100644 index 0000000..be40b10 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cpu.h @@ -0,0 +1,145 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + + // the compute plan that needs to be prepared for ggml_graph_compute() + // since https://github.com/ggml-org/ggml/issues/287 + struct ggml_cplan { + size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` + uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` + + int n_threads; + struct ggml_threadpool * threadpool; + + // abort ggml_graph_compute when true + ggml_abort_callback abort_callback; + void * abort_callback_data; + }; + + // numa strategies + enum ggml_numa_strategy { + GGML_NUMA_STRATEGY_DISABLED = 0, + GGML_NUMA_STRATEGY_DISTRIBUTE = 1, + GGML_NUMA_STRATEGY_ISOLATE = 2, + GGML_NUMA_STRATEGY_NUMACTL = 3, + GGML_NUMA_STRATEGY_MIRROR = 4, + GGML_NUMA_STRATEGY_COUNT + }; + + GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems + GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node + + GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); + GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + + GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); + GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + + GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); + GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + + GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); + + GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); + GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + + GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + + GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params); + GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool); + GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool); + GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool); + GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool); + + // ggml_graph_plan() has to be called before ggml_graph_compute() + // when plan.work_size > 0, caller must allocate memory for plan.work_data + GGML_BACKEND_API struct ggml_cplan ggml_graph_plan( + const struct ggml_cgraph * cgraph, + int n_threads, /* = GGML_DEFAULT_N_THREADS */ + struct ggml_threadpool * threadpool /* = NULL */ ); + GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + + // same as ggml_graph_compute() but the work data is allocated as a part of the context + // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data + GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); + + // + // system info + // + + // x86 + GGML_BACKEND_API int ggml_cpu_has_sse3 (void); + GGML_BACKEND_API int ggml_cpu_has_ssse3 (void); + GGML_BACKEND_API int ggml_cpu_has_avx (void); + GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void); + GGML_BACKEND_API int ggml_cpu_has_avx2 (void); + GGML_BACKEND_API int ggml_cpu_has_bmi2 (void); + GGML_BACKEND_API int ggml_cpu_has_f16c (void); + GGML_BACKEND_API int ggml_cpu_has_fma (void); + GGML_BACKEND_API int ggml_cpu_has_avx512 (void); + GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void); + GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void); + GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void); + GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void); + // ARM + GGML_BACKEND_API int ggml_cpu_has_neon (void); + GGML_BACKEND_API int ggml_cpu_has_arm_fma (void); + GGML_BACKEND_API int ggml_cpu_has_fp16_va (void); + GGML_BACKEND_API int ggml_cpu_has_dotprod (void); + GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void); + GGML_BACKEND_API int ggml_cpu_has_sve (void); + GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes + GGML_BACKEND_API int ggml_cpu_has_sme (void); + // other + GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); + GGML_BACKEND_API int ggml_cpu_has_vsx (void); + GGML_BACKEND_API int ggml_cpu_has_vxe (void); + GGML_BACKEND_API int ggml_cpu_has_nnpa (void); + GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); + GGML_BACKEND_API int ggml_cpu_has_llamafile (void); + + // Internal types and functions exposed for tests and benchmarks + + typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, + const void * GGML_RESTRICT y, size_t by, int nrc); + + struct ggml_type_traits_cpu { + ggml_from_float_t from_float; + ggml_vec_dot_t vec_dot; + enum ggml_type vec_dot_type; + int64_t nrows; // number of rows to process simultaneously + }; + + GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type); + + GGML_BACKEND_API void ggml_cpu_init(void); + + // + // CPU backend + // + + GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void); + + GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend); + GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); + GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); + GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); + + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); + + GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); + GGML_BACKEND_API void ggml_cpu_bf16_to_fp32(const ggml_bf16_t *, float *, int64_t); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cuda.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cuda.h new file mode 100644 index 0000000..22ad2c0 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-cuda.h @@ -0,0 +1,47 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef GGML_USE_HIP +#define GGML_CUDA_NAME "ROCm" +#define GGML_CUBLAS_NAME "hipBLAS" +#elif defined(GGML_USE_MUSA) +#define GGML_CUDA_NAME "MUSA" +#define GGML_CUBLAS_NAME "muBLAS" +#else +#define GGML_CUDA_NAME "CUDA" +#define GGML_CUBLAS_NAME "cuBLAS" +#endif +#define GGML_CUDA_MAX_DEVICES 16 + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device); + +GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend); + +// device buffer +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); + +// split tensor buffer that splits matrices by rows across multiple devices +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split); + +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); + +GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void); +GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size); +GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total); + +GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); +GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-metal.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-metal.h new file mode 100644 index 0000000..a610694 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-metal.h @@ -0,0 +1,66 @@ +// Note: this description is outdated +// +// An interface allowing to compute ggml_cgraph with Metal +// +// This is a fully functional interface that extends ggml with GPU support for Apple devices. +// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.) +// +// How it works? +// +// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this +// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you +// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.) +// +// You only need to make sure that all memory buffers that you used during the graph creation +// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is +// used during the graph evaluation to determine the arguments of the compute kernels. +// +// Synchronization between device and host memory (for example for input and output tensors) +// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions. +// + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include +#include + +struct ggml_tensor; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +// +// backend API +// user-code should use only these functions +// + +GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void); + +GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); + +GGML_DEPRECATED( + GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), + "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713"); + +GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); + +// helper to check if the device supports a specific family +// ideally, the user code should be doing these checks +// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf +GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); + +// capture all command buffers committed the next time `ggml_backend_graph_compute` is called +GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-opencl.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-opencl.h new file mode 100644 index 0000000..6b61771 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-opencl.h @@ -0,0 +1,26 @@ +#ifndef GGML_OPENCL_H +#define GGML_OPENCL_H + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// +// backend API +// +GGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void); +GGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend); + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void); + +#ifdef __cplusplus +} +#endif + +#endif // GGML_OPENCL_H diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-opt.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-opt.h new file mode 100644 index 0000000..4703a05 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-opt.h @@ -0,0 +1,256 @@ +// This file contains functionality for training models using GGML. +// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets. +// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code. +// +// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + + struct ggml_opt_dataset; + struct ggml_opt_context; + struct ggml_opt_result; + + typedef struct ggml_opt_dataset * ggml_opt_dataset_t; + typedef struct ggml_opt_context * ggml_opt_context_t; + typedef struct ggml_opt_result * ggml_opt_result_t; + + // ====== Loss ====== + + // built-in loss types, i.e. the built-in quantities minimized by the optimizer + // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value + enum ggml_opt_loss_type { + GGML_OPT_LOSS_TYPE_MEAN, + GGML_OPT_LOSS_TYPE_SUM, + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, + GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, + }; + + // ====== Dataset ====== + + GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( + enum ggml_type type_data, // the type for the internal data tensor + enum ggml_type type_label, // the type for the internal labels tensor + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); + + // get underlying tensors that store the data + GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset); + GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] + GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] + + // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative + GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata); + + // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch + GGML_API void ggml_opt_dataset_get_batch( + ggml_opt_dataset_t dataset, + struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] + struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] + int64_t ibatch); + GGML_API void ggml_opt_dataset_get_batch_host( + ggml_opt_dataset_t dataset, + void * data_batch, + size_t nb_data_batch, + void * labels_batch, + int64_t ibatch); + + // ====== Model / Context ====== + + enum ggml_opt_build_type { + GGML_OPT_BUILD_TYPE_FORWARD = 10, + GGML_OPT_BUILD_TYPE_GRAD = 20, + GGML_OPT_BUILD_TYPE_OPT = 30, + }; + + enum ggml_opt_optimizer_type { + GGML_OPT_OPTIMIZER_TYPE_ADAMW, + GGML_OPT_OPTIMIZER_TYPE_SGD, + + GGML_OPT_OPTIMIZER_TYPE_COUNT + }; + + // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss + struct ggml_opt_optimizer_params { + struct { + float alpha; // learning rate + float beta1; // first AdamW momentum + float beta2; // second AdamW momentum + float eps; // epsilon for numerical stability + float wd; // weight decay - 0.0f to disable + } adamw; + struct { + float alpha; // learning rate + float wd; // weight decay + } sgd; + }; + + // callback to calculate optimizer parameters prior to a backward pass + // userdata can be used to pass arbitrary data + typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); + + // returns the default optimizer params (constant, hard-coded values) + // userdata is not used + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); + + // casts userdata to ggml_opt_optimizer_params and returns it + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata); + + // parameters for initializing a new optimization context + struct ggml_opt_params { + ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs + + // by default the forward graph needs to be reconstructed for each eval + // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically + struct ggml_context * ctx_compute; + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; + + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + + int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done + + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + + // only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor + enum ggml_opt_optimizer_type optimizer; + }; + + // get parameters for an optimization context with defaults set where possible + // parameters for which no sensible defaults exist are supplied as arguments to this function + GGML_API struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + enum ggml_opt_loss_type loss_type); + + GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); + GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); + + // set gradients to zero, initilize loss, and optionally reset the optimizer + GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); + + GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically + + // get underlying tensors that store data + // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc + GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor + GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor + GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against + GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss + GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs + GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + + // get the gradient accumulator for a node from the forward graph + GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); + + GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme + + GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type); + + // ====== Optimization Result ====== + + GGML_API ggml_opt_result_t ggml_opt_result_init(void); + GGML_API void ggml_opt_result_free(ggml_opt_result_t result); + GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); + + // get data from result, uncertainties are optional and can be ignored by passing NULL + GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints + GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value + GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values + GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value + + // ====== Computation ====== + + // if not using static graphs, this function must be called prior to ggml_opt_alloc + GGML_API void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs); + + // allocate the next graph for evaluation, either forward or forward + backward + // must be called exactly once prior to calling ggml_opt_eval + GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward); + + // do forward pass, increment result if not NULL, do backward pass if allocated + GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + + // ############################################################################ + // ## The high-level functions start here. They do not depend on any private ## + // ## functions or structs and can be copied to and adapted for user code. ## + // ############################################################################ + + // ====== Intended Usage ====== + // + // 1. Select the appropriate loss for your problem. + // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them. + // Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster). + // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors. + // The first context should contain the model parameters and inputs and be allocated statically in user code. + // The second context should contain all other tensors and will be (re)allocated automatically. + // Due to this automated allocation the data of the second context is not defined when accessed in user code. + // Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors. + // 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead. + + // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation + typedef void (*ggml_opt_epoch_callback)( + bool train, // true after training evaluation, false after validation evaluation + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, // result associated with the dataset subsection + int64_t ibatch, // number of batches that have been evaluated so far + int64_t ibatch_max, // total number of batches in this dataset subsection + int64_t t_start_us); // time at which the evaluation on the dataset subsection was started + + // do training on front of dataset, do evaluation only on back of dataset + GGML_API void ggml_opt_epoch( + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, // result to increment during training, ignored if NULL + ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL + int64_t idata_split, // data index at which to split training and evaluation + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + + // callback that prints a progress bar on stderr + GGML_API void ggml_opt_epoch_callback_progress_bar( + bool train, + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us); + + // fit model defined by inputs and outputs to dataset + GGML_API void ggml_opt_fit( + ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs + struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + ggml_opt_dataset_t dataset, // dataset with data and optionally also labels + enum ggml_opt_loss_type loss_type, // loss to minimize + enum ggml_opt_optimizer_type optimizer, // sgd or adamw + ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) + int64_t nepoch, // how many times the dataset should be iterated over + int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs + float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f) + bool silent); // whether or not info prints to stderr should be suppressed + + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-rpc.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-rpc.h new file mode 100644 index 0000000..1e67411 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-rpc.h @@ -0,0 +1,33 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define RPC_PROTO_MAJOR_VERSION 2 +#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 0 +#define GGML_RPC_MAX_SERVERS 16 + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); + +GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); + +GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, + const char * cache_dir, + size_t free_mem, size_t total_mem); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); + +GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-sycl.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-sycl.h new file mode 100644 index 0000000..5ce349a --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-sycl.h @@ -0,0 +1,49 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#define GGML_SYCL_NAME "SYCL" +#define GGML_SYCL_MAX_DEVICES 48 + +#ifdef __cplusplus +extern "C" { +#endif + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device); + +GGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend); + +// devide buffer +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device); + +// split tensor buffer that splits matrices by rows across multiple devices +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split); + +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); + +GGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void); +GGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len); +GGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device, + char *description, + size_t description_size); +GGML_BACKEND_API int ggml_backend_sycl_get_device_count(); +GGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); + +// SYCL doesn't support registering host memory, keep here for reference +// GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size); +// GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-vulkan.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-vulkan.h new file mode 100644 index 0000000..ed5ea5f --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-vulkan.h @@ -0,0 +1,29 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_VK_NAME "Vulkan" +#define GGML_VK_MAX_DEVICES 16 + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num); + +GGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend); +GGML_BACKEND_API int ggml_backend_vk_get_device_count(void); +GGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size); +GGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total); + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num); +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-webgpu.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-webgpu.h new file mode 100644 index 0000000..65b8ed9 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-webgpu.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_WEBGPU_NAME "WebGPU" + +// Needed for examples in ggml +GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-zdnn.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-zdnn.h new file mode 100644 index 0000000..c2c30c9 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml-zdnn.h @@ -0,0 +1,16 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +GGML_BACKEND_API ggml_backend_t ggml_backend_zdnn_init(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml.h new file mode 100644 index 0000000..da8813f --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/ggml.h @@ -0,0 +1,2467 @@ +#pragma once + +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct ggml_context * ctx = ggml_init(params); +// +// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// +// ggml_set_param(ctx, x); // x is an input variable +// +// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); +// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct ggml_cgraph * gf = ggml_new_graph(ctx); +// ggml_build_forward_expand(gf, f); +// +// // set the input variable and parameter values +// ggml_set_f32(x, 2.0f); +// ggml_set_f32(a, 3.0f); +// ggml_set_f32(b, 4.0f); +// +// ggml_graph_compute_with_ctx(ctx, &gf, n_threads); +// +// printf("f = %f\n", ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the ggml_graph_compute() function. +// +// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - ggml_permute() +// - ggml_conv_1d_1s() +// - ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct ggml_tensor) +// +// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct ggml_tensor * c = ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// const int nx = 2; +// const int ny = 3; +// +// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny); +// +// for (int y = 0; y < ny; y++) { +// for (int x = 0; x < nx; x++) { +// *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y; +// } +// } +// +// ... +// } +// +// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + +#ifdef GGML_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BUILD +# define GGML_API __declspec(dllexport) extern +# else +# define GGML_API __declspec(dllimport) extern +# endif +# else +# define GGML_API __attribute__ ((visibility ("default"))) extern +# endif +#else +# define GGML_API extern +#endif + +// TODO: support for clang +#ifdef __GNUC__ +# define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define GGML_DEPRECATED(func, hint) func +#endif + +#ifndef __GNUC__ +# define GGML_ATTRIBUTE_FORMAT(...) +#elif defined(__MINGW32__) && !defined(__clang__) +# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif + +#include +#include +#include +#include + +#define GGML_FILE_MAGIC 0x67676d6c // "ggml" +#define GGML_FILE_VERSION 2 + +#define GGML_QNT_VERSION 2 // bump this on quantization format changes +#define GGML_QNT_VERSION_FACTOR 1000 // do not change this + +#define GGML_MAX_DIMS 4 +#define GGML_MAX_PARAMS 2048 +#define GGML_MAX_SRC 10 +#define GGML_MAX_N_THREADS 512 +#define GGML_MAX_OP_PARAMS 64 + +#ifndef GGML_MAX_NAME +# define GGML_MAX_NAME 64 +#endif + +#define GGML_DEFAULT_N_THREADS 4 +#define GGML_DEFAULT_GRAPH_SIZE 2048 + +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif + +#define GGML_EXIT_SUCCESS 0 +#define GGML_EXIT_ABORTED 1 + +#define GGML_ROPE_TYPE_NEOX 2 +#define GGML_ROPE_TYPE_MROPE 8 +#define GGML_ROPE_TYPE_VISION 24 + +#define GGML_MROPE_SECTIONS 4 + +#define GGML_UNUSED(x) (void)(x) + +#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#ifndef NDEBUG +# define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) +#elif defined(__GNUC__) +# define GGML_UNREACHABLE() __builtin_unreachable() +#elif defined(_MSC_VER) +# define GGML_UNREACHABLE() __assume(0) +#else +# define GGML_UNREACHABLE() ((void) 0) +#endif + +#ifdef __cplusplus +# define GGML_NORETURN [[noreturn]] +#elif defined(_MSC_VER) +# define GGML_NORETURN __declspec(noreturn) +#else +# define GGML_NORETURN _Noreturn +#endif + +#define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__) +#define GGML_ASSERT(x) if (!(x)) GGML_ABORT("GGML_ASSERT(%s) failed", #x) + +// used to copy the number of elements and stride in bytes of tensors into local variables. +// main purpose is to reduce code duplication and improve readability. +// +// example: +// +// GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); +// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); +// +#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ + const type prefix##0 = (pointer)->array[0]; \ + GGML_UNUSED(prefix##0); +#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ + const type prefix##1 = (pointer)->array[1]; \ + GGML_UNUSED(prefix##1); +#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ + const type prefix##2 = (pointer)->array[2]; \ + GGML_UNUSED(prefix##2); +#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ + const type prefix##3 = (pointer)->array[3]; \ + GGML_UNUSED(prefix##3); + +#define GGML_TENSOR_UNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + +#define GGML_TENSOR_BINARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + +#define GGML_TENSOR_TERNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \ + GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + +#define GGML_TENSOR_BINARY_OP_LOCALS01 \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + +#ifdef __cplusplus +extern "C" { +#endif + + // Function type used in fatal error callbacks + typedef void (*ggml_abort_callback_t)(const char * error_message); + + // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout) + // Returns the old callback for chaining + GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback); + + GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4) + GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...); + + enum ggml_status { + GGML_STATUS_ALLOC_FAILED = -2, + GGML_STATUS_FAILED = -1, + GGML_STATUS_SUCCESS = 0, + GGML_STATUS_ABORTED = 1, + }; + + // get ggml_status name string + GGML_API const char * ggml_status_to_string(enum ggml_status status); + + // ieee 754-2008 half-precision float16 + // todo: make this not an integral type + typedef uint16_t ggml_fp16_t; + GGML_API float ggml_fp16_to_fp32(ggml_fp16_t); + GGML_API ggml_fp16_t ggml_fp32_to_fp16(float); + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t); + GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t); + + // google brain half-precision bfloat16 + typedef struct { uint16_t bits; } ggml_bf16_t; + GGML_API ggml_bf16_t ggml_fp32_to_bf16(float); + GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16 + GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t); + GGML_API void ggml_fp32_to_bf16_row_ref(const float *, ggml_bf16_t *, int64_t); + GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t); + + struct ggml_object; + struct ggml_context; + struct ggml_cgraph; + + // NOTE: always add types at the end of the enum to keep backward compatibility + enum ggml_type { + GGML_TYPE_F32 = 0, + GGML_TYPE_F16 = 1, + GGML_TYPE_Q4_0 = 2, + GGML_TYPE_Q4_1 = 3, + // GGML_TYPE_Q4_2 = 4, support has been removed + // GGML_TYPE_Q4_3 = 5, support has been removed + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, + GGML_TYPE_Q2_K = 10, + GGML_TYPE_Q3_K = 11, + GGML_TYPE_Q4_K = 12, + GGML_TYPE_Q5_K = 13, + GGML_TYPE_Q6_K = 14, + GGML_TYPE_Q8_K = 15, + GGML_TYPE_IQ2_XXS = 16, + GGML_TYPE_IQ2_XS = 17, + GGML_TYPE_IQ3_XXS = 18, + GGML_TYPE_IQ1_S = 19, + GGML_TYPE_IQ4_NL = 20, + GGML_TYPE_IQ3_S = 21, + GGML_TYPE_IQ2_S = 22, + GGML_TYPE_IQ4_XS = 23, + GGML_TYPE_I8 = 24, + GGML_TYPE_I16 = 25, + GGML_TYPE_I32 = 26, + GGML_TYPE_I64 = 27, + GGML_TYPE_F64 = 28, + GGML_TYPE_IQ1_M = 29, + GGML_TYPE_BF16 = 30, + // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files + // GGML_TYPE_Q4_0_4_8 = 32, + // GGML_TYPE_Q4_0_8_8 = 33, + GGML_TYPE_TQ1_0 = 34, + GGML_TYPE_TQ2_0 = 35, + // GGML_TYPE_IQ4_NL_4_4 = 36, + // GGML_TYPE_IQ4_NL_4_8 = 37, + // GGML_TYPE_IQ4_NL_8_8 = 38, + GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) + GGML_TYPE_COUNT = 40, + }; + + // precision + enum ggml_prec { + GGML_PREC_DEFAULT = 0, // stored as ggml_tensor.op_params, 0 by default + GGML_PREC_F32 = 10, + }; + + // model file types + enum ggml_ftype { + GGML_FTYPE_UNKNOWN = -1, + GGML_FTYPE_ALL_F32 = 0, + GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors + GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors + GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors + GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors + }; + + // available tensor operations: + enum ggml_op { + GGML_OP_NONE = 0, + + GGML_OP_DUP, + GGML_OP_ADD, + GGML_OP_ADD_ID, + GGML_OP_ADD1, + GGML_OP_ACC, + GGML_OP_SUB, + GGML_OP_MUL, + GGML_OP_DIV, + GGML_OP_SQR, + GGML_OP_SQRT, + GGML_OP_LOG, + GGML_OP_SIN, + GGML_OP_COS, + GGML_OP_SUM, + GGML_OP_SUM_ROWS, + GGML_OP_MEAN, + GGML_OP_ARGMAX, + GGML_OP_COUNT_EQUAL, + GGML_OP_REPEAT, + GGML_OP_REPEAT_BACK, + GGML_OP_CONCAT, + GGML_OP_SILU_BACK, + GGML_OP_NORM, // normalize + GGML_OP_RMS_NORM, + GGML_OP_RMS_NORM_BACK, + GGML_OP_GROUP_NORM, + GGML_OP_L2_NORM, + + GGML_OP_MUL_MAT, + GGML_OP_MUL_MAT_ID, + GGML_OP_OUT_PROD, + + GGML_OP_SCALE, + GGML_OP_SET, + GGML_OP_CPY, + GGML_OP_CONT, + GGML_OP_RESHAPE, + GGML_OP_VIEW, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_GET_ROWS_BACK, + GGML_OP_SET_ROWS, + GGML_OP_DIAG, + GGML_OP_DIAG_MASK_INF, + GGML_OP_DIAG_MASK_ZERO, + GGML_OP_SOFT_MAX, + GGML_OP_SOFT_MAX_BACK, + GGML_OP_ROPE, + GGML_OP_ROPE_BACK, + GGML_OP_CLAMP, + GGML_OP_CONV_TRANSPOSE_1D, + GGML_OP_IM2COL, + GGML_OP_IM2COL_BACK, + GGML_OP_CONV_2D, + GGML_OP_CONV_2D_DW, + GGML_OP_CONV_TRANSPOSE_2D, + GGML_OP_POOL_1D, + GGML_OP_POOL_2D, + GGML_OP_POOL_2D_BACK, + GGML_OP_UPSCALE, + GGML_OP_PAD, + GGML_OP_PAD_REFLECT_1D, + GGML_OP_ROLL, + GGML_OP_ARANGE, + GGML_OP_TIMESTEP_EMBEDDING, + GGML_OP_ARGSORT, + GGML_OP_LEAKY_RELU, + + GGML_OP_FLASH_ATTN_EXT, + GGML_OP_FLASH_ATTN_BACK, + GGML_OP_SSM_CONV, + GGML_OP_SSM_SCAN, + GGML_OP_WIN_PART, + GGML_OP_WIN_UNPART, + GGML_OP_GET_REL_POS, + GGML_OP_ADD_REL_POS, + GGML_OP_RWKV_WKV6, + GGML_OP_GATED_LINEAR_ATTN, + GGML_OP_RWKV_WKV7, + + GGML_OP_UNARY, + + GGML_OP_MAP_CUSTOM1, + GGML_OP_MAP_CUSTOM2, + GGML_OP_MAP_CUSTOM3, + + GGML_OP_CUSTOM, + + GGML_OP_CROSS_ENTROPY_LOSS, + GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_OPT_STEP_ADAMW, + GGML_OP_OPT_STEP_SGD, + + GGML_OP_GLU, + + GGML_OP_COUNT, + }; + + enum ggml_unary_op { + GGML_UNARY_OP_ABS, + GGML_UNARY_OP_SGN, + GGML_UNARY_OP_NEG, + GGML_UNARY_OP_STEP, + GGML_UNARY_OP_TANH, + GGML_UNARY_OP_ELU, + GGML_UNARY_OP_RELU, + GGML_UNARY_OP_SIGMOID, + GGML_UNARY_OP_GELU, + GGML_UNARY_OP_GELU_QUICK, + GGML_UNARY_OP_SILU, + GGML_UNARY_OP_HARDSWISH, + GGML_UNARY_OP_HARDSIGMOID, + GGML_UNARY_OP_EXP, + GGML_UNARY_OP_GELU_ERF, + + GGML_UNARY_OP_COUNT, + }; + + enum ggml_glu_op { + GGML_GLU_OP_REGLU, + GGML_GLU_OP_GEGLU, + GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_SWIGLU_OAI, + GGML_GLU_OP_GEGLU_ERF, + GGML_GLU_OP_GEGLU_QUICK, + + GGML_GLU_OP_COUNT, + }; + + enum ggml_object_type { + GGML_OBJECT_TYPE_TENSOR, + GGML_OBJECT_TYPE_GRAPH, + GGML_OBJECT_TYPE_WORK_BUFFER + }; + + enum ggml_log_level { + GGML_LOG_LEVEL_NONE = 0, + GGML_LOG_LEVEL_DEBUG = 1, + GGML_LOG_LEVEL_INFO = 2, + GGML_LOG_LEVEL_WARN = 3, + GGML_LOG_LEVEL_ERROR = 4, + GGML_LOG_LEVEL_CONT = 5, // continue previous log + }; + + // this tensor... + enum ggml_tensor_flag { + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + }; + + struct ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally + bool no_alloc; // don't allocate memory for the tensor data + }; + + // n-dimensional tensor + struct ggml_tensor { + enum ggml_type type; + + struct ggml_backend_buffer * buffer; + + int64_t ne[GGML_MAX_DIMS]; // number of elements + size_t nb[GGML_MAX_DIMS]; // stride in bytes: + // nb[0] = ggml_type_size(type) + // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding + // nb[i] = nb[i-1] * ne[i-1] + + // compute data + enum ggml_op op; + + // op params - allocated as int32_t for alignment + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + + int32_t flags; + + struct ggml_tensor * src[GGML_MAX_SRC]; + + // source tensor and offset for views + struct ggml_tensor * view_src; + size_t view_offs; + + void * data; + + char name[GGML_MAX_NAME]; + + void * extra; // extra things e.g. for ggml-cuda.cu + + char padding[8]; + }; + + static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); + + // Abort callback + // If not NULL, called before ggml computation + // If it returns true, the computation is aborted + typedef bool (*ggml_abort_callback)(void * data); + + + // + // GUID + // + + // GUID types + typedef uint8_t ggml_guid[16]; + typedef ggml_guid * ggml_guid_t; + + GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b); + + // misc + + GGML_API const char * ggml_version(void); + GGML_API const char * ggml_commit(void); + + GGML_API void ggml_time_init(void); // call this once at the beginning of the program + GGML_API int64_t ggml_time_ms(void); + GGML_API int64_t ggml_time_us(void); + GGML_API int64_t ggml_cycles(void); + GGML_API int64_t ggml_cycles_per_ms(void); + + // accepts a UTF-8 path, even on Windows + GGML_API FILE * ggml_fopen(const char * fname, const char * mode); + + GGML_API void ggml_print_object (const struct ggml_object * obj); + GGML_API void ggml_print_objects(const struct ggml_context * ctx); + + GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); + GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes_pad(const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN + + GGML_API int64_t ggml_blck_size(enum ggml_type type); + GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block + GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row + + GGML_DEPRECATED( + GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float + "use ggml_row_size() instead"); + + GGML_API const char * ggml_type_name(enum ggml_type type); + GGML_API const char * ggml_op_name (enum ggml_op op); + GGML_API const char * ggml_op_symbol(enum ggml_op op); + + GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); + GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op); + GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name + + GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); + + GGML_API bool ggml_is_quantized(enum ggml_type type); + + // TODO: temporary until model loading of ggml examples is refactored + GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); + + GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); + GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars + + // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation) + GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() + GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 + GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 + + // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok) + GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor); + + // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN + GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor); + + // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements + GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor); + + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); + GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); + + GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1); + + // use this to compute the memory overhead of a tensor + GGML_API size_t ggml_tensor_overhead(void); + + GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes); + + // main + + GGML_API struct ggml_context * ggml_init (struct ggml_init_params params); + GGML_API void ggml_reset(struct ggml_context * ctx); + GGML_API void ggml_free (struct ggml_context * ctx); + + GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); + + GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx); + GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); + + GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx); + GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx); + GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx); + + GGML_API struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t *ne); + + GGML_API struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + GGML_API void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes); + + GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); + GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); + + // Context tensor enumeration and lookup + GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx); + GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); + + // Converts a flat index into coordinates + GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); + + GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor); + + GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); + GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); + + GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); + GGML_ATTRIBUTE_FORMAT(2, 3) + GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); + + // Tensor flags + GGML_API void ggml_set_input(struct ggml_tensor * tensor); + GGML_API void ggml_set_output(struct ggml_tensor * tensor); + GGML_API void ggml_set_param(struct ggml_tensor * tensor); + GGML_API void ggml_set_loss(struct ggml_tensor * tensor); + + // + // operations on tensors with backpropagation + // + + GGML_API struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type); + + // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]] + GGML_API struct ggml_tensor * ggml_add_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids); + + GGML_API struct ggml_tensor * ggml_add1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // dst = a + // view(dst, nb1, nb2, nb3, offset) += b + // return dst + GGML_API struct ggml_tensor * ggml_acc( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_acc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_log( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_log_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sin( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sin_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_cos( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_cos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // return scalar + GGML_API struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d] + GGML_API struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // mean along rows + GGML_API struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // argmax along rows + GGML_API struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // count number of equal elements in a and b + GGML_API struct ggml_tensor * ggml_count_equal( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // if a is the same shape as b, and a is not parameter, return a + // otherwise, return a new tensor: repeat(a) to fit in b + GGML_API struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // repeat a to the specified shape + GGML_API struct ggml_tensor * ggml_repeat_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // sums repetitions in a into shape of b + GGML_API struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride + + // concat a and b along dim + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_concat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int dim); + + GGML_API struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_leaky_relu( + struct ggml_context * ctx, + struct ggml_tensor * a, float negative_slope, bool inplace); + + GGML_API struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sigmoid( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sigmoid_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // GELU using erf (error function) when possible + // some backends may fallback to approximation based on Abramowitz and Stegun formula + GGML_API struct ggml_tensor * ggml_gelu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_erf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // a - x + // b - dy + GGML_API struct ggml_tensor * ggml_silu_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // hardswish(x) = x * relu6(x + 3) / 6 + GGML_API struct ggml_tensor * ggml_hardswish( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // hardsigmoid(x) = relu6(x + 3) / 6 + GGML_API struct ggml_tensor * ggml_hardsigmoid( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_exp( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_exp_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // gated linear unit ops + // A: n columns, r rows, + // result is n / 2 columns, r rows, + // expects gate in second half of row, unless swapped is true + GGML_API struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op, + bool swapped); + + GGML_API struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_erf_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_quick_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // A: n columns, r rows, + // B: n columns, r rows, + GGML_API struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op); + + GGML_API struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_erf_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_quick_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_swiglu_oai( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float alpha, + float limit); + + // normalize along rows + GGML_API struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + // group normalize along ne0*ne1*n_groups + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps); + + GGML_API struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps); + + // l2 normalize along rows + // used in rwkv v7 + GGML_API struct ggml_tensor * ggml_l2_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_l2_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + // a - x + // b - dy + GGML_API struct ggml_tensor * ggml_rms_norm_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps); + + // A: k columns, n rows => [ne03, ne02, n, k] + // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k] + // result is n columns, m rows => [ne03 * x, ne02 * y, m, n] + GGML_API struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // change the precision of a matrix multiplication + // set to GGML_PREC_F32 for higher precision (useful for phi-2) + GGML_API void ggml_mul_mat_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + + // indirect matrix multiplication + GGML_API struct ggml_tensor * ggml_mul_mat_id( + struct ggml_context * ctx, + struct ggml_tensor * as, + struct ggml_tensor * b, + struct ggml_tensor * ids); + + // A: m columns, n rows, + // B: p columns, n rows, + // result is m columns, p rows + GGML_API struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // + // operations on tensors without backpropagation + // + + GGML_API struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s); + + // x = s * a + b + GGML_API struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + + GGML_API struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + + // b -> view(a,offset,nb1,nb2,3), return modified a + GGML_API struct ggml_tensor * ggml_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); // in bytes + + // b -> view(a,offset,nb1,nb2,3), return view(a) + GGML_API struct ggml_tensor * ggml_set_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); // in bytes + + GGML_API struct ggml_tensor * ggml_set_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset); // in bytes + + GGML_API struct ggml_tensor * ggml_set_1d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset); // in bytes + + // b -> view(a,offset,nb1,nb2,3), return modified a + GGML_API struct ggml_tensor * ggml_set_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset); // in bytes + + // b -> view(a,offset,nb1,nb2,3), return view(a) + GGML_API struct ggml_tensor * ggml_set_2d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset); // in bytes + + // a -> b, return view(b) + GGML_API struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_type type); + + // make contiguous + GGML_API struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // make contiguous, with new shape + GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // return view(a), b specifies the new shape + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // offset in bytes + GGML_API struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, // row stride in bytes + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3); + + // alias for ggml_permute(ctx, a, 1, 0, 2, 3) + GGML_API struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // supports 3D: a->ne[2] == b->ne[1] + GGML_API struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, // data + struct ggml_tensor * b); // row indices + + GGML_API struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, // gradients of ggml_get_rows result + struct ggml_tensor * b, // row indices + struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape + + // a TD [n_embd, ne1, ne2, ne3] + // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3 + // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1) + // + // undefined behavior if destination rows overlap + // + // broadcast: + // ne2 % ne11 == 0 + // ne3 % ne12 == 0 + // + // return view(a) + GGML_API struct ggml_tensor * ggml_set_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, // destination + struct ggml_tensor * b, // source + struct ggml_tensor * c); // row indices + + GGML_API struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // set elements above the diagonal to -INF + GGML_API struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // set elements above the diagonal to 0 + GGML_API struct ggml_tensor * ggml_diag_mask_zero( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + GGML_API struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // a [ne0, ne01, ne02, ne03] + // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional + // + // broadcast: + // ne02 % ne12 == 0 + // ne03 % ne13 == 0 + // + // fused soft_max(a*scale + mask*(ALiBi slope)) + // max_bias = 0.0f for no ALiBi + GGML_API struct ggml_tensor * ggml_soft_max_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias); + + GGML_API void ggml_soft_max_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks); + + GGML_API struct ggml_tensor * ggml_soft_max_ext_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float scale, + float max_bias); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float scale, + float max_bias); + + // rotary position embedding + // if (mode & 1) - skip n_past elements (NOT SUPPORTED) + // if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style + // + // b is an int32 vector with size a->ne[2], it contains the positions + GGML_API struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode); + + // custom RoPE + // c is freq factors (e.g. phi3-128k), (optional) + GGML_API struct ggml_tensor * ggml_rope_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + GGML_API struct ggml_tensor * ggml_rope_multi( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[GGML_MROPE_SECTIONS], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + GGML_API struct ggml_tensor * ggml_rope_multi_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[GGML_MROPE_SECTIONS], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext_inplace instead"); + + // compute correction dims for YaRN RoPE scaling + GGML_API void ggml_rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); + + // rotary position embedding backward, i.e compute dx from dy + // a - dy + GGML_API struct ggml_tensor * ggml_rope_ext_back( + struct ggml_context * ctx, + struct ggml_tensor * a, // gradients of ggml_rope result + struct ggml_tensor * b, // positions + struct ggml_tensor * c, // freq factors + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + GGML_API struct ggml_tensor * ggml_rope_multi_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + + // clamp + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max); + + // im2col + // converts data into a format that effectively results in a convolution when combined with matrix multiplication + GGML_API struct ggml_tensor * ggml_im2col( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1, // dilation dimension 1 + bool is_2D, + enum ggml_type dst_type); + + GGML_API struct ggml_tensor * ggml_im2col_back( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // gradient of im2col output + int64_t * ne, // shape of im2col input + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1, // dilation dimension 1 + bool is_2D); + + GGML_API struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation + + // conv_1d with padding = half + // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) + GGML_API struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s, // stride + int d); // dilation + + // depthwise + // TODO: this is very likely wrong for some cases! - needs more testing + GGML_API struct ggml_tensor * ggml_conv_1d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation + + GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int d0); // dilation + + GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation + + GGML_API struct ggml_tensor * ggml_conv_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + + // kernel size is a->ne[0] x a->ne[1] + // stride is equal to kernel size + // padding is zero + // example: + // a: 16 16 3 768 + // b: 1024 1024 3 1 + // res: 64 64 768 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // kernel size is a->ne[0] x a->ne[1] + // stride is 1 + // padding is half + // example: + // a: 3 3 256 256 + // b: 64 64 256 1 + // res: 64 64 256 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // depthwise (via im2col and mul_mat) + GGML_API struct ggml_tensor * ggml_conv_2d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + + // Depthwise 2D convolution + // may be faster than ggml_conv_2d_dw, but not available in all backends + // a: KW KH 1 C convolution kernel + // b: W H C N input data + // res: W_out H_out C N + GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride0, + int stride1, + int pad0, + int pad1, + int dilation0, + int dilation1); + + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride); + + GGML_API struct ggml_tensor * ggml_conv_2d_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + + enum ggml_op_pool { + GGML_OP_POOL_MAX, + GGML_OP_POOL_AVG, + GGML_OP_POOL_COUNT, + }; + + GGML_API struct ggml_tensor * ggml_pool_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, // kernel size + int s0, // stride + int p0); // padding + + // the result will have 2*p0 padding for the first dimension + // and 2*p1 padding for the second dimension + GGML_API struct ggml_tensor * ggml_pool_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1); + + GGML_API struct ggml_tensor * ggml_pool_2d_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * af, // "a"/input used in forward pass + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1); + + enum ggml_scale_mode { + GGML_SCALE_MODE_NEAREST = 0, + GGML_SCALE_MODE_BILINEAR = 1, + + GGML_SCALE_MODE_COUNT + }; + + enum ggml_scale_flag { + GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8) + }; + + // interpolate + // multiplies ne0 and ne1 by scale factor + GGML_API struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor, + enum ggml_scale_mode mode); + + // interpolate + // interpolate scale to specified dimensions + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_upscale_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3, + enum ggml_scale_mode mode), + "use ggml_interpolate instead"); + + // Up- or downsamples the input to the specified size. + // 2D scale modes (eg. bilinear) are applied to the first two dimensions. + GGML_API struct ggml_tensor * ggml_interpolate( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + uint32_t mode); // ggml_scale_mode [ | ggml_scale_flag...] + + // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] + GGML_API struct ggml_tensor * ggml_pad( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3); + + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] + GGML_API struct ggml_tensor * ggml_pad_reflect_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1); + + // Move tensor elements by an offset given for each dimension. Elements that + // are shifted beyond the last position are wrapped around to the beginning. + GGML_API struct ggml_tensor * ggml_roll( + struct ggml_context * ctx, + struct ggml_tensor * a, + int shift0, + int shift1, + int shift2, + int shift3); + + + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 + // timesteps: [N,] + // return: [N, dim] + GGML_API struct ggml_tensor * ggml_timestep_embedding( + struct ggml_context * ctx, + struct ggml_tensor * timesteps, + int dim, + int max_period); + + // sort rows + enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, + }; + + GGML_API struct ggml_tensor * ggml_argsort( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_sort_order order); + + GGML_API struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step); + + // top k elements per row + GGML_API struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + +#define GGML_KQ_MASK_PAD 64 + + // q: [n_embd_k, n_batch, n_head, ne3 ] + // k: [n_embd_k, n_kv, n_head_kv, ne3 ] + // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! + // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! + // + // broadcast: + // n_head % n_head_kv == 0 + // n_head % ne32 == 0 + // ne3 % ne33 == 0 + // + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale, + float max_bias, + float logit_softcap); + + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + + GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec( + const struct ggml_tensor * a); + + GGML_API void ggml_flash_attn_ext_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks); + + // TODO: needs to be adapted to ggml_flash_attn_ext + GGML_API struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked); + + GGML_API struct ggml_tensor * ggml_ssm_conv( + struct ggml_context * ctx, + struct ggml_tensor * sx, + struct ggml_tensor * c); + + GGML_API struct ggml_tensor * ggml_ssm_scan( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * dt, + struct ggml_tensor * A, + struct ggml_tensor * B, + struct ggml_tensor * C, + struct ggml_tensor * ids); + + // partition into non-overlapping windows with padding if needed + // example: + // a: 768 64 64 1 + // w: 14 + // res: 768 14 14 25 + // used in sam + GGML_API struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w); + + // reverse of ggml_win_part + // used in sam + GGML_API struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w); + + GGML_API struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + GGML_API struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + // used in sam + GGML_API struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh); + + // used in sam + GGML_API struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + GGML_API struct ggml_tensor * ggml_rwkv_wkv6( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state); + + GGML_API struct ggml_tensor * ggml_gated_linear_attn( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * q, + struct ggml_tensor * g, + struct ggml_tensor * state, + float scale); + + GGML_API struct ggml_tensor * ggml_rwkv_wkv7( + struct ggml_context * ctx, + struct ggml_tensor * r, + struct ggml_tensor * w, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * state); + + // custom operators + + typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); + typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); + typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); + +#define GGML_N_TASKS_MAX (-1) + // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks + + GGML_API struct ggml_tensor * ggml_map_custom1( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom2( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom2_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom3_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_t fun, + int n_tasks, + void * userdata); + + typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata); + + GGML_API struct ggml_tensor * ggml_custom_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata); + + // loss function + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, // logits + struct ggml_tensor * b); // labels + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, // logits + struct ggml_tensor * b, // labels + struct ggml_tensor * c); // gradients of cross_entropy_loss result + + // AdamW optimizer step + // Paper: https://arxiv.org/pdf/1711.05101v3.pdf + // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html + GGML_API struct ggml_tensor * ggml_opt_step_adamw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * grad, + struct ggml_tensor * m, + struct ggml_tensor * v, + struct ggml_tensor * adamw_params); // parameters such as the learning rate + + // stochastic gradient descent step (with weight decay) + GGML_API struct ggml_tensor * ggml_opt_step_sgd( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * grad, + struct ggml_tensor * sgd_params); // alpha, weight decay + + // + // automatic differentiation + // + + GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand( + struct ggml_context * ctx, // context for gradient computation + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs); + + // graph allocation in a context + GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false + GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads); + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads); + GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); + GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 + GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); + + GGML_API int ggml_graph_size (struct ggml_cgraph * cgraph); + GGML_API struct ggml_tensor * ggml_graph_node (struct ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i] + GGML_API struct ggml_tensor ** ggml_graph_nodes (struct ggml_cgraph * cgraph); + GGML_API int ggml_graph_n_nodes(struct ggml_cgraph * cgraph); + + GGML_API void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + + GGML_API size_t ggml_graph_overhead(void); + GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); + + GGML_API struct ggml_tensor * ggml_graph_get_tensor (const struct ggml_cgraph * cgraph, const char * name); + GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); + GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); + + // print info and performance information for the graph + GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); + + // dump the graph into a file using the dot format + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + + // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? + typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); + + // Set callback for all future logging events. + // If this is not called, or NULL is supplied, everything is output on stderr. + GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data); + + GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); + + // + // quantization + // + + // - ggml_quantize_init can be called multiple times with the same type + // it will only initialize the quantization tables for the first call or after ggml_quantize_free + // automatically called by ggml_quantize_chunk for convenience + // + // - ggml_quantize_free will free any memory allocated by ggml_quantize_init + // call this at the end of the program to avoid memory leaks + // + // note: these are thread-safe + // + GGML_API void ggml_quantize_init(enum ggml_type type); + GGML_API void ggml_quantize_free(void); + + // some quantization type cannot be used without an importance matrix + GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type); + + // calls ggml_quantize_init internally (i.e. can allocate memory) + GGML_API size_t ggml_quantize_chunk( + enum ggml_type type, + const float * src, + void * dst, + int64_t start, + int64_t nrows, + int64_t n_per_row, + const float * imatrix); + +#ifdef __cplusplus + // restrict not standard in C++ +# if defined(__GNUC__) +# define GGML_RESTRICT __restrict__ +# elif defined(__clang__) +# define GGML_RESTRICT __restrict +# elif defined(_MSC_VER) +# define GGML_RESTRICT __restrict +# else +# define GGML_RESTRICT +# endif +#else +# if defined (_MSC_VER) && (__STDC_VERSION__ < 201112L) +# define GGML_RESTRICT __restrict +# else +# define GGML_RESTRICT restrict +# endif +#endif + typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + + struct ggml_type_traits { + const char * type_name; + int64_t blck_size; + int64_t blck_size_interleave; // interleave elements in blocks + size_t type_size; + bool is_quantized; + ggml_to_float_t to_float; + ggml_from_float_t from_float_ref; + }; + + GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type); + + // ggml threadpool + // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend + // the goal should be to create an API that other backends can use move everything to the ggml base + + // scheduling priorities + enum ggml_sched_priority { + GGML_SCHED_PRIO_LOW = -1, + GGML_SCHED_PRIO_NORMAL, + GGML_SCHED_PRIO_MEDIUM, + GGML_SCHED_PRIO_HIGH, + GGML_SCHED_PRIO_REALTIME + }; + + // threadpool params + // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults + struct ggml_threadpool_params { + bool cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) + int n_threads; // number of threads + enum ggml_sched_priority prio; // thread priority + uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) + bool strict_cpu; // strict cpu placement + bool paused; // start in paused state + }; + + struct ggml_threadpool; // forward declaration, see ggml.c + + typedef struct ggml_threadpool * ggml_threadpool_t; + + GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads); + GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads); + GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/include/gguf.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/gguf.h new file mode 100644 index 0000000..79ee202 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/include/gguf.h @@ -0,0 +1,202 @@ +// This file contains functionality related to "GGUF" files, the binary file format used by ggml. +// GGUF files have the following structure: +// +// 1. File magic "GGUF" (4 bytes). +// 2. File version (uint32_t). +// 3. Number of ggml tensors in file (int64_t). +// 4. Number of key-value-pairs in file (int64_t). +// 5. For each KV pair: +// 1. The key (string). +// 2. The value type (gguf_type). +// 3a. If the value type is GGUF_TYPE_ARRAY: +// 1. The type of the array (gguf_type). +// 2. The number of elements in the array (uint64_t). +// 3. The binary representation of each element in the array. +// 3b. Otherwise: +// 1. The binary representation of the value. +// 6. For each ggml tensor: +// 1. The tensor name (string). +// 2. The number of dimensions of the tensor (uint32_t). +// 3. For each dimension: +// 1. The size of the tensor in the dimension (int64_t). +// 4. The tensor data type (ggml_type). +// 5. The tensor data offset in the tensor data binary blob (uint64_t). +// 7. The tensor data binary blob (optional, aligned). +// +// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator. +// All enums are stored as int32_t. +// All bool values are stored as int8_t. +// If the special key "general.alignment" (uint32_t) is defined it is used for alignment, +// otherwise GGUF_DEFAULT_ALIGNMENT is used. +// +// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) + +#pragma once + +#include "ggml.h" + +#include +#include + +#define GGUF_MAGIC "GGUF" +#define GGUF_VERSION 3 + +#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment" + +#define GGUF_DEFAULT_ALIGNMENT 32 + +#ifdef __cplusplus +extern "C" { +#endif + + // types that can be stored as GGUF KV data + enum gguf_type { + GGUF_TYPE_UINT8 = 0, + GGUF_TYPE_INT8 = 1, + GGUF_TYPE_UINT16 = 2, + GGUF_TYPE_INT16 = 3, + GGUF_TYPE_UINT32 = 4, + GGUF_TYPE_INT32 = 5, + GGUF_TYPE_FLOAT32 = 6, + GGUF_TYPE_BOOL = 7, + GGUF_TYPE_STRING = 8, + GGUF_TYPE_ARRAY = 9, + GGUF_TYPE_UINT64 = 10, + GGUF_TYPE_INT64 = 11, + GGUF_TYPE_FLOAT64 = 12, + GGUF_TYPE_COUNT, // marks the end of the enum + }; + + struct gguf_context; + + struct gguf_init_params { + bool no_alloc; + + // if not NULL, create a ggml_context and allocate the tensor data in it + struct ggml_context ** ctx; + }; + + GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); + //GGML_API struct gguf_context * gguf_init_from_buffer(..); + + GGML_API void gguf_free(struct gguf_context * ctx); + + GGML_API const char * gguf_type_name(enum gguf_type type); + + GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx); + GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + + GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx); + GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found + GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id); + + GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id); + GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id); + + // will abort if the wrong type is used for the key + GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id); + GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id); + GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id); + GGML_API size_t gguf_get_arr_n (const struct gguf_context * ctx, int64_t key_id); + + // get raw pointer to the first element of the array with the given key_id + // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference) + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id); + + // get ith C string from array with given key_id + GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); + + GGML_API int64_t gguf_get_n_tensors (const struct gguf_context * ctx); + GGML_API int64_t gguf_find_tensor (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found + GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id); + GGML_API const char * gguf_get_tensor_name (const struct gguf_context * ctx, int64_t tensor_id); + GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int64_t tensor_id); + GGML_API size_t gguf_get_tensor_size (const struct gguf_context * ctx, int64_t tensor_id); + + // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist) + GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key); + + // overrides an existing KV pair or adds a new one, the new KV pair is always at the back + GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); + GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); + GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); + GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); + GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); + GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); + GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); + GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); + GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); + GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); + GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); + GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); + + // creates a new array with n elements of the given type and copies the corresponding number of bytes from data + GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n); + + // creates a new array with n strings and copies the corresponding strings from data + GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n); + + // set or add KV pairs from another context + GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src); + + // add tensor to GGUF context, tensor name must be unique + GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); + + // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated + // in such a way that the tensor data remains as one contiguous block (except for padding) + GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); + + // assumes that at least gguf_get_tensor_size bytes can be read from data + GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data); + + // writing gguf files can be done in 3 ways: + // + // - write the entire gguf_context to a binary file in a single pass: + // + // gguf_write_to_file(ctx, fname, /*only_meta =*/ false); + // + // - write only the meta data to a file, then re-open the file and append the tensor data: + // + // gguf_write_to_file(ctx, fname, /*only_meta =*/ true); + // FILE * f = fopen(fname, "ab"); + // fwrite(f, ...); // write tensor data + // fclose(f); + // + // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: + // + // FILE * f = fopen(fname, "wb"); + // const size_t size_meta = gguf_get_meta_size(ctx); + // fseek(f, size_meta, SEEK_SET); + // fwrite(f, ...); // write tensor data + // void * data = malloc(size_meta); + // gguf_get_meta_data(ctx, data); + // rewind(f); + // fwrite(data, 1, data, f); + // free(data); + // fclose(f); + // + + // write the entire context to a binary file + GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); + + // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding + GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); + + // writes the meta data to pointer "data" + GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/CMakeLists.txt b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/CMakeLists.txt new file mode 100644 index 0000000..2b5b816 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/CMakeLists.txt @@ -0,0 +1,416 @@ +include(CheckCXXCompilerFlag) +include("../cmake/common.cmake") + +add_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES}) + +# enable libstdc++ assertions for debug builds +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_compile_definitions($<$:_GLIBCXX_ASSERTIONS>) +endif() + +if (NOT MSVC) + if (GGML_SANITIZE_THREAD) + add_compile_options(-fsanitize=thread) + link_libraries (-fsanitize=thread) + endif() + + if (GGML_SANITIZE_ADDRESS) + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries (-fsanitize=address) + endif() + + if (GGML_SANITIZE_UNDEFINED) + add_compile_options(-fsanitize=undefined) + link_libraries (-fsanitize=undefined) + endif() +endif() + +if (GGML_FATAL_WARNINGS) + if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + list(APPEND C_FLAGS -Werror) + list(APPEND CXX_FLAGS -Werror) + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + add_compile_options(/WX) + endif() +endif() + +if (GGML_ALL_WARNINGS) + if (NOT MSVC) + list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) + list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes + -Werror=implicit-int -Werror=implicit-function-declaration) + list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn) + + list(APPEND C_FLAGS ${WARNING_FLAGS}) + list(APPEND CXX_FLAGS ${WARNING_FLAGS}) + + ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}) + + add_compile_options("$<$:${C_FLAGS};${GF_C_FLAGS}>" + "$<$:${CXX_FLAGS};${GF_CXX_FLAGS}>") + else() + # todo : msvc + set(C_FLAGS "") + set(CXX_FLAGS "") + endif() +endif() + +if (GGML_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT result OUTPUT output) + if (result) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "IPO is not supported: ${output}") + endif() +endif() + +if (GGML_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER) + find_program(GGML_CCACHE_FOUND ccache) + find_program(GGML_SCCACHE_FOUND sccache) + + if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND) + if(GGML_CCACHE_FOUND) + set(GGML_CCACHE_VARIANT ccache) + else() + set(GGML_CCACHE_VARIANT sccache) + endif() + # TODO: should not be set globally + if (GGML_SYCL AND GGML_CCACHE_FOUND AND WIN32) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache compiler_type=icl") + else () + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}") + endif () + set(ENV{CCACHE_SLOPPINESS} time_macros) + message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.") + else() + message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF") + endif () +endif() + +# this version of Apple ld64 is buggy +execute_process( + COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v + ERROR_VARIABLE output + OUTPUT_QUIET +) + +if (output MATCHES "dyld-1015\.7") + add_compile_definitions(HAVE_BUGGY_APPLE_LINKER) +endif() + +# architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +if (MSVC) + string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR) + message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}") +else () + set(CMAKE_GENERATOR_PLATFORM_LWR "") +endif () +ggml_get_system_arch() +message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}") + +if (NOT MSVC) + if (GGML_STATIC) + add_link_options(-static) + if (MINGW) + add_link_options(-static-libgcc -static-libstdc++) + endif() + endif() + if (GGML_GPROF) + add_compile_options(-pg) + endif() +endif() + +if (MINGW) + add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) +endif() + +# +# POSIX conformance +# + +# clock_gettime came in POSIX.1b (1993) +# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional +# posix_memalign came in POSIX.1-2001 / SUSv3 +# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) + +# Somehow in OpenBSD whenever POSIX conformance is specified +# some string functions rely on locale_t availability, +# which was introduced in POSIX.1-2008, forcing us to go higher +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + add_compile_definitions(_XOPEN_SOURCE=700) +else() + add_compile_definitions(_XOPEN_SOURCE=600) +endif() + +# Data types, macros and functions related to controlling CPU affinity and +# some memory allocation are available on Linux through GNU extensions in libc +if (CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Android") + add_compile_definitions(_GNU_SOURCE) +endif() + +# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, +# and on macOS its availability depends on enabling Darwin extensions +# similarly on DragonFly, enabling BSD extensions is necessary +if ( + CMAKE_SYSTEM_NAME MATCHES "Darwin" OR + CMAKE_SYSTEM_NAME MATCHES "iOS" OR + CMAKE_SYSTEM_NAME MATCHES "tvOS" OR + CMAKE_SYSTEM_NAME MATCHES "DragonFly" +) + add_compile_definitions(_DARWIN_C_SOURCE) +endif() + +# alloca is a non-standard interface that is not visible on BSDs when +# POSIX conformance is specified, but not all of them provide a clean way +# to enable it in such cases +if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD") + add_compile_definitions(__BSD_VISIBLE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "NetBSD") + add_compile_definitions(_NETBSD_SOURCE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + add_compile_definitions(_BSD_SOURCE) +endif() + +if (WIN32) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) +endif() + +# ggml + +if (GGML_BACKEND_DL AND NOT BUILD_SHARED_LIBS) + message(FATAL_ERROR "GGML_BACKEND_DL requires BUILD_SHARED_LIBS") +endif() + +add_library(ggml-base + ../include/ggml.h + ../include/ggml-alloc.h + ../include/ggml-backend.h + ../include/ggml-cpp.h + ../include/ggml-opt.h + ../include/gguf.h + ggml.c + ggml.cpp + ggml-alloc.c + ggml-backend.cpp + ggml-opt.cpp + ggml-threading.cpp + ggml-threading.h + ggml-quants.c + ggml-quants.h + gguf.cpp) + +target_include_directories(ggml-base PRIVATE .) +if (GGML_BACKEND_DL) + target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL) +endif() + +add_library(ggml + ggml-backend-reg.cpp) +add_library(ggml::ggml ALIAS ggml) + +if (GGML_BACKEND_DIR) + if (NOT GGML_BACKEND_DL) + message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL") + endif() + target_compile_definitions(ggml PUBLIC GGML_BACKEND_DIR="${GGML_BACKEND_DIR}") +endif() + +target_link_libraries(ggml PUBLIC ggml-base) + +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_libraries(ggml PRIVATE dl) +endif() + +function(ggml_add_backend_library backend) + if (GGML_BACKEND_DL) + add_library(${backend} MODULE ${ARGN}) + # write the shared library to the output directory + set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL) + add_dependencies(ggml ${backend}) + if (GGML_BACKEND_DIR) + install(TARGETS ${backend} LIBRARY DESTINATION ${GGML_BACKEND_DIR}) + else() + install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() + else() + add_library(${backend} ${ARGN}) + target_link_libraries(ggml PUBLIC ${backend}) + install(TARGETS ${backend} LIBRARY) + endif() + + target_link_libraries(${backend} PRIVATE ggml-base) + target_include_directories(${backend} PRIVATE ..) + + if (${BUILD_SHARED_LIBS}) + target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD) + target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED) + endif() + + if(NOT GGML_AVAILABLE_BACKENDS) + set(GGML_AVAILABLE_BACKENDS "${backend}" + CACHE INTERNAL "List of backends for cmake package") + else() + list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend) + if(has_backend EQUAL -1) + set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}" + CACHE INTERNAL "List of backends for cmake package") + endif() + endif() +endfunction() + +function(ggml_add_backend backend) + string(TOUPPER "GGML_${backend}" backend_id) + if (${backend_id}) + string(TOLOWER "ggml-${backend}" backend_target) + add_subdirectory(${backend_target}) + message(STATUS "Including ${backend} backend") + if (NOT GGML_BACKEND_DL) + string(TOUPPER "GGML_USE_${backend}" backend_use) + target_compile_definitions(ggml PUBLIC ${backend_use}) + endif() + endif() +endfunction() + +function(ggml_add_cpu_backend_variant tag_name) + set(GGML_CPU_TAG_NAME ${tag_name}) + # other: OPENMP LLAMAFILE CPU_HBM + if (GGML_SYSTEM_ARCH STREQUAL "x86") + foreach (feat NATIVE + SSE42 + AVX AVX2 BMI2 AVX_VNNI FMA F16C + AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 + AMX_TILE AMX_INT8 AMX_BF16) + set(GGML_${feat} OFF) + endforeach() + + foreach (feat ${ARGN}) + set(GGML_${feat} ON) + endforeach() + elseif (GGML_SYSTEM_ARCH STREQUAL "ARM") + foreach (feat ${ARGN}) + set(GGML_INTERNAL_${feat} ON) + endforeach() + elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC") + foreach (feat ${ARGN}) + set(GGML_INTERNAL_${feat} ON) + endforeach() + endif() + + ggml_add_cpu_backend_variant_impl(${tag_name}) +endfunction() + +ggml_add_backend(CPU) + +if (GGML_CPU_ALL_VARIANTS) + if (NOT GGML_BACKEND_DL) + message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL") + elseif (GGML_CPU_ARM_ARCH) + message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS") + endif() + if (GGML_SYSTEM_ARCH STREQUAL "x86") + ggml_add_cpu_backend_variant(x64) + ggml_add_cpu_backend_variant(sse42 SSE42) + ggml_add_cpu_backend_variant(sandybridge SSE42 AVX) + ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA) + ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) + ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) + ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) + if (NOT MSVC) + # MSVC doesn't support AMX + ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) + endif() + elseif(GGML_SYSTEM_ARCH STREQUAL "ARM") + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + # Many of these features are optional so we build versions with popular + # combinations and name the backends based on the version they were + # first released with + ggml_add_cpu_backend_variant(armv8.0_1) + ggml_add_cpu_backend_variant(armv8.2_1 DOTPROD) + ggml_add_cpu_backend_variant(armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC) + ggml_add_cpu_backend_variant(armv8.2_3 DOTPROD FP16_VECTOR_ARITHMETIC SVE) + ggml_add_cpu_backend_variant(armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8) + ggml_add_cpu_backend_variant(armv8.6_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2) + ggml_add_cpu_backend_variant(armv9.2_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SME) + ggml_add_cpu_backend_variant(armv9.2_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2 SME) + elseif (CMAKE_SYSTEM_NAME MATCHES "Android") + # Android-specific backends with SoC-compatible feature sets + ggml_add_cpu_backend_variant(android_armv8.0_1) + ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD) + ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC) + ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8) + elseif (APPLE) + ggml_add_cpu_backend_variant(apple_m1 DOTPROD) + ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8) + ggml_add_cpu_backend_variant(apple_m4 DOTPROD MATMUL_INT8 NOSVE SME) + else() + message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}") + endif() + elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC") + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + ggml_add_cpu_backend_variant(power0) + ggml_add_cpu_backend_variant(power7_1 POWER7) + ggml_add_cpu_backend_variant(power7_2 POWER7 VSX) + ggml_add_cpu_backend_variant(power8_1 POWER8) + ggml_add_cpu_backend_variant(power8_2 POWER8 VSX) + ggml_add_cpu_backend_variant(power9 POWER9 VSX) + ggml_add_cpu_backend_variant(power10 POWER10 VSX) + ggml_add_cpu_backend_variant(power11 POWER11 VSX) + else() + message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}") + endif() + else() + message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}") + endif() +elseif (GGML_CPU) + ggml_add_cpu_backend_variant_impl("") +endif() + +ggml_add_backend(BLAS) +ggml_add_backend(CANN) +ggml_add_backend(CUDA) +ggml_add_backend(HIP) +ggml_add_backend(METAL) +ggml_add_backend(MUSA) +ggml_add_backend(RPC) +ggml_add_backend(SYCL) +ggml_add_backend(Vulkan) +ggml_add_backend(WebGPU) +ggml_add_backend(zDNN) +ggml_add_backend(OpenCL) + +foreach (target ggml-base ggml) + target_include_directories(${target} PUBLIC $ $) + target_compile_features (${target} PRIVATE c_std_11 cxx_std_17) # don't bump +endforeach() + +target_link_libraries(ggml-base PRIVATE Threads::Threads) + +find_library(MATH_LIBRARY m) +if (MATH_LIBRARY) + if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) + target_link_libraries(ggml-base PRIVATE m) + endif() +endif() + +if (CMAKE_SYSTEM_NAME MATCHES "Android") + target_link_libraries(ggml-base PRIVATE dl) +endif() + +if(CMAKE_SYSTEM_NAME MATCHES "visionOS") + target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE) +endif() + +if (BUILD_SHARED_LIBS) + foreach (target ggml-base ggml) + set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(${target} PRIVATE GGML_BUILD) + target_compile_definitions(${target} PUBLIC GGML_SHARED) + endforeach() +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-alloc.c b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-alloc.c new file mode 100644 index 0000000..8b6e602 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-alloc.c @@ -0,0 +1,1028 @@ +#include "ggml-alloc.h" +#include "ggml-backend-impl.h" +#include "ggml.h" +#include "ggml-impl.h" +#include +#include +#include +#include +#include +#include + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MAX_FREE_BLOCKS 256 + +//#define GGML_ALLOCATOR_DEBUG + +//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__) +#define AT_PRINTF(...) + + +static bool ggml_is_view(const struct ggml_tensor * t) { + return t->view_src != NULL; +} + +// ops that return true for this function must not use restrict pointers for their backend implementations +static bool ggml_op_can_inplace(enum ggml_op op) { + switch (op) { + case GGML_OP_SCALE: + case GGML_OP_DIAG_MASK_ZERO: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ADD: + case GGML_OP_ADD_ID: + case GGML_OP_ADD1: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_UNARY: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_SILU_BACK: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + return true; + + default: + return false; + } +} + +static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) { + assert(alignment && !(alignment & (alignment - 1))); // power of 2 + size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment; + return offset + align; +} + +// tallocr + +struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer) { + void * base = ggml_backend_buffer_get_base(buffer); + size_t align = ggml_backend_buffer_get_alignment(buffer); + + assert(align && !(align & (align - 1))); // power of 2 + + struct ggml_tallocr talloc = (struct ggml_tallocr) { + /*.buffer = */ buffer, + /*.base = */ base, + /*.alignment = */ align, + /*.offset = */ aligned_offset(base, 0, align), + }; + return talloc; +} + +enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) { + size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor); + size = GGML_PAD(size, talloc->alignment); + + if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) { + GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n", + __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset); + GGML_ABORT("not enough space in the buffer"); + } + + void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset; + talloc->offset += size; + + assert(((uintptr_t)addr % talloc->alignment) == 0); + + return ggml_backend_tensor_alloc(talloc->buffer, tensor, addr); +} + +// dynamic tensor allocator + +struct free_block { + size_t offset; + size_t size; +}; + +struct ggml_dyn_tallocr { + size_t alignment; + int n_free_blocks; + struct free_block free_blocks[MAX_FREE_BLOCKS]; + size_t max_size; + +#ifdef GGML_ALLOCATOR_DEBUG + struct { + const struct ggml_tensor * tensor; + size_t offset; + } allocated_tensors[1024]; +#endif +}; + +#ifdef GGML_ALLOCATOR_DEBUG +static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) { + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i].tensor == NULL) { + alloc->allocated_tensors[i].tensor = tensor; + alloc->allocated_tensors[i].offset = offset; + return; + } + } + GGML_ABORT("out of allocated_tensors"); +} +static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) { + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i].offset == offset) { + alloc->allocated_tensors[i].tensor = NULL; + return; + } + } + GGML_ABORT("tried to free tensor %s not found\n", tensor->name); +} +#endif + +static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t size, const struct ggml_tensor * tensor) { + size = aligned_offset(NULL, size, alloc->alignment); + + AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); + + size_t max_avail = 0; + + // find the best fitting free block besides the last block + int best_fit_block = -1; + size_t best_fit_size = SIZE_MAX; + for (int i = 0; i < alloc->n_free_blocks - 1; i++) { + struct free_block * block = &alloc->free_blocks[i]; + max_avail = MAX(max_avail, block->size); + if (block->size >= size && block->size <= best_fit_size) { + best_fit_block = i; + best_fit_size = block->size; + } + } + + if (best_fit_block == -1) { + // the last block is our last resort + struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; + max_avail = MAX(max_avail, block->size); + if (block->size >= size) { + best_fit_block = alloc->n_free_blocks - 1; + } else { + // this should never happen + GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", + __func__, size, max_avail); + GGML_ABORT("not enough space in the buffer"); + } + } + + struct free_block * block = &alloc->free_blocks[best_fit_block]; + size_t offset = block->offset; + block->offset = offset + size; + block->size -= size; + if (block->size == 0) { + // remove block if empty + alloc->n_free_blocks--; + for (int j = best_fit_block; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + + AT_PRINTF("block %d, offset %zu\n", best_fit_block, offset); + +#ifdef GGML_ALLOCATOR_DEBUG + add_allocated_tensor(alloc, offset, tensor); + size_t cur_max = offset + size; + if (cur_max > alloc->max_size) { + // sort allocated_tensors by offset + for (int i = 0; i < 1024; i++) { + for (int j = i + 1; j < 1024; j++) { + if (alloc->allocated_tensors[i].offset > alloc->allocated_tensors[j].offset) { + const struct ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor; + size_t tmp_offset = alloc->allocated_tensors[i].offset; + alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor; + alloc->allocated_tensors[i].offset = alloc->allocated_tensors[j].offset; + alloc->allocated_tensors[j].tensor = tmp_tensor; + alloc->allocated_tensors[j].offset = tmp_offset; + } + } + } + GGML_LOG_DEBUG("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i].tensor) { + GGML_LOG_DEBUG("%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name, + alloc->allocated_tensors[i].offset, + alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor), + ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0); + } + } + GGML_LOG_DEBUG("\n"); + } +#endif + + alloc->max_size = MAX(alloc->max_size, offset + size); + + return offset; + + GGML_UNUSED(tensor); +} + +// this is a very naive implementation, but for our case the number of free blocks should be very small +static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, size_t size, const struct ggml_tensor * tensor) { + size = aligned_offset(NULL, size, alloc->alignment); + + AT_PRINTF("%s: freeing %s at %zu (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, offset, size, alloc->n_free_blocks); + +#ifdef GGML_ALLOCATOR_DEBUG + remove_allocated_tensor(alloc, offset, tensor); +#endif + + // see if we can merge with an existing block + for (int i = 0; i < alloc->n_free_blocks; i++) { + struct free_block * block = &alloc->free_blocks[i]; + // check if ptr is at the end of the block + if (block->offset + block->size == offset) { + block->size += size; + // check if we can merge with the next block + if (i < alloc->n_free_blocks - 1 && block->offset + block->size == alloc->free_blocks[i+1].offset) { + block->size += alloc->free_blocks[i+1].size; + alloc->n_free_blocks--; + for (int j = i+1; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + return; + } + // check if ptr is at the beginning of the block + if (offset + size == block->offset) { + block->offset = offset; + block->size += size; + // check if we can merge with the previous block + if (i > 0 && alloc->free_blocks[i-1].offset + alloc->free_blocks[i-1].size == block->offset) { + alloc->free_blocks[i-1].size += block->size; + alloc->n_free_blocks--; + for (int j = i; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + return; + } + } + // otherwise, add a new block + GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks"); + // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster) + int insert_pos = 0; + while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].offset < offset) { + insert_pos++; + } + // shift all blocks from insert_pos onward to make room for the new block + for (int i = alloc->n_free_blocks; i > insert_pos; i--) { + alloc->free_blocks[i] = alloc->free_blocks[i-1]; + } + // insert the new block + alloc->free_blocks[insert_pos].offset = offset; + alloc->free_blocks[insert_pos].size = size; + alloc->n_free_blocks++; + + GGML_UNUSED(tensor); +} + +static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) { + alloc->n_free_blocks = 1; + alloc->free_blocks[0].offset = 0; + alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows + alloc->max_size = 0; + +#ifdef GGML_ALLOCATOR_DEBUG + for (int i = 0; i < 1024; i++) { + alloc->allocated_tensors[i].tensor = NULL; + } +#endif +} + +static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) { + struct ggml_dyn_tallocr * alloc = (struct ggml_dyn_tallocr *)malloc(sizeof(struct ggml_dyn_tallocr)); + + *alloc = (struct ggml_dyn_tallocr) { + /*.alignment = */ alignment, + /*.n_free_blocks = */ 0, + /*.free_blocks = */ {{0}}, + /*.max_size = */ 0, +#ifdef GGML_ALLOCATOR_DEBUG + /*.allocated_tensors = */ {{0}}, +#endif + }; + + ggml_dyn_tallocr_reset(alloc); + + return alloc; +} + +static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) { + free(alloc); +} + +static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) { + return alloc->max_size; +} + + +///////////////////////////////////// + +// graph allocator + +struct hash_node { + int n_children; + int n_views; + int buffer_id; + size_t offset; // offset within the buffer + bool allocated; +}; + +struct tensor_alloc { + int buffer_id; + size_t offset; + size_t size_max; // 0 = pre-allocated, unused, or view +}; + +struct leaf_alloc { + struct tensor_alloc leaf; +}; + +struct node_alloc { + struct tensor_alloc dst; + struct tensor_alloc src[GGML_MAX_SRC]; +}; + +struct ggml_gallocr { + ggml_backend_buffer_type_t * bufts; // [n_buffers] + ggml_backend_buffer_t * buffers; // [n_buffers] + struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] + int n_buffers; + + struct ggml_hash_set hash_set; + struct hash_node * hash_values; // [hash_set.size] + + struct node_alloc * node_allocs; // [n_nodes] + int n_nodes; + + struct leaf_alloc * leaf_allocs; // [n_leafs] + int n_leafs; +}; + +ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs) { + ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(1, sizeof(struct ggml_gallocr)); + GGML_ASSERT(galloc != NULL); + + galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t)); + GGML_ASSERT(galloc->bufts != NULL); + + galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t)); + GGML_ASSERT(galloc->buffers != NULL); + + galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); + GGML_ASSERT(galloc->buf_tallocs != NULL); + + for (int i = 0; i < n_bufs; i++) { + galloc->bufts[i] = bufts[i]; + galloc->buffers[i] = NULL; + + // check if the same buffer type is used multiple times and reuse the same allocator + for (int j = 0; j < i; j++) { + if (bufts[i] == bufts[j]) { + galloc->buf_tallocs[i] = galloc->buf_tallocs[j]; + break; + } + } + + if (galloc->buf_tallocs[i] == NULL) { + size_t alignment = ggml_backend_buft_get_alignment(bufts[i]); + galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment); + } + } + galloc->n_buffers = n_bufs; + + return galloc; +} + +ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft) { + return ggml_gallocr_new_n(&buft, 1); +} + +void ggml_gallocr_free(ggml_gallocr_t galloc) { + if (galloc == NULL) { + return; + } + + for (int i = 0; i < galloc->n_buffers; i++) { + if (galloc->buffers != NULL) { + // skip if already freed + bool freed = false; + for (int j = 0; j < i; j++) { + if (galloc->buffers[j] == galloc->buffers[i]) { + freed = true; + break; + } + } + if (!freed) { + ggml_backend_buffer_free(galloc->buffers[i]); + } + } + if (galloc->buf_tallocs != NULL) { + // skip if already freed + bool freed = false; + for (int j = 0; j < i; j++) { + if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) { + freed = true; + break; + } + } + if (!freed) { + ggml_dyn_tallocr_free(galloc->buf_tallocs[i]); + } + } + } + + ggml_hash_set_free(&galloc->hash_set); + free(galloc->hash_values); + free(galloc->bufts); + free(galloc->buffers); + free(galloc->buf_tallocs); + free(galloc->node_allocs); + free(galloc->leaf_allocs); + free(galloc); +} + +typedef struct ggml_gallocr * ggml_gallocr_t; + +static struct hash_node * ggml_gallocr_hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) { + size_t i = ggml_hash_find_or_insert(&galloc->hash_set, t); + return &galloc->hash_values[i]; +} + +static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) { + return ggml_gallocr_hash_get(galloc, t)->allocated; +} + +static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) { + return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated; +} + +static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) { + GGML_ASSERT(buffer_id >= 0); + struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); + + if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { + hn->allocated = true; + assert(hn->offset == 0); + + // try to reuse a parent's buffer (inplace) + if (ggml_op_can_inplace(node->op)) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * parent = node->src[i]; + if (parent == NULL) { + continue; + } + + // if the node's data is external, then we cannot re-use it + if (!ggml_gallocr_is_own(galloc, parent)) { + AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data); + continue; + } + + // outputs cannot be reused + if (parent->flags & GGML_TENSOR_FLAG_OUTPUT || (parent->view_src != NULL && parent->view_src->flags & GGML_TENSOR_FLAG_OUTPUT)) { + AT_PRINTF("not reusing parent %s for %s as it is an output\n", parent->name, node->name); + continue; + } + + if (!ggml_are_same_layout(node, parent)) { + AT_PRINTF("not reusing parent %s for %s as layouts are different\n", parent->name, node->name); + continue; + } + + struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent); + if (p_hn->n_children == 1 && p_hn->n_views == 0) { + if (ggml_is_view(parent)) { + struct ggml_tensor * view_src = parent->view_src; + struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); + if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { + AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); + assert(view_src_hn->offset == p_hn->offset); + hn->buffer_id = p_hn->buffer_id; + hn->offset = p_hn->offset; + p_hn->allocated = false; // avoid freeing the parent + view_src_hn->allocated = false; + return; + } + } else { + AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); + hn->buffer_id = p_hn->buffer_id; + hn->offset = p_hn->offset; + p_hn->allocated = false; // avoid freeing the parent + return; + } + } + } + } + // allocate tensor from the buffer + struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; + ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; + size_t size = ggml_backend_buft_get_alloc_size(buft, node); + size_t offset = ggml_dyn_tallocr_alloc(alloc, size, node); + hn->buffer_id = buffer_id; + hn->offset = offset; + } +} + +static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) { + // graph outputs are never freed + if (node->flags & GGML_TENSOR_FLAG_OUTPUT) { + AT_PRINTF("not freeing output %s\n", node->name); + return; + } + + struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); + size_t offset = hn->offset; + int buffer_id = hn->buffer_id; + struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; + ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; + size_t size = ggml_backend_buft_get_alloc_size(buft, node); + ggml_dyn_tallocr_free_tensor(alloc, offset, size, node); + hn->allocated = false; +} + +static int get_node_buffer_id(const int * node_buffer_ids, int i) { + return node_buffer_ids ? node_buffer_ids[i] : 0; +} + +static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) { + // clear hash tables + ggml_hash_set_reset(&galloc->hash_set); + memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size); + + // allocate leafs + // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes + for (int i = 0; i < graph->n_leafs; i++) { + struct ggml_tensor * leaf = graph->leafs[i]; + ggml_gallocr_allocate_node(galloc, leaf, get_node_buffer_id(leaf_buffer_ids, i)); + } + + // count number of children and views + // allocate other graph inputs and leafs first to avoid overwriting them + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + + // TODO: better way to add external dependencies + // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to + // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node + // itself is never used and should not be considered a dependency + if (ggml_is_view(node) && node->op != GGML_OP_NONE) { + struct ggml_tensor * view_src = node->view_src; + ggml_gallocr_hash_get(galloc, view_src)->n_views += 1; + } + + if (node->flags & GGML_TENSOR_FLAG_INPUT) { + ggml_gallocr_allocate_node(galloc, graph->nodes[i], get_node_buffer_id(node_buffer_ids, i)); + } + + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + + ggml_gallocr_hash_get(galloc, src)->n_children += 1; + + // allocate explicit inputs + if (src->flags & GGML_TENSOR_FLAG_INPUT) { + ggml_gallocr_allocate_node(galloc, src, get_node_buffer_id(node_buffer_ids, i)); + } + } + } + + // allocate tensors + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + int buffer_id = get_node_buffer_id(node_buffer_ids, i); + + // allocate parents (only leafs need to be allocated at this point) + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + continue; + } + ggml_gallocr_allocate_node(galloc, parent, buffer_id); + } + + // allocate node + ggml_gallocr_allocate_node(galloc, node, buffer_id); + + AT_PRINTF("exec: %s (%s) <= ", ggml_op_desc(node), node->name); + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + continue; + } + AT_PRINTF("%s", parent->name); + if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) { + AT_PRINTF(", "); + } + } + AT_PRINTF("\n"); + + // update parents + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + continue; + } + struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent); + p_hn->n_children -= 1; + + AT_PRINTF("parent %s: %d children, %d views, allocated: %d\n", + parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated); + + if (p_hn->n_children == 0 && p_hn->n_views == 0) { + if (ggml_is_view(parent)) { + struct ggml_tensor * view_src = parent->view_src; + struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); + view_src_hn->n_views -= 1; + AT_PRINTF("view_src %s: %d children, %d views\n", + view_src->name, view_src_hn->n_children, view_src_hn->n_views); + if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) { + ggml_gallocr_free_node(galloc, view_src); + } + } + else if (p_hn->allocated) { + ggml_gallocr_free_node(galloc, parent); + } + } + AT_PRINTF("\n"); + } + } +} + +bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) { + size_t min_hash_size = graph->n_nodes + graph->n_leafs; + // add 25% margin to avoid hash collisions + min_hash_size += min_hash_size / 4; + + // initialize hash table + if (galloc->hash_set.size < min_hash_size) { + ggml_hash_set_free(&galloc->hash_set); + galloc->hash_set = ggml_hash_set_new(min_hash_size); + GGML_ASSERT(galloc->hash_set.keys != NULL); + + free(galloc->hash_values); + galloc->hash_values = malloc(sizeof(struct hash_node) * galloc->hash_set.size); + GGML_ASSERT(galloc->hash_values != NULL); + } + + // reset allocators + for (int i = 0; i < galloc->n_buffers; i++) { + ggml_dyn_tallocr_reset(galloc->buf_tallocs[i]); + } + + // allocate in hash table + ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids); + + // set the node_allocs from the hash table + if (galloc->n_nodes < graph->n_nodes) { + free(galloc->node_allocs); + galloc->node_allocs = calloc(graph->n_nodes, sizeof(struct node_alloc)); + GGML_ASSERT(galloc->node_allocs != NULL); + } + galloc->n_nodes = graph->n_nodes; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + struct node_alloc * node_alloc = &galloc->node_allocs[i]; + if (node->view_src || node->data) { + node_alloc->dst.buffer_id = -1; + node_alloc->dst.offset = SIZE_MAX; + node_alloc->dst.size_max = 0; + } else { + struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); + node_alloc->dst.buffer_id = hn->buffer_id; + node_alloc->dst.offset = hn->offset; + node_alloc->dst.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node); + } + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (!src || src->view_src || src->data) { + node_alloc->src[j].buffer_id = -1; + node_alloc->src[j].offset = SIZE_MAX; + node_alloc->src[j].size_max = 0; + } else { + struct hash_node * hn = ggml_gallocr_hash_get(galloc, src); + node_alloc->src[j].buffer_id = hn->buffer_id; + node_alloc->src[j].offset = hn->offset; + node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src); + } + } + } + if (galloc->n_leafs < graph->n_leafs) { + free(galloc->leaf_allocs); + galloc->leaf_allocs = calloc(graph->n_leafs, sizeof(galloc->leaf_allocs[0])); + GGML_ASSERT(galloc->leaf_allocs != NULL); + } + galloc->n_leafs = graph->n_leafs; + for (int i = 0; i < graph->n_leafs; i++) { + struct ggml_tensor * leaf = graph->leafs[i]; + struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf); + if (leaf->view_src || leaf->data) { + galloc->leaf_allocs[i].leaf.buffer_id = -1; + galloc->leaf_allocs[i].leaf.offset = SIZE_MAX; + galloc->leaf_allocs[i].leaf.size_max = 0; + } else { + galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id; + galloc->leaf_allocs[i].leaf.offset = hn->offset; + galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf); + } + } + + // reallocate buffers if needed + for (int i = 0; i < galloc->n_buffers; i++) { + // if the buffer type is used multiple times, we reuse the same buffer + for (int j = 0; j < i; j++) { + if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) { + galloc->buffers[i] = galloc->buffers[j]; + break; + } + } + + size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0; + size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]); + + // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views + if (new_size > cur_size || galloc->buffers[i] == NULL) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); +#endif + + ggml_backend_buffer_free(galloc->buffers[i]); + galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); + if (galloc->buffers[i] == NULL) { + GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); + return false; + } + ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); + } + } + + return true; +} + +bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { + return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL); +} + +static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) { + int buffer_id = tensor_alloc->buffer_id; + assert(tensor->data || tensor->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max); + + if (tensor->view_src != NULL) { + if (tensor->buffer == NULL) { + assert(tensor_alloc->offset == SIZE_MAX); + if (tensor->view_src->buffer == NULL) { + // this tensor was allocated without ggml-backend + return; + } + ggml_backend_view_init(tensor); + } + } else { + if (tensor->data == NULL) { + assert(tensor_alloc->offset != SIZE_MAX); + assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max); + void * base = ggml_backend_buffer_get_base(galloc->buffers[buffer_id]); + void * addr = (char *)base + tensor_alloc->offset; + ggml_backend_tensor_alloc(galloc->buffers[buffer_id], tensor, addr); + } else { + if (tensor->buffer == NULL) { + // this tensor was allocated without ggml-backend + return; + } + } + } +} + +static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) { + size_t node_size = 0; + if (!node->data && !node->view_src) { + // If we previously had data but don't now then reallocate + if (talloc->buffer_id < 0) { + return false; + } + node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + } + return talloc->size_max >= node_size; +} + +static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) { + if (galloc->n_nodes != graph->n_nodes) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: graph has different number of nodes\n", __func__); +#endif + return true; + } + + if (galloc->n_leafs != graph->n_leafs) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: graph has different number of leafs\n", __func__); +#endif + return true; + } + + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + struct node_alloc * node_alloc = &galloc->node_allocs[i]; + + if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: node %s is not valid\n", __func__, node->name); +#endif + return true; + } + + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name); +#endif + return true; + } + } + } + + return false; +} + +bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) { + if (ggml_gallocr_needs_realloc(galloc, graph)) { + if (galloc->n_buffers == 1) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: reallocating buffers automatically\n", __func__); +#endif + if (!ggml_gallocr_reserve(galloc, graph)) { + return false; + } + } else { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__); +#endif + return false; + } + } + + // reset buffers + for (int i = 0; i < galloc->n_buffers; i++) { + if (galloc->buffers[i] != NULL) { + ggml_backend_buffer_reset(galloc->buffers[i]); + } + } + + // allocate the graph tensors from the previous assignments + // leafs + for (int i = 0; i < graph->n_leafs; i++) { + struct ggml_tensor * leaf = graph->leafs[i]; + struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i]; + ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf); + } + // nodes + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + struct node_alloc * node_alloc = &galloc->node_allocs[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]); + } + ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst); + } + + return true; +} + +size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); + + if (galloc->buffers[buffer_id] == NULL) { + return 0; + } + + for (int i = 0; i < buffer_id; i++) { + if (galloc->buffers[i] == galloc->buffers[buffer_id]) { + // this buffer is the same as a previous one due to the same buffer type being used multiple times + // only return the buffer size the first time it appears to avoid double counting + return 0; + } + } + + return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); +} + +// utils + +static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { + for (size_t i = 0; i < *n_buffers; i++) { + ggml_backend_buffer_free((*buffers)[i]); + } + free(*buffers); +} + +static bool alloc_tensor_range(struct ggml_context * ctx, + struct ggml_tensor * first, struct ggml_tensor * last, + ggml_backend_buffer_type_t buft, size_t size, + ggml_backend_buffer_t ** buffers, size_t * n_buffers) { + + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size); + if (buffer == NULL) { + GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); + free_buffers(buffers, n_buffers); + return false; + } + + *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1)); + (*buffers)[(*n_buffers)++] = buffer; + + struct ggml_tallocr tallocr = ggml_tallocr_new(buffer); + + for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) { + enum ggml_status status = GGML_STATUS_SUCCESS; + if (t->data == NULL) { + if (t->view_src == NULL) { + status = ggml_tallocr_alloc(&tallocr, t); + } else if (t->buffer == NULL) { + status = ggml_backend_view_init(t); + } + } else { + if (t->view_src != NULL && t->buffer == NULL) { + // view of a pre-allocated tensor + status = ggml_backend_view_init(t); + } + } + if (status != GGML_STATUS_SUCCESS) { + GGML_LOG_ERROR("%s: failed to initialize tensor %s\n", __func__, t->name); + free_buffers(buffers, n_buffers); + return false; + } + } + + return true; +} + +ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_get_no_alloc(ctx) == true); + + size_t alignment = ggml_backend_buft_get_alignment(buft); + size_t max_size = ggml_backend_buft_get_max_size(buft); + + ggml_backend_buffer_t * buffers = NULL; + size_t n_buffers = 0; + + size_t cur_buf_size = 0; + struct ggml_tensor * first = ggml_get_first_tensor(ctx); + for (struct ggml_tensor * t = first; t != NULL; t = ggml_get_next_tensor(ctx, t)) { + size_t this_size = 0; + if (t->data == NULL && t->view_src == NULL) { + this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment); + } + + if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) { + // allocate tensors in the current buffer + if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) { + return NULL; + } + first = t; + cur_buf_size = this_size; + } else { + cur_buf_size += this_size; + } + } + + // allocate remaining tensors + if (cur_buf_size > 0) { + if (!alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) { + return NULL; + } + } + + if (n_buffers == 0) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__); +#endif + return NULL; + } + + ggml_backend_buffer_t buffer; + if (n_buffers == 1) { + buffer = buffers[0]; + } else { + buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers); + } + free(buffers); + return buffer; +} + +ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) { + return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend)); +} diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/CMakeLists.txt b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/CMakeLists.txt new file mode 100644 index 0000000..d6676f3 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/CMakeLists.txt @@ -0,0 +1,107 @@ +if (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$") AND + CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 11.0) + message(STATUS "Using AMX") + + file(GLOB GGML_HEADERS_AMX "*.h") + list(APPEND GGML_HEADERS_AMX "../../include/ggml-amx.h") + + file(GLOB GGML_SOURCES_AMX "*.cpp") + + add_library(ggml-amx + ${GGML_HEADERS_AMX} + ${GGML_SOURCES_AMX}) + + target_link_libraries(ggml-amx PRIVATE ggml-base) + target_include_directories(ggml-amx PRIVATE . ..) + + # this is duplicated from the CPU backend, since the AMX backend also depends on the architecture flags + # TODO: integrate AMX backend into the CPU backend + if (MSVC) + # instruction set detection for MSVC only + if (GGML_NATIVE) + # TODO: improve, should not reference files from the parent folder + include(../ggml-cpu/cmake/FindSIMD.cmake) + endif () + if (GGML_AVX512) + list(APPEND ARCH_FLAGS /arch:AVX512) + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + if (GGML_AVX512_VBMI) + add_compile_definitions($<$:__AVX512VBMI__>) + add_compile_definitions($<$:__AVX512VBMI__>) + endif() + if (GGML_AVX512_VNNI) + add_compile_definitions($<$:__AVX512VNNI__>) + add_compile_definitions($<$:__AVX512VNNI__>) + endif() + if (GGML_AVX512_BF16) + add_compile_definitions($<$:__AVX512BF16__>) + add_compile_definitions($<$:__AVX512BF16__>) + endif() + if (GGML_AMX_TILE) + add_compile_definitions($<$:__AMX_TILE__>) + add_compile_definitions($<$:__AMX_TILE__>) + endif() + if (GGML_AMX_INT8) + add_compile_definitions($<$:__AMX_INT8__>) + add_compile_definitions($<$:__AMX_INT8__>) + endif() + if (GGML_AMX_BF16) + add_compile_definitions($<$:__AMX_BF16__>) + add_compile_definitions($<$:__AMX_BF16__>) + endif() + elseif (GGML_AVX2) + list(APPEND ARCH_FLAGS /arch:AVX2) + elseif (GGML_AVX) + list(APPEND ARCH_FLAGS /arch:AVX) + endif() + else() + if (GGML_NATIVE) + list(APPEND ARCH_FLAGS -march=native) + endif() + if (GGML_F16C) + list(APPEND ARCH_FLAGS -mf16c) + endif() + if (GGML_FMA) + list(APPEND ARCH_FLAGS -mfma) + endif() + if (GGML_AVX) + list(APPEND ARCH_FLAGS -mavx) + endif() + if (GGML_AVX2) + list(APPEND ARCH_FLAGS -mavx2) + endif() + if (GGML_AVX512) + list(APPEND ARCH_FLAGS -mavx512f) + list(APPEND ARCH_FLAGS -mavx512dq) + list(APPEND ARCH_FLAGS -mavx512bw) + endif() + if (GGML_AVX512_VBMI) + list(APPEND ARCH_FLAGS -mavx512vbmi) + endif() + if (GGML_AVX512_VNNI) + list(APPEND ARCH_FLAGS -mavx512vnni) + endif() + if (GGML_AVX512_BF16) + list(APPEND ARCH_FLAGS -mavx512bf16) + endif() + if (GGML_AMX_TILE) + list(APPEND ARCH_FLAGS -mamx-tile) + endif() + if (GGML_AMX_INT8) + list(APPEND ARCH_FLAGS -mamx-int8) + endif() + if (GGML_AMX_BF16) + list(APPEND ARCH_FLAGS -mamx-bf16) + endif() + endif() + + target_compile_options(ggml-amx PRIVATE ${ARCH_FLAGS}) +else() + set(GGML_AMX OFF PARENT_SCOPE) + message(WARNING "AMX requires x86 and gcc version > 11.0. Turning off GGML_AMX.") +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/common.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/common.h new file mode 100644 index 0000000..5db8ce3 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/common.h @@ -0,0 +1,94 @@ +#pragma once + +#include "ggml.h" +// hack until AMX is moved into the CPU backend +#include "../ggml-cpu/ggml-cpu-impl.h" // + +#include +#include +#include + +#if defined(_OPENMP) +#include +#endif + +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 +#define VNNI_BLK 4 + +#define AMX_BLK_SIZE 32 + +#define TMM0 0 +#define TMM1 1 +#define TMM2 2 +#define TMM3 3 +#define TMM4 4 +#define TMM5 5 +#define TMM6 6 +#define TMM7 7 + +// parallel routines +template ::value, int>::type = 0> +inline T div_up(T x, T y) { return (x + y - 1) / y; } + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int nth, int n, const func_t& f) { +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) +{ + //int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); +} +#else + f(0, n); + + GGML_UNUSED(nth); +#endif +} + +// quantized types that have AMX support +inline bool qtype_has_amx_kernels(const enum ggml_type type) { + // TODO: fix padding for vnni format + return (type == GGML_TYPE_Q4_0) || + (type == GGML_TYPE_Q4_1); + //(type == GGML_TYPE_Q8_0) || + //(type == GGML_TYPE_Q4_K) || + //(type == GGML_TYPE_Q5_K) || + //(type == GGML_TYPE_Q6_K) || + //(type == GGML_TYPE_IQ4_XS); +} + +// ggml backend context +struct ggml_backend_amx_context { + int n_threads = GGML_DEFAULT_N_THREADS; + std::unique_ptr work_data; + size_t work_size = 0; +}; diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/ggml-amx.cpp b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/ggml-amx.cpp new file mode 100644 index 0000000..8568e79 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/ggml-amx.cpp @@ -0,0 +1,446 @@ +#include "ggml-amx.h" +#include "ggml-amx/common.h" +#include "ggml-amx/mmq.h" +#include "ggml-backend-impl.h" +#include "ggml-impl.h" + +#if defined(__gnu_linux__) +#include +#include +#endif + +#include +#include +#include + +#if defined(__AMX_INT8__) + +// AMX buffer interface +static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) { + free(buffer->context); +} + +static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) { + return (void *)(buffer->context); +} + +static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + if (qtype_has_amx_kernels(tensor->type)) { + ggml_backend_amx_convert_weight(tensor, data, offset, size); + } else { + memcpy((char *)tensor->data + offset, data, size); + } + + GGML_UNUSED(buffer); +} + +static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(!qtype_has_amx_kernels(tensor->type)); + memcpy(data, (const char *)tensor->data + offset, size); + + GGML_UNUSED(buffer); +} + +static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + if (qtype_has_amx_kernels(src->type)) { + ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst)); + } else { + memcpy(dst->data, src->data, ggml_nbytes(src)); + } + return true; + } + return false; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + +static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { + /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer, + /* .get_base = */ ggml_backend_amx_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_amx_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor, + /* .clear = */ ggml_backend_amx_buffer_clear, + /* .reset = */ NULL, +}; + +static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "AMX"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * data = aligned_alloc(TENSOR_ALIGNMENT, size); + if (data == NULL) { + fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size); +} + +static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) { + return ggml_backend_amx_get_alloc_size(tensor); + + GGML_UNUSED(buft); +} + +static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = { + /* .iface = */ { + /* .get_name = */ ggml_backend_amx_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_amx_buffer_type_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0), + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_amx; +} + +// backend interface + +static const char * ggml_backend_amx_name(ggml_backend_t backend) { + return "AMX"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_amx_free(ggml_backend_t backend) { + ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context; + delete ctx; + delete backend; +} + +static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context; + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case GGML_OP_MUL_MAT: + ggml_backend_amx_mul_mat(ctx, node); + break; + + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + break; + + default: + fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node)); + GGML_ASSERT(false); + } + } + + return GGML_STATUS_SUCCESS; + + GGML_UNUSED(backend); +} + +static struct ggml_backend_i ggml_backend_amx_i = { + /* .get_name = */ ggml_backend_amx_name, + /* .free = */ ggml_backend_amx_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_amx_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_amx_guid() { + static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e }; + return &guid; +} + +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 + +static bool ggml_amx_init() { +#if defined(__gnu_linux__) + if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) { + fprintf(stderr, "AMX is not ready to be used!\n"); + return false; + } + return true; +#elif defined(_WIN32) + return true; +#endif +} + +ggml_backend_t ggml_backend_amx_init() { + + // invoke a Linux system call to request access to AMX features + ggml_amx_init(); + + // backend context + ggml_backend_amx_context * ctx = new ggml_backend_amx_context; + + // ggml amx backend + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_amx_guid(), + /* .interface = */ ggml_backend_amx_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0), + /* .context = */ ctx, + }; + + return backend; +} + +bool ggml_backend_is_amx(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid()); +} + +void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) { + GGML_ASSERT(ggml_backend_is_amx(backend_amx)); + + ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context; + ctx->n_threads = n_threads; +} + +// device interface + +static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) { + return "AMX"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) { + return "Intel Advanced Matrix Extensions"; + + GGML_UNUSED(dev); +} + +static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + // TODO + *free = 0; + *total = 0; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_ACCEL; + + GGML_UNUSED(dev); +} + +static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_amx_device_get_name(dev); + props->description = ggml_backend_amx_device_get_description(dev); + props->type = ggml_backend_amx_device_get_type(dev); + ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total); + + // `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) { + return ggml_backend_amx_init(); + + GGML_UNUSED(dev); + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_amx_buffer_type(); + + GGML_UNUSED(dev); +} + +static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + + // handle only 2d gemm for now + auto is_contiguous_2d = [](const struct ggml_tensor * t) { + return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; + }; + + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + + case GGML_OP_MUL_MAT: { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + const enum ggml_type type = src0->type; + const int64_t ne0 = op->ne[0]; + + // amx kernels enables for Q4_0, Q4_1, Q8_0, F16 + // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256 + bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16); + + bool can_use_amx = + is_contiguous_2d(src0) && // src0 must be contiguous + is_contiguous_2d(src1) && // src1 must be contiguous + src1->type == GGML_TYPE_F32 && // src1 must be float32 + has_amx_kernels && // with amx kernel impls + ne0 % (TILE_N * 2) == 0; // out_features is 32x + + return can_use_amx; + } + default: + return false; + } + + GGML_UNUSED(dev); +} + +static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name; + + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_amx_device_i = { + /* .get_name = */ ggml_backend_amx_device_get_name, + /* .get_description = */ ggml_backend_amx_device_get_description, + /* .get_memory = */ ggml_backend_amx_device_get_memory, + /* .get_type = */ ggml_backend_amx_device_get_type, + /* .get_props = */ ggml_backend_amx_device_get_props, + /* .init_backend = */ ggml_backend_amx_device_init, + /* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_amx_device_supports_op, + /* .supports_buft = */ ggml_backend_amx_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface + +static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) { + return "AMX"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + static ggml_backend_device ggml_backend_amx_device = { + /* .iface = */ ggml_backend_amx_device_i, + /* .reg = */ reg, + /* .context = */ nullptr, + }; + + return &ggml_backend_amx_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) { + return (void *)ggml_backend_amx_set_n_threads; + } + return NULL; + + GGML_UNUSED(reg); + GGML_UNUSED(name); +} + +static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = { + /* .get_name = */ ggml_backend_amx_reg_get_name, + /* .get_device_count = */ ggml_backend_amx_reg_get_device_count, + /* .get_device = */ ggml_backend_amx_reg_get_device, + /* .get_proc_address = */ ggml_backend_amx_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_amx_reg(void) { + static struct ggml_backend_reg ggml_backend_amx_reg = { + /* .iface = */ ggml_backend_amx_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_amx_reg; +} + +#else // if defined(__AMX_INT8__) + +ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void) { + return nullptr; +} + +bool ggml_backend_is_amx(ggml_backend_t backend) { + GGML_UNUSED(backend); + return false; +} + +ggml_backend_t ggml_backend_amx_init(void) { + fprintf(stderr, "GGML is not compiled with AMX support!\n"); + return nullptr; +} + +void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) { + fprintf(stderr, "GGML is not compiled with AMX support!\n"); + + GGML_UNUSED(backend_amx); + GGML_UNUSED(n_threads); +} + +ggml_backend_reg_t ggml_backend_amx_reg(void) { + return nullptr; +} + +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/mmq.cpp b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/mmq.cpp new file mode 100644 index 0000000..529bee2 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/mmq.cpp @@ -0,0 +1,2510 @@ + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#include "mmq.h" +#include "ggml-impl.h" +#include "ggml-quants.h" +#include +#include + +#if defined(__gnu_linux__) +#include +#include +#endif + +#if defined(_OPENMP) +#include +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define RESTRICT __restrict +#else +#define RESTRICT __restrict__ +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +#if defined(__AMX_INT8__) + +namespace { + +// Forced unrolling +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +// type traits +template struct PackedTypes {}; +template <> struct PackedTypes { using type = int8_t; }; +template <> struct PackedTypes { using type = uint8_t; }; +template <> struct PackedTypes { using type = int8_t; }; +template using packed_B_type = typename PackedTypes::type; + +template +struct do_compensate : std::integral_constant::value> {}; + +template +struct do_unpack : std::integral_constant::value || + std::is_same::value> {}; + +template +struct is_type_qkk : std::integral_constant::value || + std::is_same::value || + std::is_same::value || + std::is_same::value> {}; + +#define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case GGML_TYPE_F16: { \ + using type = ggml_fp16_t; \ + constexpr int blck_size = 16; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_BF16: { \ + using type = ggml_bf16_t; \ + constexpr int blck_size = 32; \ + return __VA_ARGS__(); \ + } \ + default: \ + fprintf(stderr, "Unsupported floating data type\n"); \ + } \ + }() + +#define GGML_DISPATCH_QTYPES(QT, ...) \ + [&] { \ + switch (QT) { \ + case GGML_TYPE_Q4_0: { \ + using type = block_q4_0; \ + using vec_dot_type = block_q8_0; \ + constexpr int blck_size = QK4_0; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q4_1: { \ + using type = block_q4_1; \ + using vec_dot_type = block_q8_1; \ + constexpr int blck_size = QK4_1; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q8_0: { \ + using type = block_q8_0; \ + using vec_dot_type = block_q8_0; \ + constexpr int blck_size = QK8_0; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q4_K: { \ + using type = block_q4_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q5_K: { \ + using type = block_q5_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q6_K: { \ + using type = block_q6_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_IQ4_XS: { \ + using type = block_iq4_xs; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + default: \ + fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \ + } \ + }() + +#define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// define amx tile config data structure +struct tile_config_t{ + uint8_t palette_id = 0; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +}; + +// Notes: amx tile config +// +// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values, +// and accumulate the result to a 16 x 16 matrix C containing INT32 values, +// +// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used +// instead of the normally used 16-16-64 config. +// +// Block A: {16, 32}, dtype = int8_t +// Block B: {16, 32}, dtype = uint8_t/int8_t +// Block C: {16, 16}, dtype = int32_t +// +// Block B needs to be prepacked to vnni format before feeding into TMUL: +// packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64} +// +// Therefore, we get tileconfig: +// A B C +// rows 16 8 16 +// colsb 32 64 16 +// +// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1, +// C used TMM4-TMM7: +// B TMM0 B TMM1 +// A TMM2 C TMM4 C TMM6 +// A TMM3 C TMM5 C TMM7 +// +// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A +// will be needed. +// +// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; +// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. +// +// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ +// advanced-matrix-extensions-intrinsics-functions.html +// + +#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb +void ggml_tile_config_init(void) { + static thread_local bool is_first_time = true; + + if (!is_first_time) { + return; + } + + static thread_local tile_config_t tc; + tile_config_t current_tc; + _tile_storeconfig(¤t_tc); + + // load only when config changes + if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && + memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { + tc.palette_id = 1; + tc.start_row = 0; + TC_CONFIG_TILE(TMM0, 8, 64); + TC_CONFIG_TILE(TMM1, 8, 64); + TC_CONFIG_TILE(TMM2, 16, 32); + TC_CONFIG_TILE(TMM3, 16, 32); + TC_CONFIG_TILE(TMM4, 16, 64); + TC_CONFIG_TILE(TMM5, 16, 64); + TC_CONFIG_TILE(TMM6, 16, 64); + TC_CONFIG_TILE(TMM7, 16, 64); + _tile_loadconfig(&tc); + } + + is_first_time = false; +} + +// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. +// See the notes `s8s8 igemm compensation in avx512-vnni` for detail. +template +int get_tile_size() { + int tile_size = TILE_N * sizeof(TB); + if (do_compensate::value) { + tile_size += TILE_N * sizeof(int32_t); + } + if (std::is_same::value || + std::is_same::value) { + tile_size += TILE_N * 4; + } + if (std::is_same::value) { + tile_size += TILE_N * 2; + } + return tile_size; +} + +template +int get_row_size(int K) { + int KB = K / BLOCK_K; + int row_size = KB * sizeof(TB); + if (do_compensate::value) { + row_size += KB * sizeof(int32_t); + } + if (std::is_same::value || + std::is_same::value) { + row_size += KB * 4; + } + if (std::is_same::value) { + row_size += KB * 2; + } + return row_size; +} + +// vectorized dtype conversion +inline float FP16_TO_FP32(ggml_half val) { + __m256i v = _mm256_setr_epi16( + val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512 o = _mm512_cvtph_ps(v); + return _mm512_cvtss_f32(o); +} + +inline __m512 FP16_TO_FP32_VEC(ggml_half val) { + __m256i v = _mm256_set1_epi16(val); + return _mm512_cvtph_ps(v); +} + +// horizontal reduce +inline float _mm512_reduce_max_ps(const __m512 x) { + __m512 v = x; + __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_f32x4(v, v, 0xB1); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_ps(v, v, 0x4E); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_ps(v, v, 0xB1); + v = _mm512_max_ps(v, v1); + return _mm512_cvtss_f32(v); +} + +// transpose utils +#define SHUFFLE_EPI32(a, b, mask) \ + _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) +inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) { + // unpacking and 32-bit elements + v1[0] = _mm256_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm256_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm256_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm256_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm256_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm256_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm256_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm256_unpackhi_epi32(v[6], v[7]); + + // shuffling the 32-bit elements + v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44); + v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee); + v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44); + v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee); + v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44); + v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee); + v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44); + v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee); + + // shuffling 128-bit elements + v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02); + v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02); + v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02); + v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02); + v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13); + v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13); + v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13); + v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13); +} + +inline void transpose_16x4_32bit(__m512i * r, __m512i * d) { + + static const __m512i index1 = _mm512_set_epi32( + 0x0f, 0x0b, 0x07, 0x03, + 0x0e, 0x0a, 0x06, 0x02, + 0x0d, 0x09, 0x05, 0x01, + 0x0c, 0x08, 0x04, 0x00); + + d[0] = _mm512_permutexvar_epi32(index1, r[0]); + d[1] = _mm512_permutexvar_epi32(index1, r[1]); + d[2] = _mm512_permutexvar_epi32(index1, r[2]); + d[3] = _mm512_permutexvar_epi32(index1, r[3]); + + r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); + r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); + r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); + r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); + + d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); + d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); + d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); + d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); +} + +inline void transpose_16x16_32bit(__m512i * v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + const int KB = k / QK_K; + constexpr int kVecs = QK_K / 16; + + block_q8_K * y = reinterpret_cast(vy); + + // hold 16 float vecs from x + __m512 v[kVecs]; + + // hold the quants vecs + __m512i vq[kVecs / 4]; + + // hold the packed quants vecs + __m512i vq_packed[kVecs / 4]; + + const __m512 signBit = _mm512_set1_ps(-0.f); + + for (int i = 0; i < KB; ++i) { + // Compute max(abs(e)) for the block + __m512 vamax = _mm512_set1_ps(0.f); + for (int j = 0; j < kVecs; ++j) { + v[j] = _mm512_loadu_ps(x); x += 16; + vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j])); + } + const float amax = _mm512_reduce_max_ps(vamax); + + // Quantize these floats + const float iscale = 127.f / amax; + y[i].d = GGML_FP32_TO_FP16(1 / iscale); + const float id = ( amax != 0.0f ) ? iscale : 0.f; + const __m512 vscale = _mm512_set1_ps(id); + + // Apply multiplier and round to nearest integer + for (int j = 0; j < kVecs; ++j) { + v[j] = _mm512_mul_ps(v[j], vscale); + v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + + // Pack to epi8 vecs + for (int j = 0; j < kVecs / 4; ++j) { + __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0])); + __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1])); + __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2])); + __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3])); + + __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1); + __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1); + + vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1); + _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]); + } + + // Compute the bsums with vnni + transpose_16x4_32bit(vq, vq_packed); + + const __m512i one = _mm512_set1_epi8(1); + __m512i sum = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]); + } + _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum)); + } +} + +// quantize A from float to `vec_dot_type` +template +inline void from_float(const float * x, char * vy, int64_t k); + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { + // FIXME: using unoptimized reference impl until moved to CPU backend + quantize_row_q8_0_ref(x, (block_q8_0 *)vy, k); +} + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { + quantize_row_q8_1_ref(x, (block_q8_1 *)vy, k); +} + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { +#if 1 + // TODO: this is reference impl! + quantize_row_q8_K_ref(x, (block_q8_K *)vy, k); +#else + quantize_row_q8_K_vnni(x, vy, k); +#endif +} + +// load A from memory to array when nrows can not fill in whole tile +void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) { + assert(nr != TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) { + assert(nr != TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +template +void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { + assert(nr <= TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +template <> +void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { + assert(nr <= TILE_M); + // zero padding k from 16 to 32, so that we don't have to re-config amx + const __m128i zero = _mm_setzero_si128(); + for (int m = 0; m < nr; ++m) { + const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16)); + const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r); + } +} + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8(0xF); + return _mm256_and_si256(lowMask, bytes); +} + +// used for block_q4_K +inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) { + const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi); + const __m256i lowMask = _mm256_set1_epi8(0xF); + const __m256i q4l = _mm256_and_si256(tmp, lowMask); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask); + return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1); +} + +// used for block_q5_K +inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) { + const __m256i lowMask = _mm256_set1_epi8(0xF); + __m256i hmask = _mm256_set1_epi8(1); + hmask = _mm256_slli_epi16(hmask, k); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs); + const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + + return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1); +} + +// used for block_q6_K +inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) { + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(0x3); + + const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs); + const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32)); + const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh); + + const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4); + const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4); + const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4); + const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4); + + const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0); + const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1); + const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2); + const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3); + + r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1); + r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1); +} + +inline __m512i packNibbles(__m512i r0, __m512i r1) { + return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4)); +} + +template +inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) { + int8_t tmp[8 * 64]; + __m256i v[8], v2[8]; + for (int n = 0; n < 8; ++n) { + v[n] = bytes_from_nibbles_32(B[n * KB].qs); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]); + } + for (int n = 0; n < 8; ++n) { + v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]); + } + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < 8; n += 2) { + __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64)); + __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64)); + __m512i r1r0 = packNibbles(r0, r1); + _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0); + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { + __m256i v[8], v2[8]; + for (int n = 0; n < 8; ++n) { + v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs)); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]); + } + for (int n = 0; n < 8; ++n) { + v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs)); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]); + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { + __m512i v[16]; + // QK_K 256 with 8 groups, handle 2 groups at a time + char * pb = (char *)packed_B; + for (int k = 0; k < QK_K / 64; ++k) { + // pack 2 groups { n, g, k} to {g, k/4, 4n} + // e.g. {16, 2, 32} to {2, 8, 64} + for (int n = 0; n < TILE_N; ++n) { + v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32); + } + + transpose_16x16_32bit(v); + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < TILE_N; n += 2) { + _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); + pb += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { + __m512i v[16]; + const __m512i lowMask = _mm512_set1_epi8(0xF); + // QK_K 256 with 8 groups, handle 2 groups at a time + char * pb = (char *)packed_B; + char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; + for (int k = 0; k < QK_K / 64; ++k) { + // pack 2 groups { n, g, k} to {g, k/4, 4n} + // e.g. {16, 2, 32} to {2, 8, 64} + for (int n = 0; n < TILE_N; ++n) { + v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k); + } + + transpose_16x16_32bit(v); + + // 1. pack lower 4bits with 2 groups + for (int n = 0; n < TILE_N; n += 2) { + // get lower 4 bits + const __m512i r0 = _mm512_and_si512(v[n], lowMask); + const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); + _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; + } + + // 2. pack higher 1bit with 2 groups + const __m512i hmask = _mm512_set1_epi8(0x10); + for (int g = 0; g < 2; ++g) { + __m512i hbits = _mm512_setzero_si512(); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1)); + hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) ); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1)); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3)); + _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { + __m512i v[32]; + const __m512i lowMask = _mm512_set1_epi8(0xF); + // QK_K 256 with 8 groups, handle 4 groups at a time + char * pb = (char *)packed_B; + char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; + for (int k = 0; k < QK_K / 128; ++k) { + for (int n = 0; n < TILE_N; ++n) { + bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32); + } + + // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7 + transpose_16x16_32bit(v); + transpose_16x16_32bit(v + 16); + + // 1. pack lower 4bits with 4 groups + for (int n = 0; n < 32; n += 2) { + const __m512i r0 = _mm512_and_si512(v[n], lowMask); + const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); + _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; + } + + // 2. pack higher 2bit with 4 groups + const __m512i hmask = _mm512_set1_epi8(0x30); + for (int g = 0; g < 8; ++g) { + __m512i hbits = _mm512_setzero_si512(); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) ); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2)); + _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { + __m512i v[16]; + char * pb = (char *)packed_B; + for (int k = 0; k < QK_K / 64; ++k) { + for (int n = 0; n < TILE_N; ++n) { + __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0); + __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16); + v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); + } + + transpose_16x16_32bit(v); + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < TILE_N; n += 2) { + _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); + pb += 64; + } + } +} + +// pack B to vnni formats in 4bits or 8 bits +void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2); + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + } +} + +void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2); + ggml_half * m0 = d0 + TILE_N; + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + m0[n] = B[n * KB].m; + } +} + +inline void s8s8_compensation(void * RESTRICT packed_B) { + // packed_B layout: + // quants {TILE_N, TILEK} int8_t + // d0 {TILE_N} ggml_half + // comp {TILE_N} int32_t + const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); + __m512i vcomp = _mm512_setzero_si512(); + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < 8; ++k) { + __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64)); + vcomp = _mm512_dpbusd_epi32(vcomp, off, vb); + } + _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp); +} + +void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K); + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + } + s8s8_compensation(packed_B); +} + +// convert 8 * {min, scale} from int6 to int8 +inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) { + const uint32_t kmask1 = 0x3f3f3f3f; + const uint32_t kmask2 = 0x0f0f0f0f; + const uint32_t kmask3 = 0x03030303; + + memcpy(utmp, scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// scales {8, TILE_N} uint8 +// mins {8, TILE_N} uint8 +// d {TILE_N} ggml_half +// dmin {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N); + uint8_t * mins = scales + 8 * TILE_N; + ggml_half * d = reinterpret_cast(mins + 8 * TILE_N); + ggml_half * dmin = d + TILE_N; + + union { + uint32_t u32[4]; + uint8_t u8[16]; + } s; + + for (int n = 0; n < TILE_N; ++n) { + unpack_mins_and_scales(B[n * KB].scales, s.u32); + for (int k = 0; k < 8; ++k) { + scales[k * TILE_N + n] = s.u8[k]; + mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; + } + d[n] = B[n * KB].d; + dmin[n] = B[n * KB].dmin; + } +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// qh {8, TILE_N, 4} uint8 +// scales {8, TILE_N} uint8 +// mins {8, TILE_N} uint8 +// d {TILE_N} ggml_half +// dmin {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); + uint8_t * mins = scales + 8 * TILE_N; + ggml_half * d = reinterpret_cast(mins + 8 * TILE_N); + ggml_half * dmin = d + TILE_N; + + union { + uint32_t u32[4]; + uint8_t u8[16]; + } s; + + for (int n = 0; n < TILE_N; ++n) { + unpack_mins_and_scales(B[n * KB].scales, s.u32); + for (int k = 0; k < 8; ++k) { + scales[k * TILE_N + n] = s.u8[k]; + mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; + } + d[n] = B[n * KB].d; + dmin[n] = B[n * KB].dmin; + } +} + +// packed_B layout: +// quants {16, TILE_N, 8} uint8 +// qh {16, TILE_N, 4} uint8 +// scales {16, TILE_N} uint8 +// d {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); + ggml_half * d = reinterpret_cast(scales + 16 * TILE_N); + for (int n = 0; n < TILE_N; ++n) { + const int8_t * ps = B[n * KB].scales; + for (int k = 0; k < 16; ++k) { + scales[k * TILE_N + n] = ps[k]; + } + d[n] = B[n * KB].d; + } +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// scales {8, TILE_N} int8 +// d {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + int8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N); + ggml_half * d = reinterpret_cast(scales + 8 * TILE_N); + + // pack the scales + for (int n = 0; n < TILE_N; ++n) { + uint16_t sh = B[n * KB].scales_h; + for (int k = 0; k < 8; k += 2) { + const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32; + scales[(k + 0) * TILE_N + n] = ls1; + scales[(k + 1) * TILE_N + n] = ls2; + sh >>= 4; + } + d[n] = B[n * KB].d; + } +} + +template> +void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) { + GGML_UNUSED(tile); + GGML_UNUSED(packed_B); +}; + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B) { + const __m512i off = _mm512_set1_epi8(8); + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); + const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off); + const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) { + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); + const __m512i r0 = _mm512_and_si512(bytes, lowMask); + const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +// packed_B_t for QKK is int8_t +template +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + const int packed_B_group_size = QK_K / 2 * TILE_N / 8; + const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size; + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32); + const __m512i r0 = _mm512_and_si512(bytes, lowMask); + const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + // lower 4bits, stride 256 bytes + const int packed_l4_group_size = QK_K / 2 * TILE_N / 8; + const char * pb = (const char *)packed_B + k * packed_l4_group_size; + + // higher 1bit, stride 64 bytes + const int packed_h1_group_size = QK_K / 8 * TILE_N / 8; + const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size; + const __m512i hbits = _mm512_loadu_si512(ph); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + __m512i hmask0 = _mm512_set1_epi8(0x1); + __m512i hmask1 = _mm512_set1_epi8(0x2); + + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(pb + n * 32); + __m512i r0 = _mm512_and_si512(bytes, lowMask); + __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4); + __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4); + + hmask0 = _mm512_slli_epi16(hmask0, 2); + hmask1 = _mm512_slli_epi16(hmask1, 2); + r0 = _mm512_add_epi8(r0, h0); + r1 = _mm512_add_epi8(r1, h1); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + // lower 4bits, stride 128 bytes + const int packed_l4_group_size = QK_K / 2 * TILE_N / 16; + const char * pb = (const char *)packed_B + k * packed_l4_group_size; + + // higher 2bits, stride 64 bytes + const int packed_h2_group_size = QK_K / 4 * TILE_N / 16; + const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size; + const __m512i hbits = _mm512_loadu_si512(ph); + + const __m512i off = _mm512_set1_epi8(32); + const __m512i lowMask = _mm512_set1_epi8(0xF); + __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011 + __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100 + + // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A` + __m512i bytes = _mm512_loadu_si512(pb); + __m512i r0 = _mm512_and_si512(bytes, lowMask); + __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4); + __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2); + _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); + _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); + + hmask0 = _mm512_slli_epi16(hmask0, 4); + hmask1 = _mm512_slli_epi16(hmask1, 4); + + bytes = _mm512_loadu_si512(pb + 64); + r0 = _mm512_and_si512(bytes, lowMask); + r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + h0 = _mm512_and_si512(hbits, hmask0); + h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2); + _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); + _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + static const __m512i values128 = _mm512_set_epi8( + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 + ); + + const int packed_B_group_size = QK_K / 2 * TILE_N / 8; + const char * pb = (const char *)packed_B + k * packed_B_group_size; + const __m512i lowMask = _mm512_set1_epi8(0xF); + + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(pb + n * 32); + const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask)); + const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template +struct acc_C {}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half)))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + vsum = _mm512_fmadd_ps(vm0, vs1, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N); + const uint8_t * mins = scales + 8 * TILE_N; + const ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N); + const ggml_half * dmin = d0 + TILE_N; + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); + const uint8_t * mins = scales + 8 * TILE_N; + const ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N); + const ggml_half * dmin = d0 + TILE_N; + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); + const ggml_half * d0 = reinterpret_cast(scales + 16 * TILE_N); + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const int8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N); + const ggml_half * d0 = reinterpret_cast(scales + 8 * TILE_N); + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template constexpr int get_quants_size(); +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; } + +// used for QKK format +template ::value, int>::type = 0> +inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + get_quants_size()); + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N))); + + for (int m = 0; m < nr; ++m) { + __m512i vsumi; + if (is_acc) { + vsumi = _mm512_loadu_si512(sumi + m * TILE_N); + } else { + vsumi = _mm512_setzero_si512(); + } + __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N); + vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale)); + _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi); + } +} + +template +struct tinygemm_kernel_avx { + static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) { + GGML_UNUSED(K); + GGML_UNUSED(A); + GGML_UNUSED(B); + GGML_UNUSED(C); + GGML_UNUSED(ldc); + } +}; + +template +struct tinygemm_kernel_avx { + static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N; + assert(BLOCK_K == 16); + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](int idx) { + vc[idx] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int idx, int k) { + // TODO: use `constexpr` here to get rid of interger div + // when upgraded to C++17 + const int row = idx / COLS; + const int col = idx % COLS; + + if (col == 0) { + va = _mm512_loadu_ps(A + row * K + k); + } + if (row == 0) { + vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k))); + } + vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); + }; + + for (int k = 0; k < K; k += 16) { + Unroll{}(compute, k); + } + + auto storec = [&](int idx) { + const int row = idx / COLS; + const int col = idx % COLS; + C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]); + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_avx::apply( \ + K, (const float *)src1->data + mb_start * K, \ + (const type *)src0->data + nb_start * K, \ + (float *)dst->data + mb_start * ldc + nb_start, ldc); + + +// re-organize in the format {NB, KB, TILE_SIZE}: +#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size + +template +void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) { + const int NB = N / TILE_N; + const int KB = K / BLOCK_K; + const int TILE_SIZE = get_tile_size(); + + // parallel on NB should be enough + parallel_for(n_threads, NB, [&](int begin, int end) { + for (int n = begin; n < end; ++n) { + for (int k = 0; k < KB; ++k) { + int n0 = n * TILE_N; + pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB); + } + } + }); +} + +template +struct tinygemm_kernel_vnni {}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_0); + + const block_q8_0 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512 vc[COLS]; + __m512 vd1; + + // sum of offsets, shared across COLS + // + // avx512-vnni does not have `_mm512_dpbssd_epi32`, + // need to transfrom ss to us: + // a * (b - 8) is equavilent to b * a - 8 * a + // s u u u s u s + // + __m512i vcomp; + + const __m512i off = _mm512_set1_epi8(8); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + // load a and compute compensation + if (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + vcomp = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + } + + // load b + __m512i vsum = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; k += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]); + } + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + vsum = _mm512_sub_epi32(vsum, vcomp); + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_1); + + const block_q8_1 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512i vb[8]; + __m512 vc[COLS]; + __m512 vd1, vs1; + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + // load a + if (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s)); + } + + // load b + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; k += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); + vb[k + 0] = _mm512_and_si512(bytes, lowMask); + vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + } + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half)))); + + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]); + } + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t); + + const block_q8_0 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512i vb[8]; + __m512 vc[COLS]; + __m512 vd1; + + // Notes: s8s8 igemm compensation in avx512-vnni + // change s8s8 to u8s8 with compensate + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // (128 * b is pre-computed when packing B to vnni formats) + // + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + // load a and add offset 128 + if (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + va[k] = _mm512_add_epi8(va[k], off); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + } + + // load b + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; ++k) { + vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64)); + } + const int offset = TILE_N * TILE_K; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); + const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2)); + + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]); + } + vsum = _mm512_sub_epi32(vsum, vcomp); + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // a.qs: 8 groups, 32 bytes each group (m256i) + __m512i va[8]; + // a.bsum: 8 groups, 2 bytes each group (m128i) + __m512i va_bsum; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_scales = (QK_K / 2) * TILE_N; + const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N; + const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + // Notes: vnni formats in QK_K + // a) quants vnni format + // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32 + // from {16, 32} to {8, 64} + // + // b) min vnni format + // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8 + // from {16, 8} to {4, 32} + // + auto compute = [&](int col, int i) { + // load a + if (col == 0) { + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); + } + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + va_bsum = _mm512_castsi128_si512(q8s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // step 1: accumultate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); + + __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + + b_qs += 64; + } + // vacc += scale * (q8 @ q4) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + + // step 2: accumulate the mins + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); + vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // a.qs: 8 groups, 32 bytes each group (m256i) + __m512i va[8]; + // a.bsum: 8 groups, 2 bytes each group (m128i) + __m512i va_bsum; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_qh = (QK_K / 2) * TILE_N; + const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; + const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N; + const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + // Q5_K and Q4_K shares the same vnni formats, refer to notes above. + auto compute = [&](int col, int i) { + // load a + if (col == 0) { + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); + } + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + va_bsum = _mm512_castsi128_si512(q8s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // step 1: accumultate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + const char * b_qh = b_ptr + offset_qh; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + __m512i vsum = _mm512_setzero_si512(); + __m512i hmask0 = _mm512_set1_epi8(0x1); + __m512i hmask1 = _mm512_set1_epi8(0x2); + __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64)); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); + + __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + + __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4); + __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4); + + hmask0 = _mm512_slli_epi16(hmask0, 2); + hmask1 = _mm512_slli_epi16(hmask1, 2); + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + + b_qs += 64; + } + // vacc += scale * (q8 @ q5) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + + // step 2: accumulate the mins + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); + vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q6_K); + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // load the 256 bytes from A to 4 avx512 vectors + __m512i va[4]; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_qh = (QK_K / 2) * TILE_N; + const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N; + + // compensation + __m512i vcomp; + + const __m512i m32s = _mm512_set1_epi32(32); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + if (col == 0) { + // load a + va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); + va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); + va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); + va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // accmulate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + const char * b_qh = b_ptr + offset_qh; + int mask = 0; + for (int k_group = 0; k_group < QK_K / 16; ++k_group) { + int r = k_group >> 2; + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + __m512i vsum = _mm512_setzero_si512(); + __m512i hmask = _mm512_set1_epi8(0x3); + + __m512i bytes = _mm512_loadu_si512(b_qs); + __m512i hbits = _mm512_loadu_si512(b_qh); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4); + __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2); + + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + + va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + bytes = _mm512_loadu_si512(b_qs); + vb0 = _mm512_and_si512(bytes, lowMask); + vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4)); + vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2); + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + b_qh += 64; + + // B * A - 32 * A + __m512i vmask = _mm512_set1_epi32(k_group); + vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); + + // vacc += scale * (q8 @ q6) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // load the 256 bytes from A to 4 avx512 vectors + __m512i va[4]; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_scales = (QK_K / 2) * TILE_N ; + const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N; + + // compensation + __m512i vcomp; + + const __m256i m128s = _mm256_set1_epi16(128); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + const __m512i values128 = _mm512_set_epi8( + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 + ); + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + const __m512i values256 = _mm512_add_epi8(values128, off); + + auto loadc = [&](int col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](int col, int i) { + if (col == 0) { + // load a + va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); + va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); + va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); + va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); + + // compensation: 128 * A + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s)); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // accmulate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + int mask = 0; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + int r = k_group >> 1; + __m512i vmask = _mm512_set1_epi32(k_group); + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + __m512i bytes = _mm512_loadu_si512(b_qs); + __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask)); + __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); + + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + } + // (B + 128) * A - 128 * A + vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); + + // vacc += scale * (q8 @ q4) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + KB, (const char *)wdata + 0 * row_size_A, \ + (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ + (float *) dst->data + 0 * N + nb_start, ldc) + +template ::value, int>::type = 0> +void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) { + using packed_B_t = packed_B_type; + const int TILE_SIZE = get_tile_size(); + const bool need_unpack = do_unpack::value; + + GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); + const TA * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + const int m0 = std::min(M, TILE_M); + const int m1 = std::max(M - TILE_M, 0); + const int lda = KB * sizeof(TA); + //const int ldb = KB * sizeof(TB); + + static thread_local packed_B_t Tile0[TILE_N * TILE_K]; + static thread_local packed_B_t Tile1[TILE_N * TILE_K]; + static thread_local int8_t Tile23[TILE_M * TILE_K]; + + static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; + static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; + + // double buffering C to interleave avx512 and amx + int32_t * C_cur = TileC0; + int32_t * C_pre = TileC1; + + auto Tile4 = [&](int32_t * base) { return base; }; + auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; }; + auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; }; + auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; }; + + if (M == 2 * TILE_M) { + // i = 0 + const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE); + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + + _tile_zero(TMM4); + _tile_loadd(TMM2, A[0].qs, lda); + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM5); + _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda); + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); + + if (need_unpack) { + unpack_B(Tile1, B_blk0); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + + _tile_zero(TMM6); + _tile_dpbssd(TMM6, TMM2, TMM1); + _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM7); + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t)); + + for (int i = 1; i < KB; ++i) { + // index of previous iter + const int ii = i - 1; + const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); + GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] { + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + _tile_zero(TMM4); + _tile_loadd(TMM2, A[i].qs, lda); + acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM5); + _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda); + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); + + if (need_unpack) { + unpack_B(Tile1, B_blk1); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + _tile_zero(TMM6); + acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM6, TMM2, TMM1); + _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM7); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); + + std::swap(C_cur, C_pre); + }); + } + // final accumulation + { + int ii = KB - 1; + acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + } + } else { + for (int i = 0; i < KB; ++i) { + _tile_zero(TMM4); + _tile_zero(TMM6); + if (m1 != 0) { + _tile_zero(TMM5); + _tile_zero(TMM7); + } + + const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + + if (need_unpack) { + unpack_B(Tile1, B_blk1); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + + if (m0 == TILE_M) { + _tile_loadd(TMM2, A[i].qs, lda); + } else { + unpack_A(Tile23, &A[i], KB, m0); + _tile_loadd(TMM2, Tile23, TILE_K); + } + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_dpbssd(TMM6, TMM2, TMM1); + + _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); + _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); + + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); + acc_C::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); + }); + + if (m1 != 0) { + unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1); + _tile_loadd(TMM3, Tile23, TILE_K); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); + _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); + }); + } + } + } + return; +} + +template ::value, int>::type = 0> +void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + static_assert(std::is_same::value); + const int TILE_SIZE = get_tile_size(); + + GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); + const TA * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + const int m0 = std::min(M, TILE_M); + const int m1 = std::max(M - TILE_M, 0); + //const int lda = KB * sizeof(TA); + + static thread_local int8_t Tile0[TILE_N * TILE_K]; + static thread_local int8_t Tile1[TILE_N * TILE_K]; + static thread_local int8_t Tile23[TILE_M * TILE_K]; + + // mat mul result for each group + static thread_local int32_t Tile4[TILE_M * TILE_N]; + static thread_local int32_t Tile5[TILE_M * TILE_N]; + static thread_local int32_t Tile6[TILE_M * TILE_N]; + static thread_local int32_t Tile7[TILE_M * TILE_N]; + + // sum of each QK_K block, contains 8 groups, int32 + static thread_local int32_t Sumi4[TILE_M * TILE_N]; + static thread_local int32_t Sumi5[TILE_M * TILE_N]; + static thread_local int32_t Sumi6[TILE_M * TILE_N]; + static thread_local int32_t Sumi7[TILE_M * TILE_N]; + + const int k_group_size = std::is_same::value ? 16 : 32; + for (int i = 0; i < KB; ++i) { + // step 1: accumulate the quants across 8 groups, each group with 32 + for (int k = 0; k < QK_K / k_group_size; ++k) { + GGML_DISPATCH_BOOL(k > 0, is_acc, [&] { + _tile_zero(TMM4); + _tile_zero(TMM6); + + unpack_B(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + + unpack_B(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + + unpack_A(Tile23, &A[i], KB, k, m0); + _tile_loadd(TMM2, Tile23, TILE_K); + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_dpbssd(TMM6, TMM2, TMM1); + + _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + + scale_C(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0); + scale_C(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0); + + if (m1 != 0) { + _tile_zero(TMM5); + _tile_zero(TMM7); + + unpack_A(Tile23, &A[TILE_M * KB + i], KB, k, m1); + _tile_loadd(TMM3, Tile23, TILE_K); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_dpbssd(TMM7, TMM3, TMM1); + + _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + + scale_C(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1); + scale_C(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1); + } + }); + } + + // step 2: accmulate the mins + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); + acc_C::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); + if (m1 != 0) { + acc_C::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); + } + }); + } + return; +} + +} // anonymous namespace + +// get the packed tensor size for quantized weights +size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) { + const enum ggml_type TYPE = tensor->type; + + const int K = tensor->ne[0]; // ne0: in_features + const int N = tensor->ne[1]; // ne1: out_features + + auto get_tensor_size = [&] { + size_t row_size_B{0}; + GGML_DISPATCH_QTYPES(TYPE, [&] { + row_size_B = get_row_size(K); + }); + return N * row_size_B; + }; + + if (qtype_has_amx_kernels(TYPE)) { + return get_tensor_size(); + } else { + // for f16, bf16 we don't do packing + return ggml_nbytes(tensor); + } +} + +// pack weight to vnni format +void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + + size_t alloc_size = ggml_backend_amx_get_alloc_size(tensor); + GGML_ASSERT(alloc_size == size); + + const enum ggml_type TYPE = tensor->type; + + const int K = tensor->ne[0]; // ne0: in_features + const int N = tensor->ne[1]; // ne1: out_features + +#if defined(_OPENMP) + // the buffer ctx is not initialized when .set_tensor is called + int n_threads = omp_get_num_threads(); +#else + int n_threads = 1; +#endif + + GGML_DISPATCH_QTYPES(TYPE, [&] { + convert_B_packed_format((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads); + }); +} + +// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX) +// +// src0: weight in shape of {N, K}, quantized +// src1: input in shape of {M, K}, float32 +// dst: output in shape of {M, N}, float32 +// +// the function performs: dst = src1 @ src0.T +// +void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) { + struct ggml_tensor * src0 = dst->src[0]; + struct ggml_tensor * src1 = dst->src[1]; + + const enum ggml_type TYPE = src0->type; + + const int n_threads = ctx->n_threads; + + // f16 only has avx512 kernels for now, + // amx kernels will be added once 6th gen xeon is released. + const bool is_floating_type = TYPE == GGML_TYPE_F16; + + const int M = dst->ne[1]; + const int N = dst->ne[0]; + const int K = src0->ne[0]; + const int ldc = dst->nb[1] / dst->nb[0]; + + if (is_floating_type) { + constexpr int BLOCK_M = 4; + constexpr int BLOCK_N = 6; + const int MB = div_up(M, BLOCK_M); + const int NB = div_up(N, BLOCK_N); + + parallel_for(n_threads, MB * NB, [&](int begin, int end) { + GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { + for (int i = begin; i < end; ++i) { + int mb = i / NB; + int nb = i % NB; + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break; + case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break; + case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break; + case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break; + case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break; + default: fprintf(stderr, "Unexpected block size!\n"); + } + } + }); + }); + return; + } + + // pointer to work space, used convert A from float to quantized type + void * wdata = nullptr; + + //TODO: performance improvement: merge quant A + GGML_DISPATCH_QTYPES(TYPE, [&] { + const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); + const size_t desired_wsize = M * row_size_A; + if (ctx->work_size < desired_wsize) { + ctx->work_data.reset(new char[desired_wsize]); + ctx->work_size = desired_wsize; + } + wdata = ctx->work_data.get(); + + // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size + // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size + GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); + + const float * A_data = static_cast(src1->data); + for (int m = 0; m < M; ++m) { + from_float(A_data + m * K, (char *)wdata + m * row_size_A, K); + } + }); + + if (M == 1) { + // MB = 1 and handle 8 tiles in each block + constexpr int kTilesN = 4; + constexpr int BLOCK_N = TILE_N * kTilesN; + const int NB = div_up(N, BLOCK_N); + + parallel_for(n_threads, NB, [&](int begin, int end) { + GGML_DISPATCH_QTYPES(TYPE, [&] { + const int KB = K / blck_size; + const int TILE_SIZE = get_tile_size(); + const int row_size_A = KB * sizeof(vec_dot_type); + for (int i = begin; i < end; ++i) { + int nb = i; + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 + + switch (nb_size) { + //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break; + case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break; + case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break; + case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break; + case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break; + default: fprintf(stderr, "Unexpected n block size!\n"); + } + } + }); + }); + return; + } + + // handle 4 tiles at a tile + constexpr int BLOCK_M = TILE_M * 2; + constexpr int BLOCK_N = TILE_N * 2; + const int MB = div_up(M, BLOCK_M); + const int NB = div_up(N, BLOCK_N); + + parallel_for(n_threads, MB * NB, [&](int begin, int end) { + // init tile config for each thread + ggml_tile_config_init(); + + GGML_DISPATCH_QTYPES(TYPE, [&] { + const int KB = K / blck_size; + const int TILE_SIZE = get_tile_size(); + const int row_size_A = KB * sizeof(vec_dot_type); + + for (int i = begin; i < end; ++i) { + int mb = i / NB; + int nb = i % NB; + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = BLOCK_N; + + tinygemm_kernel_amx( + mb_size, nb_size, KB, + (const char *)wdata + mb_start * row_size_A, + (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), + (float *) dst->data + mb_start * N + nb_start, ldc); + } + }); + }); +} + +#else // if defined(__AMX_INT8__) + +void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) { + fprintf(stderr, "GGML is not compiled with AMX support!\n"); + + GGML_UNUSED(ctx); + GGML_UNUSED(dst); +} + +#endif // if defined(__AMX_INT8__) diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/mmq.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/mmq.h new file mode 100644 index 0000000..cf09206 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-amx/mmq.h @@ -0,0 +1,17 @@ +#pragma once +#include "common.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor); + +void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + +void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst); + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend-impl.h b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend-impl.h new file mode 100644 index 0000000..c36c12d --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend-impl.h @@ -0,0 +1,255 @@ +#pragma once + +// ggml-backend internal header + +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + + #define GGML_BACKEND_API_VERSION 1 + + // + // Backend buffer type + // + + struct ggml_backend_buffer_type_i { + const char * (*get_name) (ggml_backend_buffer_type_t buft); + // allocate a buffer of this type + ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size); + // tensor alignment + size_t (*get_alignment) (ggml_backend_buffer_type_t buft); + // (optional) max buffer size that can be allocated (defaults to SIZE_MAX) + size_t (*get_max_size) (ggml_backend_buffer_type_t buft); + // (optional) data size needed to allocate the tensor, including padding (defaults to ggml_nbytes) + size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); + // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false) + bool (*is_host) (ggml_backend_buffer_type_t buft); + }; + + struct ggml_backend_buffer_type { + struct ggml_backend_buffer_type_i iface; + ggml_backend_dev_t device; + void * context; + }; + + // + // Backend buffer + // + + struct ggml_backend_buffer_i { + // (optional) free the buffer + void (*free_buffer) (ggml_backend_buffer_t buffer); + // base address of the buffer + void * (*get_base) (ggml_backend_buffer_t buffer); + // (optional) initialize a tensor in the buffer (eg. add tensor extras) + enum ggml_status (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + // tensor data access + void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) + bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); + // clear the entire buffer + void (*clear) (ggml_backend_buffer_t buffer, uint8_t value); + // (optional) reset any internal state due to tensor initialization, such as tensor extras + void (*reset) (ggml_backend_buffer_t buffer); + }; + + struct ggml_backend_buffer { + struct ggml_backend_buffer_i iface; + ggml_backend_buffer_type_t buft; + void * context; + size_t size; + enum ggml_backend_buffer_usage usage; + }; + + GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, + struct ggml_backend_buffer_i iface, + void * context, + size_t size); + + // do not use directly, use ggml_backend_tensor_copy instead + GGML_API bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst); + + // multi-buffer + // buffer that contains a collection of buffers + GGML_API ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers); + GGML_API bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + + // + // Backend (stream) + // + + struct ggml_backend_i { + const char * (*get_name)(ggml_backend_t backend); + + void (*free)(ggml_backend_t backend); + + // (optional) asynchronous tensor data access + void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); + + // (optional) complete all pending operations (required if the backend supports async operations) + void (*synchronize)(ggml_backend_t backend); + + // (optional) graph plans (not used currently) + // compute graph with a plan + ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph); + void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology + void (*graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph); + // compute the graph with the plan + enum ggml_status (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); + + // compute graph (always async if supported by the backend) + enum ggml_status (*graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph); + + // (optional) event synchronization + // record an event on this stream + void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); + // wait for an event on on a different stream + void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + }; + + struct ggml_backend { + ggml_guid_t guid; + struct ggml_backend_i iface; + ggml_backend_dev_t device; + void * context; + }; + + struct ggml_backend_event { + struct ggml_backend_device * device; + void * context; + }; + + // + // Backend device + // + + // Note: if additional properties are needed, we should add a struct with all of them + // the current functions to obtain the properties can remain, since they are more convenient for often used properties + struct ggml_backend_device_i { + // device name: short identifier for this device, such as "CPU" or "CUDA0" + const char * (*get_name)(ggml_backend_dev_t dev); + + // device description: short informative description of the device, could be the model name + const char * (*get_description)(ggml_backend_dev_t dev); + + // device memory in bytes + void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total); + + // device type + enum ggml_backend_dev_type (*get_type)(ggml_backend_dev_t dev); + + // device properties + void (*get_props)(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props); + + // backend (stream) initialization + ggml_backend_t (*init_backend)(ggml_backend_dev_t dev, const char * params); + + // preferred buffer type + ggml_backend_buffer_type_t (*get_buffer_type)(ggml_backend_dev_t dev); + + // (optional) host buffer type (in system memory, typically this is a pinned memory buffer for faster transfers between host and device) + ggml_backend_buffer_type_t (*get_host_buffer_type)(ggml_backend_dev_t dev); + + // (optional) buffer from pointer: create a buffer from a host pointer (useful for memory mapped models and importing data from other libraries) + ggml_backend_buffer_t (*buffer_from_host_ptr)(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size); + + // check if the backend can compute an operation + bool (*supports_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op); + + // check if the backend can use tensors allocated in a buffer type + bool (*supports_buft)(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft); + + // (optional) check if the backend wants to run an operation, even if the weights are allocated in an incompatible buffer + // these should be expensive operations that may benefit from running on this backend instead of the CPU backend + bool (*offload_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op); + + // (optional) event synchronization + ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev); + void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event); + void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event); + }; + + struct ggml_backend_device { + struct ggml_backend_device_i iface; + ggml_backend_reg_t reg; + void * context; + }; + + // + // Backend (reg) + // + + struct ggml_backend_reg_i { + const char * (*get_name)(ggml_backend_reg_t reg); + + // enumerate available devices + size_t (*get_device_count)(ggml_backend_reg_t reg); + ggml_backend_dev_t (*get_device)(ggml_backend_reg_t reg, size_t index); + + // (optional) get a pointer to a function in the backend + // backends can add custom functions that are not part of the standard ggml-backend interface + void * (*get_proc_address)(ggml_backend_reg_t reg, const char * name); + }; + + struct ggml_backend_reg { + int api_version; // initialize to GGML_BACKEND_API_VERSION + struct ggml_backend_reg_i iface; + void * context; + }; + + // Internal backend registry API + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); + + // Add backend dynamic loading support to the backend + + // Initialize the backend + typedef ggml_backend_reg_t (*ggml_backend_init_t)(void); + // Optional: obtain a score for the backend based on the system configuration + // Higher scores are preferred, 0 means the backend is not supported in the current system + typedef int (*ggml_backend_score_t)(void); + +#ifdef GGML_BACKEND_DL +# ifdef __cplusplus +# define GGML_BACKEND_DL_IMPL(reg_fn) \ + extern "C" { \ + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void); \ + } \ + ggml_backend_reg_t ggml_backend_init(void) { \ + return reg_fn(); \ + } +# define GGML_BACKEND_DL_SCORE_IMPL(score_fn) \ + extern "C" { \ + GGML_BACKEND_API int ggml_backend_score(void); \ + } \ + int ggml_backend_score(void) { \ + return score_fn(); \ + } +# else +# define GGML_BACKEND_DL_IMPL(reg_fn) \ + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void); \ + ggml_backend_reg_t ggml_backend_init(void) { \ + return reg_fn(); \ + } +# define GGML_BACKEND_DL_SCORE_IMPL(score_fn) \ + GGML_BACKEND_API int ggml_backend_score(void); \ + int ggml_backend_score(void) { \ + return score_fn(); \ + } +# endif +#else +# define GGML_BACKEND_DL_IMPL(reg_fn) +# define GGML_BACKEND_DL_SCORE_IMPL(score_fn) +#endif + +#ifdef __cplusplus +} +#endif diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend-reg.cpp b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend-reg.cpp new file mode 100644 index 0000000..5f02a71 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend-reg.cpp @@ -0,0 +1,600 @@ +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#elif defined(__APPLE__) +# include +# include +#else +# include +# include +#endif + +// Backend registry +#ifdef GGML_USE_CPU +#include "ggml-cpu.h" +#endif + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#ifdef GGML_USE_SYCL +#include "ggml-sycl.h" +#endif + +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif + +#ifdef GGML_USE_WEBGPU +#include "ggml-webgpu.h" +#endif + +#ifdef GGML_USE_ZDNN +#include "ggml-zdnn.h" +#endif + +#ifdef GGML_USE_OPENCL +#include "ggml-opencl.h" +#endif + +#ifdef GGML_USE_BLAS +#include "ggml-blas.h" +#endif + +#ifdef GGML_USE_RPC +#include "ggml-rpc.h" +#endif + +#ifdef GGML_USE_CANN +#include "ggml-cann.h" +#endif + +// disable C++17 deprecation warning for std::codecvt_utf8 +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdeprecated-declarations" +#elif defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif + +namespace fs = std::filesystem; + +static std::string path_str(const fs::path & path) { + std::string u8path; + try { +#if defined(__cpp_lib_char8_t) + // C++20 and later: u8string() returns std::u8string + std::u8string u8str = path.u8string(); + u8path = std::string(reinterpret_cast(u8str.c_str())); +#else + // C++17: u8string() returns std::string + u8path = path.u8string(); +#endif + } catch (...) { + } + return u8path; +} + +#if defined(__clang__) +# pragma clang diagnostic pop +#elif defined(__GNUC__) +# pragma GCC diagnostic pop +#endif + +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +static dl_handle * dl_load_library(const fs::path & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.wstring().c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +static void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +static void * dl_load_library(const fs::path & path) { + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); + + return handle; +} + +static void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +#endif + +using dl_handle_ptr = std::unique_ptr; + +struct ggml_backend_reg_entry { + ggml_backend_reg_t reg; + dl_handle_ptr handle; +}; + +struct ggml_backend_registry { + std::vector backends; + std::vector devices; + + ggml_backend_registry() { +#ifdef GGML_USE_CUDA + register_backend(ggml_backend_cuda_reg()); +#endif +#ifdef GGML_USE_METAL + register_backend(ggml_backend_metal_reg()); +#endif +#ifdef GGML_USE_SYCL + register_backend(ggml_backend_sycl_reg()); +#endif +#ifdef GGML_USE_VULKAN + register_backend(ggml_backend_vk_reg()); +#endif +#ifdef GGML_USE_WEBGPU + register_backend(ggml_backend_webgpu_reg()); +#endif +#ifdef GGML_USE_ZDNN + register_backend(ggml_backend_zdnn_reg()); +#endif +#ifdef GGML_USE_OPENCL + register_backend(ggml_backend_opencl_reg()); +#endif +#ifdef GGML_USE_CANN + register_backend(ggml_backend_cann_reg()); +#endif +#ifdef GGML_USE_BLAS + register_backend(ggml_backend_blas_reg()); +#endif +#ifdef GGML_USE_RPC + register_backend(ggml_backend_rpc_reg()); +#endif +#ifdef GGML_USE_CPU + register_backend(ggml_backend_cpu_reg()); +#endif + } + + ~ggml_backend_registry() { + // FIXME: backends cannot be safely unloaded without a function to destroy all the backend resources, + // since backend threads may still be running and accessing resources from the dynamic library + for (auto & entry : backends) { + if (entry.handle) { + entry.handle.release(); // NOLINT + } + } + } + + void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) { + if (!reg) { + return; + } + +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", + __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg)); +#endif + backends.push_back({ reg, std::move(handle) }); + for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { + register_device(ggml_backend_reg_dev_get(reg, i)); + } + } + + void register_device(ggml_backend_dev_t device) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device)); +#endif + devices.push_back(device); + } + + ggml_backend_reg_t load_backend(const fs::path & path, bool silent) { + dl_handle_ptr handle { dl_load_library(path) }; + if (!handle) { + if (!silent) { + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str()); + } + return nullptr; + } + + auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); + if (score_fn && score_fn() == 0) { + if (!silent) { + GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_str(path).c_str()); + } + return nullptr; + } + + auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init"); + if (!backend_init_fn) { + if (!silent) { + GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_str(path).c_str()); + } + return nullptr; + } + + ggml_backend_reg_t reg = backend_init_fn(); + if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) { + if (!silent) { + if (!reg) { + GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", + __func__, path_str(path).c_str()); + } else { + GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n", + __func__, path_str(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); + } + } + return nullptr; + } + + GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str()); + + register_backend(reg, std::move(handle)); + + return reg; + } + + void unload_backend(ggml_backend_reg_t reg, bool silent) { + auto it = std::find_if(backends.begin(), backends.end(), + [reg](const ggml_backend_reg_entry & entry) { return entry.reg == reg; }); + + if (it == backends.end()) { + if (!silent) { + GGML_LOG_ERROR("%s: backend not found\n", __func__); + } + return; + } + + if (!silent) { + GGML_LOG_DEBUG("%s: unloading %s backend\n", __func__, ggml_backend_reg_name(reg)); + } + + // remove devices + devices.erase( + std::remove_if(devices.begin(), devices.end(), + [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }), + devices.end()); + + // remove backend + backends.erase(it); + } +}; + +static ggml_backend_registry & get_reg() { + static ggml_backend_registry reg; + return reg; +} + +// Internal API +void ggml_backend_register(ggml_backend_reg_t reg) { + get_reg().register_backend(reg); +} + +void ggml_backend_device_register(ggml_backend_dev_t device) { + get_reg().register_device(device); +} + +// Backend (reg) enumeration +static bool striequals(const char * a, const char * b) { + for (; *a && *b; a++, b++) { + if (std::tolower(*a) != std::tolower(*b)) { + return false; + } + } + return *a == *b; +} + +size_t ggml_backend_reg_count() { + return get_reg().backends.size(); +} + +ggml_backend_reg_t ggml_backend_reg_get(size_t index) { + GGML_ASSERT(index < ggml_backend_reg_count()); + return get_reg().backends[index].reg; +} + +ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) { + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + ggml_backend_reg_t reg = ggml_backend_reg_get(i); + if (striequals(ggml_backend_reg_name(reg), name)) { + return reg; + } + } + return nullptr; +} + +// Device enumeration +size_t ggml_backend_dev_count() { + return get_reg().devices.size(); +} + +ggml_backend_dev_t ggml_backend_dev_get(size_t index) { + GGML_ASSERT(index < ggml_backend_dev_count()); + return get_reg().devices[index]; +} + +ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (striequals(ggml_backend_dev_name(dev), name)) { + return dev; + } + } + return nullptr; +} + +ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == type) { + return dev; + } + } + return nullptr; +} + +// Convenience functions +ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) { + ggml_backend_dev_t dev = ggml_backend_dev_by_name(name); + if (!dev) { + return nullptr; + } + return ggml_backend_dev_init(dev, params); +} + +ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(type); + if (!dev) { + return nullptr; + } + return ggml_backend_dev_init(dev, params); +} + +ggml_backend_t ggml_backend_init_best(void) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); + if (!dev) { + dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + } + if (!dev) { + return nullptr; + } + return ggml_backend_dev_init(dev, nullptr); +} + +// Dynamic loading +ggml_backend_reg_t ggml_backend_load(const char * path) { + return get_reg().load_backend(path, false); +} + +void ggml_backend_unload(ggml_backend_reg_t reg) { + get_reg().unload_backend(reg, true); +} + +static fs::path get_executable_path() { +#if defined(__APPLE__) + // get executable path + std::vector path; + uint32_t size; + while (true) { + size = path.size(); + if (_NSGetExecutablePath(path.data(), &size) == 0) { + break; + } + path.resize(size); + } + std::string base_path(path.data(), size); + // remove executable name + auto last_slash = base_path.find_last_of('/'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + return base_path + "/"; +#elif defined(__linux__) || defined(__FreeBSD__) + std::string base_path = "."; + std::vector path(1024); + while (true) { + // get executable path +# if defined(__linux__) + ssize_t len = readlink("/proc/self/exe", path.data(), path.size()); +# elif defined(__FreeBSD__) + ssize_t len = readlink("/proc/curproc/file", path.data(), path.size()); +# endif + if (len == -1) { + break; + } + if (len < (ssize_t) path.size()) { + base_path = std::string(path.data(), len); + // remove executable name + auto last_slash = base_path.find_last_of('/'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + break; + } + path.resize(path.size() * 2); + } + + return base_path + "/"; +#elif defined(_WIN32) + std::vector path(MAX_PATH); + DWORD len = GetModuleFileNameW(NULL, path.data(), path.size()); + if (len == 0) { + return {}; + } + std::wstring base_path(path.data(), len); + // remove executable name + auto last_slash = base_path.find_last_of('\\'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + return base_path + L"\\"; +#else + return {}; +#endif +} + +static fs::path backend_filename_prefix() { +#ifdef _WIN32 + return fs::u8path("ggml-"); +#else + return fs::u8path("libggml-"); +#endif +} + +static fs::path backend_filename_extension() { +#ifdef _WIN32 + return fs::u8path(".dll"); +#else + return fs::u8path(".so"); +#endif +} + +static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) { + // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths + const fs::path name_path = fs::u8path(name); + const fs::path file_prefix = backend_filename_prefix().native() + name_path.native() + fs::u8path("-").native(); + const fs::path file_extension = backend_filename_extension(); + + std::vector search_paths; + if (user_search_path == nullptr) { +#ifdef GGML_BACKEND_DIR + search_paths.push_back(fs::u8path(GGML_BACKEND_DIR)); +#endif + // default search paths: executable directory, current directory + search_paths.push_back(get_executable_path()); + search_paths.push_back(fs::current_path()); + } else { + search_paths.push_back(fs::u8path(user_search_path)); + } + + int best_score = 0; + fs::path best_path; + + for (const auto & search_path : search_paths) { + if (!fs::exists(search_path)) { + GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str()); + continue; + } + fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); + for (const auto & entry : dir_it) { + if (entry.is_regular_file()) { + auto filename = entry.path().filename(); + auto ext = entry.path().extension(); + if (filename.native().find(file_prefix) == 0 && ext == file_extension) { + dl_handle_ptr handle { dl_load_library(entry) }; + if (!handle && !silent) { + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str()); + } + if (handle) { + auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); + if (score_fn) { + int s = score_fn(); +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_str(entry.path()).c_str(), s); +#endif + if (s > best_score) { + best_score = s; + best_path = entry.path(); + } + } else { + if (!silent) { + GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, path_str(entry.path()).c_str()); + } + } + } + } + } + } + } + + if (best_score == 0) { + // try to load the base backend + for (const auto & search_path : search_paths) { + fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native(); + fs::path path = search_path / filename; + if (fs::exists(path)) { + return get_reg().load_backend(path, silent); + } + } + return nullptr; + } + + return get_reg().load_backend(best_path, silent); +} + +void ggml_backend_load_all() { + ggml_backend_load_all_from_path(nullptr); +} + +void ggml_backend_load_all_from_path(const char * dir_path) { +#ifdef NDEBUG + bool silent = true; +#else + bool silent = false; +#endif + + ggml_backend_load_best("blas", silent, dir_path); + ggml_backend_load_best("cann", silent, dir_path); + ggml_backend_load_best("cuda", silent, dir_path); + ggml_backend_load_best("hip", silent, dir_path); + ggml_backend_load_best("metal", silent, dir_path); + ggml_backend_load_best("rpc", silent, dir_path); + ggml_backend_load_best("sycl", silent, dir_path); + ggml_backend_load_best("vulkan", silent, dir_path); + ggml_backend_load_best("opencl", silent, dir_path); + ggml_backend_load_best("musa", silent, dir_path); + ggml_backend_load_best("cpu", silent, dir_path); + // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend + const char * backend_path = std::getenv("GGML_BACKEND_PATH"); + if (backend_path) { + ggml_backend_load(backend_path); + } +} diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend.cpp b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend.cpp new file mode 100644 index 0000000..1b9d29e --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-backend.cpp @@ -0,0 +1,2027 @@ +// Note: porting this file to C++ is a work in progress + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __APPLE__ +#include +#include +#endif + + +// backend buffer type + +const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name(buft); +} + +ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + if (size == 0) { + // return a dummy buffer for zero-sized allocations + return ggml_backend_buffer_init(buft, {}, NULL, 0); + } + + return buft->iface.alloc_buffer(buft, size); +} + +size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) { + return buft->iface.get_alignment(buft); +} + +size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { + // get_max_size is optional, defaults to SIZE_MAX + if (buft->iface.get_max_size) { + return buft->iface.get_max_size(buft); + } + return SIZE_MAX; +} + +size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + // get_alloc_size is optional, defaults to ggml_nbytes + if (buft->iface.get_alloc_size) { + size_t size = buft->iface.get_alloc_size(buft, tensor); + assert(size >= ggml_nbytes(tensor)); + return size; + } + return ggml_nbytes(tensor); +} + +bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { + if (buft->iface.is_host) { + return buft->iface.is_host(buft); + } + return false; +} + +ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) { + return buft->device; +} + +// backend buffer + +ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, + struct ggml_backend_buffer_i iface, + void * context, + size_t size) { + ggml_backend_buffer_t buffer = new ggml_backend_buffer { + /* .interface = */ iface, + /* .buft = */ buft, + /* .context = */ context, + /* .size = */ size, + /* .usage = */ GGML_BACKEND_BUFFER_USAGE_ANY + }; + + return buffer; +} + +const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_name(ggml_backend_buffer_get_type(buffer)); +} + +void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { + if (buffer == NULL) { + return; + } + + if (buffer->iface.free_buffer != NULL) { + buffer->iface.free_buffer(buffer); + } + delete buffer; +} + +size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { + return buffer->size; +} + +void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { + // get_base is optional if the buffer is zero-sized + if (buffer->size == 0) { + return NULL; + } + + void * base = buffer->iface.get_base(buffer); + + GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); + + return base; +} + +enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + // init_tensor is optional + if (buffer->iface.init_tensor) { + return buffer->iface.init_tensor(buffer, tensor); + } + return GGML_STATUS_SUCCESS; +} + +void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + // clear is optional if the buffer is zero-sized + if (buffer->size == 0) { + return; + } + + buffer->iface.clear(buffer, value); +} + +size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer)); +} + +size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer)); +} + +size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor) { + return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor); +} + +bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer)); +} + +void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + buffer->usage = usage; + + // FIXME: add a generic callback to the buffer interface + if (ggml_backend_buffer_is_multi_buffer(buffer)) { + ggml_backend_multi_buffer_set_usage(buffer, usage); + } +} + +enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) { + return buffer->usage; +} + +ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) { + return buffer->buft; +} + +void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) { + if (buffer->iface.reset) { + buffer->iface.reset(buffer); + } +} + +bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer; + if (dst_buf->iface.cpy_tensor) { + return dst_buf->iface.cpy_tensor(dst_buf, src, dst); + } + return false; +} + +// backend + +ggml_guid_t ggml_backend_guid(ggml_backend_t backend) { + if (backend == NULL) { + return NULL; + } + return backend->guid; +} + +const char * ggml_backend_name(ggml_backend_t backend) { + if (backend == NULL) { + return "NULL"; + } + return backend->iface.get_name(backend); +} + +void ggml_backend_free(ggml_backend_t backend) { + if (backend == NULL) { + return; + } + + backend->iface.free(backend); +} + +ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) { + return ggml_backend_dev_buffer_type(backend->device); +} + +ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) { + return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size); +} + +size_t ggml_backend_get_alignment(ggml_backend_t backend) { + return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend)); +} + +size_t ggml_backend_get_max_size(ggml_backend_t backend) { + return ggml_backend_buft_get_max_size(ggml_backend_get_default_buffer_type(backend)); +} + +void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + if (backend->iface.set_tensor_async == NULL) { + ggml_backend_tensor_set(tensor, data, offset, size); + } else { + backend->iface.set_tensor_async(backend, tensor, data, offset, size); + } +} + +void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + if (backend->iface.get_tensor_async == NULL) { + ggml_backend_tensor_get(tensor, data, offset, size); + } else { + backend->iface.get_tensor_async(backend, tensor, data, offset, size); + } +} + +void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + if (size == 0) { + return; + } + + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + buf->iface.set_tensor(buf, tensor, data, offset, size); +} + +void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + if (size == 0) { + return; + } + + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + buf->iface.get_tensor(buf, tensor, data, offset, size); +} + +void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + if (size == 0) { + return; + } + + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not implemented by backend buffer"); + + buf->iface.memset_tensor(buf, tensor, value, offset, size); +} + +void ggml_backend_synchronize(ggml_backend_t backend) { + if (backend->iface.synchronize == NULL) { + return; + } + + backend->iface.synchronize(backend); +} + +ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend->iface.graph_plan_create != NULL); + + return backend->iface.graph_plan_create(backend, cgraph); +} + +void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend->iface.graph_plan_free != NULL); + + backend->iface.graph_plan_free(backend, plan); +} + +enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend->iface.graph_plan_compute != NULL); + + return backend->iface.graph_plan_compute(backend, plan); +} + +enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph); + ggml_backend_synchronize(backend); + return err; +} + +enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + return backend->iface.graph_compute(backend, cgraph); +} + +bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + return ggml_backend_dev_supports_op(backend->device, op); +} + +bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + return ggml_backend_dev_supports_buft(backend->device, buft); +} + +bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) { + return ggml_backend_dev_offload_op(backend->device, op); +} + +ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { + return backend->device; +} + +// backend copy + +void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); + + if (src == dst) { + return; + } + + if (ggml_backend_buffer_is_host(src->buffer)) { + ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src)); + } else if (ggml_backend_buffer_is_host(dst->buffer)) { + ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); + } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); +#endif + size_t nbytes = ggml_nbytes(src); + void * data = malloc(nbytes); + ggml_backend_tensor_get(src, data, 0, nbytes); + ggml_backend_tensor_set(dst, data, 0, nbytes); + free(data); + } +} + +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); + + if (src == dst) { + return; + } + + if (backend_dst->iface.cpy_tensor_async != NULL) { + if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) { + return; + } + } + + // an async copy would normally happen after all the queued operations on both backends are completed + // to simulate the same behavior, we need to synchronize both backends first, and do a blocking copy + ggml_backend_synchronize(backend_src); + ggml_backend_synchronize(backend_dst); + ggml_backend_tensor_copy(src, dst); +} + +// events + +ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device) { + // null device is allowed for the transition period to the device interface + if (device == NULL || device->iface.event_new == NULL) { + return NULL; + } + return device->iface.event_new(device); +} + +void ggml_backend_event_free(ggml_backend_event_t event) { + if (event == NULL) { + return; + } + event->device->iface.event_free(event->device, event); +} + +void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) { + GGML_ASSERT(backend->iface.event_record != NULL); + + backend->iface.event_record(backend, event); +} + +void ggml_backend_event_synchronize(ggml_backend_event_t event) { + GGML_ASSERT(event->device->iface.event_synchronize); + + event->device->iface.event_synchronize(event->device, event); +} + +void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_ASSERT(backend->iface.event_wait != NULL); + + backend->iface.event_wait(backend, event); +} + +// Backend device + +const char * ggml_backend_dev_name(ggml_backend_dev_t device) { + return device->iface.get_name(device); +} + +const char * ggml_backend_dev_description(ggml_backend_dev_t device) { + return device->iface.get_description(device); +} + +void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + device->iface.get_memory(device, free, total); +} + +enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { + return device->iface.get_type(device); +} + +void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) { + memset(props, 0, sizeof(*props)); + device->iface.get_props(device, props); +} + +ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) { + return device->reg; +} + +ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) { + return device->iface.init_backend(device, params); +} + +ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { + return device->iface.get_buffer_type(device); +} + +ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) { + if (device->iface.get_host_buffer_type == NULL) { + return NULL; + } + + return device->iface.get_host_buffer_type(device); +} + +ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) { + return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size); +} + +bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { + return device->iface.supports_op(device, op); +} + +bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) { + return device->iface.supports_buft(device, buft); +} + +bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { + if (device->iface.offload_op != NULL) { + return device->iface.offload_op(device, op); + } + + return false; +} + +// Backend (reg) + +const char * ggml_backend_reg_name(ggml_backend_reg_t reg) { + return reg->iface.get_name(reg); +} + +size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) { + return reg->iface.get_device_count(reg); +} + +ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) { + return reg->iface.get_device(reg, index); +} + +void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (!reg->iface.get_proc_address) { + return NULL; + } + return reg->iface.get_proc_address(reg, name); +} + +// multi-buffer buffer + +struct ggml_backend_multi_buffer_context { + ggml_backend_buffer_t * buffers; + size_t n_buffers; +}; + +static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_free(ctx->buffers[i]); + } + + free(ctx->buffers); + free(ctx); +} + +static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_clear(ctx->buffers[i], value); + } +} + +static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = { + /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer, + /* .get_base = */ NULL, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ NULL, + /* .get_tensor = */ NULL, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_multi_buffer_clear, + /* .reset = */ NULL, +}; + +ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) { + ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) malloc(sizeof(struct ggml_backend_multi_buffer_context)); + ctx->n_buffers = n_buffers; + ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t)); + + GGML_ASSERT(ctx->buffers != NULL); + + size_t total_size = 0; + for (size_t i = 0; i < n_buffers; i++) { + ctx->buffers[i] = buffers[i]; + total_size += ggml_backend_buffer_get_size(buffers[i]); + } + + return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_i, ctx, total_size); +} + +bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer; +} + +void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer)); + ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_set_usage(ctx->buffers[i], usage); + } +} + +// creates a copy of the tensor with the same memory layout +static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) { + struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor); + for (int i = 0; i < GGML_MAX_DIMS; i++) { + dup->nb[i] = tensor->nb[i]; + } + return dup; +} + +static bool ggml_is_view_op(enum ggml_op op) { + return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE; +} + +// scheduler + +#ifndef GGML_SCHED_MAX_BACKENDS +#define GGML_SCHED_MAX_BACKENDS 16 +#endif + +#ifndef GGML_SCHED_MAX_SPLIT_INPUTS +#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC +#endif + +#ifndef GGML_SCHED_MAX_COPIES +#define GGML_SCHED_MAX_COPIES 4 +#endif + +struct ggml_backend_sched_split { + int backend_id; + int i_start; + int i_end; + struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; + int n_inputs; + // graph view of this split + struct ggml_cgraph graph; +}; + +struct ggml_backend_sched { + bool is_reset; // true if the scheduler has been reset since the last graph split + bool is_alloc; + + int n_backends; + + ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS]; + ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS]; + ggml_gallocr_t galloc; + + // hash map of the nodes in the graph + struct ggml_hash_set hash_set; + int * hv_tensor_backend_ids; // [hash_set.size] + struct ggml_tensor ** hv_tensor_copies; // [hash_set.size][n_backends][n_copies] + + int * node_backend_ids; // [graph_size] + int * leaf_backend_ids; // [graph_size] + + int * prev_node_backend_ids; // [graph_size] + int * prev_leaf_backend_ids; // [graph_size] + + // copy of the graph with modified inputs + struct ggml_cgraph graph; + + // graph splits + struct ggml_backend_sched_split * splits; + int n_splits; + int splits_capacity; + + // pipeline parallelism support + int n_copies; + int cur_copy; + int next_copy; + ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; + struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; + int n_graph_inputs; + + struct ggml_context * ctx; + + ggml_backend_sched_eval_callback callback_eval; + void * callback_eval_user_data; + + char * context_buffer; + size_t context_buffer_size; + + bool op_offload; + + int debug; +}; + +#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) +#define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)] +#define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)] +#define tensor_copy(tensor, backend_id, copy_id) tensor_id_copy(hash_id(tensor), backend_id, copy_id) + +// returns the priority of the backend, lower id is higher priority +static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) { + for (int i = 0; i < sched->n_backends; i++) { + if (sched->backends[i] == backend) { + return i; + } + } + return -1; +} + +static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) { + ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + if (buffer == NULL) { + return -1; + } + + // find highest prio backend that supports the buffer type and the op + for (int i = 0; i < sched->n_backends; i++) { + if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) && + ggml_backend_supports_op(sched->backends[i], op)) { + return i; + } + } + +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n", + __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name); +#endif + + return -1; +} + +#if 0 +#define GGML_SCHED_MAX_SPLITS_DEBUG 4096 +static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only +#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) +#define GET_CAUSE(node) causes[hash_id(node)] +#else +#define SET_CAUSE(node, ...) +#define GET_CAUSE(node) "" +#endif + +// returns the backend that should be used for the node based on the current locations +static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) { + // assign pre-allocated nodes to their backend + int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor); + if (cur_backend_id != -1) { + SET_CAUSE(tensor, "1.dst"); + return cur_backend_id; + } + + // view_src + if (tensor->view_src != NULL) { + cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor); + if (cur_backend_id != -1) { + SET_CAUSE(tensor, "1.vsrc"); + return cur_backend_id; + } + } + + if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) { + // since the tensor is pre-allocated, it cannot be moved to another backend + ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ABORT("pre-allocated tensor (%s) in a buffer (%s) that cannot run the operation (%s)", tensor->name, ggml_backend_buffer_name(buffer), ggml_op_name(tensor->op)); + } + + // graph input + if (tensor->flags & GGML_TENSOR_FLAG_INPUT) { + cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU) + SET_CAUSE(tensor, "1.inp"); + return cur_backend_id; + } + + // operations with weights are preferably run on the same backend as the weights + for (int i = 0; i < GGML_MAX_SRC; i++) { + const struct ggml_tensor * src = tensor->src[i]; + if (src == NULL) { + continue; + } + // skip ROPE since the rope freqs tensor is too small to choose a backend based on it + // not an ideal solution + if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); + // check if a backend with higher prio wants to offload the op + if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { + for (int b = 0; b < src_backend_id; b++) { + if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { + SET_CAUSE(tensor, "1.off"); + return b; + } + } + } + SET_CAUSE(tensor, "1.wgt%d", i); + return src_backend_id; + } + } + + return -1; +} + +static char * fmt_size(size_t size) { + static char buffer[128]; + if (size >= 1024*1024) { + snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024); + } else { + snprintf(buffer, sizeof(buffer), "%zuK", size/1024); + } + return buffer; +} + +static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + int cur_split = 0; + for (int i = 0; i < graph->n_nodes; i++) { + if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { + ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id]; + GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs", cur_split, ggml_backend_name(split_backend), + sched->splits[cur_split].n_inputs); + for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { + if (j == 0) { + GGML_LOG_DEBUG(": "); + } + GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, + fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); + } + GGML_LOG_DEBUG("\n"); + cur_split++; + } + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + if (sched->debug > 1) { + ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name, + fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), + graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]); + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src); + GGML_LOG_DEBUG(" %20.20s (%5.5s) [%5.5s %8.8s]", src->name, + fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); + } + GGML_LOG_DEBUG("\n"); + } + } +} + +static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) { + ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer; + ggml_backend_buffer_type_t buft = NULL; + + if (buf) { + // the tensor is already allocated + buft = buf->buft; + } else { + // see if the tensor already has a backend assigned, and use the buffer type of that backend + int tensor_backend_id = tensor_backend_id(t); + if (tensor_backend_id == -1 && t->view_src) { + tensor_backend_id = tensor_backend_id(t->view_src); + } + if (tensor_backend_id != -1) { + buft = sched->bufts[tensor_backend_id]; + } + } + + return buft != NULL && ggml_backend_supports_buft(sched->backends[backend_id], buft); +} + +static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) { + if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.sup"); + } +} + +// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend +static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + // reset splits + sched->n_splits = 0; + sched->n_graph_inputs = 0; + sched->is_reset = false; + + struct ggml_init_params params = { + /* .mem_size = */ sched->context_buffer_size, + /* .mem_buffer = */ sched->context_buffer, + /* .no_alloc = */ true + }; + + ggml_free(sched->ctx); + + sched->ctx = ggml_init(params); + if (sched->ctx == NULL) { + GGML_ABORT("%s: failed to initialize context\n", __func__); + } + + // pass 1: assign backends to ops with pre-allocated inputs + for (int i = 0; i < graph->n_leafs; i++) { + struct ggml_tensor * leaf = graph->leafs[i]; + int * leaf_backend_id = &tensor_backend_id(leaf); + // do not overwrite user assignments + if (*leaf_backend_id == -1) { + *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf); + } + } + + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + int * node_backend_id = &tensor_backend_id(node); + // do not overwrite user assignments + if (*node_backend_id == -1) { + *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node); + +#if 0 + // src + if (node->op == GGML_OP_NONE) { + continue; + } + + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + int * src_backend_id = &tensor_backend_id(src); + if (*src_backend_id == -1) { + *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src); + } + } +#endif + } + } + + // pass 2: expand current backend assignments + // assign the same backend to adjacent nodes + // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend) + // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops + // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known + // expand gpu down + { + int cur_backend_id = -1; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + if (*node_backend_id == sched->n_backends - 1) { + // skip cpu (lowest prio backend) + cur_backend_id = -1; + } else { + cur_backend_id = *node_backend_id; + } + } else if (cur_backend_id != -1) { + ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); + } + } + } + // expand gpu up + { + int cur_backend_id = -1; + for (int i = graph->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + if (*node_backend_id == sched->n_backends - 1) { + // skip cpu (lowest prio backend) + cur_backend_id = -1; + } else { + cur_backend_id = *node_backend_id; + } + } else if (cur_backend_id != -1) { + ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); + } + } + } + // expand rest down + { + int cur_backend_id = -1; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + cur_backend_id = *node_backend_id; + } else if (cur_backend_id != -1) { + ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); + } + } + } + // expand rest up + { + int cur_backend_id = -1; + for (int i = graph->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + cur_backend_id = *node_backend_id; + } else if (cur_backend_id != -1) { + ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); + } + } + } + + // pass 3: upgrade nodes to higher prio backends with compatible buffer types + // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there + // however, we also need to verify that the sources are in compatible buffer types + // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph + // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same + // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU) + // additionally, set remaining unassigned nodes to the backend with the most supported inputs + // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id == -1) { + // unassigned node: find the backend with the most supported inputs + int n_supported_best = -1; + for (int b = 0; b < sched->n_backends; b++) { + if (ggml_backend_supports_op(sched->backends[b], node)) { + int n_supported = 0; + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && ggml_backend_sched_buffer_supported(sched, src, b)) { + n_supported++; + } + } + if (n_supported > n_supported_best) { + n_supported_best = n_supported; + *node_backend_id = b; + SET_CAUSE(node, "3.best"); + } + } + } + } else { + // assigned node: upgrade to higher prio backend if possible + for (int b = 0; b < *node_backend_id; b++) { + if (sched->bufts[b] == sched->bufts[*node_backend_id] && ggml_backend_supports_op(sched->backends[b], node)) { + bool supported = true; + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + if (!ggml_backend_sched_buffer_supported(sched, src, b)) { + supported = false; + break; + } + } + if (supported) { + *node_backend_id = b; + SET_CAUSE(node, "3.upg"); + break; + } + } + } + } + } + + // pass 4: assign backends to remaining src from dst and view_src + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + int * cur_backend_id = &tensor_backend_id(node); + if (node->view_src != NULL && *cur_backend_id == -1) { + *cur_backend_id = tensor_backend_id(node->view_src); + SET_CAUSE(node, "4.vsrc"); + } + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + int * src_backend_id = &tensor_backend_id(src); + if (*src_backend_id == -1) { + if (src->view_src != NULL) { + // views are always on the same backend as the source + *src_backend_id = tensor_backend_id(src->view_src); + SET_CAUSE(src, "4.vsrc"); + } else { + *src_backend_id = *cur_backend_id; + SET_CAUSE(src, "4.cur"); + } + } + } + // if the node is still unassigned, assign it to the first backend that supports it + for (int b = 0; b < sched->n_backends && *cur_backend_id == -1; b++) { + ggml_backend_sched_set_if_supported(sched, node, b, cur_backend_id); + } + GGML_ASSERT(*cur_backend_id != -1); + } + + // pass 5: split graph, find tensors that need to be copied + { + int i_split = 0; + struct ggml_backend_sched_split * split = &sched->splits[0]; + // find the backend of the first split, skipping view ops + int i = 0; + for (; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (!ggml_is_view_op(node->op)) { + split->backend_id = tensor_backend_id(node); + break; + } + } + split->i_start = 0; + split->n_inputs = 0; + int cur_backend_id = split->backend_id; + for (; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + + if (ggml_is_view_op(node->op)) { + continue; + } + + const int node_backend_id = tensor_backend_id(node); + + GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback + + // check if we should start a new split based on the sources of the current node + bool need_new_split = false; + if (node_backend_id == cur_backend_id && split->n_inputs > 0) { + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + // check if a weight is on a different and incompatible backend + // by starting a new split, the memory of the previously offloaded weights can be reused + if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + int src_backend_id = tensor_backend_id(src); + if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) { + need_new_split = true; + break; + } + } + // check if the split has too many inputs + // FIXME: count the number of inputs instead of only checking when full + if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) { + const size_t id = hash_id(src); + int src_backend_id = sched->hv_tensor_backend_ids[id]; + bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id); + if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) { + need_new_split = true; + break; + } + } + } + } + + if (node_backend_id != cur_backend_id || need_new_split) { + split->i_end = i; + i_split++; + if (i_split >= sched->splits_capacity) { + sched->splits_capacity *= 2; + sched->splits = (ggml_backend_sched_split *) + realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); + GGML_ASSERT(sched->splits != NULL); + } + split = &sched->splits[i_split]; + split->backend_id = node_backend_id; + split->i_start = i; + split->n_inputs = 0; + cur_backend_id = node_backend_id; + } + + // find inputs that are not on the same backend + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + + size_t src_id = hash_id(src); + const int src_backend_id = sched->hv_tensor_backend_ids[src_id]; + GGML_ASSERT(src_backend_id != -1); // all inputs should be assigned by now + + if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) { + if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) { + ggml_backend_t backend = sched->backends[src_backend_id]; + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * tensor_copy; + if (c == sched->cur_copy) { + tensor_copy = src; // use the original tensor as the current copy + } else { + tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); + ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); + } + if (sched->n_copies > 1) { + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor + } + tensor_id_copy(src_id, src_backend_id, c) = tensor_copy; + SET_CAUSE(tensor_copy, "4.cpy"); + } + int n_graph_inputs = sched->n_graph_inputs++; + GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS); + sched->graph_inputs[n_graph_inputs] = src; + } + } + + if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) { + // create a copy of the input in the split's backend + if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) { + ggml_backend_t backend = sched->backends[cur_backend_id]; + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); + ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); + if (sched->n_copies > 1) { + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor + } + tensor_id_copy(src_id, cur_backend_id, c) = tensor_copy; + SET_CAUSE(tensor_copy, "4.cpy"); + } + int n_inputs = split->n_inputs++; + GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS); + split->inputs[n_inputs] = src; + } + node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy); + } + } + } + split->i_end = graph->n_nodes; + sched->n_splits = i_split + 1; + } + + if (sched->debug) { + ggml_backend_sched_print_assignments(sched, graph); + } + + // swap node_backend_ids and leaf _backend_ids with prevs + { + int * tmp = sched->node_backend_ids; + sched->node_backend_ids = sched->prev_node_backend_ids; + sched->prev_node_backend_ids = tmp; + + tmp = sched->leaf_backend_ids; + sched->leaf_backend_ids = sched->prev_leaf_backend_ids; + sched->prev_leaf_backend_ids = tmp; + } + + int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies; + if (sched->graph.size < graph_size) { + sched->graph.size = graph_size; + sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); + sched->graph.leafs = (ggml_tensor **) realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *)); + GGML_ASSERT(sched->graph.nodes != NULL); + GGML_ASSERT(sched->graph.leafs != NULL); + } + sched->graph.n_nodes = 0; + sched->graph.n_leafs = 0; + + struct ggml_cgraph * graph_copy = &sched->graph; + + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &sched->splits[i]; + split->graph = ggml_graph_view(graph, split->i_start, split->i_end); + + // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split + for (int j = 0; j < split->n_inputs; j++) { + assert(graph_copy->size > (graph_copy->n_nodes + 1)); + + struct ggml_tensor * input = split->inputs[j]; + const size_t input_id = hash_id(input); + struct ggml_tensor * input_cpy = tensor_id_copy(input_id, split->backend_id, sched->cur_copy); + + // add a dependency to the input source so that it is not freed before the copy is done + struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input); + input_dep->src[0] = input; + sched->node_backend_ids[graph_copy->n_nodes] = sched->hv_tensor_backend_ids[input_id]; + graph_copy->nodes[graph_copy->n_nodes++] = input_dep; + + // add a dependency to the input copy so that it is allocated at the start of the split + sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id; + graph_copy->nodes[graph_copy->n_nodes++] = input_cpy; + } + + for (int j = split->i_start; j < split->i_end; j++) { + assert(graph_copy->size > graph_copy->n_nodes); + sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]); + graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j]; + } + } + + if (sched->n_copies > 1) { + // add input copies as leafs so that they are allocated first + for (int i = 0; i < sched->n_graph_inputs; i++) { + struct ggml_tensor * input = sched->graph_inputs[i]; + size_t id = hash_id(input); + int backend_id = tensor_backend_id(input); + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c); + sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; + assert(graph_copy->size > graph_copy->n_leafs); + graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; + } + } + + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &sched->splits[i]; + int backend_id = split->backend_id; + for (int j = 0; j < split->n_inputs; j++) { + struct ggml_tensor * input = split->inputs[j]; + size_t id = hash_id(input); + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c); + sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; + assert(graph_copy->size > graph_copy->n_leafs); + graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; + } + } + } + } + + // add leafs from the original graph + for (int i = 0; i < graph->n_leafs; i++) { + struct ggml_tensor * leaf = graph->leafs[i]; + sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf); + assert(graph_copy->size > graph_copy->n_leafs); + graph_copy->leafs[graph_copy->n_leafs++] = leaf; + } +} + +static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { + bool backend_ids_changed = false; + for (int i = 0; i < sched->graph.n_nodes; i++) { + if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] && + sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) { + backend_ids_changed = true; + break; + } + } + if (!backend_ids_changed) { + for (int i = 0; i < sched->graph.n_leafs; i++) { + if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] && + sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) { + backend_ids_changed = true; + break; + } + } + } + + // allocate graph + if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { + // the re-allocation may cause the split inputs to be moved to a different address + // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy + for (int i = 0; i < sched->n_backends; i++) { + ggml_backend_synchronize(sched->backends[i]); + } +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); +#endif + ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); + if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { + GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__); + return false; + } + } + + return true; +} + +static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { + struct ggml_backend_sched_split * splits = sched->splits; + + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &splits[i]; + int split_backend_id = split->backend_id; + ggml_backend_t split_backend = sched->backends[split_backend_id]; + + // copy the input tensors to the split backend + for (int j = 0; j < split->n_inputs; j++) { + ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); + struct ggml_tensor * input = split->inputs[j]; + struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); + + if (input->flags & GGML_TENSOR_FLAG_INPUT) { + // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); + } + ggml_backend_tensor_copy(input, input_cpy); + } else { + // wait for the split backend to finish using the input before overwriting it + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); + } + // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events + // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface + if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { + ggml_backend_synchronize(input_backend); + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); + } + ggml_backend_tensor_copy(input, input_cpy); + } + } + } + + if (!sched->callback_eval) { + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); + if (ec != GGML_STATUS_SUCCESS) { + return ec; + } + } else { + // similar to ggml_backend_compare_graph_backend + for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { + struct ggml_tensor * t = split->graph.nodes[j0]; + + // check if the user needs data from this node + bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); + + int j1 = j0; + + // determine the range [j0, j1] of nodes that can be computed together + while (!need && j1 < split->graph.n_nodes - 1) { + t = split->graph.nodes[++j1]; + need = sched->callback_eval(t, true, sched->callback_eval_user_data); + } + + struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); + + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv); + if (ec != GGML_STATUS_SUCCESS) { + return ec; + } + + // TODO: pass backend to the callback, then the user can decide if they want to synchronize + ggml_backend_synchronize(split_backend); + + if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { + break; + } + + j0 = j1; + } + } + + // record the event of this copy + if (split->n_inputs > 0) { + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy], split_backend); + } + } + } + + return GGML_STATUS_SUCCESS; +} + +ggml_backend_sched_t ggml_backend_sched_new( + ggml_backend_t * backends, + ggml_backend_buffer_type_t * bufts, + int n_backends, + size_t graph_size, + bool parallel, + bool op_offload) { + GGML_ASSERT(n_backends > 0); + GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); + GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); + + struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched)); + + const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG"); + sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0; + sched->n_backends = n_backends; + sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; + + // initialize hash table + // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead) + sched->hash_set = ggml_hash_set_new(graph_size); + sched->hv_tensor_backend_ids = (int *) malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + sched->hv_tensor_copies = (ggml_tensor **) malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); + + const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph + const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2; + sched->node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->node_backend_ids[0])); + sched->leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); + sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); + sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); + + sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); + sched->context_buffer = (char *) malloc(sched->context_buffer_size); + + const int initial_splits_capacity = 16; + sched->splits = (ggml_backend_sched_split *) calloc(initial_splits_capacity, sizeof(sched->splits[0])); + sched->splits_capacity = initial_splits_capacity; + + for (int b = 0; b < n_backends; b++) { + sched->backends[b] = backends[b]; + sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]); + GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b])); + + if (sched->n_copies > 1) { + for (int c = 0; c < sched->n_copies; c++) { + sched->events[b][c] = ggml_backend_event_new(backends[b]->device); + } + } + } + + sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + sched->op_offload = op_offload; + + ggml_backend_sched_reset(sched); + + return sched; +} + +void ggml_backend_sched_free(ggml_backend_sched_t sched) { + if (sched == NULL) { + return; + } + for (int b = 0; b < sched->n_backends; b++) { + for (int c = 0; c < sched->n_copies; c++) { + ggml_backend_event_free(sched->events[b][c]); + } + } + ggml_gallocr_free(sched->galloc); + ggml_free(sched->ctx); + ggml_hash_set_free(&sched->hash_set); + free(sched->splits); + free(sched->hv_tensor_backend_ids); + free(sched->hv_tensor_copies); + free(sched->node_backend_ids); + free(sched->leaf_backend_ids); + free(sched->prev_node_backend_ids); + free(sched->prev_leaf_backend_ids); + free(sched->context_buffer); + free(sched->graph.nodes); + free(sched->graph.leafs); + free(sched); +} + +void ggml_backend_sched_reset(ggml_backend_sched_t sched) { + // reset state for the next run + if (!sched->is_reset) { + ggml_hash_set_reset(&sched->hash_set); + memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + memset(sched->hv_tensor_copies, 0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); + sched->is_reset = true; + } + sched->is_alloc = false; +} + +bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { + GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); + + ggml_backend_sched_synchronize(sched); + + ggml_backend_sched_split_graph(sched, measure_graph); + + if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + +bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); + GGML_ASSERT(!sched->is_alloc); + + sched->cur_copy = sched->next_copy; + sched->next_copy = (sched->next_copy + 1) % sched->n_copies; + + ggml_backend_sched_split_graph(sched, graph); + + if (!ggml_backend_sched_alloc_splits(sched)) { + return false; + } + + sched->is_alloc = true; + + return true; +} + +enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph); + ggml_backend_sched_synchronize(sched); + return err; +} + +enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + if (!sched->is_reset && !sched->is_alloc) { + ggml_backend_sched_reset(sched); + } + + if (!sched->is_alloc) { + if (!ggml_backend_sched_alloc_graph(sched, graph)) { + return GGML_STATUS_ALLOC_FAILED; + } + } + + return ggml_backend_sched_compute_splits(sched); +} + +void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { + for (int i = 0; i < sched->n_backends; i++) { + ggml_backend_synchronize(sched->backends[i]); + } + if (!sched->is_alloc) { + // if the graph is not already allocated, always use copy 0 after a synchronization + // this ensures that during generation the same copy is used every time, + // which avoids changes in the graph that could cause CUDA or other graphs to be disabled + sched->next_copy = 0; + } +} + +void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { + sched->callback_eval = callback; + sched->callback_eval_user_data = user_data; +} + +int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { + return sched->n_splits; +} + +int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) { + return sched->n_copies; +} + +int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) { + return sched->n_backends; +} + +ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) { + GGML_ASSERT(i >= 0 && i < sched->n_backends); + return sched->backends[i]; +} + +size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + + return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); +} + +void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + tensor_backend_id(node) = backend_index; + SET_CAUSE(node, "usr"); + sched->is_reset = false; +} + +ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) { + int backend_index = tensor_backend_id(node); + if (backend_index == -1) { + return NULL; + } + return sched->backends[backend_index]; +} + +// utils + +enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->buffer == NULL); + GGML_ASSERT(tensor->view_src != NULL); + GGML_ASSERT(tensor->view_src->buffer != NULL); + GGML_ASSERT(tensor->view_src->data != NULL); + + tensor->buffer = tensor->view_src->buffer; + tensor->data = (char *)tensor->view_src->data + tensor->view_offs; + return ggml_backend_buffer_init_tensor(tensor->buffer, tensor); +} + +enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { + GGML_ASSERT(tensor->buffer == NULL); + GGML_ASSERT(tensor->data == NULL); + GGML_ASSERT(tensor->view_src == NULL); + GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); + GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + + tensor->buffer = buffer; + tensor->data = addr; + return ggml_backend_buffer_init_tensor(buffer, tensor); +} + +static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, + struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) { + + GGML_ASSERT(src != NULL); + GGML_ASSERT(src->data && "graph must be allocated"); + + size_t id = ggml_hash_insert(&hash_set, src); + if (id == GGML_HASHSET_ALREADY_EXISTS) { + return node_copies[ggml_hash_find(&hash_set, src)]; + } + + struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src); + if (src->view_src != NULL) { + dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src); + dst->view_offs = src->view_offs; + } + dst->op = src->op; + memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); + ggml_set_name(dst, src->name); + + // copy src + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * s = src->src[i]; + if (s == NULL) { + continue; + } + dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s); + } + + node_copies[id] = dst; + return dst; +} + +static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) { + size_t id = ggml_hash_find(hash_set, src); + if (node_init[id]) { + return; + } + node_init[id] = true; + + struct ggml_tensor * dst = node_copies[id]; + if (dst->view_src != NULL) { + graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src); + enum ggml_status status = ggml_backend_view_init(dst); + GGML_ASSERT(status == GGML_STATUS_SUCCESS); + } + else { + ggml_backend_tensor_copy(src, dst); + } + + // init src + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * s = src->src[i]; + if (s == NULL) { + continue; + } + graph_copy_init_tensor(hash_set, node_copies, node_init, s); + } +} + +struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { + struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size); + struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT + bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0])); + + struct ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false), + /* .mem_buffer = */ NULL, + /* .no_alloc = */ true + }; + + struct ggml_context * ctx_allocated = ggml_init(params); + struct ggml_context * ctx_unallocated = ggml_init(params); + + if (ctx_allocated == NULL || ctx_unallocated == NULL) { + GGML_LOG_ERROR("%s: failed to allocate context for graph copy\n", __func__); + ggml_hash_set_free(&hash_set); + free(node_copies); + free(node_init); + ggml_free(ctx_allocated); + ggml_free(ctx_unallocated); + return { + /* .buffer = */ NULL, + /* .ctx_allocated = */ NULL, + /* .ctx_unallocated = */ NULL, + /* .graph = */ NULL, + }; + } + + // dup nodes + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node); + } + + // allocate nodes + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend); + if (buffer == NULL) { + GGML_LOG_ERROR("%s: failed to allocate buffer for graph copy\n", __func__); + ggml_hash_set_free(&hash_set); + free(node_copies); + free(node_init); + ggml_free(ctx_allocated); + ggml_free(ctx_unallocated); + return { + /* .buffer = */ NULL, + /* .ctx_allocated = */ NULL, + /* .ctx_unallocated = */ NULL, + /* .graph = */ NULL, + }; + } + + //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024); + + // copy data and init views + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + graph_copy_init_tensor(&hash_set, node_copies, node_init, node); + } + + // build graph copy + struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false); + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + struct ggml_tensor * node_copy = node_copies[ggml_hash_find(&hash_set, node)]; + graph_copy->nodes[i] = node_copy; + } + graph_copy->n_nodes = graph->n_nodes; + + ggml_hash_set_free(&hash_set); + free(node_copies); + free(node_init); + + return { + /* .buffer = */ buffer, + /* .ctx_allocated = */ ctx_allocated, + /* .ctx_unallocated = */ ctx_unallocated, + /* .graph = */ graph_copy, + }; +} + +void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { + ggml_backend_buffer_free(copy.buffer); + ggml_free(copy.ctx_allocated); + ggml_free(copy.ctx_unallocated); +} + +bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) { + struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); + if (copy.buffer == NULL) { + return false; + } + + struct ggml_cgraph * g1 = graph; + struct ggml_cgraph * g2 = copy.graph; + + assert(g1->n_nodes == g2->n_nodes); + + if (test_node != nullptr) { + // Compute the whole graph and only test the output for a specific tensor + ggml_backend_graph_compute(backend1, g1); + ggml_backend_graph_compute(backend2, g2); + + int test_node_idx = -1; + for (int i = 0; i < g1->n_nodes; i++) { + struct ggml_tensor * t1 = g1->nodes[i]; + if (t1 == test_node) { + test_node_idx = i; + break; + } + } + GGML_ASSERT(test_node_idx != -1); + + callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data); + } else { + for (int i = 0; i < g1->n_nodes; i++) { + struct ggml_tensor * t1 = g1->nodes[i]; + struct ggml_tensor * t2 = g2->nodes[i]; + + assert(t1->op == t2->op && ggml_are_same_layout(t1, t2)); + + struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); + struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); + + ggml_backend_graph_compute(backend1, &g1v); + ggml_backend_graph_compute(backend2, &g2v); + + if (ggml_is_view_op(t1->op)) { + continue; + } + + // compare results, calculate rms etc + if (!callback(i, t1, t2, user_data)) { + break; + } + } + } + ggml_backend_graph_copy_free(copy); + + return true; +} + +// CPU backend - buffer + +static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + uintptr_t data = (uintptr_t)buffer->context; + + // align the buffer + if (data % TENSOR_ALIGNMENT != 0) { + data = GGML_PAD(data, TENSOR_ALIGNMENT); + } + + return (void *)data; +} + +static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_aligned_free(buffer->context, buffer->size); +} + +static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + memcpy((char *)tensor->data + offset, data, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + memcpy(data, (const char *)tensor->data + offset, size); + + GGML_UNUSED(buffer); +} + +static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; + } + return false; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + +static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { + /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; + +static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; + +// CPU backend buffer type + +// this buffer type is defined here to make it available to all backends + +static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * data = ggml_aligned_malloc(size); + + if (data == NULL) { + GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, size); +} + +static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; + + GGML_UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .device = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type; +} + +static const char * ggml_backend_cpu_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_Mapped"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_cpu_buffer_from_ptr_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_buffer_from_ptr_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .device = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type; +} + +ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { + GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); + return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size); +} diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-blas/CMakeLists.txt b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-blas/CMakeLists.txt new file mode 100644 index 0000000..76064c3 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-blas/CMakeLists.txt @@ -0,0 +1,87 @@ +if (GGML_STATIC) + set(BLA_STATIC ON) +endif() +#if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22) +# set(BLA_SIZEOF_INTEGER 8) +#endif() + +set(BLA_VENDOR ${GGML_BLAS_VENDOR}) +find_package(BLAS) + +if (BLAS_FOUND) + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + + ggml_add_backend_library(ggml-blas + ggml-blas.cpp + ) + + if (${GGML_BLAS_VENDOR} MATCHES "Apple") + add_compile_definitions(ACCELERATE_NEW_LAPACK) + add_compile_definitions(ACCELERATE_LAPACK_ILP64) + add_compile_definitions(GGML_BLAS_USE_ACCELERATE) + elseif ("${BLAS_INCLUDE_DIRS}" STREQUAL "") + # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. + # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 + find_package(PkgConfig REQUIRED) + if (${GGML_BLAS_VENDOR} MATCHES "Generic") + pkg_check_modules(DepBLAS blas) + elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS") + # As of openblas v0.3.22, the 64-bit is named openblas64.pc + pkg_check_modules(DepBLAS openblas64) + if (NOT DepBLAS_FOUND) + pkg_check_modules(DepBLAS openblas) + endif() + elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME") + add_compile_definitions(GGML_BLAS_USE_BLIS) + pkg_check_modules(DepBLAS blis) + elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS") + pkg_check_modules(DepBLAS blas-atlas) + elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS") + pkg_check_modules(DepBLAS flexiblas_api) + elseif (${GGML_BLAS_VENDOR} MATCHES "Intel") + add_compile_definitions(GGML_BLAS_USE_MKL) + # all Intel* libraries share the same include path + pkg_check_modules(DepBLAS mkl-sdl) + elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC") + # this doesn't provide pkg-config + # suggest to assign BLAS_INCLUDE_DIRS on your own + if ("${NVHPC_VERSION}" STREQUAL "") + message(WARNING "Better to set NVHPC_VERSION") + else() + set(DepBLAS_FOUND ON) + set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") + endif() + endif() + if (DepBLAS_FOUND) + set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" + " detected by pkgconfig, trying to find cblas.h from possible paths...") + find_path(BLAS_INCLUDE_DIRS + NAMES cblas.h + HINTS + /usr/include + /usr/local/include + /usr/include/openblas + /opt/homebrew/opt/openblas/include + /usr/local/opt/openblas/include + /usr/include/x86_64-linux-gnu/openblas/include + ) + endif() + endif() + + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + + target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS}) + + if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) + add_compile_definitions(GGML_BLAS_USE_MKL) + endif() + + target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES}) + target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS}) +else() + message(FATAL_ERROR "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct GGML_BLAS_VENDOR") +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-blas/ggml-blas.cpp b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-blas/ggml-blas.cpp new file mode 100644 index 0000000..aeac2e5 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-blas/ggml-blas.cpp @@ -0,0 +1,517 @@ +#include "ggml-impl.h" +#include "ggml-blas.h" +#include "ggml-backend-impl.h" + +#include +#include +#include + +#if defined(GGML_BLAS_USE_ACCELERATE) +# include +#elif defined(GGML_BLAS_USE_MKL) +# include +#elif defined(GGML_BLAS_USE_BLIS) +# include +#elif defined(GGML_BLAS_USE_NVPL) +# include +#else +# include +#endif + +struct ggml_backend_blas_context { + int n_threads = GGML_DEFAULT_N_THREADS; + std::unique_ptr work_data; + size_t work_size = 0; +#ifndef GGML_USE_OPENMP + std::vector> tasks; +#endif +}; + +static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + const int64_t ne_plane = ne01*ne00; + const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float); + + if (ctx->work_size < desired_wsize) { + ctx->work_data.reset(new char[desired_wsize]); + ctx->work_size = desired_wsize; + } + void * wdata = ctx->work_data.get(); + + // convert src0 to float + if (type != GGML_TYPE_F32) { + const auto * type_traits = ggml_get_type_traits(type); + ggml_to_float_t const to_float = type_traits->to_float; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; + + const int min_cols_per_thread = 4096; + const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1); + const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1); + +#ifdef GGML_USE_OPENMP + #pragma omp parallel for num_threads(n_threads) + for (int64_t i01 = 0; i01 < ne01; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); + } +#else + for (int i = 1; i < n_threads; i++) { + const int64_t start = i*ne01/n_threads; + const int64_t end = (i + 1)*ne01/n_threads; + if (start < end) { + ctx->tasks.push_back(std::async(std::launch::async, [=]() { + for (int64_t i01 = start; i01 < end; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); + } + })); + } + } + { + // reuse the current thread for the first task + const int64_t start = 0; + const int64_t end = ne01/n_threads; + for (int64_t i01 = start; i01 < end; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); + } + } +#endif + } + } + +#ifndef GGML_USE_OPENMP + // wait for all tasks to finish + for (auto & task : ctx->tasks) { + task.get(); + } + ctx->tasks.clear(); +#endif + } + +#if defined(OPENBLAS_VERSION) + openblas_set_num_threads(ctx->n_threads); +#endif + +#if defined(GGML_BLAS_USE_BLIS) + bli_thread_set_num_threads(ctx->n_threads); +#endif + +#if defined(GGML_BLAS_USE_NVPL) + nvpl_blas_set_num_threads(ctx->n_threads); +#endif + + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); + const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + + if (type != GGML_TYPE_F32) { + x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; + } + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne1, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); + } + } +} + +static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne3 == ne13); + GGML_ASSERT(ne03 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + // Arguments to ggml_compute_forward_out_prod (expressed as major,minor) + // src0: (k,n) + // src1: (k,m) + // dst: (m,n) + // + // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) + // Also expressed as (major,minor) + // a: (m,k): so src1 transposed + // b: (k,n): so src0 + // c: (m,n) + // + // However, if ggml_is_transposed(src1) is true, then + // src1->data already contains a transposed version, so sgemm mustn't + // transpose it further. + + int n = src0->ne[0]; + int k = src0->ne[1]; + int m = src1->ne[0]; + + CBLAS_TRANSPOSE transposeA; + int lda; + + if (!ggml_is_transposed(src1)) { + transposeA = CblasTrans; + lda = m; + } else { + transposeA = CblasNoTrans; + lda = k; + } + + float * a = (float *) ((char *) src1->data); + float * b = (float *) ((char *) src0->data); + float * c = (float *) ((char *) dst->data); + + cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); + + GGML_UNUSED(ctx); +} + +// backend interface + +static const char * ggml_backend_blas_get_name(ggml_backend_t backend) { + return "BLAS"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_blas_free(ggml_backend_t backend) { + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; + delete ctx; + delete backend; +} + +static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case GGML_OP_MUL_MAT: + ggml_backend_blas_mul_mat(ctx, node); + break; + + case GGML_OP_OUT_PROD: + ggml_backend_blas_out_prod(ctx, node); + break; + + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + break; + + default: + GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node)); + } + } + + return GGML_STATUS_SUCCESS; + + GGML_UNUSED(backend); +} + +static struct ggml_backend_i blas_backend_i = { + /* .get_name = */ ggml_backend_blas_get_name, + /* .free = */ ggml_backend_blas_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_blas_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_blas_guid(void) { + static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d }; + return &guid; +} + +ggml_backend_t ggml_backend_blas_init(void) { + ggml_backend_blas_context * ctx = new ggml_backend_blas_context; + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_blas_guid(), + /* .iface = */ blas_backend_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0), + /* .context = */ ctx, + }; + +#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) + if (openblas_get_parallel() != OPENBLAS_OPENMP) { + GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__); + } +#endif + +#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP) + GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__); +#endif + + return backend; +} + +bool ggml_backend_is_blas(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid()); +} + +void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) { + GGML_ASSERT(ggml_backend_is_blas(backend_blas)); + + ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context; + ctx->n_threads = n_threads; +} + +// device interface + +static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) { + return "BLAS"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) { + #if defined(GGML_BLAS_USE_ACCELERATE) + return "Accelerate"; + #elif defined(GGML_BLAS_USE_MKL) + return "MKL"; + #elif defined(GGML_BLAS_USE_BLIS) + return "BLIS"; + #elif defined(GGML_BLAS_USE_NVPL) + return "NVPL"; + #elif defined(OPENBLAS_VERSION) + return "OpenBLAS"; + #else + return "BLAS"; + #endif + + GGML_UNUSED(dev); +} + +static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + // TODO + *free = 0; + *total = 0; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_ACCEL; + + GGML_UNUSED(dev); +} + +static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_blas_device_get_name(dev); + props->description = ggml_backend_blas_device_get_description(dev); + props->type = ggml_backend_blas_device_get_type(dev); + ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) { + return ggml_backend_blas_init(); + + GGML_UNUSED(dev); + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + return ggml_backend_cpu_buffer_from_ptr(ptr, size); + + GGML_UNUSED(dev); + GGML_UNUSED(max_tensor_size); +} + +static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + + case GGML_OP_MUL_MAT: + { + // BLAS usually is only faster for large matrices + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = op->ne[0]; + const int64_t ne1 = op->ne[1]; + + // TODO: find the optimal value + const int64_t min_batch = 32; + + return ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + src1->type == GGML_TYPE_F32 && + (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) && + (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); + } + + case GGML_OP_OUT_PROD: + return op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + ggml_is_matrix(src0) && + ggml_is_matrix(src1) && + ggml_is_contiguous(src0) && + (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) && + (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); + + default: + return false; + + } + + GGML_UNUSED(dev); +} + +static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_host(buft); + + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_blas_device_i = { + /* .get_name = */ ggml_backend_blas_device_get_name, + /* .get_description = */ ggml_backend_blas_device_get_description, + /* .get_memory = */ ggml_backend_blas_device_get_memory, + /* .get_type = */ ggml_backend_blas_device_get_type, + /* .get_props = */ ggml_backend_blas_device_get_props, + /* .init_backend = */ ggml_backend_blas_device_init_backend, + /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr, + /* .supports_op = */ ggml_backend_blas_device_supports_op, + /* .supports_buft = */ ggml_backend_blas_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface + +static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) { + return "BLAS"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + static ggml_backend_device ggml_backend_blas_device = { + /* .iface = */ ggml_backend_blas_device_i, + /* .reg = */ reg, + /* .context = */ nullptr, + }; + + return &ggml_backend_blas_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) { + return (void *)ggml_backend_blas_set_n_threads; + } + return NULL; + + GGML_UNUSED(reg); + GGML_UNUSED(name); +} + +static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = { + /* .get_name = */ ggml_backend_blas_reg_get_name, + /* .get_device_count = */ ggml_backend_blas_reg_get_device_count, + /* .get_device = */ ggml_backend_blas_reg_get_device, + /* .get_proc_address = */ ggml_backend_blas_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_blas_reg(void) { + static struct ggml_backend_reg ggml_backend_blas_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_blas_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_blas_reg; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_blas_reg) diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-cann/CMakeLists.txt b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-cann/CMakeLists.txt new file mode 100755 index 0000000..aee5e7b --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-cann/CMakeLists.txt @@ -0,0 +1,89 @@ +if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME}) + set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME}) + message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}") +endif() + +# Auto-detech Soc type and Soc version, if detect failed, will abort build +set(SOC_VERSION "") +function(detect_ascend_soc_type SOC_VERSION) + execute_process( + COMMAND bash -c "npu-smi info|awk -F' ' 'NF > 0 && NR==7 {print $3}'" + OUTPUT_VARIABLE npu_info + RESULT_VARIABLE npu_result + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if("${npu_info}" STREQUAL "" OR ${npu_result}) + message(FATAL_ERROR "Auto-detech ascend soc type failed, please specify manually or check ascend device working normally.") + endif() + set(${SOC_VERSION} "Ascend${npu_info}" PARENT_SCOPE) +endfunction() + +if(NOT SOC_TYPE) + detect_ascend_soc_type(SOC_VERSION) + set(SOC_TYPE "${SOC_VERSION}") + message(STATUS "CANN: SOC_VERSION auto-detected is:${SOC_VERSION}") +endif() + +string(TOLOWER ${SOC_TYPE} SOC_VERSION) # SOC_VERSION need lower + +# Construct Soc specify compile option: ASCEND_#Soc_Major_SN. Such as ASCEND_910B, ASCEND_310P. +string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}") +set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}") +string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION) +message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}") +option(USE_ACL_GRAPH "Enable CANN graph execution (ACL graph mode)" OFF) + +if(USE_ACL_GRAPH AND (SOC_TYPE_MAJOR_SN STREQUAL "310P" OR SOC_TYPE_COMPILE_OPTION STREQUAL "ASCEND_310P")) + message(FATAL_ERROR + "CANN Graph (ACL graph mode) is not supported on 310P devices. " + "Please build with -DUSE_ACL_GRAPH=OFF or use a supported SOC.") +endif() + +if (CANN_INSTALL_DIR) + # Only Support Linux. + if (NOT UNIX) + message(FATAL_ERROR "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}") + endif() + + # Supported platforms: x86-64, arm64 + if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64") + else() + message(FATAL_ERROR "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}") + endif() + + # Set header and libs + set(CANN_INCLUDE_DIRS + ${CANN_INSTALL_DIR}/include + ${CANN_INSTALL_DIR}/include/aclnn + ${CANN_INSTALL_DIR}/acllib/include + ) + + list(APPEND CANN_LIBRARIES + ascendcl + nnopbase + opapi + acl_op_compiler + ) + + file(GLOB GGML_SOURCES_CANN "*.cpp") + + ggml_add_backend_library(ggml-cann ${GGML_SOURCES_CANN}) + target_link_libraries(ggml-cann PRIVATE ${CANN_LIBRARIES}) + target_include_directories(ggml-cann PRIVATE ${CANN_INCLUDE_DIRS}) + target_link_directories(ggml-cann PRIVATE ${CANN_INSTALL_DIR}/lib64) + + target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}") + + if (USE_ACL_GRAPH) + target_compile_definitions(ggml-cann PRIVATE USE_ACL_GRAPH) + message(STATUS "CANN: USE_ACL_GRAPH is enabled.") + else() + message(STATUS "CANN: USE_ACL_GRAPH is disabled.") + endif() + + message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}") + message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}") +else() + message(FATAL_ERROR "CANN: Can't find CANN_INSTALL_DIR, did you forget to source set_var.sh?") +endif() diff --git a/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-cann/Doxyfile b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-cann/Doxyfile new file mode 100755 index 0000000..3290a48 --- /dev/null +++ b/vendor/whisper-rs-sys/whisper.cpp/ggml/src/ggml-cann/Doxyfile @@ -0,0 +1,2579 @@ +# Doxyfile 1.8.17 + +# This file describes the settings to be used by the documentation system +# doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). + +#--------------------------------------------------------------------------- +# Project related configuration options +#--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. +# The default value is: UTF-8. + +DOXYFILE_ENCODING = UTF-8 + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = "ggml" + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + +PROJECT_NUMBER = + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + +PROJECT_BRIEF = "Tensor library for machine learning" + +# With the PROJECT_LOGO tag one can specify a logo or an icon that is included +# in the documentation. The maximum height of the logo should not exceed 55 +# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy +# the logo to the output directory. + +PROJECT_LOGO = + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where doxygen was started. If +# left blank the current directory will be used. + +OUTPUT_DIRECTORY = docs + +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- +# directories (in 2 levels) under the output directory of each output format and +# will distribute the generated files over these directories. Enabling this +# option can be useful when feeding doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. +# The default value is: NO. + +CREATE_SUBDIRS = NO + +# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + +ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, +# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), +# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, +# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), +# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, +# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, +# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, +# Ukrainian and Vietnamese. +# The default value is: English. + +OUTPUT_LANGUAGE = English + +# The OUTPUT_TEXT_DIRECTION tag is used to specify the direction in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all generated output in the proper direction. +# Possible values are: None, LTR, RTL and Context. +# The default value is: None. + +OUTPUT_TEXT_DIRECTION = None + +# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + +BRIEF_MEMBER_DESC = YES + +# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + +REPEAT_BRIEF = YES + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + +ABBREVIATE_BRIEF = "The $name class" \ + "The $name widget" \ + "The $name file" \ + is \ + provides \ + specifies \ + contains \ + represents \ + a \ + an \ + the + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + +ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + +INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + +FULL_PATH_NAMES = YES + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + +STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + +STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + +SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + +JAVADOC_AUTOBRIEF = NO + +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + +JAVADOC_BANNER = NO + +# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + +QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + +MULTILINE_CPP_IS_BRIEF = NO + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + +INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new +# page for each member. If set to NO, the documentation of a member will be part +# of the file/class/namespace that contains it. +# The default value is: NO. + +SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + +TAB_SIZE = 4 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:\n" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". You can put \n's in the value part of an alias to insert +# newlines (in the resulting output). You can put ^^ in the value part of an +# alias to insert a newline as if a physical newline was in the original file. +# When you need a literal { or } or , in the value part of an alias you have to +# escape them by means of a backslash (\), this can lead to conflicts with the +# commands \{ and \} for these it is advised to use the version @{ and @} or use +# a double escape (\\{ and \\}) + +ALIASES = + +# This tag can be used to specify a number of word-keyword mappings (TCL only). +# A mapping has the form "name=value". For example adding "class=itcl::class" +# will allow you to use the command class in the itcl::class meaning. + +TCL_SUBST = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_JAVA = NO + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + +OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + +OPTIMIZE_OUTPUT_VHDL = NO + +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_SLICE = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, D, PHP, md (Markdown), Objective-C, Python, Slice, +# Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files), VHDL, tcl. For instance to make doxygen treat +# .inc files as Fortran files (default is PHP), and .f files as C (default is +# Fortran), use: inc=Fortran f=C. +# +# Note: For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by doxygen. + +EXTENSION_MAPPING = + +# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See https://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by doxygen, so you can +# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + +MARKDOWN_SUPPORT = YES + +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +TOC_INCLUDE_HEADINGS = 5 + +# When enabled doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by putting a % sign in front of the word or +# globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + +AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also make the inheritance and collaboration +# diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + +BUILTIN_STL_SUPPORT = NO + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + +CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen +# will parse them like normal C++ but will assume all classes use public instead +# of private inheritance when no explicit protection keyword is present. +# The default value is: NO. + +SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + +IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES then doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + +DISTRIBUTE_GROUP_DOC = NO + +# If one adds a struct or class to a group and this option is enabled, then also +# any nested class or struct is added to the same group. By default this option +# is disabled and one has to add nested compounds explicitly via \ingroup. +# The default value is: NO. + +GROUP_NESTED_COMPOUNDS = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + +SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + +INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + +INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + +TYPEDEF_HIDES_STRUCT = NO + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + +LOOKUP_CACHE_SIZE = 0 + +#--------------------------------------------------------------------------- +# Build related configuration options +#--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + +EXTRACT_ALL = YES + +# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will +# be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIVATE = YES + +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIV_VIRTUAL = YES + +# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + +EXTRACT_PACKAGE = YES + +# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be +# included in the documentation. +# The default value is: NO. + +EXTRACT_STATIC = YES + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO, +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + +EXTRACT_LOCAL_CLASSES = YES + +# This flag is only useful for Objective-C code. If set to YES, local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO, only methods in the interface are +# included. +# The default value is: NO. + +EXTRACT_LOCAL_METHODS = YES + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + +EXTRACT_ANON_NSPACES = NO + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_MEMBERS = NO + +# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO, these classes will be included in the various overviews. This option +# has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_CLASSES = NO + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend +# declarations. If set to NO, these declarations will be included in the +# documentation. +# The default value is: NO. + +HIDE_FRIEND_COMPOUNDS = NO + +# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO, these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + +HIDE_IN_BODY_DOCS = NO + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + +INTERNAL_DOCS = NO + +# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file +# names in lower-case letters. If set to YES, upper-case letters are also +# allowed. This is useful if you have classes or files whose names only differ +# in case and if your file system supports case sensitive file names. Windows +# (including Cygwin) ands Mac users are advised to set this option to NO. +# The default value is: system dependent. + +CASE_SENSE_NAMES = YES + +# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES, the +# scope will be hidden. +# The default value is: NO. + +HIDE_SCOPE_NAMES = NO + +# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will +# append additional text to a page's title, such as Class Reference. If set to +# YES the compound reference will be hidden. +# The default value is: NO. + +HIDE_COMPOUND_REFERENCE= NO + +# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + +SHOW_INCLUDE_FILES = YES + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + +SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + +FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + +INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. +# The default value is: YES. + +SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + +SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + +SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + +SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + +SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + +STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo +# list. This list is created by putting \todo commands in the documentation. +# The default value is: YES. + +GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test +# list. This list is created by putting \test commands in the documentation. +# The default value is: YES. + +GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + +GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + +GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + +ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + +MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES, the +# list will mention the files that were used to generate the documentation. +# The default value is: YES. + +SHOW_USED_FILES = YES + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + +SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + +SHOW_NAMESPACES = YES + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + +FILE_VERSION_FILTER = + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents doxygen's defaults, run doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. +# +# Note that if you run doxygen from a directory containing a file called +# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + +LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + +CITE_BIB_FILES = + +#--------------------------------------------------------------------------- +# Configuration options related to warning and progress messages +#--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + +QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + +WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + +WARN_IF_UNDOCUMENTED = YES + +# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for +# potential errors in the documentation, such as not documenting some parameters +# in a documented function, or documenting parameters that don't exist or using +# markup commands wrongly. +# The default value is: YES. + +WARN_IF_DOC_ERROR = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO, doxygen will only warn about wrong or incomplete +# parameter documentation, but not about the absence of documentation. If +# EXTRACT_ALL is set to YES then this flag will automatically be disabled. +# The default value is: NO. + +WARN_NO_PARAMDOC = NO + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. +# The default value is: NO. + +WARN_AS_ERROR = NO + +# The WARN_FORMAT tag determines the format of the warning messages that doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# The default value is: $file:$line: $text. + +WARN_FORMAT = "$file:$line: $text" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). + +WARN_LOGFILE = + +#--------------------------------------------------------------------------- +# Configuration options related to the input files +#--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING +# Note: If this tag is empty the current directory is searched. + +INPUT = + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: https://www.gnu.org/software/libiconv/) for the list of +# possible encodings. +# The default value is: UTF-8. + +INPUT_ENCODING = UTF-8 + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# read by doxygen. +# +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, +# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, +# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, +# *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C comment), +# *.doc (to be provided as doxygen C comment), *.txt (to be provided as doxygen +# C comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f, *.for, *.tcl, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. + +FILE_PATTERNS = *.c \ + *.cc \ + *.cxx \ + *.cpp \ + *.c++ \ + *.java \ + *.ii \ + *.ixx \ + *.ipp \ + *.i++ \ + *.inl \ + *.idl \ + *.ddl \ + *.odl \ + *.h \ + *.hh \ + *.hxx \ + *.hpp \ + *.h++ \ + *.cs \ + *.d \ + *.php \ + *.php4 \ + *.php5 \ + *.phtml \ + *.inc \ + *.m \ + *.markdown \ + *.md \ + *.mm \ + *.dox \ + *.doc \ + *.txt \ + *.py \ + *.pyw \ + *.f90 \ + *.f95 \ + *.f03 \ + *.f08 \ + *.f \ + *.for \ + *.tcl \ + *.vhd \ + *.vhdl \ + *.ucf \ + *.qsf \ + *.ice + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + +RECURSIVE = YES + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which doxygen is +# run. + +EXCLUDE = + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + +EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + +EXCLUDE_PATTERNS = + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# AClass::ANamespace, ANamespace::*Test +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories use the pattern */test/* + +EXCLUDE_SYMBOLS = + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + +EXAMPLE_PATTERNS = * + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + +EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + +IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + +INPUT_FILTER = + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + +FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + +FILTER_SOURCE_FILES = NO + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + +FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the doxygen output. + +USE_MDFILE_AS_MAINPAGE = + +#--------------------------------------------------------------------------- +# Configuration options related to source browsing +#--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + +SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# classes and enums directly into the documentation. +# The default value is: NO. + +INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + +STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# entity all documented functions referencing it will be listed. +# The default value is: NO. + +REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + +REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + +REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see https://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + +VERBATIM_HEADERS = YES + +# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the +# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the +# cost of reduced performance. This can be particularly helpful with template +# rich C++ code for which doxygen's built-in parser lacks the necessary type +# information. +# Note: The availability of this option depends on whether or not doxygen was +# generated with the -Duse_libclang=ON option for CMake. +# The default value is: NO. + +CLANG_ASSISTED_PARSING = NO + +# If clang assisted parsing is enabled you can provide the compiler with command +# line options that you would normally use when invoking the compiler. Note that +# the include paths will already be set by doxygen for the files and directories +# specified with INPUT and INCLUDE_PATH. +# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. + +CLANG_OPTIONS = + +# If clang assisted parsing is enabled you can provide the clang parser with the +# path to the compilation database (see: +# http://clang.llvm.org/docs/HowToSetupToolingForLLVM.html) used when the files +# were built. This is equivalent to specifying the "-p" option to a clang tool, +# such as clang-check. These options will then be passed to the parser. +# Note: The availability of this option depends on whether or not doxygen was +# generated with the -Duse_libclang=ON option for CMake. + +CLANG_DATABASE_PATH = + +#--------------------------------------------------------------------------- +# Configuration options related to the alphabetical class index +#--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + +ALPHABETICAL_INDEX = YES + +# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in +# which the alphabetical index list will be split. +# Minimum value: 1, maximum value: 20, default value: 5. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +COLS_IN_ALPHA_INDEX = 5 + +# In case all classes in a project start with a common prefix, all classes will +# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag +# can be used to specify a prefix (or a list of prefixes) that should be ignored +# while generating the index headers. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +IGNORE_PREFIX = + +#--------------------------------------------------------------------------- +# Configuration options related to the HTML output +#--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output +# The default value is: YES. + +GENERATE_HTML = YES + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_HEADER = + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FOOTER = + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefore more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra style sheet files is of importance (e.g. the last +# style sheet in the list overrules the setting of the previous ones in the +# list). For an example see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_STYLESHEET = + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_FILES = + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the style sheet and background images according to +# this color. Hue is specified as an angle on a colorwheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use grayscales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML +# page will contain the date and time when the page was generated. Setting this +# to YES can help to show when doxygen was last run and thus if the +# documentation is up to date. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_TIMESTAMP = NO + +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_MENUS = YES + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_SECTIONS = NO + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: https://developer.apple.com/xcode/), introduced with OSX +# 10.5 (Leopard). To create a documentation set, doxygen will generate a +# Makefile in the HTML output directory. Running make will produce the docset in +# that directory and running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# (see: https://www.microsoft.com/en-us/download/details.aspx?id=21138) on +# Windows. +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler (hhc.exe). If non-empty, +# doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated +# (YES) or that it should be included in the master .chm file (NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated +# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +TOC_EXPAND = NO + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual- +# folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom- +# filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom- +# filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location of Qt's +# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the +# generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + +ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can +# further fine-tune the look of the index. As an example, the default style +# sheet generated by doxygen has an example that shows how to put an image at +# the root of the tree instead of the PROJECT_NAME. Since the tree basically has +# the same information as the tab index, you could consider setting +# DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_TREEVIEW = NO + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + +ENUM_VALUES_PER_LINE = 4 + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + +TREEVIEW_WIDTH = 250 + +# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +EXT_LINKS_IN_WINDOW = NO + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_FONTSIZE = 10 + +# Use the FORMULA_TRANSPARENT tag to determine whether or not the images +# generated for formulas are transparent PNGs. Transparent PNGs are not +# supported properly for IE 6.0, but are supported on all modern browsers. +# +# Note that when changing this option you need to delete any form_*.png files in +# the HTML output directory before the changes have effect. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_TRANSPARENT = YES + +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. + +FORMULA_MACROFILE = + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# https://www.mathjax.org) which uses client side JavaScript for the rendering +# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +USE_MATHJAX = YES + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. See the MathJax site (see: +# http://docs.mathjax.org/en/latest/output.html) for more details. +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility), NativeMML (i.e. MathML) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from https://www.mathjax.org before deployment. +# The default value is: https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_RELPATH = https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/ + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled doxygen will generate a search box for +# the HTML output. The underlying search engine uses javascript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the javascript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , /