diff --git a/BENCHMARKS.md b/BENCHMARKS.md index 91801c0..339da7a 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -549,11 +549,17 @@ -| Dataset | Deepgram
(nova-3) | OpenAI
(whisper-1) | AssemblyAI | Whisper OSS
(large-v3-turbo) | Argmax
(parakeet-v2) | Argmax
(parakeet-v3) | Apple
(SFSpeechRecognizer)
| Apple
(SpeechAnalyzer)| -|--------------------------------------------------|-------------------------|--------------------------|--------------|------------------------------------|----------------------------|----------------------------|---------------------------------------------|----------------------------| -| earnings22-keywords
(no keywords) | 15.34 | 20.69 | 12.58 | 15.4 | 14.69 | 16.89 | 28.42 | 17 | -| earnings22-keywords
(chunk-keywords) | 13.28 | 31.97 | 11.67 | 21.24 | 12.46 | 14.57 | 26.98 | - | -| earnings22-keywords
(file-keywords) | 13.85 | 28.37 | 11.80 | 14.69 | 12.57 | 14.73 | 27.26 | - | +| System | earnings22-keywords
(no keywords) | earnings22-keywords
(chunk-keywords) | earnings22-keywords
(file-keywords) | +|--------------------------------------------|---------------------------------------|------------------------------------------|----------------------------------------| +| Deepgram
(nova-3) | 15.34 | 13.28 | 13.85 | +| OpenAI
(whisper-1) | 20.69 | 31.97 | 28.37 | +| AssemblyAI | 12.58 | 11.67 | 11.80 | +| Whisper OSS
(large-v3-turbo) | 15.4 | 21.24 | 14.69 | +| Argmax
(parakeet-v2) | 14.69 | 12.46 | 12.57 | +| Argmax
(parakeet-v3) | 16.89 | 14.57 | 14.73 | +| ElevenLabs | 10.53 | 9.13 | 9.08 | +| Apple
(SFSpeechRecognizer) | 28.42 | 26.98 | 27.26 | +| Apple
(SpeechAnalyzer) | 17 | - | - |

@@ -576,11 +582,17 @@ If the model predicts 20 keywords and 15 of them match the ground truth, precisi -| Dataset | Deepgram
(nova-3) | OpenAI
(whisper-1) | AssemblyAI | Whisper OSS
(large-v3-turbo) | Argmax
(parakeet-v2) | Argmax
(parakeet-v3) | Apple
(SFSpeechRecognizer)
| Apple
(SpeechAnalyzer)| -|--------------------------------------------------|-------------------------|--------------------------|--------------|------------------------------------|----------------------------|----------------------------|-------------------------------------------------|----------------------------| -| earnings22-keywords
(no keywords) | 0.98 | 0.97 | 0.97 | 0.97 | 0.97 | 0.98 | 1 | 0.99 | -| earnings22-keywords
(chunk-keywords) | 0.99 | 0.98 | 0.99 | 0.96 | 0.98 | 0.98 | 0.99 | - | -| earnings22-keywords
(file-keywords) | 0.96 | 0.93 | 0.96 | 0.94 | 0.96 | 0.95 | 0.99 | - | +| System | earnings22-keywords
(no keywords) | earnings22-keywords
(chunk-keywords) | earnings22-keywords
(file-keywords) | +|--------------------------------------------|---------------------------------------|------------------------------------------|----------------------------------------| +| Deepgram
(nova-3) | 0.98 | 0.99 | 0.96 | +| OpenAI
(whisper-1) | 0.97 | 0.98 | 0.93 | +| AssemblyAI | 0.97 | 0.99 | 0.96 | +| Whisper OSS
(large-v3-turbo) | 0.97 | 0.96 | 0.94 | +| Argmax
(parakeet-v2) | 0.97 | 0.98 | 0.96 | +| Argmax
(parakeet-v3) | 0.98 | 0.98 | 0.95 | +| ElevenLabs | 0.97 | 0.99 | 0.96 | +| Apple
(SFSpeechRecognizer) | 1 | 0.99 | 0.99 | +| Apple
(SpeechAnalyzer) | 0.99 | - | - |

