Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ check:
test:
cargo test --workspace

# Cross-validate Rust mel+ONNX pipeline against Python reference probabilities
# Cross-validate Rust mel+ONNX pipeline against Python reference probabilities.
# Builds with `wavekat-smart-turn` so the zh fine-tune rows are also emitted;
# WaveKat weights are fetched from HuggingFace on first run (cached in $HF_HOME).
accuracy:
cargo test --features pipecat --test accuracy -- --ignored accuracy_report --nocapture
cargo test --features wavekat-smart-turn --test accuracy -- --ignored accuracy_report --nocapture

# Compare Rust vs Python mel spectrograms element-wise (requires .npy fixtures)
mel:
Expand Down
147 changes: 109 additions & 38 deletions crates/wavekat-turn/tests/accuracy.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Cross-validation accuracy test: Rust pipeline vs. Python reference.
//!
//! Verifies that our mel preprocessing and ONNX inference produce probabilities
//! within ±0.02 of the Python (Pipecat) reference for each fixture audio clip.
//! within ±0.02 of the Python reference for each fixture audio clip, across
//! every enabled backend.
//!
//! Prerequisites:
//! 1. Run `python scripts/gen_reference.py` once to produce
Expand All @@ -10,6 +11,10 @@
//!
//! Run individual regression tests: `cargo test --features pipecat --test accuracy`
//! Run the full report table: `make accuracy`
//!
//! When the `wavekat-smart-turn` feature is enabled, the report additionally
//! exercises the WaveKat zh fine-tune against the `zh_*.wav` fixtures. Weights
//! are downloaded from HuggingFace on first run (cached under `$HF_HOME/hub/`).

use std::path::PathBuf;

Expand All @@ -34,10 +39,19 @@ fn fixtures_dir() -> PathBuf {
#[cfg(any(feature = "pipecat"))]
#[derive(serde::Deserialize)]
struct RefEntry {
/// Which backend produced this reference probability.
/// Defaults to "pipecat" so older `reference.json` files keep working.
#[serde(default = "default_backend")]
backend: String,
file: String,
probability: f32,
}

#[cfg(any(feature = "pipecat"))]
fn default_backend() -> String {
"pipecat".to_string()
}

#[cfg(any(feature = "pipecat"))]
fn load_reference() -> Vec<RefEntry> {
let path = fixtures_dir().join("reference.json");
Expand All @@ -50,6 +64,11 @@ fn load_reference() -> Vec<RefEntry> {
serde_json::from_str(&json).expect("invalid reference.json")
}

#[cfg(any(feature = "pipecat"))]
fn entries_for<'a>(entries: &'a [RefEntry], backend: &str) -> Vec<&'a RefEntry> {
entries.iter().filter(|e| e.backend == backend).collect()
}

// ---------------------------------------------------------------------------
// Report row — one entry per (backend, clip)
// ---------------------------------------------------------------------------
Expand All @@ -75,44 +94,57 @@ impl Row {
}
}

// ---------------------------------------------------------------------------
// Shared audio helpers used by backend modules
// ---------------------------------------------------------------------------

#[cfg(feature = "pipecat")]
fn load_wav_f32(path: &std::path::Path) -> Vec<f32> {
let mut reader = hound::WavReader::open(path)
.unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
let spec = reader.spec();
assert_eq!(spec.sample_rate, 16_000, "expected 16 kHz");
assert_eq!(spec.channels, 1, "expected mono");
match spec.sample_format {
hound::SampleFormat::Int => reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / 32768.0) // match soundfile's normalization
.collect(),
hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
}
}

#[cfg(feature = "pipecat")]
fn raw_prob(pred: &wavekat_turn::TurnPrediction) -> f32 {
use wavekat_turn::TurnState;
match pred.state {
TurnState::Finished => pred.confidence,
TurnState::Unfinished => 1.0 - pred.confidence,
TurnState::Wait => unreachable!(),
}
}

// ---------------------------------------------------------------------------
// Pipecat backend
// ---------------------------------------------------------------------------

#[cfg(feature = "pipecat")]
mod pipecat {
use std::path::Path;

use wavekat_turn::audio::PipecatSmartTurn;
use wavekat_turn::{AudioFrame, AudioTurnDetector, TurnPrediction, TurnState};

use super::{fixtures_dir, RefEntry, Row, TOLERANCE};

fn load_wav_f32(path: &Path) -> Vec<f32> {
let mut reader = hound::WavReader::open(path)
.unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
let spec = reader.spec();
assert_eq!(spec.sample_rate, 16_000, "expected 16 kHz");
assert_eq!(spec.channels, 1, "expected mono");
match spec.sample_format {
hound::SampleFormat::Int => reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / 32768.0) // match soundfile's normalization
.collect(),
hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
}
}
use wavekat_turn::{AudioFrame, AudioTurnDetector};

use super::{entries_for, fixtures_dir, load_wav_f32, raw_prob, RefEntry, Row, TOLERANCE};

