From 1e59781fb477961146857460b9506d92a5bc6492 Mon Sep 17 00:00:00 2001 From: dberkin1 Date: Fri, 13 Feb 2026 20:52:24 +0300 Subject: [PATCH 1/5] Add TTS evals --- src/openbench/cli/commands/evaluate.py | 21 ++ src/openbench/dataset/__init__.py | 3 + src/openbench/dataset/dataset_aliases.py | 14 + src/openbench/dataset/dataset_registry.py | 2 + .../dataset/dataset_speech_generation.py | 110 ++++++ .../word_error_metrics/word_error_metrics.py | 1 + src/openbench/pipeline/__init__.py | 1 + src/openbench/pipeline/pipeline_aliases.py | 26 ++ .../pipeline/speech_generation/__init__.py | 12 + .../pipeline/speech_generation/common.py | 110 ++++++ .../speech_generation_wkp.py | 316 ++++++++++++++++++ src/openbench/runner/benchmark.py | 2 + src/openbench/types.py | 1 + 13 files changed, 619 insertions(+) create mode 100644 src/openbench/dataset/dataset_speech_generation.py create mode 100644 src/openbench/pipeline/speech_generation/__init__.py create mode 100644 src/openbench/pipeline/speech_generation/common.py create mode 100644 src/openbench/pipeline/speech_generation/speech_generation_wkp.py diff --git a/src/openbench/cli/commands/evaluate.py b/src/openbench/cli/commands/evaluate.py index 1b4f5f7..c35eb01 100644 --- a/src/openbench/cli/commands/evaluate.py +++ b/src/openbench/cli/commands/evaluate.py @@ -176,6 +176,7 @@ def run_alias_mode( wandb_tags: list[str] | None, use_keywords: bool | None, force_language: bool, + pipeline_config: list[str] | None, verbose: bool, ) -> BenchmarkResult: """Run evaluation using pipeline and dataset aliases.""" @@ -208,6 +209,19 @@ def run_alias_mode( if verbose: typer.echo("✅ Force language: enabled") + # Handle generic pipeline config overrides (key=value pairs) + if pipeline_config: + for item in pipeline_config: + if "=" not in item: + raise typer.BadParameter( + f"Invalid --pipeline-config format: '{item}'. " + f"Expected key=value" + ) + key, value = item.split("=", 1) + pipeline_config_override[key] = value + if verbose: + typer.echo(f"✅ Config override: {key}={value}") + pipeline = PipelineRegistry.create_pipeline(pipeline_name, config=pipeline_config_override) ######### Build Benchmark Config ######### @@ -345,6 +359,12 @@ def evaluate( "--force-language", help="Force language hinting for compatible pipelines", ), + pipeline_config: list[str] | None = typer.Option( + None, + "--pipeline-config", + "-pc", + help="Override pipeline config values as key=value pairs (e.g. --pipeline-config speaker=serena)", + ), verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output"), ) -> None: """Run evaluation benchmarks. @@ -406,6 +426,7 @@ def evaluate( wandb_tags=wandb_tags, use_keywords=use_keywords, force_language=force_language, + pipeline_config=pipeline_config, verbose=verbose, ) display_result(result) diff --git a/src/openbench/dataset/__init__.py b/src/openbench/dataset/__init__.py index 8c90fae..540942e 100644 --- a/src/openbench/dataset/__init__.py +++ b/src/openbench/dataset/__init__.py @@ -6,6 +6,7 @@ from .dataset_diarization import DiarizationDataset, DiarizationSample from .dataset_orchestration import OrchestrationDataset, OrchestrationSample from .dataset_registry import DatasetRegistry +from .dataset_speech_generation import SpeechGenerationDataset, SpeechGenerationSample from .dataset_streaming_transcription import StreamingDataset, StreamingSample from .dataset_transcription import TranscriptionDataset, TranscriptionSample @@ -24,11 +25,13 @@ "TranscriptionDataset", "StreamingDataset", "OrchestrationDataset", + "SpeechGenerationDataset", # Sample types "DiarizationSample", "TranscriptionSample", "StreamingSample", "OrchestrationSample", + "SpeechGenerationSample", # Registry "DatasetRegistry", ] diff --git a/src/openbench/dataset/dataset_aliases.py b/src/openbench/dataset/dataset_aliases.py index c37af42..9a7ad1b 100644 --- a/src/openbench/dataset/dataset_aliases.py +++ b/src/openbench/dataset/dataset_aliases.py @@ -554,6 +554,20 @@ def register_dataset_aliases() -> None: description="Common Voice dataset for transcription evaluation with up to 400 samples per language this subset contains only russian", ) + ########## SPEECH GENERATION ########## + + DatasetRegistry.register_alias( + "customer-service-tts-prompts-vocalized", + DatasetConfig( + dataset_id="argmaxinc/customer-service-tts-prompts-vocalized", + split="validation", + ), + supported_pipeline_types={ + PipelineType.SPEECH_GENERATION, + }, + description="Customer service TTS prompts with vocalized audio for speech generation evaluation.", + ) + ########## STREAMING TRANSCRIPTION ########## DatasetRegistry.register_alias( diff --git a/src/openbench/dataset/dataset_registry.py b/src/openbench/dataset/dataset_registry.py index f73ae17..1b40e3e 100644 --- a/src/openbench/dataset/dataset_registry.py +++ b/src/openbench/dataset/dataset_registry.py @@ -8,6 +8,7 @@ from .dataset_base import BaseDataset, DatasetConfig from .dataset_diarization import DiarizationDataset from .dataset_orchestration import OrchestrationDataset +from .dataset_speech_generation import SpeechGenerationDataset from .dataset_streaming_transcription import StreamingDataset from .dataset_transcription import TranscriptionDataset @@ -139,3 +140,4 @@ def has_alias(cls, alias: str) -> bool: DatasetRegistry.register(PipelineType.ORCHESTRATION, OrchestrationDataset) DatasetRegistry.register(PipelineType.STREAMING_TRANSCRIPTION, StreamingDataset) DatasetRegistry.register(PipelineType.TRANSCRIPTION, TranscriptionDataset) +DatasetRegistry.register(PipelineType.SPEECH_GENERATION, SpeechGenerationDataset) diff --git a/src/openbench/dataset/dataset_speech_generation.py b/src/openbench/dataset/dataset_speech_generation.py new file mode 100644 index 0000000..d9e2748 --- /dev/null +++ b/src/openbench/dataset/dataset_speech_generation.py @@ -0,0 +1,110 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +import numpy as np +from pydantic import Field +from typing_extensions import TypedDict + +from ..pipeline_prediction import Transcript +from .dataset_base import BaseDataset, BaseSample + + +class SpeechGenerationExtraInfo(TypedDict, total=False): + """Extra info for speech generation samples.""" + + language: str + + +class SpeechGenerationRow(TypedDict): + """Expected row structure for speech generation. + + Requires 'text' (the prompt string). No audio needed. + """ + + text: str + + +class SpeechGenerationSample( + BaseSample[Transcript, SpeechGenerationExtraInfo] +): + """Sample for speech generation tasks. + + The reference Transcript is created from the text + prompt. The pipeline generates audio from this text + and transcribes it to compute WER against reference. + """ + + generated_audio_duration: float | None = Field( + default=None, + description=( + "Duration (seconds) of the TTS-generated audio. " + "Set by the pipeline after generation." + ), + ) + + def get_audio_duration(self) -> float: + """Return generated audio duration if available. + + Falls back to the dummy waveform calculation + when the pipeline hasn't set the real duration yet. + """ + if self.generated_audio_duration is not None: + return self.generated_audio_duration + return super().get_audio_duration() + + @property + def text(self) -> str: + """The original text prompt.""" + return self.reference.get_transcript_string() + + +class SpeechGenerationDataset( + BaseDataset[SpeechGenerationSample] +): + """Dataset for speech generation pipelines. + + Expects column: 'text' (the prompt string). + No audio column is required since audio is generated + by the pipeline itself. + """ + + _expected_columns = ["text"] + _sample_class = SpeechGenerationSample + + def _extract_audio_info( + self, row: dict + ) -> tuple[str, np.ndarray, int]: + """Override to provide dummy audio info. + + Speech generation datasets don't have input audio. + We provide a placeholder waveform so the framework + sample structure is satisfied. The pipeline ignores + the waveform entirely. + """ + audio_name = f"sample_{row['idx']}" + # Use audio_name from the row if available + if "audio_name" in row and row["audio_name"]: + audio_name = str(row["audio_name"]) + dummy_waveform = np.zeros(1, dtype=np.float32) + dummy_sample_rate = 16000 + return audio_name, dummy_waveform, dummy_sample_rate + + def prepare_sample( + self, row: SpeechGenerationRow + ) -> tuple[Transcript, SpeechGenerationExtraInfo]: + """Prepare reference from dataset row. + + Splits text prompt into words to create the + reference Transcript. + """ + text = row["text"] + words = text.split() + reference = Transcript.from_words_info( + words=words, + ) + + extra_info: SpeechGenerationExtraInfo = {} + if "language" in row: + extra_info["language"] = row["language"] + + return reference, extra_info diff --git a/src/openbench/metric/word_error_metrics/word_error_metrics.py b/src/openbench/metric/word_error_metrics/word_error_metrics.py index 4cef2a1..24655d5 100644 --- a/src/openbench/metric/word_error_metrics/word_error_metrics.py +++ b/src/openbench/metric/word_error_metrics/word_error_metrics.py @@ -223,6 +223,7 @@ def compute_metric(self, detail: Details) -> float: PipelineType.TRANSCRIPTION, PipelineType.ORCHESTRATION, PipelineType.STREAMING_TRANSCRIPTION, + PipelineType.SPEECH_GENERATION, ), MetricOptions.WER, ) diff --git a/src/openbench/pipeline/__init__.py b/src/openbench/pipeline/__init__.py index 598328b..268d0d2 100644 --- a/src/openbench/pipeline/__init__.py +++ b/src/openbench/pipeline/__init__.py @@ -6,6 +6,7 @@ from .diarization import * from .orchestration import * from .pipeline_registry import PipelineRegistry +from .speech_generation import * from .streaming_transcription import * from .transcription import * diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index 97287b0..770af12 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -25,6 +25,9 @@ WhisperXPipeline, ) from .pipeline_registry import PipelineRegistry +from .speech_generation import ( + WhisperKitSpeechGenerationPipeline, +) from .streaming_transcription import ( AssemblyAIStreamingPipeline, DeepgramStreamingPipeline, @@ -642,6 +645,29 @@ def register_pipeline_aliases() -> None: description="PyannoteAI transcription pipeline (ignores speaker attribution). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.", ) + ################# SPEECH GENERATION PIPELINES ################# + + PipelineRegistry.register_alias( + "whisperkit-speech-generation", + WhisperKitSpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "cli_path": os.getenv("WHISPERKIT_CLI_PATH"), + "speaker": "aiden", + "language": "english", + "seed": 10, + "temperature": 0.9, + "top_k": 50, + "max_new_tokens": 245, + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + }, + description="WhisperKit speech generation pipeline. Generates audio from text prompts using whisperkit-cli TTS, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires `WHISPERKIT_CLI_PATH` env var pointing to the whisperkit-cli binary.", + ) + ################# STREAMING TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( diff --git a/src/openbench/pipeline/speech_generation/__init__.py b/src/openbench/pipeline/speech_generation/__init__.py new file mode 100644 index 0000000..1a6bce7 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/__init__.py @@ -0,0 +1,12 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +from .common import SpeechGenerationConfig, SpeechGenerationOutput +from .speech_generation_wkp import WhisperKitSpeechGenerationPipeline + + +__all__ = [ + "SpeechGenerationConfig", + "SpeechGenerationOutput", + "WhisperKitSpeechGenerationPipeline", +] diff --git a/src/openbench/pipeline/speech_generation/common.py b/src/openbench/pipeline/speech_generation/common.py new file mode 100644 index 0000000..96b924e --- /dev/null +++ b/src/openbench/pipeline/speech_generation/common.py @@ -0,0 +1,110 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +from pydantic import Field + +from ...pipeline_prediction import Transcript +from ..base import PipelineConfig, PipelineOutput + + +class SpeechGenerationConfig(PipelineConfig): + """Base config for speech generation pipelines.""" + + cli_path: str = Field( + ..., + description=( + "Path to the whisperkit-cli binary " + "(used for TTS generation)." + ), + ) + + # TTS parameters + speaker: str = Field( + default="aiden", + description="Speaker voice for TTS generation.", + ) + language: str = Field( + default="english", + description="Language for TTS generation.", + ) + seed: int | None = Field( + default=None, + description="Random seed for reproducible output.", + ) + temperature: float = Field( + default=0.9, + description="Sampling temperature for TTS.", + ) + top_k: int = Field( + default=50, + description="Top-k sampling for TTS.", + ) + max_new_tokens: int = Field( + default=245, + description="Max RVQ frames to generate.", + ) + models_path: str | None = Field( + default=None, + description="Local model directory for TTS.", + ) + model_repo: str | None = Field( + default=None, + description="HF repo for TTS model download.", + ) + version_dir: str | None = Field( + default=None, + description="TTS model version directory.", + ) + tokenizer: str | None = Field( + default=None, + description="HF tokenizer repo or local path.", + ) + + # Transcription parameters + transcription_cli_path: str | None = Field( + default=None, + description=( + "Path to CLI for transcription. " + "Defaults to cli_path if not set." + ), + ) + transcription_repo_id: str | None = Field( + default=None, + description=( + "HuggingFace repo ID for transcription " + "model (e.g. argmaxinc/parakeetkit-pro)." + ), + ) + transcription_model_variant: str | None = Field( + default=None, + description=( + "Model variant folder within the repo " + "(e.g. nvidia_parakeet-v2_476MB)." + ), + ) + transcription_model_path: str | None = Field( + default=None, + description=( + "Local path to ASR model dir. " + "Overrides repo_id/model_variant." + ), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + +class SpeechGenerationOutput(PipelineOutput[Transcript]): + """Output for speech generation pipelines. + + The prediction is a Transcript of the generated audio + (obtained by transcribing the TTS output). WER is + computed against the original text prompt. + """ + + pass diff --git a/src/openbench/pipeline/speech_generation/speech_generation_wkp.py b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py new file mode 100644 index 0000000..d7191eb --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py @@ -0,0 +1,316 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using WhisperKit CLI. + +Generates TTS audio from text prompts, then transcribes +the generated audio back to text for WER evaluation +against the original prompt. +""" + +import json +import subprocess +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import ( + SpeechGenerationSample, +) +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationConfig, SpeechGenerationOutput + + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class SpeechGenerationInput(BaseModel): + """Input for the speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=( + "Unique identifier for this sample " + "(used for temp file naming)." + ), + ) + + +@register_pipeline +class WhisperKitSpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using WhisperKit CLI. + + This pipeline: + 1. Generates audio from text via whisperkit-cli tts + 2. Transcribes audio via WhisperKitPro engine + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = SpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[SpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + # Build the WhisperKitPro engine for transcription + # (downloads model once, reuses for all samples) + transcription_engine = self._build_transcription_engine() + + def generate_and_transcribe( + input: SpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir( + parents=True, exist_ok=True + ) + + audio_path = ( + TEMP_TTS_AUDIO_DIR + / f"{input.audio_name}.wav" + ) + + # -- Step 1: Generate audio via TTS -- + tts_cmd = [ + config.cli_path, + "tts", + "--text", input.text, + "--speaker", config.speaker, + "--language", config.language, + "--output-path", str(audio_path), + "--temperature", str(config.temperature), + "--top-k", str(config.top_k), + "--max-new-tokens", + str(config.max_new_tokens), + ] + + if config.seed is not None: + tts_cmd.extend( + ["--seed", str(config.seed)] + ) + if config.models_path is not None: + tts_cmd.extend( + ["--models-path", config.models_path] + ) + if config.model_repo is not None: + tts_cmd.extend( + ["--model-repo", config.model_repo] + ) + if config.version_dir is not None: + tts_cmd.extend( + ["--version-dir", config.version_dir] + ) + if config.tokenizer is not None: + tts_cmd.extend( + ["--tokenizer", config.tokenizer] + ) + + logger.debug( + f"Running TTS: {' '.join(tts_cmd)}" + ) + + tts_result = subprocess.run( + tts_cmd, capture_output=True, text=True + ) + + if tts_result.returncode != 0: + raise RuntimeError( + "whisperkit-cli tts failed " + f"(exit {tts_result.returncode}):\n" + f" stdout: " + f"{tts_result.stdout[:500]}\n" + f" stderr: " + f"{tts_result.stderr[:500]}" + ) + + if not audio_path.exists(): + raise RuntimeError( + "TTS completed but audio file " + f"not found at {audio_path}" + ) + + logger.info( + f"Generated TTS audio: {audio_path}" + ) + + # -- Step 2: Read audio duration before + # transcription (engine may delete file) -- + try: + import soundfile as sf + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = ( + info.duration + ) + except Exception as e: + logger.warning( + f"WAV duration read failed: {e}" + ) + pipeline_ref._last_generated_duration = ( + None + ) + + # -- Step 3: Transcribe via WhisperKitPro -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=False, + ) + engine_output = transcription_engine( + engine_input + ) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = ( + [], [], [], + ) + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append( + w["start"] + ) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=( + all_starts + if all_starts else None + ), + end=( + all_ends + if all_ends else None + ), + ) + # Clean up report files + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError( + "Transcription report not found " + f"at {json_path}" + ) + + logger.info( + "Transcription: " + + transcript.get_transcript_string( + )[:100] + + "..." + ) + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription. + + Uses the same engine as the dedicated + WhisperKitPro transcription pipelines, which + handles model download, caching, and CLI args. + """ + config = self.config + cli_path = ( + config.transcription_cli_path or config.cli_path + ) + + import coremltools as ct + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=config.transcription_word_timestamps, + chunking_strategy=config.transcription_chunking_strategy, + audio_encoder_compute_units=ct.ComputeUnit.CPU_AND_NE, + text_decoder_compute_units=ct.ComputeUnit.CPU_AND_NE, + ) + + return WhisperKitPro( + cli_path=cli_path, + transcription_config=engine_config, + ) + + def __call__( + self, input_sample: BaseSample + ) -> PipelineOutput: + """Run pipeline and set generated audio duration. + + Overrides base __call__ to propagate the real + TTS audio duration back onto the sample so the + runner reports accurate audio_duration and + speed_factor values. + """ + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + # Propagate generated audio duration to sample + dur = self._last_generated_duration + logger.info( + f"Generated audio duration: {dur}s" + ) + if ( + isinstance(input_sample, SpeechGenerationSample) + and dur is not None + ): + input_sample.generated_audio_duration = dur + logger.info( + "Set sample duration to " + f"{input_sample.generated_audio_duration}s" + ) + + return parsed_output + + def parse_input( + self, input_sample: SpeechGenerationSample + ) -> SpeechGenerationInput: + """Extract text prompt from the sample.""" + text = ( + input_sample.reference.get_transcript_string() + ) + return SpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output( + self, output: Transcript + ) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/runner/benchmark.py b/src/openbench/runner/benchmark.py index 439c15c..b195862 100644 --- a/src/openbench/runner/benchmark.py +++ b/src/openbench/runner/benchmark.py @@ -33,6 +33,7 @@ PipelineType.TRANSCRIPTION: TranscriptionSampleResult, PipelineType.ORCHESTRATION: TranscriptionSampleResult, PipelineType.STREAMING_TRANSCRIPTION: TranscriptionSampleResult, + PipelineType.SPEECH_GENERATION: TranscriptionSampleResult, } @@ -64,6 +65,7 @@ def __init__(self, config: BenchmarkConfig, pipelines: list[Pipeline]): PipelineType.TRANSCRIPTION: TranscriptionWandbLogger, PipelineType.ORCHESTRATION: TranscriptionWandbLogger, PipelineType.STREAMING_TRANSCRIPTION: TranscriptionWandbLogger, + PipelineType.SPEECH_GENERATION: TranscriptionWandbLogger, } def _get_metrics(self, pipeline: Pipeline) -> dict[str, BaseMetric]: diff --git a/src/openbench/types.py b/src/openbench/types.py index b4bbeaa..b82df22 100644 --- a/src/openbench/types.py +++ b/src/openbench/types.py @@ -12,6 +12,7 @@ class PipelineType(Enum): TRANSCRIPTION = "transcription" ORCHESTRATION = "orchestration" STREAMING_TRANSCRIPTION = "streaming_transcription" + SPEECH_GENERATION = "speech_generation" # All prediction classes that we output should conform to this From 7d3ab4d43ebbf4a1fd417e0bd9d6c11aebf1ad6b Mon Sep 17 00:00:00 2001 From: dberkin1 Date: Fri, 13 Feb 2026 20:58:09 +0300 Subject: [PATCH 2/5] refactor --- src/openbench/cli/commands/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/openbench/cli/commands/evaluate.py b/src/openbench/cli/commands/evaluate.py index c35eb01..d3f172c 100644 --- a/src/openbench/cli/commands/evaluate.py +++ b/src/openbench/cli/commands/evaluate.py @@ -220,7 +220,7 @@ def run_alias_mode( key, value = item.split("=", 1) pipeline_config_override[key] = value if verbose: - typer.echo(f"✅ Config override: {key}={value}") + typer.echo(f"Config override: {key}={value}") pipeline = PipelineRegistry.create_pipeline(pipeline_name, config=pipeline_config_override) From d41818130e97dc24f2ac206bcc8a5b737007a52a Mon Sep 17 00:00:00 2001 From: dberkin1 Date: Fri, 13 Feb 2026 21:05:08 +0300 Subject: [PATCH 3/5] reformatting --- src/openbench/cli/commands/evaluate.py | 5 +- .../dataset/dataset_speech_generation.py | 21 +-- .../pipeline/speech_generation/common.py | 25 +-- .../speech_generation_wkp.py | 151 ++++++------------ 4 files changed, 57 insertions(+), 145 deletions(-) diff --git a/src/openbench/cli/commands/evaluate.py b/src/openbench/cli/commands/evaluate.py index d3f172c..7470c20 100644 --- a/src/openbench/cli/commands/evaluate.py +++ b/src/openbench/cli/commands/evaluate.py @@ -213,10 +213,7 @@ def run_alias_mode( if pipeline_config: for item in pipeline_config: if "=" not in item: - raise typer.BadParameter( - f"Invalid --pipeline-config format: '{item}'. " - f"Expected key=value" - ) + raise typer.BadParameter(f"Invalid --pipeline-config format: '{item}'. Expected key=value") key, value = item.split("=", 1) pipeline_config_override[key] = value if verbose: diff --git a/src/openbench/dataset/dataset_speech_generation.py b/src/openbench/dataset/dataset_speech_generation.py index d9e2748..61d7acd 100644 --- a/src/openbench/dataset/dataset_speech_generation.py +++ b/src/openbench/dataset/dataset_speech_generation.py @@ -24,9 +24,7 @@ class SpeechGenerationRow(TypedDict): text: str -class SpeechGenerationSample( - BaseSample[Transcript, SpeechGenerationExtraInfo] -): +class SpeechGenerationSample(BaseSample[Transcript, SpeechGenerationExtraInfo]): """Sample for speech generation tasks. The reference Transcript is created from the text @@ -36,10 +34,7 @@ class SpeechGenerationSample( generated_audio_duration: float | None = Field( default=None, - description=( - "Duration (seconds) of the TTS-generated audio. " - "Set by the pipeline after generation." - ), + description=("Duration (seconds) of the TTS-generated audio. Set by the pipeline after generation."), ) def get_audio_duration(self) -> float: @@ -58,9 +53,7 @@ def text(self) -> str: return self.reference.get_transcript_string() -class SpeechGenerationDataset( - BaseDataset[SpeechGenerationSample] -): +class SpeechGenerationDataset(BaseDataset[SpeechGenerationSample]): """Dataset for speech generation pipelines. Expects column: 'text' (the prompt string). @@ -71,9 +64,7 @@ class SpeechGenerationDataset( _expected_columns = ["text"] _sample_class = SpeechGenerationSample - def _extract_audio_info( - self, row: dict - ) -> tuple[str, np.ndarray, int]: + def _extract_audio_info(self, row: dict) -> tuple[str, np.ndarray, int]: """Override to provide dummy audio info. Speech generation datasets don't have input audio. @@ -89,9 +80,7 @@ def _extract_audio_info( dummy_sample_rate = 16000 return audio_name, dummy_waveform, dummy_sample_rate - def prepare_sample( - self, row: SpeechGenerationRow - ) -> tuple[Transcript, SpeechGenerationExtraInfo]: + def prepare_sample(self, row: SpeechGenerationRow) -> tuple[Transcript, SpeechGenerationExtraInfo]: """Prepare reference from dataset row. Splits text prompt into words to create the diff --git a/src/openbench/pipeline/speech_generation/common.py b/src/openbench/pipeline/speech_generation/common.py index 96b924e..fd8b7b4 100644 --- a/src/openbench/pipeline/speech_generation/common.py +++ b/src/openbench/pipeline/speech_generation/common.py @@ -12,10 +12,7 @@ class SpeechGenerationConfig(PipelineConfig): cli_path: str = Field( ..., - description=( - "Path to the whisperkit-cli binary " - "(used for TTS generation)." - ), + description=("Path to the whisperkit-cli binary (used for TTS generation)."), ) # TTS parameters @@ -63,31 +60,19 @@ class SpeechGenerationConfig(PipelineConfig): # Transcription parameters transcription_cli_path: str | None = Field( default=None, - description=( - "Path to CLI for transcription. " - "Defaults to cli_path if not set." - ), + description=("Path to CLI for transcription. Defaults to cli_path if not set."), ) transcription_repo_id: str | None = Field( default=None, - description=( - "HuggingFace repo ID for transcription " - "model (e.g. argmaxinc/parakeetkit-pro)." - ), + description=("HuggingFace repo ID for transcription model (e.g. argmaxinc/parakeetkit-pro)."), ) transcription_model_variant: str | None = Field( default=None, - description=( - "Model variant folder within the repo " - "(e.g. nvidia_parakeet-v2_476MB)." - ), + description=("Model variant folder within the repo (e.g. nvidia_parakeet-v2_476MB)."), ) transcription_model_path: str | None = Field( default=None, - description=( - "Local path to ASR model dir. " - "Overrides repo_id/model_variant." - ), + description=("Local path to ASR model dir. Overrides repo_id/model_variant."), ) transcription_word_timestamps: bool = Field( default=True, diff --git a/src/openbench/pipeline/speech_generation/speech_generation_wkp.py b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py index d7191eb..6476d77 100644 --- a/src/openbench/pipeline/speech_generation/speech_generation_wkp.py +++ b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py @@ -51,10 +51,7 @@ class SpeechGenerationInput(BaseModel): ) audio_name: str = Field( ..., - description=( - "Unique identifier for this sample " - "(used for temp file naming)." - ), + description=("Unique identifier for this sample (used for temp file naming)."), ) @@ -85,57 +82,44 @@ def build_pipeline( def generate_and_transcribe( input: SpeechGenerationInput, ) -> Transcript: - TEMP_TTS_AUDIO_DIR.mkdir( - parents=True, exist_ok=True - ) + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) - audio_path = ( - TEMP_TTS_AUDIO_DIR - / f"{input.audio_name}.wav" - ) + audio_path = TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.wav" # -- Step 1: Generate audio via TTS -- tts_cmd = [ config.cli_path, "tts", - "--text", input.text, - "--speaker", config.speaker, - "--language", config.language, - "--output-path", str(audio_path), - "--temperature", str(config.temperature), - "--top-k", str(config.top_k), + "--text", + input.text, + "--speaker", + config.speaker, + "--language", + config.language, + "--output-path", + str(audio_path), + "--temperature", + str(config.temperature), + "--top-k", + str(config.top_k), "--max-new-tokens", str(config.max_new_tokens), ] if config.seed is not None: - tts_cmd.extend( - ["--seed", str(config.seed)] - ) + tts_cmd.extend(["--seed", str(config.seed)]) if config.models_path is not None: - tts_cmd.extend( - ["--models-path", config.models_path] - ) + tts_cmd.extend(["--models-path", config.models_path]) if config.model_repo is not None: - tts_cmd.extend( - ["--model-repo", config.model_repo] - ) + tts_cmd.extend(["--model-repo", config.model_repo]) if config.version_dir is not None: - tts_cmd.extend( - ["--version-dir", config.version_dir] - ) + tts_cmd.extend(["--version-dir", config.version_dir]) if config.tokenizer is not None: - tts_cmd.extend( - ["--tokenizer", config.tokenizer] - ) + tts_cmd.extend(["--tokenizer", config.tokenizer]) - logger.debug( - f"Running TTS: {' '.join(tts_cmd)}" - ) + logger.debug(f"Running TTS: {' '.join(tts_cmd)}") - tts_result = subprocess.run( - tts_cmd, capture_output=True, text=True - ) + tts_result = subprocess.run(tts_cmd, capture_output=True, text=True) if tts_result.returncode != 0: raise RuntimeError( @@ -148,39 +132,27 @@ def generate_and_transcribe( ) if not audio_path.exists(): - raise RuntimeError( - "TTS completed but audio file " - f"not found at {audio_path}" - ) + raise RuntimeError(f"TTS completed but audio file not found at {audio_path}") - logger.info( - f"Generated TTS audio: {audio_path}" - ) + logger.info(f"Generated TTS audio: {audio_path}") # -- Step 2: Read audio duration before # transcription (engine may delete file) -- try: import soundfile as sf + info = sf.info(str(audio_path)) - pipeline_ref._last_generated_duration = ( - info.duration - ) + pipeline_ref._last_generated_duration = info.duration except Exception as e: - logger.warning( - f"WAV duration read failed: {e}" - ) - pipeline_ref._last_generated_duration = ( - None - ) + logger.warning(f"WAV duration read failed: {e}") + pipeline_ref._last_generated_duration = None # -- Step 3: Transcribe via WhisperKitPro -- engine_input = WhisperKitProInput( audio_path=audio_path, keep_audio=False, ) - engine_output = transcription_engine( - engine_input - ) + engine_output = transcription_engine(engine_input) # -- Step 4: Parse transcription report -- json_path = engine_output.json_report_path @@ -188,27 +160,21 @@ def generate_and_transcribe( with json_path.open("r") as f: data = json.load(f) all_words, all_starts, all_ends = ( - [], [], [], + [], + [], + [], ) for seg in data.get("segments", []): for w in seg.get("words", []): all_words.append(w["word"]) if "start" in w: - all_starts.append( - w["start"] - ) + all_starts.append(w["start"]) if "end" in w: all_ends.append(w["end"]) transcript = Transcript.from_words_info( words=all_words, - start=( - all_starts - if all_starts else None - ), - end=( - all_ends - if all_ends else None - ), + start=(all_starts if all_starts else None), + end=(all_ends if all_ends else None), ) # Clean up report files json_path.unlink(missing_ok=True) @@ -216,17 +182,9 @@ def generate_and_transcribe( if srt_path: srt_path.unlink(missing_ok=True) else: - raise RuntimeError( - "Transcription report not found " - f"at {json_path}" - ) + raise RuntimeError(f"Transcription report not found at {json_path}") - logger.info( - "Transcription: " - + transcript.get_transcript_string( - )[:100] - + "..." - ) + logger.info("Transcription: " + transcript.get_transcript_string()[:100] + "...") return transcript @@ -240,11 +198,10 @@ def _build_transcription_engine(self) -> WhisperKitPro: handles model download, caching, and CLI args. """ config = self.config - cli_path = ( - config.transcription_cli_path or config.cli_path - ) + cli_path = config.transcription_cli_path or config.cli_path import coremltools as ct + engine_config = WhisperKitProConfig( repo_id=config.transcription_repo_id, model_variant=config.transcription_model_variant, @@ -260,9 +217,7 @@ def _build_transcription_engine(self) -> WhisperKitPro: transcription_config=engine_config, ) - def __call__( - self, input_sample: BaseSample - ) -> PipelineOutput: + def __call__(self, input_sample: BaseSample) -> PipelineOutput: """Run pipeline and set generated audio duration. Overrides base __call__ to propagate the real @@ -282,35 +237,21 @@ def __call__( # Propagate generated audio duration to sample dur = self._last_generated_duration - logger.info( - f"Generated audio duration: {dur}s" - ) - if ( - isinstance(input_sample, SpeechGenerationSample) - and dur is not None - ): + logger.info(f"Generated audio duration: {dur}s") + if isinstance(input_sample, SpeechGenerationSample) and dur is not None: input_sample.generated_audio_duration = dur - logger.info( - "Set sample duration to " - f"{input_sample.generated_audio_duration}s" - ) + logger.info(f"Set sample duration to {input_sample.generated_audio_duration}s") return parsed_output - def parse_input( - self, input_sample: SpeechGenerationSample - ) -> SpeechGenerationInput: + def parse_input(self, input_sample: SpeechGenerationSample) -> SpeechGenerationInput: """Extract text prompt from the sample.""" - text = ( - input_sample.reference.get_transcript_string() - ) + text = input_sample.reference.get_transcript_string() return SpeechGenerationInput( text=text, audio_name=input_sample.audio_name, ) - def parse_output( - self, output: Transcript - ) -> SpeechGenerationOutput: + def parse_output(self, output: Transcript) -> SpeechGenerationOutput: """Wrap transcription into output.""" return SpeechGenerationOutput(prediction=output) From ef86575f157e0e8e5494d51a518b4cbcb23c20fd Mon Sep 17 00:00:00 2001 From: dberkin1 Date: Fri, 13 Feb 2026 21:12:03 +0300 Subject: [PATCH 4/5] reformat --- .../pipeline/speech_generation/speech_generation_wkp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/openbench/pipeline/speech_generation/speech_generation_wkp.py b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py index 6476d77..9c93440 100644 --- a/src/openbench/pipeline/speech_generation/speech_generation_wkp.py +++ b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py @@ -237,10 +237,10 @@ def __call__(self, input_sample: BaseSample) -> PipelineOutput: # Propagate generated audio duration to sample dur = self._last_generated_duration - logger.info(f"Generated audio duration: {dur}s") + logger.debug(f"Generated audio duration: {dur}s") if isinstance(input_sample, SpeechGenerationSample) and dur is not None: input_sample.generated_audio_duration = dur - logger.info(f"Set sample duration to {input_sample.generated_audio_duration}s") + logger.debug(f"Set sample duration to {input_sample.generated_audio_duration}s") return parsed_output From 82f40c3504b5eefd5a618ae1f6c6a56ed488d846 Mon Sep 17 00:00:00 2001 From: Berkin Durmus Date: Wed, 11 Mar 2026 17:00:07 +0300 Subject: [PATCH 5/5] Add ElevenLabs speech generation pipeline (#8) Adds a new TTS evaluation pipeline using ElevenLabs' API to generate audio from text prompts, then transcribes with WhisperKitPro for WER. Made-with: Cursor Co-authored-by: dberkin1 --- src/openbench/pipeline/pipeline_aliases.py | 19 ++ .../pipeline/speech_generation/__init__.py | 6 + .../speech_generation_elevenlabs.py | 303 ++++++++++++++++++ 3 files changed, 328 insertions(+) create mode 100644 src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index 770af12..b925b89 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -26,6 +26,7 @@ ) from .pipeline_registry import PipelineRegistry from .speech_generation import ( + ElevenLabsSpeechGenerationPipeline, WhisperKitSpeechGenerationPipeline, ) from .streaming_transcription import ( @@ -668,6 +669,24 @@ def register_pipeline_aliases() -> None: "Requires `WHISPERKIT_CLI_PATH` env var pointing to the whisperkit-cli binary.", ) + PipelineRegistry.register_alias( + "elevenlabs-speech-generation", + ElevenLabsSpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "voice_id": "JBFqnCBsd6RMkjVDRZzb", + "model_id": "eleven_multilingual_v2", + "output_format": "mp3_44100_128", + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + "keep_generated_audio": False, + }, + description="ElevenLabs speech generation pipeline. Generates audio from text prompts using ElevenLabs TTS API, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires `ELEVENLABS_API_KEY` and `WHISPERKITPRO_CLI_PATH` env vars.", + ) + ################# STREAMING TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( diff --git a/src/openbench/pipeline/speech_generation/__init__.py b/src/openbench/pipeline/speech_generation/__init__.py index 1a6bce7..cad4e25 100644 --- a/src/openbench/pipeline/speech_generation/__init__.py +++ b/src/openbench/pipeline/speech_generation/__init__.py @@ -2,10 +2,16 @@ # Copyright (C) 2025 Argmax, Inc. All Rights Reserved. from .common import SpeechGenerationConfig, SpeechGenerationOutput +from .speech_generation_elevenlabs import ( + ElevenLabsSpeechGenerationConfig, + ElevenLabsSpeechGenerationPipeline, +) from .speech_generation_wkp import WhisperKitSpeechGenerationPipeline __all__ = [ + "ElevenLabsSpeechGenerationConfig", + "ElevenLabsSpeechGenerationPipeline", "SpeechGenerationConfig", "SpeechGenerationOutput", "WhisperKitSpeechGenerationPipeline", diff --git a/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py b/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py new file mode 100644 index 0000000..1e1af94 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py @@ -0,0 +1,303 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using ElevenLabs TTS API. + +Generates TTS audio from text prompts via ElevenLabs, +then transcribes the generated audio back to text using +WhisperKitPro (Parakeet) for WER evaluation against the +original prompt. +""" + +import os +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import SpeechGenerationSample +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineConfig, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationOutput + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class ElevenLabsSpeechGenerationConfig(PipelineConfig): + """Config for the ElevenLabs speech generation pipeline.""" + + # ElevenLabs TTS parameters + api_key: str | None = Field( + default=None, + description=( + "ElevenLabs API key. Falls back to " + "ELEVENLABS_API_KEY env var." + ), + ) + voice_id: str = Field( + default="JBFqnCBsd6RMkjVDRZzb", + description="ElevenLabs voice ID.", + ) + model_id: str = Field( + default="eleven_v3", #"eleven_multilingual_v2", + description="ElevenLabs model ID.", + ) + output_format: str = Field( + default="mp3_44100_128", + description="ElevenLabs output audio format.", + ) + + # Transcription parameters (WhisperKitPro / Parakeet) + transcription_cli_path: str = Field( + ..., + description=( + "Path to the whisperkit-cli binary " + "used for transcription." + ), + ) + transcription_repo_id: str | None = Field( + default=None, + description=( + "HuggingFace repo ID for transcription " + "model (e.g. argmaxinc/parakeetkit-pro)." + ), + ) + transcription_model_variant: str | None = Field( + default=None, + description=( + "Model variant folder within the repo " + "(e.g. nvidia_parakeet-v2_476MB)." + ), + ) + transcription_model_path: str | None = Field( + default=None, + description=( + "Local path to ASR model dir. " + "Overrides repo_id/model_variant." + ), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + keep_generated_audio: bool = Field( + default=False, + description=( + "If True, keep the generated TTS audio " + "files instead of deleting them." + ), + ) + + +class ElevenLabsSpeechGenerationInput(BaseModel): + """Input for the ElevenLabs speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=( + "Unique identifier for this sample " + "(used for temp file naming)." + ), + ) + + +@register_pipeline +class ElevenLabsSpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using ElevenLabs TTS API. + + This pipeline: + 1. Generates audio from text via ElevenLabs text-to-speech API + 2. Transcribes audio via WhisperKitPro engine (Parakeet) + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = ElevenLabsSpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[ElevenLabsSpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + transcription_engine = self._build_transcription_engine() + + api_key = config.api_key or os.getenv("ELEVENLABS_API_KEY") + if not api_key: + raise ValueError( + "ElevenLabs API key must be provided " + "via config or ELEVENLABS_API_KEY env var." + ) + + from elevenlabs.client import ElevenLabs + + client = ElevenLabs(api_key=api_key) + + def generate_and_transcribe( + input: ElevenLabsSpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + + ext = config.output_format.split("_")[0] + audio_path = TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.{ext}" + + # -- Step 1: Generate audio via ElevenLabs API -- + audio_iter = client.text_to_speech.convert( + text=input.text, + voice_id=config.voice_id, + model_id=config.model_id, + output_format=config.output_format, + ) + + with open(audio_path, "wb") as f: + for chunk in audio_iter: + f.write(chunk) + + if not audio_path.exists() or audio_path.stat().st_size == 0: + raise RuntimeError( + "ElevenLabs TTS failed: audio file " + f"missing or empty at {audio_path}" + ) + + logger.info( + f"Generated ElevenLabs TTS audio: {audio_path}" + ) + + # -- Step 2: Read audio duration -- + try: + import soundfile as sf + + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = info.duration + except Exception as e: + logger.warning(f"Audio duration read failed: {e}") + pipeline_ref._last_generated_duration = None + + # -- Step 3: Transcribe via WhisperKitPro (Parakeet) -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=config.keep_generated_audio, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + import json + + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = [], [], [] + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=all_starts if all_starts else None, + end=all_ends if all_ends else None, + ) + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError( + "Transcription report not found " + f"at {json_path}" + ) + + text_preview = transcript.get_transcript_string()[:100] + logger.info(f"Transcription: {text_preview}...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription (Parakeet).""" + config = self.config + + import coremltools as ct + + compute = ct.ComputeUnit.CPU_AND_NE + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=config.transcription_word_timestamps, + chunking_strategy=config.transcription_chunking_strategy, + audio_encoder_compute_units=compute, + text_decoder_compute_units=compute, + ) + + return WhisperKitPro( + cli_path=config.transcription_cli_path, + transcription_config=engine_config, + ) + + def __call__(self, input_sample: BaseSample) -> PipelineOutput: + """Run pipeline and set generated audio duration.""" + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + is_sg = isinstance(input_sample, SpeechGenerationSample) + if is_sg and dur is not None: + input_sample.generated_audio_duration = dur + dur_val = input_sample.generated_audio_duration + logger.debug(f"Set sample duration to {dur_val}s") + + return parsed_output + + def parse_input( + self, input_sample: SpeechGenerationSample + ) -> ElevenLabsSpeechGenerationInput: + """Extract text prompt from the sample.""" + text = input_sample.reference.get_transcript_string() + return ElevenLabsSpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output(self, output: Transcript) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output)