Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Cargo.lock
.DS_Store
.cargo/config.toml

# Claude Code runtime state
.claude/scheduled_tasks.lock

# Python tooling (scripts/)
scripts/.venv/
scripts/__pycache__/
Expand Down
10 changes: 9 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
.PHONY: help check test fmt lint doc ci accuracy mel example-controller
.PHONY: help check test fmt lint doc ci accuracy mel hf-smoke example-controller

help:
@echo "Available targets:"
@echo " check Check workspace compiles"
@echo " test Run all tests"
@echo " accuracy Cross-validate Rust pipeline against Python reference"
@echo " mel Compare Rust vs Python mel spectrograms element-wise"
@echo " hf-smoke Download wavekat/smart-turn-ONNX from HF and run zh fixtures"
@echo " fmt Format code"
@echo " lint Run clippy with warnings as errors"
@echo " doc Build and open docs in browser"
Expand All @@ -28,6 +29,13 @@ accuracy:
mel:
cargo test --features pipecat -- mel_report --ignored --nocapture

# Download wavekat/smart-turn-ONNX from HuggingFace and assert the zh fine-tune
# correctly classifies the Mandarin fixtures. Requires network on first run;
# subsequent runs hit the HF cache under $HF_HOME/hub/.
hf-smoke:
cargo test --features wavekat-smart-turn --test pipecat \
-- --ignored wavekat_hf_download_smoke --nocapture

