Skip to content

feat: auto-select ONNX Runtime providers (CUDA / CPU) for training, eval, and inference#70

Open
bnovik0v wants to merge 1 commit intolivekit:mainfrom
bnovik0v:feat/ort-gpu-providers
Open

feat: auto-select ONNX Runtime providers (CUDA / CPU) for training, eval, and inference#70
bnovik0v wants to merge 1 commit intolivekit:mainfrom
bnovik0v:feat/ort-gpu-providers

Conversation

@bnovik0v
Copy link
Copy Markdown

The bug

Every ort.InferenceSession in the project is created with providers=["CPUExecutionProvider"] hardcoded, in four places:

  • src/livekit/wakeword/models/feature_extractor.py:42-44 (Mel)
  • src/livekit/wakeword/models/feature_extractor.py:101-103 (Embedding)
  • src/livekit/wakeword/inference/model.py:87-89 (classifier)
  • src/livekit/wakeword/eval/evaluate.py:197 (eval)

Even if a user installs onnxruntime-gpu on a GPU pod, CUDA is never selected. The Swift package already supports pluggable execution providers (ExecutionProvider.coreML), so this gap is Python-specific.

Impact

Feature extraction during augment is the most ORT-heavy stage — roughly 17 model calls per 2 s clip (1 mel + 16 sliding embedding windows). On an RTX 3090 pod this runs on CPU at ~3 clips/sec. A production config (~150k augmented clips across 6 splits × 2 rounds) takes ~14 hours of CPU time; with CUDA active, the same workload finished in ~13 minutes on the same pod.

The bug also affects WakeWordModel (the listener) and run_eval, both of which silently stay on CPU even when the host has onnxruntime-gpu installed.

The fix

New helper src/livekit/wakeword/_ort_providers.py with a single function:

def get_providers() -> list[str]:
    # 1. LIVEKIT_WAKEWORD_ORT_PROVIDERS env var (comma-separated) wins if set
    # 2. Otherwise intersect ort.get_available_providers() with
    #    (\"CUDAExecutionProvider\", \"CPUExecutionProvider\") — CUDA first, CPU fallback
    # 3. If neither is available, fall through to whatever ORT reports

All four call sites now use it. Behaviour on CPU-only installs is unchanged (plain onnxruntime reports only CPUExecutionProvider).

The env var escape hatch covers:

  • Forcing CPU on GPU hosts for reproducibility
  • Opting into CoreML / DirectML / ROCm / TensorRT providers without plumbing a config field

Design decisions:

  • Helper, not duplicated call-site logic — four copies of provider selection would drift.
  • Env var, not config field or constructor arg — a config field only helps YAML-driven training, not library consumers of WakeWordModel; a constructor arg would thread through multiple public APIs. Env var works everywhere with zero surface change.
  • Default preference is CUDA / CPU only — CoreML / DirectML / ROCm / TensorRT are exotic enough to warrant explicit opt-in via the env var.
  • No warning when CPU is selected on a GPU host — can't cheaply detect "GPU hardware present" from Python without importing torch; users who suspect slowness can check the INFO log for the selected provider list.

Packaging

Adds an optional gpu extra in pyproject.toml:

gpu = [\"onnxruntime-gpu>=1.17\"]

README gets a new "GPU acceleration" subsection documenting the switch:

pip uninstall -y onnxruntime
pip install livekit-wakeword[train,eval,export,gpu]

The uninstall step is required because onnxruntime and onnxruntime-gpu share the Python module name — pip cannot keep them side-by-side. The README also points at ONNX Runtime's GPU compatibility matrix since the onnxruntime-gpu wheel bundles specific CUDA toolkit versions.

What's in this PR

File Change
src/livekit/wakeword/_ort_providers.py New: get_providers() helper
src/livekit/wakeword/models/feature_extractor.py Use helper in MelSpectrogramFrontend + SpeechEmbedding
src/livekit/wakeword/inference/model.py Use helper in WakeWordModel.load_model
src/livekit/wakeword/eval/evaluate.py Use helper in run_eval
pyproject.toml New gpu optional extra
README.md GPU acceleration subsection in the training install flow
tests/test_ort_providers.py New: 12 tests — env-var override parsing, auto-detection, filtering, logging, real-call-site smoke on MelSpectrogramFrontend
uv.lock Regenerated for the new extra

Out of scope (deliberately — these are follow-ups)

  • Cross-clip batching in run_extraction: features.py processes one clip at a time. Even on GPU this leaves throughput on the table; a batched path would give another large speedup. Worth its own PR — touches the extraction loop, not provider selection.
  • Splitting augment into augment + extract CLI commands: right now run_augmentation deletes all *_rN.wav files at the top, so if feature extraction fails the user re-runs and loses the augmented clips. A --no-clean flag and/or a separate extract command is a UX concern, not a provider concern.

Verification

uv run ruff check src/livekit/wakeword/_ort_providers.py <touched files>  # clean
uv run ruff format --check <touched files>                                 # clean
uv run mypy src/livekit/wakeword/                                          # same 1 pre-existing error as main
uv run pytest tests/                                                       # 72 passed (60 existing + 12 new)

Reproduction of the original bug:

livekit-wakeword generate configs/prod.yaml
livekit-wakeword augment configs/prod.yaml
# Feature extraction sub-stage runs at ~3 cps on a GPU pod. nvidia-smi shows 0% util.

After this PR, with livekit-wakeword[train,eval,export,gpu] installed (and onnxruntime uninstalled), the same workload saturates the GPU — measured ~200 cps on RTX 3090.

…val, and inference

Every ort.InferenceSession in the project is currently created with
providers=["CPUExecutionProvider"] hardcoded, in four places:

- models/feature_extractor.py:42-44 (Mel)
- models/feature_extractor.py:101-103 (Embedding)
- inference/model.py:87-89 (classifier)
- eval/evaluate.py:197 (eval)

Even if a user installs `onnxruntime-gpu` on a GPU pod, CUDA is never
selected. Feature extraction during augment (the most ORT-heavy stage —
~17 model calls per clip) runs single-threaded on CPU at ~3 clips/sec
on an RTX 3090, making augment the bottleneck of the full pipeline.

This patch centralises provider selection in a new
`_ort_providers.get_providers()` helper:

- Default: intersect `ort.get_available_providers()` with
  ("CUDAExecutionProvider", "CPUExecutionProvider") — CUDA when the GPU
  wheel is installed, CPU otherwise. Zero behaviour change for CPU-only
  installs.
- Override: LIVEKIT_WAKEWORD_ORT_PROVIDERS env var (comma-separated) for
  forcing CPU for reproducibility, or opting into CoreML / DirectML /
  ROCm / TensorRT.
- All four sites now call `get_providers()` instead of hardcoding.

Also adds an optional `gpu` extra pinning `onnxruntime-gpu>=1.17`, and a
README subsection explaining the `pip uninstall onnxruntime` switch
(the CPU and GPU wheels share a Python module name and cannot coexist).

Scope limited to provider selection. Cross-clip batching in
run_extraction and splitting augment into augment + extract CLI
commands are follow-up PRs, not bundled here.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant