Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions suber/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from suber.metrics.suber import calculate_SubER
from suber.metrics.suber_statistics import SubERStatisticsCollector
from suber.metrics.sacrebleu_interface import calculate_sacrebleu_metric
from suber.metrics.bleurt_interface import calculate_bleurt
from suber.metrics.jiwer_interface import calculate_word_error_rate
from suber.metrics.cer import calculate_character_error_rate
from suber.metrics.length_ratio import calculate_length_ratio
Expand Down Expand Up @@ -44,6 +45,8 @@ def parse_arguments():
parser.add_argument("--suber-statistics", action="store_true",
help="If set, will create an '#info' field in the output containing statistics about the "
"different edit operations used to calculate the SubER score.")
parser.add_argument("--bleurt-checkpoint",
help="BLEURT model checkpoint folder, required if (AS-)BLEURT used as one of the metrics.")

return parser.parse_args()

Expand Down Expand Up @@ -134,6 +137,10 @@ def main():
metric_score = calculate_character_error_rate(
hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric)

elif metric == "BLEURT":
metric_score = calculate_bleurt(
hypothesis=hypothesis_segments_to_use, reference=reference_segments, checkpoint=args.bleurt_checkpoint)

else:
metric_score = calculate_sacrebleu_metric(
hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric,
Expand All @@ -153,7 +160,7 @@ def check_metrics(metrics):
# Our proposed metric:
"SubER", "SubER-cased",
# Established ASR and MT metrics, requiring aligned hypothesis-references segments:
"WER", "CER", "BLEU", "TER", "chrF",
"WER", "CER", "BLEU", "TER", "chrF", "BLEURT",
# Cased and punctuated variants of the above:
"WER-cased", "CER-cased",
# Segmentation-aware variants of the above that include line breaks as tokens:
Expand All @@ -162,7 +169,7 @@ def check_metrics(metrics):
# proposed by Karakanta et al. https://aclanthology.org/2020.iwslt-1.26.pdf
"TER-br",
# With an "AS-" prefix, the metric is computed after Levenshtein alignment of hypothesis and reference:
"AS-WER", "AS-CER", "AS-BLEU", "AS-TER", "AS-chrF", "AS-WER-cased", "AS-CER-cased", "AS-WER-seg",
"AS-WER", "AS-CER", "AS-BLEU", "AS-TER", "AS-chrF", "AS-BLEURT", "AS-WER-cased", "AS-CER-cased", "AS-WER-seg",
"AS-BLEU-seg", "AS-TER-seg", "AS-TER-br",
# With an "t-" prefix, the metric is computed after time alignment of hypothesis and reference:
"t-WER", "t-CER", "t-BLEU", "t-TER", "t-chrF", "t-WER-cased", "t-CER-cased", "t-WER-seg", "t-BLEU-seg",
Expand Down
26 changes: 26 additions & 0 deletions suber/metrics/bleurt_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import List

from suber.data_types import Segment
from suber.utilities import segment_to_string


def calculate_bleurt(hypothesis: List[Segment], reference: List[Segment], checkpoint=None) -> float:

from bleurt import score # Local import to make dependency optional.

if not checkpoint:
raise ValueError(
"BLEURT checkpoint needs to be downloaded and specified via --bleurt-checkpoint. "
"See https://github.com/google-research/bleurt/blob/master/README.md")

score.logging.set_verbosity("INFO")

hypothesis_strings = [segment_to_string(segment) for segment in hypothesis]
reference_strings = [segment_to_string(segment) for segment in reference]

scorer = score.BleurtScorer(checkpoint)
scores = scorer.score(references=reference_strings, candidates=hypothesis_strings)

average_score = sum(scores) / len(scores)

return round(average_score, 3)