diff --git a/README.md b/README.md index afbe4f2..26f70e5 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,17 @@ Safe Rust bindings for [whisper.cpp][whisper-cpp] speech-to-text inference. - **Backend matrix.** Metal, CoreML, Vulkan, OpenCL, CUDA, ROCm (HIP), oneAPI (SYCL), Moore Threads (MUSA), OpenVINO, OpenBLAS — all opt-in via Cargo features. +- **DTW token timestamps.** Built-in token-level timing via DTW + over the configured alignment heads (`AlignmentHeadsPreset`), + with safe per-token availability through + `Token::t_dtw() -> Option`. See + [DTW timestamps](#dtw-timestamps). ## Installation ```toml [dependencies] -whispercpp = "0.1" +whispercpp = "0.2" ``` The default build is plain CPU. Opt into accelerators per-target: @@ -45,11 +50,11 @@ The default build is plain CPU. Opt into accelerators per-target: ```toml # macOS Apple Silicon [target.'cfg(all(target_os = "macos", target_arch = "aarch64"))'.dependencies] -whispercpp = { version = "0.1", features = ["metal", "coreml"] } +whispercpp = { version = "0.2", features = ["metal", "coreml"] } # Linux + NVIDIA [target.'cfg(all(target_os = "linux", target_arch = "x86_64"))'.dependencies] -whispercpp = { version = "0.1", features = ["cuda"] } +whispercpp = { version = "0.2", features = ["cuda"] } ``` ## Examples @@ -80,6 +85,70 @@ GPU backends require the corresponding vendor SDK (CUDA Toolkit, ROCm, oneAPI, etc.) installed at link time. CI exercises the bundled CPU path on Linux/macOS/Windows and Metal+CoreML on macOS. +## DTW timestamps + +Token-level timestamps via DTW over the decoder's +cross-attention weights. Enable at `Context` construction: + +```rust +use whispercpp::{Context, ContextParams, AlignmentHeadsPreset}; + +let ctx = Context::new( + "ggml-large-v3-turbo.bin", + ContextParams::new() + .with_use_gpu(true) + .with_dtw_token_timestamps(true) + .with_dtw_aheads_preset(AlignmentHeadsPreset::LargeV3Turbo), +)?; +``` + +Match `AlignmentHeadsPreset` to your model — the safe API +ships every standard checkpoint preset (`TinyEn` through +`LargeV3Turbo`). Mismatched presets produce noisy timings +without erroring; bound-checked by `required_dtw_mem_size_for` +and rejected at load if the model's `n_text_ctx` exceeds +`SUPPORTED_DTW_N_TEXT_CTX`. + +After `state.full(¶ms, &samples)`, read per-token DTW +timing as `Option` (centiseconds): + +```rust +for i in 0..state.n_segments() { + let seg = state.segment(i).unwrap(); + for j in 0..seg.n_tokens() { + let token = seg.token(j).unwrap(); + match token.t_dtw() { + Some(t) => println!("token={} t_dtw={:.2}s", + token.id(), t as f64 / 100.0), + None => /* DTW unavailable for this token */ (), + } + } +} +``` + +`None` covers four cases: DTW not enabled at construction, +non-text token (special / timestamp), per-segment DTW skip +because `Params::set_audio_ctx` was overridden too small, or +audio window too short for the median-filter pass. The +underlying C-side patch (`whispercpp-sys: dtw t_dtw sentinel +init`) initialises `t_dtw = -1` before every DTW pass so the +sentinel uniquely identifies "unavailable" — `Some(0)` is a +valid timestamp (token at audio offset 0), not the sentinel. + +Constraints (enforced at `Context::new`): + +| Constraint | What it does | +|---|---| +| `dtw + flash_attn` | Rejected. Whisper.cpp silently disables DTW under flash-attn; the wrapper refuses the combination explicitly. | +| `dtw + custom n_text_ctx > 448` | Rejected. The DTW scratch arena is sized for standard whisper checkpoints; non-standard models with larger text context would overflow it. | +| `dtw_mem_size` | Clamped to `[MIN_DTW_MEM_SIZE, MAX_DTW_MEM_SIZE]`, then raised to the per-preset minimum from `required_dtw_mem_size_for`. | + +Native abort paths inside the DTW helper +(allocation failures, invalid windows, decoder errors) are +all converted to `WhisperError::StateLost` via the existing +exception shim — no `abort()` is reachable from safe Rust +through this surface. + ## Memory safety `whisper.cpp` is a binary parser of attacker-controllable model files diff --git a/whispercpp-sys/Cargo.toml b/whispercpp-sys/Cargo.toml index eedc6e2..49872be 100644 --- a/whispercpp-sys/Cargo.toml +++ b/whispercpp-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "whispercpp-sys" -version = "0.1.0" +version = "0.2.0" edition.workspace = true rust-version.workspace = true license.workspace = true diff --git a/whispercpp-sys/build.rs b/whispercpp-sys/build.rs index 37a3088..d06f112 100644 --- a/whispercpp-sys/build.rs +++ b/whispercpp-sys/build.rs @@ -116,51 +116,142 @@ fn bundled_build() { /// upstream AND someone manually replacing the submodule with /// a different tree. fn verify_patched_source(whisper_src: &Path) { - let target = whisper_src.join("src").join("whisper.cpp"); - let body = match std::fs::read_to_string(&target) { - Ok(b) => b, - Err(e) => panic!( - "whispercpp-sys: failed to read {} for patch verification: {e}", - target.display() - ), - }; - // Sentinels chosen from the highest-leverage patches — // the ones whose absence would re-introduce the - // double-free / null-deref / leak hazards the Rust - // wrapper assumes are closed. - const REQUIRED_MARKERS: &[&str] = &[ - "whispercpp-sys: kv_cache_free idempotent fix", - "whispercpp-sys: read_safe zero-init", - "whispercpp-sys: init_state RAII entry", - "whispercpp-sys: init_context RAII entry", - "whispercpp-sys: tensor header validation (model_load)", - "whispercpp-sys: ggml_log_set once-per-process", - "whispercpp-sys: hparams validation", - "whispercpp-sys: lang_str null guard", - "whispercpp-sys: special-token bounds check", - "whispercpp-sys: path_model assignment guard", - "whispercpp-sys: sched abort callback wiring", - "whispercpp-sys: vad_init RAII guard", + // double-free / null-deref / leak / native-abort hazards + // the Rust wrapper assumes are closed. Each entry is + // `(file_relative_to_whisper_src, expected_marker)`; the + // build hard-fails if any are absent. + // + // We split across both `src/whisper.cpp` and + // `ggml/src/ggml.c` because some safety patches sit in + // each. The ggml patch (OOM-safe `ggml_init`) is what + // turns the DTW scratch-allocation OOM path from + // `abort()`-uncatchable into a `WhisperError::StateLost` + // recovery — without it the wrapper's `dtw scratch + // alloc-fail throws` patch is dead code. + const REQUIRED_MARKERS: &[(&str, &str)] = &[ + ( + "src/whisper.cpp", + "whispercpp-sys: kv_cache_free idempotent fix", + ), + ("src/whisper.cpp", "whispercpp-sys: read_safe zero-init"), + ("src/whisper.cpp", "whispercpp-sys: init_state RAII entry"), + ("src/whisper.cpp", "whispercpp-sys: init_context RAII entry"), + ( + "src/whisper.cpp", + "whispercpp-sys: tensor header validation (model_load)", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: ggml_log_set once-per-process", + ), + ("src/whisper.cpp", "whispercpp-sys: hparams validation"), + ("src/whisper.cpp", "whispercpp-sys: lang_str null guard"), + ( + "src/whisper.cpp", + "whispercpp-sys: special-token bounds check", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: path_model assignment guard", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: sched abort callback wiring", + ), + ("src/whisper.cpp", "whispercpp-sys: vad_init RAII guard"), + ("src/whisper.cpp", "whispercpp-sys: dtw scratch RAII guard"), + ( + "src/whisper.cpp", + "whispercpp-sys: dtw scratch alloc-fail throws", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: dtw token assignment bounded", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: dtw short-window medfilt clamp", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: dtw audio_ctx override guard", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: ggml_init throw-on-null wrapper", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: dtw decode failure throws", + ), + ("src/whisper.cpp", "whispercpp-sys: kv buffer null throws"), + ( + "src/whisper.cpp", + "whispercpp-sys: dtw backtrace impossible-case throws", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: dtw aheads_cross_QKs invariants throw", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: token_to_str sparse-vocab no-throw", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: hparams head divisibility check", + ), + ( + "src/whisper.cpp", + "whispercpp-sys: dtw backend compute throws", + ), + ("src/whisper.cpp", "whispercpp-sys: dtw t_dtw sentinel init"), + ( + "ggml/src/ggml.c", + "whispercpp-sys: ggml_init OOM-safe context alloc", + ), ]; - let missing: Vec<&str> = REQUIRED_MARKERS - .iter() - .copied() - .filter(|m| !body.contains(m)) - .collect(); + // Read each referenced file once, then check every + // marker that points at it. Group markers by file so we + // don't re-read the same source on every iteration. + use std::collections::HashMap; + let mut by_file: HashMap<&str, Vec<&str>> = HashMap::new(); + for (file, marker) in REQUIRED_MARKERS { + by_file.entry(*file).or_default().push(*marker); + } + + let mut missing: Vec<(&str, &str)> = Vec::new(); + for (rel, markers) in &by_file { + let target = whisper_src.join(rel); + let body = match std::fs::read_to_string(&target) { + Ok(b) => b, + Err(e) => panic!( + "whispercpp-sys: failed to read {} for patch verification: {e}", + target.display() + ), + }; + for m in markers { + if !body.contains(*m) { + missing.push((*rel, *m)); + } + } + } if !missing.is_empty() { panic!( - "whispercpp-sys: the linked whisper.cpp source at {} is missing the rust-branch patches \ + "whispercpp-sys: the linked whisper.cpp source under {} is missing rust-branch patches \ (required marker{} absent: {:?}).\n\n\ The Rust safety surface depends on these patches; building against unpatched upstream \ - reintroduces multi-decoder double-free / use-after-free / null-deref classes.\n\n\ + reintroduces multi-decoder double-free / use-after-free / null-deref / native-abort \ + classes.\n\n\ Fix: ensure the submodule tracks `Findit-AI/whisper.cpp` branch `rust`. Run\n \ git submodule update --init --recursive\n\ from the repo root. If you intentionally pointed at a different source, add equivalent \ patches and the matching marker comments before retrying.", - target.display(), + whisper_src.display(), if missing.len() == 1 { "" } else { "s" }, missing, ); diff --git a/whispercpp-sys/whisper.cpp b/whispercpp-sys/whisper.cpp index 9c4881d..8932cae 160000 --- a/whispercpp-sys/whisper.cpp +++ b/whispercpp-sys/whisper.cpp @@ -1 +1 @@ -Subproject commit 9c4881d8f5cd2224224e46ec8d012cce348be39d +Subproject commit 8932cae83df538fd40bbc045b6e57d14986263fb diff --git a/whispercpp/Cargo.toml b/whispercpp/Cargo.toml index 06d2c45..9988ce4 100644 --- a/whispercpp/Cargo.toml +++ b/whispercpp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "whispercpp" -version = "0.1.0" +version = "0.2.0" edition.workspace = true rust-version.workspace = true license.workspace = true @@ -76,7 +76,7 @@ openblas = ["whispercpp-sys/openblas"] # OpenBLAS # `../whispercpp-sys/`. All `unsafe extern "C"` declarations # live there; this crate only ever calls them behind safe # wrappers. -whispercpp-sys = { version = "0.1", path = "../whispercpp-sys", default-features = false } +whispercpp-sys = { version = "0.2", path = "../whispercpp-sys", default-features = false } # Public error type. `thiserror` keeps things light. thiserror = { version = "2", default-features = false } # Inline small strings (≤23 bytes) for error payloads — paths, diff --git a/whispercpp/TODO.md b/whispercpp/TODO.md index 63a03f3..a561ee2 100644 --- a/whispercpp/TODO.md +++ b/whispercpp/TODO.md @@ -100,18 +100,6 @@ requires more design than a 1:1 port. Symbols: `whisper_set_log_callback`, `set_debug_mode`, `whisper_log_callback`. -### DTW token timestamps - -Whispery uses wav2vec2 forced alignment for word-level timing. -whisper.cpp's DTW path is a parallel mechanism with its own -configuration (`dtw_aheads`, `dtw_n_top`, `dtw_mem_size`). Wrapping -it would invite confusion about which timestamping path is -authoritative. - -Symbols: `whisper_full_params::dtw_token_timestamps` (true at -construction, but `Params::set_dtw_*` and `dtw_aheads` array are -not exposed), `whisper_aheads`, `whisper_full_get_token_dtw_t0_*`. - ### Buffer-load constructors We support `Context::new(path, params)` only. Loading from an diff --git a/whispercpp/src/context.rs b/whispercpp/src/context.rs index 9601892..2d2aaa0 100644 --- a/whispercpp/src/context.rs +++ b/whispercpp/src/context.rs @@ -53,6 +53,326 @@ pub(crate) fn init_lock() -> MutexGuard<'static, ()> { LOCK.lock().unwrap_or_else(|e| e.into_inner()) } +/// Default DTW working-memory budget (128 MiB). +/// +/// Forwarded into [`ContextParams::dtw_mem_size`] when callers +/// don't override it. Adequate for the small-head presets +/// (`Tiny*`, `Base*`, `Small`, `Medium`, `LargeV1`, `LargeV3`, +/// `LargeV3Turbo` — all ≤ 10 alignment heads). For higher-head +/// presets (`SmallEn` 19 heads, `MediumEn` 18 heads, `LargeV2` +/// 23 heads), [`Context::new`] silently raises the value to +/// the per-preset requirement returned by +/// [`required_dtw_mem_size_for`]; callers don't need to do +/// the math themselves. +/// +/// Whisper.cpp's struct comment marks `dtw_mem_size` as +/// "TODO: remove" — the buffer is expected to migrate behind +/// the encoder's standard arena. Until then we forward it +/// faithfully; the safe API stays compatible when upstream +/// drops the field. +pub const DEFAULT_DTW_MEM_SIZE: usize = 128 * 1024 * 1024; + +/// Absolute lower bound applied by +/// [`ContextParams::with_dtw_mem_size`] (and by +/// [`ContextParams::new`]). +/// +/// Whisper.cpp's +/// `whisper_exp_compute_token_level_timestamps_dtw` allocates +/// a scratch `ggml_context` sized by `dtw_mem_size`. The DTW +/// pipeline then materialises three live `n_tokens × +/// n_audio_tokens × n_heads × f32` tensors (the working +/// cross-attention tensor, the `ggml_norm` output, and the +/// `ggml_map_custom1` median-filter output) plus a small +/// backtrace lattice. The ggml context header alone needs a +/// few MiB before any tensor lands — anything below that +/// floor makes `ggml_init` return NULL and the next access +/// fault. +/// +/// `ggml_init` returns NULL when the requested arena is too +/// small, and `ggml_new_tensor_3d` aborts (via `GGML_ASSERT`) +/// when the arena cannot fit a tensor. Both shapes terminate +/// the process from inside whisper.cpp without giving the +/// `whispercpp_full_with_state` exception shim a chance to +/// catch — `GGML_ASSERT` calls `abort()`, and the NULL deref +/// is a fatal signal. **Both are reachable from safe Rust** +/// if the budget is unconstrained. +/// +/// Floor at 128 MiB covers the smallest preset's realistic +/// peak with comfortable headroom. Higher-head presets +/// require more; [`Context::new`] enforces the per-preset +/// minimum on top of this absolute floor — see +/// [`required_dtw_mem_size_for`] for the formula. +pub const MIN_DTW_MEM_SIZE: usize = DEFAULT_DTW_MEM_SIZE; + +/// Upper bound applied by [`ContextParams::with_dtw_mem_size`] +/// (and by [`ContextParams::new`]). +/// +/// `ggml_init` mallocs `dtw_mem_size + WHISPER_GGML_OBJECT_SIZE` +/// internally; passing `usize::MAX` overflows that addition and +/// drives `ggml_init` to NULL on the malloc step, with the +/// same null-deref / GGML_ASSERT consequences as the lower- +/// bound failure shape (see [`MIN_DTW_MEM_SIZE`]). +/// +/// The cap is **target-pointer-width-dependent**: +/// +/// * 64-bit (`target_pointer_width = "64"`): 4 GiB — three +/// orders of magnitude above the realistic worst case +/// ([`required_dtw_mem_size_for(LargeV2)`][required_dtw_mem_size_for] +/// = 278 MiB), so a `usize::MAX` slip collapses to a +/// large-but-allocatable value rather than an +/// overflow-induced abort. +/// * 32-bit (`target_pointer_width = "32"`): 1 GiB. +/// `4 * 1024 * 1024 * 1024 = 2^32` exceeds `usize::MAX = +/// 2^32 - 1` on 32-bit targets, which would make the crate +/// fail to compile there. 1 GiB is still ~3.7× the +/// `LargeV2` per-preset minimum and well below +/// `usize::MAX`, so the safety property (saturate above +/// the realistic worst case to dodge `ggml_init` overflow) +/// is preserved. +/// * 16-bit (`target_pointer_width = "16"`): same value as +/// 32-bit; falls back to the smaller cap. Whisper.cpp +/// does not realistically run on 16-bit targets. +#[cfg(target_pointer_width = "64")] +pub const MAX_DTW_MEM_SIZE: usize = 4 * 1024 * 1024 * 1024; + +/// 32-bit / 16-bit ceiling — see [`MAX_DTW_MEM_SIZE`]'s +/// docstring on the 64-bit variant for the full explanation. +#[cfg(not(target_pointer_width = "64"))] +pub const MAX_DTW_MEM_SIZE: usize = 1024 * 1024 * 1024; + +/// Clamp a DTW memory budget to `[MIN_DTW_MEM_SIZE, +/// MAX_DTW_MEM_SIZE]`. `const fn` so it composes inside +/// [`ContextParams::new`]'s defaults. +/// +/// `0` and other below-floor values rise to +/// [`MIN_DTW_MEM_SIZE`]; `usize::MAX` and other ceiling-busting +/// values fall to [`MAX_DTW_MEM_SIZE`]. Both ends close +/// crash-from-safe-Rust paths inside whisper.cpp's DTW +/// allocator (see those constants' docs). +#[cfg_attr(not(tarpaulin), inline(always))] +const fn clamp_dtw_mem_size(n: usize) -> usize { + if n < MIN_DTW_MEM_SIZE { + MIN_DTW_MEM_SIZE + } else if n > MAX_DTW_MEM_SIZE { + MAX_DTW_MEM_SIZE + } else { + n + } +} + +/// Alignment-head count for a given preset. +/// +/// Verified against `whisper.cpp/src/whisper.cpp:399-410` +/// (the `g_aheads_*` static arrays) and `:412-424` (the +/// `g_aheads` map that pairs each preset with its head count). +/// `None` and any never-mapped variant return `0`. +#[cfg_attr(not(tarpaulin), inline(always))] +const fn alignment_head_count(preset: AlignmentHeadsPreset) -> usize { + match preset { + AlignmentHeadsPreset::None => 0, + AlignmentHeadsPreset::TinyEn => 8, + AlignmentHeadsPreset::Tiny => 6, + AlignmentHeadsPreset::BaseEn => 5, + AlignmentHeadsPreset::Base => 8, + AlignmentHeadsPreset::SmallEn => 19, + AlignmentHeadsPreset::Small => 10, + AlignmentHeadsPreset::MediumEn => 18, + AlignmentHeadsPreset::Medium => 6, + AlignmentHeadsPreset::LargeV1 => 9, + AlignmentHeadsPreset::LargeV2 => 23, + AlignmentHeadsPreset::LargeV3 => 10, + AlignmentHeadsPreset::LargeV3Turbo => 6, + } +} + +/// Largest `n_text_ctx` (decoder text-context window) the +/// safe DTW wrapper budgets for. +/// +/// Every standard whisper checkpoint — `tiny.en` through +/// `large-v3-turbo` — has `n_text_ctx = 448`, so this matches +/// the universe of officially-released models. Some +/// fine-tuned / extended-context checkpoints carry larger +/// values (the bundled GGUF loader accepts `n_text_ctx` up +/// to several thousand), and the DTW helper sizes its +/// working tensor from the actual decoder output rather than +/// from this constant. To prevent a non-standard model from +/// silently overflowing the [`required_dtw_mem_size_for`] +/// budget and tripping `GGML_ASSERT` inside +/// `ggml_new_tensor_3d` during decode, [`Context::new`] +/// reads the loaded model's `n_text_ctx` after init and +/// refuses to publish a `Context` that has DTW enabled +/// together with `n_text_ctx > SUPPORTED_DTW_N_TEXT_CTX`. +/// Affected callers can either: +/// +/// 1. Disable DTW for that model — call +/// [`ContextParams::with_dtw_token_timestamps`] with +/// `false`. The rest of the API stays available. +/// 2. Use a standard checkpoint. +/// +/// Pre-allocating for a higher upper bound (e.g. 2048) was +/// considered and rejected: it would force a ~3-4× larger +/// DTW arena (≥ 1.27 GiB on `LargeV2`) on every DTW-enabled +/// context, including the standard-checkpoint case that +/// dominates real usage. +pub const SUPPORTED_DTW_N_TEXT_CTX: i32 = 448; + +/// Worst-case DTW scratch requirement for a given preset. +/// +/// Whisper.cpp's +/// `whisper_exp_compute_token_level_timestamps_dtw` materialises +/// up to three live `n_tokens × n_audio_tokens × n_heads × f32` +/// tensors during the DTW pipeline (the working cross-attention +/// tensor, the `ggml_norm` output, and the `ggml_map_custom1` +/// median-filter output). Worst-case bounds: +/// +/// * `n_tokens` ≤ [`SUPPORTED_DTW_N_TEXT_CTX`] — whisper's +/// `n_text_ctx` for every standard checkpoint. Non-standard +/// checkpoints with larger values are rejected by +/// [`Context::new`] when DTW is on; see that constant's +/// docs for the contract. +/// * `n_audio_tokens` = `n_frames / 2`, with `n_frames` capped +/// at `WHISPER_CHUNK_SIZE * 100 = 3000` (centiseconds for a +/// 30 s chunk), giving a max of 1500. +/// * `n_heads` from `alignment_head_count(preset)` — +/// 23 for `LargeV2`, 19 for `SmallEn`, 18 for `MediumEn`, +/// ≤ 10 for the rest. +/// +/// Per-tensor: `SUPPORTED_DTW_N_TEXT_CTX × 1500 × n_heads × +/// 4 bytes`. With three live tensors plus a 50% safety +/// margin for the backtrace lattice / tensor metadata / +/// backend compute scratch, the minimum scales linearly in +/// `n_heads`. For presets whose computed minimum falls below +/// [`MIN_DTW_MEM_SIZE`], the floor wins (the small-preset +/// case where ggml context overhead dominates). +/// +/// Returns `0` for [`AlignmentHeadsPreset::None`] — DTW is +/// disabled and no scratch is needed. +/// +/// Earlier versions of this wrapper derived the floor from +/// the wrong worst-case shapes (`n_audio_tokens ≤ 750` instead +/// of 1500, and an underestimate of the per-preset head +/// counts). The 128 MiB floor those numbers produced was too +/// small for `LargeV2` / `SmallEn` / `MediumEn` — `ggml_init` +/// could return NULL or `ggml_new_tensor_3d` could `GGML_ASSERT` +/// during decode. This function fixes the analysis by reading +/// head counts straight from +/// `whisper.cpp:399-424` and using whisper.cpp's actual +/// dimension caps. +#[cfg_attr(not(tarpaulin), inline(always))] +pub const fn required_dtw_mem_size_for(preset: AlignmentHeadsPreset) -> usize { + let n_heads = alignment_head_count(preset); + if n_heads == 0 { + return 0; + } + // Worst-case dimensions baked from whisper.cpp: + // - n_tokens (text context cap): SUPPORTED_DTW_N_TEXT_CTX + // - n_audio_tokens: WHISPER_CHUNK_SIZE * 100 / 2 = 1500 + // - bytes per element (f32): 4 + let per_tensor = (SUPPORTED_DTW_N_TEXT_CTX as usize) * 1500 * n_heads * 4; + // Three live large tensors during the DTW pipeline + 50% + // safety margin for backtrace state, tensor metadata, and + // ggml backend scratch. + let with_safety = (per_tensor * 3) * 3 / 2; + // Floor at MIN_DTW_MEM_SIZE — even small presets need at + // least this much for ggml context overhead. + if with_safety < MIN_DTW_MEM_SIZE { + MIN_DTW_MEM_SIZE + } else if with_safety > MAX_DTW_MEM_SIZE { + MAX_DTW_MEM_SIZE + } else { + with_safety + } +} + +/// Which set of cross-attention heads whisper.cpp samples for +/// the DTW backtrace. Mirrors `whisper_alignment_heads_preset`. +/// +/// Each shipping whisper checkpoint has its own set of "alignment +/// heads" — the decoder heads whose attention patterns correlate +/// best with the underlying acoustic timing. The presets below +/// pick those known-good heads; using the wrong preset for a +/// given checkpoint produces noisy timestamps. Match the preset +/// to the model file. +/// +/// # Why `NTopMost` and `Custom` are not exposed +/// +/// `WHISPER_AHEADS_N_TOP_MOST` is intentionally omitted because +/// the resulting alignment-head count is `n_top × n_text_head`, +/// which on a large model (32 layers × 20 heads) reaches 640 +/// heads — pushing the DTW working tensor (`n_tokens × +/// n_audio_tokens × n_heads × f32`) to ~860 MiB and overflowing +/// even [`MAX_DTW_MEM_SIZE`] under realistic decoder context. +/// `ggml_new_tensor_3d` aborts via `GGML_ASSERT` when the arena +/// cannot fit, terminating the process from inside whisper.cpp +/// before the exception shim can catch. Exposing the preset +/// requires the wrapper to compute scratch size from the +/// loaded model's head count first; that's not in this iteration. +/// +/// `WHISPER_AHEADS_CUSTOM` is also omitted: the C variant +/// requires a pointer to a caller-owned `whisper_ahead` array, +/// which would force `ContextParams` to own a `Vec` and lose +/// the `Copy` derive — and shares the same scratch-size +/// validation gap as `N_TOP_MOST`. +/// +/// Result: only the validated per-checkpoint presets ship, each +/// with a known-bounded alignment-head count whose scratch +/// requirement comfortably fits [`DEFAULT_DTW_MEM_SIZE`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AlignmentHeadsPreset { + /// No DTW alignment heads — disables DTW even when + /// [`ContextParams::dtw_token_timestamps`] is `true`. + None, + /// `tiny.en` — English-only. + TinyEn, + /// `tiny` — multilingual. + Tiny, + /// `base.en` — English-only. + BaseEn, + /// `base` — multilingual. + Base, + /// `small.en` — English-only. + SmallEn, + /// `small` — multilingual. + Small, + /// `medium.en` — English-only. + MediumEn, + /// `medium` — multilingual. + Medium, + /// `large-v1`. + LargeV1, + /// `large-v2`. + LargeV2, + /// `large-v3`. + LargeV3, + /// `large-v3-turbo` — the distilled-decoder variant used by + /// whispery in production. + LargeV3Turbo, +} + +impl AlignmentHeadsPreset { + /// Map to the C enum value bindgen produced. `const fn` so + /// the conversion participates in `ContextParams`'s + /// `with_*` chain. + #[cfg_attr(not(tarpaulin), inline(always))] + const fn to_raw(self) -> sys::whisper_alignment_heads_preset { + match self { + Self::None => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE, + Self::TinyEn => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN, + Self::Tiny => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY, + Self::BaseEn => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN, + Self::Base => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE, + Self::SmallEn => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN, + Self::Small => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL, + Self::MediumEn => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN, + Self::Medium => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM, + Self::LargeV1 => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1, + Self::LargeV2 => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2, + Self::LargeV3 => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3, + Self::LargeV3Turbo => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3_TURBO, + } + } +} + /// Knobs forwarded to `whisper_context_default_params` before /// loading. Mirrors the subset of `whisper_context_params` whispery /// uses today. @@ -61,22 +381,43 @@ pub(crate) fn init_lock() -> MutexGuard<'static, ()> { /// accessors and `with_*` builder methods so the type's invariants /// stay encapsulated and the public surface evolves /// independently of the underlying C struct. +/// +/// # DTW (token-level alignment via cross-attention) +/// +/// DTW is enabled at MODEL LOAD time, not per-decode. Whisper.cpp +/// builds a slightly different decoder graph when DTW is on (the +/// alignment heads' attention weights need to be exposed to the +/// post-decode DTW pass), so the choice has to be made before +/// [`Context::new`] runs. +/// +/// Once enabled, every [`crate::State::full`] call against the +/// resulting context populates [`crate::Token::t_dtw`] alongside +/// the standard `t0`/`t1` timestamp-token timings. The DTW +/// timestamp is independently derived from the cross-attention +/// pattern and is generally more robust to long silences and +/// repeated tokens than the timestamp-token path. #[derive(Debug, Clone, Copy)] pub struct ContextParams { use_gpu: bool, gpu_device: i32, flash_attn: bool, + dtw_token_timestamps: bool, + dtw_aheads_preset: AlignmentHeadsPreset, + dtw_mem_size: usize, } impl ContextParams { /// Defaults: GPU on (Metal/CUDA where compiled in), device 0, - /// flash-attn off. + /// flash-attn off, DTW off. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn new() -> Self { Self { use_gpu: true, gpu_device: 0, flash_attn: false, + dtw_token_timestamps: false, + dtw_aheads_preset: AlignmentHeadsPreset::None, + dtw_mem_size: DEFAULT_DTW_MEM_SIZE, } } @@ -123,6 +464,76 @@ impl ContextParams { self.flash_attn = on; self } + + // ── DTW (token-level alignment via cross-attention) ────────── + + /// Whether the loaded context will compute DTW per-token + /// timestamps during decode. + /// + /// When `true`, the decoder graph is built to expose + /// cross-attention weights from the heads selected by + /// [`Self::dtw_aheads_preset`], and each + /// [`crate::Token::t_dtw`] is populated after decode. Costs + /// ~5–15% extra decode time and a one-time + /// [`Self::dtw_mem_size`] allocation; eliminates a separate + /// forced-alignment pass for callers that only need + /// approximate per-token timing. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn dtw_token_timestamps(&self) -> bool { + self.dtw_token_timestamps + } + + /// Chained setter for [`Self::dtw_token_timestamps`]. + /// + /// When enabling DTW, also pick a matching preset via + /// [`Self::with_dtw_aheads_preset`] — leaving the preset on + /// [`AlignmentHeadsPreset::None`] disables DTW even when this + /// flag is `true`. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_dtw_token_timestamps(mut self, on: bool) -> Self { + self.dtw_token_timestamps = on; + self + } + + /// Which alignment-heads preset DTW samples. Each shipping + /// whisper checkpoint has its own validated preset; mismatched + /// presets produce noisy timestamps without erroring. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn dtw_aheads_preset(&self) -> AlignmentHeadsPreset { + self.dtw_aheads_preset + } + + /// Chained setter for [`Self::dtw_aheads_preset`]. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_dtw_aheads_preset(mut self, preset: AlignmentHeadsPreset) -> Self { + self.dtw_aheads_preset = preset; + self + } + + /// Working-memory budget (in bytes) for the DTW backtrace. + /// Default [`DEFAULT_DTW_MEM_SIZE`] (128 MiB). + /// + /// Whisper.cpp's struct comment flags this field as + /// "TODO: remove" — the buffer is expected to migrate behind + /// the encoder's standard arena. The Rust API will keep the + /// setter when that lands so callers don't break; the value + /// will simply become a no-op. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn dtw_mem_size(&self) -> usize { + self.dtw_mem_size + } + + /// Chained setter for [`Self::dtw_mem_size`]. + /// + /// Clamped to `[MIN_DTW_MEM_SIZE, MAX_DTW_MEM_SIZE]`. Both + /// ends close native-code abort paths reachable from safe + /// Rust through whisper.cpp's DTW arena allocator — see the + /// constants' docs for the full failure analysis. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_dtw_mem_size(mut self, n: usize) -> Self { + self.dtw_mem_size = clamp_dtw_mem_size(n); + self + } } impl Default for ContextParams { @@ -198,6 +609,95 @@ impl Context { cparams.use_gpu = params.use_gpu(); cparams.gpu_device = params.gpu_device(); cparams.flash_attn = params.flash_attn(); + // DTW (token-level alignment via cross-attention). The + // `dtw_aheads` pointer-and-length pair stays at whatever + // `whisper_context_default_params` sets it to (currently + // `{ n_heads: 0, heads: NULL }`, which whisper.cpp reads + // as "no custom heads — fall back to the preset"). + // `AlignmentHeadsPreset::Custom` is not exposed at the + // safe-API level today; if a downstream caller needs + // hand-tuned heads, that's the field to thread through. + // + // Two safety conversions before forwarding: + // + // 1. `dtw_token_timestamps && preset == None` is a + // misconfiguration whisper.cpp aborts on + // (`WHISPER_ASSERT(ctx->params.dtw_aheads_preset != + // WHISPER_AHEADS_NONE)` in + // `whisper_exp_compute_token_level_timestamps_dtw`). + // Reachable from safe Rust because `with_dtw_*` + // setters compose independently. Coerce to "DTW off" + // instead of letting the abort cross the FFI: no + // preset means there's no useful DTW work to do + // anyway. + // + // 2. `dtw_mem_size` clamps to + // `[MIN_DTW_MEM_SIZE, MAX_DTW_MEM_SIZE]`. The setter + // already clamps, but we re-clamp here to defend + // against `ContextParams` constructed via field-init + // syntax in some future internal path (or callers + // poking through `Default + struct update`). Cheap. + let dtw_on = + params.dtw_token_timestamps() && params.dtw_aheads_preset() != AlignmentHeadsPreset::None; + + // Reject `dtw_on + flash_attn` BEFORE the FFI init. + // whisper.cpp's loader logs a warning and silently + // disables DTW under flash-attention + // (`whisper.cpp:3956`). Without this check the safe Rust + // API would return `Ok(Context)` for a configuration + // whose docs promise `Token::t_dtw` will be populated — + // every t_dtw stays at the default 0 with no signal to + // the caller. Refuse the combination explicitly so the + // caller has to disable one knob and document which. + // The check happens before `init_lock` so we avoid + // taking the global init mutex for a configuration we + // know is going to fail. + if dtw_on && params.flash_attn() { + return Err(WhisperError::ContextLoad { + path: smol_str::SmolStr::new(path_str.as_ref()), + reason: smol_str::SmolStr::new( + "DTW token timestamps cannot be combined with flash_attn — \ + whisper.cpp silently disables DTW under flash_attn. \ + Set with_flash_attn(false) or with_dtw_token_timestamps(false).", + ), + }); + } + + cparams.dtw_token_timestamps = dtw_on; + cparams.dtw_aheads_preset = if dtw_on { + params.dtw_aheads_preset().to_raw() + } else { + AlignmentHeadsPreset::None.to_raw() + }; + // `dtw_n_top` only matters when preset is N_TOP_MOST, + // which is not exposed by the safe API + // (see `AlignmentHeadsPreset`'s doc-comment for the + // scratch-size analysis that motivated its omission). + // Leave the C field at whatever + // `whisper_context_default_params()` set it to. + // + // DTW memory budget: clamp the user value first, then — + // when DTW is actually on — raise to the per-preset + // minimum from `required_dtw_mem_size_for`. The 128 MiB + // floor is adequate for small-head presets but + // dangerously low for `SmallEn` / `MediumEn` / `LargeV2`, + // whose 18–23 alignment heads drive the DTW working + // tensor past the budget; without this raise the + // `ggml_new_tensor_3d` call inside the DTW path + // `GGML_ASSERT`s and aborts the process. Silent raise + // matches the existing "clamp invalid inputs to safe" + // pattern in `Params::new`. + let clamped_user = clamp_dtw_mem_size(params.dtw_mem_size()); + cparams.dtw_mem_size = if dtw_on { + let required = required_dtw_mem_size_for(params.dtw_aheads_preset()); + if clamped_user >= required { + clamped_user + } else { + required + } + } else { + clamped_user + }; // Serialise init: backend probing inside whisper.cpp // touches ggml's global logger state. @@ -253,6 +753,41 @@ impl Context { let raw = unsafe { sys::whispercpp_init_from_file_no_state(cpath.as_ptr(), cparams) }; if let Some(ptr) = NonNull::new(raw) { + // DTW-enabled contexts validate that the loaded model's + // text-context window fits the budget assumed by + // [`required_dtw_mem_size_for`]. Standard whisper + // checkpoints all carry `n_text_ctx = 448`, but the GGUF + // loader accepts larger values from custom / extended- + // context fine-tunes. If a non-standard model with + // `n_text_ctx > SUPPORTED_DTW_N_TEXT_CTX` is loaded + // alongside DTW, the DTW helper sizes its working tensor + // from `state->aheads_cross_QKs->ne[0]` (= actual decoded + // tokens, bounded by `n_text_ctx`) and overflows the + // pre-allocated arena — `ggml_new_tensor_3d` + // `GGML_ASSERT`s and the process aborts. Pre-allocating + // for a higher `n_text_ctx` upper bound (e.g. 2048) + // would force ~3-4× more DTW arena on every context; + // refusing here keeps the common-case budget tight and + // gives the caller an explicit recovery path. + if dtw_on { + // SAFETY: ptr is non-null (just unwrapped from + // NonNull); pure C accessor reading a const field. + let n_text_ctx = unsafe { sys::whisper_n_text_ctx(ptr.as_ptr()) }; + if n_text_ctx > SUPPORTED_DTW_N_TEXT_CTX { + // SAFETY: ptr was returned by + // `whispercpp_init_from_file_no_state` and held only + // by us; nothing else has observed it. + // `whisper_free` is the matching deallocator. + unsafe { sys::whisper_free(ptr.as_ptr()) }; + return Err(WhisperError::ContextLoad { + path: smol_str::SmolStr::new(path_str.as_ref()), + reason: smol_str::SmolStr::new( + "DTW enabled with a model whose n_text_ctx exceeds SUPPORTED_DTW_N_TEXT_CTX (448) — \ + disable DTW (with_dtw_token_timestamps(false)) or use a standard checkpoint", + ), + }); + } + } return Ok(Self { ptr, lost: AtomicBool::new(false), @@ -504,11 +1039,17 @@ impl Context { /// Decode a single token id back to its surface form. Useful /// for token-level diagnostics. Returns `None` when: /// - /// * `token` is outside `[0, n_vocab)` — would otherwise - /// throw `std::out_of_range` from - /// `id_to_token.at(token)` across the C ABI (UB) per - /// `whisper.cpp:4201`. Pre-checking the bound here keeps - /// the unwound exception from crossing `extern "C"`. + /// * `token` is outside `[0, n_vocab)` (cheap pre-check; + /// avoids an FFI round-trip on caller-supplied invalid + /// ids). + /// * `token` is in `[0, n_vocab)` but absent from the + /// loaded vocab table — sparse-vocab models can have + /// `hparams.n_vocab` larger than the number of entries + /// actually populated by the loader. The `whispercpp-sys: + /// token_to_str sparse-vocab no-throw` patch in + /// `whisper_token_to_str` returns NULL in this case + /// (was: `id_to_token.at(token)` threw `std::out_of_range` + /// across `extern "C"`, undefined behaviour). /// * the underlying `c_str` is NULL or non-UTF-8 (model /// corruption). /// @@ -518,13 +1059,18 @@ impl Context { /// alias mutable C++ state — `id_to_token` is built once at /// load time and never modified.) pub fn token_to_str(&self, token: i32) -> Option<&str> { - // Validate before the FFI call — the upstream `at` throw - // would cross `extern "C"` and is UB. + // Cheap pre-check — saves an FFI round-trip when the + // caller supplies an obviously-invalid id. The C-side + // patch in `whisper_token_to_str` is what actually + // makes the call no-throw on sparse-vocab misses. let n = self.n_vocab(); if token < 0 || token >= n { return None; } - // SAFETY: token bound checked above; ctx pointer invariant. + // SAFETY: ctx pointer invariant. The C-side + // `whisper_token_to_str` returns NULL on any miss + // (out-of-range OR sparse-vocab gap), no throw across + // the boundary. let raw = unsafe { sys::whisper_token_to_str(self.ptr.as_ptr(), token) }; if raw.is_null() { return None; @@ -689,6 +1235,358 @@ mod tests { core::mem::forget(Arc::try_unwrap(ctx).ok().unwrap()); } + /// Fresh `ContextParams` defaults to DTW disabled with a + /// 128 MiB working budget. Pin the contract so a future + /// refactor can't quietly enable DTW for callers that didn't + /// ask for it (they'd silently pay the ~5–15% decode-time + /// overhead). + #[test] + fn default_context_params_have_dtw_off_and_default_mem_budget() { + let p = ContextParams::new(); + assert!(!p.dtw_token_timestamps()); + assert_eq!(p.dtw_aheads_preset(), AlignmentHeadsPreset::None); + assert_eq!(p.dtw_mem_size(), DEFAULT_DTW_MEM_SIZE); + assert_eq!(DEFAULT_DTW_MEM_SIZE, 128 * 1024 * 1024); + } + + /// `with_dtw_*` chained setters compose end-to-end without + /// consuming intermediate state. Mirrors the existing + /// `with_use_gpu` / `with_gpu_device` builder shape. + /// + /// Uses a `dtw_mem_size` above [`MIN_DTW_MEM_SIZE`] so the + /// clamp passes the value through unchanged. The clamp's + /// own boundary behaviour is pinned by + /// [`with_dtw_mem_size_clamps_zero_and_usize_max`]. + #[test] + fn context_params_chained_dtw_setters_compose() { + let custom_mem = MIN_DTW_MEM_SIZE * 2; + let p = ContextParams::new() + .with_use_gpu(false) + .with_dtw_token_timestamps(true) + .with_dtw_aheads_preset(AlignmentHeadsPreset::LargeV3Turbo) + .with_dtw_mem_size(custom_mem); + assert!(!p.use_gpu()); + assert!(p.dtw_token_timestamps()); + assert_eq!(p.dtw_aheads_preset(), AlignmentHeadsPreset::LargeV3Turbo); + assert_eq!(p.dtw_mem_size(), custom_mem); + } + + /// `clamp_dtw_mem_size` raises below-floor inputs to + /// `MIN_DTW_MEM_SIZE` and lowers above-ceiling inputs to + /// `MAX_DTW_MEM_SIZE`. Both ends close + /// native-code abort / null-deref paths inside whisper.cpp's + /// DTW arena allocator (see constants' docs for the + /// failure analysis). `usize::MIN` (= 0) and `usize::MAX` + /// are the boundary cases that motivated the clamp; pin + /// them so a future refactor can't quietly drop a guard. + #[test] + fn clamp_dtw_mem_size_pins_invariants() { + // Below floor → MIN. + assert_eq!(clamp_dtw_mem_size(0), MIN_DTW_MEM_SIZE); + assert_eq!(clamp_dtw_mem_size(1), MIN_DTW_MEM_SIZE); + assert_eq!(clamp_dtw_mem_size(1024), MIN_DTW_MEM_SIZE); + assert_eq!(clamp_dtw_mem_size(MIN_DTW_MEM_SIZE - 1), MIN_DTW_MEM_SIZE); + // At and just above floor → passthrough. + assert_eq!(clamp_dtw_mem_size(MIN_DTW_MEM_SIZE), MIN_DTW_MEM_SIZE); + assert_eq!( + clamp_dtw_mem_size(MIN_DTW_MEM_SIZE + 1), + MIN_DTW_MEM_SIZE + 1 + ); + // Inside range → passthrough. + assert_eq!( + clamp_dtw_mem_size(256 * 1024 * 1024), + 256 * 1024 * 1024, + "256 MiB sits between MIN ({MIN_DTW_MEM_SIZE}) and MAX ({MAX_DTW_MEM_SIZE})", + ); + // At and just below ceiling → passthrough. + assert_eq!(clamp_dtw_mem_size(MAX_DTW_MEM_SIZE), MAX_DTW_MEM_SIZE); + assert_eq!( + clamp_dtw_mem_size(MAX_DTW_MEM_SIZE - 1), + MAX_DTW_MEM_SIZE - 1 + ); + // Above ceiling → MAX. + assert_eq!(clamp_dtw_mem_size(MAX_DTW_MEM_SIZE + 1), MAX_DTW_MEM_SIZE); + assert_eq!(clamp_dtw_mem_size(usize::MAX), MAX_DTW_MEM_SIZE); + // Floor & ceiling order pin. Comparison is between two + // `const`s, so clippy's `assertions_on_constants` lint + // wants a `const { ... }` block to make the compile-time + // evaluation explicit. + const { assert!(MIN_DTW_MEM_SIZE <= MAX_DTW_MEM_SIZE) }; + assert_eq!(MIN_DTW_MEM_SIZE, DEFAULT_DTW_MEM_SIZE); + } + + /// `with_dtw_mem_size` clamps caller-supplied values into + /// the safe range. The clamp is the safe API's defense + /// against a `dtw_mem_size = 0` / `usize::MAX` slip + /// triggering whisper.cpp's `ggml_init` NULL-return / + /// arena-overflow abort path. Pin both directions. + #[test] + fn with_dtw_mem_size_clamps_zero_and_usize_max() { + let p = ContextParams::new().with_dtw_mem_size(0); + assert_eq!( + p.dtw_mem_size(), + MIN_DTW_MEM_SIZE, + "0 → MIN (defends against ggml_init NULL on zero arena)", + ); + let p = ContextParams::new().with_dtw_mem_size(usize::MAX); + assert_eq!( + p.dtw_mem_size(), + MAX_DTW_MEM_SIZE, + "usize::MAX → MAX (defends against ggml_init internal arena math overflow)", + ); + // In-range value passes through. + let p = ContextParams::new().with_dtw_mem_size(MIN_DTW_MEM_SIZE * 2); + assert_eq!(p.dtw_mem_size(), MIN_DTW_MEM_SIZE * 2); + } + + /// `Context::new` refuses to publish a context configured + /// with both `flash_attn` and DTW token-timestamps. + /// Whisper.cpp silently disables DTW under flash-attention + /// (`whisper.cpp:3956`); without an explicit Rust-side + /// rejection, callers would observe `Ok(Context)` for a + /// configuration that promises `Token::t_dtw` to be + /// populated and then receive only zeros. + /// + /// The check fires before any FFI file-load attempt, so + /// the test path doesn't need to exist on disk — the + /// validation is decided from `ContextParams` alone. + #[test] + #[cfg_attr(miri, ignore = "FFI: whisper_context_default_params")] + fn context_new_rejects_dtw_plus_flash_attn() { + let params = ContextParams::new() + .with_flash_attn(true) + .with_dtw_token_timestamps(true) + .with_dtw_aheads_preset(AlignmentHeadsPreset::LargeV3Turbo); + let result = Context::new("/nonexistent/dtw+flash-attn-test.bin", params); + match result { + Err(WhisperError::ContextLoad { reason, .. }) => { + assert!( + reason.contains("DTW") && reason.contains("flash_attn"), + "ContextLoad reason must explain DTW + flash_attn incompatibility — got: {}", + reason, + ); + } + Err(e) => panic!( + "expected ContextLoad with DTW + flash_attn rejection, got: {:?}", + e, + ), + Ok(_) => panic!("expected error for DTW + flash_attn config, got Ok(Context)"), + } + } + + /// Mirror of the rejection test for the inverse setter + /// order — `with_dtw_token_timestamps` before + /// `with_flash_attn`. The chained-builder shape means + /// either order can land in `ContextParams`; both must + /// reach the rejection. + #[test] + #[cfg_attr(miri, ignore = "FFI: whisper_context_default_params")] + fn context_new_rejects_dtw_plus_flash_attn_setter_order_invariant() { + let params = ContextParams::new() + .with_dtw_token_timestamps(true) + .with_dtw_aheads_preset(AlignmentHeadsPreset::LargeV3Turbo) + .with_flash_attn(true); + let result = Context::new("/nonexistent/dtw+flash-attn-test2.bin", params); + let err = result + .err() + .expect("expected ContextLoad error, got Ok(Context)"); + assert!( + matches!(err, WhisperError::ContextLoad { .. }), + "expected ContextLoad regardless of setter order, got: {:?}", + err, + ); + } + + /// `flash_attn` + `dtw_token_timestamps(true)` but + /// `preset = None` is NOT rejected — the effective DTW + /// state is "off" (per the preset coercion in + /// `Context::new`), so flash_attn is fine. Pin this so a + /// future tightening of the rejection doesn't accidentally + /// reject a valid configuration. + #[test] + #[cfg_attr(miri, ignore = "FFI: whisper_context_default_params")] + fn context_new_accepts_flash_attn_with_dtw_timestamps_but_no_preset() { + let params = ContextParams::new() + .with_flash_attn(true) + .with_dtw_token_timestamps(true); + // Preset stays at the default `None` — DTW is effectively off. + // Context::new should NOT trip the flash_attn + DTW rejection, + // but the file-load FFI will still fail (path doesn't exist). + // Either ContextLoad with a generic load message OR + // InvalidCString is acceptable — what we care about here is + // that the DTW + flash_attn rejection did NOT fire. + let result = Context::new("/nonexistent/no-dtw-fine.bin", params); + if let Err(WhisperError::ContextLoad { reason, .. }) = &result { + assert!( + !(reason.contains("DTW") && reason.contains("flash_attn")), + "DTW + flash_attn rejection fired for a config where DTW is off: {}", + reason, + ); + } + } + + /// Pin the [`SUPPORTED_DTW_N_TEXT_CTX`] constant. + /// + /// The value `448` is the `n_text_ctx` for every standard + /// whisper checkpoint (`tiny.en` through + /// `large-v3-turbo`). [`required_dtw_mem_size_for`] uses it + /// as the worst-case `n_tokens` axis when sizing the DTW + /// scratch arena, and [`Context::new`] uses it as the + /// model-load gate that refuses non-standard checkpoints + /// when DTW is enabled. Drift here invalidates both the + /// budget calc and the load-time validation — pin so a + /// future refactor has to be deliberate. + #[test] + fn supported_dtw_n_text_ctx_pins_to_standard_whisper_value() { + assert_eq!( + SUPPORTED_DTW_N_TEXT_CTX, 448, + "Standard whisper checkpoints all use n_text_ctx = 448. \ + If you changed this, also re-derive the DTW scratch budget \ + and update the byte-count pins in \ + `required_dtw_mem_size_pins_per_preset_minimums`.", + ); + } + + /// Per-preset alignment-head counts must match the + /// `g_aheads_*` tables in + /// `whisper.cpp/src/whisper.cpp:399-410`. A drift here + /// (e.g. an upstream rebuild renumbers a preset's heads or + /// our match arm typo'd a count) makes + /// [`required_dtw_mem_size_for`] return an under-sized + /// budget, which lets `ggml_new_tensor_3d` inside the DTW + /// path abort the process from safe Rust. Pin every + /// shipping preset's head count. + #[test] + fn alignment_head_count_matches_whisper_cpp_tables() { + use AlignmentHeadsPreset::*; + // Counts taken from g_aheads at whisper.cpp:412-424 of + // the patched submodule. + assert_eq!(alignment_head_count(None), 0); + assert_eq!(alignment_head_count(TinyEn), 8); + assert_eq!(alignment_head_count(Tiny), 6); + assert_eq!(alignment_head_count(BaseEn), 5); + assert_eq!(alignment_head_count(Base), 8); + assert_eq!(alignment_head_count(SmallEn), 19); + assert_eq!(alignment_head_count(Small), 10); + assert_eq!(alignment_head_count(MediumEn), 18); + assert_eq!(alignment_head_count(Medium), 6); + assert_eq!(alignment_head_count(LargeV1), 9); + assert_eq!(alignment_head_count(LargeV2), 23); + assert_eq!(alignment_head_count(LargeV3), 10); + assert_eq!(alignment_head_count(LargeV3Turbo), 6); + } + + /// `required_dtw_mem_size_for` must keep every shipping + /// preset above its realistic worst-case scratch peak. + /// The original 128 MiB floor was too small for the + /// 18–23-head presets (`SmallEn`, `MediumEn`, `LargeV2`), + /// whose DTW working tensor + `ggml_norm` output + + /// median-filter output add up to 145–186 MiB just for the + /// three live tensors. Pin the per-preset minimums so a + /// future refactor can't quietly shrink them back below + /// the abort threshold. + #[test] + fn required_dtw_mem_size_pins_per_preset_minimums() { + use AlignmentHeadsPreset::*; + // None → 0 (DTW disabled, no scratch needed). + assert_eq!(required_dtw_mem_size_for(None), 0); + // Small-head presets (≤10 heads) collapse to MIN floor. + for preset in [ + TinyEn, + Tiny, + BaseEn, + Base, + Small, + Medium, + LargeV1, + LargeV3, + LargeV3Turbo, + ] { + let req = required_dtw_mem_size_for(preset); + assert!( + req >= MIN_DTW_MEM_SIZE, + "{:?} requires {} bytes; must be ≥ MIN_DTW_MEM_SIZE ({})", + preset, + req, + MIN_DTW_MEM_SIZE, + ); + assert!( + req <= MAX_DTW_MEM_SIZE, + "{:?} requires {} bytes; must be ≤ MAX_DTW_MEM_SIZE ({})", + preset, + req, + MAX_DTW_MEM_SIZE, + ); + } + // High-head presets (the 128 MiB regression class) MUST + // exceed the floor — this is the regression sentinel for + // the original analysis bug. + for preset in [SmallEn, MediumEn, LargeV2] { + let req = required_dtw_mem_size_for(preset); + assert!( + req > MIN_DTW_MEM_SIZE, + "{:?} requires only {} bytes — must exceed MIN_DTW_MEM_SIZE ({}) \ + to fit its high-head DTW pipeline; without this the wrapper's \ + floor would let whisper.cpp abort during decode", + preset, + req, + MIN_DTW_MEM_SIZE, + ); + } + // Spot-check the explicit byte counts so a future change + // to the formula has to update this pin too. Math: + // per_tensor = 448 * 1500 * n_heads * 4 + // required = (per_tensor * 3) * 3 / 2 + // For LargeV2 (n_heads=23): + // per_tensor = 448 * 1500 * 23 * 4 = 61_824_000 bytes + // required = 61_824_000 * 3 * 3 / 2 = 278_208_000 bytes + assert_eq!(required_dtw_mem_size_for(LargeV2), 278_208_000); + // SmallEn (n_heads=19): 448 * 1500 * 19 * 4 = 51_072_000 bytes + // required = 51_072_000 * 3 * 3 / 2 = 229_824_000 bytes + assert_eq!(required_dtw_mem_size_for(SmallEn), 229_824_000); + // MediumEn (n_heads=18): 448 * 1500 * 18 * 4 = 48_384_000 bytes + // required = 48_384_000 * 3 * 3 / 2 = 217_728_000 bytes + assert_eq!(required_dtw_mem_size_for(MediumEn), 217_728_000); + } + + /// Every `AlignmentHeadsPreset` variant maps to a distinct + /// `whisper_alignment_heads_preset` raw value. If a future + /// upstream renumbering collapsed two presets to the same + /// value (or our match arm typo'd one), the timestamps + /// produced for the affected models would silently become + /// noise. Pin the bijection here as a regression sentinel. + #[test] + fn alignment_heads_preset_maps_to_distinct_raw_values() { + use AlignmentHeadsPreset::*; + let presets = [ + None, + TinyEn, + Tiny, + BaseEn, + Base, + SmallEn, + Small, + MediumEn, + Medium, + LargeV1, + LargeV2, + LargeV3, + LargeV3Turbo, + ]; + let raws: Vec = + presets.iter().map(|p| p.to_raw()).collect(); + // No duplicates: every preset maps somewhere different. + let mut sorted = raws.clone(); + sorted.sort(); + sorted.dedup(); + assert_eq!( + sorted.len(), + presets.len(), + "AlignmentHeadsPreset → raw mapping must be injective: got {:?}", + raws, + ); + } + /// `full_lock` survives the documented /// concurrent-worker pattern. Two threads contend on the /// same lock, both eventually finish, neither panics. The diff --git a/whispercpp/src/lib.rs b/whispercpp/src/lib.rs index 3c81b8c..4c7c468 100644 --- a/whispercpp/src/lib.rs +++ b/whispercpp/src/lib.rs @@ -10,7 +10,10 @@ mod params; mod state; mod sys; -pub use context::{Context, ContextParams, system_info}; +pub use context::{ + AlignmentHeadsPreset, Context, ContextParams, DEFAULT_DTW_MEM_SIZE, MAX_DTW_MEM_SIZE, + MIN_DTW_MEM_SIZE, SUPPORTED_DTW_N_TEXT_CTX, required_dtw_mem_size_for, system_info, +}; pub use error::{WhisperError, WhisperResult}; pub use lang::Lang; pub use params::{ diff --git a/whispercpp/src/state.rs b/whispercpp/src/state.rs index d0ea7bb..dd0ca45 100644 --- a/whispercpp/src/state.rs +++ b/whispercpp/src/state.rs @@ -488,6 +488,28 @@ impl<'a> Segment<'a> { /// Read-only snapshot. All fields are private; access goes /// through `const fn` accessors to keep the public surface /// stable as `whisper_token_data` evolves upstream. +/// +/// # Two timestamp sources +/// +/// Whisper.cpp can produce per-token timing two different ways, +/// and they live in different fields: +/// +/// * [`Self::t0`] / [`Self::t1`] come from the **timestamp-token +/// path** — whisper.cpp's standard heuristic that pairs each +/// token with the surrounding `<|t_x|>` markers. These are +/// populated when [`crate::Params::set_token_timestamps`] is +/// `true`, regardless of DTW. +/// * [`Self::t_dtw`] comes from the **DTW backtrace** — +/// independently derived from cross-attention weights of the +/// alignment heads. Only populated when DTW is enabled at +/// [`Context`] load time via +/// [`crate::ContextParams::with_dtw_token_timestamps`] (and +/// a non-`None` +/// [`crate::ContextParams::with_dtw_aheads_preset`]). +/// +/// DTW is generally more robust to long silences and repeated +/// tokens than the timestamp-token path; the timestamp-token +/// path is cheaper but more sensitive to attention misallocation. #[derive(Debug, Clone, Copy)] pub struct Token { id: i32, @@ -497,6 +519,7 @@ pub struct Token { ptsum: f32, t0: i64, t1: i64, + t_dtw: i64, vlen: f32, } @@ -531,18 +554,70 @@ impl Token { self.ptsum } - /// DTW-derived start time (centiseconds), if available. + /// Token start time, in centiseconds, derived from the + /// timestamp-token path. `0` when + /// [`crate::Params::set_token_timestamps`] is `false`. + /// **Not** the DTW timestamp — see [`Self::t_dtw`]. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn t0(&self) -> i64 { self.t0 } - /// DTW-derived end time (centiseconds), if available. + /// Token end time, in centiseconds, derived from the + /// timestamp-token path. `0` when + /// [`crate::Params::set_token_timestamps`] is `false`. + /// **Not** the DTW timestamp — see [`Self::t_dtw`]. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn t1(&self) -> i64 { self.t1 } + /// DTW-derived token timestamp, in centiseconds. Roughly the + /// moment in the audio at which whisper.cpp emitted this + /// token, computed by running DTW over the configured + /// alignment heads' cross-attention weights. + /// + /// Returns `Some(t)` when DTW computed a real timestamp for + /// this token. Returns `None` when DTW timing is unavailable + /// for any of these reasons: + /// + /// * DTW was not enabled at [`Context`] construction + /// (`with_dtw_token_timestamps(false)`, or preset left at + /// [`crate::AlignmentHeadsPreset::None`]). + /// * The token is a non-text (special / timestamp) token — + /// DTW only writes timing for text tokens. + /// * DTW skipped this segment because the chunk's + /// `audio_ctx` (overridden by + /// [`crate::Params::set_audio_ctx`]) was too small for + /// the chunk duration. + /// * DTW skipped this segment because the audio window was + /// too short for the median-filter pass + /// (`n_audio_tokens <= 1`, ≤20 ms). + /// + /// The `whispercpp-sys: dtw t_dtw sentinel init` patch in + /// `whisper.cpp` initialises every text token's `t_dtw` to + /// `-1` at the start of the DTW pass; successful + /// computation overwrites with a non-negative timestamp, + /// while skip paths leave the sentinel in place. Negative + /// timestamps are unreachable for valid DTW output, so `-1` + /// uniquely identifies "unavailable." + /// + /// Whisper.cpp ships validated alignment-head presets for + /// every standard checkpoint (see + /// [`crate::AlignmentHeadsPreset`]); using a preset that + /// doesn't match the loaded model produces unreliable DTW + /// timings without erroring — but this method still returns + /// `Some(...)` because the values were "computed", just + /// not meaningfully. Match the preset to the model. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn t_dtw(&self) -> Option { + if self.t_dtw < 0 { + None + } else { + Some(self.t_dtw) + } + } + /// Voice activity score, if available. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn vlen(&self) -> f32 { @@ -561,7 +636,131 @@ impl Token { ptsum: raw.ptsum, t0: raw.t0, t1: raw.t1, + t_dtw: raw.t_dtw, vlen: raw.vlen, } } } + +#[cfg(test)] +mod tests { + use super::*; + + /// `Token::from_raw` projects every field — including + /// [`Token::t_dtw`], which earlier + /// versions of this wrapper missed entirely (the C struct + /// carried it but the safe view didn't surface it). Pin + /// the projection so a future refactor can't quietly drop + /// a field again. + #[test] + fn token_from_raw_projects_every_field_including_t_dtw() { + let raw = sys::whisper_token_data { + id: 1234, + tid: 5678, + p: 0.8, + plog: -0.22, + pt: 0.05, + ptsum: 0.12, + t0: 100, + t1: 250, + t_dtw: 175, + vlen: 0.42, + }; + let tok = Token::from_raw(raw); + assert_eq!(tok.id(), 1234); + assert!((tok.p() - 0.8).abs() < 1e-6); + assert!((tok.plog() - -0.22).abs() < 1e-6); + assert!((tok.pt() - 0.05).abs() < 1e-6); + assert!((tok.ptsum() - 0.12).abs() < 1e-6); + assert_eq!(tok.t0(), 100); + assert_eq!(tok.t1(), 250); + assert_eq!( + tok.t_dtw(), + Some(175), + "Token::from_raw must project the DTW timestamp", + ); + assert!((tok.vlen() - 0.42).abs() < 1e-6); + } + + /// The DTW timestamp is independent from `t0`/`t1` — it + /// comes from a different mechanism (cross-attention DTW vs. + /// timestamp-token decoding). Confirm the safe view exposes + /// distinct values rather than aliasing them. + #[test] + fn t_dtw_is_independent_of_t0_t1() { + let raw = sys::whisper_token_data { + id: 0, + tid: 0, + p: 0.0, + plog: 0.0, + pt: 0.0, + ptsum: 0.0, + t0: 100, + t1: 200, + t_dtw: 150, + vlen: 0.0, + }; + let tok = Token::from_raw(raw); + assert_eq!(tok.t0(), 100); + assert_eq!(tok.t1(), 200); + assert_eq!(tok.t_dtw(), Some(150)); + // Sanity: distinct values flow through distinct accessors, + // not collapsed into one. + assert_ne!(tok.t_dtw(), Some(tok.t0())); + assert_ne!(tok.t_dtw(), Some(tok.t1())); + } + + /// `t_dtw == -1` is the sentinel set by the + /// `whispercpp-sys: dtw t_dtw sentinel init` patch when DTW + /// is enabled but skipped for a segment (audio_ctx mismatch + /// or short-window medfilt). The wrapper must surface that + /// as `None` so callers can distinguish "DTW skipped" from + /// "DTW computed at audio offset 0." + #[test] + fn t_dtw_sentinel_minus_one_maps_to_none() { + let raw = sys::whisper_token_data { + id: 0, + tid: 0, + p: 0.0, + plog: 0.0, + pt: 0.0, + ptsum: 0.0, + t0: 0, + t1: 0, + t_dtw: -1, + vlen: 0.0, + }; + let tok = Token::from_raw(raw); + assert_eq!( + tok.t_dtw(), + None, + "t_dtw == -1 must surface as None (DTW unavailable for token)", + ); + } + + /// `t_dtw == 0` is a *valid* DTW result for a token that + /// starts at audio offset 0. It must NOT be confused with + /// the unavailable sentinel — pin so a future "treat 0 as + /// missing" refactor can't silently break this. + #[test] + fn t_dtw_zero_maps_to_some_zero() { + let raw = sys::whisper_token_data { + id: 0, + tid: 0, + p: 0.0, + plog: 0.0, + pt: 0.0, + ptsum: 0.0, + t0: 0, + t1: 0, + t_dtw: 0, + vlen: 0.0, + }; + let tok = Token::from_raw(raw); + assert_eq!( + tok.t_dtw(), + Some(0), + "t_dtw == 0 is a valid timestamp (token at audio start), not the sentinel", + ); + } +}