Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/openbench/cli/commands/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def run_alias_mode(
wandb_tags: list[str] | None,
use_keywords: bool | None,
force_language: bool,
pipeline_config: list[str] | None,
verbose: bool,
) -> BenchmarkResult:
"""Run evaluation using pipeline and dataset aliases."""
Expand Down Expand Up @@ -208,6 +209,16 @@ def run_alias_mode(
if verbose:
typer.echo("✅ Force language: enabled")

# Handle generic pipeline config overrides (key=value pairs)
if pipeline_config:
for item in pipeline_config:
if "=" not in item:
raise typer.BadParameter(f"Invalid --pipeline-config format: '{item}'. Expected key=value")
key, value = item.split("=", 1)
pipeline_config_override[key] = value
if verbose:
typer.echo(f"Config override: {key}={value}")

pipeline = PipelineRegistry.create_pipeline(pipeline_name, config=pipeline_config_override)

######### Build Benchmark Config #########
Expand Down Expand Up @@ -345,6 +356,12 @@ def evaluate(
"--force-language",
help="Force language hinting for compatible pipelines",
),
pipeline_config: list[str] | None = typer.Option(
None,
"--pipeline-config",
"-pc",
help="Override pipeline config values as key=value pairs (e.g. --pipeline-config speaker=serena)",
),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output"),
) -> None:
"""Run evaluation benchmarks.
Expand Down Expand Up @@ -406,6 +423,7 @@ def evaluate(
wandb_tags=wandb_tags,
use_keywords=use_keywords,
force_language=force_language,
pipeline_config=pipeline_config,
verbose=verbose,
)
display_result(result)
Expand Down
3 changes: 3 additions & 0 deletions src/openbench/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .dataset_diarization import DiarizationDataset, DiarizationSample
from .dataset_orchestration import OrchestrationDataset, OrchestrationSample
from .dataset_registry import DatasetRegistry
from .dataset_speech_generation import SpeechGenerationDataset, SpeechGenerationSample
from .dataset_streaming_transcription import StreamingDataset, StreamingSample
from .dataset_transcription import TranscriptionDataset, TranscriptionSample

Expand All @@ -24,11 +25,13 @@
"TranscriptionDataset",
"StreamingDataset",
"OrchestrationDataset",
"SpeechGenerationDataset",
# Sample types
"DiarizationSample",
"TranscriptionSample",
"StreamingSample",
"OrchestrationSample",
"SpeechGenerationSample",
# Registry
"DatasetRegistry",
]
14 changes: 14 additions & 0 deletions src/openbench/dataset/dataset_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,20 @@ def register_dataset_aliases() -> None:
description="Common Voice dataset for transcription evaluation with up to 400 samples per language this subset contains only russian",
)

########## SPEECH GENERATION ##########

DatasetRegistry.register_alias(
"customer-service-tts-prompts-vocalized",
DatasetConfig(
dataset_id="argmaxinc/customer-service-tts-prompts-vocalized",
split="validation",
),
supported_pipeline_types={
PipelineType.SPEECH_GENERATION,
},
description="Customer service TTS prompts with vocalized audio for speech generation evaluation.",
)

########## STREAMING TRANSCRIPTION ##########

DatasetRegistry.register_alias(
Expand Down
2 changes: 2 additions & 0 deletions src/openbench/dataset/dataset_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .dataset_base import BaseDataset, DatasetConfig
from .dataset_diarization import DiarizationDataset
from .dataset_orchestration import OrchestrationDataset
from .dataset_speech_generation import SpeechGenerationDataset
from .dataset_streaming_transcription import StreamingDataset
from .dataset_transcription import TranscriptionDataset

Expand Down Expand Up @@ -139,3 +140,4 @@ def has_alias(cls, alias: str) -> bool:
DatasetRegistry.register(PipelineType.ORCHESTRATION, OrchestrationDataset)
DatasetRegistry.register(PipelineType.STREAMING_TRANSCRIPTION, StreamingDataset)
DatasetRegistry.register(PipelineType.TRANSCRIPTION, TranscriptionDataset)
DatasetRegistry.register(PipelineType.SPEECH_GENERATION, SpeechGenerationDataset)
99 changes: 99 additions & 0 deletions src/openbench/dataset/dataset_speech_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2025 Argmax, Inc. All Rights Reserved.

import numpy as np
from pydantic import Field
from typing_extensions import TypedDict

from ..pipeline_prediction import Transcript
from .dataset_base import BaseDataset, BaseSample


class SpeechGenerationExtraInfo(TypedDict, total=False):
"""Extra info for speech generation samples."""

language: str


class SpeechGenerationRow(TypedDict):
"""Expected row structure for speech generation.

