Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ inference_outputs/
miscellaneous/

# Default openbench-cli output directory
downloaded_datasets/
downloaded_datasets/

*.sh
27 changes: 19 additions & 8 deletions src/openbench/engine/whisperkitpro_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,28 @@ class WhisperKitProConfig(BaseModel):
description="The compute units to use for the audio encoder. Default is CPU_AND_NE.",
)
text_decoder_compute_units: ct.ComputeUnit = Field(
ct.ComputeUnit.CPU_AND_GPU,
ct.ComputeUnit.CPU_AND_NE,
description="The compute units to use for the text decoder. Default is CPU_AND_GPU.",
)
diarization: bool = Field(
False,
description="Whether to perform diarization",
)
orchestration_strategy: Literal["word", "segment"] = Field(
"segment",
description="The orchestration strategy to use either `word` or `segment`",
diarization_mode: Literal["realtime", "prerecorded"] = Field(
"prerecorded",
description="Sortformer streaming mode: `realtime` (1.04s latency) or `prerecorded` (9.84s latency). This is only applicable when `engine` is `sortformer`.",
)
orchestration_strategy: Literal["segment", "subsegment"] = Field(
"subsegment",
description="The orchestration strategy to use either `segment` or `subsegment`",
)
speaker_models_path: str | None = Field(
None,
description="The path to the speaker models directory",
)
clusterer_version: Literal["pyannote3", "pyannote4"] = Field(
"pyannote4",
description="The version of the clusterer to use",
engine: Literal["pyannote", "sortformer"] = Field(
"pyannote",
description="The engine to use. If `sortformer` the diarization model used is Sortformer, otherwise it is pyannote.",
)
use_exclusive_reconciliation: bool = Field(
False,
Expand Down Expand Up @@ -158,6 +162,7 @@ def generate_cli_args(self, model_path: Path | None = None) -> list[str]:
COMPUTE_UNITS_MAPPER[self.text_decoder_compute_units],
"--fast-load",
str(self.fast_load).lower(),
"--verbose",
]
)

Expand All @@ -171,9 +176,15 @@ def generate_cli_args(self, model_path: Path | None = None) -> list[str]:
if self.diarization:
args.extend(["--diarization"])
args.extend(["--orchestration-strategy", self.orchestration_strategy])

# Add rttm path
args.extend(["--rttm-path", self.rttm_path])
args.extend(["--clusterer-version", self.clusterer_version])
args.extend(["--engine", self.engine])

# Only add diarization mode if using Sortformer
if self.engine == "sortformer":
args.extend(["--diarization-mode", self.diarization_mode])

# If speaker models path is provided use it
if self.speaker_models_path:
args.extend(["--speaker-models-path", self.speaker_models_path])
Expand Down
68 changes: 43 additions & 25 deletions src/openbench/pipeline/diarization/speakerkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,48 +23,68 @@
TEMP_AUDIO_DIR = Path("audio_temp")


class SpeakerKitPipelineConfig(DiarizationPipelineConfig):
cli_path: str = Field(..., description="The absolute path to the SpeakerKit CLI")
clusterer_version: Literal["pyannote3", "pyannote4"] = Field(
"pyannote4", description="The version of the clusterer to use"
)
model_path: str | None = Field(None, description="The absolute path to the SpeakerKit model")


class SpeakerKitInput(TypedDict):
audio_path: Path
output_path: Path
num_speakers: int | None


class SpeakerKitCli:
def __init__(self, config: SpeakerKitPipelineConfig):
self.cli_path = config.cli_path
self.model_path = config.model_path
self.clusterer_version = config.clusterer_version
class SpeakerKitPipelineConfig(DiarizationPipelineConfig):
cli_path: str = Field(..., description="The absolute path to the SpeakerKit CLI")
model_path: str | None = Field(None, description="The absolute path to the SpeakerKit model directory")
engine: Literal["pyannote", "sortformer"] = Field("pyannote", description="The engine to use")

def __call__(self, speakerkit_input: SpeakerKitInput) -> tuple[Path, float]:
@property
def is_sortformer(self) -> bool:
return self.engine == "sortformer"

def generate_cli_args(self, inputs: SpeakerKitInput) -> list[str]:
cmd = [
self.cli_path,
"diarize",
"--audio-path",
str(speakerkit_input["audio_path"]),
str(inputs["audio_path"]),
"--rttm-path",
str(speakerkit_input["output_path"]),
"--clusterer-version",
self.clusterer_version,
str(inputs["output_path"]),
"--engine",
self.engine,
"--verbose",
]

if self.model_path:
if self.model_path is not None:
cmd.extend(["--model-path", self.model_path])

if speakerkit_input["num_speakers"] is not None:
cmd.extend(["--num-speakers", str(speakerkit_input["num_speakers"])])
if inputs["num_speakers"] is not None:
cmd.extend(["--num-speakers", str(inputs["num_speakers"])])

if "SPEAKERKIT_API_KEY" in os.environ:
cmd.extend(["--api-key", os.environ["SPEAKERKIT_API_KEY"]])

return cmd

def parse_stdout(self, stdout: str) -> float:
# Default pattern for pyannote models
pattern = r"Model Load Time:\s+\d+\.\d+\s+ms\nTotal Time:\s+(\d+\.\d+)\s+ms"
divisor = 1000.0

# if model is sortfomer we override the pattern and divisor
if self.is_sortformer:
pattern = r"Prediction time:\s+(\d+\.\d+)\s+seconds"
divisor = 1.0

matches = re.search(pattern, stdout)
if matches is None:
raise ValueError(f"Could not parse prediction time from stdout: {stdout!r}")
return float(matches.group(1)) / divisor


class SpeakerKitCli:
def __init__(self, config: SpeakerKitPipelineConfig):
self.config = config

def __call__(self, speakerkit_input: SpeakerKitInput) -> tuple[Path, float]:
cmd = self.config.generate_cli_args(speakerkit_input)

try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.debug(f"Diarization CLI stdout:\n{result.stdout}")
Expand All @@ -81,11 +101,9 @@ def __call__(self, speakerkit_input: SpeakerKitInput) -> tuple[Path, float]:
speakerkit_input["audio_path"].unlink()

# Parse stdout and take the total time it took to diarize
pattern = r"Model Load Time:\s+\d+\.\d+\s+ms\nTotal Time:\s+(\d+\.\d+)\s+ms"
matches = re.search(pattern, result.stdout)
total_time = float(matches.group(1))
total_time = self.config.parse_stdout(result.stdout)

return speakerkit_input["output_path"], total_time / 1000
return speakerkit_input["output_path"], total_time


@register_pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,17 @@ class WhisperKitProOrchestrationConfig(OrchestrationConfig):
ComputeUnit.CPU_AND_NE,
description="The compute units to use for the text decoder. Default is CPU_AND_NE.",
)
orchestration_strategy: Literal["word", "segment"] = Field(
"segment",
description="The orchestration strategy to use either `word` or `segment`",
orchestration_strategy: Literal["segment", "subsegment"] = Field(
"subsegment",
description="The orchestration strategy to use either `segment` or `subsegment`",
)
clusterer_version: Literal["pyannote3", "pyannote4"] = Field(
"pyannote4",
description="The version of the clusterer to use",
engine: Literal["pyannote", "sortformer"] = Field(
"pyannote",
description="The engine to use. If `sortformer` the diarization model used is Sortformer, otherwise it is pyannote.",
)
diarization_mode: Literal["realtime", "prerecorded"] = Field(
"prerecorded",
description="Sortformer streaming mode: `realtime` (1.04s latency) or `prerecorded` (9.84s latency). This is only applicable when `engine` is `sortformer`.",
)
use_exclusive_reconciliation: bool = Field(
False,
Expand Down Expand Up @@ -107,7 +111,8 @@ def build_pipeline(self) -> WhisperKitPro:
chunking_strategy="vad",
diarization=True,
orchestration_strategy=self.config.orchestration_strategy,
clusterer_version_string=self.config.clusterer_version,
engine=self.config.engine,
diarization_mode=self.config.diarization_mode,
use_exclusive_reconciliation=self.config.use_exclusive_reconciliation,
fast_load=self.config.fast_load,
)
Expand Down
84 changes: 66 additions & 18 deletions src/openbench/pipeline/pipeline_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,23 @@ def register_pipeline_aliases() -> None:
default_config={
"out_dir": "./speakerkit-report",
"cli_path": os.getenv("SPEAKERKIT_CLI_PATH"),
"clusterer_version": "pyannote4",
"engine": "pyannote",
},
description="SpeakerKit speaker diarization pipeline. Requires CLI installation and API key. Set `SPEAKERKIT_CLI_PATH` and `SPEAKERKIT_API_KEY` env vars. For access to the CLI binary contact speakerkitpro@argmaxinc.com",
description="SpeakerKit speaker diarization pipeline using community-1 model from pyannote. Requires CLI installation and API key. Set `SPEAKERKIT_CLI_PATH` and `SPEAKERKIT_API_KEY` env vars. For access to the CLI binary contact speakerkitpro@argmaxinc.com",
)

PipelineRegistry.register_alias(
"speakerkit-sortformer-compressed",
SpeakerKitPipeline,
default_config={
"out_dir": "./speakerkit-sortformer-report",
"cli_path": os.getenv("SPEAKERKIT_CLI_PATH"),
"engine": "sortformer",
},
description=(
"SpeakerKit speaker diarization pipeline using Sortformer model compressed to 94MB. Requires CLI installation and API key. "
"Set `SPEAKERKIT_CLI_PATH` and `SPEAKERKIT_API_KEY` env vars. For access to the CLI binary contact speakerkitpro@argmaxinc.com."
),
)

PipelineRegistry.register_alias(
Expand Down Expand Up @@ -203,8 +217,8 @@ def register_pipeline_aliases() -> None:
"repo_id": "argmaxinc/whisperkit-pro",
"model_variant": "openai_whisper-tiny",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "segment",
"clusterer_version_string": "pyannote4",
"orchestration_strategy": "subsegment",
"engine": "pyannote",
"use_exclusive_reconciliation": True,
},
description="WhisperKitPro orchestration pipeline using the tiny version of the model. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
Expand All @@ -217,8 +231,8 @@ def register_pipeline_aliases() -> None:
"repo_id": "argmaxinc/whisperkit-pro",
"model_variant": "openai_whisper-large-v3",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "segment",
"clusterer_version_string": "pyannote4",
"orchestration_strategy": "subsegment",
"engine": "pyannote",
"use_exclusive_reconciliation": True,
},
description="WhisperKitPro orchestration pipeline using the large-v3 version of the model. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
Expand All @@ -231,8 +245,8 @@ def register_pipeline_aliases() -> None:
"repo_id": "argmaxinc/whisperkit-pro",
"model_variant": "openai_whisper-large-v3-v20240930",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "segment",
"clusterer_version_string": "pyannote4",
"orchestration_strategy": "subsegment",
"engine": "pyannote",
"use_exclusive_reconciliation": True,
},
description="WhisperKitPro orchestration pipeline using the large-v3-v20240930 version of the model (which is the same as large-v3-turbo from OpenAI). Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
Expand All @@ -245,8 +259,8 @@ def register_pipeline_aliases() -> None:
"repo_id": "argmaxinc/whisperkit-pro",
"model_variant": "openai_whisper-large-v3-v20240930_626MB",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "segment",
"clusterer_version_string": "pyannote4",
"orchestration_strategy": "subsegment",
"engine": "pyannote",
"use_exclusive_reconciliation": True,
},
description="WhisperKitPro orchestration pipeline using the large-v3-v20240930 version of the model compressed to 626MB (which is the same as large-v3-turbo from OpenAI). Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
Expand All @@ -259,8 +273,8 @@ def register_pipeline_aliases() -> None:
"repo_id": "argmaxinc/parakeetkit-pro",
"model_variant": "nvidia_parakeet-v2",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "segment",
"clusterer_version_string": "pyannote4",
"orchestration_strategy": "subsegment",
"engine": "pyannote",
"use_exclusive_reconciliation": True,
},
description="WhisperKitPro orchestration pipeline using the parakeet-v2 version of the model. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
Expand All @@ -273,22 +287,39 @@ def register_pipeline_aliases() -> None:
"repo_id": "argmaxinc/parakeetkit-pro",
"model_variant": "nvidia_parakeet-v2_476MB",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "segment",
"clusterer_version_string": "pyannote4",
"orchestration_strategy": "subsegment",
"engine": "pyannote",
"use_exclusive_reconciliation": True,
},
description="WhisperKitPro orchestration pipeline using the parakeet-v2 version of the model compressed to 476MB. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
)

