Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,5 @@ app/detection/models/*.tflite
tmp/
flower
.flower/
.data/
.data/
models/
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ RUN echo "Acquire::http::Pipeline-Depth 0;" > /etc/apt/apt.conf.d/99custom && \

# Install OpenCV dependencies and Redis
RUN apt-get update && apt-get install -y \
ffmpeg \
libglib2.0-0 \
libsm6 \
libxext6 \
Expand Down
117 changes: 111 additions & 6 deletions app/core/celery_queue.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,104 @@
"""Celery-based job manager relying on Celery APIs only."""

import ast
import json
import logging
from datetime import datetime
from typing import List, Optional

import redis as redis_lib
from celery.result import AsyncResult

from app.core.celery_app import CELERY_QUEUE_NAME, celery_app
from app.core.celery_app import CELERY_QUEUE_NAME, REDIS_URL, celery_app
from app.models.schemas import JobStatus

logger = logging.getLogger(__name__)

ESTIMATED_MINUTES_PER_JOB = 5
# TTL for the lightweight job registry entries (24 h, matches Celery result_expires)
_JOB_REGISTRY_TTL = 86400

TASK_NAME_BY_JOB_TYPE = {
"object_detect": "app.core.tasks.process_object_detect_task",
"scene_detect": "app.core.tasks.process_scene_detect_task",
"transcribe": "app.core.tasks.process_transcribe_task",
}


class CeleryJobManager:
def __init__(self, queue_name: str = CELERY_QUEUE_NAME):
"""Initialize Celery job manager."""
self.queue_name = queue_name
self._redis = redis_lib.from_url(REDIS_URL, decode_responses=True)
logger.info("Initialized Celery job manager with queue: %s", queue_name)

# ------------------------------------------------------------------
# Internal helpers for the lightweight job registry stored in Redis.
# These let us find PENDING tasks that aren't yet in the worker's
# active/reserved lists (happens when the worker is busy and has not
# pre-fetched the task yet).
# ------------------------------------------------------------------

def _registry_key(self, job_id: str) -> str:
return f"cvision:job:{job_id}"

def _save_to_registry(self, job: JobStatus):
"""Persist a minimal job record so we can look it up while PENDING."""
payload = {
"job_id": job.job_id,
"external_id": job.external_id,
"video_url": job.video_url,
"job_type": job.job_type,
"callback_url": job.callback_url or "",
"params": json.dumps(job.params or {}),
"status": "queued",
"start_time": datetime.now().isoformat(),
}
try:
self._redis.hset(self._registry_key(job.job_id), mapping=payload)
self._redis.expire(self._registry_key(job.job_id), _JOB_REGISTRY_TTL)
except Exception as e:
logger.warning("Could not write job registry entry %s: %s", job.job_id, e)

def _load_from_registry(self, job_id: str) -> Optional[JobStatus]:
"""Load a job from the Redis registry (fallback for PENDING state)."""
try:
data = self._redis.hgetall(self._registry_key(job_id))
if not data:
return None
job = JobStatus(
job_id=data["job_id"],
external_id=data.get("external_id", "unknown"),
video_url=data.get("video_url", "unknown"),
job_type=data.get("job_type", "object_detect"),
callback_url=data.get("callback_url") or None,
params=json.loads(data.get("params", "{}")),
)
job.status = data.get("status", "queued")
raw_start = data.get("start_time")
if raw_start:
try:
job.start_time = datetime.fromisoformat(raw_start)
except ValueError:
pass
return job
except Exception as e:
logger.warning("Could not read job registry entry %s: %s", job_id, e)
return None

def _all_registry_jobs(self) -> List[JobStatus]:
"""Return all queued jobs stored in the registry."""
jobs: List[JobStatus] = []
try:
for key in self._redis.scan_iter("cvision:job:*"):
job_id = key.split(":")[-1]
job = self._load_from_registry(job_id)
if job:
jobs.append(job)
except Exception as e:
logger.warning("Could not list registry jobs: %s", e)
return jobs

def ping(self) -> bool:
"""Test connectivity through Celery control API."""
try:
Expand Down Expand Up @@ -137,7 +210,10 @@ def get_job_from_celery(self, job_id: str) -> Optional[JobStatus]:
payload = self._extract_job_data(task)
return self._job_from_payload(job_id, payload, "queued")

return None
# Fall back to our lightweight Redis registry. A task can be in
# PENDING state but not yet visible to the inspector when the worker
# is already busy and hasn't pre-fetched the task yet.
return self._load_from_registry(job_id)
except Exception as e:
logger.error("Error getting job %s from Celery: %s", job_id, str(e))
return None
Expand All @@ -147,8 +223,13 @@ def save_job_to_celery(self, job: JobStatus):
logger.debug("save_job_to_celery no-op for job %s", job.job_id)

def get_all_jobs(self) -> List[JobStatus]:
"""Get all visible queued/processing jobs from Celery inspect APIs."""
jobs: list[JobStatus] = []
"""Get all visible queued/processing jobs from Celery inspect APIs.

Also includes jobs in the Redis registry that haven't been pre-fetched
by the worker yet (PENDING state not visible to the inspector).
"""
jobs: List[JobStatus] = []
seen_ids: set[str] = set()
try:
active, reserved, scheduled = self._inspect_tasks()

Expand All @@ -157,19 +238,39 @@ def get_all_jobs(self) -> List[JobStatus]:
job_id = task.get("id")
if job_id:
jobs.append(self._job_from_payload(job_id, payload, "processing"))
seen_ids.add(job_id)

for task in reserved + scheduled:
payload = self._extract_job_data(task)
job_id = task.get("id")
if job_id:
jobs.append(self._job_from_payload(job_id, payload, "queued"))
seen_ids.add(job_id)
except Exception as e:
logger.error("Error getting all jobs: %s", str(e))

# Add queued jobs from the registry that the inspector hasn't seen yet.
for job in self._all_registry_jobs():
if job.job_id not in seen_ids:
# Only include if Celery hasn't moved the task past PENDING.
result = AsyncResult(job.job_id, app=celery_app)
if result.state == "PENDING":
jobs.append(job)
seen_ids.add(job.job_id)

return jobs

def cleanup_stale_jobs(self):
"""Compatibility no-op (no Redis job registry)."""
logger.info("No stale job cleanup required (Celery-only mode).")
"""Remove registry entries for jobs that have been fully processed."""
try:
for key in self._redis.scan_iter("cvision:job:*"):
job_id = key.split(":")[-1]
result = AsyncResult(job_id, app=celery_app)
if result.state not in ("PENDING",):
self._redis.delete(key)
logger.info("Cleaned up stale job registry entries.")
except Exception as e:
logger.warning("Error during stale job cleanup: %s", e)

def get_queue_status_info(self):
"""Get current queue status from Celery inspect APIs."""
Expand Down Expand Up @@ -215,6 +316,10 @@ def enqueue_job(self, job: JobStatus):
queue=self.queue_name,
)

# Persist a lightweight record so PENDING jobs are discoverable
# even before the worker pre-fetches the task.
self._save_to_registry(job)

logger.info(
"Enqueued %s job %s to Celery queue %s",
job.job_type,
Expand Down
11 changes: 11 additions & 0 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,14 @@

# Processing Configuration
MAX_WORKERS = 1 # Only 1 worker since we process one job at a time

# Transcription Configuration
WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "small")
WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "cpu")
WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", None) # None = auto-detect

# Diarization Configuration
DIARIZATION_ENABLED = os.getenv("DIARIZATION_ENABLED", "true").lower() == "true"
PYANNOTE_AUTH_TOKEN = os.getenv("PYANNOTE_AUTH_TOKEN", None)
PYANNOTE_MODEL = os.getenv("PYANNOTE_MODEL", "pyannote/speaker-diarization-3.1")
150 changes: 150 additions & 0 deletions app/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,153 @@ def process_scene_detect_task(self, job_data: dict):
)

raise


# ---------------------------------------------------------------------------
# Transcribe task
# ---------------------------------------------------------------------------


@celery_app.task(bind=True, name="app.core.tasks.process_transcribe_task")
def process_transcribe_task(self, job_data: dict):
"""Celery task for audio transcription with optional speaker diarization."""
job_id = job_data["job_id"]
external_id = job_data["external_id"]
video_url = job_data["video_url"]
params = job_data.get("params", {})
callback_url = job_data.get("callback_url")
start_time = datetime.now().isoformat()

from app.core.config import (
DIARIZATION_ENABLED,
PYANNOTE_AUTH_TOKEN,
PYANNOTE_MODEL,
WHISPER_COMPUTE_TYPE,
WHISPER_DEVICE,
WHISPER_LANGUAGE,
WHISPER_MODEL_SIZE,
)

model_size = params.get("model_size", WHISPER_MODEL_SIZE)
device = params.get("device", WHISPER_DEVICE)
compute_type = params.get("compute_type", WHISPER_COMPUTE_TYPE)
language = params.get("language", WHISPER_LANGUAGE) or None
diarization_enabled = params.get("diarization_enabled", DIARIZATION_ENABLED)
num_speakers = params.get("num_speakers")
min_speakers = params.get("min_speakers")
max_speakers = params.get("max_speakers")
auth_token = PYANNOTE_AUTH_TOKEN
pyannote_model = PYANNOTE_MODEL

self.update_state(
state="PROCESSING",
meta={
"job_id": job_id,
"external_id": external_id,
"video_url": video_url,
"job_type": "transcribe",
"callback_url": callback_url,
"status": "processing",
"progress": 0.0,
"start_time": start_time,
},
)

try:
output_dir = os.path.join("outputs", external_id)
os.makedirs(output_dir, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"transcript_{job_id}_{timestamp}.json"
output_path = os.path.join(output_dir, output_filename)

# Download remote file; local paths are used directly.
if video_url.startswith(("http://", "https://")):
audio_path = download_video(video_url)
else:
audio_path = video_url

progress_cb = _make_progress_reporter(self, job_id, external_id, start_time)

from app.detection.transcribe import run_transcription_pipeline

result = run_transcription_pipeline(
audio_path=audio_path,
model_size=model_size,
device=device,
compute_type=compute_type,
language=language,
diarization_enabled=diarization_enabled,
auth_token=auth_token,
pyannote_model=pyannote_model,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
progress_callback=progress_cb,
)

result["result_type"] = "transcribe"

with open(output_path, "w") as f:
json.dump(result, f, indent=2)

logger.info(f"Transcript saved to: {output_path}")

metadata = {
"language": result["metadata"]["language"],
"audio_duration_sec": result["metadata"]["audio_duration_sec"],
"processing_time_sec": result["metadata"]["processing_time_sec"],
"segment_count": len(result["segments"]),
"speaker_count": len(result["speakers"]),
}

end_time = datetime.now().isoformat()
logger.info(f"Transcription job {job_id} completed successfully")

if callback_url:
_send_callback_sync(
job_id,
external_id,
"transcribe",
callback_url,
"completed",
{"result_path": output_path, "metadata": metadata},
)

# Clean up downloaded file
if video_url.startswith(("http://", "https://")):
try:
os.remove(audio_path)
logger.info(f"Cleaned up temporary file: {audio_path}")
except Exception as e:
logger.warning(f"Failed to clean up temporary file: {str(e)}")

return {
"job_id": job_id,
"external_id": external_id,
"video_url": video_url,
"job_type": "transcribe",
"callback_url": callback_url,
"status": "completed",
"result_path": output_path,
"start_time": start_time,
"end_time": end_time,
"metadata": metadata,
}

except Exception as e:
error_msg = str(e)
logger.error(f"Transcription job {job_id} failed: {error_msg}")
logger.error(traceback.format_exc())

if callback_url:
_send_callback_sync(
job_id,
external_id,
"transcribe",
callback_url,
"failed",
error=error_msg,
)

raise
Loading
Loading