diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 545e1d8..3007db5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,343 +2,359 @@ name: CI on: push: - branches: - - main + branches: + - 'main' paths-ignore: - - 'README' + - 'README*' - 'COPYRIGHT' - 'LICENSE-*' - '**.md' - '**.txt' pull_request: paths-ignore: - - 'README' + - 'README*' - 'COPYRIGHT' - 'LICENSE-*' - '**.md' - '**.txt' workflow_dispatch: - schedule: [cron: "0 1 */7 * *"] + schedule: + - cron: "0 1 1 * *" env: CARGO_TERM_COLOR: always RUSTFLAGS: -Dwarnings RUST_BACKTRACE: 1 +# Notes on the matrix +# ------------------- +# `whispercpp-sys` invokes cmake on the bundled `whisper.cpp/` +# submodule and emits link directives for the resulting static +# libraries. Three consequences for CI: +# +# * Every checkout that builds the bundled crate MUST set +# `submodules: recursive`. The default +# `actions/checkout@v6` strips them, leaving build.rs to fail +# at the cmake step with a missing `CMakeLists.txt`. +# * Cross-compilation to wasm / mobile / niche-arch targets +# cannot drive cmake against a vendored C++ tree without +# per-target sysroot work. We pin the supported matrix to +# ubuntu / macos x86_64+aarch64 / windows. +# * GPU backends (cuda / hipblas / sycl / vulkan / opencl / +# musa / openvino) all need vendor SDKs not present on +# GitHub-hosted runners. CI exercises `bundled` (CPU) and, +# on macOS only, `metal,coreml`. +# +# Safety gates layered: +# +# * `rustfmt` / `clippy` — style + lint floor (deny-warnings) +# * `build-bundled` — the bundled CMake link path on +# ubuntu/macos/windows +# * `build-macos-gpu` — Apple Silicon Metal + CoreML link +# * `test` — `cargo test --lib` on linux/macos +# * `sanitizer` — Linux ASan + UBSan against the +# bundled FFI smoke test (the +# safety surface that flashed under +# every codex round). +# reinstated this after an earlier +# refactor accidentally dropped it. +# * `miri` — Miri on the safe wrapper crate's +# tests. Skipped FFI-touching tests +# via cfg(miri); covers the safe +# Rust invariants (atomic ordering +# on Context poison flag, drop +# order, lang serde, etc.). +# * `doc` — rustdoc -Dwarnings + jobs: - # Check formatting (platform-independent, one OS is enough) rustfmt: name: rustfmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 - - name: Install Rust - run: rustup update stable && rustup default stable && rustup component add rustfmt - - name: Check formatting - run: cargo fmt --all -- --check + - uses: actions/checkout@v6 + - run: rustup update stable && rustup default stable && rustup component add rustfmt + - run: cargo fmt --all -- --check - # Apply clippy lints clippy: name: clippy strategy: + fail-fast: false matrix: - os: - - ubuntu-latest - - macos-latest - - windows-latest + os: [ubuntu-latest, macos-latest] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v6 - - name: Install Rust - # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. - run: rustup update stable --no-self-update && rustup default stable && rustup component add clippy - - name: Install cargo-hack - run: cargo install cargo-hack - - name: Apply clippy lints - run: cargo hack clippy --each-feature --exclude-no-default-features + - uses: actions/checkout@v6 + with: + submodules: recursive + - run: rustup update stable && rustup default stable && rustup component add clippy + - run: cargo clippy --workspace --all-targets - # Run tests on some extra platforms - cross: - name: cross + build-bundled: + name: build (bundled CPU) on ${{ matrix.os }} strategy: + fail-fast: false matrix: - target: - - aarch64-unknown-linux-gnu - - aarch64-linux-android - - aarch64-unknown-linux-musl - - i686-linux-android - - x86_64-linux-android - - i686-pc-windows-gnu - - x86_64-pc-windows-gnu - - i686-unknown-linux-gnu - - powerpc64-unknown-linux-gnu - - riscv64gc-unknown-linux-gnu - - wasm32-unknown-unknown - - wasm32-unknown-emscripten - - wasm32-wasip1 - - wasm32-wasip1-threads - - wasm32-wasip2 - runs-on: ubuntu-latest + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v6 - - name: Cache cargo build and registry - uses: actions/cache@v5 + with: + submodules: recursive + - uses: actions/cache@v5 with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cross-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cross- - - name: Install Rust - run: rustup update stable && rustup default stable - - name: cargo build --target ${{ matrix.target }} - run: | - rustup target add ${{ matrix.target }} - cargo build --target ${{ matrix.target }} + key: ${{ runner.os }}-build-bundled-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-build-bundled- + - run: rustup update stable --no-self-update && rustup default stable + - run: cargo build -p whispercpp -p whispercpp-sys - build: - name: build - strategy: - matrix: - os: - - ubuntu-latest - - macos-latest - - windows-latest - runs-on: ${{ matrix.os }} + build-macos-gpu: + name: build (metal + coreml) + runs-on: macos-latest steps: - - uses: actions/checkout@v6 - - name: Cache cargo build and registry - uses: actions/cache@v5 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-build-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-build- - - name: Install Rust - # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. - run: rustup update stable --no-self-update && rustup default stable - - name: Install cargo-hack - run: cargo install cargo-hack - - name: Run build - run: cargo hack build --feature-powerset --exclude-no-default-features + - uses: actions/checkout@v6 + with: + submodules: recursive + - uses: actions/cache@v5 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: macos-gpu-${{ hashFiles('**/Cargo.lock') }} + restore-keys: macos-gpu- + - run: rustup update stable && rustup default stable + - run: cargo build -p whispercpp --features metal,coreml,serde test: - name: test + name: test on ${{ matrix.os }} strategy: + fail-fast: false matrix: - os: - - ubuntu-latest - - macos-latest - - windows-latest + os: [ubuntu-latest, macos-latest] runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v6 - - name: Cache cargo build and registry - uses: actions/cache@v5 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-test-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-test- - - name: Install Rust - # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. - run: rustup update stable --no-self-update && rustup default stable - - name: Install cargo-hack - run: cargo install cargo-hack - - name: Run test - run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features loom - - sanitizer: - name: sanitizer - runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - - name: Cache cargo build and registry - uses: actions/cache@v5 + with: + submodules: recursive + - uses: actions/cache@v5 with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-sanitizer-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-sanitizer- - - name: Install Rust - run: rustup update nightly && rustup default nightly - - name: Install rust-src - run: rustup component add rust-src - - name: ASAN / LSAN / MSAN / TSAN - run: bash ci/sanitizer.sh + key: ${{ runner.os }}-test-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-test- + - run: rustup update stable && rustup default stable + - run: cargo test -p whispercpp --features serde --lib - miri-tb: - name: miri-tb-${{ matrix.target }} - strategy: - matrix: - include: - - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - - os: ubuntu-latest - target: aarch64-unknown-linux-gnu - - os: ubuntu-latest - target: i686-unknown-linux-gnu - - os: ubuntu-latest - target: powerpc64-unknown-linux-gnu - - os: ubuntu-latest - target: s390x-unknown-linux-gnu - - os: ubuntu-latest - target: riscv64gc-unknown-linux-gnu - - os: macos-latest - target: aarch64-apple-darwin - runs-on: ${{ matrix.os }} + doc: + name: doc + runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - - name: Cache cargo build and registry - uses: actions/cache@v5 + with: + submodules: recursive + - run: rustup update stable && rustup default stable + - run: cargo doc --workspace --no-deps --features whispercpp/serde + env: + RUSTDOCFLAGS: -Dwarnings + + sanitizer: + name: sanitizer (asan) + runs-on: ubuntu-latest + env: + # Unset the workspace deny-warnings while Sanitizer + # builds the C++ tree — ggml's CMake emits warnings + # under sanitizer flags that are not actionable from + # this crate. + RUSTFLAGS: "" + steps: + - uses: actions/checkout@v6 + with: + submodules: recursive + - uses: actions/cache@v5 with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-miri-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-miri- - - name: Miri + key: ubuntu-sanitizer-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ubuntu-sanitizer- + # The Rust sanitizer support requires nightly. + - run: rustup toolchain install nightly --profile minimal --component rust-src + - run: rustup default nightly + # AddressSanitizer compiled into both the Rust crate and + # the vendored C++ via `-Zsanitizer=address` plus + # `-Zbuild-std`. We build *only* the bundled CPU path; + # GPU backends pull in vendor SDKs that bring their own + # UB into scope and aren't worth adjudicating here. + - name: cargo test (asan) run: | - bash ci/miri_tb.sh "${{ matrix.target }}" + export ASAN_OPTIONS=detect_leaks=0:abort_on_error=1 + cargo +nightly test \ + -p whispercpp \ + --features serde \ + --lib \ + -Zbuild-std \ + --target x86_64-unknown-linux-gnu \ + -- --test-threads 1 + env: + RUSTFLAGS: "-Zsanitizer=address" + # ASan ignores detect_leaks because the static-init + # paths in libstdc++ / ggml's TLS logger have benign + # one-shot leaks at init time. The leak surface we + # actually care about (post-throw state cleanup) is + # exercised by the test suite, not by ASan's leak + # tracker. - miri-sb: - name: miri-sb-${{ matrix.target }} - strategy: - matrix: - include: - - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - - os: ubuntu-latest - target: aarch64-unknown-linux-gnu - - os: ubuntu-latest - target: i686-unknown-linux-gnu - - os: ubuntu-latest - target: powerpc64-unknown-linux-gnu - - os: ubuntu-latest - target: s390x-unknown-linux-gnu - - os: ubuntu-latest - target: riscv64gc-unknown-linux-gnu - - os: macos-latest - target: aarch64-apple-darwin - runs-on: ${{ matrix.os }} + ubsan: + name: ubsan (C++ side) + runs-on: ubuntu-latest + # rustc has no `-Zsanitizer=undefined` pass; UBSan in this + # crate's threat model lives on the C++ side (whisper.cpp + + # ggml + the shim). We pass `-fsanitize=undefined` to the + # C++ compiler via `CFLAGS` / `CXXFLAGS` / `LDFLAGS` and run + # the test suite. cmake-rs forwards these env vars to the + # CMake invocation; `cc::Build` (used for the shim) honors + # them too. Stable Rust is enough — no `-Zbuild-std` needed + # because the sanitizer lives entirely in the C++ link. + env: + RUSTFLAGS: "" steps: - uses: actions/checkout@v6 - - name: Cache cargo build and registry - uses: actions/cache@v5 + with: + submodules: recursive + - uses: actions/cache@v5 with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-miri-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-miri- - - name: Miri + key: ubuntu-ubsan-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ubuntu-ubsan- + - run: rustup update stable && rustup default stable + - name: cargo test (ubsan, C++) run: | - bash ci/miri_sb.sh "${{ matrix.target }}" + # `halt_on_error=1` + `abort_on_error=1` makes UBSan + # terminate the process on the first hit (equivalent + # to `-fno-sanitize-recover=all` at the runtime level + # but using the recovering handlers `__ubsan_handle_*` + # that are always present in libubsan, instead of the + # `_abort` variants that need an extra runtime). + # `print_stacktrace=1` makes triage easier when CI + # does fire. + export UBSAN_OPTIONS="halt_on_error=1:abort_on_error=1:print_stacktrace=1" + cargo test \ + -p whispercpp \ + --features serde \ + --lib \ + -- --test-threads 1 + env: + # Instrument the C++ side with UBSan. Recovering form + # (NOT `-fno-sanitize-recover=all`) — see the + # UBSAN_OPTIONS comment above. + CFLAGS: "-fsanitize=undefined" + CXXFLAGS: "-fsanitize=undefined" + LDFLAGS: "-fsanitize=undefined" + # rustc invokes the linker with `-nodefaultlibs`, + # which suppresses gcc's normal libubsan auto- + # injection despite `-fsanitize=undefined` being on + # the link command. We add `-lubsan` explicitly, + # wrapped in `--no-as-needed` so the linker can't + # discard it before the C++ rlibs that reference + # `__ubsan_handle_*` get scanned. + RUSTFLAGS: "-Clink-arg=-Wl,--no-as-needed -Clink-arg=-lubsan -Clink-arg=-Wl,--as-needed" - loom: - name: loom - strategy: - matrix: - os: - - ubuntu-latest - - macos-latest - - windows-latest - runs-on: ${{ matrix.os }} + sanitizer-aarch64: + name: sanitizer (asan, aarch64-linux) + runs-on: ubuntu-24.04-arm + # AddressSanitizer on aarch64 catches arch-specific bugs + # (alignment, atomic ordering, intrinsic codegen) that the + # x86_64 ASan job can miss. + # + # We initially targeted HWAddressSanitizer here for its + # cheaper memory overhead, but `-Zsanitizer=hwaddress` + + # `+tagged-globals` produces ADRP relocations against + # `compiler_builtins`'s outline-atomics dispatcher whose + # tagged-section addresses are out of range for any + # currently-shipping linker (bfd, gold, lld all hit + # `R_AARCH64_ADR_PREL_PG_HI21 out of range`). That's a + # known rustc / compiler_builtins / HWASan incompatibility + # specific to aarch64 outline atomics; ASan side-steps it + # entirely. + env: + RUSTFLAGS: "" steps: - uses: actions/checkout@v6 - - name: Cache cargo build and registry - uses: actions/cache@v5 + with: + submodules: recursive + - uses: actions/cache@v5 with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-loom-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-loom- - - name: Install Rust - run: rustup update nightly --no-self-update && rustup default nightly - - name: Loom tests - run: cargo test --tests --features loom - - # valgrind: - # name: valgrind - # runs-on: ubuntu-latest - # steps: - # - uses: actions/checkout@v6 - # - name: Cache cargo build and registry - # uses: actions/cache@v5 - # with: - # path: | - # ~/.cargo/registry - # ~/.cargo/git - # target - # key: ubuntu-latest-valgrind-${{ hashFiles('**/Cargo.lock') }} - # restore-keys: | - # ubuntu-latest-valgrind- - # - name: Install Rust - # run: rustup update stable && rustup default stable - # - name: Install Valgrind - # run: | - # sudo apt-get update -y - # sudo apt-get install -y valgrind - # # Uncomment and customize when you have binaries to test: - # # - name: cargo build foo - # # run: cargo build --bin foo - # # working-directory: integration - # # - name: Run valgrind foo - # # run: valgrind --error-exitcode=1 --leak-check=full --show-leak-kinds=all ./target/debug/foo - # # working-directory: integration + key: aarch64-asan-${{ hashFiles('**/Cargo.lock') }} + restore-keys: aarch64-asan- + - run: rustup toolchain install nightly --profile minimal --component rust-src + - run: rustup default nightly + - name: cargo test (asan, aarch64) + run: | + export ASAN_OPTIONS=detect_leaks=0:abort_on_error=1 + cargo +nightly test \ + -p whispercpp \ + --features serde \ + --lib \ + -Zbuild-std \ + --target aarch64-unknown-linux-gnu \ + -- --test-threads 1 + env: + RUSTFLAGS: "-Zsanitizer=address" - coverage: - name: coverage + miri: + name: miri (safe wrapper) runs-on: ubuntu-latest - needs: - - rustfmt - - clippy - - build - - cross - - test - - sanitizer - - loom + env: + RUSTFLAGS: "" steps: - uses: actions/checkout@v6 - - name: Install Rust - run: rustup update nightly && rustup default nightly - - name: Install cargo-tarpaulin - run: cargo install cargo-tarpaulin - - name: Cache cargo build and registry - uses: actions/cache@v5 + with: + submodules: recursive + - uses: actions/cache@v5 with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-coverage-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-coverage- - - name: Run tarpaulin + key: ubuntu-miri-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ubuntu-miri- + - run: rustup toolchain install nightly --profile minimal --component miri + - run: rustup default nightly + # Miri does not support FFI / cmake-built C++. Tag + # FFI-touching tests with #[cfg_attr(miri, ignore)] + # in source. The remaining tests cover: + # + # * `params.rs`: clamp helpers, set_n_threads_unchecked + # bypass, beam-size + best-of bypass, default param + # normalisation + # * `context.rs`: AtomicBool poisoning visibility + # across threads, full_lock serialisation + # * `lang.rs`: serde round-trips, canonicalisation + # + # That's the surface where we have new unsafe code + # (the AbortCallback typedef + UnsafeCell trampoline + # storage). Miri verifies the Rust invariants there + # without ever entering the FFI. Useful for catching + # aliasing / strict-provenance / drop-order regressions. + - name: cargo miri test (safe wrapper, FFI tests skipped) + run: | + cargo +nightly miri test \ + -p whispercpp \ + --features serde \ + --lib env: - RUSTFLAGS: "--cfg tarpaulin" - run: cargo tarpaulin --all-features --run-types tests --run-types doctests --workspace --out xml - - name: Upload to codecov.io - uses: codecov/codecov-action@v6 - with: - token: ${{ secrets.CODECOV_TOKEN }} - slug: ${{ github.repository }} - fail_ci_if_error: true + MIRIFLAGS: "-Zmiri-strict-provenance -Zmiri-symbolic-alignment-check" diff --git a/.github/workflows/loc.yml b/.github/workflows/loc.yml index 6e176a6..bc76e34 100644 --- a/.github/workflows/loc.yml +++ b/.github/workflows/loc.yml @@ -51,7 +51,7 @@ jobs: await github.rest.gists.update({ gist_id: gistId, files: { - "template-rs": { + "whispercpp": { content: output } } diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..abbcedd --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "whispercpp-sys/whisper.cpp"] + path = whispercpp-sys/whisper.cpp + url = https://github.com/Findit-AI/whisper.cpp.git + branch = rust diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index bd7a668..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,7 +0,0 @@ -# UNRELEASED - -# 0.1.2 (January 6th, 2022) - -FEATURES - - diff --git a/Cargo.toml b/Cargo.toml index ff7fe91..eddec1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,48 +1,25 @@ -[package] -name = "template-rs" -version = "0.0.0" -edition = "2021" -repository = "https://github.com/al8n/template-rs" -homepage = "https://github.com/al8n/template-rs" -documentation = "https://docs.rs/template-rs" -description = "A template for creating Rust open-source repo on GitHub" -license = "MIT OR Apache-2.0" -rust-version = "1.73" +[workspace] +resolver = "3" +members = ["whispercpp-sys", "whispercpp"] -[[bench]] -path = "benches/foo.rs" -name = "foo" -harness = false +# `whisper.cpp/` is the upstream C++ source pulled in as a git +# submodule under `whispercpp-sys/whisper.cpp/`. Cargo's +# workspace discovery walks subdirectories looking for +# Cargo.toml; the submodule has none, so no exclude is needed — +# but we list it defensively so a future "Cargo.toml in +# whisper.cpp/" surprise doesn't pull it into the workspace. +exclude = ["whispercpp-sys/whisper.cpp"] -[features] -default = ["std"] -alloc = [] -std = [] +[workspace.package] +edition = "2024" +rust-version = "1.95" +license = "MIT OR Apache-2.0" +repository = "https://github.com/findit-studio/whispercpp" +readme = "README.md" -[dependencies] - -[dev-dependencies] -criterion = "0.8" -tempfile = "3" - -[profile.bench] -opt-level = 3 -debug = false -codegen-units = 1 -lto = 'thin' -incremental = false -debug-assertions = false -overflow-checks = false -rpath = false - -[package.metadata.docs.rs] -all-features = true -rustdoc-args = ["--cfg", "docsrs"] - -[lints.rust] +[workspace.lints.rust] rust_2018_idioms = "warn" single_use_lifetimes = "warn" unexpected_cfgs = { level = "warn", check-cfg = [ - 'cfg(all_tests)', 'cfg(tarpaulin)', ] } diff --git a/README-zh_CN.md b/README-zh_CN.md deleted file mode 100644 index 7a07f4d..0000000 --- a/README-zh_CN.md +++ /dev/null @@ -1,51 +0,0 @@ -
-

template-rs

-
-
- -开源Rust代码库GitHub模版 - -[github][Github-url] -LoC -[Build][CI-url] -[codecov][codecov-url] - -[docs.rs][doc-url] -[crates.io][crates-url] -[crates.io][crates-url] -license - -[English][en-url] | 简体中文 - -
- -## Installation - -```toml -[dependencies] -template_rs = "0.1" -``` - -## Features - -- [x] 更快的创建GitHub开源Rust代码库 - -#### License - -`Template-rs` is under the terms of both the MIT license and the -Apache License (Version 2.0). - -See [LICENSE-APACHE](LICENSE-APACHE), [LICENSE-MIT](LICENSE-MIT) for details. - -Copyright (c) 2021 Al Liu. - -[Github-url]: https://github.com/al8n/template-rs/ -[CI-url]: https://github.com/al8n/template/actions/workflows/template.yml -[doc-url]: https://docs.rs/template-rs -[crates-url]: https://crates.io/crates/template-rs -[codecov-url]: https://app.codecov.io/gh/al8n/template-rs/ -[license-url]: https://opensource.org/licenses/Apache-2.0 -[rustc-url]: https://github.com/rust-lang/rust/blob/master/RELEASES.md -[license-apache-url]: https://opensource.org/licenses/Apache-2.0 -[license-mit-url]: https://opensource.org/licenses/MIT -[en-url]: https://github.com/al8n/template-rs/tree/main/README.md diff --git a/README.md b/README.md index 1af27e2..afbe4f2 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,166 @@
-

template-rs

+

whispercpp