fn reference_prob(entries: &[RefEntry], name: &str) -> f32 {
entries
.iter()
.find(|e| e.file == name)
.unwrap_or_else(|| panic!("no entry for '{}' in reference.json", name))
.find(|e| e.backend == "pipecat" && e.file == name)
.unwrap_or_else(|| panic!("no pipecat entry for '{}' in reference.json", name))
.probability
}

pub(super) fn rows(entries: &[RefEntry]) -> Vec<Row> {
entries
entries_for(entries, "pipecat")
.iter()
.map(|entry| {
let samples = load_wav_f32(&fixtures_dir().join(&entry.file));
Expand All @@ -132,18 +164,11 @@ mod pipecat {
.collect()
}

fn raw_prob(pred: &TurnPrediction) -> f32 {
match pred.state {
TurnState::Finished => pred.confidence,
TurnState::Unfinished => 1.0 - pred.confidence,
TurnState::Wait => unreachable!(),
}
}

pub(super) fn run_regression(clip: &str) {
let entries = super::load_reference();
let python_prob = reference_prob(&entries, clip);
let row = rows(&[RefEntry {
backend: "pipecat".to_string(),
file: clip.to_string(),
probability: python_prob,
}])
Expand Down Expand Up @@ -173,12 +198,53 @@ mod pipecat {
}
}

// Add future audio backends here:
// ---------------------------------------------------------------------------
// WaveKat zh backend (Smart Turn fine-tune)
// ---------------------------------------------------------------------------
//
// #[cfg(feature = "livekit-audio")]
// mod livekit_audio {
// pub(super) fn rows(entries: &[super::RefEntry]) -> Vec<super::Row> { ... }
// }
// Loads `wavekat/smart-turn-ONNX` (zh) from HuggingFace on first run. Subsequent
// runs hit the HF cache under `$HF_HOME/hub/`. The shared mel/inference pipeline
// is identical to upstream Pipecat — only the weights differ — so reusing the
// pipecat helpers is intentional.

#[cfg(feature = "wavekat-smart-turn")]
mod wavekat {
use wavekat_turn::audio::{PipecatSmartTurn, SmartTurnLang, SmartTurnVariant};
use wavekat_turn::{AudioFrame, AudioTurnDetector};

use super::{entries_for, fixtures_dir, load_wav_f32, raw_prob, RefEntry, Row};

pub(super) fn rows(entries: &[RefEntry]) -> Vec<Row> {
let backend_entries = entries_for(entries, "wavekat-zh");
if backend_entries.is_empty() {
return Vec::new();
}

// Load once, score every clip — the HF download is the slowest step.
let mut detector =
PipecatSmartTurn::with_variant(SmartTurnVariant::Wavekat(SmartTurnLang::Zh))
.expect("failed to load wavekat zh model from HuggingFace");

backend_entries
.iter()
.map(|entry| {
detector.reset();
let samples = load_wav_f32(&fixtures_dir().join(&entry.file));
for chunk in samples.chunks(1600) {
detector.push_audio(&AudioFrame::new(chunk, 16_000));
}
let pred = detector.predict().expect("predict failed");
let rust_prob = raw_prob(&pred);
Row {
backend: "wavekat-zh",
clip: entry.file.clone(),
python_prob: entry.probability,
rust_prob,
}
})
.collect()
}
}

// ---------------------------------------------------------------------------
// Accuracy report — prints a markdown table covering all enabled backends
Expand All @@ -194,7 +260,12 @@ fn accuracy_report() {
#[allow(unused_mut)]
let mut r = Vec::new();
#[cfg(feature = "pipecat")]
r.extend(pipecat::rows(&load_reference()));
{
let entries = load_reference();
r.extend(pipecat::rows(&entries));
#[cfg(feature = "wavekat-smart-turn")]
r.extend(wavekat::rows(&entries));
}
r
};

Expand Down
13 changes: 12 additions & 1 deletion scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ scripts/.venv/bin/python3 scripts/gen_reference.py
| File | Description |
|------|-------------|
| `tests/fixtures/silence_2s.wav` | 2 s of zeros at 16 kHz (generated if missing) |
| `tests/fixtures/reference.json` | P(complete) for each fixture clip |
| `tests/fixtures/reference.json` | P(complete) for each `(backend, clip)` pair |

Each entry in `reference.json` is keyed by `backend` so multiple Smart Turn
variants can coexist:

- `pipecat` — upstream Pipecat Smart Turn v3 on the English fixtures, plus the
same model run on the Mandarin fixtures as a cross-lingual baseline.
- `wavekat-zh` — WaveKat zh fine-tune of Smart Turn, only run on the
Mandarin fixtures (`zh_*.wav`).

The Rust accuracy test (`make accuracy`) filters rows by backend at compile
time — feature `wavekat-smart-turn` enables the second group.

Commit both files after re-running.
Loading