From 2d9ce44329ae73af2520196d31cd14b6192ace44 Mon Sep 17 00:00:00 2001 From: Jean Du <37294470+duj12@users.noreply.github.com> Date: Wed, 2 Jul 2025 14:07:59 +0800 Subject: [PATCH] fix(asr): load VAD model on correct CUDA device (#835) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix(asr): load VAD model on correct CUDA device Previously, the VAD sub‐model was always initialized on the default CUDA device (cuda:0), even when a higher device_index was specified. This change sets `device_vad` to `cuda:{device_index}` whenever `device == 'cuda'`, while falling back to the original `device` string for non‐CUDA cases. This ensures the VAD model is loaded on the intended GPU. Co-authored-by: dujing Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com> --- whisperx/asr.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index c47651675..0b47127da 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -401,7 +401,11 @@ def load_model( if vad_method == "silero": vad_model = Silero(**default_vad_options) elif vad_method == "pyannote": - vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) + if device == 'cuda': + device_vad = f'cuda:{device_index}' + else: + device_vad = device + vad_model = Pyannote(torch.device(device_vad), use_auth_token=None, **default_vad_options) else: raise ValueError(f"Invalid vad_method: {vad_method}")