diff --git a/src/openbench/cli/commands/evaluate.py b/src/openbench/cli/commands/evaluate.py index 1b4f5f7..7470c20 100644 --- a/src/openbench/cli/commands/evaluate.py +++ b/src/openbench/cli/commands/evaluate.py @@ -176,6 +176,7 @@ def run_alias_mode( wandb_tags: list[str] | None, use_keywords: bool | None, force_language: bool, + pipeline_config: list[str] | None, verbose: bool, ) -> BenchmarkResult: """Run evaluation using pipeline and dataset aliases.""" @@ -208,6 +209,16 @@ def run_alias_mode( if verbose: typer.echo("✅ Force language: enabled") + # Handle generic pipeline config overrides (key=value pairs) + if pipeline_config: + for item in pipeline_config: + if "=" not in item: + raise typer.BadParameter(f"Invalid --pipeline-config format: '{item}'. Expected key=value") + key, value = item.split("=", 1) + pipeline_config_override[key] = value + if verbose: + typer.echo(f"Config override: {key}={value}") + pipeline = PipelineRegistry.create_pipeline(pipeline_name, config=pipeline_config_override) ######### Build Benchmark Config ######### @@ -345,6 +356,12 @@ def evaluate( "--force-language", help="Force language hinting for compatible pipelines", ), + pipeline_config: list[str] | None = typer.Option( + None, + "--pipeline-config", + "-pc", + help="Override pipeline config values as key=value pairs (e.g. --pipeline-config speaker=serena)", + ), verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output"), ) -> None: """Run evaluation benchmarks. @@ -406,6 +423,7 @@ def evaluate( wandb_tags=wandb_tags, use_keywords=use_keywords, force_language=force_language, + pipeline_config=pipeline_config, verbose=verbose, ) display_result(result) diff --git a/src/openbench/dataset/__init__.py b/src/openbench/dataset/__init__.py index 8c90fae..540942e 100644 --- a/src/openbench/dataset/__init__.py +++ b/src/openbench/dataset/__init__.py @@ -6,6 +6,7 @@ from .dataset_diarization import DiarizationDataset, DiarizationSample from .dataset_orchestration import OrchestrationDataset, OrchestrationSample from .dataset_registry import DatasetRegistry +from .dataset_speech_generation import SpeechGenerationDataset, SpeechGenerationSample from .dataset_streaming_transcription import StreamingDataset, StreamingSample from .dataset_transcription import TranscriptionDataset, TranscriptionSample @@ -24,11 +25,13 @@ "TranscriptionDataset", "StreamingDataset", "OrchestrationDataset", + "SpeechGenerationDataset", # Sample types "DiarizationSample", "TranscriptionSample", "StreamingSample", "OrchestrationSample", + "SpeechGenerationSample", # Registry "DatasetRegistry", ] diff --git a/src/openbench/dataset/dataset_aliases.py b/src/openbench/dataset/dataset_aliases.py index c37af42..9a7ad1b 100644 --- a/src/openbench/dataset/dataset_aliases.py +++ b/src/openbench/dataset/dataset_aliases.py @@ -554,6 +554,20 @@ def register_dataset_aliases() -> None: description="Common Voice dataset for transcription evaluation with up to 400 samples per language this subset contains only russian", ) + ########## SPEECH GENERATION ########## + + DatasetRegistry.register_alias( + "customer-service-tts-prompts-vocalized", + DatasetConfig( + dataset_id="argmaxinc/customer-service-tts-prompts-vocalized", + split="validation", + ), + supported_pipeline_types={ + PipelineType.SPEECH_GENERATION, + }, + description="Customer service TTS prompts with vocalized audio for speech generation evaluation.", + ) + ########## STREAMING TRANSCRIPTION ########## DatasetRegistry.register_alias( diff --git a/src/openbench/dataset/dataset_registry.py b/src/openbench/dataset/dataset_registry.py index f73ae17..1b40e3e 100644 --- a/src/openbench/dataset/dataset_registry.py +++ b/src/openbench/dataset/dataset_registry.py @@ -8,6 +8,7 @@ from .dataset_base import BaseDataset, DatasetConfig from .dataset_diarization import DiarizationDataset from .dataset_orchestration import OrchestrationDataset +from .dataset_speech_generation import SpeechGenerationDataset from .dataset_streaming_transcription import StreamingDataset from .dataset_transcription import TranscriptionDataset @@ -139,3 +140,4 @@ def has_alias(cls, alias: str) -> bool: DatasetRegistry.register(PipelineType.ORCHESTRATION, OrchestrationDataset) DatasetRegistry.register(PipelineType.STREAMING_TRANSCRIPTION, StreamingDataset) DatasetRegistry.register(PipelineType.TRANSCRIPTION, TranscriptionDataset) +DatasetRegistry.register(PipelineType.SPEECH_GENERATION, SpeechGenerationDataset) diff --git a/src/openbench/dataset/dataset_speech_generation.py b/src/openbench/dataset/dataset_speech_generation.py new file mode 100644 index 0000000..61d7acd --- /dev/null +++ b/src/openbench/dataset/dataset_speech_generation.py @@ -0,0 +1,99 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +import numpy as np +from pydantic import Field +from typing_extensions import TypedDict + +from ..pipeline_prediction import Transcript +from .dataset_base import BaseDataset, BaseSample + + +class SpeechGenerationExtraInfo(TypedDict, total=False): + """Extra info for speech generation samples.""" + + language: str + + +class SpeechGenerationRow(TypedDict): + """Expected row structure for speech generation. + + Requires 'text' (the prompt string). No audio needed. + """ + + text: str + + +class SpeechGenerationSample(BaseSample[Transcript, SpeechGenerationExtraInfo]): + """Sample for speech generation tasks. + + The reference Transcript is created from the text + prompt. The pipeline generates audio from this text + and transcribes it to compute WER against reference. + """ + + generated_audio_duration: float | None = Field( + default=None, + description=("Duration (seconds) of the TTS-generated audio. Set by the pipeline after generation."), + ) + + def get_audio_duration(self) -> float: + """Return generated audio duration if available. + + Falls back to the dummy waveform calculation + when the pipeline hasn't set the real duration yet. + """ + if self.generated_audio_duration is not None: + return self.generated_audio_duration + return super().get_audio_duration() + + @property + def text(self) -> str: + """The original text prompt.""" + return self.reference.get_transcript_string() + + +class SpeechGenerationDataset(BaseDataset[SpeechGenerationSample]): + """Dataset for speech generation pipelines. + + Expects column: 'text' (the prompt string). + No audio column is required since audio is generated + by the pipeline itself. + """ + + _expected_columns = ["text"] + _sample_class = SpeechGenerationSample + + def _extract_audio_info(self, row: dict) -> tuple[str, np.ndarray, int]: + """Override to provide dummy audio info. + + Speech generation datasets don't have input audio. + We provide a placeholder waveform so the framework + sample structure is satisfied. The pipeline ignores + the waveform entirely. + """ + audio_name = f"sample_{row['idx']}" + # Use audio_name from the row if available + if "audio_name" in row and row["audio_name"]: + audio_name = str(row["audio_name"]) + dummy_waveform = np.zeros(1, dtype=np.float32) + dummy_sample_rate = 16000 + return audio_name, dummy_waveform, dummy_sample_rate + + def prepare_sample(self, row: SpeechGenerationRow) -> tuple[Transcript, SpeechGenerationExtraInfo]: + """Prepare reference from dataset row. + + Splits text prompt into words to create the + reference Transcript. + """ + text = row["text"] + words = text.split() + reference = Transcript.from_words_info( + words=words, + ) + + extra_info: SpeechGenerationExtraInfo = {} + if "language" in row: + extra_info["language"] = row["language"] + + return reference, extra_info diff --git a/src/openbench/metric/word_error_metrics/word_error_metrics.py b/src/openbench/metric/word_error_metrics/word_error_metrics.py index 4cef2a1..24655d5 100644 --- a/src/openbench/metric/word_error_metrics/word_error_metrics.py +++ b/src/openbench/metric/word_error_metrics/word_error_metrics.py @@ -223,6 +223,7 @@ def compute_metric(self, detail: Details) -> float: PipelineType.TRANSCRIPTION, PipelineType.ORCHESTRATION, PipelineType.STREAMING_TRANSCRIPTION, + PipelineType.SPEECH_GENERATION, ), MetricOptions.WER, ) diff --git a/src/openbench/pipeline/__init__.py b/src/openbench/pipeline/__init__.py index 598328b..268d0d2 100644 --- a/src/openbench/pipeline/__init__.py +++ b/src/openbench/pipeline/__init__.py @@ -6,6 +6,7 @@ from .diarization import * from .orchestration import * from .pipeline_registry import PipelineRegistry +from .speech_generation import * from .streaming_transcription import * from .transcription import * diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index 97287b0..b925b89 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -25,6 +25,10 @@ WhisperXPipeline, ) from .pipeline_registry import PipelineRegistry +from .speech_generation import ( + ElevenLabsSpeechGenerationPipeline, + WhisperKitSpeechGenerationPipeline, +) from .streaming_transcription import ( AssemblyAIStreamingPipeline, DeepgramStreamingPipeline, @@ -642,6 +646,47 @@ def register_pipeline_aliases() -> None: description="PyannoteAI transcription pipeline (ignores speaker attribution). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.", ) + ################# SPEECH GENERATION PIPELINES ################# + + PipelineRegistry.register_alias( + "whisperkit-speech-generation", + WhisperKitSpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "cli_path": os.getenv("WHISPERKIT_CLI_PATH"), + "speaker": "aiden", + "language": "english", + "seed": 10, + "temperature": 0.9, + "top_k": 50, + "max_new_tokens": 245, + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + }, + description="WhisperKit speech generation pipeline. Generates audio from text prompts using whisperkit-cli TTS, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires `WHISPERKIT_CLI_PATH` env var pointing to the whisperkit-cli binary.", + ) + + PipelineRegistry.register_alias( + "elevenlabs-speech-generation", + ElevenLabsSpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "voice_id": "JBFqnCBsd6RMkjVDRZzb", + "model_id": "eleven_multilingual_v2", + "output_format": "mp3_44100_128", + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + "keep_generated_audio": False, + }, + description="ElevenLabs speech generation pipeline. Generates audio from text prompts using ElevenLabs TTS API, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires `ELEVENLABS_API_KEY` and `WHISPERKITPRO_CLI_PATH` env vars.", + ) + ################# STREAMING TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( diff --git a/src/openbench/pipeline/speech_generation/__init__.py b/src/openbench/pipeline/speech_generation/__init__.py new file mode 100644 index 0000000..cad4e25 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/__init__.py @@ -0,0 +1,18 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +from .common import SpeechGenerationConfig, SpeechGenerationOutput +from .speech_generation_elevenlabs import ( + ElevenLabsSpeechGenerationConfig, + ElevenLabsSpeechGenerationPipeline, +) +from .speech_generation_wkp import WhisperKitSpeechGenerationPipeline + + +__all__ = [ + "ElevenLabsSpeechGenerationConfig", + "ElevenLabsSpeechGenerationPipeline", + "SpeechGenerationConfig", + "SpeechGenerationOutput", + "WhisperKitSpeechGenerationPipeline", +] diff --git a/src/openbench/pipeline/speech_generation/common.py b/src/openbench/pipeline/speech_generation/common.py new file mode 100644 index 0000000..fd8b7b4 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/common.py @@ -0,0 +1,95 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +from pydantic import Field + +from ...pipeline_prediction import Transcript +from ..base import PipelineConfig, PipelineOutput + + +class SpeechGenerationConfig(PipelineConfig): + """Base config for speech generation pipelines.""" + + cli_path: str = Field( + ..., + description=("Path to the whisperkit-cli binary (used for TTS generation)."), + ) + + # TTS parameters + speaker: str = Field( + default="aiden", + description="Speaker voice for TTS generation.", + ) + language: str = Field( + default="english", + description="Language for TTS generation.", + ) + seed: int | None = Field( + default=None, + description="Random seed for reproducible output.", + ) + temperature: float = Field( + default=0.9, + description="Sampling temperature for TTS.", + ) + top_k: int = Field( + default=50, + description="Top-k sampling for TTS.", + ) + max_new_tokens: int = Field( + default=245, + description="Max RVQ frames to generate.", + ) + models_path: str | None = Field( + default=None, + description="Local model directory for TTS.", + ) + model_repo: str | None = Field( + default=None, + description="HF repo for TTS model download.", + ) + version_dir: str | None = Field( + default=None, + description="TTS model version directory.", + ) + tokenizer: str | None = Field( + default=None, + description="HF tokenizer repo or local path.", + ) + + # Transcription parameters + transcription_cli_path: str | None = Field( + default=None, + description=("Path to CLI for transcription. Defaults to cli_path if not set."), + ) + transcription_repo_id: str | None = Field( + default=None, + description=("HuggingFace repo ID for transcription model (e.g. argmaxinc/parakeetkit-pro)."), + ) + transcription_model_variant: str | None = Field( + default=None, + description=("Model variant folder within the repo (e.g. nvidia_parakeet-v2_476MB)."), + ) + transcription_model_path: str | None = Field( + default=None, + description=("Local path to ASR model dir. Overrides repo_id/model_variant."), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + +class SpeechGenerationOutput(PipelineOutput[Transcript]): + """Output for speech generation pipelines. + + The prediction is a Transcript of the generated audio + (obtained by transcribing the TTS output). WER is + computed against the original text prompt. + """ + + pass diff --git a/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py b/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py new file mode 100644 index 0000000..1e1af94 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py @@ -0,0 +1,303 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using ElevenLabs TTS API. + +Generates TTS audio from text prompts via ElevenLabs, +then transcribes the generated audio back to text using +WhisperKitPro (Parakeet) for WER evaluation against the +original prompt. +""" + +import os +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import SpeechGenerationSample +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineConfig, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationOutput + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class ElevenLabsSpeechGenerationConfig(PipelineConfig): + """Config for the ElevenLabs speech generation pipeline.""" + + # ElevenLabs TTS parameters + api_key: str | None = Field( + default=None, + description=( + "ElevenLabs API key. Falls back to " + "ELEVENLABS_API_KEY env var." + ), + ) + voice_id: str = Field( + default="JBFqnCBsd6RMkjVDRZzb", + description="ElevenLabs voice ID.", + ) + model_id: str = Field( + default="eleven_v3", #"eleven_multilingual_v2", + description="ElevenLabs model ID.", + ) + output_format: str = Field( + default="mp3_44100_128", + description="ElevenLabs output audio format.", + ) + + # Transcription parameters (WhisperKitPro / Parakeet) + transcription_cli_path: str = Field( + ..., + description=( + "Path to the whisperkit-cli binary " + "used for transcription." + ), + ) + transcription_repo_id: str | None = Field( + default=None, + description=( + "HuggingFace repo ID for transcription " + "model (e.g. argmaxinc/parakeetkit-pro)." + ), + ) + transcription_model_variant: str | None = Field( + default=None, + description=( + "Model variant folder within the repo " + "(e.g. nvidia_parakeet-v2_476MB)." + ), + ) + transcription_model_path: str | None = Field( + default=None, + description=( + "Local path to ASR model dir. " + "Overrides repo_id/model_variant." + ), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + keep_generated_audio: bool = Field( + default=False, + description=( + "If True, keep the generated TTS audio " + "files instead of deleting them." + ), + ) + + +class ElevenLabsSpeechGenerationInput(BaseModel): + """Input for the ElevenLabs speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=( + "Unique identifier for this sample " + "(used for temp file naming)." + ), + ) + + +@register_pipeline +class ElevenLabsSpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using ElevenLabs TTS API. + + This pipeline: + 1. Generates audio from text via ElevenLabs text-to-speech API + 2. Transcribes audio via WhisperKitPro engine (Parakeet) + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = ElevenLabsSpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[ElevenLabsSpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + transcription_engine = self._build_transcription_engine() + + api_key = config.api_key or os.getenv("ELEVENLABS_API_KEY") + if not api_key: + raise ValueError( + "ElevenLabs API key must be provided " + "via config or ELEVENLABS_API_KEY env var." + ) + + from elevenlabs.client import ElevenLabs + + client = ElevenLabs(api_key=api_key) + + def generate_and_transcribe( + input: ElevenLabsSpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + + ext = config.output_format.split("_")[0] + audio_path = TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.{ext}" + + # -- Step 1: Generate audio via ElevenLabs API -- + audio_iter = client.text_to_speech.convert( + text=input.text, + voice_id=config.voice_id, + model_id=config.model_id, + output_format=config.output_format, + ) + + with open(audio_path, "wb") as f: + for chunk in audio_iter: + f.write(chunk) + + if not audio_path.exists() or audio_path.stat().st_size == 0: + raise RuntimeError( + "ElevenLabs TTS failed: audio file " + f"missing or empty at {audio_path}" + ) + + logger.info( + f"Generated ElevenLabs TTS audio: {audio_path}" + ) + + # -- Step 2: Read audio duration -- + try: + import soundfile as sf + + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = info.duration + except Exception as e: + logger.warning(f"Audio duration read failed: {e}") + pipeline_ref._last_generated_duration = None + + # -- Step 3: Transcribe via WhisperKitPro (Parakeet) -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=config.keep_generated_audio, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + import json + + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = [], [], [] + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=all_starts if all_starts else None, + end=all_ends if all_ends else None, + ) + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError( + "Transcription report not found " + f"at {json_path}" + ) + + text_preview = transcript.get_transcript_string()[:100] + logger.info(f"Transcription: {text_preview}...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription (Parakeet).""" + config = self.config + + import coremltools as ct + + compute = ct.ComputeUnit.CPU_AND_NE + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=config.transcription_word_timestamps, + chunking_strategy=config.transcription_chunking_strategy, + audio_encoder_compute_units=compute, + text_decoder_compute_units=compute, + ) + + return WhisperKitPro( + cli_path=config.transcription_cli_path, + transcription_config=engine_config, + ) + + def __call__(self, input_sample: BaseSample) -> PipelineOutput: + """Run pipeline and set generated audio duration.""" + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + is_sg = isinstance(input_sample, SpeechGenerationSample) + if is_sg and dur is not None: + input_sample.generated_audio_duration = dur + dur_val = input_sample.generated_audio_duration + logger.debug(f"Set sample duration to {dur_val}s") + + return parsed_output + + def parse_input( + self, input_sample: SpeechGenerationSample + ) -> ElevenLabsSpeechGenerationInput: + """Extract text prompt from the sample.""" + text = input_sample.reference.get_transcript_string() + return ElevenLabsSpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output(self, output: Transcript) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/pipeline/speech_generation/speech_generation_wkp.py b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py new file mode 100644 index 0000000..9c93440 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py @@ -0,0 +1,257 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using WhisperKit CLI. + +Generates TTS audio from text prompts, then transcribes +the generated audio back to text for WER evaluation +against the original prompt. +""" + +import json +import subprocess +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import ( + SpeechGenerationSample, +) +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationConfig, SpeechGenerationOutput + + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class SpeechGenerationInput(BaseModel): + """Input for the speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=("Unique identifier for this sample (used for temp file naming)."), + ) + + +@register_pipeline +class WhisperKitSpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using WhisperKit CLI. + + This pipeline: + 1. Generates audio from text via whisperkit-cli tts + 2. Transcribes audio via WhisperKitPro engine + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = SpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[SpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + # Build the WhisperKitPro engine for transcription + # (downloads model once, reuses for all samples) + transcription_engine = self._build_transcription_engine() + + def generate_and_transcribe( + input: SpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + + audio_path = TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.wav" + + # -- Step 1: Generate audio via TTS -- + tts_cmd = [ + config.cli_path, + "tts", + "--text", + input.text, + "--speaker", + config.speaker, + "--language", + config.language, + "--output-path", + str(audio_path), + "--temperature", + str(config.temperature), + "--top-k", + str(config.top_k), + "--max-new-tokens", + str(config.max_new_tokens), + ] + + if config.seed is not None: + tts_cmd.extend(["--seed", str(config.seed)]) + if config.models_path is not None: + tts_cmd.extend(["--models-path", config.models_path]) + if config.model_repo is not None: + tts_cmd.extend(["--model-repo", config.model_repo]) + if config.version_dir is not None: + tts_cmd.extend(["--version-dir", config.version_dir]) + if config.tokenizer is not None: + tts_cmd.extend(["--tokenizer", config.tokenizer]) + + logger.debug(f"Running TTS: {' '.join(tts_cmd)}") + + tts_result = subprocess.run(tts_cmd, capture_output=True, text=True) + + if tts_result.returncode != 0: + raise RuntimeError( + "whisperkit-cli tts failed " + f"(exit {tts_result.returncode}):\n" + f" stdout: " + f"{tts_result.stdout[:500]}\n" + f" stderr: " + f"{tts_result.stderr[:500]}" + ) + + if not audio_path.exists(): + raise RuntimeError(f"TTS completed but audio file not found at {audio_path}") + + logger.info(f"Generated TTS audio: {audio_path}") + + # -- Step 2: Read audio duration before + # transcription (engine may delete file) -- + try: + import soundfile as sf + + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = info.duration + except Exception as e: + logger.warning(f"WAV duration read failed: {e}") + pipeline_ref._last_generated_duration = None + + # -- Step 3: Transcribe via WhisperKitPro -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=False, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = ( + [], + [], + [], + ) + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=(all_starts if all_starts else None), + end=(all_ends if all_ends else None), + ) + # Clean up report files + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError(f"Transcription report not found at {json_path}") + + logger.info("Transcription: " + transcript.get_transcript_string()[:100] + "...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription. + + Uses the same engine as the dedicated + WhisperKitPro transcription pipelines, which + handles model download, caching, and CLI args. + """ + config = self.config + cli_path = config.transcription_cli_path or config.cli_path + + import coremltools as ct + + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=config.transcription_word_timestamps, + chunking_strategy=config.transcription_chunking_strategy, + audio_encoder_compute_units=ct.ComputeUnit.CPU_AND_NE, + text_decoder_compute_units=ct.ComputeUnit.CPU_AND_NE, + ) + + return WhisperKitPro( + cli_path=cli_path, + transcription_config=engine_config, + ) + + def __call__(self, input_sample: BaseSample) -> PipelineOutput: + """Run pipeline and set generated audio duration. + + Overrides base __call__ to propagate the real + TTS audio duration back onto the sample so the + runner reports accurate audio_duration and + speed_factor values. + """ + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + # Propagate generated audio duration to sample + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + if isinstance(input_sample, SpeechGenerationSample) and dur is not None: + input_sample.generated_audio_duration = dur + logger.debug(f"Set sample duration to {input_sample.generated_audio_duration}s") + + return parsed_output + + def parse_input(self, input_sample: SpeechGenerationSample) -> SpeechGenerationInput: + """Extract text prompt from the sample.""" + text = input_sample.reference.get_transcript_string() + return SpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output(self, output: Transcript) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/runner/benchmark.py b/src/openbench/runner/benchmark.py index 439c15c..b195862 100644 --- a/src/openbench/runner/benchmark.py +++ b/src/openbench/runner/benchmark.py @@ -33,6 +33,7 @@ PipelineType.TRANSCRIPTION: TranscriptionSampleResult, PipelineType.ORCHESTRATION: TranscriptionSampleResult, PipelineType.STREAMING_TRANSCRIPTION: TranscriptionSampleResult, + PipelineType.SPEECH_GENERATION: TranscriptionSampleResult, } @@ -64,6 +65,7 @@ def __init__(self, config: BenchmarkConfig, pipelines: list[Pipeline]): PipelineType.TRANSCRIPTION: TranscriptionWandbLogger, PipelineType.ORCHESTRATION: TranscriptionWandbLogger, PipelineType.STREAMING_TRANSCRIPTION: TranscriptionWandbLogger, + PipelineType.SPEECH_GENERATION: TranscriptionWandbLogger, } def _get_metrics(self, pipeline: Pipeline) -> dict[str, BaseMetric]: diff --git a/src/openbench/types.py b/src/openbench/types.py index b4bbeaa..b82df22 100644 --- a/src/openbench/types.py +++ b/src/openbench/types.py @@ -12,6 +12,7 @@ class PipelineType(Enum): TRANSCRIPTION = "transcription" ORCHESTRATION = "orchestration" STREAMING_TRANSCRIPTION = "streaming_transcription" + SPEECH_GENERATION = "speech_generation" # All prediction classes that we output should conform to this