diff --git a/whisperx/asr.py b/whisperx/asr.py index c47651675..0b47127da 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -401,7 +401,11 @@ def load_model( if vad_method == "silero": vad_model = Silero(**default_vad_options) elif vad_method == "pyannote": - vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) + if device == 'cuda': + device_vad = f'cuda:{device_index}' + else: + device_vad = device + vad_model = Pyannote(torch.device(device_vad), use_auth_token=None, **default_vad_options) else: raise ValueError(f"Invalid vad_method: {vad_method}")