diff --git a/src/openbench/pipeline/transcription/transcription_nemo.py b/src/openbench/pipeline/transcription/transcription_nemo.py index 68a1205..e86c0fc 100644 --- a/src/openbench/pipeline/transcription/transcription_nemo.py +++ b/src/openbench/pipeline/transcription/transcription_nemo.py @@ -46,6 +46,22 @@ class NeMoTranscriptionPipelineConfig(TranscriptionConfig): default=0.6, description="Weight of CTC tokens to prevent false accept errors", ) + keyword_threshold: float = Field( + default=-5.0, + description="Threshold for keyword detection score", + ) + blank_threshold: float = Field( + default=0.8, + description="Threshold for blank token probability", + ) + non_blank_threshold: float = Field( + default=0.001, + description="Threshold for non-blank token probability", + ) + intersection_threshold: float = Field( + default=30.0, + description="Threshold for intersection between spotted word and word from alignment (in percentage)", + ) spelling_separator: str = Field( default="_", description="Separator between word and its spellings", @@ -142,6 +158,9 @@ def _transcribe_with_context_biasing(self, audio_path: Path) -> str: beam_threshold=self.config.beam_threshold, cb_weight=self.config.context_score, ctc_ali_token_weight=self.config.ctc_ali_token_weight, + keyword_threshold=self.config.keyword_threshold, + blank_threshold=self.config.blank_threshold, + non_blank_threshold=self.config.non_blank_threshold, ) if not ws_results: @@ -155,6 +174,7 @@ def _transcribe_with_context_biasing(self, audio_path: Path) -> str: self.asr_model, ws_results, decoder_type=self.config.decoder_type, + intersection_threshold=self.config.intersection_threshold, blank_idx=self.blank_idx, print_stats=False, )