Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends the diarization pipeline to support spill-backed (heap-or-mmap) matrix storage and row-major slice-based APIs, reducing OOM risk on large inputs and enabling efficient “matrix view” style interop with nalgebra where needed. It also adds an execution-provider (EP) re-export layer and updates parity tooling/docs to better support CoreML debugging and fixture-skipping in published-crate contexts.
Changes:
- Convert several clustering/pipeline interfaces from
nalgebra::DMatrixownership to explicit row-major&[f64]+ shape parameters, and route large intermediate buffers throughSpillBytes{Mut}. - Add
diarization::epmodule + provider auto-registration for segmentation (while keeping embedding on CPU by default), and expand CI/docs to cover EP feature combinations. - Improve parity tests/harness ergonomics (fixture skip macro, better errors in scripts) and harden model download pinning + SHA verification.
Reviewed changes
Copilot reviewed 58 out of 61 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/parity/src/main.rs | Adds EP dispatch/debug env knobs for parity binary (CoreML isolation). |
| tests/parity/run.sh | Improves error messaging when default fixture clip is missing. |
| tests/parity/Cargo.toml | Enables CoreML feature for parity workspace and adds direct ort dep. |
| src/test_util.rs | Introduces parity_fixtures_or_skip! macro for skipping tests without fixtures. |
| src/streaming/offline_diarizer.rs | Updates streaming-offline docs around latency expectations. |
| src/streaming/mod.rs | Clarifies “when not to use” guidance for finalize-bound latency. |
| src/segment/model.rs | Switches to new ort::ep::* types; adds auto-provider default options for segmentation. |
| src/segment/mod.rs | Re-exports ExecutionProviderDispatch from ort::ep. |
| src/reconstruct/tests.rs | Comment/doc cleanups in reconstruct tests. |
| src/reconstruct/rttm.rs | Small doc clarification about comparator encapsulation. |
| src/reconstruct/rttm_parity_tests.rs | Uses fixture-skip macro; adapts embeddings/post-PLDA inputs to row-major slices. |
| src/reconstruct/parity_tests.rs | Uses fixture-skip macro; adapts embeddings/post-PLDA inputs to row-major slices. |
| src/reconstruct/mod.rs | Re-exports ShapeError from reconstruct. |
| src/reconstruct/algo.rs | Doc wording tweaks; removes stray reference in comment. |
| src/plda/transform.rs | Doc wording tweak for XVEC_CENTERED_MIN_NORM reference. |
| src/plda/parity_tests.rs | Uses fixture-skip macro in PLDA parity tests. |
| src/plda/error.rs | Clarifies doc text about test-only constructor usage. |
| src/pipeline/tests.rs | Adapts unit tests to new row-major slice API for embeddings/post-PLDA. |
| src/pipeline/parity_tests.rs | Uses fixture-skip macro; adapts parity harness to row-major slice API. |
| src/pipeline/mod.rs | Updates module docs to reflect common call paths (offline + streaming-offline). |
| src/pipeline/error.rs | Adds Spill error variant and NonFiniteField::PostPlda. |
| src/pipeline/algo.rs | Core refactor: embeddings/post-PLDA as row-major slices; spill-backed buffers; transposes for VBx via DMatrixView. |
| src/ops/spill.rs | Hardens spill backend: target gating, platform tempfile strategy (incl. O_TMPFILE), and new UnsupportedTarget error. |
| src/ops/mod.rs | Clarifies CPU feature detection documentation. |
| src/offline/tests.rs | Adds boundary tests for early offline input validation (before clustering/spill allocs). |
| src/offline/parity_tests.rs | Uses fixture-skip macro in offline parity tests. |
| src/offline/owned.rs | Doc clarification about !Sync models and borrowing patterns. |
| src/offline/mod.rs | Exposes new offline tests; doc updates; removes re-export of removed cap constant. |
| src/offline/algo.rs | Refactors offline pipeline to spill-backed buffers; adds early count validation gates; switches away from heap-only DMatrix. |
| src/lib.rs | Exposes new ep module (feature-gated) and test utilities (test-only). |
| src/ep.rs | New module re-exporting ORT EP types and providing auto_providers() helper. |
| src/embed/types.rs | Doc wording tweaks around embedding normalization usage. |
| src/embed/options.rs | Switches to ort::ep::ExecutionProviderDispatch. |
| src/embed/model.rs | Documents CPU-default embed dispatch rationale; clarifies _with_options behavior. |
| src/embed/mod.rs | Minor doc cleanup for layered API list. |
| src/cluster/vbx/tests.rs | Updates tests to pass DMatrixView into vbx_iterate. |
| src/cluster/vbx/parity_tests.rs | Uses fixture-skip macro; updates vbx_iterate calls for DMatrixView. |
| src/cluster/vbx/algo.rs | Changes vbx_iterate signature to take DMatrixView; doc cleanups. |
| src/cluster/hungarian/parity_tests.rs | Uses fixture-skip macro. |
| src/cluster/centroid/tests.rs | Adds adapters for row-major input; updates tests to new weighted_centroids signature. |
| src/cluster/centroid/parity_tests.rs | Uses fixture-skip macro; adapts parity test inputs to row-major slices. |
| src/cluster/centroid/algo.rs | Refactors weighted_centroids to accept row-major slice + shape params; removes redundant copy. |
| src/cluster/ahc/tests.rs | Adds adapters for row-major input; updates tests to new ahc_init signature. |
| src/cluster/ahc/parity_tests.rs | Uses fixture-skip macro; adapts parity test inputs to row-major slices. |
| src/cluster/ahc/algo.rs | Refactors ahc_init to accept row-major slice + shape params; spill-backs normalization buffer. |
| src/aggregate/parity_tests.rs | Uses fixture-skip macro. |
| src/aggregate/mod.rs | Doc reference tweak for parity test naming. |
| src/aggregate/count.rs | Changes hamming_aggregate/try_hamming_aggregate to return SpillBytes<f64> instead of Vec<f64>. |
| scripts/download-embed-model.sh | Pins model download revision; adds atomic download + SHA check; updates HF source. |
| README.md | Updates quickstart: pinned HF revision/SHA workflow; clarifies features and env var override. |
| models/SOURCE.md | Updates embed model provenance notes (single-file ONNX and CoreML caveats). |
| examples/stream_layer1.rs | Updates example invocation docs and clarifies no-ort usage. |
| examples/run_streaming_pipeline.rs | Adds DIA_EMBED_MODEL_PATH override; updates documented feature set. |
| examples/run_owned_pipeline.rs | Adds DIA_EMBED_MODEL_PATH override and improved error context. |
| examples/run_owned_pipeline_tch.rs | Updates documented feature set (bundled segmentation). |
| CHANGELOG.md | Updates unreleased notes to reflect offline + streaming-offline pipelines, EPs, and spill backend. |
| Cargo.toml | Adds packaging excludes; adds per-EP features + gpu meta-feature; docs.rs feature pinning; CI-related metadata. |
| .gitignore | Clarifies embed model ignore/unignore rationale and excludes unused sidecar. |
| .github/workflows/ci.yml | Adjusts feature-sweep strategy; adds EP bindings check and docs.rs feature-set doc build; updates cross targets. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| let mut seg = if force_cpu_seg { | ||
| SegmentModel::bundled_with_options(SegmentModelOptions::default()) | ||
| .context("load bundled segment model (CPU)")? | ||
| } else if compute_units.is_some() { | ||
| // Caller pinned a compute unit — explicitly construct the EP | ||
| // with that pin and pass via `_with_options`. Default | ||
| // `bundled()` would auto_providers() with CoreML's defaults. | ||
| let opts = SegmentModelOptions::default().with_providers(vec![coreml_provider()]); | ||
| SegmentModel::bundled_with_options(opts).context("load bundled segment model (CoreML pinned)")? | ||
| } else { | ||
| SegmentModel::bundled().context("load bundled segment model (auto)")? |
| let mut emb = if force_cpu_emb { | ||
| EmbedModel::from_file_with_options(&emb_path, EmbedModelOptions::default()) | ||
| .context("load embed model (CPU)")? | ||
| } else if compute_units.is_some() { | ||
| let opts = EmbedModelOptions::default().with_providers(vec![coreml_provider()]); | ||
| EmbedModel::from_file_with_options(&emb_path, opts).context("load embed model (CoreML pinned)")? | ||
| } else { | ||
| EmbedModel::from_file(&emb_path).context("load embed model (auto)")? | ||
| }; |
| # Pinned upstream revision + expected SHA-256 of the FP32 single-file ONNX. | ||
| DIA_EMBED_MODEL_REV="38168b544a562dec24d49e63786c16e80782eeaf" | ||
| DIA_EMBED_MODEL_SHA256="4c15c6be4235318d092c9d347e00c68ba476136d6172f675f76ad6b0c2661f01" | ||
| mkdir -p models | ||
| TMP="$(mktemp -t wespeaker_resnet34_lm.XXXX.onnx)" | ||
| ``` |
| eprintln!( | ||
| "# dia: seg={} emb={} coreml_cu={}", | ||
| if force_cpu_seg { "CPU" } else { "auto" }, | ||
| if force_cpu_emb { "CPU" } else { "auto" }, | ||
| cu_label, | ||
| ); |
| /// Build a minimal valid `OfflineInput`-shaped data set: well-formed | ||
| /// raw_embeddings + segmentations matching `num_chunks * num_speakers | ||
| /// * num_frames_per_chunk`, default sliding windows, with the count | ||
| /// tensor controlled by the caller. The PLDA transform is bundled. | ||
| fn synthetic_input_inputs( | ||
| num_chunks: usize, | ||
| num_frames_per_chunk: usize, | ||
| ) -> ( |
| /// Inputs to [`assign_embeddings`]. Grouped to keep the function | ||
| /// signature manageable. | ||
| #[derive(Debug, Clone)] | ||
| pub struct AssignEmbeddingsInput<'a> { | ||
| embeddings: &'a DMatrix<f64>, | ||
| /// Pre-PLDA per-`(chunk, speaker)` f64 embeddings, **row-major** |
| pub use algo::{Error, OfflineInput, OfflineOutput, diarize_offline}; | ||
|
|
| let expected_emb_rows = num_chunks | ||
| .checked_mul(num_speakers) | ||
| .ok_or(ShapeError::EmbeddingsRowsOverflow)?; | ||
| if embeddings.nrows() != expected_emb_rows { | ||
| let expected_emb_len = expected_emb_rows | ||
| .checked_mul(embed_dim) | ||
| .ok_or(ShapeError::EmbeddingsRowsOverflow)?; | ||
| if embeddings.len() != expected_emb_len { |
| let expected_emb_len = num_train | ||
| .checked_mul(embed_dim) | ||
| .ok_or(ShapeError::EmbeddingsQRowMismatch)?; | ||
| if embeddings.len() != expected_emb_len { | ||
| return Err(ShapeError::EmbeddingsQRowMismatch.into()); | ||
| } |
Welcome to Codecov 🎉Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests. ℹ️ You can also turn on project coverage checks and project coverage reporting on Pull Request comment Thanks for integrating Codecov - We've got you covered ☂️ |
| // Default (all unset) auto-registers `dia::ep::auto_providers()` | ||
| // for both — at build time with `--features coreml`, the CoreML EP. | ||
| let disable_auto = std::env::var("DIA_DISABLE_AUTO_PROVIDERS").ok().as_deref() == Some("1"); |
| /// Implementation: a private `cmp_cluster_id_str` is the canonical | ||
| /// pyannote-equivalent comparator. It renders both ids into stack- | ||
| /// allocated `itoa::Buffer`s and compares the resulting `&str` | ||
| /// slices — zero heap allocation. |
No description provided.