diff --git a/backend/parakeet/stream_handler.py b/backend/parakeet/stream_handler.py index dff7640d2f2..988f3ecb22f 100644 --- a/backend/parakeet/stream_handler.py +++ b/backend/parakeet/stream_handler.py @@ -23,13 +23,13 @@ from langdetect import detect as langdetect_detect from langdetect.lang_detect_exception import LangDetectException from scipy.spatial.distance import cdist -from transcribe import transcribe_file, _stream_model as _asr_model, INFERENCE_MODE as _INFERENCE_MODE - -try: - from pyannote.audio import Model as _PyannoteModel, Inference as _PyannoteInference -except ImportError: - _PyannoteModel = None - _PyannoteInference = None +from transcribe import ( + transcribe_file, + _stream_model as _asr_model, + INFERENCE_MODE as _INFERENCE_MODE, + get_builtin_embedding_model, + wav_bytes_to_waveform, +) logger = logging.getLogger(__name__) @@ -45,35 +45,6 @@ SPEAKER_EMBEDDING_URL = os.getenv("HOSTED_SPEAKER_EMBEDDING_API_URL", "") MIN_EMBEDDING_AUDIO_S = 0.5 -_embedding_model = None -_embedding_lock = threading.Lock() - - -def _get_builtin_embedding_model(): - global _embedding_model - if _embedding_model is not None: - return _embedding_model - with _embedding_lock: - if _embedding_model is not None: - return _embedding_model - try: - if _PyannoteModel is None or _PyannoteInference is None: - logger.warning("pyannote.audio not installed, built-in embedding unavailable") - return None - model = _PyannoteModel.from_pretrained( - "pyannote/wespeaker-voxceleb-resnet34-LM", token=os.getenv("HUGGINGFACE_TOKEN") - ) - inference = _PyannoteInference(model, window="whole") - if _torch is not None: - device = _torch.device("cuda" if _torch.cuda.is_available() else "cpu") - inference.to(device) - _embedding_model = inference - logger.info("Built-in speaker embedding model loaded (wespeaker-voxceleb-resnet34-LM)") - return _embedding_model - except Exception as e: - logger.warning(f"Could not load built-in embedding model: {e}") - return None - _vad_model = None _vad_lock = threading.Lock() @@ -88,11 +59,6 @@ def _get_builtin_embedding_model(): except ImportError: _torch = None -try: - import torchaudio -except ImportError: - torchaudio = None - def _make_divisible_by(num, factor: int) -> int: return (num // factor) * factor @@ -728,7 +694,7 @@ def _assign_speaker(self, pcm: bytes, start: float, end: float) -> str: return f"SPEAKER_{self._last_speaker}" def _get_embedding(self, wav_bytes: bytes): - model = _get_builtin_embedding_model() + model = get_builtin_embedding_model() if model is not None: return self._get_embedding_builtin(wav_bytes, model) if SPEAKER_EMBEDDING_URL: @@ -737,8 +703,7 @@ def _get_embedding(self, wav_bytes: bytes): def _get_embedding_builtin(self, wav_bytes: bytes, model): try: - buf = io.BytesIO(wav_bytes) - waveform, sample_rate = torchaudio.load(buf) + waveform, sample_rate = wav_bytes_to_waveform(wav_bytes) dur = waveform.shape[1] / sample_rate if dur < MIN_EMBEDDING_AUDIO_S: return None diff --git a/backend/parakeet/transcribe.py b/backend/parakeet/transcribe.py index 42d95b45160..e6389c8059d 100644 --- a/backend/parakeet/transcribe.py +++ b/backend/parakeet/transcribe.py @@ -1,6 +1,7 @@ import io import os import logging +import threading import wave as _wave import httpx @@ -24,6 +25,67 @@ except ImportError: nemo_asr = None +try: + import torch as _torch +except ImportError: + _torch = None + +try: + from pyannote.audio import Model as _PyannoteModel, Inference as _PyannoteInference +except ImportError: + _PyannoteModel = None + _PyannoteInference = None + +_embedding_model = None +_embedding_lock = threading.Lock() + + +def get_builtin_embedding_model(): + global _embedding_model + if _embedding_model is not None: + return _embedding_model + with _embedding_lock: + if _embedding_model is not None: + return _embedding_model + try: + if _PyannoteModel is None or _PyannoteInference is None: + logger.warning("pyannote.audio not installed, built-in embedding unavailable") + return None + model = _PyannoteModel.from_pretrained( + "pyannote/wespeaker-voxceleb-resnet34-LM", token=os.getenv("HUGGINGFACE_TOKEN") + ) + inference = _PyannoteInference(model, window="whole") + if _torch is not None and _torch.cuda.is_available(): + inference.to(_torch.device("cuda")) + _embedding_model = inference + logger.info("Built-in speaker embedding model loaded (wespeaker-voxceleb-resnet34-LM)") + return _embedding_model + except Exception as e: + logger.warning(f"Could not load built-in embedding model: {e}") + return None + + +def wav_bytes_to_waveform(wav_bytes: bytes): + buf = io.BytesIO(wav_bytes) + with _wave.open(buf, "rb") as wf: + sr = wf.getframerate() + nch = wf.getnchannels() + sw = wf.getsampwidth() + pcm = wf.readframes(wf.getnframes()) + + if sw == 1: + samples = np.frombuffer(pcm, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0 + elif sw == 2: + samples = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0 + elif sw == 4: + samples = np.frombuffer(pcm, dtype=np.int32).astype(np.float32) / 2147483648.0 + else: + raise ValueError(f"Unsupported WAV sample width: {sw} bytes") + if nch > 1: + samples = samples.reshape(-1, nch).mean(axis=1) + waveform = _torch.from_numpy(samples).unsqueeze(0) + return waveform, sr + def set_gpu_worker(worker) -> None: global _gpu_worker @@ -197,7 +259,7 @@ def _transcribe_nim(file_path: str): def _diarize_segments(file_path: str, base: dict) -> dict: - if not SPEAKER_EMBEDDING_URL: + if not SPEAKER_EMBEDDING_URL and get_builtin_embedding_model() is None: for seg in base["segments"]: seg["speaker"] = "SPEAKER_0" return base @@ -270,7 +332,33 @@ def _extract_segment_wav(wav_bytes: bytes, start: float, end: float) -> bytes: def _get_embedding(wav_bytes: bytes): + model = get_builtin_embedding_model() + if model is not None: + emb = _get_embedding_builtin(wav_bytes, model) + if emb is not None: + return emb + if SPEAKER_EMBEDDING_URL: + return _get_embedding_http(wav_bytes) + return None + + +def _get_embedding_builtin(wav_bytes: bytes, model): + try: + waveform, sample_rate = wav_bytes_to_waveform(wav_bytes) + dur = waveform.shape[1] / sample_rate + if dur < MIN_SEGMENT_DURATION: + return None + emb = model({"waveform": waveform, "sample_rate": sample_rate}) + emb = np.array(emb, dtype=np.float32) + if emb.ndim == 1: + emb = emb.reshape(1, -1) + return emb + except Exception as e: + logger.warning(f"Built-in embedding failed: {e}") + return None + +def _get_embedding_http(wav_bytes: bytes): try: with httpx.Client(timeout=httpx.Timeout(connect=5.0, read=30.0, write=10.0, pool=5.0)) as client: resp = client.post( @@ -289,5 +377,5 @@ def _get_embedding(wav_bytes: bytes): emb = emb.reshape(1, -1) return emb except Exception as e: - logger.warning(f"Embedding extraction failed: {e}") + logger.warning(f"HTTP embedding failed: {e}") return None diff --git a/backend/test.sh b/backend/test.sh index 3f9b65880e0..bb1febc24d8 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -25,6 +25,7 @@ pytest tests/unit/test_parakeet_stream_session.py -v pytest tests/unit/test_parakeet_gpu_worker.py -v pytest tests/unit/test_parakeet_batch_engine.py -v pytest tests/unit/test_parakeet_batch_routing.py -v +pytest tests/unit/test_parakeet_builtin_embedding.py -v pytest tests/unit/test_parakeet_endpoints.py -v pytest tests/unit/test_audiobuffer_guard.py -v pytest tests/unit/test_memory_leak_buffers.py -v diff --git a/backend/tests/unit/test_parakeet_builtin_embedding.py b/backend/tests/unit/test_parakeet_builtin_embedding.py new file mode 100644 index 00000000000..754359413bb --- /dev/null +++ b/backend/tests/unit/test_parakeet_builtin_embedding.py @@ -0,0 +1,324 @@ +"""Tests for built-in speaker embedding in parakeet transcribe.py. + +Validates that batch diarization uses the built-in wespeaker model first +and falls back to HTTP only when the built-in model is unavailable. +""" + +import io +import os +import struct +import sys +import wave +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +os.environ.setdefault('PARAKEET_INFERENCE_MODE', 'nemo') +os.environ.setdefault('PARAKEET_STREAM_MODEL', '') +os.environ.setdefault('PARAKEET_DEVICE', 'cpu') +os.environ.setdefault('PARAKEET_TORCH_COMPILE', 'false') +os.environ.setdefault('PARAKEET_CUDA_GRAPHS', 'false') + +_torch_mock = MagicMock() +_torch_mock.cuda.is_available.return_value = False +_torch_mock.cuda.is_bf16_supported.return_value = False +_torch_mock.bfloat16 = 'bfloat16' + + +def _torch_from_numpy(arr): + result = MagicMock() + result.unsqueeze.return_value = result + result.shape = [1, len(arr)] + return result + + +_torch_mock.from_numpy = _torch_from_numpy +sys.modules['torch'] = _torch_mock + +for _mod in [ + 'nemo', + 'nemo.collections', + 'nemo.collections.asr', + 'nemo.collections.asr.models', + 'pyannote', + 'pyannote.audio', +]: + if _mod not in sys.modules: + sys.modules[_mod] = MagicMock() + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../parakeet')) + +if 'transcribe' in sys.modules: + _existing = sys.modules['transcribe'] + if not hasattr(_existing, '__file__') or _existing.__file__ is None: + del sys.modules['transcribe'] + +import transcribe # noqa: E402 + + +def _make_wav_bytes(duration_s=1.0, sample_rate=16000, channels=1): + n_samples = int(duration_s * sample_rate) + buf = io.BytesIO() + with wave.open(buf, 'wb') as wf: + wf.setnchannels(channels) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(struct.pack(f'<{n_samples * channels}h', *([1000] * n_samples * channels))) + return buf.getvalue() + + +def _make_wav_bytes_8bit(duration_s=1.0, sample_rate=16000): + n_samples = int(duration_s * sample_rate) + buf = io.BytesIO() + with wave.open(buf, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(1) + wf.setframerate(sample_rate) + wf.writeframes(bytes([128] * n_samples)) + return buf.getvalue() + + +def _make_wav_bytes_32bit(duration_s=1.0, sample_rate=16000): + n_samples = int(duration_s * sample_rate) + buf = io.BytesIO() + with wave.open(buf, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(4) + wf.setframerate(sample_rate) + wf.writeframes(struct.pack(f'<{n_samples}i', *([100000] * n_samples))) + return buf.getvalue() + + +class TestWavBytesToWaveform: + def test_returns_waveform_and_sample_rate(self): + wav = _make_wav_bytes(duration_s=0.5, sample_rate=16000) + waveform, sr = transcribe.wav_bytes_to_waveform(wav) + assert sr == 16000 + assert waveform.shape == [1, 8000] + + def test_stereo_downmix(self): + wav = _make_wav_bytes(duration_s=0.5, sample_rate=16000, channels=2) + waveform, sr = transcribe.wav_bytes_to_waveform(wav) + assert sr == 16000 + assert waveform.shape == [1, 8000] + + def test_8bit_unsigned_pcm(self): + wav = _make_wav_bytes_8bit(duration_s=0.5, sample_rate=16000) + waveform, sr = transcribe.wav_bytes_to_waveform(wav) + assert sr == 16000 + assert waveform.shape == [1, 8000] + + def test_32bit_pcm(self): + wav = _make_wav_bytes_32bit(duration_s=0.5, sample_rate=16000) + waveform, sr = transcribe.wav_bytes_to_waveform(wav) + assert sr == 16000 + assert waveform.shape == [1, 8000] + + def test_unsupported_width_raises(self): + buf = io.BytesIO() + with wave.open(buf, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(3) + wf.setframerate(16000) + wf.writeframes(b'\x00\x00\x00' * 8000) + with pytest.raises(ValueError, match="Unsupported WAV sample width"): + transcribe.wav_bytes_to_waveform(buf.getvalue()) + + +class TestGetEmbedding: + def test_uses_builtin_model_first(self): + fake_model = MagicMock() + fake_model.return_value = np.zeros(256, dtype=np.float32) + wav = _make_wav_bytes(duration_s=1.0) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + with patch.object(transcribe, '_get_embedding_http') as http_mock: + result = transcribe._get_embedding(wav) + + assert result is not None + assert result.shape == (1, 256) + http_mock.assert_not_called() + + def test_falls_back_to_http_when_builtin_unavailable(self): + wav = _make_wav_bytes(duration_s=1.0) + http_emb = np.ones((1, 256), dtype=np.float32) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=None): + with patch.object(transcribe, '_get_embedding_http', return_value=http_emb) as http_mock: + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = 'http://fake-diarizer' + try: + result = transcribe._get_embedding(wav) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + http_mock.assert_called_once_with(wav) + assert result is not None + np.testing.assert_array_equal(result, http_emb) + + def test_falls_back_to_http_when_builtin_fails(self): + fake_model = MagicMock() + fake_model.side_effect = RuntimeError("GPU error") + wav = _make_wav_bytes(duration_s=1.0) + http_emb = np.ones((1, 256), dtype=np.float32) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + with patch.object(transcribe, '_get_embedding_http', return_value=http_emb) as http_mock: + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = 'http://fake-diarizer' + try: + result = transcribe._get_embedding(wav) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + http_mock.assert_called_once_with(wav) + np.testing.assert_array_equal(result, http_emb) + + def test_returns_none_when_no_builtin_no_url(self): + wav = _make_wav_bytes(duration_s=1.0) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=None): + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = '' + try: + result = transcribe._get_embedding(wav) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + assert result is None + + def test_returns_none_when_builtin_fails_and_http_fails(self): + fake_model = MagicMock() + fake_model.side_effect = RuntimeError("GPU error") + wav = _make_wav_bytes(duration_s=1.0) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + with patch.object(transcribe, '_get_embedding_http', return_value=None): + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = 'http://fake-diarizer' + try: + result = transcribe._get_embedding(wav) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + assert result is None + + def test_reshapes_1d_embedding(self): + fake_model = MagicMock() + fake_model.return_value = np.zeros(128, dtype=np.float32) + wav = _make_wav_bytes(duration_s=1.0) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + result = transcribe._get_embedding(wav) + + assert result.shape == (1, 128) + + +class TestGetBuiltinEmbeddingModel: + def test_returns_none_when_pyannote_unavailable(self): + old_model = transcribe._embedding_model + transcribe._embedding_model = None + old_pyannote = transcribe._PyannoteModel + transcribe._PyannoteModel = None + try: + result = transcribe.get_builtin_embedding_model() + assert result is None + finally: + transcribe._PyannoteModel = old_pyannote + transcribe._embedding_model = old_model + + def test_returns_cached_model_without_reload(self): + sentinel = MagicMock() + old_model = transcribe._embedding_model + transcribe._embedding_model = sentinel + try: + result1 = transcribe.get_builtin_embedding_model() + result2 = transcribe.get_builtin_embedding_model() + assert result1 is sentinel + assert result2 is sentinel + finally: + transcribe._embedding_model = old_model + + def test_successful_load_is_cached(self): + old_model = transcribe._embedding_model + old_pyannote_model = transcribe._PyannoteModel + old_pyannote_inference = transcribe._PyannoteInference + fake_inference = MagicMock() + fake_pyannote_model = MagicMock() + fake_pyannote_model.from_pretrained.return_value = MagicMock() + fake_pyannote_inference = MagicMock(return_value=fake_inference) + transcribe._embedding_model = None + transcribe._PyannoteModel = fake_pyannote_model + transcribe._PyannoteInference = fake_pyannote_inference + try: + result = transcribe.get_builtin_embedding_model() + assert result is fake_inference + assert transcribe._embedding_model is fake_inference + finally: + transcribe._PyannoteModel = old_pyannote_model + transcribe._PyannoteInference = old_pyannote_inference + transcribe._embedding_model = old_model + + +class TestEmbeddingBuiltinDuration: + def test_short_audio_below_min_duration_returns_none(self): + fake_model = MagicMock() + wav = _make_wav_bytes(duration_s=0.3) + result = transcribe._get_embedding_builtin(wav, fake_model) + assert result is None + fake_model.assert_not_called() + + def test_audio_at_exact_min_duration_returns_embedding(self): + fake_model = MagicMock() + fake_model.return_value = np.zeros(256, dtype=np.float32) + wav = _make_wav_bytes(duration_s=0.6) + result = transcribe._get_embedding_builtin(wav, fake_model) + assert result is not None + fake_model.assert_called_once() + + def test_audio_just_above_min_duration_returns_embedding(self): + fake_model = MagicMock() + fake_model.return_value = np.zeros(256, dtype=np.float32) + wav = _make_wav_bytes(duration_s=0.7) + result = transcribe._get_embedding_builtin(wav, fake_model) + assert result is not None + fake_model.assert_called_once() + + +class TestDiarizeSegmentsGating: + def test_proceeds_with_builtin_model_even_without_url(self, tmp_path): + wav_path = tmp_path / "test.wav" + wav_bytes = _make_wav_bytes(duration_s=2.0) + wav_path.write_bytes(wav_bytes) + + base = {"text": "hello", "segments": [{"text": "hello", "start": 0.0, "end": 2.0}]} + fake_model = MagicMock() + fake_emb = np.zeros((1, 256), dtype=np.float32) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + with patch.object(transcribe, '_get_embedding', return_value=fake_emb): + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = '' + try: + result = transcribe._diarize_segments(str(wav_path), base) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + assert result["segments"][0].get("speaker") is not None + + def test_skips_diarization_when_no_model_and_no_url(self, tmp_path): + wav_path = tmp_path / "test.wav" + wav_path.write_bytes(_make_wav_bytes(duration_s=1.0)) + + base = {"text": "hi", "segments": [{"text": "hi", "start": 0.0, "end": 1.0}]} + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=None): + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = '' + try: + result = transcribe._diarize_segments(str(wav_path), base) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + assert result["segments"][0]["speaker"] == "SPEAKER_0"