-A template for creating Rust open-source GitHub repo. +Safe Rust bindings for [whisper.cpp][whisper-cpp] speech-to-text inference. -[github][Github-url] -LoC -[Build][CI-url] -[codecov][codecov-url] +[github][Github-url] +LoC +[Build][CI-url] +[codecov][codecov-url] -[docs.rs][doc-url] -[crates.io][crates-url] -[crates.io][crates-url] +[docs.rs][doc-url] +[crates.io][crates-url] +[crates.io][crates-url] license -English | [简体中文][zh-cn-url] -
+Safe Rust bindings for [whisper.cpp][whisper-cpp] speech-to-text inference. + +- **Always-bundled build.** `whispercpp-sys` cmake-builds a vendored, + patched whisper.cpp; there is no pkg-config / system-install path. + The patched source lives on a fork branch with each fix as a + reviewable commit (see [Memory safety](#memory-safety) below). +- **Panic-free safe surface.** Every FFI call is wrapped in a C++ + exception-catching shim, every fallible setter returns + `WhisperError`, every accessor short-circuits on poisoned state. +- **`Send + Sync`** `Context`; per-`Context` `State` is `Send`. + Concurrent inference is serialized through a per-`Context` mutex + so per-call leak budgets are structural, not documentary. +- **Backend matrix.** Metal, CoreML, Vulkan, OpenCL, CUDA, ROCm + (HIP), oneAPI (SYCL), Moore Threads (MUSA), OpenVINO, OpenBLAS — + all opt-in via Cargo features. + ## Installation ```toml [dependencies] -template_rs = "0.1" +whispercpp = "0.1" +``` + +The default build is plain CPU. Opt into accelerators per-target: + +```toml +# macOS Apple Silicon +[target.'cfg(all(target_os = "macos", target_arch = "aarch64"))'.dependencies] +whispercpp = { version = "0.1", features = ["metal", "coreml"] } + +# Linux + NVIDIA +[target.'cfg(all(target_os = "linux", target_arch = "x86_64"))'.dependencies] +whispercpp = { version = "0.1", features = ["cuda"] } ``` -## Features -- [x] Create a Rust open-source repo fast +## Examples + +A working end-to-end example lives at +[`whispercpp/examples/smoke.rs`](whispercpp/examples/smoke.rs). + +## Backends + +All backend features chain to the matching `whispercpp-sys` feature +which toggles the corresponding ggml / whisper CMake flag. + +| Feature | Backend | Platforms | +|------------|--------------------------------------|--------------------------| +| `metal` | Metal GPU | Apple | +| `coreml` | CoreML / ANE encoder | Apple (with `.mlmodelc`) | +| `vulkan` | Vulkan compute | Linux / Windows / Android / MoltenVK on macOS | +| `opencl` | OpenCL (mobile / Adreno) | Linux / Android | +| `cuda` | NVIDIA CUDA | Linux / Windows | +| `hipblas` | AMD ROCm / HIP | Linux | +| `sycl` | Intel oneAPI / Arc | Linux / Windows | +| `musa` | Moore Threads MUSA | Linux | +| `openvino` | Intel OpenVINO encoder | Linux / Windows | +| `openblas` | OpenBLAS CPU | Any | +| `serde` | `Serialize` / `Deserialize` for `Lang` (lowercase ISO-639-1) | — | + +GPU backends require the corresponding vendor SDK (CUDA Toolkit, +ROCm, oneAPI, etc.) installed at link time. CI exercises the +bundled CPU path on Linux/macOS/Windows and Metal+CoreML on macOS. + +## Memory safety + +`whisper.cpp` is a binary parser of attacker-controllable model files +plus a substantial C++ inference path. The vendored submodule is +pinned to our fork branch +([`Findit-AI/whisper.cpp@rust`][fork-rust-branch]), which carries +fixes for upstream issues reachable from safe Rust: + +- `whisper_kv_cache_free` made idempotent (closes a multi-decoder + OOM double-free of a ggml backend buffer). +- `whisper_init_state` / `whisper_init_with_params_no_state` / + `whisper_vad_init_with_params` wrapped in RAII so a throw mid-init + releases the partial allocation rather than leaking the + whisper_context / whisper_state. +- Tensor headers fully validated: `n_dims ∈ [0, 4]`, name length + bounded, `ttype < GGML_TYPE_COUNT`, per-dim positivity, 64-bit + overflow check on `nelements`. +- Hparams validated against generous-but-bounded ranges; min + `n_text_ctx` enforced so the decode batch can hold the + worst-case prompt. +- Special-token ids verified to fit `n_vocab` after the + multilingual shift (closes a corrupt-vocab OOB into `logits[]`). +- File / buffer loaders throw on partial reads (peek-based EOF + detection so clean end-of-tensor-list still terminates). +- Tensor-name set tracking rejects models that satisfy the + loaded-count check by repeating one name. +- `ggml_log_set` installed once per process via `std::atomic` + so concurrent `create_state` + `State::full` don't race on + ggml's static logger globals. +- `vocab.num_languages()` synthesis null-checks + `whisper_lang_str` (closes `std::string(nullptr)` UB). +- The abort callback is wired through every sched-based graph + compute so cancellation interrupts the long-running encoder / + decoder paths, not just the gaps between them. + +A C++ exception-catching shim layer (`whispercpp_shim.cpp`) sits +between the safe Rust API and every throwing entry point. The +bindgen allowlist is enumerated symbol-by-symbol — only no-throw +raw `whisper_*` functions are exposed; every throwing function +goes through a `whispercpp_*` shim that catches and surfaces the +exception class as a sentinel (`WhisperError::ConstructorLost`, +`StateLost`, etc.). + +`build.rs` includes a canary that scans the linked source for the +required patch markers and hard-fails the build if any are missing. + +For the design details, the per-finding analysis lives on the fork +branch's commit history. + +## Crate structure + +| Crate | Purpose | +|------------------|-------------------------------------------------------------------------------------------------| +| `whispercpp` | Safe Rust API (`Context`, `State`, `Params`, `Lang`, `WhisperError`). End-user dependency. | +| `whispercpp-sys` | Bindgen output + `build.rs` (cmake build, link directives) + the C++ exception-catching shim. | + +End users should depend on `whispercpp`. `whispercpp-sys` is +re-exported as `whispercpp::sys` for callers who need a raw +escape hatch (review every use carefully — only no-throw symbols +are exposed but it's `unsafe` regardless). + +## Supported platforms + +CI runs on `ubuntu-latest`, `macos-latest`, and `windows-latest`. +Sanitizer (ASan + UBSan) and Miri jobs gate the `unsafe` boundary +on every PR. MSRV is pinned in `Cargo.toml` and enforced via +`rust-version`. -#### License +## License -`template-rs` is under the terms of both the MIT license and the +`whispercpp` is under the terms of both the MIT license and the Apache License (Version 2.0). See [LICENSE-APACHE](LICENSE-APACHE), [LICENSE-MIT](LICENSE-MIT) for details. -Copyright (c) 2021 Al Liu. +Copyright (c) 2026 FinDIT Studio authors. -[Github-url]: https://github.com/al8n/template-rs/ -[CI-url]: https://github.com/al8n/template-rs/actions/workflows/ci.yml -[doc-url]: https://docs.rs/template-rs -[crates-url]: https://crates.io/crates/template-rs -[codecov-url]: https://app.codecov.io/gh/al8n/template-rs/ -[zh-cn-url]: https://github.com/al8n/template-rs/tree/main/README-zh_CN.md +[whisper-cpp]: https://github.com/ggerganov/whisper.cpp +[fork-rust-branch]: https://github.com/Findit-AI/whisper.cpp/tree/rust +[Github-url]: https://github.com/findit-ai/whispercpp/ +[CI-url]: https://github.com/findit-ai/whispercpp/actions/workflows/ci.yml +[doc-url]: https://docs.rs/whispercpp +[crates-url]: https://crates.io/crates/whispercpp +[codecov-url]: https://app.codecov.io/gh/findit-ai/whispercpp/ diff --git a/benches/foo.rs b/benches/foo.rs deleted file mode 100644 index f328e4d..0000000 --- a/benches/foo.rs +++ /dev/null @@ -1 +0,0 @@ -fn main() {} diff --git a/ci/miri_sb.sh b/ci/miri_sb.sh deleted file mode 100755 index cc3c6e0..0000000 --- a/ci/miri_sb.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -set -e - -if [ -z "$1" ]; then - echo "Error: TARGET is not provided" - exit 1 -fi - -TARGET="$1" - -# Install cross-compilation toolchain on Linux -if [ "$(uname)" = "Linux" ]; then - case "$TARGET" in - aarch64-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-aarch64-linux-gnu - ;; - i686-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-multilib - ;; - powerpc64-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-powerpc64-linux-gnu - ;; - s390x-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-s390x-linux-gnu - ;; - riscv64gc-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-riscv64-linux-gnu - ;; - esac -fi - -rustup toolchain install nightly --component miri -rustup override set nightly -cargo miri setup - -export MIRIFLAGS="-Zmiri-strict-provenance -Zmiri-disable-isolation -Zmiri-symbolic-alignment-check" - -cargo miri test --all-targets --target "$TARGET" diff --git a/ci/miri_tb.sh b/ci/miri_tb.sh deleted file mode 100755 index 5d374c7..0000000 --- a/ci/miri_tb.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -set -e - -if [ -z "$1" ]; then - echo "Error: TARGET is not provided" - exit 1 -fi - -TARGET="$1" - -# Install cross-compilation toolchain on Linux -if [ "$(uname)" = "Linux" ]; then - case "$TARGET" in - aarch64-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-aarch64-linux-gnu - ;; - i686-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-multilib - ;; - powerpc64-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-powerpc64-linux-gnu - ;; - s390x-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-s390x-linux-gnu - ;; - riscv64gc-unknown-linux-gnu) - sudo apt-get update && sudo apt-get install -y gcc-riscv64-linux-gnu - ;; - esac -fi - -rustup toolchain install nightly --component miri -rustup override set nightly -cargo miri setup - -export MIRIFLAGS="-Zmiri-strict-provenance -Zmiri-disable-isolation -Zmiri-symbolic-alignment-check -Zmiri-tree-borrows" - -cargo miri test --all-targets --target "$TARGET" diff --git a/ci/sanitizer.sh b/ci/sanitizer.sh deleted file mode 100755 index 4ff6819..0000000 --- a/ci/sanitizer.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -set -ex - -export ASAN_OPTIONS="detect_odr_violation=0 detect_leaks=0" - -TARGET="x86_64-unknown-linux-gnu" - -# Run address sanitizer -RUSTFLAGS="-Z sanitizer=address" \ -cargo test --tests --target "$TARGET" --all-features - -# Run leak sanitizer -RUSTFLAGS="-Z sanitizer=leak" \ -cargo test --tests --target "$TARGET" --all-features - -# Run memory sanitizer (requires -Zbuild-std for instrumented std) -RUSTFLAGS="-Z sanitizer=memory" \ -cargo -Zbuild-std test --tests --target "$TARGET" --all-features - -# Run thread sanitizer (requires -Zbuild-std for instrumented std) -RUSTFLAGS="-Z sanitizer=thread" \ -cargo -Zbuild-std test --tests --target "$TARGET" --all-features diff --git a/examples/foo.rs b/examples/foo.rs deleted file mode 100644 index f328e4d..0000000 --- a/examples/foo.rs +++ /dev/null @@ -1 +0,0 @@ -fn main() {} diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index 0a58390..0000000 --- a/src/lib.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! A template for creating Rust open-source repo on GitHub -#![cfg_attr(not(feature = "std"), no_std)] -#![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(docsrs, allow(unused_attributes))] -#![deny(missing_docs)] - -#[cfg(all(not(feature = "std"), feature = "alloc"))] -extern crate alloc as std; - -#[cfg(feature = "std")] -extern crate std; diff --git a/tests/foo.rs b/tests/foo.rs deleted file mode 100644 index 8b13789..0000000 --- a/tests/foo.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/whispercpp-sys/.gitignore b/whispercpp-sys/.gitignore new file mode 100644 index 0000000..e998183 --- /dev/null +++ b/whispercpp-sys/.gitignore @@ -0,0 +1,2 @@ +# cmake-rs build output for whisper.cpp. +target/ diff --git a/whispercpp-sys/Cargo.toml b/whispercpp-sys/Cargo.toml new file mode 100644 index 0000000..eedc6e2 --- /dev/null +++ b/whispercpp-sys/Cargo.toml @@ -0,0 +1,67 @@ +[package] +name = "whispercpp-sys" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +readme = "README.md" +description = "Low-level Rust FFI bindings to whisper.cpp. Cmake-builds a patched fork that closes upstream OOM / UB / leak hazards reachable from safe Rust." +keywords = ["whisper", "ffi", "bindings", "asr", "speech-to-text"] +categories = ["external-ffi-bindings", "multimedia::audio", "science"] +links = "whisper" + +[lib] +name = "whispercpp_sys" +path = "src/lib.rs" + +[features] +# Default to a CPU build — no GPU dep pulled. The `whispercpp` +# safe wrapper turns Metal/CoreML on by default via its own +# feature defaults; this crate stays minimal so servers / CI +# runners that build `--no-default-features` get a fast +# plain-CPU compile. +# +# whisper.cpp is **always** built from the vendored submodule +# (`whisper.cpp/`, pinned to a patched fork branch). There is +# no pkg-config / system-install path: routing safe Rust code +# through a stock libwhisper would silently lose the +# memory-safety guarantees the bundled patches provide. +default = [] + +# Backend feature flags map 1:1 to whisper.cpp/ggml CMake +# options. Each feature toggles the matching `-DGGML_*=ON` +# (or `-DWHISPER_*=ON`) flag plus the link directives for the +# static library cmake produces and any system framework / +# library it depends on. +# +# Apple-only: +metal = [] # GGML_METAL: Metal GPU encoder/decoder +coreml = [] # WHISPER_COREML: CoreML encoder dispatch (ANE) +# Cross-platform GPU: +vulkan = [] # GGML_VULKAN: Vulkan compute (Linux/Win/Android) +opencl = [] # GGML_OPENCL: OpenCL (mobile GPUs, Adreno) +# Vendor-specific GPU: +cuda = [] # GGML_CUDA: NVIDIA CUDA +hipblas = [] # GGML_HIP: AMD ROCm/HIP (formerly GGML_HIPBLAS) +sycl = [] # GGML_SYCL: Intel oneAPI / Arc GPUs +musa = [] # GGML_MUSA: Moore Threads MUSA +# Encoder accelerators (similar role to CoreML, but other vendors): +openvino = [] # WHISPER_OPENVINO: Intel OpenVINO encoder +# CPU-side BLAS: +openblas = [] # GGML_BLAS=ON, GGML_BLAS_VENDOR=OpenBLAS + +[build-dependencies] +# `cmake` drives the bundled whisper.cpp build. +cmake = "0.1" +# `bindgen` generates Rust FFI for whisper.h + the shim +# header. Output lands in `OUT_DIR/generated.rs` so the +# source tree stays read-only for cargo vendor / Nix builds. +bindgen = "0.72" +# `cc` compiles `whispercpp_shim.cpp` into a tiny static lib +# that catches C++ exceptions before they unwind across the +# `extern "C"` boundary into Rust. +cc = "1" + +[lints] +workspace = true diff --git a/whispercpp-sys/README.md b/whispercpp-sys/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/whispercpp-sys/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/whispercpp-sys/build.rs b/whispercpp-sys/build.rs new file mode 100644 index 0000000..37a3088 --- /dev/null +++ b/whispercpp-sys/build.rs @@ -0,0 +1,553 @@ +//! Build script for the whisper.cpp FFI bindings. +//! +//! Compiles the vendored `whisper.cpp/` git submodule via +//! cmake-rs, links static. Feature flags translate to +//! `-DGGML_METAL=ON` etc. Output is a static `libwhisper.a` +//! plus the ggml satellite libraries that whisper.cpp's +//! CMakeLists produces. +//! +//! There is no pkg-config / system-install path: the bundled +//! source is patched in `OUT_DIR/whisper-src/` to close +//! several upstream memory-safety bugs, +//! and routing safe-Rust code through a stock libwhisper +//! would silently drop those guarantees. +//! +//! Bindgen runs against the resolved header set so the Rust +//! FFI matches the linked library's ABI. Output goes to +//! `OUT_DIR/generated.rs` (— must NOT mutate +//! the source tree). +//! +//! Bootstrap behaviour: when the submodule is missing this +//! script emits clear `cargo:warning=`s rather than panicking, +//! so `cargo check` still resolves the API. The actual link +//! step fails downstream, by design. + +use std::{ + env, + path::{Path, PathBuf}, +}; + +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=wrapper.h"); + + bundled_build(); +} + +// ─── Bundled path ──────────────────────────────────────────── + +fn bundled_build() { + let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + // The vendored submodule pinned via `.gitmodules` to the + // `Findit-AI/whisper.cpp` fork's `rust` branch — which + // carries our memory-safety patches as committed history + // — is the SOLE source of truth for the whisper.cpp build. + // + // No `WHISPER_CPP_DIR` override: the Rust safety surface + // (e.g. `State::full`'s free-on-sentinel path) relies on + // the fork's idempotent `whisper_kv_cache_free` and other + // patches being present in the linked binary. A pristine + // upstream checkout shares the same ABI but lacks those + // patches, so an env-var override would silently + // reintroduce the double-free / use-after-free class the + // wrapper closes. Users who need a different source must + // edit `.gitmodules` (reviewable) rather than flip an env + // var. + let whisper_src = crate_dir.join("whisper.cpp"); + + if !whisper_src.join("CMakeLists.txt").is_file() { + println!( + "cargo:warning=whisper.cpp source not found at {:?}.", + whisper_src + ); + println!("cargo:warning=Run `git submodule update --init --recursive` from the repo root."); + println!( + "cargo:warning=Skipping cmake + bindgen for now; link step will fail until the source is available." + ); + return; + } + + // Verify the linked source carries our patches. Cheap + // canary: scan for any sentinel comment from the patch + // set. If absent, the build hard-fails — Rust safety + // assumptions in the wrapper depend on this. + verify_patched_source(&whisper_src); + + // Tell cargo to rerun build.rs when files in the submodule + // change so `git submodule update` picks up automatically. + for top in ["CMakeLists.txt", "cmake", "include", "src", "ggml"] { + let p = whisper_src.join(top); + if p.exists() { + println!("cargo:rerun-if-changed={}", p.display()); + } + } + + let dst = build_whisper_cpp(&whisper_src); + let bundled_includes = vec![ + whisper_src.join("include"), + whisper_src.join("ggml").join("include"), + ]; + // Build the shim BEFORE emitting whisper.cpp's link + // directives. GNU ld resolves left-to-right; the shim + // depends on `whisper_*` symbols so it must appear first + // in the link list. cc::Build emits its `link-lib` line + // immediately on `compile`. + build_shim(&bundled_includes); + emit_bundled_link_directives(&dst); + let bundled_args: Vec = bundled_includes + .iter() + .map(|p| format!("-I{}", p.display())) + .collect(); + generate_bindings_with_args(&bundled_args); +} + +/// Hard-fail the build if the linked whisper.cpp source is +/// missing the rust-branch patch set. The Rust wrapper's +/// memory-safety guarantees (e.g. `State::full`'s +/// free-on-sentinel path in relying on 's +/// idempotent `whisper_kv_cache_free`) are unsound against a +/// pristine upstream tree even though the ABI is identical. +/// +/// Strategy: scan `src/whisper.cpp` for one or more sentinel +/// comments inserted by the rust-branch patches. If any +/// expected marker is missing the build refuses to proceed. +/// +/// This catches both `git submodule update` against unpatched +/// upstream AND someone manually replacing the submodule with +/// a different tree. +fn verify_patched_source(whisper_src: &Path) { + let target = whisper_src.join("src").join("whisper.cpp"); + let body = match std::fs::read_to_string(&target) { + Ok(b) => b, + Err(e) => panic!( + "whispercpp-sys: failed to read {} for patch verification: {e}", + target.display() + ), + }; + + // Sentinels chosen from the highest-leverage patches — + // the ones whose absence would re-introduce the + // double-free / null-deref / leak hazards the Rust + // wrapper assumes are closed. + const REQUIRED_MARKERS: &[&str] = &[ + "whispercpp-sys: kv_cache_free idempotent fix", + "whispercpp-sys: read_safe zero-init", + "whispercpp-sys: init_state RAII entry", + "whispercpp-sys: init_context RAII entry", + "whispercpp-sys: tensor header validation (model_load)", + "whispercpp-sys: ggml_log_set once-per-process", + "whispercpp-sys: hparams validation", + "whispercpp-sys: lang_str null guard", + "whispercpp-sys: special-token bounds check", + "whispercpp-sys: path_model assignment guard", + "whispercpp-sys: sched abort callback wiring", + "whispercpp-sys: vad_init RAII guard", + ]; + + let missing: Vec<&str> = REQUIRED_MARKERS + .iter() + .copied() + .filter(|m| !body.contains(m)) + .collect(); + + if !missing.is_empty() { + panic!( + "whispercpp-sys: the linked whisper.cpp source at {} is missing the rust-branch patches \ + (required marker{} absent: {:?}).\n\n\ + The Rust safety surface depends on these patches; building against unpatched upstream \ + reintroduces multi-decoder double-free / use-after-free / null-deref classes.\n\n\ + Fix: ensure the submodule tracks `Findit-AI/whisper.cpp` branch `rust`. Run\n \ + git submodule update --init --recursive\n\ + from the repo root. If you intentionally pointed at a different source, add equivalent \ + patches and the matching marker comments before retrying.", + target.display(), + if missing.len() == 1 { "" } else { "s" }, + missing, + ); + } +} + +/// Compile `whispercpp_shim.cpp` into a `libwhispercpp_shim.a` +/// staticlib in `OUT_DIR`, and emit the link directive for it. +/// +/// The shim catches C++ exceptions inside whisper.cpp so they +/// can't unwind across `extern "C"` into Rust. It must be +/// linked BEFORE the whisper static libs in the GNU ld +/// dependency chain so the shim's references to `whisper_*` +/// resolve. +fn build_shim(include_paths: &[PathBuf]) { + let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let mut build = cc::Build::new(); + build + .cpp(true) + .file(crate_dir.join("whispercpp_shim.cpp")) + .flag_if_supported("-std=c++17") + .flag_if_supported("/std:c++17"); + for inc in include_paths { + build.include(inc); + } + // `cc::Build::compile` emits `cargo:rustc-link-lib=static=...` + // and `cargo:rustc-link-search=native=...` automatically. + build.compile("whispercpp_shim"); + // Tell cargo to rerun the shim build when the source files + // change. (cc doesn't do this for us.) + println!("cargo:rerun-if-changed=whispercpp_shim.cpp"); + println!("cargo:rerun-if-changed=whispercpp_shim.h"); +} + +/// Drive the cmake build. Returns the install root cmake-rs +/// produced (typically `OUT_DIR/`). +fn build_whisper_cpp(whisper_src: &PathBuf) -> PathBuf { + let mut cfg = cmake::Config::new(whisper_src); + cfg + .define("BUILD_SHARED_LIBS", "OFF") + .define("WHISPER_BUILD_EXAMPLES", "OFF") + .define("WHISPER_BUILD_TESTS", "OFF") + .define("WHISPER_BUILD_SERVER", "OFF") + // Force OpenMP off. ggml's CMake auto- + // detects OpenMP; if the host has it (Linux + libgomp, + // macOS + brew libomp, etc.) it links against the + // OpenMP runtime which our `cargo:rustc-link-lib=` set + // doesn't emit, producing platform-specific link + // surprises. The wrapper also caps `n_threads = 1`, so + // OpenMP can't help anyway. Explicit OFF makes the + // bundled build deterministic across runners. + .define("GGML_OPENMP", "OFF") + // ggml fast-math + Apple Accelerate / OpenBLAS are decided + // per-feature below. + .profile("Release"); + + if cfg!(feature = "metal") { + cfg.define("GGML_METAL", "ON"); + cfg.define("GGML_METAL_NDEBUG", "ON"); + // Embed the metal shader library bytes into libggml-metal.a + // so the runtime doesn't need a sibling `default.metallib`. + cfg.define("GGML_METAL_EMBED_LIBRARY", "ON"); + } else { + cfg.define("GGML_METAL", "OFF"); + } + + if cfg!(feature = "coreml") { + cfg.define("WHISPER_COREML", "ON"); + // Enable the post-init fallback: if the `.mlmodelc` + // companion is missing at runtime, fall back to the GGML + // encoder rather than aborting. This is what whisper-cli + // does by default. + cfg.define("WHISPER_COREML_ALLOW_FALLBACK", "ON"); + } + + if cfg!(feature = "openblas") { + cfg.define("GGML_BLAS", "ON"); + cfg.define("GGML_BLAS_VENDOR", "OpenBLAS"); + } else if cfg!(target_vendor = "apple") && !cfg!(feature = "metal") { + // Apple CPU build: prefer the system Accelerate framework. + cfg.define("GGML_BLAS", "ON"); + cfg.define("GGML_BLAS_VENDOR", "Apple"); + } + + // ── Vendor-specific GPU backends ──────────────────────── + // Each `-DGGML_*=ON` triggers cmake's matching find_package + // / FetchContent for the SDK (CUDA Toolkit, ROCm, oneAPI, + // etc.). The user is expected to have the SDK installed; we + // don't auto-fetch. + if cfg!(feature = "cuda") { + cfg.define("GGML_CUDA", "ON"); + } + if cfg!(feature = "hipblas") { + // Renamed `GGML_HIPBLAS` → `GGML_HIP` upstream around + // ggml 0.10. We keep the Rust feature name `hipblas` to + // match the convention whisper-rs / llama-cpp-rs adopted + // before the upstream rename. + cfg.define("GGML_HIP", "ON"); + } + if cfg!(feature = "sycl") { + cfg.define("GGML_SYCL", "ON"); + } + if cfg!(feature = "musa") { + cfg.define("GGML_MUSA", "ON"); + } + + // ── Cross-platform GPU ───────────────────────────────── + if cfg!(feature = "vulkan") { + cfg.define("GGML_VULKAN", "ON"); + } + if cfg!(feature = "opencl") { + cfg.define("GGML_OPENCL", "ON"); + } + + // ── Encoder accelerators ─────────────────────────────── + if cfg!(feature = "openvino") { + cfg.define("WHISPER_OPENVINO", "ON"); + } + + cfg.build() +} + +/// Tell cargo which static libraries to link, in the right +/// order for the GNU/macos/MSVC linkers. cmake-rs's `build` +/// returns `/`, with libs under `lib/`. +fn emit_bundled_link_directives(install_root: &Path) { + let lib_dir = install_root.join("lib"); + println!("cargo:rustc-link-search=native={}", lib_dir.display()); + + // Order matters for GNU ld: depending libs first, low-level + // last. whisper depends on ggml; ggml's metal/blas/coreml + // sub-libs are leaves. + println!("cargo:rustc-link-lib=static=whisper"); + println!("cargo:rustc-link-lib=static=ggml"); + println!("cargo:rustc-link-lib=static=ggml-base"); + println!("cargo:rustc-link-lib=static=ggml-cpu"); + + // On Apple Silicon, whisper.cpp's CMake also builds the + // ggml-blas backend automatically (the BLAS-via-Accelerate + // path), even when Metal is the primary backend. We link it + // unconditionally on Apple targets so the resulting binary + // resolves `ggml_backend_blas_reg`. + if cfg!(target_vendor = "apple") { + println!("cargo:rustc-link-lib=static=ggml-blas"); + println!("cargo:rustc-link-lib=framework=Accelerate"); + } + if cfg!(feature = "metal") { + println!("cargo:rustc-link-lib=static=ggml-metal"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=MetalKit"); + println!("cargo:rustc-link-lib=framework=Foundation"); + } + if cfg!(feature = "coreml") { + println!("cargo:rustc-link-lib=static=whisper.coreml"); + println!("cargo:rustc-link-lib=framework=CoreML"); + } + if cfg!(feature = "openblas") { + println!("cargo:rustc-link-lib=dylib=openblas"); + } + + // ── CUDA ─────────────────────────────────────────────── + // cmake produces `libggml-cuda.a`; the runtime resolves + // CUDA Toolkit symbols via `cudart`/`cublas` dylibs in + // `$CUDA_PATH/lib64` (Linux) or `\lib\x64` (Windows). The + // user must have the CUDA Toolkit installed; we don't ship + // it. `cargo:rustc-link-search` is left to the system + // default — `LD_LIBRARY_PATH` / Windows `PATH` covers it. + if cfg!(feature = "cuda") { + println!("cargo:rustc-link-lib=static=ggml-cuda"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=cublas"); + println!("cargo:rustc-link-lib=dylib=cublasLt"); + } + + // ── ROCm / HIP (AMD) ─────────────────────────────────── + if cfg!(feature = "hipblas") { + println!("cargo:rustc-link-lib=static=ggml-hip"); + println!("cargo:rustc-link-lib=dylib=amdhip64"); + println!("cargo:rustc-link-lib=dylib=hipblas"); + println!("cargo:rustc-link-lib=dylib=rocblas"); + } + + // ── Intel SYCL / oneAPI ──────────────────────────────── + if cfg!(feature = "sycl") { + println!("cargo:rustc-link-lib=static=ggml-sycl"); + println!("cargo:rustc-link-lib=dylib=sycl"); + println!("cargo:rustc-link-lib=dylib=OpenCL"); + println!("cargo:rustc-link-lib=dylib=mkl_sycl"); + println!("cargo:rustc-link-lib=dylib=mkl_intel_ilp64"); + println!("cargo:rustc-link-lib=dylib=mkl_tbb_thread"); + println!("cargo:rustc-link-lib=dylib=mkl_core"); + } + + // ── Moore Threads MUSA ───────────────────────────────── + if cfg!(feature = "musa") { + println!("cargo:rustc-link-lib=static=ggml-musa"); + println!("cargo:rustc-link-lib=dylib=musa"); + println!("cargo:rustc-link-lib=dylib=musart"); + println!("cargo:rustc-link-lib=dylib=mublas"); + } + + // ── Vulkan (cross-platform GPU) ──────────────────────── + if cfg!(feature = "vulkan") { + println!("cargo:rustc-link-lib=static=ggml-vulkan"); + if cfg!(target_os = "macos") { + // MoltenVK ships a `vulkan` dylib that translates to Metal. + println!("cargo:rustc-link-lib=dylib=vulkan"); + } else if cfg!(target_os = "windows") { + println!("cargo:rustc-link-lib=dylib=vulkan-1"); + } else { + println!("cargo:rustc-link-lib=dylib=vulkan"); + } + } + + // ── OpenCL (mobile GPUs / Adreno) ────────────────────── + if cfg!(feature = "opencl") { + println!("cargo:rustc-link-lib=static=ggml-opencl"); + if cfg!(target_os = "macos") { + println!("cargo:rustc-link-lib=framework=OpenCL"); + } else { + println!("cargo:rustc-link-lib=dylib=OpenCL"); + } + } + + // ── OpenVINO (Intel encoder accelerator) ─────────────── + if cfg!(feature = "openvino") { + println!("cargo:rustc-link-lib=static=whisper.openvino"); + println!("cargo:rustc-link-lib=dylib=openvino"); + println!("cargo:rustc-link-lib=dylib=openvino_c"); + } + + // C++ stdlib — whisper.cpp / ggml are C++. + if cfg!(target_os = "macos") { + println!("cargo:rustc-link-lib=dylib=c++"); + } else if cfg!(target_os = "linux") { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } +} + +// ─── Bindgen ───────────────────────────────────────────────── + +/// Run bindgen against a curated `wrapper.h` and write the +/// result to `$OUT_DIR/generated.rs`. +/// +/// **Why OUT_DIR, not in-tree.** flagged the +/// previous in-tree path (`src/generated.rs`) as breaking +/// read-only builds — cargo's standard `vendor` workflow, +/// Nix-style fixed-output derivations, Bazel sandboxes, and +/// verified-source registry checkouts all forbid build.rs +/// from mutating the source tree. Per cargo's contract, every +/// build.rs side-effect goes under `OUT_DIR`. The +/// `include!` glue lives in `src/lib.rs`. +/// +/// Trade-off: the FFI surface is no longer grep-able from a +/// fresh checkout. Inspect via `cargo expand +/// -p whispercpp-sys` or look at +/// `target///build/whispercpp-sys-/out/generated.rs` +/// after a build. +fn generate_bindings_with_args(clang_args: &[String]) { + let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let header = crate_dir.join("wrapper.h"); + + let mut builder = bindgen::Builder::default().header(header.to_string_lossy().to_string()); + for arg in clang_args { + builder = builder.clang_arg(arg); + } + let bindings = builder + // Only the symbols the safe wrapper actually consumes. + // narrowed this from `whisper_.*` because + // the broad allowlist exposed unshimmed throwing C++ + // entry points (e.g. `whisper_vad_init_*` whose file + // loaders throw `std::runtime_error` on truncated + // models, and `whisper_full_with_state` whose + // exceptions cross `extern "C"` into Rust as UB). New + // raw symbols need an explicit allowlist add and a + // matching audit: confirm the upstream function cannot + // throw, OR add a `whispercpp_*` shim wrapping it in + // try/catch. + // + // No-throw raw entry points (verified): + // - `*_default_params` — value-returning + // - `*_free`, `*_free_state` — destructors + // - `*_n_*`, `*_token_*`, + // `*_is_multilingual`, + // `*_lang_str`, + // `*_model_type_readable`, + // `*_full_get_*_from_state` — pure read accessors + // - `*_token_to_str` — would throw via + // `id_to_token.at` but the safe wrapper validates + // the bound first. + // + // Throwing entry points routed through `whispercpp_*` + // shims: + // - `whisper_init_from_file_with_params_no_state` → + // `whispercpp_init_from_file_no_state` + // - `whisper_init_state` → + // `whispercpp_init_state` + // - `whisper_full_with_state` → + // `whispercpp_full_with_state` + // - `whisper_print_system_info` → + // `whispercpp_print_system_info` + // + // VAD entry points (`whisper_vad_*`) are NOT exposed — + // the safe wrapper doesn't surface VAD, and their file + // loaders throw on truncated models. + .allowlist_function("whisper_context_default_params") + .allowlist_function("whisper_full_default_params") + .allowlist_function("whisper_free") + .allowlist_function("whisper_free_state") + .allowlist_function("whisper_is_multilingual") + .allowlist_function("whisper_n_vocab") + .allowlist_function("whisper_n_audio_ctx") + .allowlist_function("whisper_n_text_ctx") + .allowlist_function("whisper_token_eot") + .allowlist_function("whisper_token_sot") + .allowlist_function("whisper_token_beg") + .allowlist_function("whisper_token_to_str") + .allowlist_function("whisper_lang_str") + .allowlist_function("whisper_model_type_readable") + .allowlist_function("whisper_full_n_segments_from_state") + .allowlist_function("whisper_full_lang_id_from_state") + .allowlist_function("whisper_full_get_segment_t0_from_state") + .allowlist_function("whisper_full_get_segment_t1_from_state") + .allowlist_function("whisper_full_get_segment_text_from_state") + .allowlist_function("whisper_full_get_segment_no_speech_prob_from_state") + .allowlist_function("whisper_full_get_segment_speaker_turn_next_from_state") + .allowlist_function("whisper_full_n_tokens_from_state") + .allowlist_function("whisper_full_get_token_data_from_state") + // ggml's logger setter is referenced from our context + // init lock comment but not directly called. We expose + // the whole ggml_log_* family for diagnostic use. + .allowlist_function("ggml_log_.*") + // Shim entry points — no-throw at the boundary. + .allowlist_function("whispercpp_.*") + // Type allowlist: every struct / enum the function + // signatures above transitively require. + .allowlist_type("whisper_context") + .allowlist_type("whisper_state") + .allowlist_type("whisper_context_params") + .allowlist_type("whisper_full_params") + .allowlist_type("whisper_token") + .allowlist_type("whisper_token_data") + .allowlist_type("whisper_pos") + .allowlist_type("whisper_seq_id") + .allowlist_type("whisper_sampling_strategy") + .allowlist_type("whisper_grammar_element") + .allowlist_type("whisper_segment") + .allowlist_type("whisper_progress_callback") + .allowlist_type("whisper_new_segment_callback") + .allowlist_type("whisper_encoder_begin_callback") + .allowlist_type("whisper_logits_filter_callback") + .allowlist_type("ggml_log_.*") + .allowlist_var("WHISPER_.*") + // Shim exception sentinels (WHISPERCPP_ERR_*). state.rs + // needs them to discriminate "shim caught a C++ exception + // → state may be corrupt → poison" from "whisper.cpp + // returned a documented error code". + .allowlist_var("WHISPERCPP_.*") + // CargoCallbacks calls + // `println!("cargo:rerun-if-changed=...")` for every + // header bindgen pulled. Those land under whisper.cpp/... + // (or the system include path) so we DO want them — a + // header change should re-bindgen. + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .layout_tests(false) + .derive_default(true) + .derive_debug(true) + .generate() + .expect("bindgen failed"); + + let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set")); + let dest = out_dir.join("generated.rs"); + let body = bindings.to_string(); + let header_comment = format!( + "// @generated\n\ + //\n\ + // whisper.cpp FFI surface — produced by bindgen against\n\ + // the bundled submodule (`whispercpp-sys/whisper.cpp/`),\n\ + // patched in OUT_DIR. Do not edit by hand.\n\ + //\n\ + // Source crate: {pkg} {ver}\n\ + // Source header: wrapper.h -> whisper.h + whispercpp_shim.h\n\ + //\n\n", + pkg = env!("CARGO_PKG_NAME"), + ver = env!("CARGO_PKG_VERSION"), + ); + + let new_contents = format!("{header_comment}{body}"); + std::fs::write(&dest, new_contents).expect("failed to write OUT_DIR/generated.rs"); +} diff --git a/whispercpp-sys/src/lib.rs b/whispercpp-sys/src/lib.rs new file mode 100644 index 0000000..1342b68 --- /dev/null +++ b/whispercpp-sys/src/lib.rs @@ -0,0 +1,33 @@ +//! `whispercpp-sys` — raw FFI bindings to whisper.cpp. +//! +//! Everything below is `unsafe`-callable C ABI surface. Higher +//! layers (the `whispercpp` crate) wrap these in safe types; +//! end users should depend on `whispercpp` rather than this +//! crate directly. +//! +//! `build.rs` cmake-builds the vendored `whisper.cpp/` submodule +//! (pinned to a patched fork branch) and statically links the +//! resulting libraries. There is no pkg-config / system-install +//! path: the safe surface in the upper crate depends on patches +//! that only the bundled build supplies, and a stock libwhisper +//! would silently lose those guarantees. Bindgen writes the FFI +//! surface to `OUT_DIR/generated.rs`. + +#![allow(unsafe_code)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +#![allow(dead_code)] +#![allow(missing_docs)] + +// Bindgen output is written to `OUT_DIR` by build.rs and +// `include!`'d here. An in-tree path (`src/generated.rs`) +// would break read-only builds (cargo vendor, Nix, Bazel, +// verified-source registry checkouts) and could race across +// builds with different feature sets. +// +// Trade-off: the FFI surface is no longer grep-able from a +// fresh checkout. Inspect via `cargo expand -p whispercpp-sys` +// or look at `target/.../build/whispercpp-sys-*/out/generated.rs` +// after a build. +include!(concat!(env!("OUT_DIR"), "/generated.rs")); diff --git a/whispercpp-sys/whisper.cpp b/whispercpp-sys/whisper.cpp new file mode 160000 index 0000000..9c4881d --- /dev/null +++ b/whispercpp-sys/whisper.cpp @@ -0,0 +1 @@ +Subproject commit 9c4881d8f5cd2224224e46ec8d012cce348be39d diff --git a/whispercpp-sys/whispercpp_shim.cpp b/whispercpp-sys/whispercpp_shim.cpp new file mode 100644 index 0000000..213eefc --- /dev/null +++ b/whispercpp-sys/whispercpp_shim.cpp @@ -0,0 +1,125 @@ +// C++ exception-catching shim around whisper.cpp's public API. +// +// See whispercpp_shim.h for the rationale. Every wrapper in +// this file isolates its whisper_* call inside `try/catch (…)` +// so a `std::bad_alloc` / `std::system_error` / any other +// throw inside whisper.cpp becomes a sentinel return value +// instead of unwinding through the `extern "C"` boundary into +// Rust — which is undefined behaviour. + +#include "whispercpp_shim.h" + +#include +#include +#include + +// Per-thread "most recent caught constructor exception" slot. +// +// the constructor shims previously collapsed +// every failure (including caught exceptions) onto `nullptr`, +// indistinguishable from an upstream "init failed cleanly" +// nullptr return. Callers therefore couldn't tell a retryable +// failure (bad path, missing file) from a partial-init exception +// that leaked the `new whisper_context` / `new whisper_state` +// allocations. +// +// We expose a thread-local sentinel. Each constructor entry +// resets it to 0 and writes a `WHISPERCPP_ERR_*` value on catch. +// Callers pair every `nullptr` observation with +// `whispercpp_take_last_constructor_exception` to discriminate +// — and surface the exception case as a non-retryable fatal +// error so workers don't compound the leak. +// +// Why thread-local: concurrent context/state inits on different +// threads must not interleave their sentinels. Cross-thread +// reads are forbidden by the API contract (read on the same +// thread that made the call). +// +// Why a single slot for both `init_from_file` and `init_state`: +// the safe Rust API reads the sentinel synchronously after each +// constructor call, before any other shim entry on the same +// thread. There's no observation window where one constructor's +// exception could be misread as another's. +static thread_local int g_last_constructor_exception = 0; + +extern "C" { + +struct whisper_context * whispercpp_init_from_file_no_state( + const char * path_model, + struct whisper_context_params params) +{ + g_last_constructor_exception = 0; + try { + return whisper_init_from_file_with_params_no_state(path_model, params); + } catch (const std::bad_alloc &) { + g_last_constructor_exception = WHISPERCPP_ERR_BAD_ALLOC; + return nullptr; + } catch (const std::system_error &) { + g_last_constructor_exception = WHISPERCPP_ERR_SYSTEM_ERROR; + return nullptr; + } catch (const std::exception &) { + g_last_constructor_exception = WHISPERCPP_ERR_STD_EXCEPTION; + return nullptr; + } catch (...) { + g_last_constructor_exception = WHISPERCPP_ERR_UNKNOWN_EXCEPTION; + return nullptr; + } +} + +struct whisper_state * whispercpp_init_state(struct whisper_context * ctx) +{ + g_last_constructor_exception = 0; + try { + return whisper_init_state(ctx); + } catch (const std::bad_alloc &) { + g_last_constructor_exception = WHISPERCPP_ERR_BAD_ALLOC; + return nullptr; + } catch (const std::system_error &) { + g_last_constructor_exception = WHISPERCPP_ERR_SYSTEM_ERROR; + return nullptr; + } catch (const std::exception &) { + g_last_constructor_exception = WHISPERCPP_ERR_STD_EXCEPTION; + return nullptr; + } catch (...) { + g_last_constructor_exception = WHISPERCPP_ERR_UNKNOWN_EXCEPTION; + return nullptr; + } +} + +int whispercpp_take_last_constructor_exception(void) +{ + int v = g_last_constructor_exception; + g_last_constructor_exception = 0; + return v; +} + +int whispercpp_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples) +{ + try { + return whisper_full_with_state(ctx, state, params, samples, n_samples); + } catch (const std::bad_alloc &) { + return WHISPERCPP_ERR_BAD_ALLOC; + } catch (const std::system_error &) { + return WHISPERCPP_ERR_SYSTEM_ERROR; + } catch (const std::exception &) { + return WHISPERCPP_ERR_STD_EXCEPTION; + } catch (...) { + return WHISPERCPP_ERR_UNKNOWN_EXCEPTION; + } +} + +const char * whispercpp_print_system_info(void) +{ + try { + return whisper_print_system_info(); + } catch (...) { + return nullptr; + } +} + +} // extern "C" diff --git a/whispercpp-sys/whispercpp_shim.h b/whispercpp-sys/whispercpp_shim.h new file mode 100644 index 0000000..8965fa8 --- /dev/null +++ b/whispercpp-sys/whispercpp_shim.h @@ -0,0 +1,122 @@ +/// C-ABI shims around the whisper.cpp public API. +/// +/// Every function declared here wraps its whisper.cpp +/// counterpart in a `try { ... } catch (...) { ... }` block. +/// flagged that whisper.cpp's `extern "C"` +/// entry points internally allocate `std::vector` and +/// construct `std::thread`, both of which can throw +/// (`std::bad_alloc`, `std::system_error`) under realistic +/// resource pressure. C++ exceptions propagating across an +/// `extern "C"` boundary into Rust code that hasn't compiled +/// with `panic=unwind` ABI compatibility is undefined +/// behaviour. +/// +/// Convention: +/// +/// * Constructors that return `T*` on success return +/// `nullptr` on caught exception (matches the C API's +/// existing failure mode). +/// * `int`-returning `whisper_full_with_state` returns a +/// negative sentinel for caught exceptions: +/// * `-100` for `std::bad_alloc` (OOM) +/// * `-101` for `std::system_error` (thread/system call) +/// * `-102` for any other `std::exception` +/// * `-103` for unknown / non-`std::exception` throws +/// These overlap whisper.cpp's own negative return codes +/// (which top out at `-7` in v1.8.4) without colliding; +/// the safe-Rust wrapper translates them into typed +/// `WhisperError` variants. + +#ifndef WHISPERCPP_SHIM_H +#define WHISPERCPP_SHIM_H + +#include "whisper.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Exception sentinels returned by `whispercpp_full_with_state`. +/// Defined as macros (not enums) so bindgen treats them as +/// plain integer constants the safe wrapper can match on. +#define WHISPERCPP_ERR_BAD_ALLOC -100 +#define WHISPERCPP_ERR_SYSTEM_ERROR -101 +#define WHISPERCPP_ERR_STD_EXCEPTION -102 +#define WHISPERCPP_ERR_UNKNOWN_EXCEPTION -103 + +/// `whisper_init_from_file_with_params_no_state` wrapped in +/// try/catch. +/// +/// Returns `nullptr` on either: +/// * the upstream C API's documented failure (file not found, +/// model corrupt, backend init refused, etc. — these return +/// nullptr without throwing), OR +/// * a caught C++ exception inside the upstream init path +/// (`std::bad_alloc`, `std::system_error`, +/// `std::exception`, or anything else). +/// +/// Use [`whispercpp_take_last_constructor_exception`] AFTER +/// observing `nullptr` to discriminate the two cases — the +/// caller MUST treat the exception case as fatal (the +/// upstream code has no RAII around `new whisper_context;`, +/// so any throw mid-init leaks the partial allocation). +/// +struct whisper_context * whispercpp_init_from_file_no_state( + const char * path_model, + struct whisper_context_params params); + +/// `whisper_init_state` wrapped in try/catch. +/// +/// Same `nullptr` discrimination contract as +/// [`whispercpp_init_from_file_no_state`]: pair every +/// `nullptr` observation with +/// [`whispercpp_take_last_constructor_exception`] to +/// distinguish "upstream returned nullptr cleanly" (retryable) +/// from "exception caught, partial native allocation leaked" +/// (fatal). +struct whisper_state * whispercpp_init_state(struct whisper_context * ctx); + +/// Read-and-clear the most recent **constructor** exception +/// sentinel. +/// +/// Set by [`whispercpp_init_from_file_no_state`] and +/// [`whispercpp_init_state`] inside their `catch` blocks; reset +/// to `0` on entry to those functions and again by this +/// accessor. +/// +/// Returns one of: +/// * `0` — no exception was caught on the most recent +/// constructor call on this thread (a `nullptr` return means +/// the upstream C API returned `nullptr` cleanly, no leak). +/// * `WHISPERCPP_ERR_BAD_ALLOC` — `std::bad_alloc` during init. +/// * `WHISPERCPP_ERR_SYSTEM_ERROR` — `std::system_error`. +/// * `WHISPERCPP_ERR_STD_EXCEPTION` — other `std::exception`. +/// * `WHISPERCPP_ERR_UNKNOWN_EXCEPTION` — non-`std::exception` +/// throw. +/// +/// Thread-local: each thread observes its own most-recent +/// sentinel. Callers must invoke this on the SAME thread that +/// made the constructor call, immediately after observing the +/// `nullptr` return. Inserting other shim calls between the +/// constructor and this read clobbers the sentinel. +int whispercpp_take_last_constructor_exception(void); + +/// `whisper_full_with_state` wrapped in try/catch. +int whispercpp_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples); + +/// `whisper_print_system_info` wrapped in try/catch. Upstream +/// rebuilds a static `std::string` via `s = ""; s += "..."; s +/// += std::to_string(...);` which can throw `std::bad_alloc` +/// across the C ABI. Returns NULL on any caught exception. +const char * whispercpp_print_system_info(void); + +#ifdef __cplusplus +} +#endif + +#endif // WHISPERCPP_SHIM_H diff --git a/whispercpp-sys/wrapper.h b/whispercpp-sys/wrapper.h new file mode 100644 index 0000000..02b35c6 --- /dev/null +++ b/whispercpp-sys/wrapper.h @@ -0,0 +1,12 @@ +// bindgen entry point. Pulls only what we use. Adding new +// whisper.cpp surface to the safe wrapper means adding the +// matching `#include` here AND extending the `allowlist_*` +// directives in `build.rs` — there is no implicit re-export. +// +// `whispercpp_shim.h` exposes the exception-catching C ABI +// shim layer. Every safe-Rust entry point +// that can run user-controlled allocations / thread spawns +// goes through these shims rather than calling whisper.cpp +// directly. +#include "whisper.h" +#include "whispercpp_shim.h" diff --git a/whispercpp/.gitignore b/whispercpp/.gitignore new file mode 100644 index 0000000..5929c59 --- /dev/null +++ b/whispercpp/.gitignore @@ -0,0 +1,4 @@ +# Per-crate cargo target dir (whisper-cpp uses its own +# workspace declaration so cargo writes here rather than +# alongside whispery's main target/). +target/ diff --git a/whispercpp/Cargo.toml b/whispercpp/Cargo.toml new file mode 100644 index 0000000..04ed87b --- /dev/null +++ b/whispercpp/Cargo.toml @@ -0,0 +1,96 @@ +[package] +name = "whispercpp" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +readme = "README.md" +description = "Safe Rust bindings for whisper.cpp speech recognition. Bundled patched build with memory-safety hardening, exception-catching FFI shim, and Send + Sync types." +keywords = ["whisper", "asr", "speech-to-text", "transcription", "audio"] +categories = ["multimedia::audio", "api-bindings", "science"] + +[lib] +name = "whispercpp" +path = "src/lib.rs" + +[features] +# Default: CPU build (no GPU dep). `cargo add whispercpp` Just +# Works on every platform (Apple, Linux, Windows) without a +# preinstalled whisper.cpp — but ships ZERO GPU acceleration. +# Opt into GPU backends explicitly per-target via the feature +# flags below; multi-platform consumers can document which +# accelerator each target needs in their own `Cargo.toml` +# rather than silently inheriting an Apple-Silicon-only set. +# +# whisper.cpp is **always** built from the vendored submodule +# in `whispercpp-sys/whisper.cpp/` and patched in `OUT_DIR`. +# There is no system / pkg-config path: routing safe-Rust +# code through a stock unpatched libwhisper would silently +# drop the memory-safety guarantees the bundled patches +# provide. +default = [] + +# Enables `serde::{Serialize, Deserialize}` for `Lang`. The +# wire format is the lowercase ISO-639-1 string ("en", "yue", +# etc.) — see `Lang`'s impls. +serde = ["dep:serde", "smol_str/serde"] + +# ── Backends ────────────────────────────────────────────── +# +# Each chains 1:1 to `whispercpp-sys`'s matching feature, +# which is what actually toggles whisper.cpp / ggml's CMake +# `-DGGML_*=ON` flag. + +# Apple-only: +metal = ["whispercpp-sys/metal"] # GGML_METAL: Metal GPU +coreml = ["whispercpp-sys/coreml"] # WHISPER_COREML: ANE encoder + +# Cross-platform GPU: +vulkan = ["whispercpp-sys/vulkan"] # GGML_VULKAN: Vulkan compute +opencl = ["whispercpp-sys/opencl"] # GGML_OPENCL: mobile / Adreno + +# Vendor-specific GPU: +cuda = ["whispercpp-sys/cuda"] # NVIDIA +hipblas = ["whispercpp-sys/hipblas"] # AMD ROCm/HIP +sycl = ["whispercpp-sys/sycl"] # Intel oneAPI / Arc +musa = ["whispercpp-sys/musa"] # Moore Threads + +# Encoder accelerators (similar role to CoreML on other vendors): +openvino = ["whispercpp-sys/openvino"] # Intel OpenVINO + +# CPU BLAS: +openblas = ["whispercpp-sys/openblas"] # OpenBLAS + +[dependencies] +# Low-level FFI to whisper.cpp. Path dep — sibling crate at +# `../whispercpp-sys/`. All `unsafe extern "C"` declarations +# live there; this crate only ever calls them behind safe +# wrappers. +whispercpp-sys = { version = "0.1", path = "../whispercpp-sys", default-features = false } +# Public error type. `thiserror` keeps things light. +thiserror = { version = "2", default-features = false } +# Inline small strings (≤23 bytes) for error payloads — paths, +# language hints, single-char interior-NUL diagnostics. Avoids +# a heap allocation on every `WhisperError::ContextLoad` / +# `InvalidCString`. +smol_str = { version = "0.3", default-features = false } +# Optional serde, gated by the `serde` feature. When enabled, +# `Lang` round-trips through the canonical lowercase ISO-639-1 +# string (`"en"`, `"yue"`, …). +serde = { version = "1", optional = true, default-features = false, features = ["alloc"] } + +[dev-dependencies] +# Hound is for the `examples/smoke.rs` WAV reader only — never +# pulled into a production build of `whispercpp` itself. +hound = "3" +# `lang.rs`'s serde tests round-trip through JSON. Dev-only — +# no runtime cost in production builds. +serde_json = "1" + +[[example]] +name = "smoke" +path = "examples/smoke.rs" + +[lints] +workspace = true diff --git a/whispercpp/README.md b/whispercpp/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/whispercpp/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/whispercpp/TODO.md b/whispercpp/TODO.md new file mode 100644 index 0000000..63a03f3 --- /dev/null +++ b/whispercpp/TODO.md @@ -0,0 +1,219 @@ +# whispercpp — unsupported surface + +This crate intentionally exposes a narrow slice of whisper.cpp. +Everything below is reachable from the auto-generated FFI in the +sibling `whispercpp-sys` crate (`whispercpp-sys/src/generated.rs`) +but is NOT wrapped in safe Rust today. + +Three categories: + +1. **Deliberately omitted** — whispery doesn't need it; wrapping would + add maintenance + surface area without a caller. +2. **Could add on demand** — small wrapper, not justified yet. +3. **Larger work** — would need design choices about safety / lifetimes + before exposing. + +When adding something below, also extend the `allowlist_function` +/ `allowlist_type` directives in +`whispercpp-sys/build.rs::generate_bindings()` if the symbol +isn't already in `whispercpp-sys/src/generated.rs`. + +--- + +## 1. Deliberately omitted + +### Built-in VAD + +whisper.cpp ships its own VAD (Silero ONNX). Whispery uses the +`silero` crate for VAD upstream of `whispercpp`, so the in-tree path +is the canonical one. Re-wrapping whisper.cpp's wrapper duplicates +state and complicates the call chain. + +Symbols: `whisper_vad_*`, `whisper_full_params::vad`, +`vad_model_path`, `vad_params`, `set_min_speech_duration_ms`, +`set_max_speech_duration_s`, `set_min_silence_duration_ms`, +`set_speech_pad_ms`, `set_threshold`. + +### Grammar + +Whispery doesn't constrain decoding via grammar. The grammar +machinery in whisper.cpp pulls a sizeable struct hierarchy +(`whisper_grammar_element`, rules, stacks) and a non-trivial +ownership model. No caller has asked for it. + +Symbols: `whisper_full_params::grammar_rules`, `grammar_n_rules`, +`grammar_i_start_rule`, `grammar_penalty`, `set_grammar`, +`set_grammar_penalty`, `set_start_rule`, `whisper_grammar_*`. + +### Translate task + +Whispery is transcribe-only. `set_translate(true)` (translate audio +→ English) is wrapped (one-line passthrough), but the full +translate-task flow (token id remapping, prompt seeding) is not +exercised by any caller and we don't ship test coverage for it. + +### Tinydiarize controls + +`Segment::speaker_turn_next()` IS wrapped (it's a 1-byte read). +Configuring `--tdrz` on the input side (`set_tdrz_enable`) is not +— it requires a TDRZ-enabled checkpoint which whispery doesn't ship, +and whispery's diarization runs upstream via pyannote-style +clustering on word ranges. + +Symbols: `whisper_full_params::tdrz_enable`, `set_tdrz_enable`. + +### Lower-level entry points + +We expose `state.full()` only. The lower-level encode/decode flow +(running the encoder, then `decode` token-by-token with custom +sampling) is meaningful for research / custom samplers but doesn't +fit whispery's pump architecture. + +Symbols: `whisper_encode`, `whisper_encode_with_state`, +`whisper_decode`, `whisper_decode_with_state`, `whisper_get_logits`, +`whisper_get_logits_from_state`, `whisper_set_mel`, +`whisper_set_mel_with_state`, `whisper_pcm_to_mel`, +`whisper_pcm_to_mel_with_state`. + +### Mid-decode callbacks + +Whisper.cpp can fire callbacks on every new segment, every logits +emission, and at encoder start. Each requires the same trampoline +discipline as the abort callback and adds another `Box` +field to `Params`. None is wired into whispery's pump (which works +chunk-at-a-time, not token-at-a-time). + +Symbols: `set_progress_callback`, `set_progress_callback_safe`, +`set_progress_callback_user_data`, `set_new_segment_callback`, +`set_segment_callback_safe`, `set_segment_callback_safe_lossy`, +`set_new_segment_callback_user_data`, `set_filter_logits_callback`, +`set_filter_logits_callback_user_data`, `set_start_encoder_callback`, +`set_start_encoder_callback_user_data`. + +### Global logging hooks + +Whispery routes diagnostics through its own `eprintln!` / `tracing` +layer. whisper.cpp's `set_log_callback` is a global hook that fires +across all instances; mixing it with Rust logging frameworks +requires more design than a 1:1 port. + +Symbols: `whisper_set_log_callback`, `set_debug_mode`, +`whisper_log_callback`. + +### DTW token timestamps + +Whispery uses wav2vec2 forced alignment for word-level timing. +whisper.cpp's DTW path is a parallel mechanism with its own +configuration (`dtw_aheads`, `dtw_n_top`, `dtw_mem_size`). Wrapping +it would invite confusion about which timestamping path is +authoritative. + +Symbols: `whisper_full_params::dtw_token_timestamps` (true at +construction, but `Params::set_dtw_*` and `dtw_aheads` array are +not exposed), `whisper_aheads`, `whisper_full_get_token_dtw_t0_*`. + +### Buffer-load constructors + +We support `Context::new(path, params)` only. Loading from an +in-memory buffer (`whisper_init_from_buffer_with_params`) or via a +custom `whisper_model_loader` is rare and adds lifetime/ownership +complexity. + +Symbols: `whisper_init_from_buffer_with_params`, +`whisper_init_with_params` (custom loader), the `whisper_model_loader` +struct. + +### Beam-search + greedy sampler details (advanced) + +Symbols: `set_beam_size`, `set_patience` are reachable through +`SamplingStrategy::BeamSearch { beam_size, patience }` already. +Direct `whisper_full_params::beam_search.beam_size` / `patience` +accessors aren't exposed (use `Params::new(strategy)` with the +right variant). + +--- + +## 2. Could add on demand + +These are 5–15-line wrappers around an existing FFI symbol. None is +required for whispery's current flow; each is justifiable when a +concrete caller appears. + +| Whisper.cpp symbol | Suggested Rust API | Why might we want it | +|---|---|---| +| `whisper_token_text(ctx, token)` (alias of `token_to_str`) | already covered | — | +| `whisper_token_to_bytes` | `Context::token_to_bytes(token) -> Option<&[u8]>` | non-UTF-8 byte sequences from BPE merges | +| `whisper_lang_id(name)` | `Context::lang_id_for(name: &str) -> Option` | reverse of `detected_lang` | +| `whisper_lang_max_id()` | `pub const LANG_MAX_ID: i32 = …` | iterate languages | +| `whisper_lang_str_full(id)` | `Lang::full_name() -> &'static str` | "english" vs "en" | +| `whisper_token_translate / transcribe / prev / nosp / not / solm` | `Context::token_translate() -> i32`, etc. | force-prefix decoding seeds | +| `whisper_token_lang(ctx, lang_id)` | `Context::token_for_lang(Lang) -> i32` | language-specific seeds | +| `whisper_token_id(ctx, token: &str)` | `Context::tokenize_one(text) -> Option` | turn a string back into a token id | +| `whisper_tokenize(ctx, text, tokens, max)` | `Context::tokenize(text) -> Vec` | batch tokenization for `set_tokens` | +| `whisper_n_len_from_state(state)` | `State::n_mel_frames() -> i32` | mel buffer length | +| `whisper_print_timings(ctx)` | `Context::print_timings()` | end-of-run cost breakdown | +| `whisper_reset_timings(ctx)` | `Context::reset_timings()` | per-chunk timing | +| `whisper_get_whisper_version()` | `pub fn version() -> &'static str` | diagnostic | +| Model layer counts (`whisper_model_n_audio_state` / `n_audio_head` / `n_audio_layer` / `n_text_state` / `n_text_head` / `n_text_layer` / `n_mels` / `model_ftype`) | `Context::model_dims() -> ModelDims` (struct of ints) | architecture-aware diagnostics | +| `whisper_full_get_token_p_from_state` | `Token::posterior() -> f32` | already covered indirectly via `Token::p()` reading `whisper_token_data.p` — verify the two agree under wildcard / temperature | +| `whisper_full_n_tokens_from_state` | already covered (`Segment::n_tokens()`) | — | + +Adding any of these means: extend the safe wrapper module, run the +existing test suite (`cargo test -p whispercpp --features serde`), +and confirm no rebuild loop on `src/generated.rs` (build.rs short- +circuits when the bindgen output is byte-identical). + +--- + +## 3. Larger work + +### Token-stream `Iterator` + +`State::segments_iter()` and `Segment::tokens_iter()` would be nice +ergonomics. The lifetime story is non-trivial — each `Segment` / +`Token` borrows from `State` via raw pointer. A correct iterator +needs to project through that lifetime without aliasing. + +### Async-friendly `full` + +`State::full` blocks for the duration of the decode (seconds to +minutes). A `tokio`-friendly variant that runs the FFI on a +blocking task pool and yields completion would help server use +cases. Currently callers spawn their own threads. + +### Streaming / partial-result API + +whisper.cpp's `whisper_full` is a one-shot call. Streaming +transcription requires either (a) the new-segment callback path +(see "Mid-decode callbacks" above), or (b) external chunking + one +`Context::create_state()` per chunk. Whispery does (b) at the +runner layer. + +### CoreML companion model build + +Whispery ships `coreml` as an opt-in feature, but generating the +`.mlmodelc` companion file (whisper.cpp's `models/generate-coreml- +model.sh`) is out-of-band. A `whispercpp-tools` crate or a build.rs +helper that converts a checkpoint at install time would close the +loop, but it requires `coremltools` (a Python dep) at build time — +not great. + +--- + +## Audit policy + +Before adding new public functions to `whispercpp`: + +1. Confirm the FFI symbol is in + `whispercpp-sys/src/generated.rs`. If not, extend the + allowlists in `whispercpp-sys/build.rs::generate_bindings()`. +2. Replicate the safety rules used by the closest existing wrapper: + pointer is non-null, lifetime tied to the parent struct, no + aliasing across threads. Document the SAFETY block. +3. Keep the public surface minimal — accessors private until a + caller materialises. The crate's value is "small, audited, no + leaks"; that holds only if every `unsafe` block has an obvious + justification. + +For deliberately-omitted items, prefer documenting the omission here +rather than wrapping speculatively. diff --git a/build.rs b/whispercpp/build.rs similarity index 100% rename from build.rs rename to whispercpp/build.rs diff --git a/whispercpp/examples/smoke.rs b/whispercpp/examples/smoke.rs new file mode 100644 index 0000000..b772b25 --- /dev/null +++ b/whispercpp/examples/smoke.rs @@ -0,0 +1,92 @@ +//! Smoke test: load a model, transcribe a 16 kHz mono WAV, print +//! the segment list. Times each phase so we can compare against +//! whisper-cli end-to-end. +//! +//! ```text +//! whisper-cpp-smoke [language] +//! ``` + +use std::time::Instant; + +use whispercpp::{Context, ContextParams, Params, SamplingStrategy}; + +fn main() -> Result<(), Box> { + let mut args = std::env::args().skip(1); + let model = args.next().ok_or("usage: [lang]")?; + let wav = args.next().ok_or("usage: [lang]")?; + let lang = args.next().unwrap_or_else(|| "en".to_string()); + + // Load 16 kHz mono f32. We rely on the `hound` crate at the + // workspace level normally; for the smoke test, do it inline + // to keep dependencies on this crate to literally just whisper. + let samples = read_wav_16k_mono(&wav)?; + let dur_s = samples.len() as f64 / 16_000.0; + eprintln!( + "[smoke] wav={wav} samples={} dur={dur_s:.2}s", + samples.len() + ); + + let t_load = Instant::now(); + let ctx = std::sync::Arc::new(Context::new( + &model, + ContextParams::new().with_use_gpu(true), + )?); + eprintln!( + "[smoke] context loaded in {:.3}s", + t_load.elapsed().as_secs_f64() + ); + + let mut state = ctx.create_state()?; + + let mut params = Params::new(SamplingStrategy::Greedy { best_of: 1 }); + // `set_language` is fallible (interior NUL); the rest are + // infallible chained `&mut Self` setters. + params.set_language(&lang)?; + params + .set_n_threads(1) + .set_no_context(true) + .set_suppress_blank(true) + .set_suppress_nst(true) + .set_temperature(0.0) + .set_temperature_inc(0.0) + .set_no_speech_thold(0.6) + .silence_print_toggles(); + + let t_full = Instant::now(); + state.full(¶ms, &samples)?; + let full_s = t_full.elapsed().as_secs_f64(); + eprintln!("[smoke] full() in {full_s:.3}s | rtf={:.3}", full_s / dur_s); + + let n = state.n_segments(); + eprintln!("[smoke] {n} segments"); + for i in 0..n { + let seg = state.segment(i).expect("idx in range"); + let t0 = seg.t0() as f64 * 0.01; + let t1 = seg.t1() as f64 * 0.01; + eprintln!(" [{t0:7.2}s -> {t1:7.2}s] {}", seg.text()?); + } + Ok(()) +} + +fn read_wav_16k_mono(path: &str) -> Result, Box> { + // Inline hound usage to keep the crate's runtime deps to zero + // beyond `thiserror`. The smoke binary only — production callers + // bring their own audio loader (whispery uses ffmpeg-next). + let mut reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + if spec.sample_rate != 16_000 { + return Err(format!("expected 16 kHz, got {} Hz", spec.sample_rate).into()); + } + if spec.channels != 1 { + return Err(format!("expected mono, got {} channels", spec.channels).into()); + } + match spec.sample_format { + hound::SampleFormat::Float => Ok(reader.samples::().collect::>()?), + hound::SampleFormat::Int => Ok( + reader + .samples::() + .map(|s| s.map(|x| x as f32 / 32768.0)) + .collect::>()?, + ), + } +} diff --git a/whispercpp/src/context.rs b/whispercpp/src/context.rs new file mode 100644 index 0000000..9601892 --- /dev/null +++ b/whispercpp/src/context.rs @@ -0,0 +1,727 @@ +//! `Context` — the loaded whisper model. +//! +//! Owns the `whisper_context*` returned by +//! `whisper_init_from_file_with_params`. Drop calls +//! `whisper_free`. Cloning is intentionally NOT supported — the +//! underlying whisper.cpp object is a unique owned resource. To +//! run multiple inference threads against the same model, share +//! `Arc` and call [`Context::create_state`] per thread +//! (each `State` carries its own KV cache). + +#![allow(unsafe_code)] + +use core::{ + ptr::NonNull, + sync::atomic::{AtomicBool, Ordering}, +}; +use std::{ + ffi::CString, + path::Path, + sync::{Arc, Mutex, MutexGuard}, +}; + +use crate::{ + error::{WhisperError, WhisperResult}, + state::State, + sys, +}; + +/// Acquire the process-wide mutex guarding every FFI call +/// that mutates ggml's global logger state. +/// +/// `whisper_init_state` calls +/// `whisper_backend_init_gpu`, which unconditionally invokes +/// `ggml_log_set(g_state.log_callback, …)` — writing to +/// ggml's file-static logger globals without any +/// synchronisation. `whisper_init_from_file_with_params_no_state` +/// is in the same family (touches `g_state` indirectly through +/// backend probing). With `unsafe impl Sync for Context`, two +/// safe-Rust threads holding `Arc` could call +/// `create_state` (or `Context::new`) concurrently and race on +/// those globals — a C/C++ data race reachable from safe Rust. +/// +/// The mutex serialises both init paths. Cost: one mutex +/// acquire per `Context::new` and per `create_state`. Both are +/// init-time, not hot-path; whispery's worker pool +/// pre-creates one `State` per worker at startup, so this is +/// microseconds-per-startup-once. +pub(crate) fn init_lock() -> MutexGuard<'static, ()> { + static LOCK: Mutex<()> = Mutex::new(()); + // Recover a poisoned lock — we don't hold any state on + // the inner ``, so re-acquiring after an unrelated panic + // in a sibling thread is fine. + LOCK.lock().unwrap_or_else(|e| e.into_inner()) +} + +/// Knobs forwarded to `whisper_context_default_params` before +/// loading. Mirrors the subset of `whisper_context_params` whispery +/// uses today. +/// +/// All fields are private; access goes through `const fn` +/// accessors and `with_*` builder methods so the type's invariants +/// stay encapsulated and the public surface evolves +/// independently of the underlying C struct. +#[derive(Debug, Clone, Copy)] +pub struct ContextParams { + use_gpu: bool, + gpu_device: i32, + flash_attn: bool, +} + +impl ContextParams { + /// Defaults: GPU on (Metal/CUDA where compiled in), device 0, + /// flash-attn off. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn new() -> Self { + Self { + use_gpu: true, + gpu_device: 0, + flash_attn: false, + } + } + + /// Whether the encoder dispatches to a GPU backend (Metal / + /// CUDA). On Apple Silicon: `true` is required to avoid the + /// BLAS-only encode path that hits whisper.cpp's `failed to + /// encode` error on `large-v3-turbo`. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn use_gpu(&self) -> bool { + self.use_gpu + } + + /// Chained setter for [`Self::use_gpu`]. `const fn` so callers + /// can build a `ContextParams` in `const` context (e.g. in + /// per-runner config statics). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_use_gpu(mut self, on: bool) -> Self { + self.use_gpu = on; + self + } + + /// GPU device index (default `0` = primary). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn gpu_device(&self) -> i32 { + self.gpu_device + } + + /// Chained setter for [`Self::gpu_device`]. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_gpu_device(mut self, idx: i32) -> Self { + self.gpu_device = idx; + self + } + + /// Whether flash-attention is enabled. Default `false`. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn flash_attn(&self) -> bool { + self.flash_attn + } + + /// Chained setter for [`Self::flash_attn`]. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_flash_attn(mut self, on: bool) -> Self { + self.flash_attn = on; + self + } +} + +impl Default for ContextParams { + #[cfg_attr(not(tarpaulin), inline(always))] + fn default() -> Self { + Self::new() + } +} + +/// Loaded whisper.cpp model. Cheap to share via `Arc`. +pub struct Context { + // `NonNull` (vs. `*mut`) makes the Drop impl total — there is + // no "uninitialised" representation to guard against. + ptr: NonNull, + // bound the per-Context leak budget under + // `WhisperError::StateLost`. A `State::full` exception + // poisons the State (we MUST NOT free a possibly-corrupt + // `whisper_state`) and leaks that state's native + // allocations (~360 MB on `large-v3-turbo`). Without this + // flag, callers retrying `create_state` on the same Context + // accumulate one leak per attempt until the host runs out + // of memory. With it, `create_state` short-circuits to + // `ContextPoisoned` after the FIRST `StateLost`, capping + // the total leak at one State per Context. Recovery + // requires dropping this Context and constructing a fresh + // one (model reload — slow but bounded). + lost: AtomicBool, + // Serialise `State::full` calls through this Context. + // Without this lock, multiple + // workers each holding their own `State` (the documented + // pattern) can ALL be inside `whispercpp_full_with_state` + // simultaneously when an OOM / system_error fires. Each + // would poison its own state and leak ~360 MB before any + // of them got to mark the Context lost — the per-Context + // cap claim becomes a per-concurrent-worker cap, defeating + // the point. Holding this mutex across the FFI call makes + // the cap structural: at most one in-flight call per + // Context, so at most one leaked state per Context. + // + // Throughput cost: serialised inference per Context. On + // GPU backends (Metal, CUDA, Vulkan) the underlying + // command queue is already serialised, so the cost is + // small. On CPU-only inference, throughput drops to one + // inference at a time per Context — callers who need + // parallel CPU inference should run multiple Contexts + // (each loads its own copy of the model). + full_lock: Mutex<()>, +} + +// SAFETY: whisper.cpp's context is read-only after init — +// `whisper_init_from_file_with_params` is the only mutator and +// runs entirely before we hand out the pointer. Per-thread state +// (KV cache, scratch buffers) lives in `State`, not in `Context`. +// Verified against whisper.cpp v1.8.4 (the submodule pin). +unsafe impl Send for Context {} +unsafe impl Sync for Context {} + +impl Context { + /// Load a `.bin` (GGML / GGUF) model from disk. + /// + /// Returns [`WhisperError::ContextLoad`] when whisper.cpp could + /// not parse the file or initialise the requested backend, or + /// [`WhisperError::InvalidCString`] if `path` contains an + /// interior NUL. **Panic-free.** + pub fn new(path: impl AsRef, params: ContextParams) -> WhisperResult { + let path_ref = path.as_ref(); + let path_str = path_ref.to_string_lossy(); + let cpath = CString::new(path_str.as_ref()) + .map_err(|_| WhisperError::InvalidCString(smol_str::SmolStr::new(path_str.as_ref())))?; + + // SAFETY: pure C call returning a value-typed defaults struct. + let mut cparams = unsafe { sys::whisper_context_default_params() }; + cparams.use_gpu = params.use_gpu(); + cparams.gpu_device = params.gpu_device(); + cparams.flash_attn = params.flash_attn(); + + // Serialise init: backend probing inside whisper.cpp + // touches ggml's global logger state. + let _lock = init_lock(); + + // SAFETY: cpath outlives the call (held on the stack); + // cparams is value-typed. + // + // We use the C++ exception-catching shim + // `whispercpp_init_from_file_no_state`: + // upstream allocates `std::vector` / `std::ifstream` + // buffers that can throw `std::bad_alloc` on OOM, and + // unwinding C++ exceptions across `extern "C"` into Rust + // is undefined behaviour. The shim catches everything + // and collapses to a NULL return. + // + // The shim itself wraps the `_no_state` form — that's + // intentional: the default + // `whisper_init_from_file_with_params` allocates an + // extra ~360 MB `whisper_state` into `ctx->state` that + // we never use (every inference path creates its own via + // [`Context::create_state`]). + // (`src/whisper.cpp:3735`). + // + // # Leak-on-OOM discrimination + // + // Upstream's + // `whisper_init_from_file_with_params_no_state` does + // `whisper_context * ctx = new whisper_context;` and + // then performs throwing model-load work (vector + // allocations for tensors, GPU buffer allocations on + // Apple Silicon / CUDA, file-stream reads). If a + // `std::bad_alloc` or `std::system_error` fires AFTER + // the raw `new` succeeded but BEFORE the function's own + // explicit-cleanup branches run, the partial + // `whisper_context` and any tensor/backend buffers + // already allocated leak — the shim catches the + // exception but has no pointer to clean up. + // + // The shim keeps a thread-local sentinel that + // distinguishes the two flavours of NULL return: + // + // * `take_last_constructor_exception == 0` → + // upstream returned NULL CLEANLY (file-not-found, + // wrong magic, backend refused — no `new` happened + // yet, or upstream's own bool-failure paths cleaned + // up). Surface as `ContextLoad`, retryable. + // * `take_last_constructor_exception != 0` → the + // shim caught a C++ throw with the `new + // whisper_context` already allocated. Surface as + // `ConstructorLost`, NOT retryable — see that + // variant's docs for the recovery contract. + let raw = unsafe { sys::whispercpp_init_from_file_no_state(cpath.as_ptr(), cparams) }; + + if let Some(ptr) = NonNull::new(raw) { + return Ok(Self { + ptr, + lost: AtomicBool::new(false), + full_lock: Mutex::new(()), + }); + } + // SAFETY: pure C call; thread-local read on the same + // thread that made the constructor call, with no other + // shim entry between them. + let exc = unsafe { sys::whispercpp_take_last_constructor_exception() }; + if exc != 0 { + return Err(WhisperError::ConstructorLost { + origin: "context", + code: exc, + }); + } + Err(WhisperError::ContextLoad { + path: smol_str::SmolStr::new(path_str.as_ref()), + reason: smol_str::SmolStr::new( + "whispercpp_init_from_file_no_state returned NULL (upstream load failure, no native exception caught)", + ), + }) + } + + /// Create a fresh inference [`State`] tied to this model. + /// + /// Takes `Arc` because the returned `State` owns a clone + /// of the Arc — that's what keeps the Context alive across the + /// state's lifetime without forcing callers to thread a `'ctx` + /// borrow through every storage location. Construct + /// `Arc::new(Context::new(...)?)` once per model, then call + /// `create_state` per worker. + pub fn create_state(self: &Arc) -> WhisperResult { + // refuse if a prior `State::full` on this + // Context returned `WhisperError::StateLost`. Each + // `StateLost` leaks the State's native allocations + // (~360 MB on `large-v3-turbo`); allowing `create_state` + // to allocate a fresh one would compound the leak per + // retry attempt. Callers must drop this Context and + // construct a fresh one (re-loading the model) to + // recover. + if self.lost.load(Ordering::Acquire) { + return Err(WhisperError::ContextPoisoned); + } + // Serialise init: `whisper_backend_init_gpu` calls + // `ggml_log_set(...)` on ggml's file-static logger + // globals without any synchronisation. Two threads + // creating states concurrently from a shared + // `Arc` would race on those globals — a C/C++ + // data race reachable from safe Rust through + // `unsafe impl Sync for Context`. + let _lock = init_lock(); + + // SAFETY: self.ptr is non-null (NonNull invariant) and + // the Arc clone we hand to State keeps the Context (and + // therefore the underlying whisper_context*) alive for + // the State's lifetime. + // + // We route through the exception-catching shim + // `whispercpp_init_state`: upstream allocates KV-cache + // and scratch buffers via `std::vector` (each potentially + // throws `std::bad_alloc`), and on Apple Silicon also + // initialises the Metal backend (which can throw on + // device-init failure). + // + // # NULL-discrimination contract + // + // Same flavour split as `Context::new`: upstream's + // `whisper_init_state` either returns NULL via its + // bool-failure paths (every `if (!whisper_kv_cache_init…)` + // branch runs `whisper_free_state(state); return nullptr;` + // before returning — leak-free) OR throws a C++ + // exception that our shim catches AFTER `new + // whisper_state` already happened (partial leak). + // + // Read the thread-local sentinel to distinguish: + // * `0` → `StateInit` (retryable, no leak) + // * `≠ 0` → `ConstructorLost { origin: "state", … }` + // (fatal, partial allocation leaked, do not auto-retry) + let raw = unsafe { sys::whispercpp_init_state(self.ptr.as_ptr()) }; + // TOCTOU close. Between the entry-time + // `lost.load` above and `whispercpp_init_state` returning, + // another thread may have transitioned an existing State + // through `StateLost` and called `mark_lost`. If we + // published this fresh State to the caller, they'd add + // another leak-prone State to a Context whose poison flag + // is now true. Re-check after the alloc; if the flag + // flipped, free the just-created state (it's intact — + // came straight out of `whisper_init_state`) and return + // `ContextPoisoned`. This bounds the leak window to the + // duration of the FFI call rather than zero, but the + // freshly-allocated state is always freed cleanly so no + // permanent leak accumulates. + if self.lost.load(Ordering::Acquire) { + if let Some(state_ptr) = NonNull::new(raw) { + // SAFETY: `raw` is the just-returned, never-published + // result of `whispercpp_init_state`; nothing else + // holds it. `whisper_free_state` is the matching + // deallocator. + unsafe { sys::whisper_free_state(state_ptr.as_ptr()) }; + } + // Even if the alloc threw (raw is null), drain the + // thread-local sentinel so it doesn't leak across into + // the next constructor call's catch-block. + let _ = unsafe { sys::whispercpp_take_last_constructor_exception() }; + return Err(WhisperError::ContextPoisoned); + } + if let Some(state_ptr) = NonNull::new(raw) { + return Ok(State::from_raw(state_ptr, Arc::clone(self))); + } + // SAFETY: pure C call; thread-local read on the same + // thread, no other shim call between. + let exc = unsafe { sys::whispercpp_take_last_constructor_exception() }; + if exc != 0 { + // A caught constructor exception means upstream + // `whisper_init_state` left partial native allocations + // that we cannot reliably free (the throw could have + // happened mid-init at any sub-call). Poison the + // Context so subsequent + // `create_state` calls fail with `ContextPoisoned` + // instead of repeating the same OOM / system_error + // path and compounding leaks. + self.lost.store(true, Ordering::Release); + return Err(WhisperError::ConstructorLost { + origin: "state", + code: exc, + }); + } + Err(WhisperError::StateInit) + } + + /// Internal: hand the raw pointer to siblings in this crate + /// that need to call FFI functions taking `whisper_context*`. + pub(crate) fn as_raw(&self) -> *mut sys::whisper_context { + self.ptr.as_ptr() + } + + /// Internal: mark this Context as poisoned because a + /// `State::full` on one of its States returned a + /// `WhisperError::StateLost`. Subsequent + /// [`Context::create_state`] calls return + /// [`WhisperError::ContextPoisoned`]. + /// + /// Idempotent: subsequent calls are cheap atomic stores. + /// `Ordering::Release` pairs with the + /// `Ordering::Acquire` load in `create_state` so threads + /// observing the flag also observe everything that led up + /// to the poisoning (per the C++ memory model: writes + /// before a Release become visible after the matching + /// Acquire). + #[cfg_attr(not(tarpaulin), inline(always))] + pub(crate) fn mark_lost(&self) { + self.lost.store(true, Ordering::Release); + } + + /// Whether [`Context::create_state`] will refuse to + /// allocate a new [`State`]. `true` after any `State::full` + /// on this Context has returned + /// [`WhisperError::StateLost`]. Recovery requires dropping + /// this Context and constructing a fresh one. + pub fn is_poisoned(&self) -> bool { + self.lost.load(Ordering::Acquire) + } + + /// Acquire the per-Context inference lock for the duration + /// of one [`State::full`] FFI call. held + /// across the leak-prone shim entry so concurrent workers + /// can't each leak under the same OOM event before + /// poisoning fires. Recovers from a poisoned mutex (a + /// previous holder panicked) by adopting the inner unit — + /// the inner state is ``, so there's no value to be + /// inconsistent. + #[cfg_attr(not(tarpaulin), inline(always))] + pub(crate) fn full_lock(&self) -> MutexGuard<'_, ()> { + self + .full_lock + .lock() + .unwrap_or_else(|poison| poison.into_inner()) + } + + // ── Model introspection ──────────────────────────────────── + + /// `true` if the loaded checkpoint carries the multilingual + /// decoder (e.g. `large-v3-turbo`). `false` for English-only + /// checkpoints (`tiny.en`, `base.en`, …). + pub fn is_multilingual(&self) -> bool { + // SAFETY: ctx pointer invariant. + unsafe { sys::whisper_is_multilingual(self.ptr.as_ptr()) != 0 } + } + + /// Vocabulary size (number of tokens the decoder can emit). + pub fn n_vocab(&self) -> i32 { + // SAFETY: ctx pointer invariant. + unsafe { sys::whisper_n_vocab(self.ptr.as_ptr()) } + } + + /// Audio context window (encoder mel-frame budget). 1500 for + /// the vanilla 30 s checkpoints. + pub fn n_audio_ctx(&self) -> i32 { + // SAFETY: ctx pointer invariant. + unsafe { sys::whisper_n_audio_ctx(self.ptr.as_ptr()) } + } + + /// Text context window (decoder past-token budget). 448 for + /// the standard checkpoints. + pub fn n_text_ctx(&self) -> i32 { + // SAFETY: ctx pointer invariant. + unsafe { sys::whisper_n_text_ctx(self.ptr.as_ptr()) } + } + + /// Human-readable model size string baked into the checkpoint + /// (`"tiny"`, `"base"`, `"large-v3-turbo"`, …). Returns + /// `None` if whisper.cpp returned a NULL pointer or non-UTF-8 + /// (model corruption). + pub fn model_type(&self) -> Option<&'static str> { + // SAFETY: pure C accessor; pointer into a static + // const-table baked into libwhisper. + let raw = unsafe { sys::whisper_model_type_readable(self.ptr.as_ptr()) }; + if raw.is_null() { + return None; + } + // SAFETY: NUL-terminated; static lifetime per whisper.cpp. + let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() }; + core::str::from_utf8(bytes).ok() + } + + // ── Special token ids ────────────────────────────────────── + + /// `<|endoftext|>` — emitted at the end of every successful + /// decode. Useful for sentinel checks against `Token::id`. + pub fn token_eot(&self) -> i32 { + // SAFETY: ctx pointer invariant. + unsafe { sys::whisper_token_eot(self.ptr.as_ptr()) } + } + + /// `<|startoftranscript|>`. + pub fn token_sot(&self) -> i32 { + // SAFETY: ctx pointer invariant. + unsafe { sys::whisper_token_sot(self.ptr.as_ptr()) } + } + + /// First timestamp token (`<|0.00|>`). Token ids `>= token_beg` + /// encode timestamps; `< token_beg` encode text. + pub fn token_beg(&self) -> i32 { + // SAFETY: ctx pointer invariant. + unsafe { sys::whisper_token_beg(self.ptr.as_ptr()) } + } + + /// Decode a single token id back to its surface form. Useful + /// for token-level diagnostics. Returns `None` when: + /// + /// * `token` is outside `[0, n_vocab)` — would otherwise + /// throw `std::out_of_range` from + /// `id_to_token.at(token)` across the C ABI (UB) per + /// `whisper.cpp:4201`. Pre-checking the bound here keeps + /// the unwound exception from crossing `extern "C"`. + /// * the underlying `c_str` is NULL or non-UTF-8 (model + /// corruption). + /// + /// The returned slice borrows from a `std::string` owned by + /// the context's vocab table; it stays valid for as long as + /// `self` is alive. (Unlike [`system_info`], this does NOT + /// alias mutable C++ state — `id_to_token` is built once at + /// load time and never modified.) + pub fn token_to_str(&self, token: i32) -> Option<&str> { + // Validate before the FFI call — the upstream `at` throw + // would cross `extern "C"` and is UB. + let n = self.n_vocab(); + if token < 0 || token >= n { + return None; + } + // SAFETY: token bound checked above; ctx pointer invariant. + let raw = unsafe { sys::whisper_token_to_str(self.ptr.as_ptr(), token) }; + if raw.is_null() { + return None; + } + // SAFETY: NUL-terminated; lives as long as Context. + let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() }; + core::str::from_utf8(bytes).ok() + } +} + +/// System-info string assembled by libwhisper — backend caps +/// (BLAS / Metal / CUDA / OpenMP), CPU SIMD flags whisper.cpp +/// detected, and the build id. Useful at startup-time logging +/// to confirm which backend the runtime linked against. +/// +/// Returns `None` if the C++ accessor handed back a NULL pointer +/// or non-UTF-8 bytes (corrupt build). +/// +/// # Soundness notes +/// +/// `whisper_print_system_info` re-builds a file-scope +/// `static std::string s` on every invocation +/// (`s = ""; s += "..."; return s.c_str;`). Two unsoundness +/// problems follow that we paper over here: +/// +/// 1. The `c_str` returned to a previous caller becomes +/// dangling on the next call — so we can't return +/// `&'static str`. We copy into an owned [`SmolStr`](smol_str::SmolStr). +/// 2. Two concurrent callers race on the static buffer (no +/// upstream lock). We serialise behind a Rust-side +/// [`OnceLock`](std::sync::OnceLock) AND a mutex so the +/// underlying C call runs AT MOST ONCE per process, +/// eliminating both the race AND the redundant work (the +/// system info doesn't change after libwhisper loads). +/// +/// Both hazards are documented against whisper.cpp v1.8.4 at +/// `src/whisper.cpp:4315`. +pub fn system_info() -> Option { + use std::sync::{Mutex, OnceLock}; + // OnceLock holds the cached result; the inner Mutex + // serialises the FIRST call so two threads can't race the + // upstream static buffer. After init, OnceLock returns + // without locking on every call. + static CACHE: OnceLock> = OnceLock::new(); + static INIT_LOCK: Mutex<()> = Mutex::new(()); + if let Some(v) = CACHE.get() { + return v.clone(); + } + // Recover a poisoned mutex (matches `init_lock` and + // `full_lock`). The inner `()` carries no state, so a panic + // in a sibling caller can't have left anything inconsistent. + let _guard = INIT_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + // Re-check inside the lock. + if let Some(v) = CACHE.get() { + return v.clone(); + } + // SAFETY: pure C accessor; the surrounding mutex prevents + // concurrent invocations on the static `std::string s`. + // Routed through the C++ exception-catching shim — the + // upstream `whisper_print_system_info` rebuilds the static + // string via `s = ""; s += "..."; s += std::to_string(…);`, + // any of which can throw `std::bad_alloc` across the C ABI. + // + let raw = unsafe { sys::whispercpp_print_system_info() }; + let result = if raw.is_null() { + None + } else { + // SAFETY: NUL-terminated; copy IMMEDIATELY into an owned + // `SmolStr` so the borrow does not outlive the C call. + let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() }; + core::str::from_utf8(bytes).ok().map(smol_str::SmolStr::new) + }; + // Best-effort set; if a racing thread won the OnceLock between + // our `get` checks (impossible under the mutex but defensive), + // we just use whichever value got cached first. + let _ = CACHE.set(result.clone()); + result +} + +impl Drop for Context { + fn drop(&mut self) { + // SAFETY: ptr is non-null and produced by + // whisper_init_from_file_with_params; whisper_free is the + // matching deallocator. Called exactly once per Context. + unsafe { + sys::whisper_free(self.ptr.as_ptr()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// invariant: a fresh Context starts + /// non-poisoned. Pin the initial state so a future refactor + /// of the `lost: AtomicBool` initialiser cannot quietly + /// flip the contract. + #[test] + fn fresh_context_marker_starts_unpoisoned() { + // We can't construct a real `Context` without a model + // file, so build the struct directly. SAFETY (test-only): + // the dangling pointer never crosses the FFI; we only + // exercise the lost-flag accessors. + let dangling = NonNull::::dangling(); + let ctx = Context { + ptr: dangling, + lost: AtomicBool::new(false), + full_lock: Mutex::new(()), + }; + assert!(!ctx.is_poisoned()); + ctx.mark_lost(); + assert!(ctx.is_poisoned()); + // Skip the real Drop — `whisper_free` would dereference + // the dangling pointer. + core::mem::forget(ctx); + } + + /// `mark_lost` is idempotent — extra calls are cheap + /// atomic stores, never reset the flag. + #[test] + fn mark_lost_is_idempotent_and_monotonic() { + let dangling = NonNull::::dangling(); + let ctx = Context { + ptr: dangling, + lost: AtomicBool::new(false), + full_lock: Mutex::new(()), + }; + ctx.mark_lost(); + ctx.mark_lost(); + ctx.mark_lost(); + assert!(ctx.is_poisoned(), "stays true across repeated marks"); + core::mem::forget(ctx); + } + + /// `mark_lost` is observable from any + /// thread that holds an `Arc` (it's the path + /// `State::full` uses to consult sibling poisoning before + /// entering FFI). Stress the Acquire/Release pairing: a + /// background thread flips the flag, the main thread + /// observes it. + #[test] + fn mark_lost_visible_across_threads() { + let dangling = NonNull::::dangling(); + let ctx = Arc::new(Context { + ptr: dangling, + lost: AtomicBool::new(false), + full_lock: Mutex::new(()), + }); + let ctx_b = Arc::clone(&ctx); + let handle = std::thread::spawn(move || { + ctx_b.mark_lost(); + }); + handle.join().unwrap(); + assert!( + ctx.is_poisoned(), + "the post-join Acquire load must see the spawn-side Release store" + ); + // Skip Drop on the Arc — the dangling pointer must not + // reach `whisper_free`. Two `forget`s, one per Arc clone + // we manually upgraded. + core::mem::forget(Arc::try_unwrap(ctx).ok().unwrap()); + } + + /// `full_lock` survives the documented + /// concurrent-worker pattern. Two threads contend on the + /// same lock, both eventually finish, neither panics. The + /// guard's lifetime constrains the lock window so the + /// per-Context leak cap is structural. + #[test] + fn full_lock_serialises_concurrent_holders() { + let dangling = NonNull::::dangling(); + let ctx = Arc::new(Context { + ptr: dangling, + lost: AtomicBool::new(false), + full_lock: Mutex::new(()), + }); + let counter = Arc::new(std::sync::atomic::AtomicU32::new(0)); + + let mut handles = Vec::new(); + for _ in 0..4 { + let ctx_t = Arc::clone(&ctx); + let counter_t = Arc::clone(&counter); + handles.push(std::thread::spawn(move || { + let _g = ctx_t.full_lock(); + // Inside the critical section: increment, sleep + // briefly to provoke contention, decrement-confirm. + let pre = counter_t.fetch_add(1, Ordering::SeqCst); + assert_eq!(pre, 0, "another holder slipped past the mutex"); + std::thread::sleep(std::time::Duration::from_millis(2)); + let post = counter_t.fetch_sub(1, Ordering::SeqCst); + assert_eq!(post, 1, "another holder is concurrent with us"); + })); + } + for h in handles { + h.join().unwrap(); + } + core::mem::forget(Arc::try_unwrap(ctx).ok().unwrap()); + } +} diff --git a/whispercpp/src/error.rs b/whispercpp/src/error.rs new file mode 100644 index 0000000..1fb1198 --- /dev/null +++ b/whispercpp/src/error.rs @@ -0,0 +1,243 @@ +//! Crate-level error type. + +use smol_str::SmolStr; +use thiserror::Error; + +/// Result alias used throughout the crate's safe API. +pub type WhisperResult = Result; + +/// Failure modes from the whisper.cpp FFI surface. +/// +/// The variants are deliberately coarse — whisper.cpp itself +/// reports outcomes via integer return codes that don't carry +/// detailed semantics. We attach context strings where the C API +/// gives us nothing structured to propagate. +/// +/// Diagnostic strings (paths, language hints, the truncated +/// interior-NUL slice) ride in `SmolStr` rather than `String`: +/// they are typically ≤ 23 bytes and inline, so the unhappy-path +/// allocator hit goes from "1 heap allocation per error" to "0". +#[derive(Debug, Error)] +pub enum WhisperError { + /// `whisper_init_from_file_with_params` returned `NULL` + /// **without** unwinding a C++ exception. The model path was + /// wrong, the file is corrupt, or the requested backend + /// (Metal / CoreML / CUDA) refused to initialise. + /// + /// **Retryable.** No partial native allocation leaked — + /// upstream's bool-failure paths in `whisper_init_*` all run + /// `whisper_free_state(state); return nullptr;` before + /// returning. Callers may try a different path / backend. + /// + /// The exception-caught counterpart is + /// [`ConstructorLost`](Self::ConstructorLost). + #[error("failed to load model from {path}: {reason}")] + ContextLoad { + /// Path the caller passed in. Stored so logs can pinpoint + /// which model file failed. + path: SmolStr, + /// Any extra context whisper.cpp surfaced (often empty — + /// the C API just returns NULL). + reason: SmolStr, + }, + + /// `whisper_init_state` returned `NULL` **without** unwinding + /// a C++ exception. Usually an OOM on the compute buffers + /// reported via the bool-returning failure path (encode + /// allocates the largest one). + /// + /// **Retryable.** Upstream cleans up partials on this path + /// (every `if (!whisper_kv_cache_init(...))` branch calls + /// `whisper_free_state(state); return nullptr;`). + /// + /// The exception-caught counterpart is + /// [`ConstructorLost`](Self::ConstructorLost). + #[error("failed to allocate whisper state")] + StateInit, + + /// `Context::create_state` was called on a `Context` whose + /// previous [`State::full`](crate::State::full) returned + /// [`StateLost`](Self::StateLost). The Context is poisoned; + /// further state allocation would compound the per-Context + /// leak budget. + /// + /// **Recovery contract.** Drop this `Context` and + /// construct a fresh one (model reload — slow but bounded). + /// Re-using the same `Arc` against this error + /// without dropping it leaves the per-State leak (~360 MB + /// on `large-v3-turbo`) in place forever; only `Drop` of + /// the Context releases what's still freeable. + /// + /// **Why this exists.** `StateLost` cannot reliably free + /// the State's native allocations (we cannot distinguish + /// "intact" from "mid-rebuild" without upstream RAII). A + /// retry loop creating fresh States on the same Context + /// would leak per attempt; this variant caps the budget at + /// one. The fix is structural — the cap survives careless + /// callers — rather than purely documentary. + #[error("Context was poisoned by a prior StateLost; drop and reconstruct to recover")] + ContextPoisoned, + + /// **The native init path threw a C++ exception that our + /// shim caught.** Either + /// [`whispercpp_init_from_file_no_state`](crate::sys::whispercpp_init_from_file_no_state) + /// or + /// [`whispercpp_init_state`](crate::sys::whispercpp_init_state) + /// returned `nullptr` AFTER catching `std::bad_alloc` / + /// `std::system_error` / `std::exception` / unknown. + /// + /// **Not retryable.** Upstream `whisper_init_state` / + /// `whisper_init_from_file_with_params_no_state` allocate + /// raw `whisper_state` / `whisper_context` objects and + /// per-backend buffers BEFORE doing the throwing + /// model/backend/KV-cache work. Those locals are not + /// RAII-owned: a caught throw partway through leaks the + /// partial allocation (state struct + every backend / cache + /// already initialised — typically tens to hundreds of MB + /// per attempt). + /// + /// **Recovery contract.** Surface this as a fatal worker / + /// process error. Do not auto-recreate the [`Context`](crate::Context) / + /// [`State`](crate::State) inside a retry loop on the same process — + /// each attempt under the same memory / system pressure + /// leaks again. The recommended response is to escalate to + /// the supervisor and let the worker process recycle. A + /// future round of upstream patches that wraps the init + /// paths in RAII would let us downgrade some of these to + /// `ContextLoad` / `StateInit`. + #[error( + "whisper.cpp init threw {origin} (sentinel {code}); native partial allocation leaked, \ + not retryable" + )] + ConstructorLost { + /// Which constructor caught the exception. `"context"` for + /// the model-load path, `"state"` for the per-call state + /// allocation. + origin: &'static str, + /// The shim sentinel set inside the catch block (one of + /// `WHISPERCPP_ERR_BAD_ALLOC`, `_SYSTEM_ERROR`, + /// `_STD_EXCEPTION`, `_UNKNOWN_EXCEPTION`). + code: i32, + }, + + /// `whisper_full_with_state` returned a non-zero, **non-fatal** + /// code. The state remains intact and may be reused for a + /// fresh call. + /// + /// Examples: `-1` (whisper_pcm_to_mel), `-2` + /// (whisper_set_mel), `-3..-6` (intermediate encode/decode + /// failures whisper.cpp marks as recoverable). `-7` is **not** + /// here — it surfaces as [`StateLost`](Self::StateLost). + #[error("whisper_full failed with code {code}")] + Full { + /// The whisper.cpp return code. See `whisper.h` for the + /// (sparse) documented values. + code: i32, + }, + + /// **The native `whisper_state` is gone.** Either whisper.cpp + /// freed it from underneath us (`-7`, multi-decoder KV-cache + /// allocation failure: upstream calls `whisper_free_state` + /// before returning), or our exception shim caught a C++ + /// throw partway through `whisper_full_with_state` (sentinels + /// ≤ `-100`). + /// + /// **Not retryable on this `State`.** The Rust `State` has + /// been poisoned: every accessor short-circuits to a safe + /// zero/None. The native allocation IS released — either + /// by upstream's own `whisper_free_state` on the `-7` path, + /// or explicitly from `State::full`'s sentinel handler on + /// caught-exception paths. The fork's idempotent + /// `whisper_kv_cache_free` patch closed the double-free + /// hazard that previously forced us to leak the state. + /// + /// **Recovery contract.** Receivers should still treat + /// this as fatal at the worker level: + /// + /// 1. Do **not** call [`State::full`](crate::State::full) + /// again on this `State`. Drop it. + /// 2. The parent [`Context`](crate::Context) is poisoned + /// too — [`Context::create_state`](crate::Context::create_state) + /// will return [`ContextPoisoned`](Self::ContextPoisoned) + /// until the Context is dropped and reconstructed. This + /// is defensive: the underlying pressure that caused the + /// throw (OOM, thread-table exhaustion, fatal backend + /// error) is likely still present, so retries on the + /// same Context would just re-fail. + /// 3. The recommended response is to surface the error to + /// your supervisor, drop the Context, and reload the + /// model in a fresh Context once pressure has resolved. + #[error("whisper_full lost the native state (code {code}); state freed, Context poisoned")] + StateLost { + /// The whisper.cpp return code or shim sentinel that + /// triggered poisoning. `-7` = upstream KV-cache failure; + /// `≤ -100` = exception sentinel from + /// `whispercpp_full_with_state` (see + /// `whispercpp_shim.h::WHISPERCPP_ERR_*`). + code: i32, + }, + + /// A path passed to the safe API contained an interior NUL + /// byte. The whisper.cpp C API requires NUL-terminated strings. + #[error("argument contained an interior NUL byte: {0}")] + InvalidCString(SmolStr), + + /// UTF-8 decode failure on a string returned from whisper.cpp + /// (segment text or token text). The model vocabulary should + /// always emit valid UTF-8; this would indicate a corrupt model + /// file. + #[error("whisper.cpp returned non-UTF-8 text: {0}")] + Utf8(#[from] core::str::Utf8Error), + + /// Audio buffer length exceeded `i32::MAX` samples. whisper.cpp's + /// C API takes the count as `int`. At 16 kHz this caps at + /// ~37 hours per call — well above any realistic chunk — so this + /// surfaces only when callers misuse the API (bytes-vs-samples + /// confusion, accidental double-pad, etc.). + #[error("audio buffer too large: {samples} samples > i32::MAX")] + SamplesOverflow { + /// The provided buffer length, for diagnostics. + samples: usize, + }, + + /// Audio buffer was too short for whisper.cpp's mel + /// spectrogram preprocessor. + /// + /// `log_mel_spectrogram` performs a reflective pad at the + /// start of the buffer: + /// `std::reverse_copy(samples + 1, samples + 1 + 200, …)`, + /// so it reads `samples[1..201]`. Inputs shorter than 201 + /// samples (≈ 12.5 ms at 16 kHz) trigger an out-of-bounds + /// read in the C++ kernel before whisper.cpp's later + /// short-input check fires. The safe wrapper rejects them + /// up-front instead of forwarding the UB across the FFI. + /// + /// Callers feeding sub-201-sample buffers should pad with + /// silence (zeros) up to at least 201 samples, or batch the + /// audio into longer windows upstream. + #[error( + "audio buffer too short: {samples} samples < {min_required} (reflective-pad lower bound)" + )] + SamplesTooShort { + /// The provided buffer length, for diagnostics. + samples: usize, + /// Minimum samples whisper.cpp's mel preprocessor requires. + min_required: usize, + }, + + /// A token id passed to + /// [`Context::token_to_str`](crate::Context::token_to_str) + /// fell outside the model's vocabulary. The C API + /// (`whisper_token_to_str`) uses `id_to_token.at(token)` + /// which throws `std::out_of_range`; a C++ exception across + /// `extern "C"` is undefined behaviour. The safe wrapper + /// validates the bound first and surfaces this error + /// instead. + #[error("token id {token} out of range [0, {vocab_size})")] + TokenOutOfRange { + /// The id the caller passed in. + token: i32, + /// The model's vocab size at validation time. + vocab_size: i32, + }, +} diff --git a/whispercpp/src/lang.rs b/whispercpp/src/lang.rs new file mode 100644 index 0000000..4588da5 --- /dev/null +++ b/whispercpp/src/lang.rs @@ -0,0 +1,829 @@ +//! `Lang` — typed enum over whisper.cpp's supported languages, with +//! an `Other(SmolStr)` escape hatch for unknown ISO codes. + +use smol_str::SmolStr; + +/// Language code. Marked `#[non_exhaustive]` so new variants can be +/// added when whisper.cpp adds languages without forcing a +/// semver-major bump; carries an `Other(SmolStr)` variant so unknown +/// ISO codes flowing in from whisper's auto-detect don't fail an +/// indexing run. +/// +/// **Canonicalisation invariant.** [`Lang::from_iso639_1`] maps known +/// codes to named variants and never produces `Other` for an +/// enum-known code. This keeps structural `PartialEq`/`Hash` correct: +/// `Lang::En != Lang::Other("en")` is fine because no API path +/// constructs `Lang::Other("en")`. +/// +/// **Serde wire format.** Lowercase ISO-639-1 strings: `"en"`, +/// `"yue"`, etc. (a previous `derive(Serialize, +/// Deserialize)` produced Rust variant names like `"En"` and +/// `{"Other":"xx"}`, which contradicted documented config shapes +/// and made human-edited configs brittle. The custom impls +/// below canonicalise through [`Lang::from_iso639_1`] / +/// [`Lang::as_str`] so the in-memory representation stays as-is +/// while the wire format matches the docs.) +#[non_exhaustive] +#[allow(missing_docs)] // variants are ISO 639-1 codes; self-documenting by name +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum Lang { + En, + Zh, + De, + Es, + Ru, + Ko, + Fr, + Ja, + Pt, + Tr, + Pl, + Ca, + Nl, + Ar, + Sv, + It, + Id, + Hi, + Fi, + Vi, + He, + Uk, + El, + Ms, + Cs, + Ro, + Da, + Hu, + Ta, + No, + Th, + Ur, + Hr, + Bg, + Lt, + La, + Mi, + Ml, + Cy, + Sk, + Te, + Fa, + Lv, + Bn, + Sr, + Az, + Sl, + Kn, + Et, + Mk, + Br, + Eu, + Is, + Hy, + Ne, + Mn, + Bs, + Kk, + Sq, + Sw, + Gl, + Mr, + Pa, + Si, + Km, + Sn, + Yo, + So, + Af, + Oc, + Ka, + Be, + Tg, + Sd, + Gu, + Am, + Yi, + Lo, + Uz, + Fo, + Ht, + Ps, + Tk, + Nn, + Mt, + Sa, + Lb, + My, + Bo, + Tl, + Mg, + As, + Tt, + Haw, + Ln, + Ha, + Ba, + Jw, + Su, + Yue, + /// ISO 639-1 (or whisper-supplied) code that did not match any + /// known variant. `from_iso639_1` and `as_str` round-trip + /// through this for unknown codes; the indexer can log the + /// SmolStr value and continue. + Other(SmolStr), +} + +impl Lang { + /// Stable round-trip with [`Lang::from_iso639_1`]. Named variants + /// emit their canonical lowercase ISO code; `Other(s)` emits `s`. + #[inline] + pub fn as_str(&self) -> &str { + match self { + Self::En => "en", + Self::Zh => "zh", + Self::De => "de", + Self::Es => "es", + Self::Ru => "ru", + Self::Ko => "ko", + Self::Fr => "fr", + Self::Ja => "ja", + Self::Pt => "pt", + Self::Tr => "tr", + Self::Pl => "pl", + Self::Ca => "ca", + Self::Nl => "nl", + Self::Ar => "ar", + Self::Sv => "sv", + Self::It => "it", + Self::Id => "id", + Self::Hi => "hi", + Self::Fi => "fi", + Self::Vi => "vi", + Self::He => "he", + Self::Uk => "uk", + Self::El => "el", + Self::Ms => "ms", + Self::Cs => "cs", + Self::Ro => "ro", + Self::Da => "da", + Self::Hu => "hu", + Self::Ta => "ta", + Self::No => "no", + Self::Th => "th", + Self::Ur => "ur", + Self::Hr => "hr", + Self::Bg => "bg", + Self::Lt => "lt", + Self::La => "la", + Self::Mi => "mi", + Self::Ml => "ml", + Self::Cy => "cy", + Self::Sk => "sk", + Self::Te => "te", + Self::Fa => "fa", + Self::Lv => "lv", + Self::Bn => "bn", + Self::Sr => "sr", + Self::Az => "az", + Self::Sl => "sl", + Self::Kn => "kn", + Self::Et => "et", + Self::Mk => "mk", + Self::Br => "br", + Self::Eu => "eu", + Self::Is => "is", + Self::Hy => "hy", + Self::Ne => "ne", + Self::Mn => "mn", + Self::Bs => "bs", + Self::Kk => "kk", + Self::Sq => "sq", + Self::Sw => "sw", + Self::Gl => "gl", + Self::Mr => "mr", + Self::Pa => "pa", + Self::Si => "si", + Self::Km => "km", + Self::Sn => "sn", + Self::Yo => "yo", + Self::So => "so", + Self::Af => "af", + Self::Oc => "oc", + Self::Ka => "ka", + Self::Be => "be", + Self::Tg => "tg", + Self::Sd => "sd", + Self::Gu => "gu", + Self::Am => "am", + Self::Yi => "yi", + Self::Lo => "lo", + Self::Uz => "uz", + Self::Fo => "fo", + Self::Ht => "ht", + Self::Ps => "ps", + Self::Tk => "tk", + Self::Nn => "nn", + Self::Mt => "mt", + Self::Sa => "sa", + Self::Lb => "lb", + Self::My => "my", + Self::Bo => "bo", + Self::Tl => "tl", + Self::Mg => "mg", + Self::As => "as", + Self::Tt => "tt", + Self::Haw => "haw", + Self::Ln => "ln", + Self::Ha => "ha", + Self::Ba => "ba", + Self::Jw => "jw", + Self::Su => "su", + Self::Yue => "yue", + Self::Other(s) => s.as_str(), + } + } +} + +impl Lang { + /// Total-function constructor: every `&str` produces a `Lang`. + /// Known whisper.cpp codes canonicalise to their named variant; + /// unknown codes go to `Lang::Other`. Never produces + /// `Lang::Other("en")` for an enum-known code "en" — see the + /// canonicalisation invariant on the type doc. + pub fn from_iso639_1(s: &str) -> Self { + match s { + "en" | "En" | "eN" | "EN" => Self::En, + "zh" | "Zh" | "zH" | "ZH" => Self::Zh, + "de" | "De" | "dE" | "DE" => Self::De, + "es" | "Es" | "eS" | "ES" => Self::Es, + "ru" | "Ru" | "rU" | "RU" => Self::Ru, + "ko" | "Ko" | "kO" | "KO" => Self::Ko, + "fr" | "Fr" | "fR" | "FR" => Self::Fr, + "ja" | "Ja" | "jA" | "JA" => Self::Ja, + "pt" | "Pt" | "pT" | "PT" => Self::Pt, + "tr" | "Tr" | "tR" | "TR" => Self::Tr, + "pl" | "Pl" | "pL" | "PL" => Self::Pl, + "ca" | "Ca" | "cA" | "CA" => Self::Ca, + "nl" | "Nl" | "nL" | "NL" => Self::Nl, + "ar" | "Ar" | "aR" | "AR" => Self::Ar, + "sv" | "Sv" | "sV" | "SV" => Self::Sv, + "it" | "It" | "iT" | "IT" => Self::It, + "id" | "Id" | "iD" | "ID" => Self::Id, + "hi" | "Hi" | "hI" | "HI" => Self::Hi, + "fi" | "Fi" | "fI" | "FI" => Self::Fi, + "vi" | "Vi" | "vI" | "VI" => Self::Vi, + "he" | "He" | "hE" | "HE" => Self::He, + "uk" | "Uk" | "uK" | "UK" => Self::Uk, + "el" | "El" | "eL" | "EL" => Self::El, + "ms" | "Ms" | "mS" | "MS" => Self::Ms, + "cs" | "Cs" | "cS" | "CS" => Self::Cs, + "ro" | "Ro" | "rO" | "RO" => Self::Ro, + "da" | "Da" | "dA" | "DA" => Self::Da, + "hu" | "Hu" | "hU" | "HU" => Self::Hu, + "ta" | "Ta" | "tA" | "TA" => Self::Ta, + "no" | "No" | "nO" | "NO" => Self::No, + "th" | "Th" | "tH" | "TH" => Self::Th, + "ur" | "Ur" | "uR" | "UR" => Self::Ur, + "hr" | "Hr" | "hR" | "HR" => Self::Hr, + "bg" | "Bg" | "bG" | "BG" => Self::Bg, + "lt" | "Lt" | "lT" | "LT" => Self::Lt, + "la" | "La" | "lA" | "LA" => Self::La, + "mi" | "Mi" | "mI" | "MI" => Self::Mi, + "ml" | "Ml" | "mL" | "ML" => Self::Ml, + "cy" | "Cy" | "cY" | "CY" => Self::Cy, + "sk" | "Sk" | "sK" | "SK" => Self::Sk, + "te" | "Te" | "tE" | "TE" => Self::Te, + "fa" | "Fa" | "fA" | "FA" => Self::Fa, + "lv" | "Lv" | "lV" | "LV" => Self::Lv, + "bn" | "Bn" | "bN" | "BN" => Self::Bn, + "sr" | "Sr" | "sR" | "SR" => Self::Sr, + "az" | "Az" | "aZ" | "AZ" => Self::Az, + "sl" | "Sl" | "sL" | "SL" => Self::Sl, + "kn" | "Kn" | "kN" | "KN" => Self::Kn, + "et" | "Et" | "eT" | "ET" => Self::Et, + "mk" | "Mk" | "mK" | "MK" => Self::Mk, + "br" | "Br" | "bR" | "BR" => Self::Br, + "eu" | "Eu" | "eU" | "EU" => Self::Eu, + "is" | "Is" | "iS" | "IS" => Self::Is, + "hy" | "Hy" | "hY" | "HY" => Self::Hy, + "ne" | "Ne" | "nE" | "NE" => Self::Ne, + "mn" | "Mn" | "mN" | "MN" => Self::Mn, + "bs" | "Bs" | "bS" | "BS" => Self::Bs, + "kk" | "Kk" | "kK" | "KK" => Self::Kk, + "sq" | "Sq" | "sQ" | "SQ" => Self::Sq, + "sw" | "Sw" | "sW" | "SW" => Self::Sw, + "gl" | "Gl" | "gL" | "GL" => Self::Gl, + "mr" | "Mr" | "mR" | "MR" => Self::Mr, + "pa" | "Pa" | "pA" | "PA" => Self::Pa, + "si" | "Si" | "sI" | "SI" => Self::Si, + "km" | "Km" | "kM" | "KM" => Self::Km, + "sn" | "Sn" | "sN" | "SN" => Self::Sn, + "yo" | "Yo" | "yO" | "YO" => Self::Yo, + "so" | "So" | "sO" | "SO" => Self::So, + "af" | "Af" | "aF" | "AF" => Self::Af, + "oc" | "Oc" | "oC" | "OC" => Self::Oc, + "ka" | "Ka" | "kA" | "KA" => Self::Ka, + "be" | "Be" | "bE" | "BE" => Self::Be, + "tg" | "Tg" | "tG" | "TG" => Self::Tg, + "sd" | "Sd" | "sD" | "SD" => Self::Sd, + "gu" | "Gu" | "gU" | "GU" => Self::Gu, + "am" | "Am" | "aM" | "AM" => Self::Am, + "yi" | "Yi" | "yI" | "YI" => Self::Yi, + "lo" | "Lo" | "lO" | "LO" => Self::Lo, + "uz" | "Uz" | "uZ" | "UZ" => Self::Uz, + "fo" | "Fo" | "fO" | "FO" => Self::Fo, + "ht" | "Ht" | "hT" | "HT" => Self::Ht, + "ps" | "Ps" | "pS" | "PS" => Self::Ps, + "tk" | "Tk" | "tK" | "TK" => Self::Tk, + "nn" | "Nn" | "nN" | "NN" => Self::Nn, + "mt" | "Mt" | "mT" | "MT" => Self::Mt, + "sa" | "Sa" | "sA" | "SA" => Self::Sa, + "lb" | "Lb" | "lB" | "LB" => Self::Lb, + "my" | "My" | "mY" | "MY" => Self::My, + "bo" | "Bo" | "bO" | "BO" => Self::Bo, + "tl" | "Tl" | "tL" | "TL" => Self::Tl, + "mg" | "Mg" | "mG" | "MG" => Self::Mg, + "as" | "As" | "aS" | "AS" => Self::As, + "tt" | "Tt" | "tT" | "TT" => Self::Tt, + "haw" | "Haw" | "hAW" | "HAW" => Self::Haw, + "ln" | "Ln" | "lN" | "LN" => Self::Ln, + "ha" | "Ha" | "hA" | "HA" => Self::Ha, + "ba" | "Ba" | "bA" | "BA" => Self::Ba, + "jw" | "Jw" | "jW" | "JW" => Self::Jw, + "su" | "Su" | "sU" | "SU" => Self::Su, + "yue" | "Yue" | "yUE" | "YUE" => Self::Yue, + other => Self::Other(SmolStr::new(other)), + } + } + + /// Total-function constructor: every `&str` produces a `Lang`. + /// Known whisper.cpp codes canonicalise to their named variant; + /// unknown codes go to `Lang::Other`. Never produces + /// `Lang::Other("en")` for an enum-known code "en" — see the + /// canonicalisation invariant on the type doc. + pub fn try_from_iso639_1(s: &str) -> Option { + Some(match s { + "en" | "En" | "eN" | "EN" => Self::En, + "zh" | "Zh" | "zH" | "ZH" => Self::Zh, + "de" | "De" | "dE" | "DE" => Self::De, + "es" | "Es" | "eS" | "ES" => Self::Es, + "ru" | "Ru" | "rU" | "RU" => Self::Ru, + "ko" | "Ko" | "kO" | "KO" => Self::Ko, + "fr" | "Fr" | "fR" | "FR" => Self::Fr, + "ja" | "Ja" | "jA" | "JA" => Self::Ja, + "pt" | "Pt" | "pT" | "PT" => Self::Pt, + "tr" | "Tr" | "tR" | "TR" => Self::Tr, + "pl" | "Pl" | "pL" | "PL" => Self::Pl, + "ca" | "Ca" | "cA" | "CA" => Self::Ca, + "nl" | "Nl" | "nL" | "NL" => Self::Nl, + "ar" | "Ar" | "aR" | "AR" => Self::Ar, + "sv" | "Sv" | "sV" | "SV" => Self::Sv, + "it" | "It" | "iT" | "IT" => Self::It, + "id" | "Id" | "iD" | "ID" => Self::Id, + "hi" | "Hi" | "hI" | "HI" => Self::Hi, + "fi" | "Fi" | "fI" | "FI" => Self::Fi, + "vi" | "Vi" | "vI" | "VI" => Self::Vi, + "he" | "He" | "hE" | "HE" => Self::He, + "uk" | "Uk" | "uK" | "UK" => Self::Uk, + "el" | "El" | "eL" | "EL" => Self::El, + "ms" | "Ms" | "mS" | "MS" => Self::Ms, + "cs" | "Cs" | "cS" | "CS" => Self::Cs, + "ro" | "Ro" | "rO" | "RO" => Self::Ro, + "da" | "Da" | "dA" | "DA" => Self::Da, + "hu" | "Hu" | "hU" | "HU" => Self::Hu, + "ta" | "Ta" | "tA" | "TA" => Self::Ta, + "no" | "No" | "nO" | "NO" => Self::No, + "th" | "Th" | "tH" | "TH" => Self::Th, + "ur" | "Ur" | "uR" | "UR" => Self::Ur, + "hr" | "Hr" | "hR" | "HR" => Self::Hr, + "bg" | "Bg" | "bG" | "BG" => Self::Bg, + "lt" | "Lt" | "lT" | "LT" => Self::Lt, + "la" | "La" | "lA" | "LA" => Self::La, + "mi" | "Mi" | "mI" | "MI" => Self::Mi, + "ml" | "Ml" | "mL" | "ML" => Self::Ml, + "cy" | "Cy" | "cY" | "CY" => Self::Cy, + "sk" | "Sk" | "sK" | "SK" => Self::Sk, + "te" | "Te" | "tE" | "TE" => Self::Te, + "fa" | "Fa" | "fA" | "FA" => Self::Fa, + "lv" | "Lv" | "lV" | "LV" => Self::Lv, + "bn" | "Bn" | "bN" | "BN" => Self::Bn, + "sr" | "Sr" | "sR" | "SR" => Self::Sr, + "az" | "Az" | "aZ" | "AZ" => Self::Az, + "sl" | "Sl" | "sL" | "SL" => Self::Sl, + "kn" | "Kn" | "kN" | "KN" => Self::Kn, + "et" | "Et" | "eT" | "ET" => Self::Et, + "mk" | "Mk" | "mK" | "MK" => Self::Mk, + "br" | "Br" | "bR" | "BR" => Self::Br, + "eu" | "Eu" | "eU" | "EU" => Self::Eu, + "is" | "Is" | "iS" | "IS" => Self::Is, + "hy" | "Hy" | "hY" | "HY" => Self::Hy, + "ne" | "Ne" | "nE" | "NE" => Self::Ne, + "mn" | "Mn" | "mN" | "MN" => Self::Mn, + "bs" | "Bs" | "bS" | "BS" => Self::Bs, + "kk" | "Kk" | "kK" | "KK" => Self::Kk, + "sq" | "Sq" | "sQ" | "SQ" => Self::Sq, + "sw" | "Sw" | "sW" | "SW" => Self::Sw, + "gl" | "Gl" | "gL" | "GL" => Self::Gl, + "mr" | "Mr" | "mR" | "MR" => Self::Mr, + "pa" | "Pa" | "pA" | "PA" => Self::Pa, + "si" | "Si" | "sI" | "SI" => Self::Si, + "km" | "Km" | "kM" | "KM" => Self::Km, + "sn" | "Sn" | "sN" | "SN" => Self::Sn, + "yo" | "Yo" | "yO" | "YO" => Self::Yo, + "so" | "So" | "sO" | "SO" => Self::So, + "af" | "Af" | "aF" | "AF" => Self::Af, + "oc" | "Oc" | "oC" | "OC" => Self::Oc, + "ka" | "Ka" | "kA" | "KA" => Self::Ka, + "be" | "Be" | "bE" | "BE" => Self::Be, + "tg" | "Tg" | "tG" | "TG" => Self::Tg, + "sd" | "Sd" | "sD" | "SD" => Self::Sd, + "gu" | "Gu" | "gU" | "GU" => Self::Gu, + "am" | "Am" | "aM" | "AM" => Self::Am, + "yi" | "Yi" | "yI" | "YI" => Self::Yi, + "lo" | "Lo" | "lO" | "LO" => Self::Lo, + "uz" | "Uz" | "uZ" | "UZ" => Self::Uz, + "fo" | "Fo" | "fO" | "FO" => Self::Fo, + "ht" | "Ht" | "hT" | "HT" => Self::Ht, + "ps" | "Ps" | "pS" | "PS" => Self::Ps, + "tk" | "Tk" | "tK" | "TK" => Self::Tk, + "nn" | "Nn" | "nN" | "NN" => Self::Nn, + "mt" | "Mt" | "mT" | "MT" => Self::Mt, + "sa" | "Sa" | "sA" | "SA" => Self::Sa, + "lb" | "Lb" | "lB" | "LB" => Self::Lb, + "my" | "My" | "mY" | "MY" => Self::My, + "bo" | "Bo" | "bO" | "BO" => Self::Bo, + "tl" | "Tl" | "tL" | "TL" => Self::Tl, + "mg" | "Mg" | "mG" | "MG" => Self::Mg, + "as" | "As" | "aS" | "AS" => Self::As, + "tt" | "Tt" | "tT" | "TT" => Self::Tt, + "haw" | "Haw" | "hAW" | "HAW" => Self::Haw, + "ln" | "Ln" | "lN" | "LN" => Self::Ln, + "ha" | "Ha" | "hA" | "HA" => Self::Ha, + "ba" | "Ba" | "bA" | "BA" => Self::Ba, + "jw" | "Jw" | "jW" | "JW" => Self::Jw, + "su" | "Su" | "sU" | "SU" => Self::Su, + "yue" | "Yue" | "yUE" | "YUE" => Self::Yue, + _ => return None, + }) + } +} + +impl core::fmt::Display for Lang { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.as_str()) + } +} + +#[cfg(feature = "serde")] +#[cfg_attr(docsrs, doc(cfg(feature = "serde")))] +const _: () = { + impl serde::Serialize for Lang { + /// Serialize as the lowercase ISO-639-1 (or whisper-supplied) + /// string code. Matches what [`Lang::as_str`] returns — + /// `Lang::En` → `"en"`, `Lang::Other(SmolStr::new("xx"))` → + /// `"xx"`. The previous `derive(Serialize)` produced Rust + /// variant names like `"En"` and `{"Other":"xx"}`, + /// contradicting the config docs. + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } + } + + impl<'de> serde::Deserialize<'de> for Lang { + /// Deserialize from an ISO-639-1 string code, **case-insensitive**. + /// + /// Accepts any ASCII-letter case (`"en"`, `"EN"`, `"En"`, + /// `"eN"` all canonicalise to `Lang::En`); whisper.cpp's + /// language codes are conventionally lowercase but the ISO + /// standard treats them as case-insensitive, and human-edited + /// configs naturally use mixed case. The accepted alphabet + /// after lowercasing is `[a-z]{1,8}` — matches the + /// alignment-stage validation in `runner/whisper_pool.rs`'s + /// `validate_language_code` so an "EN" config + /// produces a Lang that the FFI layer happily accepts. + /// + /// Routes through [`Lang::from_iso639_1`] *after* lowercasing + /// so input matching a named variant canonicalises to that + /// variant rather than landing in `Other`. Unknown codes pass + /// through `Lang::Other(SmolStr::new(lowered))` — the inner + /// string is always lowercase, preserving the canonicalisation + /// invariant across the serde boundary AND keeping the + /// language-string intern table bounded. + /// + /// Round-trip asymmetry note: `"EN"` deserialises to + /// `Lang::En` which then *serialises* as `"en"`. This is + /// intentional — the on-disk canonical form is lowercase. + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error as _; + + let s = <&str as serde::Deserialize>::deserialize(deserializer)?; + if s.is_empty() { + return Err(D::Error::custom("Lang code is empty")); + } + if s.len() > 8 { + return Err(D::Error::custom(format!( + "Lang code longer than 8 bytes ({} bytes); whisper.cpp codes are 2-3 ASCII letters", + s.len() + ))); + } + if !s.bytes().all(|b| b.is_ascii_alphabetic()) { + return Err(D::Error::custom( + "Lang code must be ASCII letters [a-zA-Z] only (no digits, dashes, or non-ASCII)", + )); + } + // Avoid the lowercasing allocation when input is already canonical. + if s.bytes().all(|b| b.is_ascii_lowercase()) { + Ok(Lang::from_iso639_1(s)) + } else { + use smol_str::StrExt; + + let lowered = s.to_ascii_lowercase_smolstr(); + Ok(Lang::try_from_iso639_1(&lowered).unwrap_or(Self::Other(lowered))) + } + } + } +}; + +#[cfg(test)] +mod tests { + use super::*; + + /// Every named variant round-trips through `from_iso639_1(as_str)` + /// AND does not match `Lang::Other(_)`. This is the + /// canonicalisation invariant. + #[test] + fn named_variants_canonicalise() { + let known = [ + Lang::En, + Lang::Zh, + Lang::De, + Lang::Es, + Lang::Ru, + Lang::Ko, + Lang::Fr, + Lang::Ja, + Lang::Pt, + Lang::Tr, + Lang::Pl, + Lang::Ca, + Lang::Nl, + Lang::Ar, + Lang::Sv, + Lang::It, + Lang::Id, + Lang::Hi, + Lang::Fi, + Lang::Vi, + Lang::He, + Lang::Uk, + Lang::El, + Lang::Ms, + Lang::Cs, + Lang::Ro, + Lang::Da, + Lang::Hu, + Lang::Ta, + Lang::No, + Lang::Th, + Lang::Ur, + Lang::Hr, + Lang::Bg, + Lang::Lt, + Lang::La, + Lang::Mi, + Lang::Ml, + Lang::Cy, + Lang::Sk, + Lang::Te, + Lang::Fa, + Lang::Lv, + Lang::Bn, + Lang::Sr, + Lang::Az, + Lang::Sl, + Lang::Kn, + Lang::Et, + Lang::Mk, + Lang::Br, + Lang::Eu, + Lang::Is, + Lang::Hy, + Lang::Ne, + Lang::Mn, + Lang::Bs, + Lang::Kk, + Lang::Sq, + Lang::Sw, + Lang::Gl, + Lang::Mr, + Lang::Pa, + Lang::Si, + Lang::Km, + Lang::Sn, + Lang::Yo, + Lang::So, + Lang::Af, + Lang::Oc, + Lang::Ka, + Lang::Be, + Lang::Tg, + Lang::Sd, + Lang::Gu, + Lang::Am, + Lang::Yi, + Lang::Lo, + Lang::Uz, + Lang::Fo, + Lang::Ht, + Lang::Ps, + Lang::Tk, + Lang::Nn, + Lang::Mt, + Lang::Sa, + Lang::Lb, + Lang::My, + Lang::Bo, + Lang::Tl, + Lang::Mg, + Lang::As, + Lang::Tt, + Lang::Haw, + Lang::Ln, + Lang::Ha, + Lang::Ba, + Lang::Jw, + Lang::Su, + Lang::Yue, + ]; + assert_eq!( + known.len(), + 100, + "must keep the 100-variant Appendix C list in sync" + ); + for v in known.iter() { + let round = Lang::from_iso639_1(v.as_str()); + assert_eq!(&round, v, "round-trip failed for {:?}", v); + assert!( + !matches!(round, Lang::Other(_)), + "{:?} canonicalised to Other; this breaks Eq/Hash", + v + ); + } + } + + #[test] + fn unknown_codes_land_in_other() { + let r = Lang::from_iso639_1("zzz"); + assert_eq!(r, Lang::Other(SmolStr::new("zzz"))); + assert_eq!(r.as_str(), "zzz"); + } + + #[test] + fn other_round_trips_via_as_str() { + let r = Lang::Other(SmolStr::new("xx")); + assert_eq!(r.as_str(), "xx"); + assert_eq!(Lang::from_iso639_1(r.as_str()), r); + } + + // --- custom serde wire format --- + + #[cfg(feature = "serde")] + #[test] + fn serde_named_variant_serializes_as_lowercase_iso() { + let json = serde_json::to_string(&Lang::En).expect("serialize"); + assert_eq!( + json, "\"en\"", + "Lang::En must serialize as \"en\", not \"En\"" + ); + let json = serde_json::to_string(&Lang::Yue).expect("serialize"); + assert_eq!(json, "\"yue\""); + } + + #[cfg(feature = "serde")] + #[test] + fn serde_other_serializes_as_inner_string() { + let v = Lang::Other(SmolStr::new("xx")); + let json = serde_json::to_string(&v).expect("serialize"); + assert_eq!( + json, "\"xx\"", + "Lang::Other(\"xx\") must serialize as \"xx\"" + ); + } + + #[cfg(feature = "serde")] + #[test] + fn serde_named_variant_round_trips() { + let json = "\"en\""; + let lang: Lang = serde_json::from_str(json).expect("deserialize"); + assert_eq!(lang, Lang::En); + // Re-serialize and verify identical wire form. + assert_eq!(serde_json::to_string(&lang).unwrap(), json); + } + + #[cfg(feature = "serde")] + #[test] + fn serde_unknown_iso_code_round_trips_via_other() { + let json = "\"xx\""; + let lang: Lang = serde_json::from_str(json).expect("deserialize"); + assert_eq!(lang, Lang::Other(SmolStr::new("xx"))); + assert_eq!(serde_json::to_string(&lang).unwrap(), json); + } + + /// Canonicalisation invariant must hold across serde: + /// deserializing a code that matches a named variant lands in + /// the named variant, not in `Other`. + #[cfg(feature = "serde")] + #[test] + fn serde_deserializes_known_codes_to_named_variants() { + let lang: Lang = serde_json::from_str("\"en\"").unwrap(); + assert!(matches!(lang, Lang::En), "must canonicalise to Lang::En"); + let lang: Lang = serde_json::from_str("\"yue\"").unwrap(); + assert!(matches!(lang, Lang::Yue)); + } + + /// Case-insensitive deserialization (UX win — users editing + /// configs naturally use mixed case): `"EN"`, `"En"`, `"eN"`, + /// `"en"` all canonicalise to `Lang::En`. The on-disk + /// canonical form is lowercase (so re-serialization always + /// emits `"en"`), but reading is permissive. + #[cfg(feature = "serde")] + #[test] + fn serde_accepts_any_case_for_named_variant() { + for input in ["\"en\"", "\"EN\"", "\"En\"", "\"eN\""] { + let lang: Lang = serde_json::from_str(input).expect(input); + assert_eq!( + lang, + Lang::En, + "input {input} must canonicalise to Lang::En" + ); + // Re-serialisation always emits the lowercase form. + assert_eq!(serde_json::to_string(&lang).unwrap(), "\"en\""); + } + } + + /// Mixed-case unknown codes also canonicalise — `"XX"` + /// deserialises to `Lang::Other(SmolStr::new("xx"))`, + /// preserving the canonicalisation invariant (no + /// `Lang::Other("XX")` ever exists in the type). + #[cfg(feature = "serde")] + #[test] + fn serde_lowercases_unknown_code_into_other() { + let lang: Lang = serde_json::from_str("\"XX\"").expect("deserialize"); + assert_eq!(lang, Lang::Other(SmolStr::new("xx"))); + let lang: Lang = serde_json::from_str("\"Xx\"").expect("deserialize"); + assert_eq!(lang, Lang::Other(SmolStr::new("xx"))); + } + + #[cfg(feature = "serde")] + #[test] + fn serde_rejects_empty_string() { + let res: Result = serde_json::from_str("\"\""); + assert!(res.is_err()); + } + + #[cfg(feature = "serde")] + #[test] + fn serde_rejects_overlong_code() { + let res: Result = serde_json::from_str("\"abcdefghi\""); + assert!(res.is_err(), "9-byte code must be rejected"); + } + + #[cfg(feature = "serde")] + #[test] + fn serde_rejects_non_ascii_letters() { + let res: Result = serde_json::from_str("\"français\""); + assert!(res.is_err(), "non-ASCII must be rejected"); + let res: Result = serde_json::from_str("\"a-b\""); + assert!(res.is_err(), "dash must be rejected"); + let res: Result = serde_json::from_str("\"a1b\""); + assert!(res.is_err(), "digits must be rejected"); + } + + /// Old derive-shaped JSON for `Other` (`{"Other":"xx"}`) must + /// fail with the new custom impl — it's an externally-tagged + /// object, not a string. Documents the breaking wire-format + /// change for migrators. + /// + /// Note: legacy `"En"` (Rust variant name) is now ACCEPTED as + /// a side-effect of case-insensitive deserialization. That's a + /// happy accident for migration — old configs that happened to + /// use the variant-name form continue to work, just with the + /// canonical lowercase form on round-trip. No special handling + /// needed. + #[cfg(feature = "serde")] + #[test] + fn serde_rejects_legacy_other_as_map() { + let res: Result = serde_json::from_str(r#"{"Other":"xx"}"#); + assert!( + res.is_err(), + "legacy Other-as-map encoding must be rejected" + ); + } +} diff --git a/whispercpp/src/lib.rs b/whispercpp/src/lib.rs new file mode 100644 index 0000000..3c81b8c --- /dev/null +++ b/whispercpp/src/lib.rs @@ -0,0 +1,20 @@ +#![doc = include_str!("../README.md")] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, allow(unused_attributes))] +#![deny(missing_docs)] + +mod context; +mod error; +mod lang; +mod params; +mod state; +mod sys; + +pub use context::{Context, ContextParams, system_info}; +pub use error::{WhisperError, WhisperResult}; +pub use lang::Lang; +pub use params::{ + MAX_BEAM_SIZE, MAX_INITIAL_TS_S, MAX_N_THREADS, MAX_TEMPERATURE, MIN_TEMPERATURE_INC, Params, + SamplingStrategy, +}; +pub use state::{Segment, State, Token}; diff --git a/whispercpp/src/params.rs b/whispercpp/src/params.rs new file mode 100644 index 0000000..5ebaa1f --- /dev/null +++ b/whispercpp/src/params.rs @@ -0,0 +1,1096 @@ +//! `Params` — the configuration passed to a single +//! [`State::full`](crate::State::full) call. +//! +//! # Ownership model +//! +//! `Params` owns every `CString` it hands to whisper.cpp. The +//! crate's whole point — fix the leak class in whisper-rs's +//! `set_initial_prompt` / `set_language` — depends on this. Each +//! setter that takes a string stores the `CString` in the +//! `Params` struct and replaces the pointer in the FFI struct +//! with `as_ptr`. When `Params` drops, the strings drop with +//! it. +//! +//! # Abort callback +//! +//! `Params` owns the abort closure as +//! `Box bool>>>` — the [`AbortCallback`] +//! type alias. The outer `Box` gives a heap-stable address so +//! the FFI `user_data` pointer survives `Params` moves; the +//! `UnsafeCell` legitimises the `&mut`-from-`&` borrow the +//! trampoline takes through the C ABI; the inner +//! `Box bool>` is the type-erased closure shape +//! whisper.cpp's callback expects. The whisper-rs UB we +//! diagnosed in earlier work (`*mut F` cast to read a +//! `*mut Box` fat-pointer-on-heap) is structurally +//! absent: the trampoline only ever reads through the +//! `UnsafeCell`'s known layout. +//! +//! # Panic-free +//! +//! Every setter returns `Result` if it can fail (interior NUL in +//! a string), and field-only setters are infallible chained +//! returns of `&mut Self`. There is no `expect`/`unwrap`/`panic!` +//! anywhere in this module's safe surface. + +#![allow(unsafe_code)] + +use core::{cell::UnsafeCell, ffi::c_void}; +use std::{ + ffi::CString, + panic::{AssertUnwindSafe, catch_unwind}, +}; + +/// Boxed-in-`UnsafeCell`-boxed-`FnMut` storage for the abort +/// callback. The outer `Box` gives a stable heap address that +/// survives `Params` moves; `UnsafeCell` legitimises the +/// `&mut`-from-`&` access the trampoline performs; the inner +/// `Box bool>` is the type-erased closure shape +/// `set_abort_callback` accepts. Aliased so the +/// `_abort_callback` field stays readable. +type AbortCallback = Box bool>>>; + +/// Upper bound applied by [`Params::set_n_threads`] and +/// [`Params::new`]. +/// +/// **Capped at `1`** — the only value provably safe under +/// whisper.cpp's current threading patterns. Two distinct +/// `vector` use sites combine to require this: +/// +/// * **Mel-spectrogram parallelism** (`whisper.cpp:3212-3217`) +/// spawns workers in a loop with no caller-thread work +/// between iterations. For this site alone, `n = 2` was +/// safe (single iteration, atomic success/fail). +/// * **Multi-decoder process loop** +/// (`whisper.cpp:7233-7242` and `:7483-7493`) spawns +/// `threads[0..n-2]` then runs `process` on the CALLER +/// thread before joining. With `n = 2`: `threads[0]` is +/// joinable, then `process` runs and can throw +/// (vector pushes inside `whisper_sample_token_topk`, +/// `whisper_process_logits`, etc. → `std::bad_alloc`). +/// On throw, stack unwinds → `vector` +/// destructor destroys `threads[0]` while joinable → +/// `std::terminate` BEFORE our shim's `catch (...)` can +/// run. This path is reachable from safe Rust via +/// `Greedy { best_of ≥ 2 }` or `BeamSearch`. +/// +/// `n = 1` short-circuits the multi-decoder branch +/// (`if (n_threads == 1) { process; }` upstream), spawns +/// no mel workers, and is the only value that doesn't reach +/// any abort path. +/// +/// History of this constant across review rounds: +/// `1024 → 64 → 16 → 2 → 1`. Each step closed a previously- +/// unanalysed thread-spawn or caller-thread-throw shape. A +/// real fix to allow `n ≥ 2` needs upstream RAII join guards +/// or per-region exception catches (unfixed bug in +/// whisper.cpp itself). When that lands, this constant can +/// be raised. +/// +/// Callers who can prove host headroom themselves can opt +/// into higher counts via +/// [`Params::set_n_threads_unchecked`] — `unsafe`, with the +/// caller's safety contract that no `std::thread` constructor +/// AND no `process` invocation will throw under the +/// workload's pressure. +pub const MAX_N_THREADS: i32 = 1; + +/// Upper bound applied by [`SamplingStrategy::BeamSearch`]'s +/// `beam_size` and [`SamplingStrategy::Greedy`]'s `best_of` +/// before they reach `whisper_full_params`. +/// +/// `whisper.cpp`'s `whisper_sample_token_topk` indexes +/// `beam_candidates[0]` and forms an iterator from +/// `vector::begin`; `k <= 0` collapses both into invalid +/// memory access. Upstream's only internal +/// guard clamps the *decoder count* to ≥ 1 — the original +/// `beam_size` flows through untouched. We clamp here to +/// `[1, MAX_BEAM_SIZE]` so safe Rust cannot reach the C++ UB. +/// +/// `64` is generous: empirical work on Whisper rarely +/// exceeds `beam_size = 5` (OpenAI's default) and quality +/// saturates by 8–16; the cap is a sanity ceiling, not a +/// tuning knob. +/// +/// Multi-decoder safety relies on the kv_cache_free +/// idempotent patch the `whispercpp-sys` build script +/// applies — that's why the system / +/// pkg-config link path was removed: +/// linking against a stock unpatched libwhisper would +/// silently restore the double-free this cap depends on. +pub const MAX_BEAM_SIZE: i32 = 64; + +/// Clamp a candidate-count knob (`beam_size` / `best_of`) to +/// the safe `[1, MAX_BEAM_SIZE]` range. `const fn` so it can +/// run inside `Params::new`'s match. +#[cfg_attr(not(tarpaulin), inline(always))] +const fn clamp_topk(k: i32) -> i32 { + if k < 1 { + 1 + } else if k > MAX_BEAM_SIZE { + MAX_BEAM_SIZE + } else { + k + } +} + +/// Clamp `n_threads` to `[1, MAX_N_THREADS]`. +/// +/// Used both by [`Params::set_n_threads`] (caller-supplied +/// values) AND by [`Params::new`] (the value +/// `whisper_full_default_params` inherits from +/// `std::min(4, hardware_concurrency)` — `hardware_concurrency` +/// is allowed by the C++ spec to return `0` on hosts where it +/// can't determine the count, which would propagate `0 - 1` +/// underflow into the upstream +/// `vector(n_threads - 1)` constructor). +#[cfg_attr(not(tarpaulin), inline(always))] +const fn clamp_n_threads(n: i32) -> i32 { + if n < 1 { + 1 + } else if n > MAX_N_THREADS { + MAX_N_THREADS + } else { + n + } +} + +/// Hard ceiling for [`Params::set_max_initial_ts`], **in seconds**. +/// +/// `max_initial_ts` is whisper.cpp's biased timestamp ceiling +/// for the FIRST segment of a chunk — it suppresses logits at +/// timestamp tokens beyond `tid0 = round(max_initial_ts / +/// precision)` where `precision = WHISPER_CHUNK_SIZE / +/// n_audio_ctx ≈ 0.02 s/frame` (`whisper.cpp:6604-6610`). +/// OpenAI's reference uses `1.0` second (`decoding.py:L426`). +/// +/// `30.0` matches the model-native chunk width (30 s); any +/// value at or below the chunk width is well-defined. Above +/// that, `tid0` walks past the timestamp-token range and the +/// bias loop becomes a no-op — not unsound, but the knob loses +/// meaning. We cap at the chunk width so the value the safe +/// API forwards always has a legitimate effect. +/// +/// NaN, ±∞, and negatives collapse to `0.0` (the upstream +/// "ignore" sentinel — see `if (max_initial_ts > 0.0)` at +/// `whisper.cpp:6604`) because the +/// `round(t / precision) → int` conversion is UB on +/// non-finite or extreme floats. +pub const MAX_INITIAL_TS_S: f32 = 30.0; + +/// Hard ceiling applied by [`Params::set_temperature`]. +/// +/// Whisper's softmax-temperature contract is `t ∈ [0.0, 1.0]` +/// (1.0 = uniform-over-vocab). Values above 1.0 still type- +/// check inside whisper.cpp but produce no useful sampling +/// behaviour, so we cap there. The cap also bounds upstream's +/// `for (float t = temperature; t < 1.0 + 1e-6; t += inc)` +/// ladder so a `temperature = f32::MAX` start can't blow past +/// the comparison. +pub const MAX_TEMPERATURE: f32 = 1.0; + +/// Smallest positive `temperature_inc` that +/// [`Params::set_temperature_inc`] will forward to the ladder. +/// +/// `1e-3` is well above `ULP(1.0) ≈ 1.19e-7` (the precision +/// floor where `t += inc` stops advancing in `float`), and +/// also bounds the ladder length: from `temperature = 0.0` to +/// `1.0`, the longest legal ladder has `1.0 / 1e-3 ≈ 1000` +/// entries — comfortably below any allocation worry. The +/// upstream OpenAI default is `0.2`; whispery's runner pins +/// `inc = 0.0` for deterministic behaviour. Anything in +/// between is fine; anything below this floor (or NaN / +/// negative) clamps DOWN to `0.0` ("no ladder"). +pub const MIN_TEMPERATURE_INC: f32 = 1e-3; + +/// Clamp an `f32` timestamp index to a finite, in-range +/// value the upstream `round` → `int` conversion can +/// safely consume. `const fn` because `f32::is_nan` / +/// `is_finite` are stable in const since Rust 1.83. +#[cfg_attr(not(tarpaulin), inline(always))] +const fn clamp_max_initial_ts(t: f32) -> f32 { + // NaN, ±∞, and negatives all collapse to 0.0 (the + // upstream "ignore" sentinel — see `if (max_initial_ts > + // 0.0)` guard at `whisper.cpp` line ~7829). The ceiling + // covers the legitimate-but-extreme f32::MAX case. + if !t.is_finite() || t < 0.0 { + 0.0 + } else if t > MAX_INITIAL_TS_S { + MAX_INITIAL_TS_S + } else { + t + } +} + +/// Clamp the per-attempt decoding temperature to a finite, +/// in-range value the upstream ladder can safely loop over. +#[cfg_attr(not(tarpaulin), inline(always))] +const fn clamp_temperature(t: f32) -> f32 { + // NaN / -∞ / negatives → 0.0 (single-attempt at + // greedy / argmax). +∞ / huge → MAX_TEMPERATURE. + if !t.is_finite() || t < 0.0 { + 0.0 + } else if t > MAX_TEMPERATURE { + MAX_TEMPERATURE + } else { + t + } +} + +/// Clamp the temperature-ladder step to either `0.0` ("no +/// ladder, single attempt") or a value large enough that +/// `t += inc` actually advances `t` once it nears the +/// `1.0 + 1e-6` upstream sentinel. +#[cfg_attr(not(tarpaulin), inline(always))] +const fn clamp_temperature_inc(inc: f32) -> f32 { + // NaN / negatives / subnormal-positive → 0.0. Upstream + // treats `inc <= 0.0` as "single attempt" (`temperature_inc + // > 0.0f` guard at whisper.cpp:6845), so we don't need a + // separate sentinel. + if !inc.is_finite() || inc < MIN_TEMPERATURE_INC { + 0.0 + } else if inc > 1.0 { + 1.0 + } else { + inc + } +} + +use crate::{ + error::{WhisperError, WhisperResult}, + sys, +}; + +/// Sampling strategy. Mirrors `whisper_sampling_strategy`. +#[derive(Debug, Clone, Copy)] +pub enum SamplingStrategy { + /// Greedy / argmax decoding with optional best-of resampling. + /// + /// `best_of > 1` activates whisper.cpp's multi-decoder + /// path, which recreates the KV cache when the decoder + /// count grows. Under allocation pressure that path can + /// throw between freeing the old cache and rebuilding the + /// new one, leaving the C-side `whisper_state`'s `kv_self` + /// freed while `state` itself is still live. The Rust + /// `State` poisons itself on every + /// shim exception sentinel — accessors safely return + /// zero/None — but you cannot recover the in-flight + /// `full` call. Stick to `best_of: 1` if you need + /// guaranteed forward progress under OOM. + Greedy { + /// Number of independent decoding attempts at each + /// temperature; the highest-scoring is kept. 1 = pure + /// greedy. See the type-level note about multi-decoder + /// OOM behaviour. + best_of: i32, + }, + /// Beam-search decoding. + /// + /// Always activates the multi-decoder path (see the + /// `Greedy` doc-note). Same OOM caveat applies: if + /// allocation fails inside the KV-cache rebuild, the + /// `State` poisons itself and you must construct a fresh + /// one to retry. For workloads where OOM is a credible + /// failure mode, prefer `Greedy { best_of: 1 }`. + BeamSearch { + /// Number of beams kept per step. + beam_size: i32, + /// Beam patience hyperparameter; -1 disables. + patience: f32, + }, +} + +/// Builder + storage for `whisper_full_params`. Construct via +/// [`Params::new`], chain setters, then pass an immutable +/// reference to [`State::full`](crate::State::full). +pub struct Params { + raw: sys::whisper_full_params, + // Stored CStrings keep the pointers in `raw` valid for the + // entire `Params` lifetime. Drop order: `raw` is plain data, + // these are dropped after the struct is unlinked from any + // FFI call (caller is required to ensure no in-flight `full` + // observes us mid-drop — enforced by `&Params` borrow on + // `State::full`). + _initial_prompt: Option, + _language: Option, + // Owned prompt-token buffer kept alongside `raw.prompt_tokens` + // (which carries `&[whisper_token]` as a raw pointer). Lifetime + // ties to `Params` like the CStrings above. + _prompt_tokens: Option>, + // Boxed abort closure, wrapped in `UnsafeCell` so the + // trampoline can `&mut`-call it through a shared `&Params` + // borrow without violating Rust's aliasing rules. The Box + // gives us a stable address that survives `Params` moves; + // `UnsafeCell` legitimises the interior mutability the + // trampoline performs. + // + // `Params` itself stays `!Sync`-by-default because + // `UnsafeCell` removes the auto-Sync impl — that matches + // `whisper.cpp`'s contract: a single `Params` may not be + // shared between two concurrent `State::full` calls. + _abort_callback: Option, +} + +impl Params { + /// Build a fresh `Params` for the given strategy. Defaults are + /// whisper.cpp's `whisper_full_default_params(strategy)`. + pub fn new(strategy: SamplingStrategy) -> Self { + let cstrategy = match strategy { + SamplingStrategy::Greedy { .. } => sys::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY, + SamplingStrategy::BeamSearch { .. } => { + sys::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH + } + }; + // SAFETY: pure C call returning a value-typed defaults + // struct. + let mut raw = unsafe { sys::whisper_full_default_params(cstrategy as _) }; + // Clamp `best_of` / `beam_size` to `[1, MAX_BEAM_SIZE]`. + // Both feed into upstream `whisper_sample_token_topk` and + // related candidate-vector indexing, where `k <= 0` is + // OOB / iterator-before-begin. The clamp + // is silent because the legitimate use case for any of + // these knobs is `1..=8`; values outside that range are + // programming errors that we'd rather convert to "still + // works" than to a C++ abort across `extern "C"`. + match strategy { + SamplingStrategy::Greedy { best_of } => { + raw.greedy.best_of = clamp_topk(best_of); + } + SamplingStrategy::BeamSearch { + beam_size, + patience, + } => { + raw.beam_search.beam_size = clamp_topk(beam_size); + raw.beam_search.patience = patience; + } + } + // Sanitize the C-supplied default `n_threads` before + // anyone can call `State::full` against fresh `Params`. + // `whisper_full_default_params` derives this from + // `std::min(4, hardware_concurrency)`; the C++ spec + // allows `hardware_concurrency` to return `0`, which + // would propagate `0 - 1` underflow into the mel path's + // `vector(n - 1)` constructor before any caller + // had a chance to invoke `set_n_threads`. + raw.n_threads = clamp_n_threads(raw.n_threads); + Self { + raw, + _initial_prompt: None, + _language: None, + _prompt_tokens: None, + _abort_callback: None, + } + } + + // ── String setters (fallible: interior NUL → InvalidCString). ── + + /// Provide a language hint (e.g. `"en"`, `"zh"`, `"auto"`). + /// Stores the `CString` for the lifetime of `self` — fixing the + /// `whisper-rs` leak. + /// + /// Returns [`WhisperError::InvalidCString`] if `lang` contains + /// an interior NUL byte. **Panic-free.** + pub fn set_language(&mut self, lang: &str) -> WhisperResult<&mut Self> { + let cstr = + CString::new(lang).map_err(|_| WhisperError::InvalidCString(smol_str::SmolStr::new(lang)))?; + self.raw.language = cstr.as_ptr(); + self._language = Some(cstr); + Ok(self) + } + + /// Set the initial prompt (`<|prompt|>` text, decoded by the + /// model before generation). Owns the `CString`. + /// + /// Returns [`WhisperError::InvalidCString`] on interior NUL. + /// **Panic-free.** + pub fn set_initial_prompt(&mut self, prompt: &str) -> WhisperResult<&mut Self> { + let cstr = CString::new(prompt).map_err(|_| { + // The prompt may be very long; trim the diagnostic so the + // error doesn't drag a kilobyte of audio context into log + // tails. SmolStr inlines short captures (≤23 bytes); a 64- + // char head usually allocates but stays bounded. + let head: String = prompt.chars().take(64).collect(); + WhisperError::InvalidCString(smol_str::SmolStr::new(head)) + })?; + self.raw.initial_prompt = cstr.as_ptr(); + self._initial_prompt = Some(cstr); + Ok(self) + } + + // ── Primitive setters (infallible chained `&mut Self`). ── + + /// Whether to detect language from the audio (overrides + /// `set_language`'s hint). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_detect_language(&mut self, on: bool) -> &mut Self { + self.raw.detect_language = on; + self + } + + /// Number of CPU threads for the encode/decode loop. + /// + /// Clamped to `[1, MAX_N_THREADS]` (see [`MAX_N_THREADS`] + /// for the per-`n` safety analysis). Callers who know the + /// host has sufficient thread-table headroom can opt into + /// higher counts via [`Self::set_n_threads_unchecked`] — + /// that's an `unsafe fn`, with the safety contract that the + /// caller asserts no `std::thread` constructor will throw + /// under the workload's pressure. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_n_threads(&mut self, n: i32) -> &mut Self { + self.raw.n_threads = clamp_n_threads(n); + self + } + + /// Set `n_threads` without applying [`MAX_N_THREADS`]'s + /// upper bound. + /// + /// Negative / zero `n` are still clamped up to `1` because + /// `n ≤ 0` triggers the upstream `vector(n - 1)` + /// underflow path (a different bug class from the + /// thread-spawn-throw abort that `MAX_N_THREADS` guards + /// against). Inputs `≥ 1` pass through verbatim. + /// + /// # Safety + /// + /// The caller must guarantee that, in the runtime + /// environment where [`State::full`](crate::State::full) + /// will execute, no `std::thread` constructor inside + /// whisper.cpp's mel-spectrogram worker loop can fail. + /// Practically that means: + /// + /// * the host's per-process thread limit (`ulimit -u` on + /// POSIX, container PIDs cgroup, Windows job-object + /// limits) admits at least `n_threads - 1` more threads + /// than the process has already spawned, AND + /// * sufficient address space and TLS reserve are + /// available for each new thread. + /// + /// If a `std::thread` constructor throws partway through + /// the loop, whisper.cpp's `vector` destructor + /// destroys joinable threads during stack unwinding, which + /// invokes `std::terminate` BEFORE our exception shim + /// can convert the throw into a [`WhisperError`]. The + /// process aborts. **No Rust-level recovery is possible.** + /// + /// Using this function with values outside the + /// `[1, MAX_N_THREADS]` range therefore trades the safe + /// API's "guaranteed-no-process-termination" property for + /// runtime parallelism the safe ceiling forbids. ggml's + /// internal planner usually caps useful parallelism in + /// the single-digit range, so values much beyond `~8` rarely + /// improve performance. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const unsafe fn set_n_threads_unchecked(&mut self, n: i32) -> &mut Self { + // Lower bound stays — this guards a different bug + // (size_t::MAX underflow in `vector(n - 1)`) + // that's not in the unsafe contract. + self.raw.n_threads = if n < 1 { 1 } else { n }; + self + } + + /// Set `BeamSearch::beam_size` without applying + /// [`MAX_BEAM_SIZE`]'s upper bound. + /// + /// `MAX_BEAM_SIZE = 64` is already generous (Whisper + /// quality saturates by `beam_size = 8–16` per OpenAI's + /// own work), so the safe API covers the realistic range. + /// This unchecked setter exists for diagnostic / + /// stress-test scenarios that legitimately need a value + /// past the cap. + /// + /// Mirrors [`Self::set_n_threads_unchecked`]: the lower + /// bound of `1` is preserved (guards `topk` underflow into + /// `vector::begin` iteration in + /// `whisper_sample_token_topk`); only the upper cap is + /// bypassed. + /// + /// # Safety + /// + /// The cap exists as a sanity ceiling, not a memory-safety + /// requirement (the multi-decoder OOM double-free that + /// motivated previous caps is patched at build + /// time — see `whispercpp-sys/build.rs::PATCHES`). Going + /// past the cap therefore trades the safe API's "no silly + /// allocations" property for the right to over-allocate + /// candidate tables on hosts with the memory budget for it. + /// + /// `n ≤ 0` still clamps to `1` (separate underflow bug + /// outside the unsafe contract). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const unsafe fn set_beam_size_unchecked(&mut self, n: i32) -> &mut Self { + self.raw.beam_search.beam_size = if n < 1 { 1 } else { n }; + self + } + + /// Set `Greedy::best_of` without applying [`MAX_BEAM_SIZE`]'s + /// upper bound. Same safety contract as + /// [`Self::set_beam_size_unchecked`] — see that function's + /// docs. + /// + /// `n ≤ 0` still clamps to `1`. + /// + /// # Safety + /// + /// See [`Self::set_beam_size_unchecked`]. Bypassing the cap + /// trades the safe API's "no silly allocations" property + /// for the right to over-allocate candidate tables; the + /// memory-safety properties of multi-decoder paths still + /// hold ('s idempotent `whisper_kv_cache_free` is + /// always applied). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const unsafe fn set_best_of_unchecked(&mut self, n: i32) -> &mut Self { + self.raw.greedy.best_of = if n < 1 { 1 } else { n }; + self + } + + /// Disable transcript prompting from the previous segment's + /// tokens (matches `whisper-rs`'s `set_no_context`). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_no_context(&mut self, on: bool) -> &mut Self { + self.raw.no_context = on; + self + } + + /// `no_speech_prob` threshold. Segments above this are flagged + /// as silence and may be retried at higher temperature. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_no_speech_thold(&mut self, t: f32) -> &mut Self { + self.raw.no_speech_thold = t; + self + } + + /// Decoding temperature for this single attempt. See + /// [`Self::set_temperature_inc`] for the internal ladder. + /// + /// Clamped to `[0.0, MAX_TEMPERATURE]`. NaN, ±∞, and + /// negatives collapse to `0.0`. Upstream's fallback ladder + /// loop is `for (float t = temperature; t < 1.0 + 1e-6; + /// t += inc)`; passing `temperature = -∞` would trip the + /// comparison forever and `push_back` into a `vector` until + /// OOM. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_temperature(&mut self, t: f32) -> &mut Self { + self.raw.temperature = clamp_temperature(t); + self + } + + /// Internal temperature-ladder step. `0.0` pins the decoder + /// to exactly one attempt at `temperature`. + /// + /// Clamped to `{0.0} ∪ [MIN_TEMPERATURE_INC, 1.0]`: + /// * NaN / negative / `(0.0, MIN_TEMPERATURE_INC)` → `0.0` + /// (treated as "ladder disabled, single attempt"). + /// * `> 1.0` → `1.0`. + /// + /// Upstream loops + /// `for (float t = temperature; t < 1.0 + 1e-6; t += inc)`. + /// With a positive `inc` smaller than `ULP(1.0) ≈ 1.19e-7`, + /// `t += inc` does not advance once `t` reaches `1.0` — the + /// loop spins forever, pushing floats into a vector until + /// OOM. Clamping subnormal positive `inc` up to `0.0` + /// (= "no ladder") closes that path while preserving the + /// common shapes (`inc = 0.0` for single-attempt, `inc = + /// 0.2` for the OpenAI-default 5-step ladder). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_temperature_inc(&mut self, inc: f32) -> &mut Self { + self.raw.temperature_inc = clamp_temperature_inc(inc); + self + } + + /// Suppress empty output bias. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_suppress_blank(&mut self, on: bool) -> &mut Self { + self.raw.suppress_blank = on; + self + } + + /// Suppress non-speech tokens. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_suppress_nst(&mut self, on: bool) -> &mut Self { + self.raw.suppress_nst = on; + self + } + + /// Toggles every `print_*` field off in one call. Whisper.cpp + /// otherwise scribbles to stdout/stderr during decode, which + /// is rarely what production callers want. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn silence_print_toggles(&mut self) -> &mut Self { + self.raw.print_special = false; + self.raw.print_progress = false; + self.raw.print_realtime = false; + self.raw.print_timestamps = false; + self + } + + /// Install an abort callback. Whisper.cpp invokes it during + /// the encode loop; returning `true` causes `whisper_full` to + /// bail out early. + /// + /// The closure is stored as `Box bool>>` — + /// a stable address whose layout matches the C trampoline + /// installed under the hood. This is the structural fix for + /// the whisper-rs `set_abort_callback_safe` UB. + pub fn set_abort_callback(&mut self, f: F) -> &mut Self + where + F: FnMut() -> bool + 'static, + { + // ordering: the previous implementation + // first published a fresh `user_data` pointer into + // `self.raw`, then assigned `self._abort_callback = + // Some(outer)`. Replacement-assignment drops the OLD + // `Some(outer)`; if its captured closure's `Drop` panics + // and a caller has wrapped this setter in `catch_unwind`, + // the unwind tears down the NEW `outer` while + // `self.raw.abort_callback_user_data` still points at it + // — leaving `Params` with a dangling FFI pointer that the + // next `State::full` call would dereference. Reorder so + // the raw fields are NEVER pointing at a freed payload. + // + // 1. Null the raw hooks first. If anything else fails + // below, the state is "no callback installed", which + // is always safe. + self.raw.abort_callback = None; + self.raw.abort_callback_user_data = core::ptr::null_mut(); + // 2. Drop the old owner explicitly. If the old closure's + // `Drop` panics, we unwind out of this function with + // `raw` already cleared (step 1) — the caller's + // `catch_unwind` cleanup leaves `Params` in a + // no-callback state, no dangling pointer. + let _old = self._abort_callback.take(); + drop(_old); + // 3. Build the new owner. `Box::new` could panic on OOM; + // if it does, `self._abort_callback` stays `None` and + // `raw` stays cleared. + let outer: AbortCallback = Box::new(UnsafeCell::new(Box::new(f))); + // 4. Derive the FFI pointer from `outer` BEFORE moving it + // into `self._abort_callback`. `Box>` + // derefs to a heap address that is stable across the + // move into `Option>` (the `Box` value itself + // moves; the heap allocation it owns does not). This + // keeps the setter panic-free — no `expect` / + // `unwrap` after the assignment. + let user_data = (&*outer) as *const UnsafeCell bool>> as *mut c_void; + self._abort_callback = Some(outer); + // 5. Publish the trampoline + user_data (from the FFI's + // perspective — whisper.cpp doesn't run concurrently + // here because we hold `&mut self`). + self.raw.abort_callback_user_data = user_data; + self.raw.abort_callback = Some(abort_trampoline); + self + } + + // ── Audio windowing ───────────────────────────────────────── + + /// Start offset into the audio, in milliseconds. Whisper.cpp + /// internally seeks past the first `offset_ms` worth of mel + /// frames before decoding. Defaults to `0`. + /// + /// Negative values are silently clamped to `0`. Upstream + /// turns `offset_ms` into a negative `mel_offset` that + /// reaches `mel_inp.data[j*n_len + i]` with `i < 0` — an + /// out-of-bounds native read reachable from safe Rust per + /// `whisper.cpp:2393–2398`. Clamping at the safe setter + /// keeps the UB from crossing the FFI. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_offset_ms(&mut self, ms: i32) -> &mut Self { + self.raw.offset_ms = if ms < 0 { 0 } else { ms }; + self + } + + /// Hard cap on audio duration to decode, in milliseconds. + /// `0` means "to end of input". Defaults to `0`. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_duration_ms(&mut self, ms: i32) -> &mut Self { + self.raw.duration_ms = ms; + self + } + + /// Override the encoder's audio context window. `0` keeps the + /// model's native value (e.g. 1500 frames = 30s for the + /// vanilla checkpoints); smaller values trade quality for + /// speed on chunks << 30s. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_audio_ctx(&mut self, n: i32) -> &mut Self { + self.raw.audio_ctx = n; + self + } + + // ── Decoding limits ──────────────────────────────────────── + + /// Cap on tokens decoded per attempt. `0` lets whisper.cpp + /// run to its natural EOT. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_max_tokens(&mut self, n: i32) -> &mut Self { + self.raw.max_tokens = n; + self + } + + /// Maximum characters per segment. Combined with + /// [`Self::set_split_on_word`], this is whisper.cpp's + /// segment-shaping mechanism. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_max_len(&mut self, n: i32) -> &mut Self { + self.raw.max_len = n; + self + } + + /// Maximum index a `<|t_x|>` token may take on the FIRST + /// segment of a chunk. Whisper.cpp uses this to bias against + /// implausibly large initial timestamps. Defaults to `1.0` + /// (i.e. ≤ 1s lead-in is the typical valid range). + /// + /// Non-finite values (`NaN` / `±∞`) and negatives are + /// silently clamped to `0.0`; values above + /// [`MAX_INITIAL_TS_S`] are clamped to that ceiling. + /// Upstream converts `std::round(max_initial_ts / + /// precision)` to `int`, which is undefined behaviour in + /// C++ for non-finite or out-of-int-range floats. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_max_initial_ts(&mut self, t: f32) -> &mut Self { + self.raw.max_initial_ts = clamp_max_initial_ts(t); + self + } + + /// Cap on text context (`tokens_prev`) carried over between + /// segments. Whisper.cpp internally truncates from the head + /// when the prompt would exceed this. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_n_max_text_ctx(&mut self, n: i32) -> &mut Self { + self.raw.n_max_text_ctx = n; + self + } + + /// Force whisper.cpp to emit at most ONE segment per `full` + /// call. Useful when callers do their own segmentation + /// upstream. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_single_segment(&mut self, on: bool) -> &mut Self { + self.raw.single_segment = on; + self + } + + // ── Quality gates ────────────────────────────────────────── + + /// Whisper.cpp's internal logprob threshold for the temperature + /// fallback ladder. Lower (more negative) = accept more + /// uncertain decodes. Whispery's runner usually pins + /// `temperature_inc=0` and gates externally on its own + /// `log_prob_threshold` knob, so this is mostly a passthrough + /// for callers that want whisper.cpp's built-in ladder. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_logprob_thold(&mut self, t: f32) -> &mut Self { + self.raw.logprob_thold = t; + self + } + + /// Whisper.cpp's entropy gate (token-prob distribution must + /// not collapse). Higher = stricter. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_entropy_thold(&mut self, t: f32) -> &mut Self { + self.raw.entropy_thold = t; + self + } + + /// Per-token timestamp probability threshold. Defaults to + /// `0.01` upstream. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_thold_pt(&mut self, t: f32) -> &mut Self { + self.raw.thold_pt = t; + self + } + + /// Sum-over-timestamps probability threshold. Defaults to + /// `0.01` upstream. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_thold_ptsum(&mut self, t: f32) -> &mut Self { + self.raw.thold_ptsum = t; + self + } + + /// Beam-search length penalty. Negative = penalise longer + /// outputs; `-1.0` disables. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_length_penalty(&mut self, p: f32) -> &mut Self { + self.raw.length_penalty = p; + self + } + + // ── Output shape ─────────────────────────────────────────── + + /// Disable timestamp tokens in the output. Faster and + /// suppresses the `<|t_x|>` markers; segment `t0`/`t1` still + /// land via whisper.cpp's segment splitter, but per-token + /// timestamps are not emitted. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_no_timestamps(&mut self, on: bool) -> &mut Self { + self.raw.no_timestamps = on; + self + } + + /// Split segments only at word boundaries. Combines with + /// [`Self::set_max_len`]. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_split_on_word(&mut self, on: bool) -> &mut Self { + self.raw.split_on_word = on; + self + } + + /// Compute per-token timestamps via DTW. Whispery handles + /// word-level alignment via wav2vec2 instead, so this is + /// rarely useful — exposed for parity. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_token_timestamps(&mut self, on: bool) -> &mut Self { + self.raw.token_timestamps = on; + self + } + + // ── Decoding seed ────────────────────────────────────────── + + /// Switch whisper.cpp to translation (transcribe → English). + /// Whispery is transcription-only in production; exposed for + /// completeness. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn set_translate(&mut self, on: bool) -> &mut Self { + self.raw.translate = on; + self + } + + /// Seed the decoder with previously-emitted tokens (acts as + /// `tokens_prev`). The slice is COPIED into a `Vec` owned by + /// `self`; the caller's slice may be dropped after this call + /// returns. Pass `&[]` to clear a previously-set prompt. + pub fn set_tokens(&mut self, tokens: &[i32]) -> &mut Self { + if tokens.is_empty() { + self.raw.prompt_tokens = core::ptr::null(); + self.raw.prompt_n_tokens = 0; + self._prompt_tokens = None; + } else { + // Bound the slice copy at `i32::MAX` elements so the + // assignment to `prompt_n_tokens` (a C `int`) cannot + // wrap to a negative or truncated count. `prompt_n_tokens` + // crosses the FFI as `int`, and `tokens.len() as i32` + // would silently truncate / wrap on slices wider than + // 2 GiB of `whisper_token` (impossible in practice on + // any current platform, but the cast is unsound under + // the crate's panic-free contract). + let max_len = i32::MAX as usize; + let take = tokens.len().min(max_len); + let owned: Vec = tokens[..take].to_vec(); + self.raw.prompt_tokens = owned.as_ptr(); + self.raw.prompt_n_tokens = owned.len() as i32; + self._prompt_tokens = Some(owned); + } + self + } + + /// Internal: hand the raw C struct to `state::full`. + #[cfg_attr(not(tarpaulin), inline(always))] + pub(crate) const fn as_raw(&self) -> sys::whisper_full_params { + self.raw + } + + /// Internal: borrow the prompt-token slice (if any). Used by + /// `State::full` to range-check against the model's vocab + /// before forwarding the FFI call. Returns `None` when no + /// prompt has been set (or [`Self::set_tokens`] was last + /// called with an empty slice). + #[cfg_attr(not(tarpaulin), inline(always))] + pub(crate) fn prompt_tokens(&self) -> Option<&[i32]> { + self._prompt_tokens.as_deref() + } +} + +// Manual `Debug` because the boxed abort callback is `dyn FnMut` +// (no `Debug` impl). We elide it; the rest of the params surface +// renders fine via the bindgen-derived `Debug` on +// `whisper_full_params`. +impl core::fmt::Debug for Params { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Params") + .field("raw", &self.raw) + .field("language", &self._language) + .field("initial_prompt", &self._initial_prompt) + .field( + "abort_callback", + &self + ._abort_callback + .as_ref() + .map(|_| "") + .unwrap_or(""), + ) + .field( + "prompt_tokens", + &self._prompt_tokens.as_ref().map(|v| v.len()).unwrap_or(0), + ) + .finish() + } +} + +unsafe extern "C" fn abort_trampoline(user_data: *mut c_void) -> bool { + // SAFETY: `user_data` is the pointer we stored in + // `set_abort_callback`. It points to a live + // `UnsafeCell bool>>` whose lifetime is + // tied to the owning `Params` (which the caller of + // `State::full` borrows for the duration of the call). The + // `UnsafeCell` is what makes the `&mut`-borrow we form below + // legal — `Params` is reachable to safe code only through + // a shared `&Params` reference, and ordinary fields of a + // shared reference cannot be mutated; routing through + // `UnsafeCell` is the canonical opt-in for this pattern. + // (Whisper.cpp guarantees no concurrent invocation of the + // callback for a single state, so the FnMut borrow is + // exclusive at all times.) + let cell: &UnsafeCell bool>> = + unsafe { &*(user_data as *const UnsafeCell bool>>) }; + // SAFETY: see above; whisper.cpp invokes the callback + // serially, so this is the only outstanding access. + let boxed: &mut Box bool> = unsafe { &mut *cell.get() }; + + // Catch unwinds — panicking across `extern "C"` is + // undefined behaviour. On panic, return `true` so + // whisper.cpp aborts the in-flight encode rather than + // continuing into an inconsistent state, and let the panic + // surface on the Rust side via the per-thread panic info + // (`std::thread::Result`-like — callers wrapping + // `State::full` in `catch_unwind` will see it). + catch_unwind(AssertUnwindSafe(boxed)).unwrap_or(true) +} + +// FFI-touching tests are skipped under Miri (every test +// here that constructs `Params` reaches `whisper_full_default_params`, +// which Miri can't execute). The pure-helper tests +// (`clamp_*`, constants) stay enabled. +// reinstated Miri coverage for the safe wrapper. +#[cfg(test)] +mod tests { + use super::*; + + /// Default-constructed `Params` must never carry an + /// `n_threads < 1`. `whisper_full_default_params` initialises + /// it from `std::min(4, hardware_concurrency)`, and + /// `hardware_concurrency` may legally return `0`; without + /// our `clamp_n_threads` in `Params::new`, that `0` would + /// cross the FFI boundary into a `vector(0 - 1)` + /// underflow. + #[test] + #[cfg_attr(miri, ignore = "FFI: whisper_full_default_params")] + fn default_params_n_threads_normalises_to_at_least_one() { + let p = Params::new(SamplingStrategy::Greedy { best_of: 1 }); + assert!( + p.raw.n_threads >= 1, + "default n_threads = {}; must be ≥ 1 to dodge the upstream vector(n - 1) underflow", + p.raw.n_threads, + ); + assert!( + p.raw.n_threads <= MAX_N_THREADS, + "default n_threads = {} above MAX_N_THREADS = {}", + p.raw.n_threads, + MAX_N_THREADS, + ); + } + + /// Caller-supplied invalid `n_threads` clamps in both + /// directions, with no panic and no FFI involvement. + #[test] + #[cfg_attr(miri, ignore = "FFI: whisper_full_default_params")] + fn set_n_threads_clamps_zero_negative_and_oversized() { + let mut p = Params::new(SamplingStrategy::Greedy { best_of: 1 }); + p.set_n_threads(0); + assert_eq!(p.raw.n_threads, 1, "0 → 1"); + p.set_n_threads(-42); + assert_eq!(p.raw.n_threads, 1, "negative → 1"); + p.set_n_threads(i32::MIN); + assert_eq!(p.raw.n_threads, 1, "i32::MIN → 1"); + p.set_n_threads(MAX_N_THREADS + 1); + assert_eq!(p.raw.n_threads, MAX_N_THREADS, "above-cap → MAX_N_THREADS"); + p.set_n_threads(i32::MAX); + assert_eq!(p.raw.n_threads, MAX_N_THREADS, "i32::MAX → MAX_N_THREADS"); + } + + /// `clamp_n_threads` is the canonical helper used by both + /// `Params::new` and `set_n_threads`. Pin its surface so a + /// future refactor can't quietly change behaviour. + #[test] + fn clamp_n_threads_pins_invariants() { + assert_eq!(clamp_n_threads(0), 1); + assert_eq!(clamp_n_threads(-1), 1); + assert_eq!(clamp_n_threads(1), 1); + assert_eq!(clamp_n_threads(MAX_N_THREADS), MAX_N_THREADS); + assert_eq!(clamp_n_threads(MAX_N_THREADS + 1), MAX_N_THREADS); + // Anything ≥ MAX_N_THREADS clamps down. + // narrowed this to 1 (the multi-decoder process loop + // races caller-thread throws against joinable workers + // even at n=2). Pin a value that used to pass through + // under the cap of 2 to lock in the new + // ceiling. + assert_eq!(clamp_n_threads(2), MAX_N_THREADS); + assert_eq!(clamp_n_threads(8), MAX_N_THREADS); + } + + /// `set_n_threads_unchecked` bypasses [`MAX_N_THREADS`] but + /// still applies the lower-bound clamp (the underflow + /// guard for `vector(n - 1)`). + #[test] + #[cfg_attr(miri, ignore = "FFI: whisper_full_default_params")] + fn set_n_threads_unchecked_bypasses_upper_cap_only() { + let mut p = Params::new(SamplingStrategy::Greedy { best_of: 1 }); + // Above the safe cap → passes through verbatim. + // SAFETY (test): we never call State::full here, so the + // unsafe contract about `std::thread` headroom is + // vacuously satisfied. + unsafe { p.set_n_threads_unchecked(8) }; + assert_eq!(p.raw.n_threads, 8); + unsafe { p.set_n_threads_unchecked(MAX_N_THREADS + 1) }; + assert_eq!(p.raw.n_threads, MAX_N_THREADS + 1); + unsafe { p.set_n_threads_unchecked(64) }; + assert_eq!(p.raw.n_threads, 64); + // Lower bound still clamps — `n ≤ 0` is a separate bug + // shape (vector underflow) that the unsafe contract + // does NOT cover. + unsafe { p.set_n_threads_unchecked(0) }; + assert_eq!(p.raw.n_threads, 1, "0 → 1 even on unchecked path"); + unsafe { p.set_n_threads_unchecked(-7) }; + assert_eq!(p.raw.n_threads, 1, "negative → 1 even on unchecked path"); + } + + /// Pin `MAX_BEAM_SIZE`. The bundled build patches + /// `whisper_kv_cache_free` to be idempotent so + /// multi-decoder is safe; `64` is the sanity ceiling. + #[test] + fn max_beam_size_pins_to_64() { + assert_eq!(MAX_BEAM_SIZE, 64); + } + + /// `set_beam_size_unchecked` / `set_best_of_unchecked` + /// bypass `MAX_BEAM_SIZE` but still clamp `n ≤ 0` to 1 + /// (separate underflow bug shape outside the unsafe + /// contract). Same shape as `set_n_threads_unchecked`. + #[test] + #[cfg_attr(miri, ignore = "FFI: whisper_full_default_params")] + fn unchecked_topk_setters_bypass_upper_cap_only() { + let mut p = Params::new(SamplingStrategy::BeamSearch { + beam_size: 1, + patience: -1.0, + }); + // Above MAX_BEAM_SIZE — passes through verbatim. + // SAFETY (test): no State::full call, contract vacuously + // satisfied. + unsafe { p.set_beam_size_unchecked(MAX_BEAM_SIZE + 1) }; + assert_eq!(p.raw.beam_search.beam_size, MAX_BEAM_SIZE + 1); + unsafe { p.set_beam_size_unchecked(128) }; + assert_eq!(p.raw.beam_search.beam_size, 128); + // Lower bound still clamps. + unsafe { p.set_beam_size_unchecked(0) }; + assert_eq!(p.raw.beam_search.beam_size, 1); + unsafe { p.set_beam_size_unchecked(-9) }; + assert_eq!(p.raw.beam_search.beam_size, 1); + + let mut g = Params::new(SamplingStrategy::Greedy { best_of: 1 }); + unsafe { g.set_best_of_unchecked(MAX_BEAM_SIZE + 1) }; + assert_eq!(g.raw.greedy.best_of, MAX_BEAM_SIZE + 1); + unsafe { g.set_best_of_unchecked(0) }; + assert_eq!(g.raw.greedy.best_of, 1); + } +} diff --git a/whispercpp/src/state.rs b/whispercpp/src/state.rs new file mode 100644 index 0000000..d0ea7bb --- /dev/null +++ b/whispercpp/src/state.rs @@ -0,0 +1,567 @@ +//! Inference state, segments, and tokens. +//! +//! `State` owns an [`Arc`] of its parent [`Context`], which keeps +//! the model alive for the state's lifetime. We picked Arc- +//! ownership over a `'ctx` borrow because the realistic usage +//! pattern (worker pools storing per-thread state across jobs) +//! is hard to express with a lifetime — the borrow checker can't +//! see that the parent Arc lives in the same stack frame as the +//! State without explicit annotation. Arc-owned lets `State` be +//! `'static` and storable in `Option` / channels. +//! +//! The state is single-threaded by design (whisper.cpp scratch +//! buffers + KV cache are not thread-safe); we mark it `!Sync` +//! implicitly by holding a raw pointer. + +#![allow(unsafe_code)] + +use core::{ptr::NonNull, str}; +use std::sync::Arc; + +use crate::{ + context::Context, + error::{WhisperError, WhisperResult}, + lang::Lang, + params::Params, + sys, +}; + +/// Per-call inference state. Owns an [`Arc`] so the +/// model outlives every per-call buffer. +/// +/// # Poisoning +/// +/// `whisper_full_with_state` calls `whisper_free_state` on +/// itself before returning `-7` from a multi-decoder KV-cache +/// allocation failure (see `whisper.cpp:7126` in the patched +/// fork). After that the underlying C++ pointer is gone — we +/// MUST NOT free it again on `Drop` and accessor methods MUST +/// NOT touch it. We model this by storing the pointer as +/// `Option>`: on `-7` we `take()` it, flipping the +/// `State` into a "poisoned" mode where every method +/// short-circuits to a safe zero/None and `Drop` becomes a +/// no-op. +/// +/// On caught-exception sentinels (`rc <= -100`), `State::full` +/// instead calls `whisper_free_state` EXPLICITLY to release the +/// native allocation, then nulls `self.ptr`. This is viable +/// because the fork's idempotent `whisper_kv_cache_free` patch +/// closed the double-free hazard that previously forced us to +/// leak the state. +/// +/// # Recovery contract on `StateLost` +/// +/// 1. Drop this `State`. The native allocation is already +/// released (no leak). +/// 2. The parent `Context` is poisoned — +/// [`Context::create_state`] will return +/// `WhisperError::ContextPoisoned` until you drop and +/// reconstruct the Context. This is defensive: the +/// pressure that caused the throw is likely still +/// present. +/// 3. Surface the error to your supervisor; once pressure +/// has resolved, reload the model in a fresh `Context`. +/// +/// `whispery`'s worker pool (`whisper_pool.rs`) treats +/// `WhisperError::StateLost` as a `WorkFailure::AsrFailed` +/// without auto-recreating the `State`, matching this +/// contract. Other consumers must implement equivalent +/// supervision. +pub struct State { + ptr: Option>, + // Keeps the parent Context alive. No `'ctx` lifetime: makes + // State `'static` for storage in `Option` / channels / + // the long-lived worker structs whispery uses. + ctx: Arc, +} + +// SAFETY: the `whisper_state` pointer is owned exclusively by us +// (no aliases). whisper.cpp permits passing a state across +// threads as long as no two threads call `whisper_full` on it +// concurrently — that's the same guarantee `Send` requires. The +// Arc is itself Send. +unsafe impl Send for State {} + +impl State { + /// Internal constructor used by [`Context::create_state`]. + pub(crate) fn from_raw(ptr: NonNull, ctx: Arc) -> Self { + Self { + ptr: Some(ptr), + ctx, + } + } + + /// Borrow the parent context. Useful when calling sites need + /// the same Arc to construct sibling state objects. + pub fn context(&self) -> &Arc { + &self.ctx + } + + /// `true` if the underlying whisper.cpp state was freed + /// behind the Rust owner (multi-decoder KV-cache allocation + /// failure → `whisper_full_with_state` returned `-7` after + /// internally calling `whisper_free_state`). After + /// poisoning, every accessor returns a zero/None default + /// and `Drop` skips the C-side free; the only sound move + /// for callers is to drop the [`State`] and (if desired) + /// allocate a fresh one via + /// [`Context::create_state`](crate::Context::create_state). + pub fn is_poisoned(&self) -> bool { + self.ptr.is_none() + } + + /// Internal: shorthand for accessor methods that need the + /// raw pointer. Returns `None` for a poisoned state. + #[inline] + fn raw(&self) -> Option<*mut sys::whisper_state> { + self.ptr.map(NonNull::as_ptr) + } + + /// Run the encoder + decoder over `samples` (16 kHz mono f32). + /// + /// Returns `Ok()` on the success contract; the segment list + /// is then accessible via [`State::n_segments`] and + /// [`State::segment`]. **Panic-free.** Returns + /// * [`WhisperError::SamplesOverflow`] when `samples.len` + /// does not fit in the C `int` whisper.cpp expects. + /// * [`WhisperError::SamplesTooShort`] when the buffer is + /// smaller than the 201-sample lower bound whisper.cpp's + /// `log_mel_spectrogram` reflective pad requires. + /// * [`WhisperError::TokenOutOfRange`] when a `prompt_tokens` + /// id passed to [`Params::set_tokens`] does not lie in + /// `[0, n_vocab)`. Upstream feeds those ids straight + /// into `ggml_get_rows(model.d_te, embd)`; out-of-range + /// indices either trip a CPU-side assert or cause invalid + /// memory access in GPU kernels, both of which abort + /// across the FFI. + pub fn full(&mut self, params: &Params, samples: &[f32]) -> WhisperResult<()> { + // `log_mel_spectrogram` runs + // `reverse_copy(samples + 1, samples + 1 + 200, …)` for its + // start-of-buffer reflective pad, reading + // `samples[1..201]`. Sub-201-sample inputs trigger an + // out-of-bounds read in the C++ kernel before whisper.cpp's + // later short-input check fires. Reject them here so the + // UB never crosses the FFI boundary. + // surfaced this as a critical finding against whisper.cpp + // v1.8.4 (`src/whisper.cpp:3201`). + const MIN_SAMPLES_FOR_REFLECTIVE_PAD: usize = 201; + if samples.len() < MIN_SAMPLES_FOR_REFLECTIVE_PAD { + return Err(WhisperError::SamplesTooShort { + samples: samples.len(), + min_required: MIN_SAMPLES_FOR_REFLECTIVE_PAD, + }); + } + let len = i32::try_from(samples.len()).map_err(|_| WhisperError::SamplesOverflow { + samples: samples.len(), + })?; + // Prompt-token range check. `Params::set_tokens` accepts an + // arbitrary `&[i32]`; we couldn't validate at setter time + // without forcing callers to thread a `Context` through + // `Params::new`. Doing the check here (we have the Context + // via `self.ctx`) catches the unsoundness while keeping + // `Params` free of model state. surfaced + // this against whisper.cpp v1.8.4 (`src/whisper.cpp:6915`): + // upstream feeds these ids into `ggml_get_rows(d_te)`, where + // CPU asserts and GPU kernels both abort on OOB. + if let Some(prompt) = params.prompt_tokens() { + let vocab = self.ctx.n_vocab(); + for &tok in prompt { + if tok < 0 || tok >= vocab { + return Err(WhisperError::TokenOutOfRange { + token: tok, + vocab_size: vocab, + }); + } + } + } + // Poisoned state can't run inference — the C-side pointer + // is gone. Surface this as `StateLost` so callers can + // distinguish "drop this State, do not auto-retry" from + // "transient error, state still usable". + let state_ptr = match self.ptr { + Some(p) => p.as_ptr(), + None => return Err(WhisperError::StateLost { code: -7 }), + }; + // finding 2: serialise FFI entry through + // the Context's per-Context mutex. Without this, multiple + // workers each holding their own State (the documented + // shared-`Arc` pattern) could ALL be inside + // `whispercpp_full_with_state` simultaneously when an OOM + // / system_error hits, each leaking ~360 MB on its own + // sentinel return before any of them got to mark the + // Context lost — turning the per-Context leak cap into a + // per-concurrent-worker cap. Holding the mutex across + // the FFI call makes the cap structural: at most one + // in-flight call per Context, so at most one leaked + // state per Context. + // + // The lock acquisition is the FIRST thing that happens + // on the inference path — the prompt-token / sample-len + // checks above don't touch native state and are + // contention-free, so we keep them outside the critical + // section. + let _full_guard = self.ctx.full_lock(); + // (re-applied under the lock): refuse to + // enter FFI if a sibling State on the same + // `Arc` has already poisoned the Context. The + // mutex above ensures we observe the latest poison + // state (the previous holder set `lost` via Release + // store BEFORE releasing the mutex; our Acquire load + // here sees that store). The sibling's `Drop` still + // frees its native state cleanly (its `self.ptr` is + // still `Some`); we only block the FFI entry that could + // turn that intact native state into another leaked one. + if self.ctx.is_poisoned() { + return Err(WhisperError::ContextPoisoned); + } + // SAFETY: + // - `self.ctx.as_raw` is a non-null whisper_context + // (NonNull invariant on Context); kept alive by the Arc + // we own. + // - `state_ptr` is non-null (just unwrapped from NonNull). + // - `params.as_raw` is a fully-initialised + // `whisper_full_params` whose owned CStrings live as long + // as `params`. + // - `samples.as_ptr` is valid for `len` f32 reads + // (slice invariant). + // + // Routed through the exception-catching shim + // `whispercpp_full_with_state`: upstream constructs + // `std::thread` workers and allocates `std::vector` + // buffers, both of which can throw (`std::system_error`, + // `std::bad_alloc`) under realistic resource pressure. + // C++ exceptions across `extern "C"` are UB; sentinel + // codes documented on + // `whispercpp_shim.h::WHISPERCPP_ERR_*`. + let rc = unsafe { + sys::whispercpp_full_with_state( + self.ctx.as_raw(), + state_ptr, + params.as_raw(), + samples.as_ptr(), + len, + ) + }; + if rc == 0 { + return Ok(()); + } + // Failure regimes (— the leak that + // earlier rounds documented as "unavoidable" is now + // freeable thanks to the idempotent + // `whisper_kv_cache_free` patch): + // + // 1. `rc == -7` — multi-decoder KV-cache rebuild + // failure. Upstream calls `whisper_free_state(state)` + // before returning the code, so our pointer is + // dangling. We MUST NOT free again. Suppress Drop + // and surface as `StateLost`. + // + // 2. `rc <= WHISPERCPP_ERR_BAD_ALLOC` — the shim caught + // a C++ exception (`std::bad_alloc`, + // `std::system_error`, other `std::exception`, or + // unknown). Upstream did NOT call + // `whisper_free_state`. Earlier rounds intentionally + // leaked the state (~360 MB on `large-v3-turbo`) + // because the multi-decoder rebuild path could leave + // `kv_self.buffer` freed-but-not-nulled — a second + // free would crash. The patch closed that + // door by making `whisper_kv_cache_free` idempotent + // (it now nulls `cache.buffer` after free; re-call + // short-circuits via `ggml_backend_buffer_free`'s own + // null guard at `ggml-backend.cpp:107-109`). Every + // other state member (`std::vector<...>` fields, + // backend lists, sched slots, `whisper_batch`) either + // self-destructs or is freed exactly once by + // `whisper_free_state`. So calling + // `whisper_free_state` on a state that the shim + // rescued from a throw is safe — and avoids the + // leak. + // + // 3. `rc < 0` otherwise — recoverable upstream failure + // (encode/decode failure, etc.). State is intact; + // return `Full { code }` and keep the pointer alive + // so future calls can reuse the state. + if rc == -7 { + // Upstream already freed; suppress our Drop. + self.ptr = None; + // poison the parent Context so + // subsequent `create_state` calls fail with + // `ContextPoisoned`. The native state is gone, but + // the failure mode (multi-decoder OOM) is a strong + // signal of resource pressure on this Context. + self.ctx.mark_lost(); + return Err(WhisperError::StateLost { code: rc }); + } + if rc <= sys::WHISPERCPP_ERR_BAD_ALLOC { + // finding 2: poison the Context BEFORE + // touching the native state. `Context::create_state` + // does not take `full_lock`, so without this ordering + // another thread could observe `lost == false` during + // the `whisper_free_state` window, allocate a fresh + // State, and publish it before our `mark_lost` ran. + // Releasing the poison flag first closes the race: + // any concurrent `create_state` either sees `lost == + // true` (returns `ContextPoisoned`) or completes + // before we reach this branch. + self.ctx.mark_lost(); + + // Shim caught a C++ throw. Free the native state + // explicitly. `take` clears + // `self.ptr` so Drop becomes a no-op. + if let Some(p) = self.ptr.take() { + // SAFETY: `p` was returned by `whisper_init_state` + // and held exclusively by `self`; the shim's catch + // means `whisper_full_with_state` did NOT call + // `whisper_free_state`, so `p` is still owned by us. + // The idempotent kv_cache_free patch makes + // this call safe even when the throw left + // `kv_self.buffer` released. + unsafe { sys::whisper_free_state(p.as_ptr()) }; + } + return Err(WhisperError::StateLost { code: rc }); + } + Err(WhisperError::Full { code: rc }) + } + + /// Number of segments produced by the most recent + /// [`State::full`] call. Returns `0` for a poisoned state + /// (see [`Self::is_poisoned`]). + pub fn n_segments(&self) -> i32 { + let Some(state) = self.raw() else { return 0 }; + // SAFETY: state non-null; pure read. + unsafe { sys::whisper_full_n_segments_from_state(state) } + } + + /// Borrow segment `idx` (0-indexed). Returns `None` for a + /// poisoned state, or when `idx` is out of range. + pub fn segment(&self, idx: i32) -> Option> { + let state_ptr = self.ptr?; + if idx < 0 || idx >= self.n_segments() { + return None; + } + Some(Segment { + state: state_ptr, + idx, + _marker: core::marker::PhantomData, + }) + } + + /// Detected (or forced) language for the most recent + /// [`State::full`] call. + /// + /// Returns `None` when: + /// * the state is poisoned (see [`Self::is_poisoned`]); + /// * whisper.cpp set the internal lang id to `-1` (no + /// detection ran, no hint set); + /// * the id is out of whisper.cpp's published table; or + /// * the table entry is not valid UTF-8 (corrupt build). + /// + /// Returns the strongly-typed [`Lang`] (canonicalised + /// through `Lang::from_iso639_1`) so callers don't pattern- + /// match on raw ISO strings. Known whisper.cpp codes round- + /// trip to their named variant; unknown codes land in + /// `Lang::Other` with the lowercase ISO string preserved. + pub fn detected_lang(&self) -> Option { + let state = self.raw()?; + // SAFETY: state non-null; pure read. + let id = unsafe { sys::whisper_full_lang_id_from_state(state) }; + if id < 0 { + return None; + } + // SAFETY: whisper_lang_str is a pure C accessor returning + // a pointer into a static const-table baked into + // libwhisper. The returned slice lives forever. + let raw = unsafe { sys::whisper_lang_str(id) }; + if raw.is_null() { + return None; + } + // SAFETY: NUL-terminated; static lifetime per whisper.cpp. + let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() }; + let code = str::from_utf8(bytes).ok()?; + Some(Lang::from_iso639_1(code)) + } +} + +impl Drop for State { + fn drop(&mut self) { + // Poisoned state — whisper.cpp already freed itself. We + // MUST NOT call `whisper_free_state` again. + if let Some(p) = self.ptr.take() { + // SAFETY: ptr was non-null and produced by + // whisper_init_state; never freed by anyone else + // because `is_poisoned` was false. + unsafe { sys::whisper_free_state(p.as_ptr()) } + } + } +} + +/// Borrowed view of one segment. +/// +/// Reaches into the `State` lazily — calling [`Segment::text`] +/// performs an FFI call each time. That matches whisper.cpp's +/// own model: segments are addressed by index, not pre-extracted. +#[derive(Clone, Copy)] +pub struct Segment<'a> { + state: NonNull, + idx: i32, + _marker: core::marker::PhantomData<&'a ()>, +} + +impl<'a> Segment<'a> { + /// Start time, in centiseconds (whisper.cpp's native unit). + /// Multiply by 0.01 for seconds. + pub fn t0(&self) -> i64 { + // SAFETY: state pointer invariant; idx is in-range (we + // checked at construction in `State::segment`). + unsafe { sys::whisper_full_get_segment_t0_from_state(self.state.as_ptr(), self.idx) } + } + + /// End time, in centiseconds. + pub fn t1(&self) -> i64 { + // SAFETY: see `t0`. + unsafe { sys::whisper_full_get_segment_t1_from_state(self.state.as_ptr(), self.idx) } + } + + /// Decoded text for this segment. Returned slice is valid + /// while `self` is held — whisper.cpp owns the buffer. + pub fn text(&self) -> WhisperResult<&'a str> { + // SAFETY: idx in-range; whisper_full_get_segment_text returns + // a pointer into the state's owned buffer; we do not store + // it past the returned &str's lifetime. + let raw = + unsafe { sys::whisper_full_get_segment_text_from_state(self.state.as_ptr(), self.idx) }; + if raw.is_null() { + return Ok(""); + } + // SAFETY: whisper.cpp guarantees NUL-terminated UTF-8 text + // for any valid model vocabulary. + let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() }; + str::from_utf8(bytes).map_err(WhisperError::from) + } + + /// `no_speech_prob` for this segment — whisper.cpp's gate for + /// the silent-segment shortcut. Higher = more confident the + /// segment is silence. + pub fn no_speech_prob(&self) -> f32 { + // SAFETY: idx in-range; pure read. + unsafe { + sys::whisper_full_get_segment_no_speech_prob_from_state(self.state.as_ptr(), self.idx) + } + } + + /// Number of tokens decoded inside this segment. + pub fn n_tokens(&self) -> i32 { + // SAFETY: idx in-range; pure read. + unsafe { sys::whisper_full_n_tokens_from_state(self.state.as_ptr(), self.idx) } + } + + /// Borrow token `tok_idx` of this segment. Returns `None` if + /// `tok_idx` is out of range. + pub fn token(&self, tok_idx: i32) -> Option { + if tok_idx < 0 || tok_idx >= self.n_tokens() { + return None; + } + // SAFETY: indices in-range; whisper.cpp returns a value- + // typed `whisper_token_data`. We project into our private + // `Token` view via `Token::from_raw`. + let raw = unsafe { + sys::whisper_full_get_token_data_from_state(self.state.as_ptr(), self.idx, tok_idx) + }; + Some(Token::from_raw(raw)) + } + + /// `true` if the next segment marks a speaker change + /// (whisper.cpp's tinydiarize / `--tdrz` mode). Always + /// `false` outside TDRZ-enabled checkpoints; exposed for + /// completeness so callers running TDRZ models don't have to + /// reach into raw FFI. + pub fn speaker_turn_next(&self) -> bool { + // SAFETY: idx in-range; pure read. + unsafe { + sys::whisper_full_get_segment_speaker_turn_next_from_state(self.state.as_ptr(), self.idx) + } + } +} + +/// Per-token data exposed by whisper.cpp. +/// +/// Read-only snapshot. All fields are private; access goes +/// through `const fn` accessors to keep the public surface +/// stable as `whisper_token_data` evolves upstream. +#[derive(Debug, Clone, Copy)] +pub struct Token { + id: i32, + p: f32, + plog: f32, + pt: f32, + ptsum: f32, + t0: i64, + t1: i64, + vlen: f32, +} + +impl Token { + /// Token id in the model vocabulary. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn id(&self) -> i32 { + self.id + } + + /// Probability of this token at decode time. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn p(&self) -> f32 { + self.p + } + + /// Log-probability (matches whisper.cpp's internal score). + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn plog(&self) -> f32 { + self.plog + } + + /// Timestamp probability if this token is a `<|t|>` marker. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn pt(&self) -> f32 { + self.pt + } + + /// Sum of all timestamp-token probabilities. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn ptsum(&self) -> f32 { + self.ptsum + } + + /// DTW-derived start time (centiseconds), if available. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn t0(&self) -> i64 { + self.t0 + } + + /// DTW-derived end time (centiseconds), if available. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn t1(&self) -> i64 { + self.t1 + } + + /// Voice activity score, if available. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn vlen(&self) -> f32 { + self.vlen + } + + /// Internal constructor used by [`State`] when projecting + /// `whisper_token_data` into the safe view. + #[cfg_attr(not(tarpaulin), inline(always))] + pub(crate) const fn from_raw(raw: crate::sys::whisper_token_data) -> Self { + Self { + id: raw.id, + p: raw.p, + plog: raw.plog, + pt: raw.pt, + ptsum: raw.ptsum, + t0: raw.t0, + t1: raw.t1, + vlen: raw.vlen, + } + } +} diff --git a/whispercpp/src/sys.rs b/whispercpp/src/sys.rs new file mode 100644 index 0000000..e5abada --- /dev/null +++ b/whispercpp/src/sys.rs @@ -0,0 +1,13 @@ +//! Raw FFI bindings to whisper.cpp, sourced from the sibling +//! `whispercpp-sys` crate. +//! +//! This module is a thin re-export: `whispercpp-sys` owns the +//! build (cmake against the vendored `whisper.cpp/` submodule, +//! pinned to a patched fork branch) and the bindgen output. We +//! re-export here so `crate::sys::whisper_*` resolves through +//! the safe-wrapper crate without a path prefix change. +//! +//! All `unsafe` lives below this re-export boundary. Safe +//! wrappers in `context.rs`, `state.rs`, `params.rs` are +//! responsible for upholding lifetime + aliasing invariants. +pub use whispercpp_sys::*;