diff --git a/config/pipeline_configs/PyannoteOrchestration.yaml b/config/pipeline_configs/PyannoteOrchestration.yaml new file mode 100644 index 0000000..53b3373 --- /dev/null +++ b/config/pipeline_configs/PyannoteOrchestration.yaml @@ -0,0 +1,5 @@ +PyannoteOrchestrationPipeline: + config: + out_dir: ./pyannote-orchestration + timeout: 3600 # 1 hour + request_buffer: 30 diff --git a/config/pipeline_configs/PyannoteTranscription.yaml b/config/pipeline_configs/PyannoteTranscription.yaml new file mode 100644 index 0000000..99e7995 --- /dev/null +++ b/config/pipeline_configs/PyannoteTranscription.yaml @@ -0,0 +1,5 @@ +PyannoteTranscriptionPipeline: + config: + out_dir: ./pyannote-transcription + timeout: 3600 # 1 hour + request_buffer: 30 diff --git a/src/openbench/engine/__init__.py b/src/openbench/engine/__init__.py index 2dec8fe..8fe912c 100644 --- a/src/openbench/engine/__init__.py +++ b/src/openbench/engine/__init__.py @@ -1,6 +1,14 @@ from .deepgram_engine import DeepgramApi, DeepgramApiResponse from .elevenlabs_engine import ElevenLabsApi, ElevenLabsApiResponse from .openai_engine import OpenAIApi +from .pyannote_engine import ( + PyannoteAIApi, + PyannoteApiDiarizationOutput, + PyannoteApiOrchestrationOutput, + PyannoteApiSegment, + PyannoteApiTurn, + PyannoteApiWord, +) from .whisperkitpro_engine import ( WhisperKitPro, WhisperKitProConfig, @@ -15,6 +23,12 @@ "ElevenLabsApi", "ElevenLabsApiResponse", "OpenAIApi", + "PyannoteAIApi", + "PyannoteApiDiarizationOutput", + "PyannoteApiOrchestrationOutput", + "PyannoteApiSegment", + "PyannoteApiTurn", + "PyannoteApiWord", "WhisperKitPro", "WhisperKitProInput", "WhisperKitProOutput", diff --git a/src/openbench/engine/pyannote_engine.py b/src/openbench/engine/pyannote_engine.py new file mode 100644 index 0000000..2d3456c --- /dev/null +++ b/src/openbench/engine/pyannote_engine.py @@ -0,0 +1,382 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2026 Argmax, Inc. All Rights Reserved. + +"""PyannoteAI API engine for diarization and transcription.""" + +import os +import time +from datetime import datetime +from pathlib import Path + +import requests +from argmaxtools.utils import get_logger +from pyannote.core import Segment +from pydantic import BaseModel, Field, model_validator + +from ..pipeline_prediction import DiarizationAnnotation + + +__all__ = [ + "PyannoteAIApi", + "PyannoteApiDiarizationOutput", + "PyannoteApiOrchestrationOutput", + "PyannoteApiSegment", + "PyannoteApiWord", + "PyannoteApiTurn", +] + +logger = get_logger(__name__) + + +def to_camel(string: str) -> str: + """Convert snake_case to camelCase.""" + components = string.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +# Response models for diarization segments +class PyannoteApiSegment(BaseModel): + """A single diarization segment from PyannoteAI API.""" + + speaker: str + start: float + end: float + + +class PyannoteApiDiarization(BaseModel): + """Diarization output containing a list of segments.""" + + diarization: list[PyannoteApiSegment] + + def to_pyannote_annotation(self) -> DiarizationAnnotation: + """Convert to pyannote DiarizationAnnotation format.""" + annotation = DiarizationAnnotation() + for segment in self.diarization: + annotation[Segment(segment.start, segment.end)] = segment.speaker + return annotation + + +# Response models for transcription (word-level and turn-level) +class PyannoteApiWord(BaseModel): + """A single word from PyannoteAI transcription output.""" + + start: float + end: float + text: str + speaker: str + + class Config: + populate_by_name = True + + +class PyannoteApiTurn(BaseModel): + """A speaker turn from PyannoteAI transcription output.""" + + start: float + end: float + text: str + speaker: str + + class Config: + populate_by_name = True + + +class PyannoteApiOrchestrationData(BaseModel): + """Output data containing diarization and transcription results.""" + + diarization: list[PyannoteApiSegment] + word_level_transcription: list[PyannoteApiWord] = Field(alias="wordLevelTranscription") + turn_level_transcription: list[PyannoteApiTurn] = Field(alias="turnLevelTranscription") + + class Config: + populate_by_name = True + + def to_pyannote_annotation(self) -> DiarizationAnnotation: + """Convert diarization to pyannote DiarizationAnnotation format.""" + annotation = DiarizationAnnotation() + for segment in self.diarization: + annotation[Segment(segment.start, segment.end)] = segment.speaker + return annotation + + +# Base output class with common fields +class PyannoteApiBaseOutput(BaseModel): + """Base output model with common job metadata.""" + + job_id: str = Field( + description="The id of the job that was submitted to pyannote-ai", + ) + status: str = Field( + description="The status of the job that was submitted to pyannote-ai", + ) + created_at: datetime = Field( + description="The time the job was created", + ) + updated_at: datetime | None = Field( + description="The time the job was updated. For some reason it can be None.", + ) + job_polling_elapsed_time: float = Field( + description="The time it took to poll the job results", + ) + + class Config: + alias_generator = to_camel + populate_by_name = True + + @model_validator(mode="before") + @classmethod + def parse_shape(cls, data: dict) -> dict: + # Handle 'Z' suffix (Zulu/UTC timezone) which fromisoformat doesn't support directly + if isinstance(data.get("createdAt"), str): + created_at = data["createdAt"].replace("Z", "+00:00") + data["createdAt"] = datetime.fromisoformat(created_at) + if isinstance(data.get("updatedAt"), str) and data["updatedAt"] is not None: + updated_at = data["updatedAt"].replace("Z", "+00:00") + data["updatedAt"] = datetime.fromisoformat(updated_at) + return data + + def get_elapsed_time(self) -> float: + """Get the elapsed time for the job.""" + if self.updated_at is not None: + return (self.updated_at - self.created_at).total_seconds() + return self.job_polling_elapsed_time + + +class PyannoteApiDiarizationOutput(PyannoteApiBaseOutput): + """Output model for diarization-only jobs.""" + + output: PyannoteApiDiarization = Field( + description="The diarization output of the job", + ) + + +class PyannoteApiOrchestrationOutput(PyannoteApiBaseOutput): + """Output model for jobs with transcription enabled.""" + + output: PyannoteApiOrchestrationData = Field( + description="The diarization and transcription output of the job", + ) + + +class PyannoteAIApi: + """ + PyannoteAI API client for diarization and transcription. + + Expects the environment variable `PYANNOTE_TOKEN` to be set with a valid pyannote-ai token. + + Args: + timeout: Timeout for job polling in seconds + request_buffer: Buffer for request rate limiting + transcription: Whether to enable transcription (STT) in addition to diarization + """ + + diarization_url = "https://api.pyannote.ai/v1/diarize" + media_url = "https://api.pyannote.ai/v1/media/input" + jobs_url = "https://api.pyannote.ai/v1/jobs" + + def __init__( + self, + timeout: int = 1800, + request_buffer: int = 30, + transcription: bool = False, + ) -> None: + self.timeout = timeout + self.request_buffer = request_buffer + self.transcription = transcription + + # Check that the API key is set + if not os.getenv("PYANNOTE_TOKEN"): + raise ValueError("`PYANNOTE_TOKEN` environment variable is not set") + + def get_presigned_url(self, audio_path: str) -> str: + """ + Get a presigned URL for uploading audio to PyannoteAI temporary storage. + + Args: + audio_path: Path to the local audio file + + Returns: + The media URL for the uploaded audio + """ + logger.debug(f"Getting presigned url for {audio_path}") + name = Path(audio_path).with_suffix(".wav").name + # For some reason if the name has underscores, it will fail + name = "".join([n.capitalize() for n in name.split("_")]) + # Pushing audio file to temporary storage from pyannote-ai + audio_url = f"media://example/{name}" + logger.debug(f"Audio url: {audio_url}") + body = {"url": audio_url} + # Post request to get the presigned url associated with the `audio_url` + response = requests.post( + url=self.media_url, + headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, + json=body, + ) + response.raise_for_status() + data = response.json() + presigned_url = data["url"] + logger.debug(f"Presigned url: {presigned_url}") + + # Upload the audio file to the presigned url + # Audio should be < 24hrs and < 1GB + with open(audio_path, "rb") as audio_file: + requests.put( + url=presigned_url, + data=audio_file, + ) + logger.debug(f"Audio file uploaded to {presigned_url}") + return audio_url + + def diarize( + self, + audio_url: str, + num_speakers: int | None = None, + transcription: bool | None = None, + ) -> requests.Response: + """ + Submit a diarization job to PyannoteAI. + + Args: + audio_url: The media URL of the uploaded audio + num_speakers: Optional number of speakers hint + transcription: Whether to enable transcription (overrides instance setting) + + Returns: + The response from the diarization endpoint + """ + data = {"url": audio_url} + + if num_speakers is not None: + data["numSpeakers"] = num_speakers + + # Use instance transcription setting if not overridden + enable_transcription = transcription if transcription is not None else self.transcription + if enable_transcription: + data["transcription"] = True + + response = requests.post( + self.diarization_url, + headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, + json=data, + ) + response.raise_for_status() + return response + + def get_job_results( + self, + diarization_response: requests.Response, + transcription: bool | None = None, + ) -> PyannoteApiDiarizationOutput | PyannoteApiOrchestrationOutput: + """ + Poll for job results until completion. + + Args: + diarization_response: The response from the diarization endpoint + transcription: Whether transcription was enabled (determines output type) + + Returns: + Either PyannoteApiDiarizationOutput or PyannoteApiOrchestrationOutput + """ + data = diarization_response.json() + headers = diarization_response.headers + job_id = data["jobId"] + logger.debug(f"Starting to poll results for job {job_id}") + + # Get rate limit info from headers with fallback defaults + remaining_requests = int(headers.get("X-RateLimit-Remaining", 30)) + rate_limit = int(headers.get("X-RateLimit-Limit", 30)) + reset_time = int(headers.get("X-RateLimit-Reset", 0)) + + logger.debug( + f"Initial rate limits - Remaining: {remaining_requests}, Limit: {rate_limit}, Reset: {reset_time}s" + ) + + # Use instance transcription setting if not overridden + enable_transcription = transcription if transcription is not None else self.transcription + + start_time = time.time() + elapsed_time = 0 + while elapsed_time < self.timeout: + try: + # Check if we need to wait for rate limit reset + if remaining_requests <= self.request_buffer: + logger.debug( + f"Running low on requests ({remaining_requests} remaining). Waiting {reset_time}s for reset" + ) + time.sleep(reset_time) + remaining_requests = rate_limit + + logger.debug(f"Polling job {job_id}") + response = requests.get( + url=f"{self.jobs_url}/{job_id}", + headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, + ) + response.raise_for_status() + + # Update rate limit information + remaining_requests = int(response.headers.get("X-RateLimit-Remaining", remaining_requests)) + reset_time = int(response.headers.get("X-RateLimit-Reset", reset_time)) + # Add a small buffer to avoid hitting rate limits + safe_remaining = max(1, remaining_requests - self.request_buffer) + delay = reset_time / safe_remaining + logger.debug( + f"Rate limit info - Remaining: {remaining_requests}, Reset: {reset_time}s, Delay: {delay * 1000:.0f}ms" + ) + + job_data = response.json() + job_status = job_data["status"] + logger.debug(f"Job {job_id} status: {job_status}") + + if job_status == "succeeded": + elapsed_time = time.time() - start_time + logger.debug(f"Job {job_id} completed successfully after {elapsed_time:.1f}s") + job_data["jobPollingElapsedTime"] = elapsed_time + + # Return appropriate output type based on transcription flag + if enable_transcription: + return PyannoteApiOrchestrationOutput.model_validate(job_data) + else: + return PyannoteApiDiarizationOutput.model_validate(job_data) + + elif job_status == "failed": + error_msg = job_data.get("error", "No error message provided") + logger.error(f"Job {job_id} failed: {error_msg}") + raise Exception(f"Job failed with error: {error_msg}") + elif job_status == "canceled": + logger.error(f"Job {job_id} was canceled") + raise Exception("Job was canceled") + + elapsed_time = time.time() - start_time + logger.debug(f"Waiting {delay * 1000:.0f}ms before next request") + time.sleep(delay) + + except requests.exceptions.RequestException as e: + logger.error(f"API request failed for job {job_id}: {str(e)}") + raise RuntimeError(f"API request failed: {str(e)}") + + logger.error(f"Job {job_id} timed out after {elapsed_time:.1f}s") + raise TimeoutError(f"Job timed out after {elapsed_time:.1f} seconds") + + def __call__( + self, + audio_path: str, + num_speakers: int | None = None, + transcription: bool | None = None, + ) -> PyannoteApiDiarizationOutput | PyannoteApiOrchestrationOutput: + """ + Process an audio file with diarization and optional transcription. + + Args: + audio_path: Path to the local audio file + num_speakers: Optional number of speakers hint + transcription: Whether to enable transcription (overrides instance setting) + + Returns: + Either PyannoteApiDiarizationOutput or PyannoteApiOrchestrationOutput + """ + # Use instance transcription setting if not overridden + enable_transcription = transcription if transcription is not None else self.transcription + + audio_url = self.get_presigned_url(audio_path) + diarization_response = self.diarize(audio_url, num_speakers, enable_transcription) + return self.get_job_results(diarization_response, enable_transcription) diff --git a/src/openbench/pipeline/diarization/pyannote_api.py b/src/openbench/pipeline/diarization/pyannote_api.py index 0d7c24f..c9fd830 100644 --- a/src/openbench/pipeline/diarization/pyannote_api.py +++ b/src/openbench/pipeline/diarization/pyannote_api.py @@ -1,19 +1,14 @@ # For licensing see accompanying LICENSE.md file. # Copyright (C) 2025 Argmax, Inc. All Rights Reserved. -import os -import time -from datetime import datetime from pathlib import Path from typing import Callable -import requests from argmaxtools.utils import get_logger -from pyannote.core import Segment -from pydantic import BaseModel, Field, model_validator +from pydantic import Field from ...dataset import DiarizationSample -from ...pipeline_prediction import DiarizationAnnotation +from ...engine import PyannoteAIApi, PyannoteApiDiarizationOutput from ..base import Pipeline, PipelineType, register_pipeline from .common import DiarizationOutput, DiarizationPipelineConfig @@ -23,208 +18,6 @@ logger = get_logger(__name__) -class PyannoteApiSegment(BaseModel): - speaker: str - start: float - end: float - - -class PyannoteApiDiarization(BaseModel): - diarization: list[PyannoteApiSegment] - - def to_pyannote_annotation(self) -> DiarizationAnnotation: - annotation = DiarizationAnnotation() - for segment in self.diarization: - annotation[Segment(segment.start, segment.end)] = segment.speaker - return annotation - - -def to_camel(string: str) -> str: - components = string.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - -class PyannoteApiOutput(BaseModel): - job_id: str = Field( - description="The id of the job that was submitted to pyannote-ai", - ) - status: str = Field( - description="The status of the job that was submitted to pyannote-ai", - ) - created_at: datetime = Field( - description="The time the job was created", - ) - updated_at: datetime | None = Field( - description="The time the job was updated. For some reason it can be None.", - ) - job_polling_elapsed_time: float = Field( - description="The time it took to poll the job results", - ) - output: PyannoteApiDiarization = Field( - description="The output of the job", - ) - - class Config: - alias_generator = to_camel - validate_by_name = True - - @model_validator(mode="before") - @classmethod - def parse_shape(cls, data: dict) -> dict: - if isinstance(data["createdAt"], str): - data["createdAt"] = datetime.fromisoformat(data["createdAt"]) - if isinstance(data["updatedAt"], str) and data["updatedAt"] is not None: - data["updatedAt"] = datetime.fromisoformat(data["updatedAt"]) - return data - - def get_elapsed_time(self) -> float: - if self.updated_at is not None: - return (self.updated_at - self.created_at).total_seconds() - return self.job_polling_elapsed_time - - -# Expects a env variable `PYANNOTE_TOKEN` to be set with a valid pyannote-ai token -class PyannoteApi: - diarization_url = "https://api.pyannote.ai/v1/diarize" - media_url = "https://api.pyannote.ai/v1/media/input" - jobs_url = "https://api.pyannote.ai/v1/jobs" - - def __init__( - self, - timeout: int = 1800, - request_buffer: int = 30, - ) -> None: - self.timeout = timeout - self.request_buffer = request_buffer - - def get_presigned_url(self, audio_path: str) -> str: - # We need to push the audio file to temporary storage from pyannote-ai - # we could also push it to S3 or other storage, but this is the easiest way - # to avoid setting up a storage bucket - logger.debug(f"Getting presigned url for {audio_path}") - name = Path(audio_path).with_suffix(".wav").name - # For some reason if the name has underscores, it will fail - name = "".join([n.capitalize() for n in name.split("_")]) - # Pushing audio file to temporary storage from pyannote-ai - audio_url = f"media://example/{name}" - logger.debug(f"Audio url: {audio_url}") - body = {"url": audio_url} - # Post request to get the presigned url associated with the `audio_url` - response = requests.post( - url=self.media_url, - headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, - json=body, - ) - response.raise_for_status() - data = response.json() - presigned_url = data["url"] - logger.debug(f"Presigned url: {presigned_url}") - - # Upload the audio file to the presigned url - # Audio should be < 24hrs and < 1GB - with open(audio_path, "rb") as audio_file: - requests.put( - url=presigned_url, - data=audio_file, - ) - logger.debug(f"Audio file uploaded to {presigned_url}") - return audio_url - - def diarize(self, audio_url: str, num_speakers: int | None = None) -> str: - # We could also pass a Webhook to get the result, but we can poll the job status - # and get the response. This is easier to implement although polling can hit - # rate limits. - data = {"url": audio_url} - - if num_speakers is not None: - data["numSpeakers"] = num_speakers - - response = requests.post( - self.diarization_url, - headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, - json=data, - ) - response.raise_for_status() - return response - - def get_job_results(self, diarization_response: requests.Response) -> PyannoteApiOutput: - data = diarization_response.json() - headers = diarization_response.headers - job_id = data["jobId"] - logger.debug(f"Starting to poll results for job {job_id}") - - # Get rate limit info from headers with fallback defaults - remaining_requests = int(headers.get("X-RateLimit-Remaining", 30)) - rate_limit = int(headers.get("X-RateLimit-Limit", 30)) - reset_time = int(headers.get("X-RateLimit-Reset", 0)) - - logger.debug( - f"Initial rate limits - Remaining: {remaining_requests}, Limit: {rate_limit}, Reset: {reset_time}s" - ) - - start_time = time.time() - elapsed_time = 0 - while elapsed_time < self.timeout: - try: - # Check if we need to wait for rate limit reset - if remaining_requests <= self.request_buffer: - logger.debug( - f"Running low on requests ({remaining_requests} remaining). Waiting {reset_time}s for reset" - ) - time.sleep(reset_time) - remaining_requests = rate_limit - - logger.debug(f"Polling job {job_id}") - response = requests.get( - url=f"{self.jobs_url}/{job_id}", - headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, - ) - response.raise_for_status() - - # Update rate limit information - remaining_requests = int(response.headers.get("X-RateLimit-Remaining", remaining_requests)) - reset_time = int(response.headers.get("X-RateLimit-Reset", reset_time)) - # Add a small buffer to avoid hitting rate limits - safe_remaining = max(1, remaining_requests - self.request_buffer) - delay = reset_time / safe_remaining - logger.debug( - f"Rate limit info - Remaining: {remaining_requests}, Reset: {reset_time}s, Delay: {delay * 1000:.0f}ms" - ) - - job_data = response.json() - job_status = job_data["status"] - logger.debug(f"Job {job_id} status: {job_status}") - - if job_status == "succeeded": - elapsed_time = time.time() - start_time - logger.debug(f"Job {job_id} completed successfully after {elapsed_time:.1f}s") - job_data["jobPollingElapsedTime"] = elapsed_time - return PyannoteApiOutput.model_validate(job_data) - elif job_status == "failed": - error_msg = job_data.get("error", "No error message provided") - logger.error(f"Job {job_id} failed: {error_msg}") - raise Exception(f"Job failed with error: {error_msg}") - elif job_status == "canceled": - logger.error(f"Job {job_id} was canceled") - raise Exception("Job was canceled") - - elapsed_time = time.time() - start_time - logger.debug(f"Waiting {delay * 1000:.0f}ms before next request") - time.sleep(delay) - - except requests.exceptions.RequestException as e: - logger.error(f"API request failed for job {job_id}: {str(e)}") - raise RuntimeError(f"API request failed: {str(e)}") - - logger.error(f"Job {job_id} timed out after {elapsed_time:.1f}s") - raise TimeoutError(f"Job timed out after {elapsed_time:.1f} seconds") - - def __call__(self, audio_path: str, num_speakers: int | None = None) -> PyannoteApiOutput: - audio_url = self.get_presigned_url(audio_path) - diarization_response = self.diarize(audio_url, num_speakers) - return self.get_job_results(diarization_response) - - class PyannoteApiConfig(DiarizationPipelineConfig): timeout: int = Field( default=1800, @@ -246,10 +39,11 @@ class PyannoteApiPipeline(Pipeline): def build_pipeline( self, - ) -> Callable[[dict[str, str | int | None]], PyannoteApiOutput]: - api = PyannoteApi( + ) -> Callable[[dict[str, str | int | None]], PyannoteApiDiarizationOutput]: + api = PyannoteAIApi( timeout=self.config.timeout, request_buffer=self.config.request_buffer, + transcription=False, ) return lambda input_sample: api( audio_path=input_sample["audio_path"], @@ -262,10 +56,10 @@ def parse_input(self, input_sample: DiarizationSample) -> dict[str, str | int | self._audio_path = audio_path return {"audio_path": str(audio_path)} - def parse_output(self, output: PyannoteApiOutput) -> DiarizationOutput: - output = DiarizationOutput( + def parse_output(self, output: PyannoteApiDiarizationOutput) -> DiarizationOutput: + result = DiarizationOutput( prediction=output.output.to_pyannote_annotation(), ) # remove audio from temp self._audio_path.unlink() - return output + return result diff --git a/src/openbench/pipeline/orchestration/__init__.py b/src/openbench/pipeline/orchestration/__init__.py index b7fee67..0e7870f 100644 --- a/src/openbench/pipeline/orchestration/__init__.py +++ b/src/openbench/pipeline/orchestration/__init__.py @@ -5,6 +5,7 @@ from .orchestration_deepgram import DeepgramOrchestrationPipeline, DeepgramOrchestrationPipelineConfig from .orchestration_elevenlabs import ElevenLabsOrchestrationPipeline, ElevenLabsOrchestrationPipelineConfig from .orchestration_openai import OpenAIOrchestrationPipeline, OpenAIOrchestrationPipelineConfig +from .orchestration_pyannote import PyannoteOrchestrationPipeline, PyannoteOrchestrationPipelineConfig from .orchestration_whisperkitpro import WhisperKitProOrchestrationConfig, WhisperKitProOrchestrationPipeline from .whisperx import WhisperXPipeline, WhisperXPipelineConfig @@ -22,4 +23,6 @@ "OpenAIOrchestrationPipelineConfig", "NeMoMTParakeetPipeline", "NeMoMTParakeetPipelineConfig", + "PyannoteOrchestrationPipeline", + "PyannoteOrchestrationPipelineConfig", ] diff --git a/src/openbench/pipeline/orchestration/orchestration_pyannote.py b/src/openbench/pipeline/orchestration/orchestration_pyannote.py new file mode 100644 index 0000000..b9b3c14 --- /dev/null +++ b/src/openbench/pipeline/orchestration/orchestration_pyannote.py @@ -0,0 +1,90 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2026 Argmax, Inc. All Rights Reserved. + +"""PyannoteAI orchestration pipeline (diarization + transcription with speaker attribution).""" + +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import Field + +from ...dataset import OrchestrationSample +from ...engine import PyannoteAIApi, PyannoteApiOrchestrationOutput +from ...pipeline_prediction import Transcript +from ..base import Pipeline, PipelineType, register_pipeline +from .common import OrchestrationConfig, OrchestrationOutput + + +__all__ = ["PyannoteOrchestrationPipeline", "PyannoteOrchestrationPipelineConfig"] + +logger = get_logger(__name__) + +TEMP_AUDIO_DIR = Path("audio_temp") + + +class PyannoteOrchestrationPipelineConfig(OrchestrationConfig): + """Configuration for PyannoteAI orchestration pipeline.""" + + timeout: int = Field( + default=1800, + description="Timeout for the orchestration job in seconds", + ) + request_buffer: int = Field( + default=30, + description="Buffer for the request rate limit", + ) + + +@register_pipeline +class PyannoteOrchestrationPipeline(Pipeline): + """ + PyannoteAI orchestration pipeline. + + Uses the PyannoteAI API with transcription enabled to get both diarization + and speaker-attributed transcription results. + """ + + _config_class = PyannoteOrchestrationPipelineConfig + pipeline_type = PipelineType.ORCHESTRATION + + def build_pipeline( + self, + ) -> Callable[[Path], PyannoteApiOrchestrationOutput]: + api = PyannoteAIApi( + timeout=self.config.timeout, + request_buffer=self.config.request_buffer, + transcription=True, + ) + + def orchestrate(audio_path: Path) -> PyannoteApiOrchestrationOutput: + response = api(audio_path=str(audio_path)) + # Remove temporary audio path + audio_path.unlink(missing_ok=True) + return response + + return orchestrate + + def parse_input(self, input_sample: OrchestrationSample) -> Path: + """Save audio to temporary directory for processing.""" + return input_sample.save_audio(TEMP_AUDIO_DIR) + + def parse_output(self, output: PyannoteApiOrchestrationOutput) -> OrchestrationOutput: + """ + Parse the PyannoteAI response into a Transcript with speaker attribution. + + Uses word-level transcription to get precise word timings with speaker labels. + """ + # Extract words from word-level transcription with speaker attribution + transcript = Transcript.from_words_info( + words=[word.text for word in output.output.word_level_transcription], + start=[word.start for word in output.output.word_level_transcription], + end=[word.end for word in output.output.word_level_transcription], + speaker=[word.speaker for word in output.output.word_level_transcription], + ) + + return OrchestrationOutput( + prediction=transcript, + diarization_output=None, + transcription_output=None, + ) diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index d653f6f..97287b0 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -20,6 +20,7 @@ ElevenLabsOrchestrationPipeline, NeMoMTParakeetPipeline, OpenAIOrchestrationPipeline, + PyannoteOrchestrationPipeline, WhisperKitProOrchestrationPipeline, WhisperXPipeline, ) @@ -38,6 +39,7 @@ GroqTranscriptionPipeline, NeMoTranscriptionPipeline, OpenAITranscriptionPipeline, + PyannoteTranscriptionPipeline, SpeechAnalyzerPipeline, WhisperKitProTranscriptionPipeline, WhisperKitTranscriptionPipeline, @@ -326,6 +328,17 @@ def register_pipeline_aliases() -> None: description="NeMo Multi-Talker Parakeet orchestration pipeline (diarization + transcription).", ) + PipelineRegistry.register_alias( + "pyannote-orchestration", + PyannoteOrchestrationPipeline, + default_config={ + "out_dir": "./pyannote_orchestration_results", + "timeout": 3600, + "request_buffer": 30, + }, + description="PyannoteAI orchestration pipeline (diarization + transcription). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.", + ) + ################# TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( @@ -618,6 +631,17 @@ def register_pipeline_aliases() -> None: description="ElevenLabs transcription pipeline with keyterm prompting support. Requires API key from https://elevenlabs.io/. Set `ELEVENLABS_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "pyannote-transcription", + PyannoteTranscriptionPipeline, + default_config={ + "out_dir": "./pyannote_transcription_results", + "timeout": 3600, + "request_buffer": 30, + }, + description="PyannoteAI transcription pipeline (ignores speaker attribution). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.", + ) + ################# STREAMING TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( diff --git a/src/openbench/pipeline/transcription/__init__.py b/src/openbench/pipeline/transcription/__init__.py index 3fe5900..f465293 100644 --- a/src/openbench/pipeline/transcription/__init__.py +++ b/src/openbench/pipeline/transcription/__init__.py @@ -10,6 +10,7 @@ from .transcription_nemo import NeMoTranscriptionPipeline, NeMoTranscriptionPipelineConfig from .transcription_openai import OpenAITranscriptionPipeline, OpenAITranscriptionPipelineConfig from .transcription_oss_whisper import WhisperOSSTranscriptionPipeline, WhisperOSSTranscriptionPipelineConfig +from .transcription_pyannote import PyannoteTranscriptionPipeline, PyannoteTranscriptionPipelineConfig from .transcription_whisperkitpro import WhisperKitProTranscriptionConfig, WhisperKitProTranscriptionPipeline from .whisperkit import WhisperKitTranscriptionConfig, WhisperKitTranscriptionPipeline @@ -34,4 +35,6 @@ "ElevenLabsTranscriptionPipelineConfig", "NeMoTranscriptionPipeline", "NeMoTranscriptionPipelineConfig", + "PyannoteTranscriptionPipeline", + "PyannoteTranscriptionPipelineConfig", ] diff --git a/src/openbench/pipeline/transcription/transcription_pyannote.py b/src/openbench/pipeline/transcription/transcription_pyannote.py new file mode 100644 index 0000000..ab956e5 --- /dev/null +++ b/src/openbench/pipeline/transcription/transcription_pyannote.py @@ -0,0 +1,88 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2026 Argmax, Inc. All Rights Reserved. + +"""PyannoteAI transcription pipeline (ignores speaker attribution).""" + +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import Field + +from ...dataset import TranscriptionSample +from ...engine import PyannoteAIApi, PyannoteApiOrchestrationOutput +from ...pipeline_prediction import Transcript +from ..base import Pipeline, PipelineType, register_pipeline +from .common import TranscriptionConfig, TranscriptionOutput + + +__all__ = ["PyannoteTranscriptionPipeline", "PyannoteTranscriptionPipelineConfig"] + +logger = get_logger(__name__) + +TEMP_AUDIO_DIR = Path("audio_temp") + + +class PyannoteTranscriptionPipelineConfig(TranscriptionConfig): + """Configuration for PyannoteAI transcription pipeline.""" + + timeout: int = Field( + default=1800, + description="Timeout for the transcription job in seconds", + ) + request_buffer: int = Field( + default=30, + description="Buffer for the request rate limit", + ) + + +@register_pipeline +class PyannoteTranscriptionPipeline(Pipeline): + """ + PyannoteAI transcription pipeline. + + Uses the PyannoteAI API with transcription enabled, but ignores speaker + attribution in the output. This is useful for datasets that don't have + speaker labels. + """ + + _config_class = PyannoteTranscriptionPipelineConfig + pipeline_type = PipelineType.TRANSCRIPTION + + def build_pipeline( + self, + ) -> Callable[[Path], PyannoteApiOrchestrationOutput]: + api = PyannoteAIApi( + timeout=self.config.timeout, + request_buffer=self.config.request_buffer, + transcription=True, + ) + + def transcribe(audio_path: Path) -> PyannoteApiOrchestrationOutput: + response = api(audio_path=str(audio_path)) + # Remove temporary audio path + audio_path.unlink(missing_ok=True) + return response + + return transcribe + + def parse_input(self, input_sample: TranscriptionSample) -> Path: + """Save audio to temporary directory for processing.""" + return input_sample.save_audio(TEMP_AUDIO_DIR) + + def parse_output(self, output: PyannoteApiOrchestrationOutput) -> TranscriptionOutput: + """ + Parse the PyannoteAI response into a Transcript without speaker attribution. + + The word-level transcription contains speaker information, but we ignore it + here since this is a transcription-only pipeline. + """ + # Extract words from word-level transcription, ignoring speaker info + transcript = Transcript.from_words_info( + words=[word.text for word in output.output.word_level_transcription], + start=[word.start for word in output.output.word_level_transcription], + end=[word.end for word in output.output.word_level_transcription], + speaker=None, # Ignore speaker attribution for transcription pipeline + ) + + return TranscriptionOutput(prediction=transcript)