feat: auto-select ONNX Runtime providers (CUDA / CPU) for training, eval, and inference#70
Open
bnovik0v wants to merge 1 commit intolivekit:mainfrom
Open
feat: auto-select ONNX Runtime providers (CUDA / CPU) for training, eval, and inference#70bnovik0v wants to merge 1 commit intolivekit:mainfrom
bnovik0v wants to merge 1 commit intolivekit:mainfrom
Conversation
…val, and inference
Every ort.InferenceSession in the project is currently created with
providers=["CPUExecutionProvider"] hardcoded, in four places:
- models/feature_extractor.py:42-44 (Mel)
- models/feature_extractor.py:101-103 (Embedding)
- inference/model.py:87-89 (classifier)
- eval/evaluate.py:197 (eval)
Even if a user installs `onnxruntime-gpu` on a GPU pod, CUDA is never
selected. Feature extraction during augment (the most ORT-heavy stage —
~17 model calls per clip) runs single-threaded on CPU at ~3 clips/sec
on an RTX 3090, making augment the bottleneck of the full pipeline.
This patch centralises provider selection in a new
`_ort_providers.get_providers()` helper:
- Default: intersect `ort.get_available_providers()` with
("CUDAExecutionProvider", "CPUExecutionProvider") — CUDA when the GPU
wheel is installed, CPU otherwise. Zero behaviour change for CPU-only
installs.
- Override: LIVEKIT_WAKEWORD_ORT_PROVIDERS env var (comma-separated) for
forcing CPU for reproducibility, or opting into CoreML / DirectML /
ROCm / TensorRT.
- All four sites now call `get_providers()` instead of hardcoding.
Also adds an optional `gpu` extra pinning `onnxruntime-gpu>=1.17`, and a
README subsection explaining the `pip uninstall onnxruntime` switch
(the CPU and GPU wheels share a Python module name and cannot coexist).
Scope limited to provider selection. Cross-clip batching in
run_extraction and splitting augment into augment + extract CLI
commands are follow-up PRs, not bundled here.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The bug
Every
ort.InferenceSessionin the project is created withproviders=["CPUExecutionProvider"]hardcoded, in four places:src/livekit/wakeword/models/feature_extractor.py:42-44(Mel)src/livekit/wakeword/models/feature_extractor.py:101-103(Embedding)src/livekit/wakeword/inference/model.py:87-89(classifier)src/livekit/wakeword/eval/evaluate.py:197(eval)Even if a user installs
onnxruntime-gpuon a GPU pod, CUDA is never selected. The Swift package already supports pluggable execution providers (ExecutionProvider.coreML), so this gap is Python-specific.Impact
Feature extraction during
augmentis the most ORT-heavy stage — roughly 17 model calls per 2 s clip (1 mel + 16 sliding embedding windows). On an RTX 3090 pod this runs on CPU at ~3 clips/sec. A production config (~150k augmented clips across 6 splits × 2 rounds) takes ~14 hours of CPU time; with CUDA active, the same workload finished in ~13 minutes on the same pod.The bug also affects
WakeWordModel(the listener) andrun_eval, both of which silently stay on CPU even when the host hasonnxruntime-gpuinstalled.The fix
New helper
src/livekit/wakeword/_ort_providers.pywith a single function:All four call sites now use it. Behaviour on CPU-only installs is unchanged (plain
onnxruntimereports onlyCPUExecutionProvider).The env var escape hatch covers:
Design decisions:
WakeWordModel; a constructor arg would thread through multiple public APIs. Env var works everywhere with zero surface change.Packaging
Adds an optional
gpuextra inpyproject.toml:README gets a new "GPU acceleration" subsection documenting the switch:
The uninstall step is required because
onnxruntimeandonnxruntime-gpushare the Python module name — pip cannot keep them side-by-side. The README also points at ONNX Runtime's GPU compatibility matrix since theonnxruntime-gpuwheel bundles specific CUDA toolkit versions.What's in this PR
src/livekit/wakeword/_ort_providers.pyget_providers()helpersrc/livekit/wakeword/models/feature_extractor.pyMelSpectrogramFrontend+SpeechEmbeddingsrc/livekit/wakeword/inference/model.pyWakeWordModel.load_modelsrc/livekit/wakeword/eval/evaluate.pyrun_evalpyproject.tomlgpuoptional extraREADME.mdtests/test_ort_providers.pyMelSpectrogramFrontenduv.lockOut of scope (deliberately — these are follow-ups)
run_extraction:features.pyprocesses one clip at a time. Even on GPU this leaves throughput on the table; a batched path would give another large speedup. Worth its own PR — touches the extraction loop, not provider selection.augmentintoaugment+extractCLI commands: right nowrun_augmentationdeletes all*_rN.wavfiles at the top, so if feature extraction fails the user re-runs and loses the augmented clips. A--no-cleanflag and/or a separateextractcommand is a UX concern, not a provider concern.Verification
Reproduction of the original bug:
livekit-wakeword generate configs/prod.yaml livekit-wakeword augment configs/prod.yaml # Feature extraction sub-stage runs at ~3 cps on a GPU pod. nvidia-smi shows 0% util.After this PR, with
livekit-wakeword[train,eval,export,gpu]installed (andonnxruntimeuninstalled), the same workload saturates the GPU — measured ~200 cps on RTX 3090.