# Run TurnController example
example-controller:
cargo run --features pipecat --example controller
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@ models behind common Rust traits. Same pattern as
| Backend | Feature flag | Input | Model size | Inference | License |
|---------|-------------|-------|------------|-----------|---------|
| [Pipecat Smart Turn v3](https://github.com/pipecat-ai/smart-turn) | `pipecat` | Audio (16 kHz PCM) | ~8 MB (int8 ONNX) | ~12 ms CPU | BSD 2-Clause |
| WaveKat Smart Turn fine-tunes ([HF](https://huggingface.co/wavekat/smart-turn-ONNX)) | `wavekat-smart-turn` | Audio (16 kHz PCM) | ~8 MB (int8 ONNX) | ~12 ms CPU | BSD 2-Clause |
| [LiveKit Turn Detector](https://github.com/livekit/turn-detector) | `livekit` | Text (ASR transcript) | ~400 MB (ONNX) | ~25 ms CPU | LiveKit Model License |

The WaveKat fine-tunes share the upstream Pipecat ONNX contract (same input
shape, same tensor names) — they're language-specialized weights for the
same architecture. Use them when you want better behavior on a specific
language; today Mandarin (`zh`) is the only one shipped, but more will land
in the same HF repo over time.

## Quick Start

```sh
Expand Down Expand Up @@ -92,8 +99,33 @@ wavekat-voice --> orchestrates VAD + turn + ASR + LLM + TTS
| Flag | Default | Description |
|------|---------|-------------|
| `pipecat` | off | Pipecat Smart Turn v3 audio backend (requires `ort`, `ndarray`) |
| `wavekat-smart-turn` | off | WaveKat language-specialized fine-tunes; implies `pipecat`, adds `hf-hub` runtime download |
| `livekit` | off | LiveKit text-based backend (requires `ort`, `ndarray`) |

### Selecting a Smart Turn variant

```rust
use wavekat_turn::audio::{PipecatSmartTurn, SmartTurnVariant};
# #[cfg(feature = "wavekat-smart-turn")]
use wavekat_turn::audio::SmartTurnLang;

// Embedded upstream weights — works offline, no setup.
let detector = PipecatSmartTurn::new()?;

# #[cfg(feature = "wavekat-smart-turn")]
// WaveKat Mandarin fine-tune — downloaded from HuggingFace on first call,
// then cached under $HF_HOME/hub/.
let detector = PipecatSmartTurn::with_variant(
SmartTurnVariant::Wavekat(SmartTurnLang::Zh),
)?;
```

The first call for a WaveKat variant downloads the ONNX from
[`wavekat/smart-turn-ONNX`](https://huggingface.co/wavekat/smart-turn-ONNX)
and caches it under `$HF_HOME/hub/` (default `~/.cache/huggingface/hub/`).
For offline builds, set `WAVEKAT_TURN_MODEL_DIR` to a directory containing
`<lang>/smart-turn-cpu.onnx` to skip the download.

## Important Notes

- **8 kHz telephony audio must be upsampled to 16 kHz** before passing to
Expand Down
7 changes: 7 additions & 0 deletions crates/wavekat-turn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ build = "build.rs"
default = []
pipecat = ["dep:ort", "dep:ndarray", "dep:realfft", "dep:ureq"]
livekit = ["dep:ort", "dep:ndarray"]
# WaveKat language-specialized Smart Turn fine-tunes, fetched from HuggingFace
# at runtime via `hf-hub`. The language is chosen at runtime through
# `SmartTurnVariant::Wavekat(SmartTurnLang::…)`.
wavekat-smart-turn = ["pipecat", "dep:hf-hub"]

[dependencies]
wavekat-core = "0.0.4"
Expand All @@ -26,6 +30,9 @@ thiserror = "2"
ort = { version = "2.0.0-rc.12", optional = true, features = ["ndarray"] }
ndarray = { version = "0.17", optional = true }
realfft = { version = "3", optional = true }
# Runtime HuggingFace downloads for WaveKat fine-tunes (gated on
# `wavekat-smart-turn`). A blocking ureq backend keeps us off tokio.
hf-hub = { version = "0.5", optional = true, default-features = false, features = ["ureq"] }

[build-dependencies]
ureq = { version = "3", optional = true }
Expand Down
14 changes: 13 additions & 1 deletion crates/wavekat-turn/src/audio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,21 @@
//!
//! These backends operate directly on raw audio frames and do not
//! require an upstream ASR transcript.
//!
//! [`PipecatSmartTurn`] is the entry point; [`SmartTurnVariant`] selects
//! which set of weights to load (upstream Pipecat vs WaveKat fine-tunes).
//! When the `wavekat-smart-turn` feature is enabled, [`SmartTurnLang`]
//! enumerates the language-specialized fine-tunes available on
//! HuggingFace.

#[cfg(feature = "pipecat")]
mod pipecat;

#[cfg(feature = "wavekat-smart-turn")]
pub(crate) mod wavekat_download;

#[cfg(feature = "pipecat")]
pub use pipecat::PipecatSmartTurn;
pub use pipecat::{PipecatSmartTurn, SmartTurnVariant};

#[cfg(feature = "wavekat-smart-turn")]
pub use pipecat::SmartTurnLang;
64 changes: 63 additions & 1 deletion crates/wavekat-turn/src/audio/pipecat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,40 @@ use realfft::{RealFftPlanner, RealToComplex};
use crate::onnx;
use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState};

// ---------------------------------------------------------------------------
// Model variants
// ---------------------------------------------------------------------------

/// Language for a WaveKat fine-tune of Pipecat Smart Turn.
///
/// Each variant resolves to a `<lang>/smart-turn-cpu.onnx` file inside the
/// language-agnostic HuggingFace repo `wavekat/smart-turn-ONNX`. The set is
/// marked `#[non_exhaustive]` because adding a new language must not be a
/// breaking change.
#[cfg(feature = "wavekat-smart-turn")]
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SmartTurnLang {
/// Mandarin Chinese.
Zh,
}

/// Which set of Smart Turn weights to load.
///
/// All variants share the same architecture (Whisper-Tiny encoder + binary
/// classification head) and ONNX tensor contract — only the weights differ.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SmartTurnVariant {
/// Upstream multilingual Pipecat Smart Turn v3 (embedded in the crate).
PipecatV3,
/// WaveKat language-specialized fine-tune. Resolved at runtime through
/// HuggingFace (cached under `$HF_HOME/hub/`) and overridable via
/// `WAVEKAT_TURN_MODEL_DIR`.
#[cfg(feature = "wavekat-smart-turn")]
Wavekat(SmartTurnLang),
}

// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -373,6 +407,10 @@ fn prepare_audio(samples: &[f32]) -> Vec<f32> {

/// Pipecat Smart Turn v3 detector.
///
/// Wraps the Smart Turn v3 architecture (Whisper-Tiny encoder + binary
/// classification head). Use [`new`] for the embedded upstream weights, or
/// [`with_variant`] to pick a WaveKat fine-tune at runtime.
///
/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with
/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires
/// end-of-speech to get a [`TurnPrediction`].
Expand All @@ -392,6 +430,8 @@ fn prepare_audio(samples: &[f32]) -> Vec<f32> {
/// # }
/// ```
///
/// [`new`]: Self::new
/// [`with_variant`]: Self::with_variant
/// [`push_audio`]: AudioTurnDetector::push_audio
/// [`predict`]: AudioTurnDetector::predict
pub struct PipecatSmartTurn {
Expand All @@ -409,12 +449,34 @@ unsafe impl Send for PipecatSmartTurn {}
unsafe impl Sync for PipecatSmartTurn {}

impl PipecatSmartTurn {
/// Load the Smart Turn v3.2 model embedded at compile time.
/// Load the upstream Pipecat Smart Turn v3.2 model embedded at compile time.
///
/// Equivalent to [`with_variant(SmartTurnVariant::PipecatV3)`](Self::with_variant).
pub fn new() -> Result<Self, TurnError> {
let session = onnx::session_from_memory(MODEL_BYTES)?;
Ok(Self::build(session))
}

/// Load a specific variant of the Smart Turn model.
///
/// - [`SmartTurnVariant::PipecatV3`] uses the embedded ONNX bytes — no
/// network required.
/// - [`SmartTurnVariant::Wavekat`] (when the `wavekat-smart-turn` feature
/// is enabled) downloads the corresponding language file from the
/// `wavekat/smart-turn-ONNX` HuggingFace repo and caches it under
/// `$HF_HOME/hub/`. Set `WAVEKAT_TURN_MODEL_DIR` to point at a
/// pre-populated directory (offline / CI use).
pub fn with_variant(variant: SmartTurnVariant) -> Result<Self, TurnError> {
match variant {
SmartTurnVariant::PipecatV3 => Self::new(),
#[cfg(feature = "wavekat-smart-turn")]
SmartTurnVariant::Wavekat(lang) => {
let path = crate::audio::wavekat_download::resolve_model(lang)?;
Self::from_file(path)
}
}
}

/// Load a model from a custom path on disk.
///
/// Useful for CI environments that supply the model file separately, or
Expand Down
66 changes: 66 additions & 0 deletions crates/wavekat-turn/src/audio/wavekat_download.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//! Runtime download of WaveKat-trained Smart Turn weights from HuggingFace.
//!
//! Mirrors the `wavekat-tts` pattern: one language-agnostic HF repo with
//! per-language subdirectories, a dated `REVISION` pinned in code so that
//! model updates ship via a crate release, and a `WAVEKAT_TURN_MODEL_DIR`
//! escape hatch for offline / CI builds.

use std::path::PathBuf;

use hf_hub::api::sync::ApiBuilder;
use hf_hub::{Repo, RepoType};

use super::pipecat::SmartTurnLang;
use crate::error::TurnError;

/// HuggingFace repo holding all WaveKat Smart Turn fine-tunes.
const REPO_ID: &str = "wavekat/smart-turn-ONNX";

/// Pinned model revision. Bumping this string is the way to ship updated
/// weights to consumers — same pattern as `wavekat-tts`.
const REVISION: &str = "main";

/// Env var that lets callers point at a local directory containing
/// `<lang>/smart-turn-cpu.onnx`, skipping the HuggingFace download entirely.
const LOCAL_DIR_ENV: &str = "WAVEKAT_TURN_MODEL_DIR";

/// Map a language to its file path inside the HF repo.
fn relative_path(lang: SmartTurnLang) -> &'static str {
match lang {
SmartTurnLang::Zh => "zh/smart-turn-cpu.onnx",
}
}

/// Resolve the on-disk path for `lang`, downloading from HuggingFace if needed.
pub(crate) fn resolve_model(lang: SmartTurnLang) -> Result<PathBuf, TurnError> {
let rel = relative_path(lang);

if let Some(dir) = std::env::var_os(LOCAL_DIR_ENV) {
let candidate = PathBuf::from(dir).join(rel);
if !candidate.exists() {
return Err(TurnError::ModelNotLoaded(format!(
"{LOCAL_DIR_ENV} is set but {} does not exist",
candidate.display()
)));
}
return Ok(candidate);
}

let api = ApiBuilder::new()
.with_token(std::env::var("HF_TOKEN").ok())
.build()
.map_err(|e| TurnError::BackendError(format!("failed to build hf-hub client: {e}")))?;

let repo = api.repo(Repo::with_revision(
REPO_ID.to_string(),
RepoType::Model,
REVISION.to_string(),
));

repo.get(rel).map_err(|e| {
TurnError::BackendError(format!(
"failed to download {REPO_ID}@{REVISION}:{rel} from HuggingFace: {e}. \
Set {LOCAL_DIR_ENV} to a directory containing {rel} to skip the download."
))
})
}
9 changes: 8 additions & 1 deletion crates/wavekat-turn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@
//!
//! | Feature | Backend | Input |
//! |---------|---------|-------|
//! | `pipecat` | Pipecat Smart Turn v3 (ONNX) | Audio (16 kHz) |
//! | `pipecat` | Pipecat Smart Turn v3 (ONNX, embedded) | Audio (16 kHz) |
//! | `wavekat-smart-turn` | WaveKat language-specialized fine-tunes (ONNX, runtime download) | Audio (16 kHz) |
//! | `livekit` | LiveKit Turn Detector (ONNX) | Text |
//!
//! `wavekat-smart-turn` implies `pipecat` and adds an `hf-hub` runtime
//! dependency. Weights live in
//! [`wavekat/smart-turn-ONNX`](https://huggingface.co/wavekat/smart-turn-ONNX)
//! and are cached under `$HF_HOME/hub/`. Set `WAVEKAT_TURN_MODEL_DIR` to a
//! directory containing `<lang>/smart-turn-cpu.onnx` to skip the download.

pub mod controller;
pub mod error;
Expand Down
Loading