@@ -603,11 +615,17 @@ If the ground-truth transcript has 25 keywords and the model correctly finds 15, -| Dataset | Deepgram
(nova-3) | OpenAI
(whisper-1) | AssemblyAI | Whisper OSS
(large-v3-turbo) | Argmax
(parakeet-v2) | Argmax
(parakeet-v3) | Apple
(SFSpeechRecognizer)
| Apple
(SpeechAnalyzer)| -|--------------------------------------------------|-------------------------|--------------------------|--------------|------------------------------------|----------------------------|----------------------------|-------------------------------------------------|----------------------------| -| earnings22-keywords
(no keywords) | 0.61 | 0.53 | 0.55 | 0.53 | 0.47 | 0.45 | 0.26 | 0.39 | -| earnings22-keywords
(chunk-keywords) | 0.89 | 0.7 | 0.69 | 0.77 | 0.85 | 0.82 | 0.45 | - | -| earnings22-keywords
(file-keywords) | 0.83 | 0.79 | 0.68 | 0.82 | 0.82 | 0.8 | 0.4 | - | +| System | earnings22-keywords
(no keywords) | earnings22-keywords
(chunk-keywords) | earnings22-keywords
(file-keywords) | +|--------------------------------------------|---------------------------------------|------------------------------------------|----------------------------------------| +| Deepgram
(nova-3) | 0.61 | 0.89 | 0.83 | +| OpenAI
(whisper-1) | 0.53 | 0.7 | 0.79 | +| AssemblyAI | 0.55 | 0.69 | 0.68 | +| Whisper OSS
(large-v3-turbo) | 0.53 | 0.77 | 0.82 | +| Argmax
(parakeet-v2) | 0.47 | 0.85 | 0.82 | +| Argmax
(parakeet-v3) | 0.45 | 0.82 | 0.8 | +| ElevenLabs | 0.75 | 0.96 | 0.94 | +| Apple
(SFSpeechRecognizer) | 0.26 | 0.45 | 0.4 | +| Apple
(SpeechAnalyzer) | 0.39 | - | - |

@@ -632,11 +650,17 @@ F1 = 2 × (0.75 × 0.6) / (0.75 + 0.6) = **66.7%**, reflecting the model's overa -| Dataset | Deepgram
(nova-3) | OpenAI
(whisper-1) | AssemblyAI | Whisper OSS
(large-v3-turbo) | Argmax
(parakeet-v2) | Argmax
(parakeet-v3) | Apple
SFSpeechRecognizer
(Old API) | Apple
(SpeechAnalyzer)| -|--------------------------------------------------|-------------------------|--------------------------|--------------|------------------------------------|----------------------------|----------------------------|-------------------------------------------------|----------------------------| -| earnings22-keywords
(no keywords) | 0.75 | 0.68 | 0.7 | 0.69 | 0.63 | 0.62 | 0.41 | 0.56 | -| earnings22-keywords
(chunk-keywords) | 0.94 | 0.82 | 0.81 | 0.86 | 0.91 | 0.89 | 0.62 | - | -| earnings22-keywords
(file-keywords) | 0.89 | 0.86 | 0.8 | 0.87 | 0.88 | 0.87 | 0.58 | - | +| System | earnings22-keywords
(no keywords) | earnings22-keywords
(chunk-keywords) | earnings22-keywords
(file-keywords) | +|--------------------------------------------|---------------------------------------|------------------------------------------|----------------------------------------| +| Deepgram
(nova-3) | 0.75 | 0.94 | 0.89 | +| OpenAI
(whisper-1) | 0.68 | 0.82 | 0.86 | +| AssemblyAI | 0.7 | 0.81 | 0.8 | +| Whisper OSS
(large-v3-turbo) | 0.69 | 0.86 | 0.87 | +| Argmax
(parakeet-v2) | 0.63 | 0.91 | 0.88 | +| Argmax
(parakeet-v3) | 0.62 | 0.89 | 0.87 | +| ElevenLabs | 0.84 | 0.97 | 0.95 | +| Apple
(SFSpeechRecognizer) | 0.41 | 0.62 | 0.58 | +| Apple
(SpeechAnalyzer) | 0.56 | - | - |

