From 704992ed36e3f268054fa1b15e12169e4b3eff83 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 29 Mar 2026 15:36:37 +0000 Subject: [PATCH 1/6] Initial plan From e4714bccd86e42978821eac2725f1cc075d41c64 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 29 Mar 2026 15:43:46 +0000 Subject: [PATCH 2/6] feat: transcription + speaker diarization pipeline Agent-Logs-Url: https://github.com/celluloid-camp/vision/sessions/9475a454-52b5-4684-80b8-4c467bd8ec3b Co-authored-by: younes200 <198514+younes200@users.noreply.github.com> --- app/core/celery_queue.py | 1 + app/core/config.py | 11 ++ app/core/tasks.py | 150 ++++++++++++++ app/detection/transcribe.py | 371 +++++++++++++++++++++++++++++++++++ app/models/result_models.py | 84 +++++++- app/tests/test_transcribe.py | 206 +++++++++++++++++++ env.example | 11 ++ pyproject.toml | 3 + 8 files changed, 836 insertions(+), 1 deletion(-) create mode 100644 app/detection/transcribe.py create mode 100644 app/tests/test_transcribe.py diff --git a/app/core/celery_queue.py b/app/core/celery_queue.py index 35c039c..345d34b 100644 --- a/app/core/celery_queue.py +++ b/app/core/celery_queue.py @@ -17,6 +17,7 @@ TASK_NAME_BY_JOB_TYPE = { "object_detect": "app.core.tasks.process_object_detect_task", "scene_detect": "app.core.tasks.process_scene_detect_task", + "transcribe": "app.core.tasks.process_transcribe_task", } diff --git a/app/core/config.py b/app/core/config.py index fa39833..6233112 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,3 +19,14 @@ # Processing Configuration MAX_WORKERS = 1 # Only 1 worker since we process one job at a time + +# Transcription Configuration +WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "small") +WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "cpu") +WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8") +WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", None) # None = auto-detect + +# Diarization Configuration +DIARIZATION_ENABLED = os.getenv("DIARIZATION_ENABLED", "true").lower() == "true" +PYANNOTE_AUTH_TOKEN = os.getenv("PYANNOTE_AUTH_TOKEN", None) +PYANNOTE_MODEL = os.getenv("PYANNOTE_MODEL", "pyannote/speaker-diarization-3.1") diff --git a/app/core/tasks.py b/app/core/tasks.py index c6cb7b7..b93c320 100644 --- a/app/core/tasks.py +++ b/app/core/tasks.py @@ -390,3 +390,153 @@ def process_scene_detect_task(self, job_data: dict): ) raise + + +# --------------------------------------------------------------------------- +# Transcribe task +# --------------------------------------------------------------------------- + + +@celery_app.task(bind=True, name="app.core.tasks.process_transcribe_task") +def process_transcribe_task(self, job_data: dict): + """Celery task for audio transcription with optional speaker diarization.""" + job_id = job_data["job_id"] + external_id = job_data["external_id"] + video_url = job_data["video_url"] + params = job_data.get("params", {}) + callback_url = job_data.get("callback_url") + start_time = datetime.now().isoformat() + + from app.core.config import ( + DIARIZATION_ENABLED, + PYANNOTE_AUTH_TOKEN, + PYANNOTE_MODEL, + WHISPER_COMPUTE_TYPE, + WHISPER_DEVICE, + WHISPER_LANGUAGE, + WHISPER_MODEL_SIZE, + ) + + model_size = params.get("model_size", WHISPER_MODEL_SIZE) + device = params.get("device", WHISPER_DEVICE) + compute_type = params.get("compute_type", WHISPER_COMPUTE_TYPE) + language = params.get("language", WHISPER_LANGUAGE) or None + diarization_enabled = params.get("diarization_enabled", DIARIZATION_ENABLED) + num_speakers = params.get("num_speakers") + min_speakers = params.get("min_speakers") + max_speakers = params.get("max_speakers") + auth_token = PYANNOTE_AUTH_TOKEN + pyannote_model = PYANNOTE_MODEL + + self.update_state( + state="PROCESSING", + meta={ + "job_id": job_id, + "external_id": external_id, + "video_url": video_url, + "job_type": "transcribe", + "callback_url": callback_url, + "status": "processing", + "progress": 0.0, + "start_time": start_time, + }, + ) + + try: + output_dir = os.path.join("outputs", external_id) + os.makedirs(output_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_filename = f"transcript_{job_id}_{timestamp}.json" + output_path = os.path.join(output_dir, output_filename) + + # Download remote file; local paths are used directly. + if video_url.startswith(("http://", "https://")): + audio_path = download_video(video_url) + else: + audio_path = video_url + + progress_cb = _make_progress_reporter(self, job_id, external_id, start_time) + + from app.detection.transcribe import run_transcription_pipeline + + result = run_transcription_pipeline( + audio_path=audio_path, + model_size=model_size, + device=device, + compute_type=compute_type, + language=language, + diarization_enabled=diarization_enabled, + auth_token=auth_token, + pyannote_model=pyannote_model, + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + progress_callback=progress_cb, + ) + + result["result_type"] = "transcribe" + + with open(output_path, "w") as f: + json.dump(result, f, indent=2) + + logger.info(f"Transcript saved to: {output_path}") + + metadata = { + "language": result["metadata"]["language"], + "audio_duration_sec": result["metadata"]["audio_duration_sec"], + "processing_time_sec": result["metadata"]["processing_time_sec"], + "segment_count": len(result["segments"]), + "speaker_count": len(result["speakers"]), + } + + end_time = datetime.now().isoformat() + logger.info(f"Transcription job {job_id} completed successfully") + + if callback_url: + _send_callback_sync( + job_id, + external_id, + "transcribe", + callback_url, + "completed", + {"result_path": output_path, "metadata": metadata}, + ) + + # Clean up downloaded file + if video_url.startswith(("http://", "https://")): + try: + os.remove(audio_path) + logger.info(f"Cleaned up temporary file: {audio_path}") + except Exception as e: + logger.warning(f"Failed to clean up temporary file: {str(e)}") + + return { + "job_id": job_id, + "external_id": external_id, + "video_url": video_url, + "job_type": "transcribe", + "callback_url": callback_url, + "status": "completed", + "result_path": output_path, + "start_time": start_time, + "end_time": end_time, + "metadata": metadata, + } + + except Exception as e: + error_msg = str(e) + logger.error(f"Transcription job {job_id} failed: {error_msg}") + logger.error(traceback.format_exc()) + + if callback_url: + _send_callback_sync( + job_id, + external_id, + "transcribe", + callback_url, + "failed", + error=error_msg, + ) + + raise diff --git a/app/detection/transcribe.py b/app/detection/transcribe.py new file mode 100644 index 0000000..a9a6a45 --- /dev/null +++ b/app/detection/transcribe.py @@ -0,0 +1,371 @@ +"""Transcription and speaker diarization pipeline. + +Uses faster-whisper for ASR (CPU INT8) and pyannote.audio for speaker diarization. +Both models run fully self-hosted with no external API calls required. +""" + +import logging +import os +import time +from typing import Callable, List, Optional + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Data helpers +# --------------------------------------------------------------------------- + + +def _overlap(a_start: float, a_end: float, b_start: float, b_end: float) -> float: + """Return the duration of overlap between two time intervals.""" + return max(0.0, min(a_end, b_end) - max(a_start, b_start)) + + +def merge_transcript_with_speakers( + asr_segments: List[dict], + diarization_segments: List[dict], +) -> List[dict]: + """Attach speaker labels to ASR segments using overlap-based matching. + + For each ASR segment the diarization segment with the greatest time overlap + is selected and its speaker label is used. If no diarization segment + overlaps the ASR segment the speaker field is set to ``None``. + + Args: + asr_segments: List of dicts with keys ``start``, ``end``, ``text``, + ``confidence`` and optional ``words``. + diarization_segments: List of dicts with keys ``start``, ``end``, + ``speaker``. + + Returns: + A new list of segment dicts that each include a ``speaker`` key. + """ + merged: List[dict] = [] + for seg in asr_segments: + best_speaker: Optional[str] = None + best_overlap = 0.0 + for d_seg in diarization_segments: + ov = _overlap(seg["start"], seg["end"], d_seg["start"], d_seg["end"]) + if ov > best_overlap: + best_overlap = ov + best_speaker = d_seg["speaker"] + result = dict(seg) + result["speaker"] = best_speaker + merged.append(result) + return merged + + +def aggregate_speakers(merged_segments: List[dict]) -> List[dict]: + """Compute total speaking time per speaker from merged segments. + + Args: + merged_segments: Output of :func:`merge_transcript_with_speakers`. + + Returns: + List of dicts ``{"label": str, "total_speaking_time_sec": float}`` + sorted by label. + """ + totals: dict[str, float] = {} + for seg in merged_segments: + speaker = seg.get("speaker") + if speaker is None: + continue + duration = seg["end"] - seg["start"] + totals[speaker] = totals.get(speaker, 0.0) + duration + return [ + {"label": label, "total_speaking_time_sec": round(total, 3)} + for label, total in sorted(totals.items()) + ] + + +# --------------------------------------------------------------------------- +# ASR: faster-whisper +# --------------------------------------------------------------------------- + + +def transcribe_audio( + audio_path: str, + model_size: str = "small", + device: str = "cpu", + compute_type: str = "int8", + language: Optional[str] = None, + progress_callback: Optional[Callable[[float], None]] = None, +) -> dict: + """Transcribe audio using faster-whisper on CPU. + + Args: + audio_path: Path to the audio/video file to transcribe. + model_size: Whisper model size (``tiny``, ``base``, ``small``, + ``medium``, ``large-v2``, …). + device: Inference device (``"cpu"`` or ``"cuda"``). + compute_type: Quantisation type (e.g. ``"int8"``, ``"float16"``). + language: ISO-639-1 language code, or ``None`` for auto-detection. + progress_callback: Optional callable receiving a progress percentage + (0–100). + + Returns: + Dict with keys ``segments`` (list of segment dicts), ``language`` + (detected/forced language string) and ``audio_duration_sec`` (float). + + Raises: + ImportError: If ``faster-whisper`` is not installed. + FileNotFoundError: If ``audio_path`` does not exist. + """ + try: + from faster_whisper import WhisperModel # type: ignore + except ImportError as exc: + raise ImportError( + "faster-whisper is required for transcription. " + "Install it with: pip install faster-whisper" + ) from exc + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + logger.info( + "Loading Whisper model '%s' (device=%s, compute_type=%s)", + model_size, + device, + compute_type, + ) + model = WhisperModel(model_size, device=device, compute_type=compute_type) + + transcribe_kwargs: dict = { + "word_timestamps": True, + } + if language: + transcribe_kwargs["language"] = language + + logger.info("Starting transcription of %s", audio_path) + t0 = time.time() + raw_segments, info = model.transcribe(audio_path, **transcribe_kwargs) + + segments: List[dict] = [] + seg_list = list(raw_segments) # materialise the generator + total = len(seg_list) or 1 + for idx, seg in enumerate(seg_list): + words = None + if seg.words: + words = [ + { + "word": w.word, + "start": round(w.start, 3), + "end": round(w.end, 3), + "probability": round(w.probability, 4), + } + for w in seg.words + ] + segments.append( + { + "id": idx, + "start": round(seg.start, 3), + "end": round(seg.end, 3), + "text": seg.text.strip(), + "confidence": round(float(getattr(seg, "avg_logprob", None) or 0.0), 4) if getattr(seg, "avg_logprob", None) is not None else None, + "words": words, + } + ) + if progress_callback: + progress_callback(min(95.0, (idx + 1) / total * 95.0)) + + elapsed = time.time() - t0 + logger.info( + "Transcription completed in %.1fs – %d segments, language=%s", + elapsed, + len(segments), + info.language, + ) + + if progress_callback: + progress_callback(100.0) + + return { + "segments": segments, + "language": info.language, + "audio_duration_sec": round(info.duration or 0.0, 3), + } + + +# --------------------------------------------------------------------------- +# Diarization: pyannote.audio +# --------------------------------------------------------------------------- + + +def diarize_audio( + audio_path: str, + auth_token: Optional[str] = None, + model_name: str = "pyannote/speaker-diarization-3.1", + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, +) -> List[dict]: + """Run speaker diarization using pyannote.audio. + + Args: + audio_path: Path to the audio/video file. + auth_token: HuggingFace access token required to download the + pyannote model on first use. If ``None`` the value is read + from the ``PYANNOTE_AUTH_TOKEN`` environment variable. + model_name: pyannote pipeline identifier on the HuggingFace Hub. + num_speakers: Exact number of speakers (overrides min/max). + min_speakers: Minimum number of speakers hint. + max_speakers: Maximum number of speakers hint. + + Returns: + List of dicts with keys ``start``, ``end``, ``speaker``. + + Raises: + ImportError: If ``pyannote.audio`` is not installed. + FileNotFoundError: If ``audio_path`` does not exist. + RuntimeError: If the pipeline cannot be loaded (e.g. missing token). + """ + try: + from pyannote.audio import Pipeline # type: ignore + except ImportError as exc: + raise ImportError( + "pyannote.audio is required for speaker diarization. " + "Install it with: pip install pyannote.audio" + ) from exc + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + token = auth_token or os.getenv("PYANNOTE_AUTH_TOKEN") + + logger.info("Loading pyannote pipeline '%s'", model_name) + try: + pipeline = Pipeline.from_pretrained(model_name, use_auth_token=token) + except Exception as exc: + raise RuntimeError( + f"Failed to load pyannote pipeline '{model_name}': {exc}. " + "Make sure PYANNOTE_AUTH_TOKEN is set and you have accepted the " + "model licence on huggingface.co." + ) from exc + + diarize_kwargs: dict = {} + if num_speakers is not None: + diarize_kwargs["num_speakers"] = num_speakers + else: + if min_speakers is not None: + diarize_kwargs["min_speakers"] = min_speakers + if max_speakers is not None: + diarize_kwargs["max_speakers"] = max_speakers + + logger.info("Running speaker diarization on %s", audio_path) + t0 = time.time() + diarization = pipeline(audio_path, **diarize_kwargs) + elapsed = time.time() - t0 + logger.info("Diarization completed in %.1fs", elapsed) + + segments: List[dict] = [] + for turn, _, speaker in diarization.itertracks(yield_label=True): + segments.append( + { + "start": round(turn.start, 3), + "end": round(turn.end, 3), + "speaker": speaker, + } + ) + return segments + + +# --------------------------------------------------------------------------- +# Combined pipeline +# --------------------------------------------------------------------------- + + +def run_transcription_pipeline( + audio_path: str, + model_size: str = "small", + device: str = "cpu", + compute_type: str = "int8", + language: Optional[str] = None, + diarization_enabled: bool = True, + auth_token: Optional[str] = None, + pyannote_model: str = "pyannote/speaker-diarization-3.1", + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + progress_callback: Optional[Callable[[float], None]] = None, +) -> dict: + """Run the full transcription + diarization pipeline. + + Executes ASR first (reports 0–70 % progress), then optionally diarization + (70–90 %), then merges and aggregates (90–100 %). + + Args: + audio_path: Path to the audio/video file. + model_size: Whisper model size. + device: Inference device. + compute_type: Quantisation type. + language: Forced language code, or ``None`` for auto-detection. + diarization_enabled: Whether to run speaker diarization. + auth_token: HuggingFace token for pyannote model. + pyannote_model: pyannote pipeline identifier. + num_speakers: Exact speaker count hint. + min_speakers: Minimum speaker count hint. + max_speakers: Maximum speaker count hint. + progress_callback: Optional callable receiving progress 0–100. + + Returns: + Dict matching the agreed output schema: + ``{"metadata": {...}, "segments": [...], "speakers": [...], + "diarization": [...]}``. + """ + pipeline_start = time.time() + + def asr_progress(pct: float) -> None: + if progress_callback: + progress_callback(pct * 0.70) + + asr_result = transcribe_audio( + audio_path, + model_size=model_size, + device=device, + compute_type=compute_type, + language=language, + progress_callback=asr_progress, + ) + + diarization_segments: List[dict] = [] + if diarization_enabled: + try: + if progress_callback: + progress_callback(70.0) + diarization_segments = diarize_audio( + audio_path, + auth_token=auth_token, + model_name=pyannote_model, + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + if progress_callback: + progress_callback(90.0) + except Exception as exc: + logger.warning( + "Diarization failed – continuing with transcript only. Error: %s", exc + ) + + merged = merge_transcript_with_speakers(asr_result["segments"], diarization_segments) + speakers = aggregate_speakers(merged) + + processing_time = round(time.time() - pipeline_start, 3) + + return { + "metadata": { + "engine": "faster-whisper+pyannote" if diarization_segments else "faster-whisper", + "asr_backend": "faster-whisper", + "diarization_backend": "pyannote" if diarization_segments else None, + "device": device, + "compute_type": compute_type, + "asr_model": model_size, + "language": asr_result["language"], + "audio_duration_sec": asr_result["audio_duration_sec"], + "processing_time_sec": processing_time, + }, + "segments": merged, + "speakers": speakers, + "diarization": diarization_segments, + } diff --git a/app/models/result_models.py b/app/models/result_models.py index 1d49ff4..0262988 100644 --- a/app/models/result_models.py +++ b/app/models/result_models.py @@ -22,6 +22,7 @@ class JobType(str, Enum): object_detect = "object_detect" scene_detect = "scene_detect" + transcribe = "transcribe" def _validate_video_url(v: str) -> str: @@ -102,6 +103,31 @@ class SceneDetectParams(BaseModel): ) +class TranscribeParams(BaseModel): + model_size: str = Field( + "small", description="Whisper model size (tiny, base, small, medium, large-v2, large-v3)" + ) + device: str = Field("cpu", description="Inference device (cpu or cuda)") + compute_type: str = Field( + "int8", description="Quantisation type (int8, float16, float32)" + ) + language: Optional[str] = Field( + None, description="ISO-639-1 language code, or null for auto-detection" + ) + diarization_enabled: bool = Field( + True, description="Whether to run speaker diarization" + ) + num_speakers: Optional[int] = Field( + None, ge=1, description="Exact number of speakers (overrides min/max)" + ) + min_speakers: Optional[int] = Field( + None, ge=1, description="Minimum number of speakers hint" + ) + max_speakers: Optional[int] = Field( + None, ge=1, description="Maximum number of speakers hint" + ) + + # --------------------------------------------------------------------------- # Create-job request (single model, params validated by job_type) # --------------------------------------------------------------------------- @@ -114,7 +140,7 @@ class CreateJobRequest(BaseModel): callback_url: Optional[str] = Field( None, description="Callback URL for job completion notifications" ) - params: Union[ObjectDetectParams, SceneDetectParams] = Field( + params: Union[ObjectDetectParams, SceneDetectParams, TranscribeParams] = Field( default_factory=ObjectDetectParams ) @@ -134,6 +160,8 @@ def coerce_params_by_job_type(cls, data): if isinstance(raw_params, dict): if job_type == "scene_detect": data["params"] = SceneDetectParams(**raw_params) + elif job_type == "transcribe": + data["params"] = TranscribeParams(**raw_params) else: data["params"] = ObjectDetectParams(**raw_params) return data @@ -265,6 +293,59 @@ class SceneDetectResultsModel(BaseModel): sprite_fragments: Optional[list[str]] = None +# --------------------------------------------------------------------------- +# Transcription result models +# --------------------------------------------------------------------------- + + +class WordModel(BaseModel): + word: str + start: float + end: float + probability: float + + +class TranscriptSegmentModel(BaseModel): + id: int + start: float + end: float + speaker: Optional[str] = None + text: str + confidence: Optional[float] = None + words: Optional[list[WordModel]] = None + + +class SpeakerSummaryModel(BaseModel): + label: str + total_speaking_time_sec: float + + +class DiarizationSegmentModel(BaseModel): + start: float + end: float + speaker: str + + +class TranscriptionMetadataModel(BaseModel): + engine: str + asr_backend: str = "faster-whisper" + diarization_backend: Optional[str] = None + device: str + compute_type: str + asr_model: str + language: str + audio_duration_sec: float + processing_time_sec: float + + +class TranscribeResultsModel(BaseModel): + result_type: Literal["transcribe"] = "transcribe" + metadata: TranscriptionMetadataModel + segments: list[TranscriptSegmentModel] + speakers: list[SpeakerSummaryModel] + diarization: list[DiarizationSegmentModel] + + # --------------------------------------------------------------------------- # Job results response (polymorphic data) # --------------------------------------------------------------------------- @@ -274,6 +355,7 @@ class SceneDetectResultsModel(BaseModel): Union[ Annotated[DetectionResultsModel, Tag("object_detect")], Annotated[SceneDetectResultsModel, Tag("scene_detect")], + Annotated[TranscribeResultsModel, Tag("transcribe")], ], Discriminator("result_type"), ] diff --git a/app/tests/test_transcribe.py b/app/tests/test_transcribe.py new file mode 100644 index 0000000..c765316 --- /dev/null +++ b/app/tests/test_transcribe.py @@ -0,0 +1,206 @@ +"""Unit tests for the transcription merge / alignment logic. + +These tests exercise the pure-Python helpers in ``app.detection.transcribe`` +without requiring any ML models or audio files. +""" + +import importlib.util +import os +import sys + +import pytest + +# --------------------------------------------------------------------------- +# Load app.detection.transcribe directly to avoid triggering the heavy +# app.detection.__init__.py which requires cv2 / mediapipe etc. +# --------------------------------------------------------------------------- +_TRANSCRIBE_PATH = os.path.join( + os.path.dirname(__file__), "..", "detection", "transcribe.py" +) +_spec = importlib.util.spec_from_file_location("transcribe_module", _TRANSCRIBE_PATH) +_transcribe_mod = importlib.util.module_from_spec(_spec) +sys.modules.setdefault("transcribe_module", _transcribe_mod) +_spec.loader.exec_module(_transcribe_mod) + +_overlap = _transcribe_mod._overlap +merge_transcript_with_speakers = _transcribe_mod.merge_transcript_with_speakers +aggregate_speakers = _transcribe_mod.aggregate_speakers + + +# --------------------------------------------------------------------------- +# _overlap helper +# --------------------------------------------------------------------------- + + +class TestOverlap: + def test_no_overlap_before(self): + assert _overlap(0.0, 1.0, 2.0, 3.0) == 0.0 + + def test_no_overlap_after(self): + assert _overlap(2.0, 3.0, 0.0, 1.0) == 0.0 + + def test_touching_at_boundary(self): + # Segments share a single point – zero-duration overlap + assert _overlap(0.0, 1.0, 1.0, 2.0) == 0.0 + + def test_partial_overlap(self): + assert _overlap(0.0, 2.0, 1.0, 3.0) == pytest.approx(1.0) + + def test_full_containment(self): + # b is fully inside a + assert _overlap(0.0, 4.0, 1.0, 3.0) == pytest.approx(2.0) + + def test_identical_intervals(self): + assert _overlap(1.5, 3.5, 1.5, 3.5) == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- +# merge_transcript_with_speakers +# --------------------------------------------------------------------------- + + +class TestMergeTranscriptWithSpeakers: + def _seg(self, id, start, end, text="hello"): + return { + "id": id, + "start": start, + "end": end, + "text": text, + "confidence": 0.9, + "words": None, + } + + def _diar(self, start, end, speaker): + return {"start": start, "end": end, "speaker": speaker} + + def test_empty_inputs(self): + result = merge_transcript_with_speakers([], []) + assert result == [] + + def test_no_diarization_segments(self): + asr = [self._seg(0, 0.0, 2.0)] + result = merge_transcript_with_speakers(asr, []) + assert len(result) == 1 + assert result[0]["speaker"] is None + + def test_exact_match(self): + asr = [self._seg(0, 0.5, 3.0)] + diar = [self._diar(0.5, 3.0, "SPEAKER_00")] + result = merge_transcript_with_speakers(asr, diar) + assert result[0]["speaker"] == "SPEAKER_00" + + def test_best_overlap_wins(self): + asr = [self._seg(0, 0.0, 4.0)] + diar = [ + self._diar(0.0, 1.0, "SPEAKER_00"), # 1 s overlap + self._diar(1.0, 4.0, "SPEAKER_01"), # 3 s overlap – should win + ] + result = merge_transcript_with_speakers(asr, diar) + assert result[0]["speaker"] == "SPEAKER_01" + + def test_gap_between_segments(self): + # ASR segment falls in a gap between diarization segments + asr = [self._seg(0, 2.0, 3.0)] + diar = [ + self._diar(0.0, 1.5, "SPEAKER_00"), + self._diar(3.5, 5.0, "SPEAKER_01"), + ] + result = merge_transcript_with_speakers(asr, diar) + assert result[0]["speaker"] is None + + def test_crossing_boundaries(self): + """ASR segment spans across a speaker change; best-overlap speaker wins.""" + asr = [self._seg(0, 0.0, 6.0)] + diar = [ + self._diar(0.0, 4.0, "SPEAKER_00"), # 4 s overlap + self._diar(4.0, 8.0, "SPEAKER_01"), # 2 s overlap + ] + result = merge_transcript_with_speakers(asr, diar) + assert result[0]["speaker"] == "SPEAKER_00" + + def test_original_segment_not_mutated(self): + asr = [self._seg(0, 0.0, 2.0)] + original_keys = set(asr[0].keys()) + merge_transcript_with_speakers(asr, [self._diar(0.0, 2.0, "SPEAKER_00")]) + # Original dict must not gain the 'speaker' key + assert set(asr[0].keys()) == original_keys + + def test_multiple_segments_multiple_speakers(self): + asr = [ + self._seg(0, 0.0, 2.0, "Hello"), + self._seg(1, 3.0, 5.0, "World"), + ] + diar = [ + self._diar(0.0, 2.5, "SPEAKER_00"), + self._diar(2.5, 6.0, "SPEAKER_01"), + ] + result = merge_transcript_with_speakers(asr, diar) + assert result[0]["speaker"] == "SPEAKER_00" + assert result[1]["speaker"] == "SPEAKER_01" + + def test_preserves_all_original_fields(self): + asr = [self._seg(0, 1.0, 3.0, "test")] + diar = [self._diar(1.0, 3.0, "SPEAKER_00")] + result = merge_transcript_with_speakers(asr, diar) + seg = result[0] + assert seg["id"] == 0 + assert seg["start"] == 1.0 + assert seg["end"] == 3.0 + assert seg["text"] == "test" + assert seg["confidence"] == 0.9 + assert seg["words"] is None + + +# --------------------------------------------------------------------------- +# aggregate_speakers +# --------------------------------------------------------------------------- + + +class TestAggregateSpeakers: + def test_empty_input(self): + assert aggregate_speakers([]) == [] + + def test_no_speaker_labels(self): + segs = [{"start": 0.0, "end": 2.0, "text": "x", "speaker": None}] + assert aggregate_speakers(segs) == [] + + def test_single_speaker(self): + segs = [ + {"start": 0.0, "end": 2.0, "speaker": "SPEAKER_00"}, + {"start": 3.0, "end": 5.0, "speaker": "SPEAKER_00"}, + ] + result = aggregate_speakers(segs) + assert len(result) == 1 + assert result[0]["label"] == "SPEAKER_00" + assert result[0]["total_speaking_time_sec"] == pytest.approx(4.0) + + def test_two_speakers(self): + segs = [ + {"start": 0.0, "end": 3.0, "speaker": "SPEAKER_00"}, + {"start": 3.5, "end": 5.5, "speaker": "SPEAKER_01"}, + {"start": 6.0, "end": 7.0, "speaker": "SPEAKER_00"}, + ] + result = aggregate_speakers(segs) + totals = {s["label"]: s["total_speaking_time_sec"] for s in result} + assert totals["SPEAKER_00"] == pytest.approx(4.0) + assert totals["SPEAKER_01"] == pytest.approx(2.0) + + def test_output_sorted_by_label(self): + segs = [ + {"start": 0.0, "end": 1.0, "speaker": "SPEAKER_02"}, + {"start": 1.0, "end": 2.0, "speaker": "SPEAKER_00"}, + {"start": 2.0, "end": 3.0, "speaker": "SPEAKER_01"}, + ] + result = aggregate_speakers(segs) + labels = [s["label"] for s in result] + assert labels == sorted(labels) + + def test_missing_speaker_key_skipped(self): + """Segments without a 'speaker' key are tolerated and skipped.""" + segs = [ + {"start": 0.0, "end": 2.0}, # no speaker key + {"start": 2.0, "end": 4.0, "speaker": "SPEAKER_00"}, + ] + result = aggregate_speakers(segs) + assert len(result) == 1 + assert result[0]["label"] == "SPEAKER_00" diff --git a/env.example b/env.example index 75c7b46..4401e63 100644 --- a/env.example +++ b/env.example @@ -12,3 +12,14 @@ FLOWER_DB=/app/flower/flower.db # Logging Configuration LOG_LEVEL=INFO + +# Transcription Configuration +WHISPER_MODEL_SIZE=small +WHISPER_DEVICE=cpu +WHISPER_COMPUTE_TYPE=int8 +# WHISPER_LANGUAGE=en # Uncomment to force a specific language (default: auto-detect) + +# Diarization Configuration +DIARIZATION_ENABLED=true +# PYANNOTE_AUTH_TOKEN=hf_... # HuggingFace token for pyannote model access +PYANNOTE_MODEL=pyannote/speaker-diarization-3.1 diff --git a/pyproject.toml b/pyproject.toml index 743c1e4..bc304f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,9 @@ dependencies = [ "redis>=5.0.1", "uvicorn>=0.34.3", "scenedetect==0.6.7.1", + "faster-whisper>=1.1.1", + "pyannote.audio>=3.3.2", + "ffmpeg-python>=0.2.0", ] [project.optional-dependencies] From 7eed24a046f0a6458193078a9af59d8d5f514f82 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 29 Mar 2026 16:03:03 +0000 Subject: [PATCH 3/6] test: use real video URL in tests and add transcribe job type tests Agent-Logs-Url: https://github.com/celluloid-camp/vision/sessions/c9195ef1-f43a-4d3d-bee5-b943fa77209d Co-authored-by: younes200 <198514+younes200@users.noreply.github.com> --- app/tests/test_api.py | 59 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/app/tests/test_api.py b/app/tests/test_api.py index e69d49f..b7133c1 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -42,11 +42,14 @@ def _wait_for_api(max_wait: int = 60) -> None: pytest.fail(f"API at {BASE_URL} did not become ready within {max_wait}s") +TEST_VIDEO_URL = "https://pt-storage.celluloid.cloud/web-videos/a53d2ce5-0db1-49bf-9daa-be2dafd46ee9-144.mp4" + + def _object_detect_payload(**overrides) -> dict: base = { "job_type": "object_detect", "external_id": "ci-test-project", - "video_url": "http://example.com/video.mp4", + "video_url": TEST_VIDEO_URL, "params": {"similarity_threshold": 0.5}, } base.update(overrides) @@ -57,13 +60,24 @@ def _scene_detect_payload(**overrides) -> dict: base = { "job_type": "scene_detect", "external_id": "ci-test-project-scene", - "video_url": "http://example.com/video.mp4", + "video_url": TEST_VIDEO_URL, "params": {"threshold": 30.0}, } base.update(overrides) return base +def _transcribe_payload(**overrides) -> dict: + base = { + "job_type": "transcribe", + "external_id": "ci-test-project-transcribe", + "video_url": TEST_VIDEO_URL, + "params": {"model_size": "small", "diarization_enabled": False}, + } + base.update(overrides) + return base + + # --------------------------------------------------------------------------- # Session-scoped fixture: wait for the API to be up # --------------------------------------------------------------------------- @@ -326,12 +340,51 @@ def test_scene_detect_default_params(self): json={ "job_type": "scene_detect", "external_id": "ci-scene-defaults", - "video_url": "http://example.com/video.mp4", + "video_url": TEST_VIDEO_URL, }, headers=HEADERS_AUTH, ) assert r.status_code in (202, 422) + def test_transcribe_response_shape(self): + r = requests.post( + f"{BASE_URL}/job/create", + json=_transcribe_payload(external_id="ci-shape-tr"), + headers=HEADERS_AUTH, + ) + if r.status_code == 202: + data = r.json() + assert data["job_type"] == "transcribe" + assert "job_id" in data + assert "status" in data + assert "queue_position" in data + assert "message" in data + + def test_transcribe_default_params(self): + """transcribe works without explicit params (defaults applied).""" + r = requests.post( + f"{BASE_URL}/job/create", + json={ + "job_type": "transcribe", + "external_id": "ci-transcribe-defaults", + "video_url": TEST_VIDEO_URL, + }, + headers=HEADERS_AUTH, + ) + assert r.status_code in (202, 422) + + def test_transcribe_invalid_job_type_returns_422(self): + r = requests.post( + f"{BASE_URL}/job/create", + json={ + "job_type": "invalid_type", + "external_id": "ci-proj", + "video_url": TEST_VIDEO_URL, + }, + headers=HEADERS_AUTH, + ) + assert r.status_code == 422 + # --------------------------------------------------------------------------- # GET /status/{job_id} From 68a1b3508c336bb29e5222af53790706f6a97d07 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 29 Mar 2026 16:21:46 +0000 Subject: [PATCH 4/6] fix: persist PENDING jobs in Redis registry to fix status/dedup CI failures Agent-Logs-Url: https://github.com/celluloid-camp/vision/sessions/11ea3dc8-9436-4591-8472-a6ebaf2eb656 Co-authored-by: younes200 <198514+younes200@users.noreply.github.com> --- app/core/celery_queue.py | 116 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 110 insertions(+), 6 deletions(-) diff --git a/app/core/celery_queue.py b/app/core/celery_queue.py index 345d34b..4aed783 100644 --- a/app/core/celery_queue.py +++ b/app/core/celery_queue.py @@ -1,18 +1,22 @@ """Celery-based job manager relying on Celery APIs only.""" import ast +import json import logging from datetime import datetime from typing import List, Optional +import redis as redis_lib from celery.result import AsyncResult -from app.core.celery_app import CELERY_QUEUE_NAME, celery_app +from app.core.celery_app import CELERY_QUEUE_NAME, REDIS_URL, celery_app from app.models.schemas import JobStatus logger = logging.getLogger(__name__) ESTIMATED_MINUTES_PER_JOB = 5 +# TTL for the lightweight job registry entries (24 h, matches Celery result_expires) +_JOB_REGISTRY_TTL = 86400 TASK_NAME_BY_JOB_TYPE = { "object_detect": "app.core.tasks.process_object_detect_task", @@ -25,8 +29,76 @@ class CeleryJobManager: def __init__(self, queue_name: str = CELERY_QUEUE_NAME): """Initialize Celery job manager.""" self.queue_name = queue_name + self._redis = redis_lib.from_url(REDIS_URL, decode_responses=True) logger.info("Initialized Celery job manager with queue: %s", queue_name) + # ------------------------------------------------------------------ + # Internal helpers for the lightweight job registry stored in Redis. + # These let us find PENDING tasks that aren't yet in the worker's + # active/reserved lists (happens when the worker is busy and has not + # pre-fetched the task yet). + # ------------------------------------------------------------------ + + def _registry_key(self, job_id: str) -> str: + return f"cvision:job:{job_id}" + + def _save_to_registry(self, job: JobStatus): + """Persist a minimal job record so we can look it up while PENDING.""" + payload = { + "job_id": job.job_id, + "external_id": job.external_id, + "video_url": job.video_url, + "job_type": job.job_type, + "callback_url": job.callback_url or "", + "params": json.dumps(job.params or {}), + "status": "queued", + "start_time": datetime.now().isoformat(), + } + try: + self._redis.hset(self._registry_key(job.job_id), mapping=payload) + self._redis.expire(self._registry_key(job.job_id), _JOB_REGISTRY_TTL) + except Exception as e: + logger.warning("Could not write job registry entry %s: %s", job.job_id, e) + + def _load_from_registry(self, job_id: str) -> Optional[JobStatus]: + """Load a job from the Redis registry (fallback for PENDING state).""" + try: + data = self._redis.hgetall(self._registry_key(job_id)) + if not data: + return None + job = JobStatus( + job_id=data["job_id"], + external_id=data.get("external_id", "unknown"), + video_url=data.get("video_url", "unknown"), + job_type=data.get("job_type", "object_detect"), + callback_url=data.get("callback_url") or None, + params=json.loads(data.get("params", "{}")), + ) + job.status = data.get("status", "queued") + raw_start = data.get("start_time") + if raw_start: + try: + job.start_time = datetime.fromisoformat(raw_start) + except ValueError: + pass + return job + except Exception as e: + logger.warning("Could not read job registry entry %s: %s", job_id, e) + return None + + def _all_registry_jobs(self) -> List[JobStatus]: + """Return all queued jobs stored in the registry.""" + jobs: List[JobStatus] = [] + try: + for key in self._redis.scan_iter("cvision:job:*"): + job_id = key.split(":")[-1] + job = self._load_from_registry(job_id) + if job: + jobs.append(job) + except Exception as e: + logger.warning("Could not list registry jobs: %s", e) + return jobs + def ping(self) -> bool: """Test connectivity through Celery control API.""" try: @@ -138,7 +210,10 @@ def get_job_from_celery(self, job_id: str) -> Optional[JobStatus]: payload = self._extract_job_data(task) return self._job_from_payload(job_id, payload, "queued") - return None + # Fall back to our lightweight Redis registry. A task can be in + # PENDING state but not yet visible to the inspector when the worker + # is already busy and hasn't pre-fetched the task yet. + return self._load_from_registry(job_id) except Exception as e: logger.error("Error getting job %s from Celery: %s", job_id, str(e)) return None @@ -148,8 +223,13 @@ def save_job_to_celery(self, job: JobStatus): logger.debug("save_job_to_celery no-op for job %s", job.job_id) def get_all_jobs(self) -> List[JobStatus]: - """Get all visible queued/processing jobs from Celery inspect APIs.""" - jobs: list[JobStatus] = [] + """Get all visible queued/processing jobs from Celery inspect APIs. + + Also includes jobs in the Redis registry that haven't been pre-fetched + by the worker yet (PENDING state not visible to the inspector). + """ + jobs: List[JobStatus] = [] + seen_ids: set[str] = set() try: active, reserved, scheduled = self._inspect_tasks() @@ -158,19 +238,39 @@ def get_all_jobs(self) -> List[JobStatus]: job_id = task.get("id") if job_id: jobs.append(self._job_from_payload(job_id, payload, "processing")) + seen_ids.add(job_id) for task in reserved + scheduled: payload = self._extract_job_data(task) job_id = task.get("id") if job_id: jobs.append(self._job_from_payload(job_id, payload, "queued")) + seen_ids.add(job_id) except Exception as e: logger.error("Error getting all jobs: %s", str(e)) + + # Add queued jobs from the registry that the inspector hasn't seen yet. + for job in self._all_registry_jobs(): + if job.job_id not in seen_ids: + # Only include if Celery hasn't moved the task past PENDING. + result = AsyncResult(job.job_id, app=celery_app) + if result.state == "PENDING": + jobs.append(job) + seen_ids.add(job.job_id) + return jobs def cleanup_stale_jobs(self): - """Compatibility no-op (no Redis job registry).""" - logger.info("No stale job cleanup required (Celery-only mode).") + """Remove registry entries for jobs that have been fully processed.""" + try: + for key in self._redis.scan_iter("cvision:job:*"): + job_id = key.split(":")[-1] + result = AsyncResult(job_id, app=celery_app) + if result.state not in ("PENDING",): + self._redis.delete(key) + logger.info("Cleaned up stale job registry entries.") + except Exception as e: + logger.warning("Error during stale job cleanup: %s", e) def get_queue_status_info(self): """Get current queue status from Celery inspect APIs.""" @@ -216,6 +316,10 @@ def enqueue_job(self, job: JobStatus): queue=self.queue_name, ) + # Persist a lightweight record so PENDING jobs are discoverable + # even before the worker pre-fetches the task. + self._save_to_registry(job) + logger.info( "Enqueued %s job %s to Celery queue %s", job.job_type, From 2a502839bc4d70311e3f61aea0700271198a133a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 29 Mar 2026 16:35:27 +0000 Subject: [PATCH 5/6] chore: regenerate openapi.json with transcribe job type schemas Agent-Logs-Url: https://github.com/celluloid-camp/vision/sessions/96349765-5ad2-4029-b86f-3e9188f82b89 Co-authored-by: younes200 <198514+younes200@users.noreply.github.com> --- openapi.json | 331 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 327 insertions(+), 4 deletions(-) diff --git a/openapi.json b/openapi.json index 9a8dcc2..1c79cdd 100644 --- a/openapi.json +++ b/openapi.json @@ -2,7 +2,7 @@ "openapi": "3.1.0", "info": { "title": "Celluloid Video Analysis API", - "version": "1.8.0" + "version": "1.11.0" }, "paths": { "/health": { @@ -276,6 +276,9 @@ }, { "$ref": "#/components/schemas/SceneDetectParams" + }, + { + "$ref": "#/components/schemas/TranscribeParams" } ], "title": "Params" @@ -496,6 +499,29 @@ "type": "object", "title": "DetectionStatisticsModel" }, + "DiarizationSegmentModel": { + "properties": { + "start": { + "type": "number", + "title": "Start" + }, + "end": { + "type": "number", + "title": "End" + }, + "speaker": { + "type": "string", + "title": "Speaker" + } + }, + "type": "object", + "required": [ + "start", + "end", + "speaker" + ], + "title": "DiarizationSegmentModel" + }, "HTTPValidationError": { "properties": { "detail": { @@ -613,13 +639,17 @@ }, { "$ref": "#/components/schemas/SceneDetectResultsModel" + }, + { + "$ref": "#/components/schemas/TranscribeResultsModel" } ], "discriminator": { "propertyName": "result_type", "mapping": { "object_detect": "#/components/schemas/DetectionResultsModel", - "scene_detect": "#/components/schemas/SceneDetectResultsModel" + "scene_detect": "#/components/schemas/SceneDetectResultsModel", + "transcribe": "#/components/schemas/TranscribeResultsModel" } } }, @@ -761,7 +791,8 @@ "type": "string", "enum": [ "object_detect", - "scene_detect" + "scene_detect", + "transcribe" ], "title": "JobType" }, @@ -1003,6 +1034,24 @@ ], "title": "SceneInfoModel" }, + "SpeakerSummaryModel": { + "properties": { + "label": { + "type": "string", + "title": "Label" + }, + "total_speaking_time_sec": { + "type": "number", + "title": "Total Speaking Time Sec" + } + }, + "type": "object", + "required": [ + "label", + "total_speaking_time_sec" + ], + "title": "SpeakerSummaryModel" + }, "SpriteMetadataModel": { "properties": { "url": { @@ -1024,6 +1073,252 @@ ], "title": "SpriteMetadataModel" }, + "TranscribeParams": { + "properties": { + "model_size": { + "type": "string", + "title": "Model Size", + "description": "Whisper model size (tiny, base, small, medium, large-v2, large-v3)", + "default": "small" + }, + "device": { + "type": "string", + "title": "Device", + "description": "Inference device (cpu or cuda)", + "default": "cpu" + }, + "compute_type": { + "type": "string", + "title": "Compute Type", + "description": "Quantisation type (int8, float16, float32)", + "default": "int8" + }, + "language": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Language", + "description": "ISO-639-1 language code, or null for auto-detection" + }, + "diarization_enabled": { + "type": "boolean", + "title": "Diarization Enabled", + "description": "Whether to run speaker diarization", + "default": true + }, + "num_speakers": { + "anyOf": [ + { + "type": "integer", + "minimum": 1.0 + }, + { + "type": "null" + } + ], + "title": "Num Speakers", + "description": "Exact number of speakers (overrides min/max)" + }, + "min_speakers": { + "anyOf": [ + { + "type": "integer", + "minimum": 1.0 + }, + { + "type": "null" + } + ], + "title": "Min Speakers", + "description": "Minimum number of speakers hint" + }, + "max_speakers": { + "anyOf": [ + { + "type": "integer", + "minimum": 1.0 + }, + { + "type": "null" + } + ], + "title": "Max Speakers", + "description": "Maximum number of speakers hint" + } + }, + "type": "object", + "title": "TranscribeParams" + }, + "TranscribeResultsModel": { + "properties": { + "result_type": { + "type": "string", + "const": "transcribe", + "title": "Result Type", + "default": "transcribe" + }, + "metadata": { + "$ref": "#/components/schemas/TranscriptionMetadataModel" + }, + "segments": { + "items": { + "$ref": "#/components/schemas/TranscriptSegmentModel" + }, + "type": "array", + "title": "Segments" + }, + "speakers": { + "items": { + "$ref": "#/components/schemas/SpeakerSummaryModel" + }, + "type": "array", + "title": "Speakers" + }, + "diarization": { + "items": { + "$ref": "#/components/schemas/DiarizationSegmentModel" + }, + "type": "array", + "title": "Diarization" + } + }, + "type": "object", + "required": [ + "metadata", + "segments", + "speakers", + "diarization" + ], + "title": "TranscribeResultsModel" + }, + "TranscriptSegmentModel": { + "properties": { + "id": { + "type": "integer", + "title": "Id" + }, + "start": { + "type": "number", + "title": "Start" + }, + "end": { + "type": "number", + "title": "End" + }, + "speaker": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Speaker" + }, + "text": { + "type": "string", + "title": "Text" + }, + "confidence": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Confidence" + }, + "words": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/WordModel" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Words" + } + }, + "type": "object", + "required": [ + "id", + "start", + "end", + "text" + ], + "title": "TranscriptSegmentModel" + }, + "TranscriptionMetadataModel": { + "properties": { + "engine": { + "type": "string", + "title": "Engine" + }, + "asr_backend": { + "type": "string", + "title": "Asr Backend", + "default": "faster-whisper" + }, + "diarization_backend": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Diarization Backend" + }, + "device": { + "type": "string", + "title": "Device" + }, + "compute_type": { + "type": "string", + "title": "Compute Type" + }, + "asr_model": { + "type": "string", + "title": "Asr Model" + }, + "language": { + "type": "string", + "title": "Language" + }, + "audio_duration_sec": { + "type": "number", + "title": "Audio Duration Sec" + }, + "processing_time_sec": { + "type": "number", + "title": "Processing Time Sec" + } + }, + "type": "object", + "required": [ + "engine", + "device", + "compute_type", + "asr_model", + "language", + "audio_duration_sec", + "processing_time_sec" + ], + "title": "TranscriptionMetadataModel" + }, "ValidationError": { "properties": { "loc": { @@ -1089,6 +1384,34 @@ "source" ], "title": "VideoMetadataModel" + }, + "WordModel": { + "properties": { + "word": { + "type": "string", + "title": "Word" + }, + "start": { + "type": "number", + "title": "Start" + }, + "end": { + "type": "number", + "title": "End" + }, + "probability": { + "type": "number", + "title": "Probability" + } + }, + "type": "object", + "required": [ + "word", + "start", + "end", + "probability" + ], + "title": "WordModel" } }, "securitySchemes": { @@ -1121,4 +1444,4 @@ "description": "Webhook operations." } ] -} \ No newline at end of file +} From ca383f739f7f019e70675ed679c62d08c7623504 Mon Sep 17 00:00:00 2001 From: Younes Date: Mon, 30 Mar 2026 14:24:27 +0100 Subject: [PATCH 6/6] Refactor transcription pipeline and deployment tooling. Add robust diarization handling, model persistence configuration, and deployment/hook updates while removing hardcoded secrets from deployment scripts. Made-with: Cursor --- .gitignore | 3 +- Dockerfile | 1 + app/detection/transcribe.py | 360 +++++++++++++++++- deploy.sh | 3 +- .../src/components/VideoDetectionPlayer.tsx | 15 +- lefthook.yml | 4 +- pyproject.toml | 6 +- 7 files changed, 366 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index bd7137d..2169815 100644 --- a/.gitignore +++ b/.gitignore @@ -190,4 +190,5 @@ app/detection/models/*.tflite tmp/ flower .flower/ -.data/ \ No newline at end of file +.data/ +models/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 6872804..12d36fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,7 @@ RUN echo "Acquire::http::Pipeline-Depth 0;" > /etc/apt/apt.conf.d/99custom && \ # Install OpenCV dependencies and Redis RUN apt-get update && apt-get install -y \ + ffmpeg \ libglib2.0-0 \ libsm6 \ libxext6 \ diff --git a/app/detection/transcribe.py b/app/detection/transcribe.py index a9a6a45..7300da2 100644 --- a/app/detection/transcribe.py +++ b/app/detection/transcribe.py @@ -4,14 +4,258 @@ Both models run fully self-hosted with no external API calls required. """ +import json import logging import os +import tempfile import time +from dataclasses import dataclass +from datetime import datetime +from inspect import signature from typing import Callable, List, Optional +from urllib.parse import urlparse + +from app.core.utils import download_file, ensure_dir logger = logging.getLogger(__name__) +def _get_shared_models_root() -> str: + """Return the shared models root directory. + + Priority: + 1) CELLULOID_MODELS_DIR env var + 2) /app/models (container default) + 3) local fallback next to this module + """ + env_dir = os.getenv("CELLULOID_MODELS_DIR") + if env_dir: + try: + ensure_dir(env_dir) + return env_dir + except Exception as exc: + logger.warning( + "Could not use CELLULOID_MODELS_DIR='%s': %s. Falling back.", + env_dir, + exc, + ) + + container_models_dir = "/app/models" + if os.path.isdir(container_models_dir): + try: + ensure_dir(container_models_dir) + return container_models_dir + except Exception as exc: + logger.warning( + "Could not use container models dir '%s': %s. Falling back.", + container_models_dir, + exc, + ) + + fallback_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") + ensure_dir(fallback_dir) + return fallback_dir + + +def _patch_torchaudio_audio_metadata() -> None: + """Patch torchaudio API drift for pyannote 3.x compatibility.""" + try: + import torchaudio # type: ignore + except Exception: + return + + patched = False + + if not hasattr(torchaudio, "AudioMetaData"): + + @dataclass + class _AudioMetaDataShim: + sample_rate: int = 0 + num_frames: int = 0 + num_channels: int = 0 + bits_per_sample: int = 0 + encoding: str = "" + + setattr(torchaudio, "AudioMetaData", _AudioMetaDataShim) + patched = True + + if not hasattr(torchaudio, "list_audio_backends"): + + def _list_audio_backends() -> List[str]: + return ["soundfile"] + + setattr(torchaudio, "list_audio_backends", _list_audio_backends) + patched = True + + if not hasattr(torchaudio, "get_audio_backend"): + setattr(torchaudio, "_compat_audio_backend", "soundfile") + + def _get_audio_backend() -> str: + return str(getattr(torchaudio, "_compat_audio_backend", "soundfile")) + + setattr(torchaudio, "get_audio_backend", _get_audio_backend) + patched = True + + if not hasattr(torchaudio, "set_audio_backend"): + + def _set_audio_backend(backend: Optional[str]) -> None: + if backend in (None, "soundfile", "ffmpeg", "sox"): + setattr(torchaudio, "_compat_audio_backend", backend or "soundfile") + return + raise ValueError(f"Unsupported audio backend: {backend}") + + setattr(torchaudio, "set_audio_backend", _set_audio_backend) + patched = True + + if not hasattr(torchaudio, "info"): + + def _info(uri, backend: Optional[str] = None): # type: ignore[no-untyped-def] + import soundfile as sf # type: ignore + + details = sf.info(uri) + return torchaudio.AudioMetaData( + sample_rate=int(details.samplerate or 0), + num_frames=int(details.frames or 0), + num_channels=int(details.channels or 0), + bits_per_sample=0, + encoding=str(details.format or ""), + ) + + setattr(torchaudio, "info", _info) + patched = True + + if patched: + logger.info("Applied torchaudio compatibility shims for pyannote") + + +def _patch_huggingface_hub_auth_kwargs() -> None: + """Map deprecated `use_auth_token` to `token` for new hub versions.""" + try: + import huggingface_hub as hf # type: ignore + except Exception: + return + + patched = False + + try: + hf_download_params = signature(hf.hf_hub_download).parameters + if "use_auth_token" not in hf_download_params: + original_hf_download = hf.hf_hub_download + + def _hf_hub_download_compat(*args, **kwargs): # type: ignore[no-untyped-def] + if "use_auth_token" in kwargs and "token" not in kwargs: + kwargs["token"] = kwargs.pop("use_auth_token") + else: + kwargs.pop("use_auth_token", None) + return original_hf_download(*args, **kwargs) + + hf.hf_hub_download = _hf_hub_download_compat # type: ignore[assignment] + patched = True + except Exception: + pass + + try: + snapshot_params = signature(hf.snapshot_download).parameters + if "use_auth_token" not in snapshot_params: + original_snapshot = hf.snapshot_download + + def _snapshot_download_compat(*args, **kwargs): # type: ignore[no-untyped-def] + if "use_auth_token" in kwargs and "token" not in kwargs: + kwargs["token"] = kwargs.pop("use_auth_token") + else: + kwargs.pop("use_auth_token", None) + return original_snapshot(*args, **kwargs) + + hf.snapshot_download = _snapshot_download_compat # type: ignore[assignment] + patched = True + except Exception: + pass + + if patched: + logger.info("Applied huggingface_hub auth kwarg compatibility shims") + + +def _get_transcription_models_dir() -> str: + """Return local models directory for transcription backends.""" + models_dir = os.path.join(_get_shared_models_root(), "whisper") + ensure_dir(models_dir) + return models_dir + + +def _get_pyannote_cache_dir() -> str: + """Return local cache directory for pyannote checkpoints.""" + cache_dir = os.path.join(_get_shared_models_root(), "pyannote") + ensure_dir(cache_dir) + return cache_dir + + +def _prepare_diarization_input(audio_path: str) -> tuple[str, bool]: + """Return a diarization-friendly audio path (mono 16k WAV). + + pyannote/torchaudio backends can fail to decode container formats like mp4 + depending on runtime codec support. We normalize to WAV for reliability. + + Returns: + (path, should_cleanup) + """ + ext = os.path.splitext(audio_path)[1].lower() + if ext in {".wav", ".flac", ".ogg"}: + return audio_path, False + + try: + import ffmpeg # type: ignore + except Exception as exc: + raise RuntimeError( + "ffmpeg-python is required to convert media to WAV for diarization." + ) from exc + + with tempfile.NamedTemporaryFile( + prefix="pyannote_", suffix=".wav", delete=False + ) as tmp: + wav_path = tmp.name + + try: + ( + ffmpeg.input(audio_path) + .output(wav_path, ac=1, ar=16000, format="wav") + .overwrite_output() + .run(capture_stdout=True, capture_stderr=True, quiet=True) + ) + return wav_path, True + except Exception as exc: + try: + os.remove(wav_path) + except OSError: + pass + raise RuntimeError( + f"Failed to prepare diarization audio from '{audio_path}': {exc}" + ) + + +def _resolve_faster_whisper_model(model_size: str) -> str: + """Resolve a local faster-whisper model path, downloading once if needed.""" + if os.path.isdir(model_size): + return model_size + + models_dir = _get_transcription_models_dir() + local_model_dir = os.path.join(models_dir, f"faster-whisper-{model_size}") + + # Reuse existing local model if present. + if os.path.isdir(local_model_dir) and os.listdir(local_model_dir): + return local_model_dir + + try: + from huggingface_hub import snapshot_download # type: ignore + except Exception: + # Fallback: let faster-whisper handle download/caching. + return model_size + + repo_id = f"Systran/faster-whisper-{model_size}" + logger.info("Downloading Whisper model '%s' to %s", repo_id, local_model_dir) + snapshot_download(repo_id=repo_id, local_dir=local_model_dir) + return local_model_dir + + # --------------------------------------------------------------------------- # Data helpers # --------------------------------------------------------------------------- @@ -129,7 +373,8 @@ def transcribe_audio( device, compute_type, ) - model = WhisperModel(model_size, device=device, compute_type=compute_type) + resolved_model = _resolve_faster_whisper_model(model_size) + model = WhisperModel(resolved_model, device=device, compute_type=compute_type) transcribe_kwargs: dict = { "word_timestamps": True, @@ -162,7 +407,11 @@ def transcribe_audio( "start": round(seg.start, 3), "end": round(seg.end, 3), "text": seg.text.strip(), - "confidence": round(float(getattr(seg, "avg_logprob", None) or 0.0), 4) if getattr(seg, "avg_logprob", None) is not None else None, + "confidence": ( + round(float(getattr(seg, "avg_logprob", None) or 0.0), 4) + if getattr(seg, "avg_logprob", None) is not None + else None + ), "words": words, } ) @@ -220,6 +469,9 @@ def diarize_audio( FileNotFoundError: If ``audio_path`` does not exist. RuntimeError: If the pipeline cannot be loaded (e.g. missing token). """ + _patch_torchaudio_audio_metadata() + _patch_huggingface_hub_auth_kwargs() + try: from pyannote.audio import Pipeline # type: ignore except ImportError as exc: @@ -231,15 +483,26 @@ def diarize_audio( if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - token = auth_token or os.getenv("PYANNOTE_AUTH_TOKEN") + token = auth_token or os.getenv("PYANNOTE_AUTH_TOKEN") or os.getenv("HF_TOKEN") logger.info("Loading pyannote pipeline '%s'", model_name) try: - pipeline = Pipeline.from_pretrained(model_name, use_auth_token=token) + from_pretrained_params = signature(Pipeline.from_pretrained).parameters + load_kwargs: dict = {} + + if "cache_dir" in from_pretrained_params: + load_kwargs["cache_dir"] = _get_pyannote_cache_dir() + + if token: + if "token" in from_pretrained_params: + load_kwargs["token"] = token + else: + load_kwargs["use_auth_token"] = token + pipeline = Pipeline.from_pretrained(model_name, **load_kwargs) except Exception as exc: raise RuntimeError( f"Failed to load pyannote pipeline '{model_name}': {exc}. " - "Make sure PYANNOTE_AUTH_TOKEN is set and you have accepted the " + "Make sure PYANNOTE_AUTH_TOKEN (or HF_TOKEN) is set and you have accepted the " "model licence on huggingface.co." ) from exc @@ -254,7 +517,17 @@ def diarize_audio( logger.info("Running speaker diarization on %s", audio_path) t0 = time.time() - diarization = pipeline(audio_path, **diarize_kwargs) + diarization_input, cleanup_input = _prepare_diarization_input(audio_path) + try: + diarization = pipeline(diarization_input, **diarize_kwargs) + finally: + if cleanup_input: + try: + os.remove(diarization_input) + except OSError: + logger.warning( + "Could not remove temporary diarization file: %s", diarization_input + ) elapsed = time.time() - t0 logger.info("Diarization completed in %.1fs", elapsed) @@ -348,14 +621,18 @@ def asr_progress(pct: float) -> None: "Diarization failed – continuing with transcript only. Error: %s", exc ) - merged = merge_transcript_with_speakers(asr_result["segments"], diarization_segments) + merged = merge_transcript_with_speakers( + asr_result["segments"], diarization_segments + ) speakers = aggregate_speakers(merged) processing_time = round(time.time() - pipeline_start, 3) return { "metadata": { - "engine": "faster-whisper+pyannote" if diarization_segments else "faster-whisper", + "engine": ( + "faster-whisper+pyannote" if diarization_segments else "faster-whisper" + ), "asr_backend": "faster-whisper", "diarization_backend": "pyannote" if diarization_segments else None, "device": device, @@ -369,3 +646,70 @@ def asr_progress(pct: float) -> None: "speakers": speakers, "diarization": diarization_segments, } + + +def main() -> int: + """CLI entry point for transcription + diarization.""" + import argparse + + parser = argparse.ArgumentParser( + description="Transcribe audio/video with faster-whisper + pyannote" + ) + parser.add_argument("audio_url", help="URL or local path to an audio/video file") + parser.add_argument( + "--output-dir", + default=None, + help="Output directory (default: tmp next to this file)", + ) + parser.add_argument( + "--model-size", + default="small", + help="Whisper model size (default: small)", + ) + parser.add_argument( + "--language", + default=None, + help="Force language code (default: auto-detect)", + ) + parser.add_argument( + "--disable-diarization", + action="store_true", + help="Run transcription without speaker diarization", + ) + args = parser.parse_args() + + audio_url = args.audio_url + tmp_dir = args.output_dir or os.path.join(os.getcwd(), "tmp") + ensure_dir(tmp_dir) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if audio_url.startswith(("http://", "https://")): + filename = os.path.basename(urlparse(audio_url).path) + if not filename or "." not in filename: + filename = "audio.mp4" + name, ext = os.path.splitext(filename) + unique_filename = f"{name}_{timestamp}{ext}" + audio_path = os.path.join(tmp_dir, unique_filename) + print(f"Downloading media to: {audio_path}") + download_file(audio_url, audio_path) + else: + audio_path = audio_url + + results = run_transcription_pipeline( + audio_path=audio_path, + model_size=args.model_size, + language=args.language, + diarization_enabled=not args.disable_diarization, + ) + + output_path = os.path.join(tmp_dir, f"transcription_{timestamp}.json") + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + print(f"Transcription results saved to: {output_path}") + return 0 + + +if __name__ == "__main__": + main() diff --git a/deploy.sh b/deploy.sh index 7220cb7..bcdc304 100755 --- a/deploy.sh +++ b/deploy.sh @@ -24,7 +24,7 @@ docker run -d \ -p 5555:5555 \ -v "$(pwd)/outputs:/app/outputs" \ -v "$(pwd)/flower:/app/flower" \ - -v "$(pwd)/models:/app/models:ro" \ + -v "$(pwd)/models:/app/models" \ -e REDIS_URL="redis://host.docker.internal:6379/0" \ -e API_KEY="xxx" \ -e BASE_URL="http://localhost:8081" \ @@ -33,6 +33,7 @@ docker run -d \ -e FLOWER_UNAUTHENTICATED_API="true" \ -e FLOWER_PERSISTENT="true" \ -e FLOWER_DB="/app/flower/flower.db" \ + -e CELLULOID_MODELS_DIR="/app/models" \ celluloid-video-analysis-api # Wait for service to be ready diff --git a/detection-viewer/src/components/VideoDetectionPlayer.tsx b/detection-viewer/src/components/VideoDetectionPlayer.tsx index d5a9dda..c30377f 100644 --- a/detection-viewer/src/components/VideoDetectionPlayer.tsx +++ b/detection-viewer/src/components/VideoDetectionPlayer.tsx @@ -1,6 +1,6 @@ -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import type { FrameEntry } from '../types' -import { findActiveFrameIndexForPlayback } from '../lib/frameSync' +import { findLastFrameIndexAtOrBefore } from '../lib/frameSync' import { BboxOverlay } from './BboxOverlay' type Props = { @@ -68,21 +68,12 @@ export function VideoDetectionPlayer({ const current = frameIndex >= 0 && frames[frameIndex] ? frames[frameIndex] : undefined - /** Half a source frame at `fps` — only show a keyframe when playback is this close in time. */ - const maxDeltaSec = useMemo( - () => Math.max(0.02, 0.5 / (fps != null && fps > 0 ? fps : 25)), - [fps], - ) - - const maxDeltaRef = useRef(maxDeltaSec) - maxDeltaRef.current = maxDeltaSec - const syncFrameFromVideo = useCallback(() => { const el = videoRef.current const fr = framesRef.current if (!el || fr.length === 0 || skipTimeSync.current) return const t = el.currentTime - const idx = findActiveFrameIndexForPlayback(fr, t, maxDeltaRef.current) + const idx = findLastFrameIndexAtOrBefore(fr, t) if (idx !== frameIndexRef.current) { indexFromPlayback.current = true onFrameIndexChangeRef.current(idx) diff --git a/lefthook.yml b/lefthook.yml index 8ee0ab3..a0718f6 100644 --- a/lefthook.yml +++ b/lefthook.yml @@ -2,7 +2,7 @@ pre-commit: jobs: - name: python-format glob: "*.py" - run: uv run black {staged_files} + run: uv run --extra dev black {staged_files} - name: python-lint glob: "*.py" - run: uv run flake8 {staged_files} + run: uv run --extra dev flake8 {staged_files} diff --git a/pyproject.toml b/pyproject.toml index bc304f8..f27e94b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,6 @@ dependencies = [ "numpy>=1.24.0", "requests>=2.31.0", "Pillow>=10.0.0", - "matplotlib>=3.7.0", "fastapi==0.104.1", "python-dotenv>=1.0.0", "asgiref>=3.8.1", @@ -24,7 +23,10 @@ dependencies = [ "uvicorn>=0.34.3", "scenedetect==0.6.7.1", "faster-whisper>=1.1.1", - "pyannote.audio>=3.3.2", + "pyannote.audio==3.3.2", + "huggingface-hub<1.0", + "torch<2.6", + "torchaudio<2.6", "ffmpeg-python>=0.2.0", ]