diff --git a/config/benchmark_config/datasets/callhome_english.yaml b/config/benchmark_config/datasets/callhome_english.yaml new file mode 100644 index 0000000..c6315a3 --- /dev/null +++ b/config/benchmark_config/datasets/callhome_english.yaml @@ -0,0 +1,6 @@ +callhome_english: + dataset_id: argmaxinc/callhome-english + split: test + + + diff --git a/config/benchmark_config/metrics/cpwer.yaml b/config/benchmark_config/metrics/cpwer.yaml new file mode 100644 index 0000000..c10d6c9 --- /dev/null +++ b/config/benchmark_config/metrics/cpwer.yaml @@ -0,0 +1,3 @@ +# Dummy argument to make the config file valid +cpwer: + skip_overlap: false \ No newline at end of file diff --git a/config/pipeline_configs/DeepgramStreamingOrchestrationPipeline.yaml b/config/pipeline_configs/DeepgramStreamingOrchestrationPipeline.yaml new file mode 100644 index 0000000..6d3eae2 --- /dev/null +++ b/config/pipeline_configs/DeepgramStreamingOrchestrationPipeline.yaml @@ -0,0 +1,9 @@ +DeepgramStreamingOrchestrationPipeline: + pipeline_config: + sample_rate: 16000 + channels: 1 + sample_width: 2 + realtime_resolution: 0.020 + model_version: "nova-3" + enable_diarization: true + diff --git a/src/openbench/metric/word_error_metrics/word_error_metrics.py b/src/openbench/metric/word_error_metrics/word_error_metrics.py index 4cef2a1..fc4cadb 100644 --- a/src/openbench/metric/word_error_metrics/word_error_metrics.py +++ b/src/openbench/metric/word_error_metrics/word_error_metrics.py @@ -297,7 +297,10 @@ def compute_metric(self, detail: Details) -> float: return (S + D + I) / N if N > 0 else 0.0 -@MetricRegistry.register_metric(PipelineType.ORCHESTRATION, MetricOptions.CPWER) +@MetricRegistry.register_metric( + PipelineType.ORCHESTRATION, + MetricOptions.CPWER, +) class ConcatenatedMinimumPermutationWER(BaseWordErrorMetric): """Concatenated minimum-Permutation Word Error Rate (cpWER) implementation. diff --git a/src/openbench/pipeline/orchestration/__init__.py b/src/openbench/pipeline/orchestration/__init__.py index 2d7bd44..4b39f6c 100644 --- a/src/openbench/pipeline/orchestration/__init__.py +++ b/src/openbench/pipeline/orchestration/__init__.py @@ -3,6 +3,10 @@ from .nemo import NeMoMTParakeetPipeline, NeMoMTParakeetPipelineConfig from .orchestration_deepgram import DeepgramOrchestrationPipeline, DeepgramOrchestrationPipelineConfig +from .orchestration_deepgram_streaming import ( + DeepgramStreamingOrchestrationPipeline, + DeepgramStreamingOrchestrationPipelineConfig, +) from .orchestration_openai import OpenAIOrchestrationPipeline, OpenAIOrchestrationPipelineConfig from .orchestration_whisperkitpro import WhisperKitProOrchestrationConfig, WhisperKitProOrchestrationPipeline from .whisperx import WhisperXPipeline, WhisperXPipelineConfig @@ -11,6 +15,8 @@ __all__ = [ "DeepgramOrchestrationPipeline", "DeepgramOrchestrationPipelineConfig", + "DeepgramStreamingOrchestrationPipeline", + "DeepgramStreamingOrchestrationPipelineConfig", "WhisperXPipeline", "WhisperXPipelineConfig", "WhisperKitProOrchestrationPipeline", diff --git a/src/openbench/pipeline/orchestration/orchestration_deepgram_streaming.py b/src/openbench/pipeline/orchestration/orchestration_deepgram_streaming.py new file mode 100644 index 0000000..1d2be74 --- /dev/null +++ b/src/openbench/pipeline/orchestration/orchestration_deepgram_streaming.py @@ -0,0 +1,86 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +import numpy as np +from pydantic import Field + +from ...dataset import OrchestrationSample +from ...pipeline import Pipeline, PipelineConfig, register_pipeline +from ...pipeline_prediction import Transcript, Word +from ...types import PipelineType +from ..streaming_transcription.deepgram import DeepgramApi +from .common import OrchestrationOutput + + +class DeepgramStreamingOrchestrationPipelineConfig(PipelineConfig): + sample_rate: int = Field(default=16000, description="Sample rate of the audio") + channels: int = Field(default=1, description="Number of audio channels") + sample_width: int = Field(default=2, description="Sample width in bytes") + realtime_resolution: float = Field(default=0.020, description="Real-time resolution for streaming") + model_version: str = Field( + default="nova-3", description=("The model to use for real-time transcription with diarization") + ) + enable_diarization: bool = Field(default=True, description="Whether to enable speaker diarization") + + +@register_pipeline +class DeepgramStreamingOrchestrationPipeline(Pipeline): + _config_class = DeepgramStreamingOrchestrationPipelineConfig + pipeline_type = PipelineType.ORCHESTRATION + + def build_pipeline(self): + """Build Deepgram streaming API with diarization enabled.""" + # Create a modified config for the streaming API + from types import SimpleNamespace + + api_config = SimpleNamespace( + channels=self.config.channels, + sample_width=self.config.sample_width, + sample_rate=self.config.sample_rate, + realtime_resolution=self.config.realtime_resolution, + model_version=self.config.model_version, + enable_diarization=self.config.enable_diarization, + ) + + pipeline = DeepgramApi(api_config) + return pipeline + + def parse_input(self, input_sample: OrchestrationSample): + """Convert audio waveform to bytes for streaming.""" + y = input_sample.waveform + y_int16 = (y * 32767).astype(np.int16) + audio_data_byte = y_int16.T.tobytes() + return audio_data_byte + + def parse_output(self, output) -> OrchestrationOutput: + """Parse output to extract transcription and diarization.""" + # Extract words with speaker info if diarization enabled + words = [] + + if "words_with_speakers" in output and output["words_with_speakers"]: + # This comes from diarization-enabled streaming + for word_info in output["words_with_speakers"]: + words.append( + Word( + word=word_info.get("word", ""), + start=word_info.get("start"), + end=word_info.get("end"), + speaker=word_info.get("speaker"), + ) + ) + else: + # Speaker labels are required for orchestration pipelines + raise ValueError( + "No speaker diarization data available. " + "Orchestration pipelines require speaker labels. " + "Ensure 'enable_diarization' is set to True in the pipeline config." + ) + + # Create final transcript with speaker-attributed words + transcript = Transcript(words=words) + + return OrchestrationOutput( + prediction=transcript, + transcription_output=None, + diarization_output=None, + ) diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index b6583a4..02f6ed9 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -16,6 +16,7 @@ ) from .orchestration import ( DeepgramOrchestrationPipeline, + DeepgramStreamingOrchestrationPipeline, NeMoMTParakeetPipeline, OpenAIOrchestrationPipeline, WhisperKitProOrchestrationPipeline, @@ -171,6 +172,20 @@ def register_pipeline_aliases() -> None: description="Deepgram orchestration pipeline. Requires API key from https://www.deepgram.com/. Set `DEEPGRAM_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "deepgram-streaming-orchestration", + DeepgramStreamingOrchestrationPipeline, + default_config={ + "sample_rate": 16000, + "channels": 1, + "sample_width": 2, + "realtime_resolution": 0.020, + "model_version": "nova-3", + "enable_diarization": True, + }, + description=("Deepgram streaming orchestration pipeline with diarization enabled."), + ) + PipelineRegistry.register_alias( "whisperkitpro-orchestration-tiny", WhisperKitProOrchestrationPipeline, diff --git a/src/openbench/pipeline/streaming_transcription/deepgram.py b/src/openbench/pipeline/streaming_transcription/deepgram.py index ae3bdd6..2cc90d3 100644 --- a/src/openbench/pipeline/streaming_transcription/deepgram.py +++ b/src/openbench/pipeline/streaming_transcription/deepgram.py @@ -26,18 +26,20 @@ class DeepgramApi: def __init__(self, cfg) -> None: - self.realtime_resolution = 0.020 - self.model_version = "nova-3" + self.realtime_resolution = getattr(cfg, "realtime_resolution", 0.020) + self.model_version = getattr(cfg, "model_version", "nova-3") self.api_key = os.getenv("DEEPGRAM_API_KEY") assert self.api_key is not None, "Please set API key in environment" self.channels = cfg.channels self.sample_width = cfg.sample_width self.sample_rate = cfg.sample_rate self.host_url = os.getenv("DEEPGRAM_HOST_URL", "wss://api.deepgram.com") + self.enable_diarization = getattr(cfg, "enable_diarization", False) async def run(self, data, key, channels, sample_width, sample_rate): - """Connect to the Deepgram real-time streaming endpoint, stream the data - in real-time, and print out the responses from the server. + """Connect to Deepgram real-time streaming endpoint. + + Stream the data in real-time and print responses from server. This uses a pre-recorded file as an example. It mimics a real-time connection by sending `REALTIME_RESOLUTION` seconds of audio every @@ -62,9 +64,23 @@ async def run(self, data, key, channels, sample_width, sample_rate): confirmed_interim_transcripts = [] model_timestamps_hypothesis = [] model_timestamps_confirmed = [] - # Connect to the real-time streaming endpoint, attaching our API key. + words_with_speakers = [] + + # Build connection URL with optional diarization + url = ( + f"{self.host_url}/v1/listen?" + f"model={self.model_version}&" + f"channels={channels}&" + f"sample_rate={sample_rate}&" + f"encoding=linear16&" + f"interim_results=true" + ) + if self.enable_diarization: + url += "&diarize=true" + + # Connect to the real-time streaming endpoint async with websockets.connect( - f"{self.host_url}/v1/listen?model={self.model_version}&channels={channels}&sample_rate={sample_rate}&encoding=linear16&interim_results=true", + url, additional_headers={ "Authorization": "Token {}".format(key), }, @@ -91,8 +107,8 @@ async def sender(ws): await ws.send(json.dumps({"type": "CloseStream"})) async def receiver(ws): - """Print out the messages received from the server.""" - nonlocal audio_cursor + """Print out messages received from the server.""" + nonlocal audio_cursor, words_with_speakers global transcript global interim_transcripts global audio_cursor_l @@ -105,26 +121,37 @@ async def receiver(ws): async for msg in ws: msg = json.loads(msg) if "request_id" in msg: - # This is the final metadata message. It gets sent as the - # very last message by Deepgram during a clean shutdown. + # This is the final metadata message. # There is no transcript in it. continue - if msg["channel"]["alternatives"][0]["transcript"] != "": + alternatives = msg["channel"]["alternatives"][0] + if alternatives["transcript"] != "": if not msg["is_final"]: audio_cursor_l.append(audio_cursor) - model_timestamps_hypothesis.append(msg["channel"]["alternatives"][0]["words"]) - interim_transcripts.append( - transcript + " " + msg["channel"]["alternatives"][0]["transcript"] - ) - logger.debug( - "\n" + "Transcription: " + transcript + msg["channel"]["alternatives"][0]["transcript"] - ) + model_timestamps_hypothesis.append(alternatives["words"]) + interim_transcripts.append(transcript + " " + alternatives["transcript"]) + logger.debug("\n" + "Transcription: " + transcript + alternatives["transcript"]) elif msg["is_final"]: confirmed_audio_cursor_l.append(audio_cursor) - transcript = transcript + " " + msg["channel"]["alternatives"][0]["transcript"] + transcript = transcript + " " + alternatives["transcript"] confirmed_interim_transcripts.append(transcript) - model_timestamps_confirmed.append(msg["channel"]["alternatives"][0]["words"]) + words = alternatives["words"] + model_timestamps_confirmed.append(words) + + # Collect speaker info if diarization enabled + if self.enable_diarization: + for word_info in words: + if "speaker" in word_info: + speaker_label = f"SPEAKER_{word_info['speaker']}" + words_with_speakers.append( + { + "word": word_info.get("word", ""), + "speaker": speaker_label, + "start": word_info.get("start", 0), + "end": word_info.get("end", 0), + } + ) await asyncio.gather(sender(ws), receiver(ws)) return ( @@ -135,6 +162,7 @@ async def receiver(ws): confirmed_audio_cursor_l, model_timestamps_hypothesis, model_timestamps_confirmed, + words_with_speakers, ) def __call__(self, sample): @@ -147,6 +175,7 @@ def __call__(self, sample): confirmed_audio_cursor_l, model_timestamps_hypothesis, model_timestamps_confirmed, + words_with_speakers, ) = asyncio.get_event_loop().run_until_complete( self.run(sample, self.api_key, self.channels, self.sample_width, self.sample_rate) ) @@ -154,10 +183,11 @@ def __call__(self, sample): "transcript": transcript, "interim_transcripts": interim_transcripts, "audio_cursor": audio_cursor_l, - "confirmed_interim_transcripts": confirmed_interim_transcripts, + "confirmed_interim_transcripts": (confirmed_interim_transcripts), "confirmed_audio_cursor": confirmed_audio_cursor_l, - "model_timestamps_hypothesis": model_timestamps_hypothesis, - "model_timestamps_confirmed": model_timestamps_confirmed, + "model_timestamps_hypothesis": (model_timestamps_hypothesis), + "model_timestamps_confirmed": (model_timestamps_confirmed), + "words_with_speakers": words_with_speakers, }