diff --git a/.gitignore b/.gitignore index ddfc86c..10c049c 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,6 @@ inference_outputs/ miscellaneous/ # Default openbench-cli output directory -downloaded_datasets/ \ No newline at end of file +downloaded_datasets/ + +*.sh \ No newline at end of file diff --git a/src/openbench/engine/whisperkitpro_engine.py b/src/openbench/engine/whisperkitpro_engine.py index 0103320..69a1d00 100644 --- a/src/openbench/engine/whisperkitpro_engine.py +++ b/src/openbench/engine/whisperkitpro_engine.py @@ -83,24 +83,28 @@ class WhisperKitProConfig(BaseModel): description="The compute units to use for the audio encoder. Default is CPU_AND_NE.", ) text_decoder_compute_units: ct.ComputeUnit = Field( - ct.ComputeUnit.CPU_AND_GPU, + ct.ComputeUnit.CPU_AND_NE, description="The compute units to use for the text decoder. Default is CPU_AND_GPU.", ) diarization: bool = Field( False, description="Whether to perform diarization", ) - orchestration_strategy: Literal["word", "segment"] = Field( - "segment", - description="The orchestration strategy to use either `word` or `segment`", + diarization_mode: Literal["realtime", "prerecorded"] = Field( + "prerecorded", + description="Sortformer streaming mode: `realtime` (1.04s latency) or `prerecorded` (9.84s latency). This is only applicable when `engine` is `sortformer`.", + ) + orchestration_strategy: Literal["segment", "subsegment"] = Field( + "subsegment", + description="The orchestration strategy to use either `segment` or `subsegment`", ) speaker_models_path: str | None = Field( None, description="The path to the speaker models directory", ) - clusterer_version: Literal["pyannote3", "pyannote4"] = Field( - "pyannote4", - description="The version of the clusterer to use", + engine: Literal["pyannote", "sortformer"] = Field( + "pyannote", + description="The engine to use. If `sortformer` the diarization model used is Sortformer, otherwise it is pyannote.", ) use_exclusive_reconciliation: bool = Field( False, @@ -158,6 +162,7 @@ def generate_cli_args(self, model_path: Path | None = None) -> list[str]: COMPUTE_UNITS_MAPPER[self.text_decoder_compute_units], "--fast-load", str(self.fast_load).lower(), + "--verbose", ] ) @@ -171,9 +176,15 @@ def generate_cli_args(self, model_path: Path | None = None) -> list[str]: if self.diarization: args.extend(["--diarization"]) args.extend(["--orchestration-strategy", self.orchestration_strategy]) + # Add rttm path args.extend(["--rttm-path", self.rttm_path]) - args.extend(["--clusterer-version", self.clusterer_version]) + args.extend(["--engine", self.engine]) + + # Only add diarization mode if using Sortformer + if self.engine == "sortformer": + args.extend(["--diarization-mode", self.diarization_mode]) + # If speaker models path is provided use it if self.speaker_models_path: args.extend(["--speaker-models-path", self.speaker_models_path]) diff --git a/src/openbench/pipeline/diarization/speakerkit.py b/src/openbench/pipeline/diarization/speakerkit.py index 3718cb3..6dfa908 100644 --- a/src/openbench/pipeline/diarization/speakerkit.py +++ b/src/openbench/pipeline/diarization/speakerkit.py @@ -23,48 +23,68 @@ TEMP_AUDIO_DIR = Path("audio_temp") -class SpeakerKitPipelineConfig(DiarizationPipelineConfig): - cli_path: str = Field(..., description="The absolute path to the SpeakerKit CLI") - clusterer_version: Literal["pyannote3", "pyannote4"] = Field( - "pyannote4", description="The version of the clusterer to use" - ) - model_path: str | None = Field(None, description="The absolute path to the SpeakerKit model") - - class SpeakerKitInput(TypedDict): audio_path: Path output_path: Path num_speakers: int | None -class SpeakerKitCli: - def __init__(self, config: SpeakerKitPipelineConfig): - self.cli_path = config.cli_path - self.model_path = config.model_path - self.clusterer_version = config.clusterer_version +class SpeakerKitPipelineConfig(DiarizationPipelineConfig): + cli_path: str = Field(..., description="The absolute path to the SpeakerKit CLI") + model_path: str | None = Field(None, description="The absolute path to the SpeakerKit model directory") + engine: Literal["pyannote", "sortformer"] = Field("pyannote", description="The engine to use") - def __call__(self, speakerkit_input: SpeakerKitInput) -> tuple[Path, float]: + @property + def is_sortformer(self) -> bool: + return self.engine == "sortformer" + + def generate_cli_args(self, inputs: SpeakerKitInput) -> list[str]: cmd = [ self.cli_path, "diarize", "--audio-path", - str(speakerkit_input["audio_path"]), + str(inputs["audio_path"]), "--rttm-path", - str(speakerkit_input["output_path"]), - "--clusterer-version", - self.clusterer_version, + str(inputs["output_path"]), + "--engine", + self.engine, "--verbose", ] - if self.model_path: + if self.model_path is not None: cmd.extend(["--model-path", self.model_path]) - if speakerkit_input["num_speakers"] is not None: - cmd.extend(["--num-speakers", str(speakerkit_input["num_speakers"])]) + if inputs["num_speakers"] is not None: + cmd.extend(["--num-speakers", str(inputs["num_speakers"])]) if "SPEAKERKIT_API_KEY" in os.environ: cmd.extend(["--api-key", os.environ["SPEAKERKIT_API_KEY"]]) + return cmd + + def parse_stdout(self, stdout: str) -> float: + # Default pattern for pyannote models + pattern = r"Model Load Time:\s+\d+\.\d+\s+ms\nTotal Time:\s+(\d+\.\d+)\s+ms" + divisor = 1000.0 + + # if model is sortfomer we override the pattern and divisor + if self.is_sortformer: + pattern = r"Prediction time:\s+(\d+\.\d+)\s+seconds" + divisor = 1.0 + + matches = re.search(pattern, stdout) + if matches is None: + raise ValueError(f"Could not parse prediction time from stdout: {stdout!r}") + return float(matches.group(1)) / divisor + + +class SpeakerKitCli: + def __init__(self, config: SpeakerKitPipelineConfig): + self.config = config + + def __call__(self, speakerkit_input: SpeakerKitInput) -> tuple[Path, float]: + cmd = self.config.generate_cli_args(speakerkit_input) + try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) logger.debug(f"Diarization CLI stdout:\n{result.stdout}") @@ -81,11 +101,9 @@ def __call__(self, speakerkit_input: SpeakerKitInput) -> tuple[Path, float]: speakerkit_input["audio_path"].unlink() # Parse stdout and take the total time it took to diarize - pattern = r"Model Load Time:\s+\d+\.\d+\s+ms\nTotal Time:\s+(\d+\.\d+)\s+ms" - matches = re.search(pattern, result.stdout) - total_time = float(matches.group(1)) + total_time = self.config.parse_stdout(result.stdout) - return speakerkit_input["output_path"], total_time / 1000 + return speakerkit_input["output_path"], total_time @register_pipeline diff --git a/src/openbench/pipeline/orchestration/orchestration_whisperkitpro.py b/src/openbench/pipeline/orchestration/orchestration_whisperkitpro.py index 5e88083..9609e73 100644 --- a/src/openbench/pipeline/orchestration/orchestration_whisperkitpro.py +++ b/src/openbench/pipeline/orchestration/orchestration_whisperkitpro.py @@ -69,13 +69,17 @@ class WhisperKitProOrchestrationConfig(OrchestrationConfig): ComputeUnit.CPU_AND_NE, description="The compute units to use for the text decoder. Default is CPU_AND_NE.", ) - orchestration_strategy: Literal["word", "segment"] = Field( - "segment", - description="The orchestration strategy to use either `word` or `segment`", + orchestration_strategy: Literal["segment", "subsegment"] = Field( + "subsegment", + description="The orchestration strategy to use either `segment` or `subsegment`", ) - clusterer_version: Literal["pyannote3", "pyannote4"] = Field( - "pyannote4", - description="The version of the clusterer to use", + engine: Literal["pyannote", "sortformer"] = Field( + "pyannote", + description="The engine to use. If `sortformer` the diarization model used is Sortformer, otherwise it is pyannote.", + ) + diarization_mode: Literal["realtime", "prerecorded"] = Field( + "prerecorded", + description="Sortformer streaming mode: `realtime` (1.04s latency) or `prerecorded` (9.84s latency). This is only applicable when `engine` is `sortformer`.", ) use_exclusive_reconciliation: bool = Field( False, @@ -107,7 +111,8 @@ def build_pipeline(self) -> WhisperKitPro: chunking_strategy="vad", diarization=True, orchestration_strategy=self.config.orchestration_strategy, - clusterer_version_string=self.config.clusterer_version, + engine=self.config.engine, + diarization_mode=self.config.diarization_mode, use_exclusive_reconciliation=self.config.use_exclusive_reconciliation, fast_load=self.config.fast_load, ) diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index 97287b0..ce9792b 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -112,9 +112,23 @@ def register_pipeline_aliases() -> None: default_config={ "out_dir": "./speakerkit-report", "cli_path": os.getenv("SPEAKERKIT_CLI_PATH"), - "clusterer_version": "pyannote4", + "engine": "pyannote", }, - description="SpeakerKit speaker diarization pipeline. Requires CLI installation and API key. Set `SPEAKERKIT_CLI_PATH` and `SPEAKERKIT_API_KEY` env vars. For access to the CLI binary contact speakerkitpro@argmaxinc.com", + description="SpeakerKit speaker diarization pipeline using community-1 model from pyannote. Requires CLI installation and API key. Set `SPEAKERKIT_CLI_PATH` and `SPEAKERKIT_API_KEY` env vars. For access to the CLI binary contact speakerkitpro@argmaxinc.com", + ) + + PipelineRegistry.register_alias( + "speakerkit-sortformer-compressed", + SpeakerKitPipeline, + default_config={ + "out_dir": "./speakerkit-sortformer-report", + "cli_path": os.getenv("SPEAKERKIT_CLI_PATH"), + "engine": "sortformer", + }, + description=( + "SpeakerKit speaker diarization pipeline using Sortformer model compressed to 94MB. Requires CLI installation and API key. " + "Set `SPEAKERKIT_CLI_PATH` and `SPEAKERKIT_API_KEY` env vars. For access to the CLI binary contact speakerkitpro@argmaxinc.com." + ), ) PipelineRegistry.register_alias( @@ -203,8 +217,8 @@ def register_pipeline_aliases() -> None: "repo_id": "argmaxinc/whisperkit-pro", "model_variant": "openai_whisper-tiny", "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), - "orchestration_strategy": "segment", - "clusterer_version_string": "pyannote4", + "orchestration_strategy": "subsegment", + "engine": "pyannote", "use_exclusive_reconciliation": True, }, description="WhisperKitPro orchestration pipeline using the tiny version of the model. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", @@ -217,8 +231,8 @@ def register_pipeline_aliases() -> None: "repo_id": "argmaxinc/whisperkit-pro", "model_variant": "openai_whisper-large-v3", "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), - "orchestration_strategy": "segment", - "clusterer_version_string": "pyannote4", + "orchestration_strategy": "subsegment", + "engine": "pyannote", "use_exclusive_reconciliation": True, }, description="WhisperKitPro orchestration pipeline using the large-v3 version of the model. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", @@ -231,8 +245,8 @@ def register_pipeline_aliases() -> None: "repo_id": "argmaxinc/whisperkit-pro", "model_variant": "openai_whisper-large-v3-v20240930", "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), - "orchestration_strategy": "segment", - "clusterer_version_string": "pyannote4", + "orchestration_strategy": "subsegment", + "engine": "pyannote", "use_exclusive_reconciliation": True, }, description="WhisperKitPro orchestration pipeline using the large-v3-v20240930 version of the model (which is the same as large-v3-turbo from OpenAI). Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", @@ -245,8 +259,8 @@ def register_pipeline_aliases() -> None: "repo_id": "argmaxinc/whisperkit-pro", "model_variant": "openai_whisper-large-v3-v20240930_626MB", "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), - "orchestration_strategy": "segment", - "clusterer_version_string": "pyannote4", + "orchestration_strategy": "subsegment", + "engine": "pyannote", "use_exclusive_reconciliation": True, }, description="WhisperKitPro orchestration pipeline using the large-v3-v20240930 version of the model compressed to 626MB (which is the same as large-v3-turbo from OpenAI). Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", @@ -259,8 +273,8 @@ def register_pipeline_aliases() -> None: "repo_id": "argmaxinc/parakeetkit-pro", "model_variant": "nvidia_parakeet-v2", "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), - "orchestration_strategy": "segment", - "clusterer_version_string": "pyannote4", + "orchestration_strategy": "subsegment", + "engine": "pyannote", "use_exclusive_reconciliation": True, }, description="WhisperKitPro orchestration pipeline using the parakeet-v2 version of the model. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", @@ -273,13 +287,30 @@ def register_pipeline_aliases() -> None: "repo_id": "argmaxinc/parakeetkit-pro", "model_variant": "nvidia_parakeet-v2_476MB", "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), - "orchestration_strategy": "segment", - "clusterer_version_string": "pyannote4", + "orchestration_strategy": "subsegment", + "engine": "pyannote", "use_exclusive_reconciliation": True, }, description="WhisperKitPro orchestration pipeline using the parakeet-v2 version of the model compressed to 476MB. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "whisperkitpro-orchestration-parakeet-v2-compressed-sortformer-compressed", + WhisperKitProOrchestrationPipeline, + default_config={ + "repo_id": "argmaxinc/parakeetkit-pro", + "model_variant": "nvidia_parakeet-v2_476MB", + "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "orchestration_strategy": "subsegment", + "engine": "sortformer", + "diarization_mode": "prerecorded", + }, + description=( + "WhisperKitPro orchestration pipeline using the parakeet-v2 version of the model compressed to 476MB and using Sortformer for diarization. " + "Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var." + ), + ) + PipelineRegistry.register_alias( "whisperkitpro-orchestration-parakeet-v3", WhisperKitProOrchestrationPipeline, @@ -287,8 +318,8 @@ def register_pipeline_aliases() -> None: "repo_id": "argmaxinc/parakeetkit-pro", "model_variant": "nvidia_parakeet-v3", "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), - "orchestration_strategy": "segment", - "clusterer_version_string": "pyannote4", + "orchestration_strategy": "subsegment", + "engine": "pyannote", "use_exclusive_reconciliation": True, }, description="WhisperKitPro orchestration pipeline using the parakeet-v3 version of the model. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", @@ -301,13 +332,30 @@ def register_pipeline_aliases() -> None: "repo_id": "argmaxinc/parakeetkit-pro", "model_variant": "nvidia_parakeet-v3_494MB", "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), - "orchestration_strategy": "segment", - "clusterer_version_string": "pyannote4", + "orchestration_strategy": "subsegment", + "engine": "pyannote", "use_exclusive_reconciliation": True, }, description="WhisperKitPro orchestration pipeline using the parakeet-v3 version of the model compressed to 494MB. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "whisperkitpro-orchestration-parakeet-v3-compressed-sortformer-compressed", + WhisperKitProOrchestrationPipeline, + default_config={ + "repo_id": "argmaxinc/parakeetkit-pro", + "model_variant": "nvidia_parakeet-v3_494MB", + "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "orchestration_strategy": "subsegment", + "engine": "sortformer", + "diarization_mode": "prerecorded", + }, + description=( + "WhisperKitPro orchestration pipeline using the parakeet-v3 version of the model compressed to 494MB and using Sortformer for diarization. " + "Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var." + ), + ) + PipelineRegistry.register_alias( "openai-orchestration", OpenAIOrchestrationPipeline,