diff --git a/suber/__main__.py b/suber/__main__.py index 8eeceae..b8c2840 100644 --- a/suber/__main__.py +++ b/suber/__main__.py @@ -12,6 +12,7 @@ from suber.metrics.suber import calculate_SubER from suber.metrics.suber_statistics import SubERStatisticsCollector from suber.metrics.sacrebleu_interface import calculate_sacrebleu_metric +from suber.metrics.bleurt_interface import calculate_bleurt from suber.metrics.jiwer_interface import calculate_word_error_rate from suber.metrics.cer import calculate_character_error_rate from suber.metrics.length_ratio import calculate_length_ratio @@ -44,6 +45,8 @@ def parse_arguments(): parser.add_argument("--suber-statistics", action="store_true", help="If set, will create an '#info' field in the output containing statistics about the " "different edit operations used to calculate the SubER score.") + parser.add_argument("--bleurt-checkpoint", + help="BLEURT model checkpoint folder, required if (AS-)BLEURT used as one of the metrics.") return parser.parse_args() @@ -134,6 +137,10 @@ def main(): metric_score = calculate_character_error_rate( hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric) + elif metric == "BLEURT": + metric_score = calculate_bleurt( + hypothesis=hypothesis_segments_to_use, reference=reference_segments, checkpoint=args.bleurt_checkpoint) + else: metric_score = calculate_sacrebleu_metric( hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric, @@ -153,7 +160,7 @@ def check_metrics(metrics): # Our proposed metric: "SubER", "SubER-cased", # Established ASR and MT metrics, requiring aligned hypothesis-references segments: - "WER", "CER", "BLEU", "TER", "chrF", + "WER", "CER", "BLEU", "TER", "chrF", "BLEURT", # Cased and punctuated variants of the above: "WER-cased", "CER-cased", # Segmentation-aware variants of the above that include line breaks as tokens: @@ -162,7 +169,7 @@ def check_metrics(metrics): # proposed by Karakanta et al. https://aclanthology.org/2020.iwslt-1.26.pdf "TER-br", # With an "AS-" prefix, the metric is computed after Levenshtein alignment of hypothesis and reference: - "AS-WER", "AS-CER", "AS-BLEU", "AS-TER", "AS-chrF", "AS-WER-cased", "AS-CER-cased", "AS-WER-seg", + "AS-WER", "AS-CER", "AS-BLEU", "AS-TER", "AS-chrF", "AS-BLEURT", "AS-WER-cased", "AS-CER-cased", "AS-WER-seg", "AS-BLEU-seg", "AS-TER-seg", "AS-TER-br", # With an "t-" prefix, the metric is computed after time alignment of hypothesis and reference: "t-WER", "t-CER", "t-BLEU", "t-TER", "t-chrF", "t-WER-cased", "t-CER-cased", "t-WER-seg", "t-BLEU-seg", diff --git a/suber/metrics/bleurt_interface.py b/suber/metrics/bleurt_interface.py new file mode 100644 index 0000000..4a46f4c --- /dev/null +++ b/suber/metrics/bleurt_interface.py @@ -0,0 +1,26 @@ +from typing import List + +from suber.data_types import Segment +from suber.utilities import segment_to_string + + +def calculate_bleurt(hypothesis: List[Segment], reference: List[Segment], checkpoint=None) -> float: + + from bleurt import score # Local import to make dependency optional. + + if not checkpoint: + raise ValueError( + "BLEURT checkpoint needs to be downloaded and specified via --bleurt-checkpoint. " + "See https://github.com/google-research/bleurt/blob/master/README.md") + + score.logging.set_verbosity("INFO") + + hypothesis_strings = [segment_to_string(segment) for segment in hypothesis] + reference_strings = [segment_to_string(segment) for segment in reference] + + scorer = score.BleurtScorer(checkpoint) + scores = scorer.score(references=reference_strings, candidates=hypothesis_strings) + + average_score = sum(scores) / len(scores) + + return round(average_score, 3)