CUDA backend (NVIDIA/WSL2) + faster-than-real-time STT#7
CUDA backend (NVIDIA/WSL2) + faster-than-real-time STT#7HorizonXP wants to merge 35 commits intoantirez:mainfrom
Conversation
|
Why WSL though and not native CUDA under windows? I was also trying to get CUDA to work with gemini, I'm not there yet unfortunately. However, I have a working Windows CPU AVX512 build (powershell build script) here: https://github.com/Danmoreng/voxtral.c/tree/windows-support Would you mind if I merge your CUDA implementation into my fork? |
Fair point. To be candid, this approach reflects what I had available and what I’m comfortable running locally. The primary goal was simply to leverage the GPU where possible and prove out the core CUDA path. My hope was that this would still be reasonably portable, at least across Linux and WSL, and that it could form a solid base for broader CUDA support. I don’t consider this complete yet, and I’m totally fine if you decide to merge it or not in its current form. You’ve already pointed out a real gap that I agree is worth addressing. I suspect it’s better handled as a follow-up PR once the core CUDA functionality has landed. Supporting all CUDA environments cleanly (Linux, WSL, Windows) is doable, but the toolchains and constraints differ enough that it will likely require explicit platform conditionals and some careful structuring. My thinking was to get the fundamentals right first, correctness and meaningful GPU acceleration, and then iterate toward platform-specific polish once we know the shape of the solution is sound. |
|
Correctness fix: if CUDA-full decoder falls back to CPU, host KV cache can be stale. Now we track host KV validity ( Cherry-picked commit: 2433096 |
|
Awesome, managed to get direct windows build to work on my machine and even added microphone input. works like a charm on my machine (Laptop RTX 5080 16GB). If you're interested, you can check it out here: https://github.com/Danmoreng/voxtral.c/tree/cuda-fork-merge https://github.com/Danmoreng/voxtral.c/blob/cuda-fork-merge/WINDOWS_CUDA_GUIDE.md If you need Microsoft MSVC / CUDA Toolkit under windows for building, I have a powershell script that installs all the requirements for building llama.cpp under windows already here: https://github.com/Danmoreng/llama.cpp-installer/blob/main/install_llama_cpp.ps1 |
- Add CUDA prefill (seq_len>1) to populate KV cache on-device and sync host KV. - Cache cuBLASLt descriptors/layouts for M=1 BF16 GEMMs to reduce per-call overhead. - Document new env VOX_DISABLE_CUDA_PREFILL and update benchmark notes.
- Add dynamic KV-append + dynamic attention kernels for graph capture. - Capture a single-token decoder step graph (opt-in via VOX_CUDA_GRAPHS=1). - Add bf16 cache eviction counter for observability.
- Add fused RMSNorm->BF16 kernel and use it in encoder/decoder attention norms. - Add mul_1p_rows kernel to apply ada_scale across prefill sequences in one launch. - Document CUDA Graphs opt-in and related env flags.
b9569fa to
1f0efcf
Compare
- Add fused SiLU*mul kernel (best-effort) - Support GEMM beta accumulation and use it to fold residual adds into matmuls - Apply the same fusions to prefill and graph capture paths
Performance: remove decoder prefill DtoH KV copies and skip host KV memmoves when host cache is stale.
Implements optional INT8 quantized LM head for fused top1 logits.
Summary
This PR adds an NVIDIA CUDA backend for
voxtral.c(Linux/WSL2 tested) and enables faster-than-real-time speech-to-text on an RTX 3080 Ti.It also adds native Windows build support (PowerShell build + model download scripts) and Windows microphone capture via WASAPI (ported from Danmoreng’s work referenced in PR comment #3867041166).
Key knobs:
VOX_CUDA_FAST=1: convenience preset that enables the best-known decoder speedups by default (CUDA graphs + attention v4 (fused KV append; falls back to v3) + merged projections + device RoPE + fused top1-only logits when alternatives are off), unless explicitly overridden.VOX_CUDA_PIPELINE_FULL=1: experimental full CUDA streaming pipeline (keeps adapter embeddings on-device; thread-safe across multiple contexts/streams via serialization).VOX_CUDA_LOGITS_FUSED=1:top1-onlylogits path (avoids materializing the full logits buffer when only the best token id is needed). Enabled by default underVOX_CUDA_FAST=1(disable withVOX_DISABLE_CUDA_LOGITS_FUSED=1).VOX_CUDA_LOGITS_INT8=1: opt-in INT8-quantized LM head fortop1-onlylogits (reduces bandwidth of thevocab x dimprojection). Default off; may affect accuracy.VOX_CUDA_CUBLASLT_MAX_WS_MB=<MB>: cap cuBLASLt workspace used for M=1 GEMM algo selection (default: 32). Larger values can enable faster kernels at the cost of persistent VRAM.VOX_CUDA_LT_COMPUTE=32F|32F_FAST_16BF|32F_FAST_TF32|32F_FAST_16F: opt-in cuBLASLt compute modes for BF16M=1GEMMs (default:32F). May change outputs slightly; validate with./scripts/accuracy_regression.sh.VOX_DISABLE_CUBLASLT_AUTOTUNE=1: disable best-effort cuBLASLt autotune for repeatedM=1decoder GEMMs (enabled by default underVOX_CUDA_FAST=1; override withVOX_CUDA_CUBLASLT_AUTOTUNE=0/1). This can reduce prefill overhead on very short clips.Benchmarks (RTX 3080 Ti, WSL2)
Definitions:
Wall transcribe: wall time excluding model loadxRT: times real-time =audio_seconds / wall_seconds(higher is better;> 1.0xis faster-than-real-time)All timings are from
VOX_PRINT_TIMINGS=1.samples/test_speech.wavsamples/I_have_a_dream.oggNotes:
Model load:) and is small (hundreds of ms here; includes CUDA driver init on first run).VOX_CUDA_FAST=1primarily accelerates the decoder step loop (graphs + v4 + merged weights).VOX_CUDA_LOGITS_INT8=1is most useful on longer samples: it does a one-time LM-head quantize+upload on first use (INT8 weights are ~384MiB). On very short clips, that one-time work can outweigh the per-step speedup.VOX_CUDA_FAST=1also enables cuBLASLt autotune for the repeatedM=1decoder GEMMs; on very short clips the one-time tuning shows up as higher prefill. Disable viaVOX_DISABLE_CUBLASLT_AUTOTUNE=1if benchmarking minimal startup latency.Detailed Timing Breakdown
test_speech.wav(3.641750s):I_have_a_dream.ogg(180.021438s; converted to 16kHz mono WAV for the run):How To Build / Run (Linux / WSL2)
Build:
make cuda(requires CUDA Toolkit +nvcc)Run:
./download_model.shVOX_PRINT_TIMINGS=1 ./voxtral -d voxtral-model -i samples/test_speech.wavVOX_CUDA_FAST=1 VOX_PRINT_TIMINGS=1 ./voxtral -d voxtral-model -i samples/I_have_a_dream.oggVOX_CUDA_FAST=1 VOX_CUDA_LOGITS_INT8=1 VOX_PRINT_TIMINGS=1 ./voxtral -d voxtral-model -i samples/I_have_a_dream.oggBenchmark helper:
./scripts/benchmark_backends.sh voxtral-model samples/test_speech.wavVOX_BENCH_SKIP_BLAS=1 ./scripts/benchmark_backends.sh voxtral-model samples/I_have_a_dream.oggVOX_BENCH_CUDA_OPTS=1 ...Windows (Native)
New files:
WINDOWS_CUDA_GUIDE.mdbuild.ps1,download_model.ps1,runtest.ps1voxtral_mic_win32.c(WASAPI mic)Quickstart:
Implementation Notes
vox_rms_norm_to_bf16_ada) to combine RMSNorm +(1+ada_scale)+ BF16 cast (reduces per-step kernel count).nvcc -cubinblob (voxtral_cuda_kernels_cubin.h) to avoid PTX JIT compatibility issues on WSL2.VOX_CUDA_PREFETCH=1,VOX_CUDA_HOSTREG_GIB=<GiB>, async alloc mempool (default on; disable viaVOX_DISABLE_CUDA_MEMPOOL=1).Validation
Ran (WSL2):
./scripts/validate_cuda.sh voxtral-model samples/test_speech.wav./scripts/validate_cuda_pipeline_compact.sh voxtral-model samples/antirez_speaking_italian_short.ogg./scripts/stress_cuda_two_streams.sh voxtral-model samples/test_speech.wav./scripts/accuracy_regression.sh voxtral-model samples/test_speech.wav 0./runtest.shCredits
cuda-fork-merge) and adapted into this PR.