Requires 'text' (the prompt string). No audio needed.
"""

text: str


class SpeechGenerationSample(BaseSample[Transcript, SpeechGenerationExtraInfo]):
"""Sample for speech generation tasks.

The reference Transcript is created from the text
prompt. The pipeline generates audio from this text
and transcribes it to compute WER against reference.
"""

generated_audio_duration: float | None = Field(
default=None,
description=("Duration (seconds) of the TTS-generated audio. Set by the pipeline after generation."),
)

def get_audio_duration(self) -> float:
"""Return generated audio duration if available.

Falls back to the dummy waveform calculation
when the pipeline hasn't set the real duration yet.
"""
if self.generated_audio_duration is not None:
return self.generated_audio_duration
return super().get_audio_duration()

@property
def text(self) -> str:
"""The original text prompt."""
return self.reference.get_transcript_string()


class SpeechGenerationDataset(BaseDataset[SpeechGenerationSample]):
"""Dataset for speech generation pipelines.

Expects column: 'text' (the prompt string).
No audio column is required since audio is generated
by the pipeline itself.
"""

_expected_columns = ["text"]
_sample_class = SpeechGenerationSample

def _extract_audio_info(self, row: dict) -> tuple[str, np.ndarray, int]:
"""Override to provide dummy audio info.

Speech generation datasets don't have input audio.
We provide a placeholder waveform so the framework
sample structure is satisfied. The pipeline ignores
the waveform entirely.
"""
audio_name = f"sample_{row['idx']}"
# Use audio_name from the row if available
if "audio_name" in row and row["audio_name"]:
audio_name = str(row["audio_name"])
dummy_waveform = np.zeros(1, dtype=np.float32)
dummy_sample_rate = 16000
return audio_name, dummy_waveform, dummy_sample_rate

def prepare_sample(self, row: SpeechGenerationRow) -> tuple[Transcript, SpeechGenerationExtraInfo]:
"""Prepare reference from dataset row.