PipelineRegistry.register_alias(
"whisperkitpro-orchestration-parakeet-v2-compressed-sortformer-compressed",
WhisperKitProOrchestrationPipeline,
default_config={
"repo_id": "argmaxinc/parakeetkit-pro",
"model_variant": "nvidia_parakeet-v2_476MB",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "subsegment",
"engine": "sortformer",
"diarization_mode": "prerecorded",
},
description=(
"WhisperKitPro orchestration pipeline using the parakeet-v2 version of the model compressed to 476MB and using Sortformer for diarization. "
"Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var."
),
)

PipelineRegistry.register_alias(
"whisperkitpro-orchestration-parakeet-v3",
WhisperKitProOrchestrationPipeline,
default_config={
"repo_id": "argmaxinc/parakeetkit-pro",
"model_variant": "nvidia_parakeet-v3",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "segment",
"clusterer_version_string": "pyannote4",
"orchestration_strategy": "subsegment",
"engine": "pyannote",
"use_exclusive_reconciliation": True,
},
description="WhisperKitPro orchestration pipeline using the parakeet-v3 version of the model. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
Expand All @@ -301,13 +332,30 @@ def register_pipeline_aliases() -> None:
"repo_id": "argmaxinc/parakeetkit-pro",
"model_variant": "nvidia_parakeet-v3_494MB",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "segment",
"clusterer_version_string": "pyannote4",
"orchestration_strategy": "subsegment",
"engine": "pyannote",
"use_exclusive_reconciliation": True,
},
description="WhisperKitPro orchestration pipeline using the parakeet-v3 version of the model compressed to 494MB. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.",
)

PipelineRegistry.register_alias(
"whisperkitpro-orchestration-parakeet-v3-compressed-sortformer-compressed",
WhisperKitProOrchestrationPipeline,
default_config={
"repo_id": "argmaxinc/parakeetkit-pro",
"model_variant": "nvidia_parakeet-v3_494MB",
"cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"),
"orchestration_strategy": "subsegment",
"engine": "sortformer",
"diarization_mode": "prerecorded",
},
description=(
"WhisperKitPro orchestration pipeline using the parakeet-v3 version of the model compressed to 494MB and using Sortformer for diarization. "
"Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var."
),
)

PipelineRegistry.register_alias(
"openai-orchestration",
OpenAIOrchestrationPipeline,
Expand Down
Loading