From f7920bd1e599268172c12ef9a676c45a515652b4 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 21 Jun 2026 08:01:22 +0000 Subject: [PATCH 1/4] Use built-in wespeaker model for batch diarization embeddings (#8081) Batch /v2/transcribe was making external HTTP calls to the diarizer service for every audio segment (~18 req/sec at peak). The streaming path already loads wespeaker-voxceleb-resnet34-LM locally but the batch path never used it. Changes: - Move embedding model singleton and WAV loader into transcribe.py (avoids circular import since stream_handler imports from transcribe) - Batch _get_embedding() now tries built-in model first, HTTP fallback - stream_handler.py imports shared helpers instead of duplicating them - Replace torchaudio.load() with wave+numpy+torch (torchaudio is a stub in the Docker image) - 9 new unit tests covering built-in priority, HTTP fallback, and gating Co-Authored-By: Claude Opus 4.6 --- backend/parakeet/stream_handler.py | 53 +---- backend/parakeet/transcribe.py | 95 ++++++++- backend/test.sh | 1 + .../unit/test_parakeet_builtin_embedding.py | 193 ++++++++++++++++++ 4 files changed, 296 insertions(+), 46 deletions(-) create mode 100644 backend/tests/unit/test_parakeet_builtin_embedding.py diff --git a/backend/parakeet/stream_handler.py b/backend/parakeet/stream_handler.py index dff7640d2f2..988f3ecb22f 100644 --- a/backend/parakeet/stream_handler.py +++ b/backend/parakeet/stream_handler.py @@ -23,13 +23,13 @@ from langdetect import detect as langdetect_detect from langdetect.lang_detect_exception import LangDetectException from scipy.spatial.distance import cdist -from transcribe import transcribe_file, _stream_model as _asr_model, INFERENCE_MODE as _INFERENCE_MODE - -try: - from pyannote.audio import Model as _PyannoteModel, Inference as _PyannoteInference -except ImportError: - _PyannoteModel = None - _PyannoteInference = None +from transcribe import ( + transcribe_file, + _stream_model as _asr_model, + INFERENCE_MODE as _INFERENCE_MODE, + get_builtin_embedding_model, + wav_bytes_to_waveform, +) logger = logging.getLogger(__name__) @@ -45,35 +45,6 @@ SPEAKER_EMBEDDING_URL = os.getenv("HOSTED_SPEAKER_EMBEDDING_API_URL", "") MIN_EMBEDDING_AUDIO_S = 0.5 -_embedding_model = None -_embedding_lock = threading.Lock() - - -def _get_builtin_embedding_model(): - global _embedding_model - if _embedding_model is not None: - return _embedding_model - with _embedding_lock: - if _embedding_model is not None: - return _embedding_model - try: - if _PyannoteModel is None or _PyannoteInference is None: - logger.warning("pyannote.audio not installed, built-in embedding unavailable") - return None - model = _PyannoteModel.from_pretrained( - "pyannote/wespeaker-voxceleb-resnet34-LM", token=os.getenv("HUGGINGFACE_TOKEN") - ) - inference = _PyannoteInference(model, window="whole") - if _torch is not None: - device = _torch.device("cuda" if _torch.cuda.is_available() else "cpu") - inference.to(device) - _embedding_model = inference - logger.info("Built-in speaker embedding model loaded (wespeaker-voxceleb-resnet34-LM)") - return _embedding_model - except Exception as e: - logger.warning(f"Could not load built-in embedding model: {e}") - return None - _vad_model = None _vad_lock = threading.Lock() @@ -88,11 +59,6 @@ def _get_builtin_embedding_model(): except ImportError: _torch = None -try: - import torchaudio -except ImportError: - torchaudio = None - def _make_divisible_by(num, factor: int) -> int: return (num // factor) * factor @@ -728,7 +694,7 @@ def _assign_speaker(self, pcm: bytes, start: float, end: float) -> str: return f"SPEAKER_{self._last_speaker}" def _get_embedding(self, wav_bytes: bytes): - model = _get_builtin_embedding_model() + model = get_builtin_embedding_model() if model is not None: return self._get_embedding_builtin(wav_bytes, model) if SPEAKER_EMBEDDING_URL: @@ -737,8 +703,7 @@ def _get_embedding(self, wav_bytes: bytes): def _get_embedding_builtin(self, wav_bytes: bytes, model): try: - buf = io.BytesIO(wav_bytes) - waveform, sample_rate = torchaudio.load(buf) + waveform, sample_rate = wav_bytes_to_waveform(wav_bytes) dur = waveform.shape[1] / sample_rate if dur < MIN_EMBEDDING_AUDIO_S: return None diff --git a/backend/parakeet/transcribe.py b/backend/parakeet/transcribe.py index 42d95b45160..4ebc85032c0 100644 --- a/backend/parakeet/transcribe.py +++ b/backend/parakeet/transcribe.py @@ -1,6 +1,7 @@ import io import os import logging +import threading import wave as _wave import httpx @@ -24,6 +25,70 @@ except ImportError: nemo_asr = None +try: + import torch as _torch +except ImportError: + _torch = None + +try: + from pyannote.audio import Model as _PyannoteModel, Inference as _PyannoteInference +except ImportError: + _PyannoteModel = None + _PyannoteInference = None + +_embedding_model = None +_embedding_lock = threading.Lock() + + +def get_builtin_embedding_model(): + global _embedding_model + if _embedding_model is not None: + return _embedding_model + with _embedding_lock: + if _embedding_model is not None: + return _embedding_model + try: + if _PyannoteModel is None or _PyannoteInference is None: + logger.warning("pyannote.audio not installed, built-in embedding unavailable") + return None + model = _PyannoteModel.from_pretrained( + "pyannote/wespeaker-voxceleb-resnet34-LM", token=os.getenv("HUGGINGFACE_TOKEN") + ) + inference = _PyannoteInference(model, window="whole") + if _torch is not None and _torch.cuda.is_available(): + inference.to(_torch.device("cuda")) + _embedding_model = inference + logger.info("Built-in speaker embedding model loaded (wespeaker-voxceleb-resnet34-LM)") + return _embedding_model + except Exception as e: + logger.warning(f"Could not load built-in embedding model: {e}") + return None + + +def wav_bytes_to_waveform(wav_bytes: bytes): + buf = io.BytesIO(wav_bytes) + with _wave.open(buf, "rb") as wf: + sr = wf.getframerate() + nch = wf.getnchannels() + sw = wf.getsampwidth() + pcm = wf.readframes(wf.getnframes()) + + if sw == 2: + dtype = np.int16 + divisor = 32768.0 + elif sw == 4: + dtype = np.int32 + divisor = 2147483648.0 + else: + dtype = np.int16 + divisor = 32768.0 + + samples = np.frombuffer(pcm, dtype=dtype).astype(np.float32) / divisor + if nch > 1: + samples = samples.reshape(-1, nch).mean(axis=1) + waveform = _torch.from_numpy(samples).unsqueeze(0) + return waveform, sr + def set_gpu_worker(worker) -> None: global _gpu_worker @@ -197,7 +262,7 @@ def _transcribe_nim(file_path: str): def _diarize_segments(file_path: str, base: dict) -> dict: - if not SPEAKER_EMBEDDING_URL: + if not SPEAKER_EMBEDDING_URL and get_builtin_embedding_model() is None: for seg in base["segments"]: seg["speaker"] = "SPEAKER_0" return base @@ -270,7 +335,33 @@ def _extract_segment_wav(wav_bytes: bytes, start: float, end: float) -> bytes: def _get_embedding(wav_bytes: bytes): + model = get_builtin_embedding_model() + if model is not None: + emb = _get_embedding_builtin(wav_bytes, model) + if emb is not None: + return emb + if SPEAKER_EMBEDDING_URL: + return _get_embedding_http(wav_bytes) + return None + + +def _get_embedding_builtin(wav_bytes: bytes, model): + try: + waveform, sample_rate = wav_bytes_to_waveform(wav_bytes) + dur = waveform.shape[1] / sample_rate + if dur < MIN_SEGMENT_DURATION: + return None + emb = model({"waveform": waveform, "sample_rate": sample_rate}) + emb = np.array(emb, dtype=np.float32) + if emb.ndim == 1: + emb = emb.reshape(1, -1) + return emb + except Exception as e: + logger.warning(f"Built-in embedding failed: {e}") + return None + +def _get_embedding_http(wav_bytes: bytes): try: with httpx.Client(timeout=httpx.Timeout(connect=5.0, read=30.0, write=10.0, pool=5.0)) as client: resp = client.post( @@ -289,5 +380,5 @@ def _get_embedding(wav_bytes: bytes): emb = emb.reshape(1, -1) return emb except Exception as e: - logger.warning(f"Embedding extraction failed: {e}") + logger.warning(f"HTTP embedding failed: {e}") return None diff --git a/backend/test.sh b/backend/test.sh index 3f9b65880e0..bb1febc24d8 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -25,6 +25,7 @@ pytest tests/unit/test_parakeet_stream_session.py -v pytest tests/unit/test_parakeet_gpu_worker.py -v pytest tests/unit/test_parakeet_batch_engine.py -v pytest tests/unit/test_parakeet_batch_routing.py -v +pytest tests/unit/test_parakeet_builtin_embedding.py -v pytest tests/unit/test_parakeet_endpoints.py -v pytest tests/unit/test_audiobuffer_guard.py -v pytest tests/unit/test_memory_leak_buffers.py -v diff --git a/backend/tests/unit/test_parakeet_builtin_embedding.py b/backend/tests/unit/test_parakeet_builtin_embedding.py new file mode 100644 index 00000000000..39295464fe1 --- /dev/null +++ b/backend/tests/unit/test_parakeet_builtin_embedding.py @@ -0,0 +1,193 @@ +"""Tests for built-in speaker embedding in parakeet transcribe.py. + +Validates that batch diarization uses the built-in wespeaker model first +and falls back to HTTP only when the built-in model is unavailable. +""" + +import io +import os +import struct +import sys +import wave +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +os.environ.setdefault('PARAKEET_INFERENCE_MODE', 'nemo') +os.environ.setdefault('PARAKEET_STREAM_MODEL', '') +os.environ.setdefault('PARAKEET_DEVICE', 'cpu') +os.environ.setdefault('PARAKEET_TORCH_COMPILE', 'false') +os.environ.setdefault('PARAKEET_CUDA_GRAPHS', 'false') + +_torch_mock = MagicMock() +_torch_mock.cuda.is_available.return_value = False +_torch_mock.cuda.is_bf16_supported.return_value = False +_torch_mock.bfloat16 = 'bfloat16' + + +def _torch_from_numpy(arr): + result = MagicMock() + result.unsqueeze.return_value = result + result.shape = [1, len(arr)] + return result + + +_torch_mock.from_numpy = _torch_from_numpy +sys.modules['torch'] = _torch_mock + +for _mod in [ + 'nemo', + 'nemo.collections', + 'nemo.collections.asr', + 'nemo.collections.asr.models', + 'pyannote', + 'pyannote.audio', +]: + if _mod not in sys.modules: + sys.modules[_mod] = MagicMock() + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../parakeet')) + +if 'transcribe' in sys.modules: + _existing = sys.modules['transcribe'] + if not hasattr(_existing, '__file__') or _existing.__file__ is None: + del sys.modules['transcribe'] + +import transcribe # noqa: E402 + + +def _make_wav_bytes(duration_s=1.0, sample_rate=16000, channels=1): + n_samples = int(duration_s * sample_rate) + buf = io.BytesIO() + with wave.open(buf, 'wb') as wf: + wf.setnchannels(channels) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(struct.pack(f'<{n_samples * channels}h', *([1000] * n_samples * channels))) + return buf.getvalue() + + +class TestWavBytesToWaveform: + def test_returns_waveform_and_sample_rate(self): + wav = _make_wav_bytes(duration_s=0.5, sample_rate=16000) + waveform, sr = transcribe.wav_bytes_to_waveform(wav) + assert sr == 16000 + assert waveform.shape == [1, 8000] + + def test_stereo_downmix(self): + wav = _make_wav_bytes(duration_s=0.5, sample_rate=16000, channels=2) + waveform, sr = transcribe.wav_bytes_to_waveform(wav) + assert sr == 16000 + assert waveform.shape == [1, 8000] + + +class TestGetEmbedding: + def test_uses_builtin_model_first(self): + fake_model = MagicMock() + fake_model.return_value = np.zeros(256, dtype=np.float32) + wav = _make_wav_bytes(duration_s=1.0) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + with patch.object(transcribe, '_get_embedding_http') as http_mock: + result = transcribe._get_embedding(wav) + + assert result is not None + assert result.shape == (1, 256) + http_mock.assert_not_called() + + def test_falls_back_to_http_when_builtin_unavailable(self): + wav = _make_wav_bytes(duration_s=1.0) + http_emb = np.ones((1, 256), dtype=np.float32) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=None): + with patch.object(transcribe, '_get_embedding_http', return_value=http_emb) as http_mock: + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = 'http://fake-diarizer' + try: + result = transcribe._get_embedding(wav) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + http_mock.assert_called_once_with(wav) + assert result is not None + np.testing.assert_array_equal(result, http_emb) + + def test_falls_back_to_http_when_builtin_fails(self): + fake_model = MagicMock() + fake_model.side_effect = RuntimeError("GPU error") + wav = _make_wav_bytes(duration_s=1.0) + http_emb = np.ones((1, 256), dtype=np.float32) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + with patch.object(transcribe, '_get_embedding_http', return_value=http_emb) as http_mock: + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = 'http://fake-diarizer' + try: + result = transcribe._get_embedding(wav) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + http_mock.assert_called_once_with(wav) + np.testing.assert_array_equal(result, http_emb) + + def test_returns_none_when_both_fail(self): + wav = _make_wav_bytes(duration_s=1.0) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=None): + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = '' + try: + result = transcribe._get_embedding(wav) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + assert result is None + + def test_reshapes_1d_embedding(self): + fake_model = MagicMock() + fake_model.return_value = np.zeros(128, dtype=np.float32) + wav = _make_wav_bytes(duration_s=1.0) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + result = transcribe._get_embedding(wav) + + assert result.shape == (1, 128) + + +class TestDiarizeSegmentsGating: + def test_proceeds_with_builtin_model_even_without_url(self, tmp_path): + wav_path = tmp_path / "test.wav" + wav_bytes = _make_wav_bytes(duration_s=2.0) + wav_path.write_bytes(wav_bytes) + + base = {"text": "hello", "segments": [{"text": "hello", "start": 0.0, "end": 2.0}]} + fake_model = MagicMock() + fake_emb = np.zeros((1, 256), dtype=np.float32) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + with patch.object(transcribe, '_get_embedding', return_value=fake_emb): + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = '' + try: + result = transcribe._diarize_segments(str(wav_path), base) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + assert result["segments"][0].get("speaker") is not None + + def test_skips_diarization_when_no_model_and_no_url(self, tmp_path): + wav_path = tmp_path / "test.wav" + wav_path.write_bytes(_make_wav_bytes(duration_s=1.0)) + + base = {"text": "hi", "segments": [{"text": "hi", "start": 0.0, "end": 1.0}]} + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=None): + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = '' + try: + result = transcribe._diarize_segments(str(wav_path), base) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + assert result["segments"][0]["speaker"] == "SPEAKER_0" From 19f6b45fd4b3fa2a1409beb42dc1f24ac574b141 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 21 Jun 2026 08:05:36 +0000 Subject: [PATCH 2/4] Handle WAV sample width edge cases in wav_bytes_to_waveform Add 8-bit unsigned PCM and 32-bit PCM support. Raise ValueError for unsupported widths (e.g. 24-bit) so _get_embedding_builtin returns None and falls back to HTTP instead of producing corrupted waveforms. Co-Authored-By: Claude Opus 4.6 --- backend/parakeet/transcribe.py | 15 +++---- .../unit/test_parakeet_builtin_embedding.py | 44 +++++++++++++++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/backend/parakeet/transcribe.py b/backend/parakeet/transcribe.py index 4ebc85032c0..e6389c8059d 100644 --- a/backend/parakeet/transcribe.py +++ b/backend/parakeet/transcribe.py @@ -73,17 +73,14 @@ def wav_bytes_to_waveform(wav_bytes: bytes): sw = wf.getsampwidth() pcm = wf.readframes(wf.getnframes()) - if sw == 2: - dtype = np.int16 - divisor = 32768.0 + if sw == 1: + samples = np.frombuffer(pcm, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0 + elif sw == 2: + samples = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0 elif sw == 4: - dtype = np.int32 - divisor = 2147483648.0 + samples = np.frombuffer(pcm, dtype=np.int32).astype(np.float32) / 2147483648.0 else: - dtype = np.int16 - divisor = 32768.0 - - samples = np.frombuffer(pcm, dtype=dtype).astype(np.float32) / divisor + raise ValueError(f"Unsupported WAV sample width: {sw} bytes") if nch > 1: samples = samples.reshape(-1, nch).mean(axis=1) waveform = _torch.from_numpy(samples).unsqueeze(0) diff --git a/backend/tests/unit/test_parakeet_builtin_embedding.py b/backend/tests/unit/test_parakeet_builtin_embedding.py index 39295464fe1..5bf96ae76ed 100644 --- a/backend/tests/unit/test_parakeet_builtin_embedding.py +++ b/backend/tests/unit/test_parakeet_builtin_embedding.py @@ -68,6 +68,28 @@ def _make_wav_bytes(duration_s=1.0, sample_rate=16000, channels=1): return buf.getvalue() +def _make_wav_bytes_8bit(duration_s=1.0, sample_rate=16000): + n_samples = int(duration_s * sample_rate) + buf = io.BytesIO() + with wave.open(buf, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(1) + wf.setframerate(sample_rate) + wf.writeframes(bytes([128] * n_samples)) + return buf.getvalue() + + +def _make_wav_bytes_32bit(duration_s=1.0, sample_rate=16000): + n_samples = int(duration_s * sample_rate) + buf = io.BytesIO() + with wave.open(buf, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(4) + wf.setframerate(sample_rate) + wf.writeframes(struct.pack(f'<{n_samples}i', *([100000] * n_samples))) + return buf.getvalue() + + class TestWavBytesToWaveform: def test_returns_waveform_and_sample_rate(self): wav = _make_wav_bytes(duration_s=0.5, sample_rate=16000) @@ -81,6 +103,28 @@ def test_stereo_downmix(self): assert sr == 16000 assert waveform.shape == [1, 8000] + def test_8bit_unsigned_pcm(self): + wav = _make_wav_bytes_8bit(duration_s=0.5, sample_rate=16000) + waveform, sr = transcribe.wav_bytes_to_waveform(wav) + assert sr == 16000 + assert waveform.shape == [1, 8000] + + def test_32bit_pcm(self): + wav = _make_wav_bytes_32bit(duration_s=0.5, sample_rate=16000) + waveform, sr = transcribe.wav_bytes_to_waveform(wav) + assert sr == 16000 + assert waveform.shape == [1, 8000] + + def test_unsupported_width_raises(self): + buf = io.BytesIO() + with wave.open(buf, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(3) + wf.setframerate(16000) + wf.writeframes(b'\x00\x00\x00' * 8000) + with pytest.raises(ValueError, match="Unsupported WAV sample width"): + transcribe.wav_bytes_to_waveform(buf.getvalue()) + class TestGetEmbedding: def test_uses_builtin_model_first(self): From e24c287d14d7cdd1cbeca725242ea97dfa65d0f1 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 21 Jun 2026 08:09:39 +0000 Subject: [PATCH 3/4] Add tester-requested coverage: both-fail, model cache, duration boundary - test_returns_none_when_builtin_fails_and_http_fails: both paths fail - TestGetBuiltinEmbeddingModel: pyannote unavailable returns None, cached model returned without re-loading - TestEmbeddingBuiltinDuration: short audio below MIN_SEGMENT_DURATION returns None without calling model, at-duration audio proceeds Co-Authored-By: Claude Opus 4.6 --- .../unit/test_parakeet_builtin_embedding.py | 58 ++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/backend/tests/unit/test_parakeet_builtin_embedding.py b/backend/tests/unit/test_parakeet_builtin_embedding.py index 5bf96ae76ed..6a358786ded 100644 --- a/backend/tests/unit/test_parakeet_builtin_embedding.py +++ b/backend/tests/unit/test_parakeet_builtin_embedding.py @@ -175,7 +175,7 @@ def test_falls_back_to_http_when_builtin_fails(self): http_mock.assert_called_once_with(wav) np.testing.assert_array_equal(result, http_emb) - def test_returns_none_when_both_fail(self): + def test_returns_none_when_no_builtin_no_url(self): wav = _make_wav_bytes(duration_s=1.0) with patch.object(transcribe, 'get_builtin_embedding_model', return_value=None): @@ -188,6 +188,22 @@ def test_returns_none_when_both_fail(self): assert result is None + def test_returns_none_when_builtin_fails_and_http_fails(self): + fake_model = MagicMock() + fake_model.side_effect = RuntimeError("GPU error") + wav = _make_wav_bytes(duration_s=1.0) + + with patch.object(transcribe, 'get_builtin_embedding_model', return_value=fake_model): + with patch.object(transcribe, '_get_embedding_http', return_value=None): + old_url = transcribe.SPEAKER_EMBEDDING_URL + transcribe.SPEAKER_EMBEDDING_URL = 'http://fake-diarizer' + try: + result = transcribe._get_embedding(wav) + finally: + transcribe.SPEAKER_EMBEDDING_URL = old_url + + assert result is None + def test_reshapes_1d_embedding(self): fake_model = MagicMock() fake_model.return_value = np.zeros(128, dtype=np.float32) @@ -199,6 +215,46 @@ def test_reshapes_1d_embedding(self): assert result.shape == (1, 128) +class TestGetBuiltinEmbeddingModel: + def test_returns_none_when_pyannote_unavailable(self): + old_model = transcribe._embedding_model + transcribe._embedding_model = None + old_pyannote = transcribe._PyannoteModel + transcribe._PyannoteModel = None + try: + result = transcribe.get_builtin_embedding_model() + assert result is None + finally: + transcribe._PyannoteModel = old_pyannote + transcribe._embedding_model = old_model + + def test_returns_cached_model(self): + sentinel = MagicMock() + old_model = transcribe._embedding_model + transcribe._embedding_model = sentinel + try: + assert transcribe.get_builtin_embedding_model() is sentinel + finally: + transcribe._embedding_model = old_model + + +class TestEmbeddingBuiltinDuration: + def test_short_audio_below_min_duration_returns_none(self): + fake_model = MagicMock() + wav = _make_wav_bytes(duration_s=0.3) + result = transcribe._get_embedding_builtin(wav, fake_model) + assert result is None + fake_model.assert_not_called() + + def test_audio_at_min_duration_returns_embedding(self): + fake_model = MagicMock() + fake_model.return_value = np.zeros(256, dtype=np.float32) + wav = _make_wav_bytes(duration_s=0.7) + result = transcribe._get_embedding_builtin(wav, fake_model) + assert result is not None + fake_model.assert_called_once() + + class TestDiarizeSegmentsGating: def test_proceeds_with_builtin_model_even_without_url(self, tmp_path): wav_path = tmp_path / "test.wav" From 2303e0d6683297b140c588ddcbe243e090b99120 Mon Sep 17 00:00:00 2001 From: beastoin Date: Sun, 21 Jun 2026 08:14:29 +0000 Subject: [PATCH 4/4] Fix tester feedback: exact boundary tests and model cache verification - test_audio_at_exact_min_duration: use 0.6s (MIN_SEGMENT_DURATION) - test_audio_just_above_min_duration: use 0.7s - test_successful_load_is_cached: verify pyannote load result is stored - test_returns_cached_model_without_reload: verify cached across calls Co-Authored-By: Claude Opus 4.6 --- .../unit/test_parakeet_builtin_embedding.py | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/backend/tests/unit/test_parakeet_builtin_embedding.py b/backend/tests/unit/test_parakeet_builtin_embedding.py index 6a358786ded..754359413bb 100644 --- a/backend/tests/unit/test_parakeet_builtin_embedding.py +++ b/backend/tests/unit/test_parakeet_builtin_embedding.py @@ -228,15 +228,38 @@ def test_returns_none_when_pyannote_unavailable(self): transcribe._PyannoteModel = old_pyannote transcribe._embedding_model = old_model - def test_returns_cached_model(self): + def test_returns_cached_model_without_reload(self): sentinel = MagicMock() old_model = transcribe._embedding_model transcribe._embedding_model = sentinel try: - assert transcribe.get_builtin_embedding_model() is sentinel + result1 = transcribe.get_builtin_embedding_model() + result2 = transcribe.get_builtin_embedding_model() + assert result1 is sentinel + assert result2 is sentinel finally: transcribe._embedding_model = old_model + def test_successful_load_is_cached(self): + old_model = transcribe._embedding_model + old_pyannote_model = transcribe._PyannoteModel + old_pyannote_inference = transcribe._PyannoteInference + fake_inference = MagicMock() + fake_pyannote_model = MagicMock() + fake_pyannote_model.from_pretrained.return_value = MagicMock() + fake_pyannote_inference = MagicMock(return_value=fake_inference) + transcribe._embedding_model = None + transcribe._PyannoteModel = fake_pyannote_model + transcribe._PyannoteInference = fake_pyannote_inference + try: + result = transcribe.get_builtin_embedding_model() + assert result is fake_inference + assert transcribe._embedding_model is fake_inference + finally: + transcribe._PyannoteModel = old_pyannote_model + transcribe._PyannoteInference = old_pyannote_inference + transcribe._embedding_model = old_model + class TestEmbeddingBuiltinDuration: def test_short_audio_below_min_duration_returns_none(self): @@ -246,7 +269,15 @@ def test_short_audio_below_min_duration_returns_none(self): assert result is None fake_model.assert_not_called() - def test_audio_at_min_duration_returns_embedding(self): + def test_audio_at_exact_min_duration_returns_embedding(self): + fake_model = MagicMock() + fake_model.return_value = np.zeros(256, dtype=np.float32) + wav = _make_wav_bytes(duration_s=0.6) + result = transcribe._get_embedding_builtin(wav, fake_model) + assert result is not None + fake_model.assert_called_once() + + def test_audio_just_above_min_duration_returns_embedding(self): fake_model = MagicMock() fake_model.return_value = np.zeros(256, dtype=np.float32) wav = _make_wav_bytes(duration_s=0.7)