Splits text prompt into words to create the
reference Transcript.
"""
text = row["text"]
words = text.split()
reference = Transcript.from_words_info(
words=words,
)

extra_info: SpeechGenerationExtraInfo = {}
if "language" in row:
extra_info["language"] = row["language"]

return reference, extra_info
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def compute_metric(self, detail: Details) -> float:
PipelineType.TRANSCRIPTION,
PipelineType.ORCHESTRATION,
PipelineType.STREAMING_TRANSCRIPTION,
PipelineType.SPEECH_GENERATION,
),
MetricOptions.WER,
)
Expand Down
1 change: 1 addition & 0 deletions src/openbench/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .diarization import *
from .orchestration import *
from .pipeline_registry import PipelineRegistry
from .speech_generation import *
from .streaming_transcription import *
from .transcription import *

Expand Down
45 changes: 45 additions & 0 deletions src/openbench/pipeline/pipeline_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
WhisperXPipeline,
)
from .pipeline_registry import PipelineRegistry
from .speech_generation import (
ElevenLabsSpeechGenerationPipeline,
WhisperKitSpeechGenerationPipeline,
)
from .streaming_transcription import (
AssemblyAIStreamingPipeline,
DeepgramStreamingPipeline,
Expand Down Expand Up @@ -642,6 +646,47 @@ def register_pipeline_aliases() -> None:
description="PyannoteAI transcription pipeline (ignores speaker attribution). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.",
)

################# SPEECH GENERATION PIPELINES #################

PipelineRegistry.register_alias(
"whisperkit-speech-generation",
WhisperKitSpeechGenerationPipeline,
default_config={
"out_dir": "./speech_generation_results",
"cli_path": os.getenv("WHISPERKIT_CLI_PATH"),
"speaker": "aiden",
"language": "english",
"seed": 10,
"temperature": 0.9,
"top_k": 50,
"max_new_tokens": 245,
"transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"transcription_repo_id": "argmaxinc/parakeetkit-pro",
"transcription_model_variant": "nvidia_parakeet-v2_476MB",
},
description="WhisperKit speech generation pipeline. Generates audio from text prompts using whisperkit-cli TTS, "
"then transcribes the generated audio to compute WER against the original prompt. "
"Requires `WHISPERKIT_CLI_PATH` env var pointing to the whisperkit-cli binary.",
)

PipelineRegistry.register_alias(
"elevenlabs-speech-generation",
ElevenLabsSpeechGenerationPipeline,
default_config={
"out_dir": "./speech_generation_results",
"voice_id": "JBFqnCBsd6RMkjVDRZzb",
"model_id": "eleven_multilingual_v2",
"output_format": "mp3_44100_128",
"transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"transcription_repo_id": "argmaxinc/parakeetkit-pro",
"transcription_model_variant": "nvidia_parakeet-v2_476MB",
"keep_generated_audio": False,
},
description="ElevenLabs speech generation pipeline. Generates audio from text prompts using ElevenLabs TTS API, "
"then transcribes the generated audio to compute WER against the original prompt. "
"Requires `ELEVENLABS_API_KEY` and `WHISPERKITPRO_CLI_PATH` env vars.",
)

################# STREAMING TRANSCRIPTION PIPELINES #################

PipelineRegistry.register_alias(
Expand Down
18 changes: 18 additions & 0 deletions src/openbench/pipeline/speech_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2025 Argmax, Inc. All Rights Reserved.

from .common import SpeechGenerationConfig, SpeechGenerationOutput
from .speech_generation_elevenlabs import (
ElevenLabsSpeechGenerationConfig,
ElevenLabsSpeechGenerationPipeline,
)
from .speech_generation_wkp import WhisperKitSpeechGenerationPipeline


__all__ = [
"ElevenLabsSpeechGenerationConfig",
"ElevenLabsSpeechGenerationPipeline",
"SpeechGenerationConfig",
"SpeechGenerationOutput",
"WhisperKitSpeechGenerationPipeline",
]
95 changes: 95 additions & 0 deletions src/openbench/pipeline/speech_generation/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2025 Argmax, Inc. All Rights Reserved.

from pydantic import Field

from ...pipeline_prediction import Transcript
from ..base import PipelineConfig, PipelineOutput


class SpeechGenerationConfig(PipelineConfig):
"""Base config for speech generation pipelines."""

cli_path: str = Field(
...,
description=("Path to the whisperkit-cli binary (used for TTS generation)."),
)

# TTS parameters
speaker: str = Field(
default="aiden",
description="Speaker voice for TTS generation.",
)
language: str = Field(
default="english",
description="Language for TTS generation.",
)
seed: int | None = Field(
default=None,
description="Random seed for reproducible output.",
)
temperature: float = Field(
default=0.9,
description="Sampling temperature for TTS.",
)
top_k: int = Field(
default=50,
description="Top-k sampling for TTS.",
)
max_new_tokens: int = Field(
default=245,
description="Max RVQ frames to generate.",
)
models_path: str | None = Field(
default=None,
description="Local model directory for TTS.",
)
model_repo: str | None = Field(
default=None,
description="HF repo for TTS model download.",
)
version_dir: str | None = Field(
default=None,
description="TTS model version directory.",
)
tokenizer: str | None = Field(
default=None,
description="HF tokenizer repo or local path.",
)

# Transcription parameters
transcription_cli_path: str | None = Field(
default=None,
description=("Path to CLI for transcription. Defaults to cli_path if not set."),
)
transcription_repo_id: str | None = Field(
default=None,
description=("HuggingFace repo ID for transcription model (e.g. argmaxinc/parakeetkit-pro)."),
)
transcription_model_variant: str | None = Field(
default=None,
description=("Model variant folder within the repo (e.g. nvidia_parakeet-v2_476MB)."),
)
transcription_model_path: str | None = Field(
default=None,
description=("Local path to ASR model dir. Overrides repo_id/model_variant."),
)
transcription_word_timestamps: bool = Field(
default=True,
description="Include word timestamps.",
)
transcription_chunking_strategy: str = Field(
default="vad",
description="Chunking strategy (none or vad).",
)


class SpeechGenerationOutput(PipelineOutput[Transcript]):
"""Output for speech generation pipelines.

The prediction is a Transcript of the generated audio
(obtained by transcribing the TTS output). WER is
computed against the original text prompt.
"""

pass
Loading
Loading