From 6625b62f7dab4efb5a8beb336bc507236591dfb8 Mon Sep 17 00:00:00 2001 From: Eduardo Pacheco Date: Tue, 13 Jan 2026 15:14:39 -0300 Subject: [PATCH 1/2] add: pyannote orchestration --- .../PyannoteOrchestration.yaml | 5 + .../PyannoteTranscription.yaml | 5 + src/openbench/engine/__init__.py | 14 + src/openbench/engine/pyannote_engine.py | 382 ++++++++++++++++++ .../pipeline/diarization/pyannote_api.py | 222 +--------- .../pipeline/orchestration/__init__.py | 3 + .../orchestration/orchestration_pyannote.py | 90 +++++ src/openbench/pipeline/pipeline_aliases.py | 24 ++ .../pipeline/transcription/__init__.py | 3 + .../transcription/transcription_pyannote.py | 88 ++++ 10 files changed, 622 insertions(+), 214 deletions(-) create mode 100644 config/pipeline_configs/PyannoteOrchestration.yaml create mode 100644 config/pipeline_configs/PyannoteTranscription.yaml create mode 100644 src/openbench/engine/pyannote_engine.py create mode 100644 src/openbench/pipeline/orchestration/orchestration_pyannote.py create mode 100644 src/openbench/pipeline/transcription/transcription_pyannote.py diff --git a/config/pipeline_configs/PyannoteOrchestration.yaml b/config/pipeline_configs/PyannoteOrchestration.yaml new file mode 100644 index 0000000..53b3373 --- /dev/null +++ b/config/pipeline_configs/PyannoteOrchestration.yaml @@ -0,0 +1,5 @@ +PyannoteOrchestrationPipeline: + config: + out_dir: ./pyannote-orchestration + timeout: 3600 # 1 hour + request_buffer: 30 diff --git a/config/pipeline_configs/PyannoteTranscription.yaml b/config/pipeline_configs/PyannoteTranscription.yaml new file mode 100644 index 0000000..99e7995 --- /dev/null +++ b/config/pipeline_configs/PyannoteTranscription.yaml @@ -0,0 +1,5 @@ +PyannoteTranscriptionPipeline: + config: + out_dir: ./pyannote-transcription + timeout: 3600 # 1 hour + request_buffer: 30 diff --git a/src/openbench/engine/__init__.py b/src/openbench/engine/__init__.py index 9f30aa1..9739e04 100644 --- a/src/openbench/engine/__init__.py +++ b/src/openbench/engine/__init__.py @@ -1,5 +1,13 @@ from .deepgram_engine import DeepgramApi, DeepgramApiResponse from .openai_engine import OpenAIApi +from .pyannote_engine import ( + PyannoteAIApi, + PyannoteApiDiarizationOutput, + PyannoteApiOrchestrationOutput, + PyannoteApiSegment, + PyannoteApiTurn, + PyannoteApiWord, +) from .whisperkitpro_engine import ( WhisperKitPro, WhisperKitProConfig, @@ -12,6 +20,12 @@ "DeepgramApi", "DeepgramApiResponse", "OpenAIApi", + "PyannoteAIApi", + "PyannoteApiDiarizationOutput", + "PyannoteApiOrchestrationOutput", + "PyannoteApiSegment", + "PyannoteApiTurn", + "PyannoteApiWord", "WhisperKitPro", "WhisperKitProInput", "WhisperKitProOutput", diff --git a/src/openbench/engine/pyannote_engine.py b/src/openbench/engine/pyannote_engine.py new file mode 100644 index 0000000..b73a92f --- /dev/null +++ b/src/openbench/engine/pyannote_engine.py @@ -0,0 +1,382 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +"""PyannoteAI API engine for diarization and transcription.""" + +import os +import time +from datetime import datetime +from pathlib import Path + +import requests +from argmaxtools.utils import get_logger +from pyannote.core import Segment +from pydantic import BaseModel, Field, model_validator + +from ..pipeline_prediction import DiarizationAnnotation + + +__all__ = [ + "PyannoteAIApi", + "PyannoteApiDiarizationOutput", + "PyannoteApiOrchestrationOutput", + "PyannoteApiSegment", + "PyannoteApiWord", + "PyannoteApiTurn", +] + +logger = get_logger(__name__) + + +def to_camel(string: str) -> str: + """Convert snake_case to camelCase.""" + components = string.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +# Response models for diarization segments +class PyannoteApiSegment(BaseModel): + """A single diarization segment from PyannoteAI API.""" + + speaker: str + start: float + end: float + + +class PyannoteApiDiarization(BaseModel): + """Diarization output containing a list of segments.""" + + diarization: list[PyannoteApiSegment] + + def to_pyannote_annotation(self) -> DiarizationAnnotation: + """Convert to pyannote DiarizationAnnotation format.""" + annotation = DiarizationAnnotation() + for segment in self.diarization: + annotation[Segment(segment.start, segment.end)] = segment.speaker + return annotation + + +# Response models for transcription (word-level and turn-level) +class PyannoteApiWord(BaseModel): + """A single word from PyannoteAI transcription output.""" + + start: float + end: float + text: str + speaker: str + + class Config: + populate_by_name = True + + +class PyannoteApiTurn(BaseModel): + """A speaker turn from PyannoteAI transcription output.""" + + start: float + end: float + text: str + speaker: str + + class Config: + populate_by_name = True + + +class PyannoteApiOrchestrationData(BaseModel): + """Output data containing diarization and transcription results.""" + + diarization: list[PyannoteApiSegment] + word_level_transcription: list[PyannoteApiWord] = Field(alias="wordLevelTranscription") + turn_level_transcription: list[PyannoteApiTurn] = Field(alias="turnLevelTranscription") + + class Config: + populate_by_name = True + + def to_pyannote_annotation(self) -> DiarizationAnnotation: + """Convert diarization to pyannote DiarizationAnnotation format.""" + annotation = DiarizationAnnotation() + for segment in self.diarization: + annotation[Segment(segment.start, segment.end)] = segment.speaker + return annotation + + +# Base output class with common fields +class PyannoteApiBaseOutput(BaseModel): + """Base output model with common job metadata.""" + + job_id: str = Field( + description="The id of the job that was submitted to pyannote-ai", + ) + status: str = Field( + description="The status of the job that was submitted to pyannote-ai", + ) + created_at: datetime = Field( + description="The time the job was created", + ) + updated_at: datetime | None = Field( + description="The time the job was updated. For some reason it can be None.", + ) + job_polling_elapsed_time: float = Field( + description="The time it took to poll the job results", + ) + + class Config: + alias_generator = to_camel + populate_by_name = True + + @model_validator(mode="before") + @classmethod + def parse_shape(cls, data: dict) -> dict: + # Handle 'Z' suffix (Zulu/UTC timezone) which fromisoformat doesn't support directly + if isinstance(data.get("createdAt"), str): + created_at = data["createdAt"].replace("Z", "+00:00") + data["createdAt"] = datetime.fromisoformat(created_at) + if isinstance(data.get("updatedAt"), str) and data["updatedAt"] is not None: + updated_at = data["updatedAt"].replace("Z", "+00:00") + data["updatedAt"] = datetime.fromisoformat(updated_at) + return data + + def get_elapsed_time(self) -> float: + """Get the elapsed time for the job.""" + if self.updated_at is not None: + return (self.updated_at - self.created_at).total_seconds() + return self.job_polling_elapsed_time + + +class PyannoteApiDiarizationOutput(PyannoteApiBaseOutput): + """Output model for diarization-only jobs.""" + + output: PyannoteApiDiarization = Field( + description="The diarization output of the job", + ) + + +class PyannoteApiOrchestrationOutput(PyannoteApiBaseOutput): + """Output model for jobs with transcription enabled.""" + + output: PyannoteApiOrchestrationData = Field( + description="The diarization and transcription output of the job", + ) + + +class PyannoteAIApi: + """ + PyannoteAI API client for diarization and transcription. + + Expects the environment variable `PYANNOTE_TOKEN` to be set with a valid pyannote-ai token. + + Args: + timeout: Timeout for job polling in seconds + request_buffer: Buffer for request rate limiting + transcription: Whether to enable transcription (STT) in addition to diarization + """ + + diarization_url = "https://api.pyannote.ai/v1/diarize" + media_url = "https://api.pyannote.ai/v1/media/input" + jobs_url = "https://api.pyannote.ai/v1/jobs" + + def __init__( + self, + timeout: int = 1800, + request_buffer: int = 30, + transcription: bool = False, + ) -> None: + self.timeout = timeout + self.request_buffer = request_buffer + self.transcription = transcription + + # Check that the API key is set + if not os.getenv("PYANNOTE_TOKEN"): + raise ValueError("`PYANNOTE_TOKEN` environment variable is not set") + + def get_presigned_url(self, audio_path: str) -> str: + """ + Get a presigned URL for uploading audio to PyannoteAI temporary storage. + + Args: + audio_path: Path to the local audio file + + Returns: + The media URL for the uploaded audio + """ + logger.debug(f"Getting presigned url for {audio_path}") + name = Path(audio_path).with_suffix(".wav").name + # For some reason if the name has underscores, it will fail + name = "".join([n.capitalize() for n in name.split("_")]) + # Pushing audio file to temporary storage from pyannote-ai + audio_url = f"media://example/{name}" + logger.debug(f"Audio url: {audio_url}") + body = {"url": audio_url} + # Post request to get the presigned url associated with the `audio_url` + response = requests.post( + url=self.media_url, + headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, + json=body, + ) + response.raise_for_status() + data = response.json() + presigned_url = data["url"] + logger.debug(f"Presigned url: {presigned_url}") + + # Upload the audio file to the presigned url + # Audio should be < 24hrs and < 1GB + with open(audio_path, "rb") as audio_file: + requests.put( + url=presigned_url, + data=audio_file, + ) + logger.debug(f"Audio file uploaded to {presigned_url}") + return audio_url + + def diarize( + self, + audio_url: str, + num_speakers: int | None = None, + transcription: bool | None = None, + ) -> requests.Response: + """ + Submit a diarization job to PyannoteAI. + + Args: + audio_url: The media URL of the uploaded audio + num_speakers: Optional number of speakers hint + transcription: Whether to enable transcription (overrides instance setting) + + Returns: + The response from the diarization endpoint + """ + data = {"url": audio_url} + + if num_speakers is not None: + data["numSpeakers"] = num_speakers + + # Use instance transcription setting if not overridden + enable_transcription = transcription if transcription is not None else self.transcription + if enable_transcription: + data["transcription"] = True + + response = requests.post( + self.diarization_url, + headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, + json=data, + ) + response.raise_for_status() + return response + + def get_job_results( + self, + diarization_response: requests.Response, + transcription: bool | None = None, + ) -> PyannoteApiDiarizationOutput | PyannoteApiOrchestrationOutput: + """ + Poll for job results until completion. + + Args: + diarization_response: The response from the diarization endpoint + transcription: Whether transcription was enabled (determines output type) + + Returns: + Either PyannoteApiDiarizationOutput or PyannoteApiOrchestrationOutput + """ + data = diarization_response.json() + headers = diarization_response.headers + job_id = data["jobId"] + logger.debug(f"Starting to poll results for job {job_id}") + + # Get rate limit info from headers with fallback defaults + remaining_requests = int(headers.get("X-RateLimit-Remaining", 30)) + rate_limit = int(headers.get("X-RateLimit-Limit", 30)) + reset_time = int(headers.get("X-RateLimit-Reset", 0)) + + logger.debug( + f"Initial rate limits - Remaining: {remaining_requests}, Limit: {rate_limit}, Reset: {reset_time}s" + ) + + # Use instance transcription setting if not overridden + enable_transcription = transcription if transcription is not None else self.transcription + + start_time = time.time() + elapsed_time = 0 + while elapsed_time < self.timeout: + try: + # Check if we need to wait for rate limit reset + if remaining_requests <= self.request_buffer: + logger.debug( + f"Running low on requests ({remaining_requests} remaining). Waiting {reset_time}s for reset" + ) + time.sleep(reset_time) + remaining_requests = rate_limit + + logger.debug(f"Polling job {job_id}") + response = requests.get( + url=f"{self.jobs_url}/{job_id}", + headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, + ) + response.raise_for_status() + + # Update rate limit information + remaining_requests = int(response.headers.get("X-RateLimit-Remaining", remaining_requests)) + reset_time = int(response.headers.get("X-RateLimit-Reset", reset_time)) + # Add a small buffer to avoid hitting rate limits + safe_remaining = max(1, remaining_requests - self.request_buffer) + delay = reset_time / safe_remaining + logger.debug( + f"Rate limit info - Remaining: {remaining_requests}, Reset: {reset_time}s, Delay: {delay * 1000:.0f}ms" + ) + + job_data = response.json() + job_status = job_data["status"] + logger.debug(f"Job {job_id} status: {job_status}") + + if job_status == "succeeded": + elapsed_time = time.time() - start_time + logger.debug(f"Job {job_id} completed successfully after {elapsed_time:.1f}s") + job_data["jobPollingElapsedTime"] = elapsed_time + + # Return appropriate output type based on transcription flag + if enable_transcription: + return PyannoteApiOrchestrationOutput.model_validate(job_data) + else: + return PyannoteApiDiarizationOutput.model_validate(job_data) + + elif job_status == "failed": + error_msg = job_data.get("error", "No error message provided") + logger.error(f"Job {job_id} failed: {error_msg}") + raise Exception(f"Job failed with error: {error_msg}") + elif job_status == "canceled": + logger.error(f"Job {job_id} was canceled") + raise Exception("Job was canceled") + + elapsed_time = time.time() - start_time + logger.debug(f"Waiting {delay * 1000:.0f}ms before next request") + time.sleep(delay) + + except requests.exceptions.RequestException as e: + logger.error(f"API request failed for job {job_id}: {str(e)}") + raise RuntimeError(f"API request failed: {str(e)}") + + logger.error(f"Job {job_id} timed out after {elapsed_time:.1f}s") + raise TimeoutError(f"Job timed out after {elapsed_time:.1f} seconds") + + def __call__( + self, + audio_path: str, + num_speakers: int | None = None, + transcription: bool | None = None, + ) -> PyannoteApiDiarizationOutput | PyannoteApiOrchestrationOutput: + """ + Process an audio file with diarization and optional transcription. + + Args: + audio_path: Path to the local audio file + num_speakers: Optional number of speakers hint + transcription: Whether to enable transcription (overrides instance setting) + + Returns: + Either PyannoteApiDiarizationOutput or PyannoteApiOrchestrationOutput + """ + # Use instance transcription setting if not overridden + enable_transcription = transcription if transcription is not None else self.transcription + + audio_url = self.get_presigned_url(audio_path) + diarization_response = self.diarize(audio_url, num_speakers, enable_transcription) + return self.get_job_results(diarization_response, enable_transcription) diff --git a/src/openbench/pipeline/diarization/pyannote_api.py b/src/openbench/pipeline/diarization/pyannote_api.py index 0d7c24f..c9fd830 100644 --- a/src/openbench/pipeline/diarization/pyannote_api.py +++ b/src/openbench/pipeline/diarization/pyannote_api.py @@ -1,19 +1,14 @@ # For licensing see accompanying LICENSE.md file. # Copyright (C) 2025 Argmax, Inc. All Rights Reserved. -import os -import time -from datetime import datetime from pathlib import Path from typing import Callable -import requests from argmaxtools.utils import get_logger -from pyannote.core import Segment -from pydantic import BaseModel, Field, model_validator +from pydantic import Field from ...dataset import DiarizationSample -from ...pipeline_prediction import DiarizationAnnotation +from ...engine import PyannoteAIApi, PyannoteApiDiarizationOutput from ..base import Pipeline, PipelineType, register_pipeline from .common import DiarizationOutput, DiarizationPipelineConfig @@ -23,208 +18,6 @@ logger = get_logger(__name__) -class PyannoteApiSegment(BaseModel): - speaker: str - start: float - end: float - - -class PyannoteApiDiarization(BaseModel): - diarization: list[PyannoteApiSegment] - - def to_pyannote_annotation(self) -> DiarizationAnnotation: - annotation = DiarizationAnnotation() - for segment in self.diarization: - annotation[Segment(segment.start, segment.end)] = segment.speaker - return annotation - - -def to_camel(string: str) -> str: - components = string.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - -class PyannoteApiOutput(BaseModel): - job_id: str = Field( - description="The id of the job that was submitted to pyannote-ai", - ) - status: str = Field( - description="The status of the job that was submitted to pyannote-ai", - ) - created_at: datetime = Field( - description="The time the job was created", - ) - updated_at: datetime | None = Field( - description="The time the job was updated. For some reason it can be None.", - ) - job_polling_elapsed_time: float = Field( - description="The time it took to poll the job results", - ) - output: PyannoteApiDiarization = Field( - description="The output of the job", - ) - - class Config: - alias_generator = to_camel - validate_by_name = True - - @model_validator(mode="before") - @classmethod - def parse_shape(cls, data: dict) -> dict: - if isinstance(data["createdAt"], str): - data["createdAt"] = datetime.fromisoformat(data["createdAt"]) - if isinstance(data["updatedAt"], str) and data["updatedAt"] is not None: - data["updatedAt"] = datetime.fromisoformat(data["updatedAt"]) - return data - - def get_elapsed_time(self) -> float: - if self.updated_at is not None: - return (self.updated_at - self.created_at).total_seconds() - return self.job_polling_elapsed_time - - -# Expects a env variable `PYANNOTE_TOKEN` to be set with a valid pyannote-ai token -class PyannoteApi: - diarization_url = "https://api.pyannote.ai/v1/diarize" - media_url = "https://api.pyannote.ai/v1/media/input" - jobs_url = "https://api.pyannote.ai/v1/jobs" - - def __init__( - self, - timeout: int = 1800, - request_buffer: int = 30, - ) -> None: - self.timeout = timeout - self.request_buffer = request_buffer - - def get_presigned_url(self, audio_path: str) -> str: - # We need to push the audio file to temporary storage from pyannote-ai - # we could also push it to S3 or other storage, but this is the easiest way - # to avoid setting up a storage bucket - logger.debug(f"Getting presigned url for {audio_path}") - name = Path(audio_path).with_suffix(".wav").name - # For some reason if the name has underscores, it will fail - name = "".join([n.capitalize() for n in name.split("_")]) - # Pushing audio file to temporary storage from pyannote-ai - audio_url = f"media://example/{name}" - logger.debug(f"Audio url: {audio_url}") - body = {"url": audio_url} - # Post request to get the presigned url associated with the `audio_url` - response = requests.post( - url=self.media_url, - headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, - json=body, - ) - response.raise_for_status() - data = response.json() - presigned_url = data["url"] - logger.debug(f"Presigned url: {presigned_url}") - - # Upload the audio file to the presigned url - # Audio should be < 24hrs and < 1GB - with open(audio_path, "rb") as audio_file: - requests.put( - url=presigned_url, - data=audio_file, - ) - logger.debug(f"Audio file uploaded to {presigned_url}") - return audio_url - - def diarize(self, audio_url: str, num_speakers: int | None = None) -> str: - # We could also pass a Webhook to get the result, but we can poll the job status - # and get the response. This is easier to implement although polling can hit - # rate limits. - data = {"url": audio_url} - - if num_speakers is not None: - data["numSpeakers"] = num_speakers - - response = requests.post( - self.diarization_url, - headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, - json=data, - ) - response.raise_for_status() - return response - - def get_job_results(self, diarization_response: requests.Response) -> PyannoteApiOutput: - data = diarization_response.json() - headers = diarization_response.headers - job_id = data["jobId"] - logger.debug(f"Starting to poll results for job {job_id}") - - # Get rate limit info from headers with fallback defaults - remaining_requests = int(headers.get("X-RateLimit-Remaining", 30)) - rate_limit = int(headers.get("X-RateLimit-Limit", 30)) - reset_time = int(headers.get("X-RateLimit-Reset", 0)) - - logger.debug( - f"Initial rate limits - Remaining: {remaining_requests}, Limit: {rate_limit}, Reset: {reset_time}s" - ) - - start_time = time.time() - elapsed_time = 0 - while elapsed_time < self.timeout: - try: - # Check if we need to wait for rate limit reset - if remaining_requests <= self.request_buffer: - logger.debug( - f"Running low on requests ({remaining_requests} remaining). Waiting {reset_time}s for reset" - ) - time.sleep(reset_time) - remaining_requests = rate_limit - - logger.debug(f"Polling job {job_id}") - response = requests.get( - url=f"{self.jobs_url}/{job_id}", - headers={"Authorization": f"Bearer {os.environ['PYANNOTE_TOKEN']}"}, - ) - response.raise_for_status() - - # Update rate limit information - remaining_requests = int(response.headers.get("X-RateLimit-Remaining", remaining_requests)) - reset_time = int(response.headers.get("X-RateLimit-Reset", reset_time)) - # Add a small buffer to avoid hitting rate limits - safe_remaining = max(1, remaining_requests - self.request_buffer) - delay = reset_time / safe_remaining - logger.debug( - f"Rate limit info - Remaining: {remaining_requests}, Reset: {reset_time}s, Delay: {delay * 1000:.0f}ms" - ) - - job_data = response.json() - job_status = job_data["status"] - logger.debug(f"Job {job_id} status: {job_status}") - - if job_status == "succeeded": - elapsed_time = time.time() - start_time - logger.debug(f"Job {job_id} completed successfully after {elapsed_time:.1f}s") - job_data["jobPollingElapsedTime"] = elapsed_time - return PyannoteApiOutput.model_validate(job_data) - elif job_status == "failed": - error_msg = job_data.get("error", "No error message provided") - logger.error(f"Job {job_id} failed: {error_msg}") - raise Exception(f"Job failed with error: {error_msg}") - elif job_status == "canceled": - logger.error(f"Job {job_id} was canceled") - raise Exception("Job was canceled") - - elapsed_time = time.time() - start_time - logger.debug(f"Waiting {delay * 1000:.0f}ms before next request") - time.sleep(delay) - - except requests.exceptions.RequestException as e: - logger.error(f"API request failed for job {job_id}: {str(e)}") - raise RuntimeError(f"API request failed: {str(e)}") - - logger.error(f"Job {job_id} timed out after {elapsed_time:.1f}s") - raise TimeoutError(f"Job timed out after {elapsed_time:.1f} seconds") - - def __call__(self, audio_path: str, num_speakers: int | None = None) -> PyannoteApiOutput: - audio_url = self.get_presigned_url(audio_path) - diarization_response = self.diarize(audio_url, num_speakers) - return self.get_job_results(diarization_response) - - class PyannoteApiConfig(DiarizationPipelineConfig): timeout: int = Field( default=1800, @@ -246,10 +39,11 @@ class PyannoteApiPipeline(Pipeline): def build_pipeline( self, - ) -> Callable[[dict[str, str | int | None]], PyannoteApiOutput]: - api = PyannoteApi( + ) -> Callable[[dict[str, str | int | None]], PyannoteApiDiarizationOutput]: + api = PyannoteAIApi( timeout=self.config.timeout, request_buffer=self.config.request_buffer, + transcription=False, ) return lambda input_sample: api( audio_path=input_sample["audio_path"], @@ -262,10 +56,10 @@ def parse_input(self, input_sample: DiarizationSample) -> dict[str, str | int | self._audio_path = audio_path return {"audio_path": str(audio_path)} - def parse_output(self, output: PyannoteApiOutput) -> DiarizationOutput: - output = DiarizationOutput( + def parse_output(self, output: PyannoteApiDiarizationOutput) -> DiarizationOutput: + result = DiarizationOutput( prediction=output.output.to_pyannote_annotation(), ) # remove audio from temp self._audio_path.unlink() - return output + return result diff --git a/src/openbench/pipeline/orchestration/__init__.py b/src/openbench/pipeline/orchestration/__init__.py index 2d7bd44..dcb66d0 100644 --- a/src/openbench/pipeline/orchestration/__init__.py +++ b/src/openbench/pipeline/orchestration/__init__.py @@ -4,6 +4,7 @@ from .nemo import NeMoMTParakeetPipeline, NeMoMTParakeetPipelineConfig from .orchestration_deepgram import DeepgramOrchestrationPipeline, DeepgramOrchestrationPipelineConfig from .orchestration_openai import OpenAIOrchestrationPipeline, OpenAIOrchestrationPipelineConfig +from .orchestration_pyannote import PyannoteOrchestrationPipeline, PyannoteOrchestrationPipelineConfig from .orchestration_whisperkitpro import WhisperKitProOrchestrationConfig, WhisperKitProOrchestrationPipeline from .whisperx import WhisperXPipeline, WhisperXPipelineConfig @@ -19,4 +20,6 @@ "OpenAIOrchestrationPipelineConfig", "NeMoMTParakeetPipeline", "NeMoMTParakeetPipelineConfig", + "PyannoteOrchestrationPipeline", + "PyannoteOrchestrationPipelineConfig", ] diff --git a/src/openbench/pipeline/orchestration/orchestration_pyannote.py b/src/openbench/pipeline/orchestration/orchestration_pyannote.py new file mode 100644 index 0000000..ebf713c --- /dev/null +++ b/src/openbench/pipeline/orchestration/orchestration_pyannote.py @@ -0,0 +1,90 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +"""PyannoteAI orchestration pipeline (diarization + transcription with speaker attribution).""" + +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import Field + +from ...dataset import OrchestrationSample +from ...engine import PyannoteAIApi, PyannoteApiOrchestrationOutput +from ...pipeline_prediction import Transcript +from ..base import Pipeline, PipelineType, register_pipeline +from .common import OrchestrationConfig, OrchestrationOutput + + +__all__ = ["PyannoteOrchestrationPipeline", "PyannoteOrchestrationPipelineConfig"] + +logger = get_logger(__name__) + +TEMP_AUDIO_DIR = Path("audio_temp") + + +class PyannoteOrchestrationPipelineConfig(OrchestrationConfig): + """Configuration for PyannoteAI orchestration pipeline.""" + + timeout: int = Field( + default=1800, + description="Timeout for the orchestration job in seconds", + ) + request_buffer: int = Field( + default=30, + description="Buffer for the request rate limit", + ) + + +@register_pipeline +class PyannoteOrchestrationPipeline(Pipeline): + """ + PyannoteAI orchestration pipeline. + + Uses the PyannoteAI API with transcription enabled to get both diarization + and speaker-attributed transcription results. + """ + + _config_class = PyannoteOrchestrationPipelineConfig + pipeline_type = PipelineType.ORCHESTRATION + + def build_pipeline( + self, + ) -> Callable[[Path], PyannoteApiOrchestrationOutput]: + api = PyannoteAIApi( + timeout=self.config.timeout, + request_buffer=self.config.request_buffer, + transcription=True, + ) + + def orchestrate(audio_path: Path) -> PyannoteApiOrchestrationOutput: + response = api(audio_path=str(audio_path)) + # Remove temporary audio path + audio_path.unlink(missing_ok=True) + return response + + return orchestrate + + def parse_input(self, input_sample: OrchestrationSample) -> Path: + """Save audio to temporary directory for processing.""" + return input_sample.save_audio(TEMP_AUDIO_DIR) + + def parse_output(self, output: PyannoteApiOrchestrationOutput) -> OrchestrationOutput: + """ + Parse the PyannoteAI response into a Transcript with speaker attribution. + + Uses word-level transcription to get precise word timings with speaker labels. + """ + # Extract words from word-level transcription with speaker attribution + transcript = Transcript.from_words_info( + words=[word.text for word in output.output.word_level_transcription], + start=[word.start for word in output.output.word_level_transcription], + end=[word.end for word in output.output.word_level_transcription], + speaker=[word.speaker for word in output.output.word_level_transcription], + ) + + return OrchestrationOutput( + prediction=transcript, + diarization_output=None, + transcription_output=None, + ) diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index a89f4b9..862c9db 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -18,6 +18,7 @@ DeepgramOrchestrationPipeline, NeMoMTParakeetPipeline, OpenAIOrchestrationPipeline, + PyannoteOrchestrationPipeline, WhisperKitProOrchestrationPipeline, WhisperXPipeline, ) @@ -35,6 +36,7 @@ GroqTranscriptionPipeline, NeMoTranscriptionPipeline, OpenAITranscriptionPipeline, + PyannoteTranscriptionPipeline, SpeechAnalyzerPipeline, WhisperKitProTranscriptionPipeline, WhisperKitTranscriptionPipeline, @@ -303,6 +305,17 @@ def register_pipeline_aliases() -> None: description="NeMo Multi-Talker Parakeet orchestration pipeline (diarization + transcription).", ) + PipelineRegistry.register_alias( + "pyannote-orchestration", + PyannoteOrchestrationPipeline, + default_config={ + "out_dir": "./pyannote_orchestration_results", + "timeout": 3600, + "request_buffer": 30, + }, + description="PyannoteAI orchestration pipeline (diarization + transcription). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.", + ) + ################# TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( @@ -585,6 +598,17 @@ def register_pipeline_aliases() -> None: description="AssemblyAI transcription pipeline with keyword boosting support. Requires API key from https://www.assemblyai.com/. Set `ASSEMBLYAI_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "pyannote-transcription", + PyannoteTranscriptionPipeline, + default_config={ + "out_dir": "./pyannote_transcription_results", + "timeout": 3600, + "request_buffer": 30, + }, + description="PyannoteAI transcription pipeline (ignores speaker attribution). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.", + ) + ################# STREAMING TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( diff --git a/src/openbench/pipeline/transcription/__init__.py b/src/openbench/pipeline/transcription/__init__.py index 75f3dc8..e8a79f4 100644 --- a/src/openbench/pipeline/transcription/__init__.py +++ b/src/openbench/pipeline/transcription/__init__.py @@ -9,6 +9,7 @@ from .transcription_nemo import NeMoTranscriptionPipeline, NeMoTranscriptionPipelineConfig from .transcription_openai import OpenAITranscriptionPipeline, OpenAITranscriptionPipelineConfig from .transcription_oss_whisper import WhisperOSSTranscriptionPipeline, WhisperOSSTranscriptionPipelineConfig +from .transcription_pyannote import PyannoteTranscriptionPipeline, PyannoteTranscriptionPipelineConfig from .transcription_whisperkitpro import WhisperKitProTranscriptionConfig, WhisperKitProTranscriptionPipeline from .whisperkit import WhisperKitTranscriptionConfig, WhisperKitTranscriptionPipeline @@ -31,4 +32,6 @@ "DeepgramTranscriptionPipelineConfig", "NeMoTranscriptionPipeline", "NeMoTranscriptionPipelineConfig", + "PyannoteTranscriptionPipeline", + "PyannoteTranscriptionPipelineConfig", ] diff --git a/src/openbench/pipeline/transcription/transcription_pyannote.py b/src/openbench/pipeline/transcription/transcription_pyannote.py new file mode 100644 index 0000000..8078cd3 --- /dev/null +++ b/src/openbench/pipeline/transcription/transcription_pyannote.py @@ -0,0 +1,88 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +"""PyannoteAI transcription pipeline (ignores speaker attribution).""" + +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import Field + +from ...dataset import TranscriptionSample +from ...engine import PyannoteAIApi, PyannoteApiOrchestrationOutput +from ...pipeline_prediction import Transcript +from ..base import Pipeline, PipelineType, register_pipeline +from .common import TranscriptionConfig, TranscriptionOutput + + +__all__ = ["PyannoteTranscriptionPipeline", "PyannoteTranscriptionPipelineConfig"] + +logger = get_logger(__name__) + +TEMP_AUDIO_DIR = Path("audio_temp") + + +class PyannoteTranscriptionPipelineConfig(TranscriptionConfig): + """Configuration for PyannoteAI transcription pipeline.""" + + timeout: int = Field( + default=1800, + description="Timeout for the transcription job in seconds", + ) + request_buffer: int = Field( + default=30, + description="Buffer for the request rate limit", + ) + + +@register_pipeline +class PyannoteTranscriptionPipeline(Pipeline): + """ + PyannoteAI transcription pipeline. + + Uses the PyannoteAI API with transcription enabled, but ignores speaker + attribution in the output. This is useful for datasets that don't have + speaker labels. + """ + + _config_class = PyannoteTranscriptionPipelineConfig + pipeline_type = PipelineType.TRANSCRIPTION + + def build_pipeline( + self, + ) -> Callable[[Path], PyannoteApiOrchestrationOutput]: + api = PyannoteAIApi( + timeout=self.config.timeout, + request_buffer=self.config.request_buffer, + transcription=True, + ) + + def transcribe(audio_path: Path) -> PyannoteApiOrchestrationOutput: + response = api(audio_path=str(audio_path)) + # Remove temporary audio path + audio_path.unlink(missing_ok=True) + return response + + return transcribe + + def parse_input(self, input_sample: TranscriptionSample) -> Path: + """Save audio to temporary directory for processing.""" + return input_sample.save_audio(TEMP_AUDIO_DIR) + + def parse_output(self, output: PyannoteApiOrchestrationOutput) -> TranscriptionOutput: + """ + Parse the PyannoteAI response into a Transcript without speaker attribution. + + The word-level transcription contains speaker information, but we ignore it + here since this is a transcription-only pipeline. + """ + # Extract words from word-level transcription, ignoring speaker info + transcript = Transcript.from_words_info( + words=[word.text for word in output.output.word_level_transcription], + start=[word.start for word in output.output.word_level_transcription], + end=[word.end for word in output.output.word_level_transcription], + speaker=None, # Ignore speaker attribution for transcription pipeline + ) + + return TranscriptionOutput(prediction=transcript) From 6a26b20ebb128ac0fc26d77418cd618be47516de Mon Sep 17 00:00:00 2001 From: Eduardo Pacheco Date: Tue, 13 Jan 2026 15:34:07 -0300 Subject: [PATCH 2/2] fix: year in copyright --- src/openbench/engine/pyannote_engine.py | 2 +- src/openbench/pipeline/orchestration/orchestration_pyannote.py | 2 +- src/openbench/pipeline/transcription/transcription_pyannote.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/openbench/engine/pyannote_engine.py b/src/openbench/engine/pyannote_engine.py index b73a92f..2d3456c 100644 --- a/src/openbench/engine/pyannote_engine.py +++ b/src/openbench/engine/pyannote_engine.py @@ -1,5 +1,5 @@ # For licensing see accompanying LICENSE.md file. -# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. +# Copyright (C) 2026 Argmax, Inc. All Rights Reserved. """PyannoteAI API engine for diarization and transcription.""" diff --git a/src/openbench/pipeline/orchestration/orchestration_pyannote.py b/src/openbench/pipeline/orchestration/orchestration_pyannote.py index ebf713c..b9b3c14 100644 --- a/src/openbench/pipeline/orchestration/orchestration_pyannote.py +++ b/src/openbench/pipeline/orchestration/orchestration_pyannote.py @@ -1,5 +1,5 @@ # For licensing see accompanying LICENSE.md file. -# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. +# Copyright (C) 2026 Argmax, Inc. All Rights Reserved. """PyannoteAI orchestration pipeline (diarization + transcription with speaker attribution).""" diff --git a/src/openbench/pipeline/transcription/transcription_pyannote.py b/src/openbench/pipeline/transcription/transcription_pyannote.py index 8078cd3..ab956e5 100644 --- a/src/openbench/pipeline/transcription/transcription_pyannote.py +++ b/src/openbench/pipeline/transcription/transcription_pyannote.py @@ -1,5 +1,5 @@ # For licensing see accompanying LICENSE.md file. -# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. +# Copyright (C) 2026 Argmax, Inc. All Rights Reserved. """PyannoteAI transcription pipeline (ignores speaker attribution)."""