diff --git a/config/pipeline_configs/ElevenLabsTranscriptionPipeline.yaml b/config/pipeline_configs/ElevenLabsTranscriptionPipeline.yaml new file mode 100644 index 0000000..4bec5f2 --- /dev/null +++ b/config/pipeline_configs/ElevenLabsTranscriptionPipeline.yaml @@ -0,0 +1,5 @@ +ElevenLabsTranscriptionPipeline: + config: + model_id: "scribe_v2" + use_keywords: true + diff --git a/pyproject.toml b/pyproject.toml index 9213bc2..c8c1120 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "openai>=2.7.1", "meeteval>=0.4.3", "nemo-toolkit[asr]>=2.6.0", + "elevenlabs>=2.30.0", ] [project.scripts] diff --git a/src/openbench/cli/commands/inference.py b/src/openbench/cli/commands/inference.py index d16a05d..5955925 100644 --- a/src/openbench/cli/commands/inference.py +++ b/src/openbench/cli/commands/inference.py @@ -88,7 +88,7 @@ def get_dummy_sample( waveform=waveform, sample_rate=sample_rate, extra_info={}, - reference=Transcript.from_words_info(words=["dummy"]), + reference=Transcript.from_words_info(words=["dummy"], speaker=["SPEAKER_0"]), ) elif pipeline_type == PipelineType.TRANSCRIPTION: return TranscriptionSample( diff --git a/src/openbench/engine/__init__.py b/src/openbench/engine/__init__.py index 9f30aa1..2dec8fe 100644 --- a/src/openbench/engine/__init__.py +++ b/src/openbench/engine/__init__.py @@ -1,4 +1,5 @@ from .deepgram_engine import DeepgramApi, DeepgramApiResponse +from .elevenlabs_engine import ElevenLabsApi, ElevenLabsApiResponse from .openai_engine import OpenAIApi from .whisperkitpro_engine import ( WhisperKitPro, @@ -11,6 +12,8 @@ __all__ = [ "DeepgramApi", "DeepgramApiResponse", + "ElevenLabsApi", + "ElevenLabsApiResponse", "OpenAIApi", "WhisperKitPro", "WhisperKitProInput", diff --git a/src/openbench/engine/elevenlabs_engine.py b/src/openbench/engine/elevenlabs_engine.py new file mode 100644 index 0000000..05f9dda --- /dev/null +++ b/src/openbench/engine/elevenlabs_engine.py @@ -0,0 +1,109 @@ +import os +from pathlib import Path + +from argmaxtools.utils import get_logger +from elevenlabs.client import ElevenLabs +from pydantic import BaseModel, model_validator + + +logger = get_logger(__name__) + + +class ElevenLabsApiResponse(BaseModel): + """Response from ElevenLabs speech-to-text API.""" + + words: list[str] + speakers: list[str] + start: list[float] + end: list[float] + + @property + def transcript(self) -> str: + return " ".join(self.words) + + @model_validator(mode="after") + def validate_lengths(self) -> "ElevenLabsApiResponse": + if ( + len(self.words) != len(self.speakers) + or len(self.words) != len(self.start) + or len(self.words) != len(self.end) + ): + raise ValueError("All lists must be of the same length") + return self + + +class ElevenLabsApi: + """ElevenLabs Speech-to-Text API wrapper.""" + + def __init__( + self, + model_id: str = "scribe_v2", + timeout: float = 300, + ): + self.model_id = model_id + self.timeout = timeout + + api_key = os.getenv("ELEVENLABS_API_KEY") + if not api_key: + raise ValueError("`ELEVENLABS_API_KEY` is not set") + + self.client = ElevenLabs(api_key=api_key, timeout=timeout) + + def transcribe( + self, + audio_path: Path | str, + keyterms: list[str] | None = None, + language_code: str | None = None, + diarize: bool = False, + num_speakers: int | None = None, + ) -> ElevenLabsApiResponse: + """Transcribe an audio file using ElevenLabs API. + + Args: + audio_path: Path to the audio file + keyterms: List of keywords to boost recognition + language_code: Language code (e.g., 'eng') + diarize: Whether to enable speaker diarization + num_speakers: Maximum number of speakers + + Returns: + ElevenLabsApiResponse with words, speakers, and timestamps + """ + if isinstance(audio_path, str): + audio_path = Path(audio_path) + + with audio_path.open("rb") as f: + audio_data = f.read() + + kwargs = { + "model_id": self.model_id, + "file": audio_data, + } + + if keyterms: + kwargs["keyterms"] = keyterms + logger.debug(f"Using keyterms: {keyterms}") + + if language_code: + kwargs["language_code"] = language_code + logger.debug(f"Using language: {language_code}") + + if diarize: + kwargs["diarize"] = True + logger.debug("Diarization enabled") + + if num_speakers is not None: + kwargs["num_speakers"] = num_speakers + logger.debug(f"Max speakers: {num_speakers}") + + response = self.client.speech_to_text.convert(**kwargs) + + # ElevenLabs returns whitespace as separate "words" - filter them out + words = [w for w in response.words if w.text and w.text.strip()] + + return ElevenLabsApiResponse( + words=[w.text for w in words], + speakers=[str(w.speaker_id) for w in words], + start=[float(w.start) for w in words], + end=[float(w.end) for w in words], + ) diff --git a/src/openbench/pipeline/diarization/__init__.py b/src/openbench/pipeline/diarization/__init__.py index c0842ba..3c3ac9f 100644 --- a/src/openbench/pipeline/diarization/__init__.py +++ b/src/openbench/pipeline/diarization/__init__.py @@ -4,6 +4,7 @@ from .aws import * from .common import * from .diarization_deepgram import * +from .elevenlabs import * from .nemo import * from .picovoice import * from .pyannote import * diff --git a/src/openbench/pipeline/diarization/elevenlabs.py b/src/openbench/pipeline/diarization/elevenlabs.py new file mode 100644 index 0000000..6a56a6c --- /dev/null +++ b/src/openbench/pipeline/diarization/elevenlabs.py @@ -0,0 +1,65 @@ +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pyannote.core import Segment +from pydantic import Field + +from ...dataset import DiarizationSample +from ...engine import ElevenLabsApi, ElevenLabsApiResponse +from ...pipeline_prediction import DiarizationAnnotation +from ..base import Pipeline, PipelineType, register_pipeline +from .common import DiarizationOutput, DiarizationPipelineConfig + + +__all__ = ["ElevenLabsDiarizationPipeline", "ElevenLabsDiarizationPipelineConfig"] + +TEMP_AUDIO_DIR = Path("audio_temp") + +logger = get_logger(__name__) + + +class ElevenLabsDiarizationPipelineConfig(DiarizationPipelineConfig): + model_id: str = Field( + default="scribe_v2", + description="The ElevenLabs speech-to-text model to use", + ) + num_speakers: int | None = Field( + default=None, + description="Maximum number of speakers (helps with diarization). Max 32.", + ) + + +@register_pipeline +class ElevenLabsDiarizationPipeline(Pipeline): + _config_class = ElevenLabsDiarizationPipelineConfig + pipeline_type = PipelineType.DIARIZATION + + def build_pipeline(self) -> Callable[[Path], ElevenLabsApiResponse]: + api = ElevenLabsApi(model_id=self.config.model_id) + + num_speakers = None + if self.config.use_exact_num_speakers: + num_speakers = self.config.num_speakers + + def transcribe(audio_path: Path) -> ElevenLabsApiResponse: + response = api.transcribe( + audio_path=audio_path, + diarize=True, + num_speakers=num_speakers, + ) + # Remove temporary audio path + audio_path.unlink(missing_ok=True) + return response + + return transcribe + + def parse_input(self, input_sample: DiarizationSample) -> Path: + return input_sample.save_audio(output_dir=TEMP_AUDIO_DIR) + + def parse_output(self, output: ElevenLabsApiResponse) -> DiarizationOutput: + annotation = DiarizationAnnotation() + for word, speaker, start, end in zip(output.words, output.speakers, output.start, output.end): + annotation[Segment(start, end)] = f"SPEAKER_{speaker}" + + return DiarizationOutput(prediction=annotation) diff --git a/src/openbench/pipeline/orchestration/__init__.py b/src/openbench/pipeline/orchestration/__init__.py index 2d7bd44..b7fee67 100644 --- a/src/openbench/pipeline/orchestration/__init__.py +++ b/src/openbench/pipeline/orchestration/__init__.py @@ -3,6 +3,7 @@ from .nemo import NeMoMTParakeetPipeline, NeMoMTParakeetPipelineConfig from .orchestration_deepgram import DeepgramOrchestrationPipeline, DeepgramOrchestrationPipelineConfig +from .orchestration_elevenlabs import ElevenLabsOrchestrationPipeline, ElevenLabsOrchestrationPipelineConfig from .orchestration_openai import OpenAIOrchestrationPipeline, OpenAIOrchestrationPipelineConfig from .orchestration_whisperkitpro import WhisperKitProOrchestrationConfig, WhisperKitProOrchestrationPipeline from .whisperx import WhisperXPipeline, WhisperXPipelineConfig @@ -11,6 +12,8 @@ __all__ = [ "DeepgramOrchestrationPipeline", "DeepgramOrchestrationPipelineConfig", + "ElevenLabsOrchestrationPipeline", + "ElevenLabsOrchestrationPipelineConfig", "WhisperXPipeline", "WhisperXPipelineConfig", "WhisperKitProOrchestrationPipeline", diff --git a/src/openbench/pipeline/orchestration/orchestration_elevenlabs.py b/src/openbench/pipeline/orchestration/orchestration_elevenlabs.py new file mode 100644 index 0000000..3c101fe --- /dev/null +++ b/src/openbench/pipeline/orchestration/orchestration_elevenlabs.py @@ -0,0 +1,70 @@ +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import Field + +from ...dataset import OrchestrationSample +from ...engine import ElevenLabsApi, ElevenLabsApiResponse +from ...pipeline import Pipeline, register_pipeline +from ...pipeline_prediction import Transcript +from ...types import PipelineType +from .common import OrchestrationConfig, OrchestrationOutput + + +logger = get_logger(__name__) + +TEMP_AUDIO_DIR = Path("temp_audio_dir") + + +class ElevenLabsOrchestrationPipelineConfig(OrchestrationConfig): + model_id: str = Field( + default="scribe_v2", + description="The ElevenLabs speech-to-text model to use", + ) + num_speakers: int | None = Field( + default=None, + description="Maximum number of speakers (helps with diarization). Max 32.", + ) + + +@register_pipeline +class ElevenLabsOrchestrationPipeline(Pipeline): + _config_class = ElevenLabsOrchestrationPipelineConfig + pipeline_type = PipelineType.ORCHESTRATION + + def build_pipeline(self) -> Callable[[Path], ElevenLabsApiResponse]: + api = ElevenLabsApi(model_id=self.config.model_id) + + def orchestrate(audio_path: Path) -> ElevenLabsApiResponse: + response = api.transcribe( + audio_path=audio_path, + language_code=self.current_language, + diarize=True, + num_speakers=self.config.num_speakers, + ) + # Remove temporary audio path + audio_path.unlink(missing_ok=True) + return response + + return orchestrate + + def parse_input(self, input_sample: OrchestrationSample) -> Path: + """Override to extract language from sample before processing.""" + self.current_language = None + if self.config.force_language: + self.current_language = input_sample.language + + return input_sample.save_audio(TEMP_AUDIO_DIR) + + def parse_output(self, output: ElevenLabsApiResponse) -> OrchestrationOutput: + return OrchestrationOutput( + prediction=Transcript.from_words_info( + words=output.words, + speaker=output.speakers, + start=output.start, + end=output.end, + ), + diarization_output=None, + transcription_output=None, + ) diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index a89f4b9..d653f6f 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -8,6 +8,7 @@ from .diarization import ( AWSTranscribePipeline, DeepgramDiarizationPipeline, + ElevenLabsDiarizationPipeline, NeMoSortformerPipeline, PicovoicePipeline, PyannoteApiPipeline, @@ -16,6 +17,7 @@ ) from .orchestration import ( DeepgramOrchestrationPipeline, + ElevenLabsOrchestrationPipeline, NeMoMTParakeetPipeline, OpenAIOrchestrationPipeline, WhisperKitProOrchestrationPipeline, @@ -32,6 +34,7 @@ from .transcription import ( AssemblyAITranscriptionPipeline, DeepgramTranscriptionPipeline, + ElevenLabsTranscriptionPipeline, GroqTranscriptionPipeline, NeMoTranscriptionPipeline, OpenAITranscriptionPipeline, @@ -131,6 +134,16 @@ def register_pipeline_aliases() -> None: description="Deepgram diarization pipeline. Requires API key from https://www.deepgram.com/. Set `DEEPGRAM_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "elevenlabs-diarization", + ElevenLabsDiarizationPipeline, + default_config={ + "out_dir": "./elevenlabs_diarization_results", + "model_id": "scribe_v2", + }, + description="ElevenLabs diarization pipeline. Requires API key from https://elevenlabs.io/. Set `ELEVENLABS_API_KEY` env var.", + ) + ################# ORCHESTRATION PIPELINES ################# PipelineRegistry.register_alias( @@ -171,6 +184,16 @@ def register_pipeline_aliases() -> None: description="Deepgram orchestration pipeline. Requires API key from https://www.deepgram.com/. Set `DEEPGRAM_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "elevenlabs-orchestration", + ElevenLabsOrchestrationPipeline, + default_config={ + "out_dir": "./elevenlabs_orchestration_results", + "model_id": "scribe_v2", + }, + description="ElevenLabs orchestration pipeline with diarization. Requires API key from https://elevenlabs.io/. Set `ELEVENLABS_API_KEY` env var.", + ) + PipelineRegistry.register_alias( "whisperkitpro-orchestration-tiny", WhisperKitProOrchestrationPipeline, @@ -585,6 +608,16 @@ def register_pipeline_aliases() -> None: description="AssemblyAI transcription pipeline with keyword boosting support. Requires API key from https://www.assemblyai.com/. Set `ASSEMBLYAI_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "elevenlabs-transcription", + ElevenLabsTranscriptionPipeline, + default_config={ + "model_id": "scribe_v2", + "use_keywords": False, + }, + description="ElevenLabs transcription pipeline with keyterm prompting support. Requires API key from https://elevenlabs.io/. Set `ELEVENLABS_API_KEY` env var.", + ) + ################# STREAMING TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( diff --git a/src/openbench/pipeline/transcription/__init__.py b/src/openbench/pipeline/transcription/__init__.py index 75f3dc8..3fe5900 100644 --- a/src/openbench/pipeline/transcription/__init__.py +++ b/src/openbench/pipeline/transcription/__init__.py @@ -5,6 +5,7 @@ from .common import TranscriptionOutput from .transcription_assemblyai import AssemblyAITranscriptionPipeline, AssemblyAITranscriptionPipelineConfig from .transcription_deepgram import DeepgramTranscriptionPipeline, DeepgramTranscriptionPipelineConfig +from .transcription_elevenlabs import ElevenLabsTranscriptionPipeline, ElevenLabsTranscriptionPipelineConfig from .transcription_groq import GroqTranscriptionConfig, GroqTranscriptionPipeline from .transcription_nemo import NeMoTranscriptionPipeline, NeMoTranscriptionPipelineConfig from .transcription_openai import OpenAITranscriptionPipeline, OpenAITranscriptionPipelineConfig @@ -29,6 +30,8 @@ "WhisperOSSTranscriptionPipelineConfig", "DeepgramTranscriptionPipeline", "DeepgramTranscriptionPipelineConfig", + "ElevenLabsTranscriptionPipeline", + "ElevenLabsTranscriptionPipelineConfig", "NeMoTranscriptionPipeline", "NeMoTranscriptionPipelineConfig", ] diff --git a/src/openbench/pipeline/transcription/transcription_elevenlabs.py b/src/openbench/pipeline/transcription/transcription_elevenlabs.py new file mode 100644 index 0000000..8379af4 --- /dev/null +++ b/src/openbench/pipeline/transcription/transcription_elevenlabs.py @@ -0,0 +1,70 @@ +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import Field + +from ...dataset import TranscriptionSample +from ...engine import ElevenLabsApi, ElevenLabsApiResponse +from ...pipeline import Pipeline, register_pipeline +from ...pipeline_prediction import Transcript +from ...types import PipelineType +from .common import TranscriptionConfig, TranscriptionOutput + + +logger = get_logger(__name__) + +TEMP_AUDIO_DIR = Path("temp_audio_dir") + + +class ElevenLabsTranscriptionPipelineConfig(TranscriptionConfig): + model_id: str = Field( + default="scribe_v2", + description="The ElevenLabs speech-to-text model to use", + ) + + +@register_pipeline +class ElevenLabsTranscriptionPipeline(Pipeline): + _config_class = ElevenLabsTranscriptionPipelineConfig + pipeline_type = PipelineType.TRANSCRIPTION + + def build_pipeline(self) -> Callable[[Path], ElevenLabsApiResponse]: + api = ElevenLabsApi(model_id=self.config.model_id) + + def transcribe(audio_path: Path) -> ElevenLabsApiResponse: + response = api.transcribe( + audio_path=audio_path, + keyterms=self.current_keywords, + language_code=self.current_language, + diarize=False, + ) + # Remove temporary audio path + audio_path.unlink(missing_ok=True) + return response + + return transcribe + + def parse_input(self, input_sample: TranscriptionSample) -> Path: + """Override to extract keywords from sample before processing.""" + self.current_keywords = None + if self.config.use_keywords: + keywords = input_sample.extra_info.get("dictionary", []) + if keywords: + self.current_keywords = keywords + + self.current_language = None + if self.config.force_language: + self.current_language = input_sample.language + + return input_sample.save_audio(TEMP_AUDIO_DIR) + + def parse_output(self, output: ElevenLabsApiResponse) -> TranscriptionOutput: + return TranscriptionOutput( + prediction=Transcript.from_words_info( + words=output.words, + speaker=None, + start=output.start, + end=output.end, + ) + ) diff --git a/uv.lock b/uv.lock index 3b9f95b..ba02fae 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.13" resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'linux'", @@ -41,6 +41,7 @@ wheels = [ name = "aenum" version = "3.1.16" source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/09/7a/61ed58e8be9e30c3fe518899cc78c284896d246d51381bab59b5db11e1f3/aenum-3.1.16.tar.gz", hash = "sha256:bfaf9589bdb418ee3a986d85750c7318d9d2839c1b1a1d6fe8fc53ec201cf140", size = 137693, upload-time = "2026-01-12T22:34:38.819Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/e3/52/6ad8f63ec8da1bf40f96996d25d5b650fdd38f5975f8c813732c47388f18/aenum-3.1.16-py3-none-any.whl", hash = "sha256:9035092855a98e41b66e3d0998bd7b96280e85ceb3a04cc035636138a1943eaf", size = 165627, upload-time = "2025-04-25T03:17:58.89Z" }, ] @@ -1108,6 +1109,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359, upload-time = "2025-02-09T03:17:01.998Z" }, ] +[[package]] +name = "elevenlabs" +version = "2.30.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "pydantic" }, + { name = "pydantic-core" }, + { name = "requests" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/a0/e7761b7a5a73f3933480ff3bdf87fcc9228ab8bf776fc08d1af0e4ff2f89/elevenlabs-2.30.0.tar.gz", hash = "sha256:a6a0474e045b93475fcd5f5829b67438d5a6aef9698b6f8758e7148ac03c2b12", size = 438124, upload-time = "2026-01-13T09:25:21.352Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/16/7854a5d114f23fe3e2778b3b7838f4e63973f4ab6f0d71247ce0bcb9038f/elevenlabs-2.30.0-py3-none-any.whl", hash = "sha256:eeee92703a27e7ecd0e8ba4d547f730bcef4ecb02485efa59aeecab3a904d024", size = 1171600, upload-time = "2026-01-13T09:25:19.515Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.0" @@ -3299,6 +3317,7 @@ dependencies = [ { name = "boto3" }, { name = "datasets" }, { name = "deepgram-sdk" }, + { name = "elevenlabs" }, { name = "groq" }, { name = "hdbscan" }, { name = "hydra-core" }, @@ -3351,6 +3370,7 @@ requires-dist = [ { name = "boto3", specifier = ">=1.36.20,<2" }, { name = "datasets", specifier = ">=3.1.0,<4" }, { name = "deepgram-sdk", specifier = ">=4.8.0,<5" }, + { name = "elevenlabs", specifier = ">=2.30.0" }, { name = "groq", specifier = ">=0.31.0" }, { name = "hdbscan", specifier = ">=0.8.40,<0.9" }, { name = "hydra-core", specifier = ">=1.3.2,<2" },